diff --git a/cmd/tlstunnel/main.go b/cmd/tlstunnel/main.go index 97487e3..94cbd0b 100644 --- a/cmd/tlstunnel/main.go +++ b/cmd/tlstunnel/main.go @@ -2,11 +2,7 @@ package main import ( "flag" - "fmt" "log" - "net" - "net/url" - "strings" "git.sr.ht/~emersion/tlstunnel" ) @@ -24,19 +20,8 @@ func main() { srv := tlstunnel.NewServer() - for _, d := range cfg.Children { - var err error - switch d.Name { - case "frontend": - err = parseFrontend(srv, d) - case "tls": - err = parseTLS(srv, d) - default: - log.Fatalf("unknown %q directive", d.Name) - } - if err != nil { - log.Fatalf("directive %q: %v", d.Name, err) - } + if err := srv.Load(cfg); err != nil { + log.Fatal(err) } if err := srv.Start(); err != nil { @@ -45,92 +30,3 @@ func main() { select {} } - -func parseFrontend(srv *tlstunnel.Server, d *tlstunnel.Directive) error { - frontend := &tlstunnel.Frontend{Server: srv} - srv.Frontends = append(srv.Frontends, frontend) - - // TODO: support multiple backends - backendDirective := d.ChildByName("backend") - if backendDirective == nil { - return fmt.Errorf("missing backend directive in frontend block") - } - if err := parseBackend(&frontend.Backend, backendDirective); err != nil { - return err - } - - for _, listenAddr := range d.Params { - host, port, err := net.SplitHostPort(listenAddr) - if err != nil { - return fmt.Errorf("failed to parse listen address %q: %v", listenAddr, err) - } - - // TODO: come up with something more robust - var name string - if host != "" && host != "localhost" && net.ParseIP(host) == nil { - name = host - host = "" - - srv.ManagedNames = append(srv.ManagedNames, name) - } - - addr := net.JoinHostPort(host, port) - - ln := srv.RegisterListener(addr) - if err := ln.RegisterFrontend(name, frontend); err != nil { - return err - } - } - - return nil -} - -func parseBackend(backend *tlstunnel.Backend, d *tlstunnel.Directive) error { - var backendURI string - if err := d.ParseParams(&backendURI); err != nil { - return err - } - if !strings.Contains(backendURI, ":/") { - // This is a raw domain name, make it an URL with an empty scheme - backendURI = "//" + backendURI - } - - u, err := url.Parse(backendURI) - if err != nil { - return fmt.Errorf("failed to parse backend URI %q: %v", backendURI, err) - } - - if strings.HasSuffix(u.Scheme, "+proxy") { - u.Scheme = strings.TrimSuffix(u.Scheme, "+proxy") - backend.Proxy = true - } - - switch u.Scheme { - case "", "tcp": - backend.Network = "tcp" - backend.Address = u.Host - case "unix": - backend.Network = "unix" - backend.Address = u.Host - default: - return fmt.Errorf("failed to setup backend %q: unsupported URI scheme", backendURI) - } - - return nil -} - -func parseTLS(srv *tlstunnel.Server, d *tlstunnel.Directive) error { - for _, child := range d.Children { - switch child.Name { - case "acme_ca": - var caURL string - if err := child.ParseParams(&caURL); err != nil { - return err - } - srv.ACMEManager.CA = caURL - default: - return fmt.Errorf("unknown %q directive", child.Name) - } - } - return nil -} diff --git a/directives.go b/directives.go new file mode 100644 index 0000000..8f3de54 --- /dev/null +++ b/directives.go @@ -0,0 +1,115 @@ +package tlstunnel + +import ( + "fmt" + "net" + "net/url" + "strings" +) + +func parseConfig(srv *Server, cfg *Directive) error { + for _, d := range cfg.Children { + var err error + switch d.Name { + case "frontend": + err = parseFrontend(srv, d) + case "tls": + err = parseTLS(srv, d) + default: + return fmt.Errorf("unknown %q directive", d.Name) + } + if err != nil { + return fmt.Errorf("directive %q: %v", d.Name, err) + } + } + return nil +} + +func parseFrontend(srv *Server, d *Directive) error { + frontend := &Frontend{Server: srv} + srv.Frontends = append(srv.Frontends, frontend) + + // TODO: support multiple backends + backendDirective := d.ChildByName("backend") + if backendDirective == nil { + return fmt.Errorf("missing backend directive in frontend block") + } + if err := parseBackend(&frontend.Backend, backendDirective); err != nil { + return err + } + + for _, listenAddr := range d.Params { + host, port, err := net.SplitHostPort(listenAddr) + if err != nil { + return fmt.Errorf("failed to parse listen address %q: %v", listenAddr, err) + } + + // TODO: come up with something more robust + var name string + if host != "" && host != "localhost" && net.ParseIP(host) == nil { + name = host + host = "" + + srv.ManagedNames = append(srv.ManagedNames, name) + } + + addr := net.JoinHostPort(host, port) + + ln := srv.RegisterListener(addr) + if err := ln.RegisterFrontend(name, frontend); err != nil { + return err + } + } + + return nil +} + +func parseBackend(backend *Backend, d *Directive) error { + var backendURI string + if err := d.ParseParams(&backendURI); err != nil { + return err + } + if !strings.Contains(backendURI, ":/") { + // This is a raw domain name, make it an URL with an empty scheme + backendURI = "//" + backendURI + } + + u, err := url.Parse(backendURI) + if err != nil { + return fmt.Errorf("failed to parse backend URI %q: %v", backendURI, err) + } + + if strings.HasSuffix(u.Scheme, "+proxy") { + u.Scheme = strings.TrimSuffix(u.Scheme, "+proxy") + backend.Proxy = true + } + + switch u.Scheme { + case "", "tcp": + backend.Network = "tcp" + backend.Address = u.Host + case "unix": + backend.Network = "unix" + backend.Address = u.Host + default: + return fmt.Errorf("failed to setup backend %q: unsupported URI scheme", backendURI) + } + + return nil +} + +func parseTLS(srv *Server, d *Directive) error { + for _, child := range d.Children { + switch child.Name { + case "acme_ca": + var caURL string + if err := child.ParseParams(&caURL); err != nil { + return err + } + srv.ACMEManager.CA = caURL + default: + return fmt.Errorf("unknown %q directive", child.Name) + } + } + return nil +} diff --git a/server.go b/server.go index 3f90eeb..590dd25 100644 --- a/server.go +++ b/server.go @@ -38,6 +38,10 @@ func NewServer() *Server { } } +func (srv *Server) Load(cfg *Directive) error { + return parseConfig(srv, cfg) +} + func (srv *Server) RegisterListener(addr string) *Listener { // TODO: normalize addr with net.LookupPort ln, ok := srv.Listeners[addr]