Add config reloading

Instead of updating the configuration, we configure a new Server instance and
then migrate Listeners that still exist to it. Open client connections are
left completely untouched.

Closes https://todo.sr.ht/~emersion/tlstunnel/1
This commit is contained in:
minus 2020-12-22 12:06:14 +01:00 committed by Simon Ser
parent 09d28676a6
commit 4548a7fe65
No known key found for this signature in database
GPG Key ID: 0FDE7BE0E88F5E48
3 changed files with 152 additions and 27 deletions

View File

@ -2,7 +2,11 @@ package main
import ( import (
"flag" "flag"
"fmt"
"log" "log"
"os"
"os/signal"
"syscall"
"git.sr.ht/~emersion/go-scfg" "git.sr.ht/~emersion/go-scfg"
"git.sr.ht/~emersion/tlstunnel" "git.sr.ht/~emersion/tlstunnel"
@ -15,13 +19,10 @@ var (
certDataPath = "" certDataPath = ""
) )
func main() { func newServer() (*tlstunnel.Server, error) {
flag.StringVar(&configPath, "config", configPath, "path to configuration file")
flag.Parse()
cfg, err := scfg.Load(configPath) cfg, err := scfg.Load(configPath)
if err != nil { if err != nil {
log.Fatalf("failed to load config file: %v", err) return nil, fmt.Errorf("failed to load config file: %w", err)
} }
srv := tlstunnel.NewServer() srv := tlstunnel.NewServer()
@ -37,7 +38,7 @@ func main() {
} }
logger, err := loggerCfg.Build() logger, err := loggerCfg.Build()
if err != nil { if err != nil {
log.Fatalf("failed to initialize zap logger: %v", err) return nil, fmt.Errorf("failed to initialize zap logger: %w", err)
} }
srv.ACMEConfig.Logger = logger srv.ACMEConfig.Logger = logger
srv.ACMEManager.Logger = logger srv.ACMEManager.Logger = logger
@ -47,12 +48,48 @@ func main() {
} }
if err := srv.Load(cfg); err != nil { if err := srv.Load(cfg); err != nil {
log.Fatal(err) return nil, err
} }
return srv, nil
}
func main() {
flag.StringVar(&configPath, "config", configPath, "path to configuration file")
flag.Parse()
srv, err := newServer()
if err != nil {
log.Fatalf("failed to create server: %v", err)
}
sigCh := make(chan os.Signal, 1)
signal.Notify(sigCh, syscall.SIGINT, syscall.SIGTERM, syscall.SIGHUP)
if err := srv.Start(); err != nil { if err := srv.Start(); err != nil {
log.Fatal(err) log.Fatal(err)
} }
select {} for sig := range sigCh {
switch sig {
case syscall.SIGINT:
case syscall.SIGTERM:
srv.Stop()
return
case syscall.SIGHUP:
log.Print("caught SIGHUP, reloading config")
newSrv, err := newServer()
if err != nil {
log.Printf("reload failed: %v", err)
continue
}
err = newSrv.Replace(srv)
if err != nil {
log.Printf("reload failed: %v", err)
continue
}
srv = newSrv
log.Print("successfully reloaded config")
}
}
} }

122
server.go
View File

