fix(ogtags): respect target host/SNI/insecure flags in OG passthrough

This commit is contained in:
Jason Cameron
2025-11-16 02:43:32 -05:00
parent da1890380e
commit bba8b8001e
11 changed files with 400 additions and 75 deletions

View File

@@ -439,26 +439,29 @@ func main() {
} }
s, err := libanubis.New(libanubis.Options{ s, err := libanubis.New(libanubis.Options{
BasePrefix: *basePrefix, BasePrefix: *basePrefix,
StripBasePrefix: *stripBasePrefix, StripBasePrefix: *stripBasePrefix,
Next: rp, Next: rp,
Policy: policy, Policy: policy,
ServeRobotsTXT: *robotsTxt, TargetHost: *targetHost,
ED25519PrivateKey: ed25519Priv, TargetSNI: *targetSNI,
HS512Secret: []byte(*hs512Secret), TargetInsecureSkipVerify: *targetInsecureSkipVerify,
CookieDomain: *cookieDomain, ServeRobotsTXT: *robotsTxt,
CookieDynamicDomain: *cookieDynamicDomain, ED25519PrivateKey: ed25519Priv,
CookieExpiration: *cookieExpiration, HS512Secret: []byte(*hs512Secret),
CookiePartitioned: *cookiePartitioned, CookieDomain: *cookieDomain,
RedirectDomains: redirectDomainsList, CookieDynamicDomain: *cookieDynamicDomain,
Target: *target, CookieExpiration: *cookieExpiration,
WebmasterEmail: *webmasterEmail, CookiePartitioned: *cookiePartitioned,
OpenGraph: policy.OpenGraph, RedirectDomains: redirectDomainsList,
CookieSecure: *cookieSecure, Target: *target,
CookieSameSite: parseSameSite(*cookieSameSite), WebmasterEmail: *webmasterEmail,
PublicUrl: *publicUrl, OpenGraph: policy.OpenGraph,
JWTRestrictionHeader: *jwtRestrictionHeader, CookieSecure: *cookieSecure,
DifficultyInJWT: *difficultyInJWT, CookieSameSite: parseSameSite(*cookieSameSite),
PublicUrl: *publicUrl,
JWTRestrictionHeader: *jwtRestrictionHeader,
DifficultyInJWT: *difficultyInJWT,
}) })
if err != nil { if err != nil {
log.Fatalf("can't construct libanubis.Server: %v", err) log.Fatalf("can't construct libanubis.Server: %v", err)

View File

@@ -24,7 +24,7 @@ func TestCacheReturnsDefault(t *testing.T) {
TimeToLive: time.Minute, TimeToLive: time.Minute,
ConsiderHost: false, ConsiderHost: false,
Override: want, Override: want,
}, memory.New(t.Context())) }, memory.New(t.Context()), TargetOptions{})
u, err := url.Parse("https://anubis.techaro.lol") u, err := url.Parse("https://anubis.techaro.lol")
if err != nil { if err != nil {
@@ -52,7 +52,7 @@ func TestCheckCache(t *testing.T) {
Enabled: true, Enabled: true,
TimeToLive: time.Minute, TimeToLive: time.Minute,
ConsiderHost: false, ConsiderHost: false,
}, memory.New(t.Context())) }, memory.New(t.Context()), TargetOptions{})
// Set up test data // Set up test data
urlStr := "http://example.com/page" urlStr := "http://example.com/page"
@@ -115,7 +115,7 @@ func TestGetOGTags(t *testing.T) {
Enabled: true, Enabled: true,
TimeToLive: time.Minute, TimeToLive: time.Minute,
ConsiderHost: false, ConsiderHost: false,
}, memory.New(t.Context())) }, memory.New(t.Context()), TargetOptions{})
// Parse the test server URL // Parse the test server URL
parsedURL, err := url.Parse(ts.URL) parsedURL, err := url.Parse(ts.URL)
@@ -271,7 +271,7 @@ func TestGetOGTagsWithHostConsideration(t *testing.T) {
Enabled: true, Enabled: true,
TimeToLive: time.Minute, TimeToLive: time.Minute,
ConsiderHost: tc.ogCacheConsiderHost, ConsiderHost: tc.ogCacheConsiderHost,
}, memory.New(t.Context())) }, memory.New(t.Context()), TargetOptions{})
for i, req := range tc.requests { for i, req := range tc.requests {
ogTags, err := cache.GetOGTags(t.Context(), parsedURL, req.host) ogTags, err := cache.GetOGTags(t.Context(), parsedURL, req.host)

View File

@@ -2,6 +2,7 @@ package ogtags
import ( import (
"context" "context"
"crypto/tls"
"errors" "errors"
"fmt" "fmt"
"io" "io"
@@ -27,7 +28,10 @@ func (c *OGTagCache) fetchHTMLDocumentWithCache(ctx context.Context, urlStr stri
} }
// Set the Host header to the original host // Set the Host header to the original host
if originalHost != "" { switch {
case c.targetHost != "":
req.Host = c.targetHost
case originalHost != "":
req.Host = originalHost req.Host = originalHost
} }
@@ -35,8 +39,34 @@ func (c *OGTagCache) fetchHTMLDocumentWithCache(ctx context.Context, urlStr stri
req.Header.Set("X-Forwarded-Proto", "https") req.Header.Set("X-Forwarded-Proto", "https")
req.Header.Set("User-Agent", "Anubis-OGTag-Fetcher/1.0") // For tracking purposes req.Header.Set("User-Agent", "Anubis-OGTag-Fetcher/1.0") // For tracking purposes
client := c.client
if c.targetSNIAuto {
serverName := originalHost
if c.targetHost != "" {
serverName = c.targetHost
}
if serverName != "" {
transport := c.transport.Clone()
if transport.TLSClientConfig == nil {
transport.TLSClientConfig = &tls.Config{}
}
transport.TLSClientConfig.ServerName = serverName
if c.insecureSkipVerify {
transport.TLSClientConfig.InsecureSkipVerify = true
}
client = &http.Client{
Timeout: httpTimeout,
Transport: transport,
}
defer transport.CloseIdleConnections()
}
}
// Send the request // Send the request
resp, err := c.client.Do(req) resp, err := client.Do(req)
if err != nil { if err != nil {
var netErr net.Error var netErr net.Error
if errors.As(err, &netErr) && netErr.Timeout() { if errors.As(err, &netErr) && netErr.Timeout() {

View File

@@ -87,7 +87,7 @@ func TestFetchHTMLDocument(t *testing.T) {
Enabled: true, Enabled: true,
TimeToLive: time.Minute, TimeToLive: time.Minute,
ConsiderHost: false, ConsiderHost: false,
}, memory.New(t.Context())) }, memory.New(t.Context()), TargetOptions{})
doc, err := cache.fetchHTMLDocument(t.Context(), ts.URL, "anything") doc, err := cache.fetchHTMLDocument(t.Context(), ts.URL, "anything")
if tt.expectError { if tt.expectError {
@@ -118,7 +118,7 @@ func TestFetchHTMLDocumentInvalidURL(t *testing.T) {
Enabled: true, Enabled: true,
TimeToLive: time.Minute, TimeToLive: time.Minute,
ConsiderHost: false, ConsiderHost: false,
}, memory.New(t.Context())) }, memory.New(t.Context()), TargetOptions{})
doc, err := cache.fetchHTMLDocument(t.Context(), "http://invalid.url.that.doesnt.exist.example", "anything") doc, err := cache.fetchHTMLDocument(t.Context(), "http://invalid.url.that.doesnt.exist.example", "anything")

View File

@@ -111,7 +111,7 @@ func TestIntegrationGetOGTags(t *testing.T) {
Enabled: true, Enabled: true,
TimeToLive: time.Minute, TimeToLive: time.Minute,
ConsiderHost: false, ConsiderHost: false,
}, memory.New(t.Context())) }, memory.New(t.Context()), TargetOptions{})
// Create URL for test // Create URL for test
testURL, _ := url.Parse(ts.URL) testURL, _ := url.Parse(ts.URL)

View File

@@ -31,7 +31,7 @@ func BenchmarkGetTarget(b *testing.B) {
for _, tt := range tests { for _, tt := range tests {
b.Run(tt.name, func(b *testing.B) { b.Run(tt.name, func(b *testing.B) {
cache := NewOGTagCache(tt.target, config.OpenGraph{}, memory.New(b.Context())) cache := NewOGTagCache(tt.target, config.OpenGraph{}, memory.New(b.Context()), TargetOptions{})
urls := make([]*url.URL, len(tt.paths)) urls := make([]*url.URL, len(tt.paths))
for i, path := range tt.paths { for i, path := range tt.paths {
u, _ := url.Parse(path) u, _ := url.Parse(path)
@@ -67,7 +67,7 @@ func BenchmarkExtractOGTags(b *testing.B) {
</head><body><div><p>Content</p></div></body></html>`, </head><body><div><p>Content</p></div></body></html>`,
} }
cache := NewOGTagCache("http://example.com", config.OpenGraph{}, memory.New(b.Context())) cache := NewOGTagCache("http://example.com", config.OpenGraph{}, memory.New(b.Context()), TargetOptions{})
docs := make([]*html.Node, len(htmlSamples)) docs := make([]*html.Node, len(htmlSamples))
for i, sample := range htmlSamples { for i, sample := range htmlSamples {
@@ -85,7 +85,7 @@ func BenchmarkExtractOGTags(b *testing.B) {
// Memory usage test // Memory usage test
func TestMemoryUsage(t *testing.T) { func TestMemoryUsage(t *testing.T) {
cache := NewOGTagCache("http://example.com", config.OpenGraph{}, memory.New(t.Context())) cache := NewOGTagCache("http://example.com", config.OpenGraph{}, memory.New(t.Context()), TargetOptions{})
// Force GC and wait for it to complete // Force GC and wait for it to complete
runtime.GC() runtime.GC()

View File

@@ -2,11 +2,13 @@ package ogtags
import ( import (
"context" "context"
"crypto/tls"
"log/slog" "log/slog"
"net" "net"
"net/http" "net/http"
"net/url" "net/url"
"strings" "strings"
"sync"
"time" "time"
"github.com/TecharoHQ/anubis/lib/policy/config" "github.com/TecharoHQ/anubis/lib/policy/config"
@@ -25,6 +27,7 @@ type OGTagCache struct {
cache store.JSON[map[string]string] cache store.JSON[map[string]string]
targetURL *url.URL targetURL *url.URL
client *http.Client client *http.Client
transport *http.Transport
// Pre-built strings for optimization // Pre-built strings for optimization
unixPrefix string // "http://unix" unixPrefix string // "http://unix"
@@ -34,9 +37,20 @@ type OGTagCache struct {
ogCacheConsiderHost bool ogCacheConsiderHost bool
ogPassthrough bool ogPassthrough bool
ogOverride map[string]string ogOverride map[string]string
targetHost string
targetSNI string
targetSNIAuto bool
insecureSkipVerify bool
transportMu sync.Mutex
} }
func NewOGTagCache(target string, conf config.OpenGraph, backend store.Interface) *OGTagCache { type TargetOptions struct {
Host string
SNI string
InsecureSkipVerify bool
}
func NewOGTagCache(target string, conf config.OpenGraph, backend store.Interface, targetOpts TargetOptions) *OGTagCache {
// Predefined approved tags and prefixes // Predefined approved tags and prefixes
defaultApprovedTags := []string{"description", "keywords", "author"} defaultApprovedTags := []string{"description", "keywords", "author"}
defaultApprovedPrefixes := []string{"og:", "twitter:", "fediverse:"} defaultApprovedPrefixes := []string{"og:", "twitter:", "fediverse:"}
@@ -62,20 +76,37 @@ func NewOGTagCache(target string, conf config.OpenGraph, backend store.Interface
} }
} }
client := &http.Client{ transport := http.DefaultTransport.(*http.Transport).Clone()
Timeout: httpTimeout,
}
// Configure custom transport for Unix sockets // Configure custom transport for Unix sockets
if parsedTargetURL.Scheme == "unix" { if parsedTargetURL.Scheme == "unix" {
socketPath := parsedTargetURL.Path // For unix scheme, path is the socket path socketPath := parsedTargetURL.Path // For unix scheme, path is the socket path
client.Transport = &http.Transport{ transport.DialContext = func(_ context.Context, _, _ string) (net.Conn, error) {
DialContext: func(_ context.Context, _, _ string) (net.Conn, error) { return net.Dial("unix", socketPath)
return net.Dial("unix", socketPath)
},
} }
} }
targetSNIAuto := targetOpts.SNI == "auto"
if targetOpts.SNI != "" && !targetSNIAuto {
if transport.TLSClientConfig == nil {
transport.TLSClientConfig = &tls.Config{}
}
transport.TLSClientConfig.ServerName = targetOpts.SNI
}
if targetOpts.InsecureSkipVerify {
if transport.TLSClientConfig == nil {
transport.TLSClientConfig = &tls.Config{}
}
transport.TLSClientConfig.InsecureSkipVerify = true
}
client := &http.Client{
Timeout: httpTimeout,
Transport: transport,
}
return &OGTagCache{ return &OGTagCache{
cache: store.JSON[map[string]string]{ cache: store.JSON[map[string]string]{
Underlying: backend, Underlying: backend,
@@ -89,7 +120,12 @@ func NewOGTagCache(target string, conf config.OpenGraph, backend store.Interface
approvedTags: defaultApprovedTags, approvedTags: defaultApprovedTags,
approvedPrefixes: defaultApprovedPrefixes, approvedPrefixes: defaultApprovedPrefixes,
client: client, client: client,
transport: transport,
unixPrefix: "http://unix", unixPrefix: "http://unix",
targetHost: targetOpts.Host,
targetSNI: targetOpts.SNI,
targetSNIAuto: targetSNIAuto,
insecureSkipVerify: targetOpts.InsecureSkipVerify,
} }
} }

View File

@@ -48,7 +48,7 @@ func FuzzGetTarget(f *testing.F) {
} }
// Create cache - should not panic // Create cache - should not panic
cache := NewOGTagCache(target, config.OpenGraph{}, memory.New(context.Background())) cache := NewOGTagCache(target, config.OpenGraph{}, memory.New(context.Background()), TargetOptions{})
// Create URL // Create URL
u := &url.URL{ u := &url.URL{
@@ -132,7 +132,7 @@ func FuzzExtractOGTags(f *testing.F) {
return return
} }
cache := NewOGTagCache("http://example.com", config.OpenGraph{}, memory.New(context.Background())) cache := NewOGTagCache("http://example.com", config.OpenGraph{}, memory.New(context.Background()), TargetOptions{})
// Should not panic // Should not panic
tags := cache.extractOGTags(doc) tags := cache.extractOGTags(doc)
@@ -188,7 +188,7 @@ func FuzzGetTargetRoundTrip(f *testing.F) {
t.Skip() t.Skip()
} }
cache := NewOGTagCache(target, config.OpenGraph{}, memory.New(context.Background())) cache := NewOGTagCache(target, config.OpenGraph{}, memory.New(context.Background()), TargetOptions{})
u := &url.URL{Path: path, RawQuery: query} u := &url.URL{Path: path, RawQuery: query}
result := cache.getTarget(u) result := cache.getTarget(u)
@@ -245,7 +245,7 @@ func FuzzExtractMetaTagInfo(f *testing.F) {
}, },
} }
cache := NewOGTagCache("http://example.com", config.OpenGraph{}, memory.New(context.Background())) cache := NewOGTagCache("http://example.com", config.OpenGraph{}, memory.New(context.Background()), TargetOptions{})
// Should not panic // Should not panic
property, content := cache.extractMetaTagInfo(node) property, content := cache.extractMetaTagInfo(node)
@@ -298,7 +298,7 @@ func BenchmarkFuzzedGetTarget(b *testing.B) {
for _, input := range inputs { for _, input := range inputs {
b.Run(input.name, func(b *testing.B) { b.Run(input.name, func(b *testing.B) {
cache := NewOGTagCache(input.target, config.OpenGraph{}, memory.New(context.Background())) cache := NewOGTagCache(input.target, config.OpenGraph{}, memory.New(context.Background()), TargetOptions{})
u := &url.URL{Path: input.path, RawQuery: input.query} u := &url.URL{Path: input.path, RawQuery: input.query}
b.ResetTimer() b.ResetTimer()

View File

@@ -2,15 +2,23 @@ package ogtags
import ( import (
"context" "context"
"crypto/rand"
"crypto/rsa"
"crypto/tls"
"crypto/x509"
"crypto/x509/pkix"
"errors" "errors"
"fmt" "fmt"
"math/big"
"net" "net"
"net/http" "net/http"
"net/http/httptest"
"net/url" "net/url"
"os" "os"
"path/filepath" "path/filepath"
"reflect" "reflect"
"strings" "strings"
"sync"
"testing" "testing"
"time" "time"
@@ -45,7 +53,7 @@ func TestNewOGTagCache(t *testing.T) {
Enabled: tt.ogPassthrough, Enabled: tt.ogPassthrough,
TimeToLive: tt.ogTimeToLive, TimeToLive: tt.ogTimeToLive,
ConsiderHost: false, ConsiderHost: false,
}, memory.New(t.Context())) }, memory.New(t.Context()), TargetOptions{})
if cache == nil { if cache == nil {
t.Fatal("expected non-nil cache, got nil") t.Fatal("expected non-nil cache, got nil")
@@ -85,7 +93,7 @@ func TestNewOGTagCache_UnixSocket(t *testing.T) {
Enabled: true, Enabled: true,
TimeToLive: 5 * time.Minute, TimeToLive: 5 * time.Minute,
ConsiderHost: false, ConsiderHost: false,
}, memory.New(t.Context())) }, memory.New(t.Context()), TargetOptions{})
if cache == nil { if cache == nil {
t.Fatal("expected non-nil cache, got nil") t.Fatal("expected non-nil cache, got nil")
@@ -170,7 +178,7 @@ func TestGetTarget(t *testing.T) {
Enabled: true, Enabled: true,
TimeToLive: time.Minute, TimeToLive: time.Minute,
ConsiderHost: false, ConsiderHost: false,
}, memory.New(t.Context())) }, memory.New(t.Context()), TargetOptions{})
u := &url.URL{ u := &url.URL{
Path: tt.path, Path: tt.path,
@@ -243,7 +251,7 @@ func TestIntegrationGetOGTags_UnixSocket(t *testing.T) {
Enabled: true, Enabled: true,
TimeToLive: time.Minute, TimeToLive: time.Minute,
ConsiderHost: false, ConsiderHost: false,
}, memory.New(t.Context())) }, memory.New(t.Context()), TargetOptions{})
// Create a dummy URL for the request (path and query matter) // Create a dummy URL for the request (path and query matter)
testReqURL, _ := url.Parse("/some/page?query=1") testReqURL, _ := url.Parse("/some/page?query=1")
@@ -274,3 +282,244 @@ func TestIntegrationGetOGTags_UnixSocket(t *testing.T) {
t.Errorf("Expected cached OG tags %v, got %v", expectedTags, cachedTags) t.Errorf("Expected cached OG tags %v, got %v", expectedTags, cachedTags)
} }
} }
func TestGetOGTagsWithTargetHostOverride(t *testing.T) {
originalHost := "example.test"
overrideHost := "backend.internal"
seenHosts := make(chan string, 10)
ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
seenHosts <- r.Host
w.Header().Set("Content-Type", "text/html")
fmt.Fprintln(w, `<!DOCTYPE html><html><head><meta property="og:title" content="HostOverride" /></head><body>ok</body></html>`)
}))
defer ts.Close()
targetURL, err := url.Parse(ts.URL)
if err != nil {
t.Fatalf("failed to parse server URL: %v", err)
}
conf := config.OpenGraph{
Enabled: true,
TimeToLive: time.Minute,
ConsiderHost: false,
}
t.Run("default host uses original", func(t *testing.T) {
cache := NewOGTagCache(ts.URL, conf, memory.New(t.Context()), TargetOptions{})
if _, err := cache.GetOGTags(t.Context(), targetURL, originalHost); err != nil {
t.Fatalf("GetOGTags failed: %v", err)
}
select {
case host := <-seenHosts:
if host != originalHost {
t.Fatalf("expected host %q, got %q", originalHost, host)
}
case <-time.After(time.Second):
t.Fatal("server did not receive request")
}
})
t.Run("override host respected", func(t *testing.T) {
cache := NewOGTagCache(ts.URL, conf, memory.New(t.Context()), TargetOptions{
Host: overrideHost,
})
if _, err := cache.GetOGTags(t.Context(), targetURL, originalHost); err != nil {
t.Fatalf("GetOGTags failed: %v", err)
}
select {
case host := <-seenHosts:
if host != overrideHost {
t.Fatalf("expected host %q, got %q", overrideHost, host)
}
case <-time.After(time.Second):
t.Fatal("server did not receive request")
}
})
}
func TestGetOGTagsWithInsecureSkipVerify(t *testing.T) {
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "text/html")
fmt.Fprintln(w, `<!DOCTYPE html><html><head><meta property="og:title" content="Self-Signed" /></head><body>hello</body></html>`)
})
ts := httptest.NewTLSServer(handler)
defer ts.Close()
parsedURL, err := url.Parse(ts.URL)
if err != nil {
t.Fatalf("failed to parse server URL: %v", err)
}
conf := config.OpenGraph{
Enabled: true,
TimeToLive: time.Minute,
ConsiderHost: false,
}
// Without skip verify we should get a TLS error
cacheStrict := NewOGTagCache(ts.URL, conf, memory.New(t.Context()), TargetOptions{})
if _, err := cacheStrict.GetOGTags(t.Context(), parsedURL, parsedURL.Host); err == nil {
t.Fatal("expected TLS verification error without InsecureSkipVerify")
}
cacheSkip := NewOGTagCache(ts.URL, conf, memory.New(t.Context()), TargetOptions{
InsecureSkipVerify: true,
})
tags, err := cacheSkip.GetOGTags(t.Context(), parsedURL, parsedURL.Host)
if err != nil {
t.Fatalf("expected successful fetch with InsecureSkipVerify, got: %v", err)
}
if tags["og:title"] != "Self-Signed" {
t.Fatalf("expected og:title to be %q, got %q", "Self-Signed", tags["og:title"])
}
}
func TestGetOGTagsWithTargetSNI(t *testing.T) {
originalHost := "hecate.test"
conf := config.OpenGraph{
Enabled: true,
TimeToLive: time.Minute,
ConsiderHost: false,
}
t.Run("explicit SNI override", func(t *testing.T) {
expectedSNI := "backend.internal"
ts, recorder := newSNIServer(t, `<!DOCTYPE html><html><head><meta property="og:title" content="SNI Works" /></head><body>ok</body></html>`)
defer ts.Close()
targetURL, err := url.Parse(ts.URL)
if err != nil {
t.Fatalf("failed to parse server URL: %v", err)
}
cacheExplicit := NewOGTagCache(ts.URL, conf, memory.New(t.Context()), TargetOptions{
SNI: expectedSNI,
InsecureSkipVerify: true,
})
if _, err := cacheExplicit.GetOGTags(t.Context(), targetURL, originalHost); err != nil {
t.Fatalf("expected successful fetch with explicit SNI, got: %v", err)
}
if got := recorder.last(); got != expectedSNI {
t.Fatalf("expected server to see SNI %q, got %q", expectedSNI, got)
}
})
t.Run("auto SNI uses original host", func(t *testing.T) {
ts, recorder := newSNIServer(t, `<!DOCTYPE html><html><head><meta property="og:title" content="SNI Auto" /></head><body>ok</body></html>`)
defer ts.Close()
targetURL, err := url.Parse(ts.URL)
if err != nil {
t.Fatalf("failed to parse server URL: %v", err)
}
cacheAuto := NewOGTagCache(ts.URL, conf, memory.New(t.Context()), TargetOptions{
SNI: "auto",
InsecureSkipVerify: true,
})
if _, err := cacheAuto.GetOGTags(t.Context(), targetURL, originalHost); err != nil {
t.Fatalf("expected successful fetch with auto SNI, got: %v", err)
}
if got := recorder.last(); got != originalHost {
t.Fatalf("expected server to see SNI %q with auto, got %q", originalHost, got)
}
})
t.Run("default SNI uses backend host", func(t *testing.T) {
ts, recorder := newSNIServer(t, `<!DOCTYPE html><html><head><meta property="og:title" content="SNI Default" /></head><body>ok</body></html>`)
defer ts.Close()
targetURL, err := url.Parse(ts.URL)
if err != nil {
t.Fatalf("failed to parse server URL: %v", err)
}
cacheDefault := NewOGTagCache(ts.URL, conf, memory.New(t.Context()), TargetOptions{
InsecureSkipVerify: true,
})
if _, err := cacheDefault.GetOGTags(t.Context(), targetURL, originalHost); err != nil {
t.Fatalf("expected successful fetch without explicit SNI, got: %v", err)
}
wantSNI := ""
if net.ParseIP(targetURL.Hostname()) == nil {
wantSNI = targetURL.Hostname()
}
if got := recorder.last(); got != wantSNI {
t.Fatalf("expected default SNI %q, got %q", wantSNI, got)
}
})
}
func newSNIServer(t *testing.T, body string) (*httptest.Server, *sniRecorder) {
t.Helper()
handler := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
w.Header().Set("Content-Type", "text/html")
fmt.Fprint(w, body)
})
recorder := &sniRecorder{}
ts := httptest.NewUnstartedServer(handler)
cert := mustCertificateForHost(t, "sni.test")
ts.TLS = &tls.Config{
GetCertificate: func(hello *tls.ClientHelloInfo) (*tls.Certificate, error) {
recorder.record(hello.ServerName)
return &cert, nil
},
}
ts.StartTLS()
return ts, recorder
}
func mustCertificateForHost(t *testing.T, host string) tls.Certificate {
t.Helper()
priv, err := rsa.GenerateKey(rand.Reader, 2048)
if err != nil {
t.Fatalf("failed to generate key: %v", err)
}
template := &x509.Certificate{
SerialNumber: big.NewInt(1),
Subject: pkix.Name{
CommonName: host,
},
NotBefore: time.Now().Add(-time.Hour),
NotAfter: time.Now().Add(time.Hour),
ExtKeyUsage: []x509.ExtKeyUsage{x509.ExtKeyUsageServerAuth},
KeyUsage: x509.KeyUsageDigitalSignature | x509.KeyUsageKeyEncipherment,
BasicConstraintsValid: true,
DNSNames: []string{host},
}
der, err := x509.CreateCertificate(rand.Reader, template, template, &priv.PublicKey, priv)
if err != nil {
t.Fatalf("failed to create certificate: %v", err)
}
return tls.Certificate{
Certificate: [][]byte{der},
PrivateKey: priv,
}
}
type sniRecorder struct {
mu sync.Mutex
names []string
}
func (r *sniRecorder) record(name string) {
r.mu.Lock()
defer r.mu.Unlock()
r.names = append(r.names, name)
}
func (r *sniRecorder) last() string {
r.mu.Lock()
defer r.mu.Unlock()
if len(r.names) == 0 {
return ""
}
return r.names[len(r.names)-1]
}

View File

@@ -18,7 +18,7 @@ func TestExtractOGTags(t *testing.T) {
Enabled: false, Enabled: false,
ConsiderHost: false, ConsiderHost: false,
TimeToLive: time.Minute, TimeToLive: time.Minute,
}, memory.New(t.Context())) }, memory.New(t.Context()), TargetOptions{})
// Manually set approved tags/prefixes based on the user request for clarity // Manually set approved tags/prefixes based on the user request for clarity
testCache.approvedTags = []string{"description"} testCache.approvedTags = []string{"description"}
testCache.approvedPrefixes = []string{"og:"} testCache.approvedPrefixes = []string{"og:"}
@@ -199,7 +199,7 @@ func TestExtractMetaTagInfo(t *testing.T) {
Enabled: false, Enabled: false,
ConsiderHost: false, ConsiderHost: false,
TimeToLive: time.Minute, TimeToLive: time.Minute,
}, memory.New(t.Context())) }, memory.New(t.Context()), TargetOptions{})
testCache.approvedTags = []string{"description"} testCache.approvedTags = []string{"description"}
testCache.approvedPrefixes = []string{"og:"} testCache.approvedPrefixes = []string{"og:"}

