diff --git a/.github/actions/spelling/expect.txt b/.github/actions/spelling/expect.txt index 8f1a5abb..75aa9e2a 100644 --- a/.github/actions/spelling/expect.txt +++ b/.github/actions/spelling/expect.txt @@ -47,6 +47,7 @@ cachediptoasn Caddyfile caninetools Cardyb +CAs celchecker celphase cerr @@ -203,8 +204,10 @@ kagi kagibot Keyfunc keypair +keypairreloader KHTML kinda +kpr KUBECONFIG lcj ldflags @@ -229,6 +232,7 @@ metarefresh metrix mimi Minfilia +minica mistralai mnt Mojeek @@ -313,6 +317,7 @@ searchbot searx sebest secretplans +selfsigned Semrush Seo setsebool diff --git a/data/botPolicies.yaml b/data/botPolicies.yaml index fdb8539b..fd1d1a34 100644 --- a/data/botPolicies.yaml +++ b/data/botPolicies.yaml @@ -174,6 +174,20 @@ status_codes: # metrics: # bind: ":9090" # network: "tcp" +# +# # To serve metrics over TLS, set the path to the right TLS certificate and key +# # here. When the files change on disk, they will automatically be reloaded. +# # +# # https://anubis.techaro.lol/docs/admin/policies#tls +# tls: +# certificate: /path/to/tls.crt +# key: /path/to/tls.key +# +# # If you want to secure your metrics endpoint using mutual TLS (mTLS), set +# # the path to a certificate authority public certificate here. +# # +# # https://anubis.techaro.lol/docs/admin/policies#mtls +# ca: /path/to/ca.crt # Anubis can store temporary data in one of a few backends. See the storage # backends section of the docs for more information: diff --git a/docs/docs/CHANGELOG.md b/docs/docs/CHANGELOG.md index 785dfa03..d1bd4935 100644 --- a/docs/docs/CHANGELOG.md +++ b/docs/docs/CHANGELOG.md @@ -21,6 +21,7 @@ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0 - Improve error messages and fix broken REDIRECT_DOMAINS link in docs ([#1193](https://github.com/TecharoHQ/anubis/issues/1193)) - Add Bulgarian locale ([#1394](https://github.com/TecharoHQ/anubis/pull/1394)) - Fix CEL internal errors when iterating `headers`/`query` map wrappers by implementing map iterators for `HTTPHeaders` and `URLValues` ([#1465](https://github.com/TecharoHQ/anubis/pull/1465)). +- Enable [metrics serving via TLS](./admin/policies.mdx#tls), including [mutual TLS (mTLS)](./admin/policies.mdx#mtls). ## v1.25.0: Necron diff --git a/docs/docs/admin/policies.mdx b/docs/docs/admin/policies.mdx index 37998948..06de8dd8 100644 --- a/docs/docs/admin/policies.mdx +++ b/docs/docs/admin/policies.mdx @@ -138,6 +138,39 @@ metrics: socketMode: "0700" # must be a string ``` +### TLS + +If you want to serve the metrics server over TLS, use the `tls` block: + +```yaml +metrics: + bind: ":9090" + network: "tcp" + + tls: + certificate: /path/to/tls.crt + key: /path/to/tls.key +``` + +The certificate and key will automatically be reloaded when the respective files change. + +### mTLS + +If you want to validate requests to ensure that they use a client certificate signed by a certificate authority (mutual TLS or mTLS), set the `ca` value in the `tls` block: + +```yaml +metrics: + bind: ":9090" + network: "tcp" + + tls: + certificate: /path/to/tls.crt + key: /path/to/tls.key + ca: /path/to/ca.crt +``` + +As it is not expected for certificate authority certificates to change often, the CA certificate will NOT be automatically reloaded when the respective file changes. + ## Imprint / Impressum support Anubis has support for showing imprint / impressum information. This is defined in the `impressum` block of your configuration. See [Imprint / Impressum configuration](./configuration/impressum.mdx) for more information. diff --git a/lib/config/metrics.go b/lib/config/metrics.go index 37585ab2..a4fcc7fb 100644 --- a/lib/config/metrics.go +++ b/lib/config/metrics.go @@ -1,24 +1,34 @@ package config import ( + "crypto/tls" + "crypto/x509" "errors" "fmt" + "os" "strconv" ) var ( - ErrInvalidMetricsConfig = errors.New("config: invalid metrics configuration") - ErrNoMetricsBind = errors.New("config.Metrics: must define bind") - ErrNoMetricsNetwork = errors.New("config.Metrics: must define network") - ErrNoMetricsSocketMode = errors.New("config.Metrics: must define socket mode when using unix sockets") - ErrInvalidMetricsSocketMode = errors.New("config.Metrics: invalid unix socket mode") - ErrInvalidMetricsNetwork = errors.New("config.Metrics: invalid metrics network") + ErrInvalidMetricsConfig = errors.New("config: invalid metrics configuration") + ErrInvalidMetricsTLSConfig = errors.New("config: invalid metrics TLS configuration") + ErrNoMetricsBind = errors.New("config.Metrics: must define bind") + ErrNoMetricsNetwork = errors.New("config.Metrics: must define network") + ErrNoMetricsSocketMode = errors.New("config.Metrics: must define socket mode when using unix sockets") + ErrInvalidMetricsSocketMode = errors.New("config.Metrics: invalid unix socket mode") + ErrInvalidMetricsNetwork = errors.New("config.Metrics: invalid metrics network") + ErrNoMetricsTLSCertificate = errors.New("config.Metrics.TLS: must define certificate file") + ErrNoMetricsTLSKey = errors.New("config.Metrics.TLS: must define key file") + ErrInvalidMetricsTLSKeypair = errors.New("config.Metrics.TLS: keypair is invalid") + ErrInvalidMetricsCACertificate = errors.New("config.Metrics.TLS: invalid CA certificate") + ErrCantReadFile = errors.New("config: can't read required file") ) type Metrics struct { - Bind string `json:"bind" yaml:"bind"` - Network string `json:"network" yaml:"network"` - SocketMode string `json:"socketMode" yaml:"socketMode"` + Bind string `json:"bind" yaml:"bind"` + Network string `json:"network" yaml:"network"` + SocketMode string `json:"socketMode" yaml:"socketMode"` + TLS *MetricsTLS `json:"tls" yaml:"tls"` } func (m *Metrics) Valid() error { @@ -46,9 +56,78 @@ func (m *Metrics) Valid() error { errs = append(errs, ErrInvalidMetricsNetwork) } + if m.TLS != nil { + if err := m.TLS.Valid(); err != nil { + errs = append(errs, err) + } + } + if len(errs) != 0 { return errors.Join(ErrInvalidMetricsConfig, errors.Join(errs...)) } return nil } + +type MetricsTLS struct { + Certificate string `json:"certificate" yaml:"certificate"` + Key string `json:"key" yaml:"key"` + CA string `json:"ca" yaml:"ca"` +} + +func (mt *MetricsTLS) Valid() error { + var errs []error + + if mt.Certificate == "" { + errs = append(errs, ErrNoMetricsTLSCertificate) + } + + if err := canReadFile(mt.Certificate); err != nil { + errs = append(errs, fmt.Errorf("%w %s: %w", ErrCantReadFile, mt.Certificate, err)) + } + + if mt.Key == "" { + errs = append(errs, ErrNoMetricsTLSKey) + } + + if err := canReadFile(mt.Key); err != nil { + errs = append(errs, fmt.Errorf("%w %s: %w", ErrCantReadFile, mt.Key, err)) + } + + if _, err := tls.LoadX509KeyPair(mt.Certificate, mt.Key); err != nil { + errs = append(errs, fmt.Errorf("%w: %w", ErrInvalidMetricsTLSKeypair, err)) + } + + if mt.CA != "" { + caCert, err := os.ReadFile(mt.CA) + if err != nil { + errs = append(errs, fmt.Errorf("%w %s: %w", ErrCantReadFile, mt.CA, err)) + } + + certPool := x509.NewCertPool() + if !certPool.AppendCertsFromPEM(caCert) { + errs = append(errs, fmt.Errorf("%w %s", ErrInvalidMetricsCACertificate, mt.CA)) + } + } + + if len(errs) != 0 { + return errors.Join(ErrInvalidMetricsTLSConfig, errors.Join(errs...)) + } + + return nil +} + +func canReadFile(fname string) error { + fin, err := os.Open(fname) + if err != nil { + return err + } + defer fin.Close() + + data := make([]byte, 64) + if _, err := fin.Read(data); err != nil { + return fmt.Errorf("can't read %s: %w", fname, err) + } + + return nil +} diff --git a/lib/config/metrics_test.go b/lib/config/metrics_test.go index a92277b6..2e3335f5 100644 --- a/lib/config/metrics_test.go +++ b/lib/config/metrics_test.go @@ -75,6 +75,88 @@ func TestMetricsValid(t *testing.T) { }, err: ErrInvalidMetricsNetwork, }, + { + name: "invalid TLS config", + input: &Metrics{ + Bind: ":9090", + Network: "tcp", + TLS: &MetricsTLS{}, + }, + err: ErrInvalidMetricsTLSConfig, + }, + { + name: "selfsigned TLS cert", + input: &Metrics{ + Bind: ":9090", + Network: "tcp", + TLS: &MetricsTLS{ + Certificate: "./testdata/tls/selfsigned.crt", + Key: "./testdata/tls/selfsigned.key", + }, + }, + }, + { + name: "wrong path to selfsigned TLS cert", + input: &Metrics{ + Bind: ":9090", + Network: "tcp", + TLS: &MetricsTLS{ + Certificate: "./testdata/tls2/selfsigned.crt", + Key: "./testdata/tls2/selfsigned.key", + }, + }, + err: ErrCantReadFile, + }, + { + name: "unparseable TLS cert", + input: &Metrics{ + Bind: ":9090", + Network: "tcp", + TLS: &MetricsTLS{ + Certificate: "./testdata/tls/invalid.crt", + Key: "./testdata/tls/invalid.key", + }, + }, + err: ErrInvalidMetricsTLSKeypair, + }, + { + name: "mTLS with CA", + input: &Metrics{ + Bind: ":9090", + Network: "tcp", + TLS: &MetricsTLS{ + Certificate: "./testdata/tls/selfsigned.crt", + Key: "./testdata/tls/selfsigned.key", + CA: "./testdata/tls/minica.pem", + }, + }, + }, + { + name: "mTLS with nonexistent CA", + input: &Metrics{ + Bind: ":9090", + Network: "tcp", + TLS: &MetricsTLS{ + Certificate: "./testdata/tls/selfsigned.crt", + Key: "./testdata/tls/selfsigned.key", + CA: "./testdata/tls/nonexistent.crt", + }, + }, + err: ErrCantReadFile, + }, + { + name: "mTLS with invalid CA", + input: &Metrics{ + Bind: ":9090", + Network: "tcp", + TLS: &MetricsTLS{ + Certificate: "./testdata/tls/selfsigned.crt", + Key: "./testdata/tls/selfsigned.key", + CA: "./testdata/tls/invalid.crt", + }, + }, + err: ErrInvalidMetricsCACertificate, + }, } { t.Run(tt.name, func(t *testing.T) { if err := tt.input.Valid(); !errors.Is(err, tt.err) { diff --git a/lib/config/testdata/tls/1.1.1.1/cert.pem b/lib/config/testdata/tls/1.1.1.1/cert.pem new file mode 100644 index 00000000..ef50a00a --- /dev/null +++ b/lib/config/testdata/tls/1.1.1.1/cert.pem @@ -0,0 +1,12 @@ +-----BEGIN CERTIFICATE----- +MIIB1zCCAVygAwIBAgIIYO0SAFtXlVgwCgYIKoZIzj0EAwMwIDEeMBwGA1UEAxMV +bWluaWNhIHJvb3QgY2EgNDE2MmMwMB4XDTI2MDQyMjIzMjUwMVoXDTI4MDUyMjIz +MjUwMVowEjEQMA4GA1UEAxMHMS4xLjEuMTB2MBAGByqGSM49AgEGBSuBBAAiA2IA +BLsuA2LKGbEBuSA4LTm1KaKc7/QCkUOsipXR4+D5/3sWBZiAH7iWUgHwpx5YZf2q +kZn6oRda+ks0JLTQ6VhteQedmb7l86bMeDMR8p4Lg2b38l/xEr7S25UfUDKudXrO +AqNxMG8wDgYDVR0PAQH/BAQDAgWgMB0GA1UdJQQWMBQGCCsGAQUFBwMBBggrBgEF +BQcDAjAMBgNVHRMBAf8EAjAAMB8GA1UdIwQYMBaAFE/7VDxF2+cUs9bu0pJM3xoC +L1TSMA8GA1UdEQQIMAaHBAEBAQEwCgYIKoZIzj0EAwMDaQAwZgIxAPLXds9MMH4K +F5FxTf9i0PKPsLQARsABVTgwB94hMR70rqW8Pwbjl7ZGNaYlaeRHUwIxAPMQ8zoF +nim+YS1xLqQek/LXuJto8jxcfkQQBsboVzcTa5uaNRhNd5YwrpomGl3lKA== +-----END CERTIFICATE----- diff --git a/lib/config/testdata/tls/1.1.1.1/key.pem b/lib/config/testdata/tls/1.1.1.1/key.pem new file mode 100644 index 00000000..b41e05ad --- /dev/null +++ b/lib/config/testdata/tls/1.1.1.1/key.pem @@ -0,0 +1,6 @@ +-----BEGIN PRIVATE KEY----- +MIG2AgEAMBAGByqGSM49AgEGBSuBBAAiBIGeMIGbAgEBBDBN8QsHxxHGJpStu8K7 +D/FmaBBNo6c514KGFSIfqGFuREF5aOL3gN/W11yk2OIibdWhZANiAAS7LgNiyhmx +AbkgOC05tSminO/0ApFDrIqV0ePg+f97FgWYgB+4llIB8KceWGX9qpGZ+qEXWvpL +NCS00OlYbXkHnZm+5fOmzHgzEfKeC4Nm9/Jf8RK+0tuVH1AyrnV6zgI= +-----END PRIVATE KEY----- diff --git a/lib/config/testdata/tls/invalid.crt b/lib/config/testdata/tls/invalid.crt new file mode 100644 index 00000000..e69de29b diff --git a/lib/config/testdata/tls/invalid.key b/lib/config/testdata/tls/invalid.key new file mode 100644 index 00000000..e69de29b diff --git a/lib/config/testdata/tls/minica-key.pem b/lib/config/testdata/tls/minica-key.pem new file mode 100644 index 00000000..2110df56 --- /dev/null +++ b/lib/config/testdata/tls/minica-key.pem @@ -0,0 +1,6 @@ +-----BEGIN PRIVATE KEY----- +MIG2AgEAMBAGByqGSM49AgEGBSuBBAAiBIGeMIGbAgEBBDDr9QQo7ZaTgUL6d73G +2BG7+YRTFJHAZa0FogRglfc+jYttL1J4/xTig3RmHoqSgrehZANiAASDhijM9Xe0 +G9Vam6AJMeKC6aWDNSLwrxNVmPxemsY/yJ1urBgnxRd9GFH6YW1ki/B8rS+Xl1UX +NnhBrukLaXvgAQQq782/5IUYGsvK5jw8+dSscYVMCQJwGfmQuaNeczQ= +-----END PRIVATE KEY----- diff --git a/lib/config/testdata/tls/minica.pem b/lib/config/testdata/tls/minica.pem new file mode 100644 index 00000000..c8acb8ab --- /dev/null +++ b/lib/config/testdata/tls/minica.pem @@ -0,0 +1,13 @@ +-----BEGIN CERTIFICATE----- +MIIB+zCCAYKgAwIBAgIIQWLAtv4ijQ0wCgYIKoZIzj0EAwMwIDEeMBwGA1UEAxMV +bWluaWNhIHJvb3QgY2EgNDE2MmMwMCAXDTI2MDQyMjIzMjUwMVoYDzIxMjYwNDIy +MjMyNTAxWjAgMR4wHAYDVQQDExVtaW5pY2Egcm9vdCBjYSA0MTYyYzAwdjAQBgcq +hkjOPQIBBgUrgQQAIgNiAASDhijM9Xe0G9Vam6AJMeKC6aWDNSLwrxNVmPxemsY/ +yJ1urBgnxRd9GFH6YW1ki/B8rS+Xl1UXNnhBrukLaXvgAQQq782/5IUYGsvK5jw8 ++dSscYVMCQJwGfmQuaNeczSjgYYwgYMwDgYDVR0PAQH/BAQDAgKEMB0GA1UdJQQW +MBQGCCsGAQUFBwMBBggrBgEFBQcDAjASBgNVHRMBAf8ECDAGAQH/AgEAMB0GA1Ud +DgQWBBRP+1Q8RdvnFLPW7tKSTN8aAi9U0jAfBgNVHSMEGDAWgBRP+1Q8RdvnFLPW +7tKSTN8aAi9U0jAKBggqhkjOPQQDAwNnADBkAjBfY7vb7cuLTjg7uoe+kl07FMYT +BGMSnWdhN3yXqMUS3A6XZxD/LntXT6V7yFOlAJYCMH7w8/ATYaTqbk2jBRyQt9/x +ajN+kZ6ZK+fKttqE8CD62mbHg09xoNxRq+K2I3PVyQ== +-----END CERTIFICATE----- diff --git a/lib/config/testdata/tls/selfsigned.crt b/lib/config/testdata/tls/selfsigned.crt new file mode 100644 index 00000000..a431395c --- /dev/null +++ b/lib/config/testdata/tls/selfsigned.crt @@ -0,0 +1,11 @@ +-----BEGIN CERTIFICATE----- +MIIBnzCCAVGgAwIBAgIUK39B3Ft+kU5o81IuISs79O4u1ncwBQYDK2VwMEUxCzAJ +BgNVBAYTAkFVMRMwEQYDVQQIDApTb21lLVN0YXRlMSEwHwYDVQQKDBhJbnRlcm5l +dCBXaWRnaXRzIFB0eSBMdGQwHhcNMjYwNDIyMTQyNjE4WhcNMjYwNTIyMTQyNjE4 +WjBFMQswCQYDVQQGEwJBVTETMBEGA1UECAwKU29tZS1TdGF0ZTEhMB8GA1UECgwY +SW50ZXJuZXQgV2lkZ2l0cyBQdHkgTHRkMCowBQYDK2VwAyEAfgpAUpp8MIOOdQpH +fxaw3R7mFKQRMR6Kmxzk1Xn/2VujUzBRMB0GA1UdDgQWBBSmkBmzo0RiZ2iocMR8 +uIIpz9cZyTAfBgNVHSMEGDAWgBSmkBmzo0RiZ2iocMR8uIIpz9cZyTAPBgNVHRMB +Af8EBTADAQH/MAUGAytlcANBAG37XXZrVUUzGyy3T9qsPIzvJQAGpGhdjJ7bt9O6 +sBhzrliTONPrudYuyUggWsHgFb0JlN2xs4/2HhKU+PY7AAQ= +-----END CERTIFICATE----- diff --git a/lib/config/testdata/tls/selfsigned.key b/lib/config/testdata/tls/selfsigned.key new file mode 100644 index 00000000..73e73217 --- /dev/null +++ b/lib/config/testdata/tls/selfsigned.key @@ -0,0 +1,3 @@ +-----BEGIN PRIVATE KEY----- +MC4CAQAwBQYDK2VwBCIEIL0HxjjfVlg6zQPB9/zTLq0IBzfp8gEoifEYzQZYIj+T +-----END PRIVATE KEY----- diff --git a/lib/metrics/keypairreloader.go b/lib/metrics/keypairreloader.go new file mode 100644 index 00000000..be0a61ca --- /dev/null +++ b/lib/metrics/keypairreloader.go @@ -0,0 +1,81 @@ +package metrics + +import ( + "crypto/tls" + "fmt" + "log/slog" + "os" + "sync" + "time" +) + +type KeypairReloader struct { + certMu sync.RWMutex + cert *tls.Certificate + certPath string + keyPath string + modTime time.Time + lg *slog.Logger +} + +func NewKeypairReloader(certPath, keyPath string, lg *slog.Logger) (*KeypairReloader, error) { + result := &KeypairReloader{ + certPath: certPath, + keyPath: keyPath, + lg: lg, + } + cert, err := tls.LoadX509KeyPair(certPath, keyPath) + if err != nil { + return nil, err + } + result.cert = &cert + + st, err := os.Stat(certPath) + if err != nil { + return nil, err + } + result.modTime = st.ModTime() + + return result, nil +} + +func (kpr *KeypairReloader) maybeReload() error { + kpr.lg.Debug("loading new keypair", "cert", kpr.certPath, "key", kpr.keyPath) + newCert, err := tls.LoadX509KeyPair(kpr.certPath, kpr.keyPath) + if err != nil { + return err + } + + st, err := os.Stat(kpr.certPath) + if err != nil { + return err + } + + kpr.certMu.Lock() + defer kpr.certMu.Unlock() + kpr.cert = &newCert + kpr.modTime = st.ModTime() + + return nil +} + +func (kpr *KeypairReloader) GetCertificate(*tls.ClientHelloInfo) (*tls.Certificate, error) { + st, err := os.Stat(kpr.certPath) + if err != nil { + return nil, fmt.Errorf("stat(%q): %w", kpr.certPath, err) + } + + kpr.certMu.RLock() + needsReload := st.ModTime().After(kpr.modTime) + kpr.certMu.RUnlock() + + if needsReload { + if err := kpr.maybeReload(); err != nil { + return nil, fmt.Errorf("reload cert: %w", err) + } + } + + kpr.certMu.RLock() + defer kpr.certMu.RUnlock() + return kpr.cert, nil +} diff --git a/lib/metrics/keypairreloader_test.go b/lib/metrics/keypairreloader_test.go new file mode 100644 index 00000000..2ec514e4 --- /dev/null +++ b/lib/metrics/keypairreloader_test.go @@ -0,0 +1,265 @@ +package metrics + +import ( + "bytes" + "crypto/rand" + "crypto/rsa" + "crypto/x509" + "crypto/x509/pkix" + "encoding/pem" + "errors" + "log/slog" + "math/big" + "os" + "path/filepath" + "testing" + "time" +) + +func discardLogger() *slog.Logger { + return slog.New(slog.DiscardHandler) +} + +// writeKeypair generates a fresh self-signed cert + RSA key and writes them +// as PEM files in dir. Returns the paths and the cert's DER bytes so callers +// can identify which pair was loaded. +func writeKeypair(t *testing.T, dir, prefix string) (certPath, keyPath string, certDER []byte) { + t.Helper() + + priv, err := rsa.GenerateKey(rand.Reader, 2048) + if err != nil { + t.Fatalf("rsa.GenerateKey: %v", err) + } + + tmpl := &x509.Certificate{ + SerialNumber: big.NewInt(time.Now().UnixNano()), + Subject: pkix.Name{CommonName: "keypairreloader-test"}, + NotBefore: time.Now().Add(-time.Hour), + NotAfter: time.Now().Add(time.Hour), + KeyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageKeyEncipherment, + ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth}, + BasicConstraintsValid: true, + DNSNames: []string{"keypairreloader-test"}, + } + + der, err := x509.CreateCertificate(rand.Reader, tmpl, tmpl, &priv.PublicKey, priv) + if err != nil { + t.Fatalf("x509.CreateCertificate: %v", err) + } + + certPath = filepath.Join(dir, prefix+"cert.pem") + keyPath = filepath.Join(dir, prefix+"key.pem") + + certPEM := pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: der}) + keyPEM := pem.EncodeToMemory(&pem.Block{Type: "RSA PRIVATE KEY", Bytes: x509.MarshalPKCS1PrivateKey(priv)}) + + if err := os.WriteFile(certPath, certPEM, 0o600); err != nil { + t.Fatalf("write cert: %v", err) + } + if err := os.WriteFile(keyPath, keyPEM, 0o600); err != nil { + t.Fatalf("write key: %v", err) + } + + return certPath, keyPath, der +} + +func TestNewKeypairReloader(t *testing.T) { + dir := t.TempDir() + goodCert, goodKey, _ := writeKeypair(t, dir, "good-") + + garbagePath := filepath.Join(dir, "garbage.pem") + if err := os.WriteFile(garbagePath, []byte("not a pem file"), 0o600); err != nil { + t.Fatalf("write garbage: %v", err) + } + + tests := []struct { + name string + certPath string + keyPath string + wantErr error + wantNil bool + }{ + { + name: "valid cert and key", + certPath: goodCert, + keyPath: goodKey, + }, + { + name: "missing cert file", + certPath: filepath.Join(dir, "does-not-exist.pem"), + keyPath: goodKey, + wantErr: os.ErrNotExist, + wantNil: true, + }, + { + name: "missing key file", + certPath: goodCert, + keyPath: filepath.Join(dir, "does-not-exist-key.pem"), + wantErr: os.ErrNotExist, + wantNil: true, + }, + { + name: "cert file is garbage", + certPath: garbagePath, + keyPath: goodKey, + wantNil: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + kpr, err := NewKeypairReloader(tt.certPath, tt.keyPath, discardLogger()) + + if tt.wantErr != nil && !errors.Is(err, tt.wantErr) { + t.Errorf("err = %v, want errors.Is(..., %v)", err, tt.wantErr) + } + if tt.wantErr == nil && !tt.wantNil && err != nil { + t.Errorf("unexpected err: %v", err) + } + if tt.wantNil && kpr != nil { + t.Errorf("kpr = %+v, want nil", kpr) + } + if !tt.wantNil && kpr == nil { + t.Errorf("kpr is nil, want non-nil") + } + }) + } +} + +func TestKeypairReloader_GetCertificate(t *testing.T) { + tests := []struct { + name string + run func(t *testing.T) + }{ + { + name: "returns loaded cert", + run: func(t *testing.T) { + dir := t.TempDir() + certPath, keyPath, wantDER := writeKeypair(t, dir, "a-") + + kpr, err := NewKeypairReloader(certPath, keyPath, discardLogger()) + if err != nil { + t.Fatalf("NewKeypairReloader: %v", err) + } + + got, err := kpr.GetCertificate(nil) + if err != nil { + t.Fatalf("GetCertificate: %v", err) + } + if len(got.Certificate) == 0 || !bytes.Equal(got.Certificate[0], wantDER) { + t.Errorf("GetCertificate returned wrong cert bytes") + } + }, + }, + { + name: "reloads when mtime advances", + run: func(t *testing.T) { + dir := t.TempDir() + certPath, keyPath, _ := writeKeypair(t, dir, "a-") + + kpr, err := NewKeypairReloader(certPath, keyPath, discardLogger()) + if err != nil { + t.Fatalf("NewKeypairReloader: %v", err) + } + + // Overwrite with a new pair at the same paths and bump mtime. + newCertPath, newKeyPath, newDER := writeKeypair(t, dir, "b-") + mustRename(t, newCertPath, certPath) + mustRename(t, newKeyPath, keyPath) + future := time.Now().Add(time.Hour) + if err := os.Chtimes(certPath, future, future); err != nil { + t.Fatalf("Chtimes: %v", err) + } + + got, err := kpr.GetCertificate(nil) + if err != nil { + t.Fatalf("GetCertificate: %v", err) + } + if len(got.Certificate) == 0 || !bytes.Equal(got.Certificate[0], newDER) { + t.Errorf("GetCertificate did not return reloaded cert") + } + }, + }, + { + name: "does not reload when mtime unchanged", + run: func(t *testing.T) { + dir := t.TempDir() + certPath, keyPath, originalDER := writeKeypair(t, dir, "a-") + + kpr, err := NewKeypairReloader(certPath, keyPath, discardLogger()) + if err != nil { + t.Fatalf("NewKeypairReloader: %v", err) + } + + // Overwrite the cert/key files with a *different* keypair, then + // rewind mtime so the reloader must not pick up the change. + newCertPath, newKeyPath, newDER := writeKeypair(t, dir, "b-") + mustRename(t, newCertPath, certPath) + mustRename(t, newKeyPath, keyPath) + past := time.Unix(0, 0) + if err := os.Chtimes(certPath, past, past); err != nil { + t.Fatalf("Chtimes: %v", err) + } + + got, err := kpr.GetCertificate(nil) + if err != nil { + t.Fatalf("GetCertificate: %v", err) + } + if len(got.Certificate) == 0 { + t.Fatal("empty cert chain") + } + if bytes.Equal(got.Certificate[0], newDER) { + t.Errorf("GetCertificate reloaded despite unchanged mtime") + } + if !bytes.Equal(got.Certificate[0], originalDER) { + t.Errorf("GetCertificate did not return original cert") + } + }, + }, + { + name: "does not panic when reload fails after mtime bump", + run: func(t *testing.T) { + dir := t.TempDir() + certPath, keyPath, _ := writeKeypair(t, dir, "a-") + + kpr, err := NewKeypairReloader(certPath, keyPath, discardLogger()) + if err != nil { + t.Fatalf("NewKeypairReloader: %v", err) + } + + // Corrupt the cert file and bump mtime. maybeReload will fail. + if err := os.WriteFile(certPath, []byte("not a pem file"), 0o600); err != nil { + t.Fatalf("corrupt cert: %v", err) + } + future := time.Now().Add(time.Hour) + if err := os.Chtimes(certPath, future, future); err != nil { + t.Fatalf("Chtimes: %v", err) + } + + defer func() { + if r := recover(); r != nil { + t.Fatalf("GetCertificate panicked on reload failure: %v", r) + } + }() + + got, err := kpr.GetCertificate(nil) + if err == nil { + t.Errorf("GetCertificate returned nil err for corrupt cert; got %+v", got) + } + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + tt.run(t) + }) + } +} + +func mustRename(t *testing.T, from, to string) { + t.Helper() + if err := os.Rename(from, to); err != nil { + t.Fatalf("rename %q -> %q: %v", from, to, err) + } +} diff --git a/lib/metrics/metrics.go b/lib/metrics/metrics.go index c0384be4..d94c089d 100644 --- a/lib/metrics/metrics.go +++ b/lib/metrics/metrics.go @@ -2,11 +2,14 @@ package metrics import ( "context" + "crypto/tls" + "crypto/x509" "errors" "fmt" "log/slog" "net/http" "net/http/pprof" + "os" "time" "github.com/TecharoHQ/anubis/internal" @@ -20,10 +23,16 @@ type Server struct { Log *slog.Logger } -func (s *Server) Run(ctx context.Context, done func()) error { +func (s *Server) Run(ctx context.Context, done func()) { defer done() lg := s.Log.With("subsystem", "metrics") + if err := s.run(ctx, lg); err != nil { + lg.Error("can't serve metrics server", "err", err) + } +} + +func (s *Server) run(ctx context.Context, lg *slog.Logger) error { mux := http.NewServeMux() mux.HandleFunc("GET /debug/pprof/", pprof.Index) mux.HandleFunc("GET /debug/pprof/cmdline", pprof.Cmdline) @@ -62,6 +71,32 @@ func (s *Server) Run(ctx context.Context, done func()) error { defer ln.Close() + if s.Config.TLS != nil { + kpr, err := NewKeypairReloader(s.Config.TLS.Certificate, s.Config.TLS.Key, lg) + if err != nil { + return fmt.Errorf("can't setup keypair reloader: %w", err) + } + + srv.TLSConfig = &tls.Config{ + GetCertificate: kpr.GetCertificate, + } + + if s.Config.TLS.CA != "" { + caCert, err := os.ReadFile(s.Config.TLS.CA) + if err != nil { + return fmt.Errorf("%w %s: %w", config.ErrCantReadFile, s.Config.TLS.CA, err) + } + + certPool := x509.NewCertPool() + if !certPool.AppendCertsFromPEM(caCert) { + return fmt.Errorf("%w %s", config.ErrInvalidMetricsCACertificate, s.Config.TLS.CA) + } + + srv.TLSConfig.ClientCAs = certPool + srv.TLSConfig.ClientAuth = tls.RequireAndVerifyClientCert + } + } + lg.Debug("listening for metrics", "url", metricsURL) go func() { @@ -73,8 +108,15 @@ func (s *Server) Run(ctx context.Context, done func()) error { } }() - if err := srv.Serve(ln); !errors.Is(err, http.ErrServerClosed) { - return fmt.Errorf("can't serve metrics server: %w", err) + switch s.Config.TLS != nil { + case true: + if err := srv.ServeTLS(ln, "", ""); !errors.Is(err, http.ErrServerClosed) { + return fmt.Errorf("can't serve TLS metrics server: %w", err) + } + case false: + if err := srv.Serve(ln); !errors.Is(err, http.ErrServerClosed) { + return fmt.Errorf("can't serve metrics server: %w", err) + } } return nil