diff --git a/internal/basicauth.go b/internal/basicauth.go new file mode 100644 index 00000000..69725ab0 --- /dev/null +++ b/internal/basicauth.go @@ -0,0 +1,52 @@ +package internal + +import ( + "crypto/sha256" + "crypto/subtle" + "fmt" + "log/slog" + "net/http" +) + +// BasicAuth wraps next in HTTP Basic authentication using the provided +// credentials. If either username or password is empty, next is returned +// unchanged and a debug log line is emitted. +// +// Credentials are compared in constant time to avoid leaking information +// through timing side channels. +func BasicAuth(realm, username, password string, next http.Handler) http.Handler { + if username == "" || password == "" { + slog.Debug("skipping middleware, basic auth credentials are empty") + return next + } + + expectedUser := sha256.Sum256([]byte(username)) + expectedPass := sha256.Sum256([]byte(password)) + challenge := fmt.Sprintf("Basic realm=%q, charset=\"UTF-8\"", realm) + + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + user, pass, ok := r.BasicAuth() + if !ok { + unauthorized(w, challenge) + return + } + + gotUser := sha256.Sum256([]byte(user)) + gotPass := sha256.Sum256([]byte(pass)) + + userMatch := subtle.ConstantTimeCompare(gotUser[:], expectedUser[:]) + passMatch := subtle.ConstantTimeCompare(gotPass[:], expectedPass[:]) + + if userMatch&passMatch != 1 { + unauthorized(w, challenge) + return + } + + next.ServeHTTP(w, r) + }) +} + +func unauthorized(w http.ResponseWriter, challenge string) { + w.Header().Set("WWW-Authenticate", challenge) + http.Error(w, "Unauthorized", http.StatusUnauthorized) +} \ No newline at end of file diff --git a/internal/basicauth_test.go b/internal/basicauth_test.go new file mode 100644 index 00000000..8492d226 --- /dev/null +++ b/internal/basicauth_test.go @@ -0,0 +1,138 @@ +package internal + +import ( + "net/http" + "net/http/httptest" + "strings" + "testing" +) + +func okHandler() http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(http.StatusOK) + _, _ = w.Write([]byte("ok")) + }) +} + +func TestBasicAuth(t *testing.T) { + t.Parallel() + + const ( + realm = "test-realm" + username = "admin" + password = "hunter2" + ) + + for _, tt := range []struct { + name string + setAuth bool + user string + pass string + wantStatus int + wantBody string + wantChall bool + }{ + { + name: "valid credentials", + setAuth: true, + user: username, + pass: password, + wantStatus: http.StatusOK, + wantBody: "ok", + }, + { + name: "missing credentials", + setAuth: false, + wantStatus: http.StatusUnauthorized, + wantChall: true, + }, + { + name: "wrong username", + setAuth: true, + user: "nobody", + pass: password, + wantStatus: http.StatusUnauthorized, + wantChall: true, + }, + { + name: "wrong password", + setAuth: true, + user: username, + pass: "wrong", + wantStatus: http.StatusUnauthorized, + wantChall: true, + }, + { + name: "empty supplied credentials", + setAuth: true, + user: "", + pass: "", + wantStatus: http.StatusUnauthorized, + wantChall: true, + }, + } { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + h := BasicAuth(realm, username, password, okHandler()) + + req := httptest.NewRequest(http.MethodGet, "/", nil) + if tt.setAuth { + req.SetBasicAuth(tt.user, tt.pass) + } + rec := httptest.NewRecorder() + h.ServeHTTP(rec, req) + + if rec.Code != tt.wantStatus { + t.Errorf("status = %d, want %d", rec.Code, tt.wantStatus) + } + + if tt.wantBody != "" && rec.Body.String() != tt.wantBody { + t.Errorf("body = %q, want %q", rec.Body.String(), tt.wantBody) + } + + chall := rec.Header().Get("WWW-Authenticate") + if tt.wantChall { + if chall == "" { + t.Error("WWW-Authenticate header missing on 401") + } + if !strings.Contains(chall, realm) { + t.Errorf("WWW-Authenticate = %q, want realm %q", chall, realm) + } + } else if chall != "" { + t.Errorf("unexpected WWW-Authenticate header: %q", chall) + } + }) + } +} + +func TestBasicAuthPassthrough(t *testing.T) { + t.Parallel() + + for _, tt := range []struct { + name string + username string + password string + }{ + {name: "empty username", username: "", password: "hunter2"}, + {name: "empty password", username: "admin", password: ""}, + {name: "both empty", username: "", password: ""}, + } { + t.Run(tt.name, func(t *testing.T) { + t.Parallel() + + h := BasicAuth("realm", tt.username, tt.password, okHandler()) + + req := httptest.NewRequest(http.MethodGet, "/", nil) + rec := httptest.NewRecorder() + h.ServeHTTP(rec, req) + + if rec.Code != http.StatusOK { + t.Errorf("status = %d, want %d (passthrough expected)", rec.Code, http.StatusOK) + } + if rec.Body.String() != "ok" { + t.Errorf("body = %q, want %q", rec.Body.String(), "ok") + } + }) + } +} \ No newline at end of file