From c0f5ca6b39257434f681f40bd680cc5a1a4f6ab8 Mon Sep 17 00:00:00 2001 From: Simon Ser Date: Tue, 8 Sep 2020 17:13:39 +0200 Subject: [PATCH] Implement basic TCP proxy --- config.go | 50 +++++++++++++++++++++++++------ main.go | 90 ++++++++++++++++++++++++++++++++++++++++++++++++++++++- server.go | 59 ++++++++++++++++++++++++++++++++++++ 3 files changed, 189 insertions(+), 10 deletions(-) create mode 100644 server.go diff --git a/config.go b/config.go index 48105a8..9032094 100644 --- a/config.go +++ b/config.go @@ -1,19 +1,52 @@ package main import ( - "os" - "io" "bufio" "fmt" + "io" + "os" "github.com/google/shlex" ) type Directive struct { - Params []string + Name string + Params []string Children []*Directive } +func (d *Directive) ParseParams(params ...*string) error { + if len(d.Params) < len(params) { + return fmt.Errorf("directive %q: want %v params, got %v", d.Name, len(params), len(d.Params)) + } + for i, ptr := range params { + if ptr == nil { + continue + } + *ptr = d.Params[i] + } + return nil +} + +func (d *Directive) ChildrenByName(name string) []*Directive { + l := make([]*Directive, 0, len(d.Children)) + for _, child := range d.Children { + if child.Name == name { + l = append(l, child) + } + } + return l +} + +func (d *Directive) ChildByName(name string) *Directive { + for _, child := range d.Children { + if child.Name == name { + return child + } + } + return nil +} + func Load(path string) ([]*Directive, error) { f, err := os.Open(path) if err != nil { @@ -38,7 +71,7 @@ func Parse(r io.Reader) ([]*Directive, error) { continue } - if len(words) == 1 && l[len(l) - 1] == '}' { + if len(words) == 1 && l[len(l)-1] == '}' { if cur == nil { return nil, fmt.Errorf("unexpected '}'") } @@ -47,14 +80,13 @@ func Parse(r io.Reader) ([]*Directive, error) { } var d *Directive - if words[len(words) - 1] == "{" && l[len(l) - 1] == '{' { - d = &Directive{ - Params: words[:len(words) - 1], - } + if words[len(words)-1] == "{" && l[len(l)-1] == '{' { + words = words[:len(words)-1] + d = &Directive{Params: words} cur = d directives = append(directives, d) } else { - d = &Directive{Params: words} + d = &Directive{Name: words[0], Params: words[1:]} if cur != nil { cur.Children = append(cur.Children, d) } else { diff --git a/main.go b/main.go index 7f4bd33..0b17b08 100644 --- a/main.go +++ b/main.go @@ -1,7 +1,11 @@ package main import ( + "fmt" "log" + "net" + "net/url" + "strings" ) func main() { @@ -9,5 +13,89 @@ func main() { if err != nil { log.Fatalf("failed to load config file: %v", err) } - _ = directives + + srv := &Server{} + + for _, d := range directives { + if err := parseFrontend(srv, d); err != nil { + log.Fatalf("failed to parse frontend: %v", err) + } + } + + select {} +} + +func parseFrontend(srv *Server, d *Directive) error { + frontend := &Frontend{Server: srv} + srv.Frontends = append(srv.Frontends, frontend) + + // TODO: support multiple backends + backendDirective := d.ChildByName("backend") + if backendDirective == nil { + return fmt.Errorf("missing backend directive in frontend block") + } + if err := parseBackend(&frontend.Backend, backendDirective); err != nil { + return err + } + + for _, listenDirective := range d.ChildrenByName("listen") { + var listenAddr string + if err := listenDirective.ParseParams(&listenAddr); err != nil { + return err + } + + host, port, err := net.SplitHostPort(listenAddr) + if err != nil { + return fmt.Errorf("failed to parse listen address %q: %v", listenAddr, err) + } + + // TODO: come up with something more robust + if host != "localhost" && net.ParseIP(host) == nil { + host = "" + } + + ln, err := net.Listen("tcp", net.JoinHostPort(host, port)) + if err != nil { + return fmt.Errorf("failed to listen on %q: %v", listenAddr, err) + } + + go func() { + if err := frontend.Serve(ln); err != nil { + log.Fatalf("failed to serve: %v", err) + } + }() + } + + return nil +} + +func parseBackend(backend *Backend, d *Directive) error { + var backendURI string + if err := d.ParseParams(&backendURI); err != nil { + return err + } + if !strings.Contains(backendURI, ":/") { + // This is a raw domain name, make it an URL with an empty scheme + backendURI = "//" + backendURI + } + + u, err := url.Parse(backendURI) + if err != nil { + return fmt.Errorf("failed to parse backend URI %q: %v", backendURI, err) + } + + // TODO: +proxy to use the PROXY protocol + + switch u.Scheme { + case "", "tcp": + backend.Network = "tcp" + backend.Address = u.Host + case "unix": + backend.Network = "unix" + backend.Address = u.Host + default: + return fmt.Errorf("failed to setup backend %q: unsupported URI scheme", backendURI) + } + + return nil } diff --git a/server.go b/server.go new file mode 100644 index 0000000..34d09d0 --- /dev/null +++ b/server.go @@ -0,0 +1,59 @@ +package main + +import ( + "fmt" + "io" + "net" +) + +type Server struct { + Frontends []*Frontend +} + +type Frontend struct { + Server *Server + Backend Backend +} + +func (fe *Frontend) Serve(ln net.Listener) error { + for { + conn, err := ln.Accept() + if err != nil { + return fmt.Errorf("failed to accept connection: %v", err) + } + + // TODO: log errors to debug log + go fe.handle(conn) + } +} + +func (fe *Frontend) handle(downstream net.Conn) error { + defer downstream.Close() + + be := &fe.Backend + upstream, err := net.Dial(be.Network, be.Address) + if err != nil { + return fmt.Errorf("failed to dial backend: %v", err) + } + defer upstream.Close() + + return duplexCopy(upstream, downstream) +} + +type Backend struct { + Network string + Address string +} + +func duplexCopy(a, b io.ReadWriter) error { + done := make(chan error, 2) + go func() { + _, err := io.Copy(a, b) + done <- err + }() + go func() { + _, err := io.Copy(b, a) + done <- err + }() + return <-done +}