From 63e6a152807257a95ebb96ce01ff569a06d59720 Mon Sep 17 00:00:00 2001 From: Xe Iaso Date: Wed, 22 Apr 2026 13:52:40 -0400 Subject: [PATCH] feat(metrics): import keypairreloader from a private project Signed-off-by: Xe Iaso --- lib/metrics/keypairreloader.go | 81 +++++++++ lib/metrics/keypairreloader_test.go | 265 ++++++++++++++++++++++++++++ lib/metrics/metrics.go | 14 +- 3 files changed, 359 insertions(+), 1 deletion(-) create mode 100644 lib/metrics/keypairreloader.go create mode 100644 lib/metrics/keypairreloader_test.go diff --git a/lib/metrics/keypairreloader.go b/lib/metrics/keypairreloader.go new file mode 100644 index 00000000..be0a61ca --- /dev/null +++ b/lib/metrics/keypairreloader.go @@ -0,0 +1,81 @@ +package metrics + +import ( + "crypto/tls" + "fmt" + "log/slog" + "os" + "sync" + "time" +) + +type KeypairReloader struct { + certMu sync.RWMutex + cert *tls.Certificate + certPath string + keyPath string + modTime time.Time + lg *slog.Logger +} + +func NewKeypairReloader(certPath, keyPath string, lg *slog.Logger) (*KeypairReloader, error) { + result := &KeypairReloader{ + certPath: certPath, + keyPath: keyPath, + lg: lg, + } + cert, err := tls.LoadX509KeyPair(certPath, keyPath) + if err != nil { + return nil, err + } + result.cert = &cert + + st, err := os.Stat(certPath) + if err != nil { + return nil, err + } + result.modTime = st.ModTime() + + return result, nil +} + +func (kpr *KeypairReloader) maybeReload() error { + kpr.lg.Debug("loading new keypair", "cert", kpr.certPath, "key", kpr.keyPath) + newCert, err := tls.LoadX509KeyPair(kpr.certPath, kpr.keyPath) + if err != nil { + return err + } + + st, err := os.Stat(kpr.certPath) + if err != nil { + return err + } + + kpr.certMu.Lock() + defer kpr.certMu.Unlock() + kpr.cert = &newCert + kpr.modTime = st.ModTime() + + return nil +} + +func (kpr *KeypairReloader) GetCertificate(*tls.ClientHelloInfo) (*tls.Certificate, error) { + st, err := os.Stat(kpr.certPath) + if err != nil { + return nil, fmt.Errorf("stat(%q): %w", kpr.certPath, err) + } + + kpr.certMu.RLock() + needsReload := st.ModTime().After(kpr.modTime) + kpr.certMu.RUnlock() + + if needsReload { + if err := kpr.maybeReload(); err != nil { + return nil, fmt.Errorf("reload cert: %w", err) + } + } + + kpr.certMu.RLock() + defer kpr.certMu.RUnlock() + return kpr.cert, nil +} diff --git a/lib/metrics/keypairreloader_test.go b/lib/metrics/keypairreloader_test.go new file mode 100644 index 00000000..2ec514e4 --- /dev/null +++ b/lib/metrics/keypairreloader_test.go @@ -0,0 +1,265 @@ +package metrics + +import ( + "bytes" + "crypto/rand" + "crypto/rsa" + "crypto/x509" + "crypto/x509/pkix" + "encoding/pem" + "errors" + "log/slog" + "math/big" + "os" + "path/filepath" + "testing" + "time" +) + +func discardLogger() *slog.Logger { + return slog.New(slog.DiscardHandler) +} + +// writeKeypair generates a fresh self-signed cert + RSA key and writes them +// as PEM files in dir. Returns the paths and the cert's DER bytes so callers +// can identify which pair was loaded. +func writeKeypair(t *testing.T, dir, prefix string) (certPath, keyPath string, certDER []byte) { + t.Helper() + + priv, err := rsa.GenerateKey(rand.Reader, 2048) + if err != nil { + t.Fatalf("rsa.GenerateKey: %v", err) + } + + tmpl := &x509.Certificate{ + SerialNumber: big.NewInt(time.Now().UnixNano()), + Subject: pkix.Name{CommonName: "keypairreloader-test"}, + NotBefore: time.Now().Add(-time.Hour), + NotAfter: time.Now().Add(time.Hour), + KeyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageKeyEncipherment, + ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth}, + BasicConstraintsValid: true, + DNSNames: []string{"keypairreloader-test"}, + } + + der, err := x509.CreateCertificate(rand.Reader, tmpl, tmpl, &priv.PublicKey, priv) + if err != nil { + t.Fatalf("x509.CreateCertificate: %v", err) + } + + certPath = filepath.Join(dir, prefix+"cert.pem") + keyPath = filepath.Join(dir, prefix+"key.pem") + + certPEM := pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: der}) + keyPEM := pem.EncodeToMemory(&pem.Block{Type: "RSA PRIVATE KEY", Bytes: x509.MarshalPKCS1PrivateKey(priv)}) + + if err := os.WriteFile(certPath, certPEM, 0o600); err != nil { + t.Fatalf("write cert: %v", err) + } + if err := os.WriteFile(keyPath, keyPEM, 0o600); err != nil { + t.Fatalf("write key: %v", err) + } + + return certPath, keyPath, der +} + +func TestNewKeypairReloader(t *testing.T) { + dir := t.TempDir() + goodCert, goodKey, _ := writeKeypair(t, dir, "good-") + + garbagePath := filepath.Join(dir, "garbage.pem") + if err := os.WriteFile(garbagePath, []byte("not a pem file"), 0o600); err != nil { + t.Fatalf("write garbage: %v", err) + } + + tests := []struct { + name string + certPath string + keyPath string + wantErr error + wantNil bool + }{ + { + name: "valid cert and key", + certPath: goodCert, + keyPath: goodKey, + }, + { + name: "missing cert file", + certPath: filepath.Join(dir, "does-not-exist.pem"), + keyPath: goodKey, + wantErr: os.ErrNotExist, + wantNil: true, + }, + { + name: "missing key file", + certPath: goodCert, + keyPath: filepath.Join(dir, "does-not-exist-key.pem"), + wantErr: os.ErrNotExist, + wantNil: true, + }, + { + name: "cert file is garbage", + certPath: garbagePath, + keyPath: goodKey, + wantNil: true, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + kpr, err := NewKeypairReloader(tt.certPath, tt.keyPath, discardLogger()) + + if tt.wantErr != nil && !errors.Is(err, tt.wantErr) { + t.Errorf("err = %v, want errors.Is(..., %v)", err, tt.wantErr) + } + if tt.wantErr == nil && !tt.wantNil && err != nil { + t.Errorf("unexpected err: %v", err) + } + if tt.wantNil && kpr != nil { + t.Errorf("kpr = %+v, want nil", kpr) + } + if !tt.wantNil && kpr == nil { + t.Errorf("kpr is nil, want non-nil") + } + }) + } +} + +func TestKeypairReloader_GetCertificate(t *testing.T) { + tests := []struct { + name string + run func(t *testing.T) + }{ + { + name: "returns loaded cert", + run: func(t *testing.T) { + dir := t.TempDir() + certPath, keyPath, wantDER := writeKeypair(t, dir, "a-") + + kpr, err := NewKeypairReloader(certPath, keyPath, discardLogger()) + if err != nil { + t.Fatalf("NewKeypairReloader: %v", err) + } + + got, err := kpr.GetCertificate(nil) + if err != nil { + t.Fatalf("GetCertificate: %v", err) + } + if len(got.Certificate) == 0 || !bytes.Equal(got.Certificate[0], wantDER) { + t.Errorf("GetCertificate returned wrong cert bytes") + } + }, + }, + { + name: "reloads when mtime advances", + run: func(t *testing.T) { + dir := t.TempDir() + certPath, keyPath, _ := writeKeypair(t, dir, "a-") + + kpr, err := NewKeypairReloader(certPath, keyPath, discardLogger()) + if err != nil { + t.Fatalf("NewKeypairReloader: %v", err) + } + + // Overwrite with a new pair at the same paths and bump mtime. + newCertPath, newKeyPath, newDER := writeKeypair(t, dir, "b-") + mustRename(t, newCertPath, certPath) + mustRename(t, newKeyPath, keyPath) + future := time.Now().Add(time.Hour) + if err := os.Chtimes(certPath, future, future); err != nil { + t.Fatalf("Chtimes: %v", err) + } + + got, err := kpr.GetCertificate(nil) + if err != nil { + t.Fatalf("GetCertificate: %v", err) + } + if len(got.Certificate) == 0 || !bytes.Equal(got.Certificate[0], newDER) { + t.Errorf("GetCertificate did not return reloaded cert") + } + }, + }, + { + name: "does not reload when mtime unchanged", + run: func(t *testing.T) { + dir := t.TempDir() + certPath, keyPath, originalDER := writeKeypair(t, dir, "a-") + + kpr, err := NewKeypairReloader(certPath, keyPath, discardLogger()) + if err != nil { + t.Fatalf("NewKeypairReloader: %v", err) + } + + // Overwrite the cert/key files with a *different* keypair, then + // rewind mtime so the reloader must not pick up the change. + newCertPath, newKeyPath, newDER := writeKeypair(t, dir, "b-") + mustRename(t, newCertPath, certPath) + mustRename(t, newKeyPath, keyPath) + past := time.Unix(0, 0) + if err := os.Chtimes(certPath, past, past); err != nil { + t.Fatalf("Chtimes: %v", err) + } + + got, err := kpr.GetCertificate(nil) + if err != nil { + t.Fatalf("GetCertificate: %v", err) + } + if len(got.Certificate) == 0 { + t.Fatal("empty cert chain") + } + if bytes.Equal(got.Certificate[0], newDER) { + t.Errorf("GetCertificate reloaded despite unchanged mtime") + } + if !bytes.Equal(got.Certificate[0], originalDER) { + t.Errorf("GetCertificate did not return original cert") + } + }, + }, + { + name: "does not panic when reload fails after mtime bump", + run: func(t *testing.T) { + dir := t.TempDir() + certPath, keyPath, _ := writeKeypair(t, dir, "a-") + + kpr, err := NewKeypairReloader(certPath, keyPath, discardLogger()) + if err != nil { + t.Fatalf("NewKeypairReloader: %v", err) + } + + // Corrupt the cert file and bump mtime. maybeReload will fail. + if err := os.WriteFile(certPath, []byte("not a pem file"), 0o600); err != nil { + t.Fatalf("corrupt cert: %v", err) + } + future := time.Now().Add(time.Hour) + if err := os.Chtimes(certPath, future, future); err != nil { + t.Fatalf("Chtimes: %v", err) + } + + defer func() { + if r := recover(); r != nil { + t.Fatalf("GetCertificate panicked on reload failure: %v", r) + } + }() + + got, err := kpr.GetCertificate(nil) + if err == nil { + t.Errorf("GetCertificate returned nil err for corrupt cert; got %+v", got) + } + }, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + tt.run(t) + }) + } +} + +func mustRename(t *testing.T, from, to string) { + t.Helper() + if err := os.Rename(from, to); err != nil { + t.Fatalf("rename %q -> %q: %v", from, to, err) + } +} diff --git a/lib/metrics/metrics.go b/lib/metrics/metrics.go index f5f34d42..7edb6251 100644 --- a/lib/metrics/metrics.go +++ b/lib/metrics/metrics.go @@ -2,6 +2,7 @@ package metrics import ( "context" + "crypto/tls" "errors" "fmt" "log/slog" @@ -62,6 +63,17 @@ func (s *Server) Run(ctx context.Context, done func()) error { defer ln.Close() + if s.Config.TLS != nil { + kpr, err := NewKeypairReloader(s.Config.TLS.Certificate, s.Config.TLS.Key, lg) + if err != nil { + return fmt.Errorf("can't setup keypair reloader: %w", err) + } + + srv.TLSConfig = &tls.Config{ + GetCertificate: kpr.GetCertificate, + } + } + lg.Debug("listening for metrics", "url", metricsURL) go func() { @@ -75,7 +87,7 @@ func (s *Server) Run(ctx context.Context, done func()) error { switch s.Config.TLS != nil { case true: - if err := srv.ServeTLS(ln, s.Config.TLS.Certificate, s.Config.TLS.Key); !errors.Is(err, http.ErrServerClosed) { + if err := srv.ServeTLS(ln, "", ""); !errors.Is(err, http.ErrServerClosed) { return fmt.Errorf("can't serve TLS metrics server: %w", err) } case false: