Move back directive processing to tlstunnel package

This commit is contained in:
Simon Ser 2020-09-10 15:05:43 +02:00
parent ec2a768909
commit 2fdea9d4ed
No known key found for this signature in database
GPG Key ID: 0FDE7BE0E88F5E48
3 changed files with 121 additions and 106 deletions

View File

@ -2,11 +2,7 @@ package main
import ( import (
"flag" "flag"
"fmt"
"log" "log"
"net"
"net/url"
"strings"
"git.sr.ht/~emersion/tlstunnel" "git.sr.ht/~emersion/tlstunnel"
) )
@ -24,19 +20,8 @@ func main() {
srv := tlstunnel.NewServer() srv := tlstunnel.NewServer()
for _, d := range cfg.Children { if err := srv.Load(cfg); err != nil {
var err error log.Fatal(err)
switch d.Name {
case "frontend":
err = parseFrontend(srv, d)
case "tls":
err = parseTLS(srv, d)
default:
log.Fatalf("unknown %q directive", d.Name)
}
if err != nil {
log.Fatalf("directive %q: %v", d.Name, err)
}
} }
if err := srv.Start(); err != nil { if err := srv.Start(); err != nil {
@ -45,92 +30,3 @@ func main() {
select {} select {}
} }
func parseFrontend(srv *tlstunnel.Server, d *tlstunnel.Directive) error {
frontend := &tlstunnel.Frontend{Server: srv}
srv.Frontends = append(srv.Frontends, frontend)
// TODO: support multiple backends
backendDirective := d.ChildByName("backend")
if backendDirective == nil {
return fmt.Errorf("missing backend directive in frontend block")
}
if err := parseBackend(&frontend.Backend, backendDirective); err != nil {
return err
}
for _, listenAddr := range d.Params {
host, port, err := net.SplitHostPort(listenAddr)
if err != nil {
return fmt.Errorf("failed to parse listen address %q: %v", listenAddr, err)
}
// TODO: come up with something more robust
var name string
if host != "" && host != "localhost" && net.ParseIP(host) == nil {
name = host
host = ""
srv.ManagedNames = append(srv.ManagedNames, name)
}
addr := net.JoinHostPort(host, port)
ln := srv.RegisterListener(addr)
if err := ln.RegisterFrontend(name, frontend); err != nil {
return err
}
}
return nil
}
func parseBackend(backend *tlstunnel.Backend, d *tlstunnel.Directive) error {
var backendURI string
if err := d.ParseParams(&backendURI); err != nil {
return err
}
if !strings.Contains(backendURI, ":/") {
// This is a raw domain name, make it an URL with an empty scheme
backendURI = "//" + backendURI
}
u, err := url.Parse(backendURI)
if err != nil {
return fmt.Errorf("failed to parse backend URI %q: %v", backendURI, err)
}
if strings.HasSuffix(u.Scheme, "+proxy") {
u.Scheme = strings.TrimSuffix(u.Scheme, "+proxy")
backend.Proxy = true
}
switch u.Scheme {
case "", "tcp":
backend.Network = "tcp"
backend.Address = u.Host
case "unix":
backend.Network = "unix"
backend.Address = u.Host
default:
return fmt.Errorf("failed to setup backend %q: unsupported URI scheme", backendURI)
}
return nil
}
func parseTLS(srv *tlstunnel.Server, d *tlstunnel.Directive) error {
for _, child := range d.Children {
switch child.Name {
case "acme_ca":
var caURL string
if err := child.ParseParams(&caURL); err != nil {
return err
}
srv.ACMEManager.CA = caURL
default:
return fmt.Errorf("unknown %q directive", child.Name)
}
}
return nil
}

115
directives.go Normal file
View File

@ -0,0 +1,115 @@
package tlstunnel
import (
"fmt"
"net"
"net/url"
"strings"
)
func parseConfig(srv *Server, cfg *Directive) error {
for _, d := range cfg.Children {
var err error
switch d.Name {
case "frontend":
err = parseFrontend(srv, d)
case "tls":
err = parseTLS(srv, d)
default:
return fmt.Errorf("unknown %q directive", d.Name)
}
if err != nil {
return fmt.Errorf("directive %q: %v", d.Name, err)
}
}
return nil
}
func parseFrontend(srv *Server, d *Directive) error {
frontend := &Frontend{Server: srv}
srv.Frontends = append(srv.Frontends, frontend)
// TODO: support multiple backends
backendDirective := d.ChildByName("backend")
if backendDirective == nil {
return fmt.Errorf("missing backend directive in frontend block")
}
if err := parseBackend(&frontend.Backend, backendDirective); err != nil {
return err
}
for _, listenAddr := range d.Params {
host, port, err := net.SplitHostPort(listenAddr)
if err != nil {
return fmt.Errorf("failed to parse listen address %q: %v", listenAddr, err)
}
// TODO: come up with something more robust
var name string
if host != "" && host != "localhost" && net.ParseIP(host) == nil {
name = host
host = ""
srv.ManagedNames = append(srv.ManagedNames, name)
}
addr := net.JoinHostPort(host, port)
ln := srv.RegisterListener(addr)
if err := ln.RegisterFrontend(name, frontend); err != nil {
return err
}
}
return nil
}
func parseBackend(backend *Backend, d *Directive) error {
var backendURI string
if err := d.ParseParams(&backendURI); err != nil {
return err
}
if !strings.Contains(backendURI, ":/") {
// This is a raw domain name, make it an URL with an empty scheme
backendURI = "//" + backendURI
}
u, err := url.Parse(backendURI)
if err != nil {
return fmt.Errorf("failed to parse backend URI %q: %v", backendURI, err)
}
if strings.HasSuffix(u.Scheme, "+proxy") {
u.Scheme = strings.TrimSuffix(u.Scheme, "+proxy")
backend.Proxy = true
}
switch u.Scheme {
case "", "tcp":
backend.Network = "tcp"
backend.Address = u.Host
case "unix":
backend.Network = "unix"
backend.Address = u.Host
default:
return fmt.Errorf("failed to setup backend %q: unsupported URI scheme", backendURI)
}
return nil
}
func parseTLS(srv *Server, d *Directive) error {
for _, child := range d.Children {
switch child.Name {
case "acme_ca":
var caURL string
if err := child.ParseParams(&caURL); err != nil {
return err
}
srv.ACMEManager.CA = caURL
default:
return fmt.Errorf("unknown %q directive", child.Name)
}
}
return nil
}

View File

@ -38,6 +38,10 @@ func NewServer() *Server {
} }
} }
func (srv *Server) Load(cfg *Directive) error {
return parseConfig(srv, cfg)
}
func (srv *Server) RegisterListener(addr string) *Listener { func (srv *Server) RegisterListener(addr string) *Listener {
// TODO: normalize addr with net.LookupPort // TODO: normalize addr with net.LookupPort
ln, ok := srv.Listeners[addr] ln, ok := srv.Listeners[addr]