mirror of
https://github.com/TecharoHQ/anubis.git
synced 2026-04-23 08:36:41 +00:00
feat(internal): add basic auth HTTP middleware
Signed-off-by: Xe Iaso <me@xeiaso.net>
This commit is contained in:
@@ -0,0 +1,52 @@
|
||||
package internal
|
||||
|
||||
import (
|
||||
"crypto/sha256"
|
||||
"crypto/subtle"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
"net/http"
|
||||
)
|
||||
|
||||
// BasicAuth wraps next in HTTP Basic authentication using the provided
|
||||
// credentials. If either username or password is empty, next is returned
|
||||
// unchanged and a debug log line is emitted.
|
||||
//
|
||||
// Credentials are compared in constant time to avoid leaking information
|
||||
// through timing side channels.
|
||||
func BasicAuth(realm, username, password string, next http.Handler) http.Handler {
|
||||
if username == "" || password == "" {
|
||||
slog.Debug("skipping middleware, basic auth credentials are empty")
|
||||
return next
|
||||
}
|
||||
|
||||
expectedUser := sha256.Sum256([]byte(username))
|
||||
expectedPass := sha256.Sum256([]byte(password))
|
||||
challenge := fmt.Sprintf("Basic realm=%q, charset=\"UTF-8\"", realm)
|
||||
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
user, pass, ok := r.BasicAuth()
|
||||
if !ok {
|
||||
unauthorized(w, challenge)
|
||||
return
|
||||
}
|
||||
|
||||
gotUser := sha256.Sum256([]byte(user))
|
||||
gotPass := sha256.Sum256([]byte(pass))
|
||||
|
||||
userMatch := subtle.ConstantTimeCompare(gotUser[:], expectedUser[:])
|
||||
passMatch := subtle.ConstantTimeCompare(gotPass[:], expectedPass[:])
|
||||
|
||||
if userMatch&passMatch != 1 {
|
||||
unauthorized(w, challenge)
|
||||
return
|
||||
}
|
||||
|
||||
next.ServeHTTP(w, r)
|
||||
})
|
||||
}
|
||||
|
||||
func unauthorized(w http.ResponseWriter, challenge string) {
|
||||
w.Header().Set("WWW-Authenticate", challenge)
|
||||
http.Error(w, "Unauthorized", http.StatusUnauthorized)
|
||||
}
|
||||
@@ -0,0 +1,138 @@
|
||||
package internal
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"strings"
|
||||
"testing"
|
||||
)
|
||||
|
||||
func okHandler() http.Handler {
|
||||
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
_, _ = w.Write([]byte("ok"))
|
||||
})
|
||||
}
|
||||
|
||||
func TestBasicAuth(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
const (
|
||||
realm = "test-realm"
|
||||
username = "admin"
|
||||
password = "hunter2"
|
||||
)
|
||||
|
||||
for _, tt := range []struct {
|
||||
name string
|
||||
setAuth bool
|
||||
user string
|
||||
pass string
|
||||
wantStatus int
|
||||
wantBody string
|
||||
wantChall bool
|
||||
}{
|
||||
{
|
||||
name: "valid credentials",
|
||||
setAuth: true,
|
||||
user: username,
|
||||
pass: password,
|
||||
wantStatus: http.StatusOK,
|
||||
wantBody: "ok",
|
||||
},
|
||||
{
|
||||
name: "missing credentials",
|
||||
setAuth: false,
|
||||
wantStatus: http.StatusUnauthorized,
|
||||
wantChall: true,
|
||||
},
|
||||
{
|
||||
name: "wrong username",
|
||||
setAuth: true,
|
||||
user: "nobody",
|
||||
pass: password,
|
||||
wantStatus: http.StatusUnauthorized,
|
||||
wantChall: true,
|
||||
},
|
||||
{
|
||||
name: "wrong password",
|
||||
setAuth: true,
|
||||
user: username,
|
||||
pass: "wrong",
|
||||
wantStatus: http.StatusUnauthorized,
|
||||
wantChall: true,
|
||||
},
|
||||
{
|
||||
name: "empty supplied credentials",
|
||||
setAuth: true,
|
||||
user: "",
|
||||
pass: "",
|
||||
wantStatus: http.StatusUnauthorized,
|
||||
wantChall: true,
|
||||
},
|
||||
} {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
h := BasicAuth(realm, username, password, okHandler())
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
if tt.setAuth {
|
||||
req.SetBasicAuth(tt.user, tt.pass)
|
||||
}
|
||||
rec := httptest.NewRecorder()
|
||||
h.ServeHTTP(rec, req)
|
||||
|
||||
if rec.Code != tt.wantStatus {
|
||||
t.Errorf("status = %d, want %d", rec.Code, tt.wantStatus)
|
||||
}
|
||||
|
||||
if tt.wantBody != "" && rec.Body.String() != tt.wantBody {
|
||||
t.Errorf("body = %q, want %q", rec.Body.String(), tt.wantBody)
|
||||
}
|
||||
|
||||
chall := rec.Header().Get("WWW-Authenticate")
|
||||
if tt.wantChall {
|
||||
if chall == "" {
|
||||
t.Error("WWW-Authenticate header missing on 401")
|
||||
}
|
||||
if !strings.Contains(chall, realm) {
|
||||
t.Errorf("WWW-Authenticate = %q, want realm %q", chall, realm)
|
||||
}
|
||||
} else if chall != "" {
|
||||
t.Errorf("unexpected WWW-Authenticate header: %q", chall)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestBasicAuthPassthrough(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
for _, tt := range []struct {
|
||||
name string
|
||||
username string
|
||||
password string
|
||||
}{
|
||||
{name: "empty username", username: "", password: "hunter2"},
|
||||
{name: "empty password", username: "admin", password: ""},
|
||||
{name: "both empty", username: "", password: ""},
|
||||
} {
|
||||
t.Run(tt.name, func(t *testing.T) {
|
||||
t.Parallel()
|
||||
|
||||
h := BasicAuth("realm", tt.username, tt.password, okHandler())
|
||||
|
||||
req := httptest.NewRequest(http.MethodGet, "/", nil)
|
||||
rec := httptest.NewRecorder()
|
||||
h.ServeHTTP(rec, req)
|
||||
|
||||
if rec.Code != http.StatusOK {
|
||||
t.Errorf("status = %d, want %d (passthrough expected)", rec.Code, http.StatusOK)
|
||||
}
|
||||
if rec.Body.String() != "ok" {
|
||||
t.Errorf("body = %q, want %q", rec.Body.String(), "ok")
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
Reference in New Issue
Block a user