mirror of
https://github.com/TecharoHQ/anubis.git
synced 2026-04-23 16:46:40 +00:00
feat(metrics): import keypairreloader from a private project
Signed-off-by: Xe Iaso <me@xeiaso.net>
This commit is contained in:
@@ -0,0 +1,81 @@
|
|||||||
|
package metrics
|
||||||
|
|
||||||
|
import (
|
||||||
|
"crypto/tls"
|
||||||
|
"fmt"
|
||||||
|
"log/slog"
|
||||||
|
"os"
|
||||||
|
"sync"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
type KeypairReloader struct {
|
||||||
|
certMu sync.RWMutex
|
||||||
|
cert *tls.Certificate
|
||||||
|
certPath string
|
||||||
|
keyPath string
|
||||||
|
modTime time.Time
|
||||||
|
lg *slog.Logger
|
||||||
|
}
|
||||||
|
|
||||||
|
func NewKeypairReloader(certPath, keyPath string, lg *slog.Logger) (*KeypairReloader, error) {
|
||||||
|
result := &KeypairReloader{
|
||||||
|
certPath: certPath,
|
||||||
|
keyPath: keyPath,
|
||||||
|
lg: lg,
|
||||||
|
}
|
||||||
|
cert, err := tls.LoadX509KeyPair(certPath, keyPath)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
result.cert = &cert
|
||||||
|
|
||||||
|
st, err := os.Stat(certPath)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
result.modTime = st.ModTime()
|
||||||
|
|
||||||
|
return result, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (kpr *KeypairReloader) maybeReload() error {
|
||||||
|
kpr.lg.Debug("loading new keypair", "cert", kpr.certPath, "key", kpr.keyPath)
|
||||||
|
newCert, err := tls.LoadX509KeyPair(kpr.certPath, kpr.keyPath)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
st, err := os.Stat(kpr.certPath)
|
||||||
|
if err != nil {
|
||||||
|
return err
|
||||||
|
}
|
||||||
|
|
||||||
|
kpr.certMu.Lock()
|
||||||
|
defer kpr.certMu.Unlock()
|
||||||
|
kpr.cert = &newCert
|
||||||
|
kpr.modTime = st.ModTime()
|
||||||
|
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func (kpr *KeypairReloader) GetCertificate(*tls.ClientHelloInfo) (*tls.Certificate, error) {
|
||||||
|
st, err := os.Stat(kpr.certPath)
|
||||||
|
if err != nil {
|
||||||
|
return nil, fmt.Errorf("stat(%q): %w", kpr.certPath, err)
|
||||||
|
}
|
||||||
|
|
||||||
|
kpr.certMu.RLock()
|
||||||
|
needsReload := st.ModTime().After(kpr.modTime)
|
||||||
|
kpr.certMu.RUnlock()
|
||||||
|
|
||||||
|
if needsReload {
|
||||||
|
if err := kpr.maybeReload(); err != nil {
|
||||||
|
return nil, fmt.Errorf("reload cert: %w", err)
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
kpr.certMu.RLock()
|
||||||
|
defer kpr.certMu.RUnlock()
|
||||||
|
return kpr.cert, nil
|
||||||
|
}
|
||||||
@@ -0,0 +1,265 @@
|
|||||||
|
package metrics
|
||||||
|
|
||||||
|
import (
|
||||||
|
"bytes"
|
||||||
|
"crypto/rand"
|
||||||
|
"crypto/rsa"
|
||||||
|
"crypto/x509"
|
||||||
|
"crypto/x509/pkix"
|
||||||
|
"encoding/pem"
|
||||||
|
"errors"
|
||||||
|
"log/slog"
|
||||||
|
"math/big"
|
||||||
|
"os"
|
||||||
|
"path/filepath"
|
||||||
|
"testing"
|
||||||
|
"time"
|
||||||
|
)
|
||||||
|
|
||||||
|
func discardLogger() *slog.Logger {
|
||||||
|
return slog.New(slog.DiscardHandler)
|
||||||
|
}
|
||||||
|
|
||||||
|
// writeKeypair generates a fresh self-signed cert + RSA key and writes them
|
||||||
|
// as PEM files in dir. Returns the paths and the cert's DER bytes so callers
|
||||||
|
// can identify which pair was loaded.
|
||||||
|
func writeKeypair(t *testing.T, dir, prefix string) (certPath, keyPath string, certDER []byte) {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
priv, err := rsa.GenerateKey(rand.Reader, 2048)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("rsa.GenerateKey: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
tmpl := &x509.Certificate{
|
||||||
|
SerialNumber: big.NewInt(time.Now().UnixNano()),
|
||||||
|
Subject: pkix.Name{CommonName: "keypairreloader-test"},
|
||||||
|
NotBefore: time.Now().Add(-time.Hour),
|
||||||
|
NotAfter: time.Now().Add(time.Hour),
|
||||||
|
KeyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageKeyEncipherment,
|
||||||
|
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth},
|
||||||
|
BasicConstraintsValid: true,
|
||||||
|
DNSNames: []string{"keypairreloader-test"},
|
||||||
|
}
|
||||||
|
|
||||||
|
der, err := x509.CreateCertificate(rand.Reader, tmpl, tmpl, &priv.PublicKey, priv)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("x509.CreateCertificate: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
certPath = filepath.Join(dir, prefix+"cert.pem")
|
||||||
|
keyPath = filepath.Join(dir, prefix+"key.pem")
|
||||||
|
|
||||||
|
certPEM := pem.EncodeToMemory(&pem.Block{Type: "CERTIFICATE", Bytes: der})
|
||||||
|
keyPEM := pem.EncodeToMemory(&pem.Block{Type: "RSA PRIVATE KEY", Bytes: x509.MarshalPKCS1PrivateKey(priv)})
|
||||||
|
|
||||||
|
if err := os.WriteFile(certPath, certPEM, 0o600); err != nil {
|
||||||
|
t.Fatalf("write cert: %v", err)
|
||||||
|
}
|
||||||
|
if err := os.WriteFile(keyPath, keyPEM, 0o600); err != nil {
|
||||||
|
t.Fatalf("write key: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
return certPath, keyPath, der
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestNewKeypairReloader(t *testing.T) {
|
||||||
|
dir := t.TempDir()
|
||||||
|
goodCert, goodKey, _ := writeKeypair(t, dir, "good-")
|
||||||
|
|
||||||
|
garbagePath := filepath.Join(dir, "garbage.pem")
|
||||||
|
if err := os.WriteFile(garbagePath, []byte("not a pem file"), 0o600); err != nil {
|
||||||
|
t.Fatalf("write garbage: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
certPath string
|
||||||
|
keyPath string
|
||||||
|
wantErr error
|
||||||
|
wantNil bool
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "valid cert and key",
|
||||||
|
certPath: goodCert,
|
||||||
|
keyPath: goodKey,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "missing cert file",
|
||||||
|
certPath: filepath.Join(dir, "does-not-exist.pem"),
|
||||||
|
keyPath: goodKey,
|
||||||
|
wantErr: os.ErrNotExist,
|
||||||
|
wantNil: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "missing key file",
|
||||||
|
certPath: goodCert,
|
||||||
|
keyPath: filepath.Join(dir, "does-not-exist-key.pem"),
|
||||||
|
wantErr: os.ErrNotExist,
|
||||||
|
wantNil: true,
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "cert file is garbage",
|
||||||
|
certPath: garbagePath,
|
||||||
|
keyPath: goodKey,
|
||||||
|
wantNil: true,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
kpr, err := NewKeypairReloader(tt.certPath, tt.keyPath, discardLogger())
|
||||||
|
|
||||||
|
if tt.wantErr != nil && !errors.Is(err, tt.wantErr) {
|
||||||
|
t.Errorf("err = %v, want errors.Is(..., %v)", err, tt.wantErr)
|
||||||
|
}
|
||||||
|
if tt.wantErr == nil && !tt.wantNil && err != nil {
|
||||||
|
t.Errorf("unexpected err: %v", err)
|
||||||
|
}
|
||||||
|
if tt.wantNil && kpr != nil {
|
||||||
|
t.Errorf("kpr = %+v, want nil", kpr)
|
||||||
|
}
|
||||||
|
if !tt.wantNil && kpr == nil {
|
||||||
|
t.Errorf("kpr is nil, want non-nil")
|
||||||
|
}
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestKeypairReloader_GetCertificate(t *testing.T) {
|
||||||
|
tests := []struct {
|
||||||
|
name string
|
||||||
|
run func(t *testing.T)
|
||||||
|
}{
|
||||||
|
{
|
||||||
|
name: "returns loaded cert",
|
||||||
|
run: func(t *testing.T) {
|
||||||
|
dir := t.TempDir()
|
||||||
|
certPath, keyPath, wantDER := writeKeypair(t, dir, "a-")
|
||||||
|
|
||||||
|
kpr, err := NewKeypairReloader(certPath, keyPath, discardLogger())
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("NewKeypairReloader: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
got, err := kpr.GetCertificate(nil)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("GetCertificate: %v", err)
|
||||||
|
}
|
||||||
|
if len(got.Certificate) == 0 || !bytes.Equal(got.Certificate[0], wantDER) {
|
||||||
|
t.Errorf("GetCertificate returned wrong cert bytes")
|
||||||
|
}
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "reloads when mtime advances",
|
||||||
|
run: func(t *testing.T) {
|
||||||
|
dir := t.TempDir()
|
||||||
|
certPath, keyPath, _ := writeKeypair(t, dir, "a-")
|
||||||
|
|
||||||
|
kpr, err := NewKeypairReloader(certPath, keyPath, discardLogger())
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("NewKeypairReloader: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Overwrite with a new pair at the same paths and bump mtime.
|
||||||
|
newCertPath, newKeyPath, newDER := writeKeypair(t, dir, "b-")
|
||||||
|
mustRename(t, newCertPath, certPath)
|
||||||
|
mustRename(t, newKeyPath, keyPath)
|
||||||
|
future := time.Now().Add(time.Hour)
|
||||||
|
if err := os.Chtimes(certPath, future, future); err != nil {
|
||||||
|
t.Fatalf("Chtimes: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
got, err := kpr.GetCertificate(nil)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("GetCertificate: %v", err)
|
||||||
|
}
|
||||||
|
if len(got.Certificate) == 0 || !bytes.Equal(got.Certificate[0], newDER) {
|
||||||
|
t.Errorf("GetCertificate did not return reloaded cert")
|
||||||
|
}
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "does not reload when mtime unchanged",
|
||||||
|
run: func(t *testing.T) {
|
||||||
|
dir := t.TempDir()
|
||||||
|
certPath, keyPath, originalDER := writeKeypair(t, dir, "a-")
|
||||||
|
|
||||||
|
kpr, err := NewKeypairReloader(certPath, keyPath, discardLogger())
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("NewKeypairReloader: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Overwrite the cert/key files with a *different* keypair, then
|
||||||
|
// rewind mtime so the reloader must not pick up the change.
|
||||||
|
newCertPath, newKeyPath, newDER := writeKeypair(t, dir, "b-")
|
||||||
|
mustRename(t, newCertPath, certPath)
|
||||||
|
mustRename(t, newKeyPath, keyPath)
|
||||||
|
past := time.Unix(0, 0)
|
||||||
|
if err := os.Chtimes(certPath, past, past); err != nil {
|
||||||
|
t.Fatalf("Chtimes: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
got, err := kpr.GetCertificate(nil)
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("GetCertificate: %v", err)
|
||||||
|
}
|
||||||
|
if len(got.Certificate) == 0 {
|
||||||
|
t.Fatal("empty cert chain")
|
||||||
|
}
|
||||||
|
if bytes.Equal(got.Certificate[0], newDER) {
|
||||||
|
t.Errorf("GetCertificate reloaded despite unchanged mtime")
|
||||||
|
}
|
||||||
|
if !bytes.Equal(got.Certificate[0], originalDER) {
|
||||||
|
t.Errorf("GetCertificate did not return original cert")
|
||||||
|
}
|
||||||
|
},
|
||||||
|
},
|
||||||
|
{
|
||||||
|
name: "does not panic when reload fails after mtime bump",
|
||||||
|
run: func(t *testing.T) {
|
||||||
|
dir := t.TempDir()
|
||||||
|
certPath, keyPath, _ := writeKeypair(t, dir, "a-")
|
||||||
|
|
||||||
|
kpr, err := NewKeypairReloader(certPath, keyPath, discardLogger())
|
||||||
|
if err != nil {
|
||||||
|
t.Fatalf("NewKeypairReloader: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
// Corrupt the cert file and bump mtime. maybeReload will fail.
|
||||||
|
if err := os.WriteFile(certPath, []byte("not a pem file"), 0o600); err != nil {
|
||||||
|
t.Fatalf("corrupt cert: %v", err)
|
||||||
|
}
|
||||||
|
future := time.Now().Add(time.Hour)
|
||||||
|
if err := os.Chtimes(certPath, future, future); err != nil {
|
||||||
|
t.Fatalf("Chtimes: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
defer func() {
|
||||||
|
if r := recover(); r != nil {
|
||||||
|
t.Fatalf("GetCertificate panicked on reload failure: %v", r)
|
||||||
|
}
|
||||||
|
}()
|
||||||
|
|
||||||
|
got, err := kpr.GetCertificate(nil)
|
||||||
|
if err == nil {
|
||||||
|
t.Errorf("GetCertificate returned nil err for corrupt cert; got %+v", got)
|
||||||
|
}
|
||||||
|
},
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
for _, tt := range tests {
|
||||||
|
t.Run(tt.name, func(t *testing.T) {
|
||||||
|
tt.run(t)
|
||||||
|
})
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func mustRename(t *testing.T, from, to string) {
|
||||||
|
t.Helper()
|
||||||
|
if err := os.Rename(from, to); err != nil {
|
||||||
|
t.Fatalf("rename %q -> %q: %v", from, to, err)
|
||||||
|
}
|
||||||
|
}
|
||||||
+13
-1
@@ -2,6 +2,7 @@ package metrics
|
|||||||
|
|
||||||
import (
|
import (
|
||||||
"context"
|
"context"
|
||||||
|
"crypto/tls"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"log/slog"
|
"log/slog"
|
||||||
@@ -62,6 +63,17 @@ func (s *Server) Run(ctx context.Context, done func()) error {
|
|||||||
|
|
||||||
defer ln.Close()
|
defer ln.Close()
|
||||||
|
|
||||||
|
if s.Config.TLS != nil {
|
||||||
|
kpr, err := NewKeypairReloader(s.Config.TLS.Certificate, s.Config.TLS.Key, lg)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("can't setup keypair reloader: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
srv.TLSConfig = &tls.Config{
|
||||||
|
GetCertificate: kpr.GetCertificate,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
lg.Debug("listening for metrics", "url", metricsURL)
|
lg.Debug("listening for metrics", "url", metricsURL)
|
||||||
|
|
||||||
go func() {
|
go func() {
|
||||||
@@ -75,7 +87,7 @@ func (s *Server) Run(ctx context.Context, done func()) error {
|
|||||||
|
|
||||||
switch s.Config.TLS != nil {
|
switch s.Config.TLS != nil {
|
||||||
case true:
|
case true:
|
||||||
if err := srv.ServeTLS(ln, s.Config.TLS.Certificate, s.Config.TLS.Key); !errors.Is(err, http.ErrServerClosed) {
|
if err := srv.ServeTLS(ln, "", ""); !errors.Is(err, http.ErrServerClosed) {
|
||||||
return fmt.Errorf("can't serve TLS metrics server: %w", err)
|
return fmt.Errorf("can't serve TLS metrics server: %w", err)
|
||||||
}
|
}
|
||||||
case false:
|
case false:
|
||||||
|
|||||||
Reference in New Issue
Block a user