refactor(auth): replace untyped JWT claims with typed Claims struct
Introduced a typed Claims struct in core/auth to replace the raw map[string]any approach used for JWT claims throughout the codebase. This provides compile-time safety and better readability when creating, validating, and extracting JWT tokens. Also upgraded lestrrat-go/jwx from v2 to v3 and go-chi/jwtauth to v5.4.0, adapting all callers to the new API where token accessor methods now return tuples instead of bare values. Updated all affected handlers, middleware, and tests. Signed-off-by: Deluan <deluan@navidrome.org>
This commit is contained in:
+22
-36
@@ -4,12 +4,11 @@ import (
|
||||
"cmp"
|
||||
"context"
|
||||
"crypto/sha256"
|
||||
"maps"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/go-chi/jwtauth/v5"
|
||||
"github.com/lestrrat-go/jwx/v2/jwt"
|
||||
"github.com/lestrrat-go/jwx/v3/jwt"
|
||||
"github.com/navidrome/navidrome/conf"
|
||||
"github.com/navidrome/navidrome/consts"
|
||||
"github.com/navidrome/navidrome/log"
|
||||
@@ -46,38 +45,30 @@ func Init(ds model.DataStore) {
|
||||
})
|
||||
}
|
||||
|
||||
func createBaseClaims() map[string]any {
|
||||
tokenClaims := map[string]any{}
|
||||
tokenClaims[jwt.IssuerKey] = consts.JWTIssuer
|
||||
return tokenClaims
|
||||
}
|
||||
|
||||
func CreatePublicToken(claims map[string]any) (string, error) {
|
||||
tokenClaims := createBaseClaims()
|
||||
maps.Copy(tokenClaims, claims)
|
||||
_, token, err := TokenAuth.Encode(tokenClaims)
|
||||
|
||||
func CreatePublicToken(claims Claims) (string, error) {
|
||||
claims.Issuer = consts.JWTIssuer
|
||||
_, token, err := TokenAuth.Encode(claims.ToMap())
|
||||
return token, err
|
||||
}
|
||||
|
||||
func CreateExpiringPublicToken(exp time.Time, claims map[string]any) (string, error) {
|
||||
tokenClaims := createBaseClaims()
|
||||
func CreateExpiringPublicToken(exp time.Time, claims Claims) (string, error) {
|
||||
claims.Issuer = consts.JWTIssuer
|
||||
if !exp.IsZero() {
|
||||
tokenClaims[jwt.ExpirationKey] = exp.UTC().Unix()
|
||||
claims.ExpiresAt = exp
|
||||
}
|
||||
maps.Copy(tokenClaims, claims)
|
||||
_, token, err := TokenAuth.Encode(tokenClaims)
|
||||
|
||||
_, token, err := TokenAuth.Encode(claims.ToMap())
|
||||
return token, err
|
||||
}
|
||||
|
||||
func CreateToken(u *model.User) (string, error) {
|
||||
claims := createBaseClaims()
|
||||
claims[jwt.SubjectKey] = u.UserName
|
||||
claims[jwt.IssuedAtKey] = time.Now().UTC().Unix()
|
||||
claims["uid"] = u.ID
|
||||
claims["adm"] = u.IsAdmin
|
||||
token, _, err := TokenAuth.Encode(claims)
|
||||
claims := Claims{
|
||||
Issuer: consts.JWTIssuer,
|
||||
Subject: u.UserName,
|
||||
IssuedAt: time.Now(),
|
||||
UserID: u.ID,
|
||||
IsAdmin: u.IsAdmin,
|
||||
}
|
||||
token, _, err := TokenAuth.Encode(claims.ToMap())
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
@@ -86,23 +77,18 @@ func CreateToken(u *model.User) (string, error) {
|
||||
}
|
||||
|
||||
func TouchToken(token jwt.Token) (string, error) {
|
||||
claims, err := token.AsMap(context.Background())
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
|
||||
claims[jwt.ExpirationKey] = time.Now().UTC().Add(conf.Server.SessionTimeout).Unix()
|
||||
_, newToken, err := TokenAuth.Encode(claims)
|
||||
|
||||
claims := ClaimsFromToken(token).
|
||||
WithExpiresAt(time.Now().UTC().Add(conf.Server.SessionTimeout))
|
||||
_, newToken, err := TokenAuth.Encode(claims.ToMap())
|
||||
return newToken, err
|
||||
}
|
||||
|
||||
func Validate(tokenStr string) (map[string]any, error) {
|
||||
func Validate(tokenStr string) (Claims, error) {
|
||||
token, err := jwtauth.VerifyToken(TokenAuth, tokenStr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
return Claims{}, err
|
||||
}
|
||||
return token.AsMap(context.Background())
|
||||
return ClaimsFromToken(token), nil
|
||||
}
|
||||
|
||||
func WithAdminUser(ctx context.Context, ds model.DataStore) context.Context {
|
||||
|
||||
@@ -54,7 +54,7 @@ var _ = Describe("Auth", func() {
|
||||
|
||||
decodedClaims, err := auth.Validate(tokenStr)
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
Expect(decodedClaims["iss"]).To(Equal("issuer"))
|
||||
Expect(decodedClaims.Issuer).To(Equal("issuer"))
|
||||
})
|
||||
|
||||
It("returns ErrExpired if the `exp` field is in the past", func() {
|
||||
@@ -82,11 +82,11 @@ var _ = Describe("Auth", func() {
|
||||
claims, err := auth.Validate(tokenStr)
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
|
||||
Expect(claims["iss"]).To(Equal(consts.JWTIssuer))
|
||||
Expect(claims["sub"]).To(Equal("johndoe"))
|
||||
Expect(claims["uid"]).To(Equal("123"))
|
||||
Expect(claims["adm"]).To(Equal(true))
|
||||
Expect(claims["exp"]).To(BeTemporally(">", time.Now()))
|
||||
Expect(claims.Issuer).To(Equal(consts.JWTIssuer))
|
||||
Expect(claims.Subject).To(Equal("johndoe"))
|
||||
Expect(claims.UserID).To(Equal("123"))
|
||||
Expect(claims.IsAdmin).To(Equal(true))
|
||||
Expect(claims.ExpiresAt).To(BeTemporally(">", time.Now()))
|
||||
})
|
||||
})
|
||||
|
||||
@@ -104,8 +104,7 @@ var _ = Describe("Auth", func() {
|
||||
|
||||
decodedClaims, err := auth.Validate(touched)
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
exp := decodedClaims["exp"].(time.Time)
|
||||
Expect(exp.Sub(yesterday)).To(BeNumerically(">=", oneDay))
|
||||
Expect(decodedClaims.ExpiresAt.Sub(yesterday)).To(BeNumerically(">=", oneDay))
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
@@ -0,0 +1,94 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"github.com/lestrrat-go/jwx/v3/jwt"
|
||||
)
|
||||
|
||||
// Claims represents the typed JWT claims used throughout Navidrome,
|
||||
// replacing the untyped map[string]any approach.
|
||||
type Claims struct {
|
||||
// Standard JWT claims
|
||||
Issuer string
|
||||
Subject string // username for session tokens
|
||||
IssuedAt time.Time
|
||||
ExpiresAt time.Time
|
||||
|
||||
// Custom claims
|
||||
UserID string // "uid"
|
||||
IsAdmin bool // "adm"
|
||||
ID string // "id" - artwork/mediafile ID
|
||||
Format string // "f" - audio format
|
||||
BitRate int // "b" - audio bitrate
|
||||
}
|
||||
|
||||
// ToMap converts Claims to a map[string]any for use with TokenAuth.Encode().
|
||||
// Only non-zero fields are included.
|
||||
func (c Claims) ToMap() map[string]any {
|
||||
m := make(map[string]any)
|
||||
if c.Issuer != "" {
|
||||
m[jwt.IssuerKey] = c.Issuer
|
||||
}
|
||||
if c.Subject != "" {
|
||||
m[jwt.SubjectKey] = c.Subject
|
||||
}
|
||||
if !c.IssuedAt.IsZero() {
|
||||
m[jwt.IssuedAtKey] = c.IssuedAt.UTC().Unix()
|
||||
}
|
||||
if !c.ExpiresAt.IsZero() {
|
||||
m[jwt.ExpirationKey] = c.ExpiresAt.UTC().Unix()
|
||||
}
|
||||
if c.UserID != "" {
|
||||
m["uid"] = c.UserID
|
||||
}
|
||||
if c.IsAdmin {
|
||||
m["adm"] = c.IsAdmin
|
||||
}
|
||||
if c.ID != "" {
|
||||
m["id"] = c.ID
|
||||
}
|
||||
if c.Format != "" {
|
||||
m["f"] = c.Format
|
||||
}
|
||||
if c.BitRate != 0 {
|
||||
m["b"] = c.BitRate
|
||||
}
|
||||
return m
|
||||
}
|
||||
|
||||
func (c Claims) WithExpiresAt(t time.Time) Claims {
|
||||
c.ExpiresAt = t
|
||||
return c
|
||||
}
|
||||
|
||||
// ClaimsFromToken extracts Claims directly from a jwt.Token using token.Get().
|
||||
func ClaimsFromToken(token jwt.Token) Claims {
|
||||
var c Claims
|
||||
c.Issuer, _ = token.Issuer()
|
||||
c.Subject, _ = token.Subject()
|
||||
c.IssuedAt, _ = token.IssuedAt()
|
||||
c.ExpiresAt, _ = token.Expiration()
|
||||
|
||||
var uid string
|
||||
if err := token.Get("uid", &uid); err == nil {
|
||||
c.UserID = uid
|
||||
}
|
||||
var adm bool
|
||||
if err := token.Get("adm", &adm); err == nil {
|
||||
c.IsAdmin = adm
|
||||
}
|
||||
var id string
|
||||
if err := token.Get("id", &id); err == nil {
|
||||
c.ID = id
|
||||
}
|
||||
var f string
|
||||
if err := token.Get("f", &f); err == nil {
|
||||
c.Format = f
|
||||
}
|
||||
var b int
|
||||
if err := token.Get("b", &b); err == nil {
|
||||
c.BitRate = b
|
||||
}
|
||||
return c
|
||||
}
|
||||
@@ -0,0 +1,99 @@
|
||||
package auth_test
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"github.com/go-chi/jwtauth/v5"
|
||||
"github.com/navidrome/navidrome/core/auth"
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
|
||||
var _ = Describe("Claims", func() {
|
||||
Describe("ToMap", func() {
|
||||
It("includes only non-zero fields", func() {
|
||||
c := auth.Claims{
|
||||
Issuer: "ND",
|
||||
Subject: "johndoe",
|
||||
UserID: "123",
|
||||
IsAdmin: true,
|
||||
}
|
||||
m := c.ToMap()
|
||||
Expect(m).To(HaveKeyWithValue("iss", "ND"))
|
||||
Expect(m).To(HaveKeyWithValue("sub", "johndoe"))
|
||||
Expect(m).To(HaveKeyWithValue("uid", "123"))
|
||||
Expect(m).To(HaveKeyWithValue("adm", true))
|
||||
Expect(m).NotTo(HaveKey("exp"))
|
||||
Expect(m).NotTo(HaveKey("iat"))
|
||||
Expect(m).NotTo(HaveKey("id"))
|
||||
Expect(m).NotTo(HaveKey("f"))
|
||||
Expect(m).NotTo(HaveKey("b"))
|
||||
})
|
||||
|
||||
It("includes expiration and issued-at when set", func() {
|
||||
now := time.Now()
|
||||
c := auth.Claims{
|
||||
IssuedAt: now,
|
||||
ExpiresAt: now.Add(time.Hour),
|
||||
}
|
||||
m := c.ToMap()
|
||||
Expect(m).To(HaveKey("iat"))
|
||||
Expect(m).To(HaveKey("exp"))
|
||||
})
|
||||
|
||||
It("includes custom claims for public tokens", func() {
|
||||
c := auth.Claims{
|
||||
ID: "al-123",
|
||||
Format: "mp3",
|
||||
BitRate: 192,
|
||||
}
|
||||
m := c.ToMap()
|
||||
Expect(m).To(HaveKeyWithValue("id", "al-123"))
|
||||
Expect(m).To(HaveKeyWithValue("f", "mp3"))
|
||||
Expect(m).To(HaveKeyWithValue("b", 192))
|
||||
})
|
||||
})
|
||||
|
||||
Describe("ClaimsFromToken", func() {
|
||||
It("round-trips session claims through encode/decode", func() {
|
||||
tokenAuth := jwtauth.New("HS256", []byte("test-secret"), nil)
|
||||
now := time.Now().Truncate(time.Second)
|
||||
original := auth.Claims{
|
||||
Issuer: "ND",
|
||||
Subject: "johndoe",
|
||||
UserID: "123",
|
||||
IsAdmin: true,
|
||||
}
|
||||
m := original.ToMap()
|
||||
m["iat"] = now.UTC().Unix()
|
||||
token, _, err := tokenAuth.Encode(m)
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
|
||||
c := auth.ClaimsFromToken(token)
|
||||
Expect(c.Issuer).To(Equal("ND"))
|
||||
Expect(c.Subject).To(Equal("johndoe"))
|
||||
Expect(c.UserID).To(Equal("123"))
|
||||
Expect(c.IsAdmin).To(BeTrue())
|
||||
Expect(c.IssuedAt.UTC()).To(Equal(now.UTC()))
|
||||
})
|
||||
|
||||
It("round-trips public token claims through encode/decode", func() {
|
||||
tokenAuth := jwtauth.New("HS256", []byte("test-secret"), nil)
|
||||
original := auth.Claims{
|
||||
Issuer: "ND",
|
||||
ID: "al-456",
|
||||
Format: "opus",
|
||||
BitRate: 128,
|
||||
}
|
||||
token, _, err := tokenAuth.Encode(original.ToMap())
|
||||
Expect(err).NotTo(HaveOccurred())
|
||||
|
||||
c := auth.ClaimsFromToken(token)
|
||||
Expect(c.Issuer).To(Equal("ND"))
|
||||
Expect(c.ID).To(Equal("al-456"))
|
||||
Expect(c.Format).To(Equal("opus"))
|
||||
Expect(c.BitRate).To(Equal(128))
|
||||
})
|
||||
})
|
||||
|
||||
})
|
||||
@@ -18,7 +18,7 @@ import (
|
||||
// ImageURL generates a public URL for artwork images.
|
||||
// It creates a signed token for the artwork ID and builds a complete public URL.
|
||||
func ImageURL(req *http.Request, artID model.ArtworkID, size int) string {
|
||||
token, _ := auth.CreatePublicToken(map[string]any{"id": artID.String()})
|
||||
token, _ := auth.CreatePublicToken(auth.Claims{ID: artID.String()})
|
||||
uri := path.Join(consts.URLPathPublicImages, token)
|
||||
params := url.Values{}
|
||||
if size > 0 {
|
||||
|
||||
Reference in New Issue
Block a user