diff --git a/cmd/osiris/internal/config/bind.go b/cmd/osiris/internal/config/bind.go index 5467cc62..ba459030 100644 --- a/cmd/osiris/internal/config/bind.go +++ b/cmd/osiris/internal/config/bind.go @@ -7,7 +7,7 @@ import ( ) var ( - ErrCantBindToPort = errors.New("bind: can't bind to host:port") + ErrInvalidHostpost = errors.New("bind: invalid host:port") ) type Bind struct { @@ -19,25 +19,16 @@ type Bind struct { func (b *Bind) Valid() error { var errs []error - ln, err := net.Listen("tcp", b.HTTP) - if err != nil { - errs = append(errs, fmt.Errorf("%w %q: %w", ErrCantBindToPort, b.HTTP, err)) - } else { - defer ln.Close() + if _, _, err := net.SplitHostPort(b.HTTP); err != nil { + errs = append(errs, fmt.Errorf("%w %q: %w", ErrInvalidHostpost, b.HTTP, err)) } - ln, err = net.Listen("tcp", b.HTTPS) - if err != nil { - errs = append(errs, fmt.Errorf("%w %q: %w", ErrCantBindToPort, b.HTTPS, err)) - } else { - defer ln.Close() + if _, _, err := net.SplitHostPort(b.HTTPS); err != nil { + errs = append(errs, fmt.Errorf("%w %q: %w", ErrInvalidHostpost, b.HTTPS, err)) } - ln, err = net.Listen("tcp", b.Metrics) - if err != nil { - errs = append(errs, fmt.Errorf("%w %q: %w", ErrCantBindToPort, b.Metrics, err)) - } else { - defer ln.Close() + if _, _, err := net.SplitHostPort(b.Metrics); err != nil { + errs = append(errs, fmt.Errorf("%w %q: %w", ErrInvalidHostpost, b.Metrics, err)) } if len(errs) != 0 { diff --git a/cmd/osiris/internal/config/bind_test.go b/cmd/osiris/internal/config/bind_test.go index 56f298be..439b13a0 100644 --- a/cmd/osiris/internal/config/bind_test.go +++ b/cmd/osiris/internal/config/bind_test.go @@ -24,7 +24,7 @@ func TestBindValid(t *testing.T) { err: nil, }, { - name: "reused ports", + name: "invalid ports", precondition: func(t *testing.T) { ln, err := net.Listen("tcp", ":8081") if err != nil { @@ -33,11 +33,11 @@ func TestBindValid(t *testing.T) { t.Cleanup(func() { ln.Close() }) }, bind: Bind{ - HTTP: ":8081", - HTTPS: ":8081", - Metrics: ":8081", + HTTP: "", + HTTPS: "", + Metrics: "", }, - err: ErrCantBindToPort, + err: ErrInvalidHostpost, }, } { t.Run(tt.name, func(t *testing.T) { diff --git a/cmd/osiris/internal/entrypoint/entrypoint.go b/cmd/osiris/internal/entrypoint/entrypoint.go index a64779f9..0380366d 100644 --- a/cmd/osiris/internal/entrypoint/entrypoint.go +++ b/cmd/osiris/internal/entrypoint/entrypoint.go @@ -33,6 +33,8 @@ func Main(ctx context.Context, opts Options) error { if err != nil { return err } + rtr.opts = opts + go rtr.backgroundReloadConfig(ctx) g, gCtx := errgroup.WithContext(ctx) diff --git a/cmd/osiris/internal/entrypoint/router.go b/cmd/osiris/internal/entrypoint/router.go index 72f05822..b936dbb6 100644 --- a/cmd/osiris/internal/entrypoint/router.go +++ b/cmd/osiris/internal/entrypoint/router.go @@ -10,13 +10,18 @@ import ( "net/http" "net/http/httputil" "net/url" + "os" + "os/signal" "strings" "sync" + "syscall" + "time" "github.com/TecharoHQ/anubis/cmd/osiris/internal/config" "github.com/TecharoHQ/anubis/internal" "github.com/TecharoHQ/anubis/internal/fingerprint" "github.com/felixge/httpsnoop" + "github.com/hashicorp/hcl/v2/hclsimple" "github.com/lum8rjack/go-ja4h" "github.com/prometheus/client_golang/prometheus" "github.com/prometheus/client_golang/prometheus/promauto" @@ -52,6 +57,7 @@ type Router struct { lock sync.RWMutex routes map[string]http.Handler tlsCerts map[string]*tls.Certificate + opts Options } func (rtr *Router) setConfig(c config.Toplevel) error { @@ -143,6 +149,48 @@ func (rtr *Router) GetCertificate(hello *tls.ClientHelloInfo) (*tls.Certificate, return cert, nil } +func (rtr *Router) loadConfig() error { + slog.Info("reloading config", "fname", rtr.opts.ConfigFname) + var cfg config.Toplevel + if err := hclsimple.DecodeFile(rtr.opts.ConfigFname, nil, &cfg); err != nil { + return err + } + + if err := cfg.Valid(); err != nil { + return err + } + + if err := rtr.setConfig(cfg); err != nil { + return err + } + + slog.Info("done!") + + return nil +} + +func (rtr *Router) backgroundReloadConfig(ctx context.Context) { + t := time.NewTicker(time.Hour) + defer t.Stop() + ch := make(chan os.Signal, 1) + signal.Notify(ch, syscall.SIGHUP) + + for { + select { + case <-ctx.Done(): + return + case <-t.C: + if err := rtr.loadConfig(); err != nil { + slog.Error("can't reload config", "fname", rtr.opts.ConfigFname, "err", err) + } + case <-ch: + if err := rtr.loadConfig(); err != nil { + slog.Error("can't reload config", "fname", rtr.opts.ConfigFname, "err", err) + } + } + } +} + func NewRouter(c config.Toplevel) (*Router, error) { result := &Router{ routes: map[string]http.Handler{},