diff --git a/internal/ogtags/fetch.go b/internal/ogtags/fetch.go index 384864ca..0bfb0a11 100644 --- a/internal/ogtags/fetch.go +++ b/internal/ogtags/fetch.go @@ -2,7 +2,6 @@ package ogtags import ( "context" - "crypto/tls" "errors" "fmt" "io" @@ -28,42 +27,26 @@ func (c *OGTagCache) fetchHTMLDocumentWithCache(ctx context.Context, urlStr stri } // Set the Host header to the original host + var hostForRequest string switch { case c.targetHost != "": - req.Host = c.targetHost + hostForRequest = c.targetHost case originalHost != "": - req.Host = originalHost + hostForRequest = originalHost + } + if hostForRequest != "" { + req.Host = hostForRequest } // Add proxy headers req.Header.Set("X-Forwarded-Proto", "https") req.Header.Set("User-Agent", "Anubis-OGTag-Fetcher/1.0") // For tracking purposes - client := c.client - - if c.targetSNIAuto { - serverName := originalHost - if c.targetHost != "" { - serverName = c.targetHost - } - - if serverName != "" { - transport := c.transport.Clone() - if transport.TLSClientConfig == nil { - transport.TLSClientConfig = &tls.Config{} - } - transport.TLSClientConfig.ServerName = serverName - if c.insecureSkipVerify { - transport.TLSClientConfig.InsecureSkipVerify = true - } - - client = &http.Client{ - Timeout: httpTimeout, - Transport: transport, - } - defer transport.CloseIdleConnections() - } + serverName := hostForRequest + if serverName == "" { + serverName = req.URL.Hostname() } + client := c.clientForSNI(serverName) // Send the request resp, err := client.Do(req) diff --git a/internal/ogtags/ogtags.go b/internal/ogtags/ogtags.go index 1d6cafc2..66c13078 100644 --- a/internal/ogtags/ogtags.go +++ b/internal/ogtags/ogtags.go @@ -8,6 +8,7 @@ import ( "net/http" "net/url" "strings" + "sync" "time" "github.com/TecharoHQ/anubis/lib/policy/config" @@ -40,6 +41,8 @@ type OGTagCache struct { targetSNI string targetSNIAuto bool insecureSkipVerify bool + sniClients map[string]*http.Client + transportMu sync.RWMutex } type TargetOptions struct { @@ -124,6 +127,7 @@ func NewOGTagCache(target string, conf config.OpenGraph, backend store.Interface targetSNI: targetOpts.SNI, targetSNIAuto: targetSNIAuto, insecureSkipVerify: targetOpts.InsecureSkipVerify, + sniClients: make(map[string]*http.Client), } } diff --git a/internal/ogtags/sni.go b/internal/ogtags/sni.go new file mode 100644 index 00000000..46cfe031 --- /dev/null +++ b/internal/ogtags/sni.go @@ -0,0 +1,42 @@ +package ogtags + +import ( + "crypto/tls" + "net/http" +) + +// clientForSNI returns a cached client for the given server name, creating one if needed. +func (c *OGTagCache) clientForSNI(serverName string) *http.Client { + if !c.targetSNIAuto || serverName == "" { + return c.client + } + + c.transportMu.RLock() + cli, ok := c.sniClients[serverName] + c.transportMu.RUnlock() + if ok { + return cli + } + + c.transportMu.Lock() + defer c.transportMu.Unlock() + if cli, ok := c.sniClients[serverName]; ok { + return cli + } + + tr := c.transport.Clone() + if tr.TLSClientConfig == nil { + tr.TLSClientConfig = &tls.Config{} + } + tr.TLSClientConfig.ServerName = serverName + if c.insecureSkipVerify { + tr.TLSClientConfig.InsecureSkipVerify = true + } + + cli = &http.Client{ + Timeout: httpTimeout, + Transport: tr, + } + c.sniClients[serverName] = cli + return cli +}