@ -8,6 +8,7 @@ import (
"log" "log"
"net" "net"
"strings" "strings"
"sync/atomic"
"git.sr.ht/~emersion/go-scfg" "git.sr.ht/~emersion/go-scfg"
"github.com/caddyserver/certmagic" "github.com/caddyserver/certmagic"
@ -24,6 +25,8 @@ type Server struct {
ACMEManager *certmagic.ACMEManager ACMEManager *certmagic.ACMEManager
ACMEConfig *certmagic.Config ACMEConfig *certmagic.Config
cancelACME context.CancelFunc
} }
func NewServer() *Server { func NewServer() *Server {
@ -57,17 +60,28 @@ func (srv *Server) RegisterListener(addr string) *Listener {
return ln return ln
} }
func (srv *Server) Start() error { func (srv *Server) startACME() error {
var ctx context.Context
ctx, srv.cancelACME = context.WithCancel(context.Background())
for _, cert := range srv.UnmanagedCerts { for _, cert := range srv.UnmanagedCerts {
if err := srv.ACMEConfig.CacheUnmanagedTLSCertificate(cert, nil); err != nil { if err := srv.ACMEConfig.CacheUnmanagedTLSCertificate(cert, nil); err != nil {
return err return err
} }
} }
if err := srv.ACMEConfig.ManageAsync(context.Background(), srv.ManagedNames); err != nil { if err := srv.ACMEConfig.ManageAsync(ctx, srv.ManagedNames); err != nil {
return fmt.Errorf("failed to manage TLS certificates: %v", err) return fmt.Errorf("failed to manage TLS certificates: %v", err)
} }
return nil
}
func (srv *Server) Start() error {
if err := srv.startACME(); err != nil {
return err
}
for _, ln := range srv.Listeners { for _, ln := range srv.Listeners {
if err := ln.Start(); err != nil { if err := ln.Start(); err != nil {
return err return err
@ -76,37 +90,94 @@ func (srv *Server) Start() error {
return nil return nil
} }
type Listener struct { func (srv *Server) Stop() {
Address string srv.cancelACME()
// TODO: clean cached unmanaged certs
for _, ln := range srv.Listeners {
ln.Stop()
}
}
// Replace starts the server but takes over existing listeners from an old
// Server instance. The old instance keeps running unchanged if Replace
// returns an error.
func (srv *Server) Replace(old *Server) error {
// Try to start new listeners
for addr, ln := range srv.Listeners {
if _, ok := old.Listeners[addr]; ok {
continue
}
if err := ln.Start(); err != nil {
for _, ln2 := range srv.Listeners {
ln2.Stop()
}
return err
}
}
// Restart ACME
old.cancelACME()
if err := srv.startACME(); err != nil {
for _, ln2 := range srv.Listeners {
ln2.Stop()
}
return err
}
// TODO: clean cached unmanaged certs
// Take over existing listeners and terminate old ones
for addr, oldLn := range old.Listeners {
if ln, ok := srv.Listeners[addr]; ok {
srv.Listeners[addr] = oldLn.UpdateFrom(ln)
} else {
oldLn.Stop()
}
}
return nil
}
type listenerHandles struct {
Server *Server Server *Server
Frontends map[string]*Frontend // indexed by server name Frontends map[string]*Frontend // indexed by server name
} }
type Listener struct {
Address string
netLn net.Listener
atomic atomic.Value
}
func newListener(srv *Server, addr string) *Listener { func newListener(srv *Server, addr string) *Listener {
return &Listener{ ln := &Listener{
Address: addr, Address: addr,
}
ln.atomic.Store(&listenerHandles{
Server: srv, Server: srv,
Frontends: make(map[string]*Frontend), Frontends: make(map[string]*Frontend),
} })
return ln
} }
func (ln *Listener) RegisterFrontend(name string, fe *Frontend) error { func (ln *Listener) RegisterFrontend(name string, fe *Frontend) error {
if _, ok := ln.Frontends[name]; ok { fes := ln.atomic.Load().(*listenerHandles).Frontends
if _, ok := fes[name]; ok {
return fmt.Errorf("listener %q: duplicate frontends for server name %q", ln.Address, name) return fmt.Errorf("listener %q: duplicate frontends for server name %q", ln.Address, name)
} }
ln.Frontends[name] = fe fes[name] = fe
return nil return nil
} }
func (ln *Listener) Start() error { func (ln *Listener) Start() error {
netLn, err := net.Listen("tcp", ln.Address) var err error
ln.netLn, err = net.Listen("tcp", ln.Address)
if err != nil { if err != nil {
return err return err
} }
log.Printf("listening on %q", ln.Address) log.Printf("listening on %q", ln.Address)
go func() { go func() {
if err := ln.serve(netLn); err != nil { if err := ln.serve(); err != nil {
log.Fatalf("listener %q: %v", ln.Address, err) log.Fatalf("listener %q: %v", ln.Address, err)
} }
}() }()
@ -114,10 +185,22 @@ func (ln *Listener) Start() error {
return nil return nil
} }
func (ln *Listener) serve(netLn net.Listener) error { func (ln *Listener) Stop() {
ln.netLn.Close()
}
func (ln *Listener) UpdateFrom(new *Listener) *Listener {
ln.atomic.Store(new.atomic.Load())
return ln
}
func (ln *Listener) serve() error {
for { for {
conn, err := netLn.Accept() conn, err := ln.netLn.Accept()
if err != nil { if err != nil && strings.Contains(err.Error(), "use of closed network connection") {
// Listening socket has been closed by Stop()
return nil
} else if err != nil {
return fmt.Errorf("failed to accept connection: %v", err) return fmt.Errorf("failed to accept connection: %v", err)
} }
@ -131,9 +214,10 @@ func (ln *Listener) serve(netLn net.Listener) error {
func (ln *Listener) handle(conn net.Conn) error { func (ln *Listener) handle(conn net.Conn) error {
defer conn.Close() defer conn.Close()
srv := ln.atomic.Load().(*listenerHandles).Server
// TODO: setup timeouts // TODO: setup timeouts
tlsConfig := ln.Server.ACMEConfig.TLSConfig() tlsConfig := srv.ACMEConfig.TLSConfig()
getConfigForClient := tlsConfig.GetConfigForClient getConfigForClient := tlsConfig.GetConfigForClient
tlsConfig.GetConfigForClient = func(hello *tls.ClientHelloInfo) (*tls.Config, error) { tlsConfig.GetConfigForClient = func(hello *tls.ClientHelloInfo) (*tls.Config, error) {
// Call previous GetConfigForClient function, if any // Call previous GetConfigForClient function, if any
@ -145,7 +229,7 @@ func (ln *Listener) handle(conn net.Conn) error {
return nil, err return nil, err
} }
} else { } else {
tlsConfig = ln.Server.ACMEConfig.TLSConfig() tlsConfig = srv.ACMEConfig.TLSConfig()
} }
fe, err := ln.matchFrontend(hello.ServerName) fe, err := ln.matchFrontend(hello.ServerName)
@ -171,18 +255,20 @@ func (ln *Listener) handle(conn net.Conn) error {
} }
func (ln *Listener) matchFrontend(serverName string) (*Frontend, error) { func (ln *Listener) matchFrontend(serverName string) (*Frontend, error) {
fe, ok := ln.Frontends[serverName] fes := ln.atomic.Load().(*listenerHandles).Frontends
fe, ok := fes[serverName]
if !ok { if !ok {
// Match wildcard certificates, allowing only a single, non-partial // Match wildcard certificates, allowing only a single, non-partial
// wildcard, in the left-most label // wildcard, in the left-most label
i := strings.IndexByte(serverName, '.') i := strings.IndexByte(serverName, '.')
// Don't allow wildcards with only a TLD (e.g. *.com) // Don't allow wildcards with only a TLD (e.g. *.com)
if i >= 0 && strings.IndexByte(serverName[i+1:], '.') >= 0 { if i >= 0 && strings.IndexByte(serverName[i+1:], '.') >= 0 {
fe, ok = ln.Frontends["*"+serverName[i:]] fe, ok = fes["*"+serverName[i:]]
} }
} }
if !ok { if !ok {
fe, ok = ln.Frontends[""] fe, ok = fes[""]
} }
if !ok { if !ok {
return nil, fmt.Errorf("can't find frontend for server name %q", serverName) return nil, fmt.Errorf("can't find frontend for server name %q", serverName)

View File

@ -27,6 +27,8 @@ The config file has one directive per line. Directives have a name, followed
by parameters separated by space characters. Directives may have children in by parameters separated by space characters. Directives may have children in
blocks delimited by "{" and "}". Lines beginning with "#" are comments. blocks delimited by "{" and "}". Lines beginning with "#" are comments.
tlstunnel will reload the config file when it receives the HUP signal.
Example: Example:
``` ```