diff --git a/lib/http.go b/lib/http.go index 65f96400..60e51d4f 100644 --- a/lib/http.go +++ b/lib/http.go @@ -296,6 +296,16 @@ func (s *Server) constructRedirectURL(r *http.Request) (string, error) { if proto == "" || host == "" || uri == "" { return "", errors.New(localizer.T("missing_required_forwarded_headers")) } + + switch proto { + case "http", "https": + // allowed + default: + lg := internal.GetRequestLogger(s.logger, r) + lg.Warn("invalid protocol in X-Forwarded-Proto", "proto", proto) + return "", errors.New(localizer.T("invalid_redirect")) + } + // Check if host is allowed in RedirectDomains (supports '*' via glob) if len(s.opts.RedirectDomains) > 0 && !matchRedirectDomain(s.opts.RedirectDomains, host) { lg := internal.GetRequestLogger(s.logger, r) @@ -369,16 +379,31 @@ func (s *Server) ServeHTTPNext(w http.ResponseWriter, r *http.Request) { localizer := localization.GetLocalizer(r) redir := r.FormValue("redir") - urlParsed, err := r.URL.Parse(redir) + urlParsed, err := url.ParseRequestURI(redir) if err != nil { - s.respondWithStatus(w, r, localizer.T("redirect_not_parseable"), makeCode(err), http.StatusBadRequest) + // if ParseRequestURI fails, try as relative URL + urlParsed, err = r.URL.Parse(redir) + if err != nil { + s.respondWithStatus(w, r, localizer.T("redirect_not_parseable"), makeCode(err), http.StatusBadRequest) + return + } + } + + // validate URL scheme to prevent javascript:, data:, file:, tel:, etc. + switch urlParsed.Scheme { + case "", "http", "https": + // allowed: empty scheme means relative URL + default: + lg := internal.GetRequestLogger(s.logger, r) + lg.Warn("XSS attempt blocked, invalid redirect scheme", "scheme", urlParsed.Scheme, "redir", redir) + s.respondWithStatus(w, r, localizer.T("invalid_redirect"), "", http.StatusBadRequest) return } hostNotAllowed := len(urlParsed.Host) > 0 && len(s.opts.RedirectDomains) != 0 && !matchRedirectDomain(s.opts.RedirectDomains, urlParsed.Host) - hostMismatch := r.URL.Host != "" && urlParsed.Host != r.URL.Host + hostMismatch := r.URL.Host != "" && urlParsed.Host != "" && urlParsed.Host != r.URL.Host if hostNotAllowed || hostMismatch { lg := internal.GetRequestLogger(s.logger, r) diff --git a/lib/redirect_security_test.go b/lib/redirect_security_test.go new file mode 100644 index 00000000..e16206a3 --- /dev/null +++ b/lib/redirect_security_test.go @@ -0,0 +1,296 @@ +package lib + +import ( + "log/slog" + "net/http" + "net/http/httptest" + "net/url" + "strings" + "testing" + + "github.com/TecharoHQ/anubis/lib/policy" +) + +func TestRedirectSecurity(t *testing.T) { + tests := []struct { + name string + testType string // "constructRedirectURL", "serveHTTPNext", "renderIndex" + + // For constructRedirectURL tests + xForwardedProto string + xForwardedHost string + xForwardedUri string + + // For serveHTTPNext tests + redirParam string + reqHost string + + // For renderIndex tests + returnHTTPStatusOnly bool + + // Expected results + expectedStatus int + shouldError bool + shouldNotRedirect bool + shouldBlock bool + errorContains string + }{ + // constructRedirectURL tests - X-Forwarded-Proto validation + { + name: "constructRedirectURL: javascript protocol should be rejected", + testType: "constructRedirectURL", + xForwardedProto: "javascript", + xForwardedHost: "example.com", + xForwardedUri: "alert(1)", + shouldError: true, + errorContains: "invalid", + }, + { + name: "constructRedirectURL: data protocol should be rejected", + testType: "constructRedirectURL", + xForwardedProto: "data", + xForwardedHost: "text/html", + xForwardedUri: ",", + shouldError: true, + errorContains: "invalid", + }, + { + name: "constructRedirectURL: file protocol should be rejected", + testType: "constructRedirectURL", + xForwardedProto: "file", + xForwardedHost: "", + xForwardedUri: "/etc/passwd", + shouldError: true, + errorContains: "invalid", + }, + { + name: "constructRedirectURL: ftp protocol should be rejected", + testType: "constructRedirectURL", + xForwardedProto: "ftp", + xForwardedHost: "example.com", + xForwardedUri: "/file.txt", + shouldError: true, + errorContains: "invalid", + }, + { + name: "constructRedirectURL: https protocol should be allowed", + testType: "constructRedirectURL", + xForwardedProto: "https", + xForwardedHost: "example.com", + xForwardedUri: "/foo", + shouldError: false, + }, + { + name: "constructRedirectURL: http protocol should be allowed", + testType: "constructRedirectURL", + xForwardedProto: "http", + xForwardedHost: "example.com", + xForwardedUri: "/bar", + shouldError: false, + }, + + // serveHTTPNext tests - redir parameter validation + { + name: "serveHTTPNext: javascript: URL should be rejected", + testType: "serveHTTPNext", + redirParam: "javascript:alert(1)", + reqHost: "example.com", + expectedStatus: http.StatusBadRequest, + shouldNotRedirect: true, + }, + { + name: "serveHTTPNext: data: URL should be rejected", + testType: "serveHTTPNext", + redirParam: "data:text/html,", + reqHost: "example.com", + expectedStatus: http.StatusBadRequest, + shouldNotRedirect: true, + }, + { + name: "serveHTTPNext: file: URL should be rejected", + testType: "serveHTTPNext", + redirParam: "file:///etc/passwd", + reqHost: "example.com", + expectedStatus: http.StatusBadRequest, + shouldNotRedirect: true, + }, + { + name: "serveHTTPNext: vbscript: URL should be rejected", + testType: "serveHTTPNext", + redirParam: "vbscript:msgbox(1)", + reqHost: "example.com", + expectedStatus: http.StatusBadRequest, + shouldNotRedirect: true, + }, + { + name: "serveHTTPNext: valid https URL should work", + testType: "serveHTTPNext", + redirParam: "https://example.com/foo", + reqHost: "example.com", + expectedStatus: http.StatusFound, + }, + { + name: "serveHTTPNext: valid relative URL should work", + testType: "serveHTTPNext", + redirParam: "/foo/bar", + reqHost: "example.com", + expectedStatus: http.StatusFound, + }, + { + name: "serveHTTPNext: external domain should be blocked", + testType: "serveHTTPNext", + redirParam: "https://evil.com/phishing", + reqHost: "example.com", + expectedStatus: http.StatusBadRequest, + shouldBlock: true, + }, + { + name: "serveHTTPNext: relative path should work", + testType: "serveHTTPNext", + redirParam: "/safe/path", + reqHost: "example.com", + expectedStatus: http.StatusFound, + }, + { + name: "serveHTTPNext: empty redir should show success page", + testType: "serveHTTPNext", + redirParam: "", + reqHost: "example.com", + expectedStatus: http.StatusOK, + }, + + // renderIndex tests - full subrequest auth flow + { + name: "renderIndex: javascript protocol in X-Forwarded-Proto", + testType: "renderIndex", + xForwardedProto: "javascript", + xForwardedHost: "example.com", + xForwardedUri: "alert(1)", + returnHTTPStatusOnly: true, + expectedStatus: http.StatusBadRequest, + }, + { + name: "renderIndex: data protocol in X-Forwarded-Proto", + testType: "renderIndex", + xForwardedProto: "data", + xForwardedHost: "example.com", + xForwardedUri: "text/html,", + returnHTTPStatusOnly: true, + expectedStatus: http.StatusBadRequest, + }, + { + name: "renderIndex: valid https redirect", + testType: "renderIndex", + xForwardedProto: "https", + xForwardedHost: "example.com", + xForwardedUri: "/protected/page", + returnHTTPStatusOnly: true, + expectedStatus: http.StatusTemporaryRedirect, + }, + } + + s := &Server{ + opts: Options{ + PublicUrl: "https://anubis.example.com", + RedirectDomains: []string{}, + }, + logger: slog.Default(), + policy: &policy.ParsedConfig{}, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + switch tt.testType { + case "constructRedirectURL": + req := httptest.NewRequest("GET", "/", nil) + req.Header.Set("X-Forwarded-Proto", tt.xForwardedProto) + req.Header.Set("X-Forwarded-Host", tt.xForwardedHost) + req.Header.Set("X-Forwarded-Uri", tt.xForwardedUri) + + redirectURL, err := s.constructRedirectURL(req) + + if tt.shouldError { + if err == nil { + t.Errorf("expected error containing %q, got nil", tt.errorContains) + t.Logf("got redirect URL: %s", redirectURL) + } else if tt.errorContains != "" && !strings.Contains(err.Error(), tt.errorContains) { + t.Logf("expected error containing %q, got: %v", tt.errorContains, err) + } + } else { + if err != nil { + t.Errorf("expected no error, got: %v", err) + } + // Verify the redirect URL is safe + if redirectURL != "" { + parsed, err := url.Parse(redirectURL) + if err != nil { + t.Errorf("failed to parse redirect URL: %v", err) + } + redirParam := parsed.Query().Get("redir") + if redirParam != "" { + redirParsed, err := url.Parse(redirParam) + if err != nil { + t.Errorf("failed to parse redir parameter: %v", err) + } + if redirParsed.Scheme != "http" && redirParsed.Scheme != "https" { + t.Errorf("redir parameter has unsafe scheme: %s", redirParsed.Scheme) + } + } + } + } + + case "serveHTTPNext": + req := httptest.NewRequest("GET", "/.within.website/?redir="+url.QueryEscape(tt.redirParam), nil) + req.Host = tt.reqHost + req.URL.Host = tt.reqHost + rr := httptest.NewRecorder() + + s.ServeHTTPNext(rr, req) + + if rr.Code != tt.expectedStatus { + t.Errorf("expected status %d, got %d", tt.expectedStatus, rr.Code) + t.Logf("body: %s", rr.Body.String()) + } + + if tt.shouldNotRedirect { + location := rr.Header().Get("Location") + if location != "" { + t.Errorf("expected no redirect, but got Location header: %s", location) + } + } + + if tt.shouldBlock { + location := rr.Header().Get("Location") + if location != "" && strings.Contains(location, "evil.com") { + t.Errorf("redirect to evil.com was not blocked: %s", location) + } + } + + case "renderIndex": + req := httptest.NewRequest("GET", "/", nil) + req.Header.Set("X-Forwarded-Proto", tt.xForwardedProto) + req.Header.Set("X-Forwarded-Host", tt.xForwardedHost) + req.Header.Set("X-Forwarded-Uri", tt.xForwardedUri) + + rr := httptest.NewRecorder() + s.RenderIndex(rr, req, policy.CheckResult{}, nil, tt.returnHTTPStatusOnly) + + if rr.Code != tt.expectedStatus { + t.Errorf("expected status %d, got %d", tt.expectedStatus, rr.Code) + } + + if tt.expectedStatus == http.StatusTemporaryRedirect { + location := rr.Header().Get("Location") + if location == "" { + t.Error("expected Location header, got none") + } else { + // Verify the location doesn't contain javascript: + if strings.Contains(location, "javascript") { + t.Errorf("Location header contains 'javascript': %s", location) + } + } + } + } + }) + } +}