refactor(auth): replace untyped JWT claims with typed Claims struct
Introduced a typed Claims struct in core/auth to replace the raw map[string]any approach used for JWT claims throughout the codebase. This provides compile-time safety and better readability when creating, validating, and extracting JWT tokens. Also upgraded lestrrat-go/jwx from v2 to v3 and go-chi/jwtauth to v5.4.0, adapting all callers to the new API where token accessor methods now return tuples instead of bare values. Updated all affected handlers, middleware, and tests. Signed-off-by: Deluan <deluan@navidrome.org>
This commit is contained in:
+8
-4
@@ -185,12 +185,16 @@ func tokenFromHeader(r *http.Request) string {
|
||||
}
|
||||
|
||||
func UsernameFromToken(r *http.Request) string {
|
||||
token, claims, err := jwtauth.FromContext(r.Context())
|
||||
if err != nil || claims["sub"] == nil || token == nil {
|
||||
token, _, err := jwtauth.FromContext(r.Context())
|
||||
if err != nil || token == nil {
|
||||
return ""
|
||||
}
|
||||
log.Trace(r, "Found username in JWT token", "username", token.Subject())
|
||||
return token.Subject()
|
||||
sub, _ := token.Subject()
|
||||
if sub == "" {
|
||||
return ""
|
||||
}
|
||||
log.Trace(r, "Found username in JWT token", "username", sub)
|
||||
return sub
|
||||
}
|
||||
|
||||
func UsernameFromExtAuthHeader(r *http.Request) string {
|
||||
|
||||
@@ -7,7 +7,6 @@ import (
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"github.com/lestrrat-go/jwx/v2/jwt"
|
||||
"github.com/navidrome/navidrome/core/artwork"
|
||||
"github.com/navidrome/navidrome/core/auth"
|
||||
"github.com/navidrome/navidrome/log"
|
||||
@@ -76,22 +75,14 @@ func decodeArtworkID(tokenString string) (model.ArtworkID, error) {
|
||||
if token == nil {
|
||||
return model.ArtworkID{}, errors.New("unauthorized")
|
||||
}
|
||||
err = jwt.Validate(token, jwt.WithRequiredClaim("id"))
|
||||
if err != nil {
|
||||
return model.ArtworkID{}, err
|
||||
c := auth.ClaimsFromToken(token)
|
||||
if c.ID == "" {
|
||||
return model.ArtworkID{}, errors.New("required claim \"id\" not found")
|
||||
}
|
||||
claims, err := token.AsMap(context.Background())
|
||||
if err != nil {
|
||||
return model.ArtworkID{}, err
|
||||
}
|
||||
id, ok := claims["id"].(string)
|
||||
if !ok {
|
||||
return model.ArtworkID{}, errors.New("invalid id type")
|
||||
}
|
||||
artID, err := model.ParseArtworkID(id)
|
||||
artID, err := model.ParseArtworkID(c.ID)
|
||||
if err == nil {
|
||||
return artID, nil
|
||||
}
|
||||
// Try to default to mediafile artworkId (if used with a mediafileShare token)
|
||||
return model.ParseArtworkID("mf-" + id)
|
||||
return model.ParseArtworkID("mf-" + c.ID)
|
||||
}
|
||||
|
||||
@@ -3,7 +3,6 @@ package public
|
||||
import (
|
||||
"github.com/go-chi/jwtauth/v5"
|
||||
"github.com/navidrome/navidrome/core/auth"
|
||||
"github.com/navidrome/navidrome/model"
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
)
|
||||
@@ -15,18 +14,11 @@ var _ = Describe("decodeArtworkID", func() {
|
||||
|
||||
It("fails to decode an invalid token", func() {
|
||||
_, err := decodeArtworkID("xx-123")
|
||||
Expect(err).To(MatchError("invalid JWT"))
|
||||
})
|
||||
|
||||
It("defaults to kind mediafile for empty artwork ID", func() {
|
||||
token, _ := auth.CreatePublicToken(map[string]any{"id": ""})
|
||||
id, err := decodeArtworkID(token)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(id.Kind).To(Equal(model.KindMediaFileArtwork))
|
||||
Expect(err).To(HaveOccurred())
|
||||
})
|
||||
|
||||
It("fails to decode a token without an id", func() {
|
||||
token, _ := auth.CreatePublicToken(map[string]any{})
|
||||
token, _ := auth.CreatePublicToken(auth.Claims{})
|
||||
_, err := decodeArtworkID(token)
|
||||
Expect(err).To(HaveOccurred())
|
||||
})
|
||||
|
||||
@@ -97,12 +97,10 @@ func (pub *Router) mapShareToM3U(r *http.Request, s model.Share) *model.Share {
|
||||
}
|
||||
|
||||
func encodeMediafileShare(s model.Share, id string) string {
|
||||
claims := map[string]any{"id": id}
|
||||
if s.Format != "" {
|
||||
claims["f"] = s.Format
|
||||
}
|
||||
if s.MaxBitRate != 0 {
|
||||
claims["b"] = s.MaxBitRate
|
||||
claims := auth.Claims{
|
||||
ID: id,
|
||||
Format: s.Format,
|
||||
BitRate: s.MaxBitRate,
|
||||
}
|
||||
token, _ := auth.CreateExpiringPublicToken(V(s.ExpiresAt), claims)
|
||||
return token
|
||||
|
||||
@@ -1,13 +1,11 @@
|
||||
package public
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"io"
|
||||
"net/http"
|
||||
"strconv"
|
||||
|
||||
"github.com/lestrrat-go/jwx/v2/jwt"
|
||||
"github.com/navidrome/navidrome/core/auth"
|
||||
"github.com/navidrome/navidrome/log"
|
||||
"github.com/navidrome/navidrome/utils/req"
|
||||
@@ -85,21 +83,13 @@ func decodeStreamInfo(tokenString string) (shareTrackInfo, error) {
|
||||
if token == nil {
|
||||
return shareTrackInfo{}, errors.New("unauthorized")
|
||||
}
|
||||
err = jwt.Validate(token, jwt.WithRequiredClaim("id"))
|
||||
if err != nil {
|
||||
return shareTrackInfo{}, err
|
||||
c := auth.ClaimsFromToken(token)
|
||||
if c.ID == "" {
|
||||
return shareTrackInfo{}, errors.New("required claim \"id\" not found")
|
||||
}
|
||||
claims, err := token.AsMap(context.Background())
|
||||
if err != nil {
|
||||
return shareTrackInfo{}, err
|
||||
}
|
||||
id, ok := claims["id"].(string)
|
||||
if !ok {
|
||||
return shareTrackInfo{}, errors.New("invalid id type")
|
||||
}
|
||||
resp := shareTrackInfo{}
|
||||
resp.id = id
|
||||
resp.format, _ = claims["f"].(string)
|
||||
resp.bitrate, _ = claims["b"].(int)
|
||||
return resp, nil
|
||||
return shareTrackInfo{
|
||||
id: c.ID,
|
||||
format: c.Format,
|
||||
bitrate: c.BitRate,
|
||||
}, nil
|
||||
}
|
||||
|
||||
@@ -159,7 +159,7 @@ func validateCredentials(user *model.User, pass, token, salt, jwt string) error
|
||||
switch {
|
||||
case jwt != "":
|
||||
claims, err := auth.Validate(jwt)
|
||||
valid = err == nil && claims["sub"] == user.UserName
|
||||
valid = err == nil && claims.Subject == user.UserName
|
||||
case pass != "":
|
||||
if strings.HasPrefix(pass, "enc:") {
|
||||
if dec, err := hex.DecodeString(pass[4:]); err == nil {
|
||||
|
||||
Reference in New Issue
Block a user