diff --git a/cmd/anubis/main.go b/cmd/anubis/main.go index 65241b95..6ad10274 100644 --- a/cmd/anubis/main.go +++ b/cmd/anubis/main.go @@ -439,26 +439,29 @@ func main() { } s, err := libanubis.New(libanubis.Options{ - BasePrefix: *basePrefix, - StripBasePrefix: *stripBasePrefix, - Next: rp, - Policy: policy, - ServeRobotsTXT: *robotsTxt, - ED25519PrivateKey: ed25519Priv, - HS512Secret: []byte(*hs512Secret), - CookieDomain: *cookieDomain, - CookieDynamicDomain: *cookieDynamicDomain, - CookieExpiration: *cookieExpiration, - CookiePartitioned: *cookiePartitioned, - RedirectDomains: redirectDomainsList, - Target: *target, - WebmasterEmail: *webmasterEmail, - OpenGraph: policy.OpenGraph, - CookieSecure: *cookieSecure, - CookieSameSite: parseSameSite(*cookieSameSite), - PublicUrl: *publicUrl, - JWTRestrictionHeader: *jwtRestrictionHeader, - DifficultyInJWT: *difficultyInJWT, + BasePrefix: *basePrefix, + StripBasePrefix: *stripBasePrefix, + Next: rp, + Policy: policy, + TargetHost: *targetHost, + TargetSNI: *targetSNI, + TargetInsecureSkipVerify: *targetInsecureSkipVerify, + ServeRobotsTXT: *robotsTxt, + ED25519PrivateKey: ed25519Priv, + HS512Secret: []byte(*hs512Secret), + CookieDomain: *cookieDomain, + CookieDynamicDomain: *cookieDynamicDomain, + CookieExpiration: *cookieExpiration, + CookiePartitioned: *cookiePartitioned, + RedirectDomains: redirectDomainsList, + Target: *target, + WebmasterEmail: *webmasterEmail, + OpenGraph: policy.OpenGraph, + CookieSecure: *cookieSecure, + CookieSameSite: parseSameSite(*cookieSameSite), + PublicUrl: *publicUrl, + JWTRestrictionHeader: *jwtRestrictionHeader, + DifficultyInJWT: *difficultyInJWT, }) if err != nil { log.Fatalf("can't construct libanubis.Server: %v", err) diff --git a/internal/ogtags/cache_test.go b/internal/ogtags/cache_test.go index 08bf4e34..89ba2299 100644 --- a/internal/ogtags/cache_test.go +++ b/internal/ogtags/cache_test.go @@ -24,7 +24,7 @@ func TestCacheReturnsDefault(t *testing.T) { TimeToLive: time.Minute, ConsiderHost: false, Override: want, - }, memory.New(t.Context())) + }, memory.New(t.Context()), TargetOptions{}) u, err := url.Parse("https://anubis.techaro.lol") if err != nil { @@ -52,7 +52,7 @@ func TestCheckCache(t *testing.T) { Enabled: true, TimeToLive: time.Minute, ConsiderHost: false, - }, memory.New(t.Context())) + }, memory.New(t.Context()), TargetOptions{}) // Set up test data urlStr := "http://example.com/page" @@ -115,7 +115,7 @@ func TestGetOGTags(t *testing.T) { Enabled: true, TimeToLive: time.Minute, ConsiderHost: false, - }, memory.New(t.Context())) + }, memory.New(t.Context()), TargetOptions{}) // Parse the test server URL parsedURL, err := url.Parse(ts.URL) @@ -271,7 +271,7 @@ func TestGetOGTagsWithHostConsideration(t *testing.T) { Enabled: true, TimeToLive: time.Minute, ConsiderHost: tc.ogCacheConsiderHost, - }, memory.New(t.Context())) + }, memory.New(t.Context()), TargetOptions{}) for i, req := range tc.requests { ogTags, err := cache.GetOGTags(t.Context(), parsedURL, req.host) diff --git a/internal/ogtags/fetch.go b/internal/ogtags/fetch.go index 26a0af2a..384864ca 100644 --- a/internal/ogtags/fetch.go +++ b/internal/ogtags/fetch.go @@ -2,6 +2,7 @@ package ogtags import ( "context" + "crypto/tls" "errors" "fmt" "io" @@ -27,7 +28,10 @@ func (c *OGTagCache) fetchHTMLDocumentWithCache(ctx context.Context, urlStr stri } // Set the Host header to the original host - if originalHost != "" { + switch { + case c.targetHost != "": + req.Host = c.targetHost + case originalHost != "": req.Host = originalHost } @@ -35,8 +39,34 @@ func (c *OGTagCache) fetchHTMLDocumentWithCache(ctx context.Context, urlStr stri req.Header.Set("X-Forwarded-Proto", "https") req.Header.Set("User-Agent", "Anubis-OGTag-Fetcher/1.0") // For tracking purposes + client := c.client + + if c.targetSNIAuto { + serverName := originalHost + if c.targetHost != "" { + serverName = c.targetHost + } + + if serverName != "" { + transport := c.transport.Clone() + if transport.TLSClientConfig == nil { + transport.TLSClientConfig = &tls.Config{} + } + transport.TLSClientConfig.ServerName = serverName + if c.insecureSkipVerify { + transport.TLSClientConfig.InsecureSkipVerify = true + } + + client = &http.Client{ + Timeout: httpTimeout, + Transport: transport, + } + defer transport.CloseIdleConnections() + } + } + // Send the request - resp, err := c.client.Do(req) + resp, err := client.Do(req) if err != nil { var netErr net.Error if errors.As(err, &netErr) && netErr.Timeout() { diff --git a/internal/ogtags/fetch_test.go b/internal/ogtags/fetch_test.go index c986272a..864e8f2b 100644 --- a/internal/ogtags/fetch_test.go +++ b/internal/ogtags/fetch_test.go @@ -87,7 +87,7 @@ func TestFetchHTMLDocument(t *testing.T) { Enabled: true, TimeToLive: time.Minute, ConsiderHost: false, - }, memory.New(t.Context())) + }, memory.New(t.Context()), TargetOptions{}) doc, err := cache.fetchHTMLDocument(t.Context(), ts.URL, "anything") if tt.expectError { @@ -118,7 +118,7 @@ func TestFetchHTMLDocumentInvalidURL(t *testing.T) { Enabled: true, TimeToLive: time.Minute, ConsiderHost: false, - }, memory.New(t.Context())) + }, memory.New(t.Context()), TargetOptions{}) doc, err := cache.fetchHTMLDocument(t.Context(), "http://invalid.url.that.doesnt.exist.example", "anything") diff --git a/internal/ogtags/integration_test.go b/internal/ogtags/integration_test.go index 574172d0..af56668b 100644 --- a/internal/ogtags/integration_test.go +++ b/internal/ogtags/integration_test.go @@ -111,7 +111,7 @@ func TestIntegrationGetOGTags(t *testing.T) { Enabled: true, TimeToLive: time.Minute, ConsiderHost: false, - }, memory.New(t.Context())) + }, memory.New(t.Context()), TargetOptions{}) // Create URL for test testURL, _ := url.Parse(ts.URL) diff --git a/internal/ogtags/mem_test.go b/internal/ogtags/mem_test.go index b415cda4..7d2ac0cb 100644 --- a/internal/ogtags/mem_test.go +++ b/internal/ogtags/mem_test.go @@ -31,7 +31,7 @@ func BenchmarkGetTarget(b *testing.B) { for _, tt := range tests { b.Run(tt.name, func(b *testing.B) { - cache := NewOGTagCache(tt.target, config.OpenGraph{}, memory.New(b.Context())) + cache := NewOGTagCache(tt.target, config.OpenGraph{}, memory.New(b.Context()), TargetOptions{}) urls := make([]*url.URL, len(tt.paths)) for i, path := range tt.paths { u, _ := url.Parse(path) @@ -67,7 +67,7 @@ func BenchmarkExtractOGTags(b *testing.B) {
Content