refactor(lib): keep server URL state local

Signed-off-by: Jason Cameron <jason.cameron@stanwith.me>
This commit is contained in:
Jason Cameron
2026-05-14 01:10:29 -04:00
parent c184028d42
commit b250fea00b
4 changed files with 190 additions and 48 deletions
+50 -15
View File
@@ -13,6 +13,7 @@ import (
"net/url" "net/url"
"strconv" "strconv"
"strings" "strings"
"sync"
"time" "time"
"github.com/golang-jwt/jwt/v5" "github.com/golang-jwt/jwt/v5"
@@ -97,6 +98,8 @@ type Server struct {
OGTags *ogtags.OGTagCache OGTags *ogtags.OGTagCache
logger *slog.Logger logger *slog.Logger
opts Options opts Options
basePrefix string
publicURL string
ed25519Priv ed25519.PrivateKey ed25519Priv ed25519.PrivateKey
hs512Secret []byte hs512Secret []byte
} }
@@ -123,6 +126,50 @@ func (s *Server) getRequestLogger(r *http.Request) (*slog.Logger, *http.Request)
return lg, r return lg, r
} }
var anubisBasePrefixMu sync.Mutex
func (s *Server) configuredBasePrefix() string {
if s.basePrefix != "" {
return s.basePrefix
}
return strings.TrimRight(s.opts.BasePrefix, "/")
}
func (s *Server) configuredPublicURL() string {
if s.publicURL != "" {
return s.publicURL
}
return s.opts.PublicUrl
}
func (s *Server) prefixedPath(path string) string {
if !strings.HasPrefix(path, "/") {
path = "/" + path
}
return s.configuredBasePrefix() + path
}
func (s *Server) cookiePath() string {
basePrefix := s.configuredBasePrefix()
if basePrefix == "" {
return "/"
}
return basePrefix + "/"
}
func (s *Server) withAnubisBasePrefix(fn func()) {
anubisBasePrefixMu.Lock()
defer anubisBasePrefixMu.Unlock()
oldBasePrefix := anubis.BasePrefix
anubis.BasePrefix = s.configuredBasePrefix()
defer func() {
anubis.BasePrefix = oldBasePrefix
}()
fn()
}
func (s *Server) getTokenKeyfunc() jwt.Keyfunc { func (s *Server) getTokenKeyfunc() jwt.Keyfunc {
// return ED25519 key if HS512 is not set // return ED25519 key if HS512 is not set
if len(s.hs512Secret) == 0 { if len(s.hs512Secret) == 0 {
@@ -246,11 +293,7 @@ func (s *Server) maybeReverseProxy(w http.ResponseWriter, r *http.Request, httpS
return return
} }
// Adjust cookie path if base prefix is not empty cookiePath := s.cookiePath()
cookiePath := "/"
if anubis.BasePrefix != "" {
cookiePath = strings.TrimSuffix(anubis.BasePrefix, "/") + "/"
}
cr, rule, err := s.check(r, lg) cr, rule, err := s.check(r, lg)
if err != nil { if err != nil {
@@ -344,11 +387,7 @@ func (s *Server) maybeReverseProxy(w http.ResponseWriter, r *http.Request, httpS
} }
func (s *Server) checkRules(w http.ResponseWriter, r *http.Request, cr policy.CheckResult, lg *slog.Logger, rule *policy.Bot) bool { func (s *Server) checkRules(w http.ResponseWriter, r *http.Request, cr policy.CheckResult, lg *slog.Logger, rule *policy.Bot) bool {
// Adjust cookie path if base prefix is not empty cookiePath := s.cookiePath()
cookiePath := "/"
if anubis.BasePrefix != "" {
cookiePath = strings.TrimSuffix(anubis.BasePrefix, "/") + "/"
}
localizer := localization.GetLocalizer(r) localizer := localization.GetLocalizer(r)
@@ -511,11 +550,7 @@ func (s *Server) PassChallenge(w http.ResponseWriter, r *http.Request) {
return return
} }
// Adjust cookie path if base prefix is not empty cookiePath := s.cookiePath()
cookiePath := "/"
if anubis.BasePrefix != "" {
cookiePath = strings.TrimSuffix(anubis.BasePrefix, "/") + "/"
}
if _, err := r.Cookie(anubis.TestCookieName); errors.Is(err, http.ErrNoCookie) { if _, err := r.Cookie(anubis.TestCookieName); errors.Is(err, http.ErrNoCookie) {
s.ClearCookie(w, CookieOpts{Path: cookiePath, Host: r.Host}) s.ClearCookie(w, CookieOpts{Path: cookiePath, Host: r.Host})
+92 -2
View File
@@ -2,6 +2,7 @@ package lib
import ( import (
"bytes" "bytes"
"compress/gzip"
"context" "context"
"encoding/json" "encoding/json"
"fmt" "fmt"
@@ -501,8 +502,11 @@ func TestBasePrefix(t *testing.T) {
for _, tc := range testCases { for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) { t.Run(tc.name, func(t *testing.T) {
// Reset the global BasePrefix before each test originalBasePrefix := anubis.BasePrefix
anubis.BasePrefix = "" anubis.BasePrefix = "/not-this-server"
t.Cleanup(func() {
anubis.BasePrefix = originalBasePrefix
})
pol := loadPolicies(t, "", 4) pol := loadPolicies(t, "", 4)
@@ -631,10 +635,96 @@ func TestBasePrefix(t *testing.T) {
if ckie.Path != expectedPath { if ckie.Path != expectedPath {
t.Errorf("cookie path is wrong, wanted %s, got: %s", expectedPath, ckie.Path) t.Errorf("cookie path is wrong, wanted %s, got: %s", expectedPath, ckie.Path)
} }
if anubis.BasePrefix != "/not-this-server" {
t.Errorf("New should not overwrite anubis.BasePrefix, got %q", anubis.BasePrefix)
}
}) })
} }
} }
func TestBasePrefixUsesServerConfigNotPackageGlobal(t *testing.T) {
originalBasePrefix := anubis.BasePrefix
anubis.BasePrefix = "/global"
t.Cleanup(func() {
anubis.BasePrefix = originalBasePrefix
})
pol := loadPolicies(t, "", 4)
srv := spawnAnubis(t, Options{
Next: http.NewServeMux(),
Policy: pol,
BasePrefix: "/local/",
})
if anubis.BasePrefix != "/global" {
t.Fatalf("New should leave anubis.BasePrefix alone, got %q", anubis.BasePrefix)
}
req := httptest.NewRequest(http.MethodPost, "/local/.within.website/x/cmd/anubis/api/make-challenge?redir=/", nil)
req.Header.Set("X-Real-Ip", "127.0.0.1")
rr := httptest.NewRecorder()
srv.ServeHTTP(rr, req)
if rr.Code != http.StatusOK {
t.Fatalf("expected local server prefix route to work, got status %d body %q", rr.Code, rr.Body.String())
}
}
func TestRenderIndexUsesServerBasePrefixNotPackageGlobal(t *testing.T) {
originalBasePrefix := anubis.BasePrefix
anubis.BasePrefix = "/global"
t.Cleanup(func() {
anubis.BasePrefix = originalBasePrefix
})
pol := loadPolicies(t, "", 4)
srv := spawnAnubis(t, Options{
Next: http.NewServeMux(),
Policy: pol,
BasePrefix: "/local/",
})
req := httptest.NewRequest(http.MethodGet, "/local/protected", nil)
req.Header.Set("X-Real-Ip", "127.0.0.1")
req.Header.Set("User-Agent", "Mozilla/5.0")
req.Header.Set("Accept-Encoding", "gzip")
rr := httptest.NewRecorder()
srv.ServeHTTP(rr, req)
resp := rr.Result()
defer resp.Body.Close()
var body strings.Builder
reader := io.Reader(resp.Body)
if resp.Header.Get("Content-Encoding") == "gzip" {
gzipReader, err := gzip.NewReader(resp.Body)
if err != nil {
t.Fatalf("opening gzip response should not fail: %v", err)
}
defer gzipReader.Close()
reader = gzipReader
}
if _, err := io.Copy(&body, reader); err != nil {
t.Fatalf("reading challenge response should not fail: %v", err)
}
if resp.StatusCode != http.StatusOK {
t.Fatalf("expected challenge status %d, got %d body %q", http.StatusOK, resp.StatusCode, body.String())
}
if !strings.Contains(body.String(), "/local/.within.website/x/cmd/anubis/static/js/main.mjs") {
t.Fatalf("expected challenge assets to use server base prefix, body %q", body.String())
}
if strings.Contains(body.String(), "/global/.within.website") {
t.Fatalf("challenge body used package global base prefix: %q", body.String())
}
if anubis.BasePrefix != "/global" {
t.Fatalf("rendering should restore anubis.BasePrefix, got %q", anubis.BasePrefix)
}
}
func TestCustomStatusCodes(t *testing.T) { func TestCustomStatusCodes(t *testing.T) {
h := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { h := http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
t.Log(r.UserAgent()) t.Log(r.UserAgent())
+13 -17
View File
@@ -112,8 +112,7 @@ func New(opts Options) (*Server, error) {
opts.ED25519PrivateKey = priv opts.ED25519PrivateKey = priv
} }
anubis.BasePrefix = strings.TrimRight(opts.BasePrefix, "/") opts.BasePrefix = strings.TrimRight(opts.BasePrefix, "/")
anubis.PublicUrl = opts.PublicUrl
result := &Server{ result := &Server{
next: opts.Next, next: opts.Next,
@@ -121,6 +120,8 @@ func New(opts Options) (*Server, error) {
hs512Secret: opts.HS512Secret, hs512Secret: opts.HS512Secret,
policy: opts.Policy, policy: opts.Policy,
opts: opts, opts: opts,
basePrefix: opts.BasePrefix,
publicURL: opts.PublicUrl,
OGTags: ogtags.NewOGTagCache(opts.Target, opts.Policy.OpenGraph, opts.Policy.Store, ogtags.TargetOptions{ OGTags: ogtags.NewOGTagCache(opts.Target, opts.Policy.OpenGraph, opts.Policy.Store, ogtags.TargetOptions{
Host: opts.TargetHost, Host: opts.TargetHost,
SNI: opts.TargetSNI, SNI: opts.TargetSNI,
@@ -131,28 +132,20 @@ func New(opts Options) (*Server, error) {
} }
mux := http.NewServeMux() mux := http.NewServeMux()
xess.Mount(mux) xessPrefix := result.prefixedPath(xess.BasePrefix)
mux.Handle(xessPrefix, internal.UnchangingCache(http.StripPrefix(xessPrefix, http.FileServerFS(xess.Static))))
// Helper to add global prefix // Helper to add the server-local base prefix.
registerWithPrefix := func(pattern string, handler http.Handler, method string) { registerWithPrefix := func(pattern string, handler http.Handler, method string) {
if method != "" { if method != "" {
method = method + " " // methods must end with a space to register with them method = method + " " // methods must end with a space to register with them
} }
// Ensure there's no double slash when concatenating BasePrefix and pattern mux.Handle(method+result.prefixedPath(pattern), handler)
basePrefix := strings.TrimSuffix(anubis.BasePrefix, "/")
prefix := method + basePrefix
// If pattern doesn't start with a slash, add one
if !strings.HasPrefix(pattern, "/") {
pattern = "/" + pattern
}
mux.Handle(prefix+pattern, handler)
} }
// Ensure there's no double slash when concatenating BasePrefix and StaticPath // Ensure there's no double slash when concatenating BasePrefix and StaticPath
stripPrefix := strings.TrimSuffix(anubis.BasePrefix, "/") + anubis.StaticPath stripPrefix := result.prefixedPath(anubis.StaticPath)
registerWithPrefix(anubis.StaticPath, internal.UnchangingCache(internal.NoBrowsing(http.StripPrefix(stripPrefix, http.FileServerFS(web.Static)))), "") registerWithPrefix(anubis.StaticPath, internal.UnchangingCache(internal.NoBrowsing(http.StripPrefix(stripPrefix, http.FileServerFS(web.Static)))), "")
if opts.ServeRobotsTXT { if opts.ServeRobotsTXT {
@@ -166,9 +159,12 @@ func New(opts Options) (*Server, error) {
if opts.Policy.Impressum != nil { if opts.Policy.Impressum != nil {
registerWithPrefix(anubis.APIPrefix+"imprint", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { registerWithPrefix(anubis.APIPrefix+"imprint", http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
templ.Handler( handler := templ.Handler(
web.Base(opts.Policy.Impressum.Page.Title, opts.Policy.Impressum.Page, opts.Policy.Impressum, localization.GetLocalizer(r)), web.Base(opts.Policy.Impressum.Page.Title, opts.Policy.Impressum.Page, opts.Policy.Impressum, localization.GetLocalizer(r)),
).ServeHTTP(w, r) )
result.withAnubisBasePrefix(func() {
handler.ServeHTTP(w, r)
})
}), "GET") }), "GET")
} }
+35 -14
View File
@@ -193,7 +193,7 @@ func (s *Server) RenderIndex(w http.ResponseWriter, r *http.Request, cr policy.C
localizer := localization.GetLocalizer(r) localizer := localization.GetLocalizer(r)
if returnHTTPStatusOnly { if returnHTTPStatusOnly {
if s.opts.PublicUrl == "" { if s.configuredPublicURL() == "" {
w.WriteHeader(http.StatusUnauthorized) w.WriteHeader(http.StatusUnauthorized)
w.Write([]byte(localizer.T("authorization_required"))) w.Write([]byte(localizer.T("authorization_required")))
} else { } else {
@@ -270,7 +270,10 @@ func (s *Server) RenderIndex(w http.ResponseWriter, r *http.Request, cr policy.C
Store: s.store, Store: s.store,
} }
component, err := impl.Issue(w, r, lg, in) var component templ.Component
s.withAnubisBasePrefix(func() {
component, err = impl.Issue(w, r, lg, in)
})
if err != nil { if err != nil {
lg.Error("[unexpected] challenge component render failed, please open an issue", "err", err) // This is likely a bug in the template. Should never be triggered as CI tests for this. lg.Error("[unexpected] challenge component render failed, please open an issue", "err", err) // This is likely a bug in the template. Should never be triggered as CI tests for this.
s.respondWithError(w, r, fmt.Sprintf("%s \"RenderIndex\"", localizer.T("internal_server_error")), makeCode(err)) s.respondWithError(w, r, fmt.Sprintf("%s \"RenderIndex\"", localizer.T("internal_server_error")), makeCode(err))
@@ -291,7 +294,9 @@ func (s *Server) RenderIndex(w http.ResponseWriter, r *http.Request, cr policy.C
page, page,
templ.WithStatus(s.opts.Policy.StatusCodes.Challenge), templ.WithStatus(s.opts.Policy.StatusCodes.Challenge),
))) )))
handler.ServeHTTP(w, r) s.withAnubisBasePrefix(func() {
handler.ServeHTTP(w, r)
})
} }
func (s *Server) constructRedirectURL(r *http.Request) (string, error) { func (s *Server) constructRedirectURL(r *http.Request) (string, error) {
@@ -323,15 +328,18 @@ func (s *Server) constructRedirectURL(r *http.Request) (string, error) {
redir := proto + "://" + host + uri redir := proto + "://" + host + uri
escapedURL := url.QueryEscape(redir) escapedURL := url.QueryEscape(redir)
return fmt.Sprintf("%s/.within.website/?redir=%s", s.opts.PublicUrl, escapedURL), nil return fmt.Sprintf("%s/.within.website/?redir=%s", s.configuredPublicURL(), escapedURL), nil
} }
func (s *Server) RenderBench(w http.ResponseWriter, r *http.Request) { func (s *Server) RenderBench(w http.ResponseWriter, r *http.Request) {
localizer := localization.GetLocalizer(r) localizer := localization.GetLocalizer(r)
templ.Handler( handler := templ.Handler(
web.Base(localizer.T("benchmarking_anubis"), web.Bench(localizer), s.policy.Impressum, localizer), web.Base(localizer.T("benchmarking_anubis"), web.Bench(localizer), s.policy.Impressum, localizer),
).ServeHTTP(w, r) )
s.withAnubisBasePrefix(func() {
handler.ServeHTTP(w, r)
})
} }
func (s *Server) respondWithError(w http.ResponseWriter, r *http.Request, message, code string) { func (s *Server) respondWithError(w http.ResponseWriter, r *http.Request, message, code string) {
@@ -348,21 +356,31 @@ func (s *Server) respondWithStatus(w http.ResponseWriter, r *http.Request, msg,
localizer, localizer,
) )
handler := internal.NoStoreCache(templ.Handler(component, templ.WithStatus(status))) handler := internal.NoStoreCache(templ.Handler(component, templ.WithStatus(status)))
handler.ServeHTTP(w, r) s.withAnubisBasePrefix(func() {
handler.ServeHTTP(w, r)
})
} }
func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) { func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) {
if strings.HasPrefix(r.URL.Path, anubis.BasePrefix+anubis.StaticPath) { if strings.HasPrefix(r.URL.Path, s.prefixedPath(anubis.StaticPath)) {
s.mux.ServeHTTP(w, r) s.mux.ServeHTTP(w, r)
return return
} else if strings.HasPrefix(r.URL.Path, anubis.BasePrefix+xess.BasePrefix) { } else if strings.HasPrefix(r.URL.Path, s.prefixedPath(xess.BasePrefix)) {
s.mux.ServeHTTP(w, r) s.mux.ServeHTTP(w, r)
return return
} }
// Forward robots.txt requests to mux when ServeRobotsTXT is enabled // Forward robots.txt requests to mux when ServeRobotsTXT is enabled
if s.opts.ServeRobotsTXT { if s.opts.ServeRobotsTXT {
path := strings.TrimPrefix(r.URL.Path, anubis.BasePrefix) path := r.URL.Path
basePrefix := s.configuredBasePrefix()
if basePrefix != "" {
if !strings.HasPrefix(path, basePrefix) {
s.maybeReverseProxyOrPage(w, r)
return
}
path = strings.TrimPrefix(path, basePrefix)
}
if path == "/robots.txt" || path == "/.well-known/robots.txt" { if path == "/robots.txt" || path == "/.well-known/robots.txt" {
s.mux.ServeHTTP(w, r) s.mux.ServeHTTP(w, r)
return return
@@ -373,11 +391,11 @@ func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) {
} }
func (s *Server) stripBasePrefixFromRequest(r *http.Request) *http.Request { func (s *Server) stripBasePrefixFromRequest(r *http.Request) *http.Request {
if !s.opts.StripBasePrefix || s.opts.BasePrefix == "" { basePrefix := s.configuredBasePrefix()
if !s.opts.StripBasePrefix || basePrefix == "" {
return r return r
} }
basePrefix := strings.TrimSuffix(s.opts.BasePrefix, "/")
path := r.URL.Path path := r.URL.Path
if !strings.HasPrefix(path, basePrefix) { if !strings.HasPrefix(path, basePrefix) {
@@ -441,9 +459,12 @@ func (s *Server) ServeHTTPNext(w http.ResponseWriter, r *http.Request) {
return return
} }
templ.Handler( handler := templ.Handler(
web.Base(localizer.T("you_are_not_a_bot"), web.StaticHappy(localizer), s.policy.Impressum, localizer), web.Base(localizer.T("you_are_not_a_bot"), web.StaticHappy(localizer), s.policy.Impressum, localizer),
).ServeHTTP(w, r) )
s.withAnubisBasePrefix(func() {
handler.ServeHTTP(w, r)
})
} else { } else {
asn, asnDesc := asnFromContext(r.Context()) asn, asnDesc := asnFromContext(r.Context())
requestsProxied.WithLabelValues(r.Host, asn, asnDesc).Inc() requestsProxied.WithLabelValues(r.Host, asn, asnDesc).Inc()