Add config reloading
Instead of updating the configuration, we configure a new Server instance and then migrate Listeners that still exist to it. Open client connections are left completely untouched. Closes https://todo.sr.ht/~emersion/tlstunnel/1
This commit is contained in:
parent
09d28676a6
commit
4548a7fe65
|
@ -2,7 +2,11 @@ package main
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"flag"
|
"flag"
|
||||||
|
"fmt"
|
||||||
"log"
|
"log"
|
||||||
|
"os"
|
||||||
|
"os/signal"
|
||||||
|
"syscall"
|
||||||
|
|
||||||
"git.sr.ht/~emersion/go-scfg"
|
"git.sr.ht/~emersion/go-scfg"
|
||||||
"git.sr.ht/~emersion/tlstunnel"
|
"git.sr.ht/~emersion/tlstunnel"
|
||||||
|
@ -15,13 +19,10 @@ var (
|
||||||
certDataPath = ""
|
certDataPath = ""
|
||||||
)
|
)
|
||||||
|
|
||||||
func main() {
|
func newServer() (*tlstunnel.Server, error) {
|
||||||
flag.StringVar(&configPath, "config", configPath, "path to configuration file")
|
|
||||||
flag.Parse()
|
|
||||||
|
|
||||||
cfg, err := scfg.Load(configPath)
|
cfg, err := scfg.Load(configPath)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Fatalf("failed to load config file: %v", err)
|
return nil, fmt.Errorf("failed to load config file: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
srv := tlstunnel.NewServer()
|
srv := tlstunnel.NewServer()
|
||||||
|
@ -37,7 +38,7 @@ func main() {
|
||||||
}
|
}
|
||||||
logger, err := loggerCfg.Build()
|
logger, err := loggerCfg.Build()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Fatalf("failed to initialize zap logger: %v", err)
|
return nil, fmt.Errorf("failed to initialize zap logger: %w", err)
|
||||||
}
|
}
|
||||||
srv.ACMEConfig.Logger = logger
|
srv.ACMEConfig.Logger = logger
|
||||||
srv.ACMEManager.Logger = logger
|
srv.ACMEManager.Logger = logger
|
||||||
|
@ -47,12 +48,48 @@ func main() {
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := srv.Load(cfg); err != nil {
|
if err := srv.Load(cfg); err != nil {
|
||||||
log.Fatal(err)
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
return srv, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func main() {
|
||||||
|
flag.StringVar(&configPath, "config", configPath, "path to configuration file")
|
||||||
|
flag.Parse()
|
||||||
|
|
||||||
|
srv, err := newServer()
|
||||||
|
if err != nil {
|
||||||
|
log.Fatalf("failed to create server: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
sigCh := make(chan os.Signal, 1)
|
||||||
|
signal.Notify(sigCh, syscall.SIGINT, syscall.SIGTERM, syscall.SIGHUP)
|
||||||
|
|
||||||
if err := srv.Start(); err != nil {
|
if err := srv.Start(); err != nil {
|
||||||
log.Fatal(err)
|
log.Fatal(err)
|
||||||
}
|
}
|
||||||
|
|
||||||
select {}
|
for sig := range sigCh {
|
||||||
|
switch sig {
|
||||||
|
case syscall.SIGINT:
|
||||||
|
case syscall.SIGTERM:
|
||||||
|
srv.Stop()
|
||||||
|
return
|
||||||
|
case syscall.SIGHUP:
|
||||||
|
log.Print("caught SIGHUP, reloading config")
|
||||||
|
newSrv, err := newServer()
|
||||||
|
if err != nil {
|
||||||
|
log.Printf("reload failed: %v", err)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
err = newSrv.Replace(srv)
|
||||||
|
if err != nil {
|
||||||
|
log.Printf("reload failed: %v", err)
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
srv = newSrv
|
||||||
|
log.Print("successfully reloaded config")
|
||||||
|
}
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|
124
server.go
124
server.go
|
@ -8,6 +8,7 @@ import (
|
||||||
"log"
|
"log"
|
||||||
"net"
|
"net"
|
||||||
"strings"
|
"strings"
|
||||||
|
"sync/atomic"
|
||||||
|
|
||||||
"git.sr.ht/~emersion/go-scfg"
|
"git.sr.ht/~emersion/go-scfg"
|
||||||
"github.com/caddyserver/certmagic"
|
"github.com/caddyserver/certmagic"
|
||||||
|
@ -24,6 +25,8 @@ type Server struct {
|
||||||
|
|
||||||
ACMEManager *certmagic.ACMEManager
|
ACMEManager *certmagic.ACMEManager
|
||||||
ACMEConfig *certmagic.Config
|
ACMEConfig *certmagic.Config
|
||||||
|
|
||||||
|
cancelACME context.CancelFunc
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewServer() *Server {
|
func NewServer() *Server {
|
||||||
|
@ -57,17 +60,28 @@ func (srv *Server) RegisterListener(addr string) *Listener {
|
||||||
return ln
|
return ln
|
||||||
}
|
}
|
||||||
|
|
||||||
func (srv *Server) Start() error {
|
func (srv *Server) startACME() error {
|
||||||
|
var ctx context.Context
|
||||||
|
ctx, srv.cancelACME = context.WithCancel(context.Background())
|
||||||
|
|
||||||
for _, cert := range srv.UnmanagedCerts {
|
for _, cert := range srv.UnmanagedCerts {
|
||||||
if err := srv.ACMEConfig.CacheUnmanagedTLSCertificate(cert, nil); err != nil {
|
if err := srv.ACMEConfig.CacheUnmanagedTLSCertificate(cert, nil); err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
if err := srv.ACMEConfig.ManageAsync(context.Background(), srv.ManagedNames); err != nil {
|
if err := srv.ACMEConfig.ManageAsync(ctx, srv.ManagedNames); err != nil {
|
||||||
return fmt.Errorf("failed to manage TLS certificates: %v", err)
|
return fmt.Errorf("failed to manage TLS certificates: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (srv *Server) Start() error {
|
||||||
|
if err := srv.startACME(); err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
for _, ln := range srv.Listeners {
|
for _, ln := range srv.Listeners {
|
||||||
if err := ln.Start(); err != nil {
|
if err := ln.Start(); err != nil {
|
||||||
return err
|
return err
|
||||||
|
@ -76,37 +90,94 @@ func (srv *Server) Start() error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
type Listener struct {
|
func (srv *Server) Stop() {
|
||||||
Address string
|
srv.cancelACME()
|
||||||
|
// TODO: clean cached unmanaged certs
|
||||||
|
for _, ln := range srv.Listeners {
|
||||||
|
ln.Stop()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Replace starts the server but takes over existing listeners from an old
|
||||||
|
// Server instance. The old instance keeps running unchanged if Replace
|
||||||
|
// returns an error.
|
||||||
|
func (srv *Server) Replace(old *Server) error {
|
||||||
|
// Try to start new listeners
|
||||||
|
for addr, ln := range srv.Listeners {
|
||||||
|
if _, ok := old.Listeners[addr]; ok {
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
if err := ln.Start(); err != nil {
|
||||||
|
for _, ln2 := range srv.Listeners {
|
||||||
|
ln2.Stop()
|
||||||
|
}
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// Restart ACME
|
||||||
|
old.cancelACME()
|
||||||
|
if err := srv.startACME(); err != nil {
|
||||||
|
for _, ln2 := range srv.Listeners {
|
||||||
|
ln2.Stop()
|
||||||
|
}
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
// TODO: clean cached unmanaged certs
|
||||||
|
|
||||||
|
// Take over existing listeners and terminate old ones
|
||||||
|
for addr, oldLn := range old.Listeners {
|
||||||
|
if ln, ok := srv.Listeners[addr]; ok {
|
||||||
|
srv.Listeners[addr] = oldLn.UpdateFrom(ln)
|
||||||
|
} else {
|
||||||
|
oldLn.Stop()
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
type listenerHandles struct {
|
||||||
Server *Server
|
Server *Server
|
||||||
Frontends map[string]*Frontend // indexed by server name
|
Frontends map[string]*Frontend // indexed by server name
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type Listener struct {
|
||||||
|
Address string
|
||||||
|
netLn net.Listener
|
||||||
|
atomic atomic.Value
|
||||||
|
}
|
||||||
|
|
||||||
func newListener(srv *Server, addr string) *Listener {
|
func newListener(srv *Server, addr string) *Listener {
|
||||||
return &Listener{
|
ln := &Listener{
|
||||||
Address: addr,
|
Address: addr,
|
||||||
|
}
|
||||||
|
ln.atomic.Store(&listenerHandles{
|
||||||
Server: srv,
|
Server: srv,
|
||||||
Frontends: make(map[string]*Frontend),
|
Frontends: make(map[string]*Frontend),
|
||||||
}
|
})
|
||||||
|
return ln
|
||||||
}
|
}
|
||||||
|
|
||||||
func (ln *Listener) RegisterFrontend(name string, fe *Frontend) error {
|
func (ln *Listener) RegisterFrontend(name string, fe *Frontend) error {
|
||||||
if _, ok := ln.Frontends[name]; ok {
|
fes := ln.atomic.Load().(*listenerHandles).Frontends
|
||||||
|
if _, ok := fes[name]; ok {
|
||||||
return fmt.Errorf("listener %q: duplicate frontends for server name %q", ln.Address, name)
|
return fmt.Errorf("listener %q: duplicate frontends for server name %q", ln.Address, name)
|
||||||
}
|
}
|
||||||
ln.Frontends[name] = fe
|
fes[name] = fe
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (ln *Listener) Start() error {
|
func (ln *Listener) Start() error {
|
||||||
netLn, err := net.Listen("tcp", ln.Address)
|
var err error
|
||||||
|
ln.netLn, err = net.Listen("tcp", ln.Address)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
log.Printf("listening on %q", ln.Address)
|
log.Printf("listening on %q", ln.Address)
|
||||||
|
|
||||||
go func() {
|
go func() {
|
||||||
if err := ln.serve(netLn); err != nil {
|
if err := ln.serve(); err != nil {
|
||||||
log.Fatalf("listener %q: %v", ln.Address, err)
|
log.Fatalf("listener %q: %v", ln.Address, err)
|
||||||
}
|
}
|
||||||
}()
|
}()
|
||||||
|
@ -114,10 +185,22 @@ func (ln *Listener) Start() error {
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (ln *Listener) serve(netLn net.Listener) error {
|
func (ln *Listener) Stop() {
|
||||||
|
ln.netLn.Close()
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ln *Listener) UpdateFrom(new *Listener) *Listener {
|
||||||
|
ln.atomic.Store(new.atomic.Load())
|
||||||
|
return ln
|
||||||
|
}
|
||||||
|
|
||||||
|
func (ln *Listener) serve() error {
|
||||||
for {
|
for {
|
||||||
conn, err := netLn.Accept()
|
conn, err := ln.netLn.Accept()
|
||||||
if err != nil {
|
if err != nil && strings.Contains(err.Error(), "use of closed network connection") {
|
||||||
|
// Listening socket has been closed by Stop()
|
||||||
|
return nil
|
||||||
|
} else if err != nil {
|
||||||
return fmt.Errorf("failed to accept connection: %v", err)
|
return fmt.Errorf("failed to accept connection: %v", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
|
@ -131,9 +214,10 @@ func (ln *Listener) serve(netLn net.Listener) error {
|
||||||
|
|
||||||
func (ln *Listener) handle(conn net.Conn) error {
|
func (ln *Listener) handle(conn net.Conn) error {
|
||||||
defer conn.Close()
|
defer conn.Close()
|
||||||
|
srv := ln.atomic.Load().(*listenerHandles).Server
|
||||||
|
|
||||||
// TODO: setup timeouts
|
// TODO: setup timeouts
|
||||||
tlsConfig := ln.Server.ACMEConfig.TLSConfig()
|
tlsConfig := srv.ACMEConfig.TLSConfig()
|
||||||
getConfigForClient := tlsConfig.GetConfigForClient
|
getConfigForClient := tlsConfig.GetConfigForClient
|
||||||
tlsConfig.GetConfigForClient = func(hello *tls.ClientHelloInfo) (*tls.Config, error) {
|
tlsConfig.GetConfigForClient = func(hello *tls.ClientHelloInfo) (*tls.Config, error) {
|
||||||
// Call previous GetConfigForClient function, if any
|
// Call previous GetConfigForClient function, if any
|
||||||
|
@ -145,7 +229,7 @@ func (ln *Listener) handle(conn net.Conn) error {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
} else {
|
} else {
|
||||||
tlsConfig = ln.Server.ACMEConfig.TLSConfig()
|
tlsConfig = srv.ACMEConfig.TLSConfig()
|
||||||
}
|
}
|
||||||
|
|
||||||
fe, err := ln.matchFrontend(hello.ServerName)
|
fe, err := ln.matchFrontend(hello.ServerName)
|
||||||
|
@ -171,18 +255,20 @@ func (ln *Listener) handle(conn net.Conn) error {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (ln *Listener) matchFrontend(serverName string) (*Frontend, error) {
|
func (ln *Listener) matchFrontend(serverName string) (*Frontend, error) {
|
||||||
fe, ok := ln.Frontends[serverName]
|
fes := ln.atomic.Load().(*listenerHandles).Frontends
|
||||||
|
|
||||||
|
fe, ok := fes[serverName]
|
||||||
if !ok {
|
if !ok {
|
||||||
// Match wildcard certificates, allowing only a single, non-partial
|
// Match wildcard certificates, allowing only a single, non-partial
|
||||||
// wildcard, in the left-most label
|
// wildcard, in the left-most label
|
||||||
i := strings.IndexByte(serverName, '.')
|
i := strings.IndexByte(serverName, '.')
|
||||||
// Don't allow wildcards with only a TLD (e.g. *.com)
|
// Don't allow wildcards with only a TLD (e.g. *.com)
|
||||||
if i >= 0 && strings.IndexByte(serverName[i+1:], '.') >= 0 {
|
if i >= 0 && strings.IndexByte(serverName[i+1:], '.') >= 0 {
|
||||||
fe, ok = ln.Frontends["*"+serverName[i:]]
|
fe, ok = fes["*"+serverName[i:]]
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
if !ok {
|
if !ok {
|
||||||
fe, ok = ln.Frontends[""]
|
fe, ok = fes[""]
|
||||||
}
|
}
|
||||||
if !ok {
|
if !ok {
|
||||||
return nil, fmt.Errorf("can't find frontend for server name %q", serverName)
|
return nil, fmt.Errorf("can't find frontend for server name %q", serverName)
|
||||||
|
|
|
@ -27,6 +27,8 @@ The config file has one directive per line. Directives have a name, followed
|
||||||
by parameters separated by space characters. Directives may have children in
|
by parameters separated by space characters. Directives may have children in
|
||||||
blocks delimited by "{" and "}". Lines beginning with "#" are comments.
|
blocks delimited by "{" and "}". Lines beginning with "#" are comments.
|
||||||
|
|
||||||
|
tlstunnel will reload the config file when it receives the HUP signal.
|
||||||
|
|
||||||
Example:
|
Example:
|
||||||
|
|
||||||
```
|
```
|
||||||
|
|
Loading…
Reference in New Issue