View File

@@ -27,27 +27,30 @@ import (
) )
type Options struct { type Options struct {
Next http.Handler Next http.Handler
Policy *policy.ParsedConfig Policy *policy.ParsedConfig
Target string Target string
CookieDynamicDomain bool TargetHost string
CookieDomain string TargetSNI string
CookieExpiration time.Duration TargetInsecureSkipVerify bool
CookiePartitioned bool CookieDynamicDomain bool
BasePrefix string CookieDomain string
WebmasterEmail string CookieExpiration time.Duration
RedirectDomains []string CookiePartitioned bool
ED25519PrivateKey ed25519.PrivateKey BasePrefix string
HS512Secret []byte WebmasterEmail string
StripBasePrefix bool RedirectDomains []string
OpenGraph config.OpenGraph ED25519PrivateKey ed25519.PrivateKey
ServeRobotsTXT bool HS512Secret []byte
CookieSecure bool StripBasePrefix bool
CookieSameSite http.SameSite OpenGraph config.OpenGraph
Logger *slog.Logger ServeRobotsTXT bool
PublicUrl string CookieSecure bool
JWTRestrictionHeader string CookieSameSite http.SameSite
DifficultyInJWT bool Logger *slog.Logger
PublicUrl string
JWTRestrictionHeader string
DifficultyInJWT bool
} }
func LoadPoliciesOrDefault(ctx context.Context, fname string, defaultDifficulty int) (*policy.ParsedConfig, error) { func LoadPoliciesOrDefault(ctx context.Context, fname string, defaultDifficulty int) (*policy.ParsedConfig, error) {
@@ -116,9 +119,13 @@ func New(opts Options) (*Server, error) {
hs512Secret: opts.HS512Secret, hs512Secret: opts.HS512Secret,
policy: opts.Policy, policy: opts.Policy,
opts: opts, opts: opts,
OGTags: ogtags.NewOGTagCache(opts.Target, opts.Policy.OpenGraph, opts.Policy.Store), OGTags: ogtags.NewOGTagCache(opts.Target, opts.Policy.OpenGraph, opts.Policy.Store, ogtags.TargetOptions{
store: opts.Policy.Store, Host: opts.TargetHost,
logger: opts.Logger, SNI: opts.TargetSNI,
InsecureSkipVerify: opts.TargetInsecureSkipVerify,
}),
store: opts.Policy.Store,
logger: opts.Logger,
} }
mux := http.NewServeMux() mux := http.NewServeMux()