Replace beego/orm with dbx (#2693)

* Start migration to dbx package

* Fix annotations and bookmarks bindings

* Fix tests

* Fix more tests

* Remove remaining references to beego/orm

* Add PostScanner/PostMapper interfaces

* Fix importing SmartPlaylists

* Renaming

* More renaming

* Fix artist DB mapping

* Fix playlist updates

* Remove bookmarks at the end of the test

* Remove remaining `orm` struct tags

* Fix user timestamps DB access

* Fix smart playlist evaluated_at DB access

* Fix search3
This commit is contained in:
Deluan Quintão
2023-12-09 13:52:17 -05:00
committed by GitHub
parent 7074455e0e
commit 0ca0d5da22
60 changed files with 461 additions and 376 deletions
+3 -3
View File
@@ -6,11 +6,11 @@ import (
"strings"
. "github.com/Masterminds/squirrel"
"github.com/beego/beego/v2/client/orm"
"github.com/deluan/rest"
"github.com/navidrome/navidrome/conf"
"github.com/navidrome/navidrome/log"
"github.com/navidrome/navidrome/model"
"github.com/pocketbase/dbx"
)
type albumRepository struct {
@@ -18,10 +18,10 @@ type albumRepository struct {
sqlRestful
}
func NewAlbumRepository(ctx context.Context, o orm.QueryExecutor) model.AlbumRepository {
func NewAlbumRepository(ctx context.Context, db dbx.Builder) model.AlbumRepository {
r := &albumRepository{}
r.ctx = ctx
r.ormer = o
r.db = db
r.tableName = "album"
r.sortMappings = map[string]string{
"name": "order_album_name asc, order_album_artist_name asc",
+1 -2
View File
@@ -3,7 +3,6 @@ package persistence
import (
"context"
"github.com/beego/beego/v2/client/orm"
"github.com/navidrome/navidrome/log"
"github.com/navidrome/navidrome/model"
"github.com/navidrome/navidrome/model/request"
@@ -16,7 +15,7 @@ var _ = Describe("AlbumRepository", func() {
BeforeEach(func() {
ctx := request.WithUser(log.NewContext(context.TODO()), model.User{ID: "userid", UserName: "johndoe"})
repo = NewAlbumRepository(ctx, orm.NewOrm())
repo = NewAlbumRepository(ctx, getDBXBuilder())
})
Describe("Get", func() {
+33 -37
View File
@@ -8,13 +8,13 @@ import (
"strings"
. "github.com/Masterminds/squirrel"
"github.com/beego/beego/v2/client/orm"
"github.com/deluan/rest"
"github.com/navidrome/navidrome/conf"
"github.com/navidrome/navidrome/consts"
"github.com/navidrome/navidrome/log"
"github.com/navidrome/navidrome/model"
"github.com/navidrome/navidrome/utils"
"github.com/pocketbase/dbx"
)
type artistRepository struct {
@@ -24,14 +24,40 @@ type artistRepository struct {
}
type dbArtist struct {
model.Artist `structs:",flatten"`
SimilarArtists string `structs:"similar_artists" json:"similarArtists"`
*model.Artist `structs:",flatten"`
SimilarArtists string `structs:"-" json:"similarArtists"`
}
func NewArtistRepository(ctx context.Context, o orm.QueryExecutor) model.ArtistRepository {
func (a *dbArtist) PostScan() error {
if a.SimilarArtists == "" {
return nil
}
for _, s := range strings.Split(a.SimilarArtists, ";") {
fields := strings.Split(s, ":")
if len(fields) != 2 {
continue
}
name, _ := url.QueryUnescape(fields[1])
a.Artist.SimilarArtists = append(a.Artist.SimilarArtists, model.Artist{
ID: fields[0],
Name: name,
})
}
return nil
}
func (a *dbArtist) PostMapArgs(m map[string]any) error {
var sa []string
for _, s := range a.Artist.SimilarArtists {
sa = append(sa, fmt.Sprintf("%s:%s", s.ID, url.QueryEscape(s.Name)))
}
m["similar_artists"] = strings.Join(sa, ";")
return nil
}
func NewArtistRepository(ctx context.Context, db dbx.Builder) model.ArtistRepository {
r := &artistRepository{}
r.ctx = ctx
r.ormer = o
r.db = db
r.indexGroups = utils.ParseIndexGroups(conf.Server.IndexGroups)
r.tableName = "artist"
r.sortMappings = map[string]string{
@@ -62,7 +88,7 @@ func (r *artistRepository) Exists(id string) (bool, error) {
func (r *artistRepository) Put(a *model.Artist, colsToUpdate ...string) error {
a.FullText = getFullText(a.Name, a.SortArtistName)
dba := r.fromModel(a)
dba := &dbArtist{Artist: a}
_, err := r.put(dba.ID, dba, colsToUpdate...)
if err != nil {
return err
@@ -102,41 +128,11 @@ func (r *artistRepository) GetAll(options ...model.QueryOptions) (model.Artists,
func (r *artistRepository) toModels(dba []dbArtist) model.Artists {
res := model.Artists{}
for i := range dba {
a := dba[i]
res = append(res, *r.toModel(&a))
res = append(res, *dba[i].Artist)
}
return res
}
func (r *artistRepository) toModel(dba *dbArtist) *model.Artist {
a := dba.Artist
a.SimilarArtists = nil
for _, s := range strings.Split(dba.SimilarArtists, ";") {
fields := strings.Split(s, ":")
if len(fields) != 2 {
continue
}
name, _ := url.QueryUnescape(fields[1])
a.SimilarArtists = append(a.SimilarArtists, model.Artist{
ID: fields[0],
Name: name,
})
}
return &a
}
func (r *artistRepository) fromModel(a *model.Artist) *dbArtist {
dba := &dbArtist{Artist: *a}
var sa []string
for _, s := range a.SimilarArtists {
sa = append(sa, fmt.Sprintf("%s:%s", s.ID, url.QueryEscape(s.Name)))
}
dba.SimilarArtists = strings.Join(sa, ";")
return dba
}
func (r *artistRepository) getIndexKey(a *model.Artist) string {
name := strings.ToLower(utils.NoArticle(a.Name))
for k, v := range r.indexGroups {
+13 -4
View File
@@ -3,7 +3,7 @@ package persistence
import (
"context"
"github.com/beego/beego/v2/client/orm"
"github.com/fatih/structs"
"github.com/navidrome/navidrome/log"
"github.com/navidrome/navidrome/model"
"github.com/navidrome/navidrome/model/request"
@@ -18,7 +18,7 @@ var _ = Describe("ArtistRepository", func() {
BeforeEach(func() {
ctx := log.NewContext(context.TODO())
ctx = request.WithUser(ctx, model.User{ID: "userid"})
repo = NewArtistRepository(ctx, orm.NewOrm())
repo = NewArtistRepository(ctx, getDBXBuilder())
})
Describe("Count", func() {
@@ -71,8 +71,17 @@ var _ = Describe("ArtistRepository", func() {
}}
})
It("maps fields", func() {
dba := repo.(*artistRepository).fromModel(a)
actual := repo.(*artistRepository).toModel(dba)
dba := &dbArtist{Artist: a}
m := structs.Map(dba)
Expect(dba.PostMapArgs(m)).To(Succeed())
Expect(m).To(HaveKeyWithValue("similar_artists", "2:AC%2FDC;-1:Test%3BWith%3ASep%2CChars"))
other := dbArtist{SimilarArtists: m["similar_artists"].(string), Artist: &model.Artist{
ID: "1", Name: "Van Halen",
}}
Expect(other.PostScan()).To(Succeed())
actual := other.Artist
Expect(*actual).To(MatchFields(IgnoreExtras, Fields{
"ID": Equal(a.ID),
"Name": Equal(a.Name),
+9 -4
View File
@@ -4,9 +4,9 @@ import (
"context"
"github.com/google/uuid"
"github.com/pocketbase/dbx"
. "github.com/Masterminds/squirrel"
"github.com/beego/beego/v2/client/orm"
"github.com/deluan/rest"
"github.com/navidrome/navidrome/log"
"github.com/navidrome/navidrome/model"
@@ -17,10 +17,10 @@ type genreRepository struct {
sqlRestful
}
func NewGenreRepository(ctx context.Context, o orm.QueryExecutor) model.GenreRepository {
func NewGenreRepository(ctx context.Context, db dbx.Builder) model.GenreRepository {
r := &genreRepository{}
r.ctx = ctx
r.ormer = o
r.db = db
r.tableName = "genre"
r.filterMappings = map[string]filterFunc{
"name": containsFilter,
@@ -29,7 +29,12 @@ func NewGenreRepository(ctx context.Context, o orm.QueryExecutor) model.GenreRep
}
func (r *genreRepository) GetAll(opt ...model.QueryOptions) (model.Genres, error) {
sq := r.newSelect(opt...).Columns("genre.id", "genre.name", "a.album_count", "m.song_count").
sq := r.newSelect(opt...).Columns(
"genre.id",
"genre.name",
"coalesce(a.album_count, 0) as album_count",
"coalesce(m.song_count, 0) as song_count",
).
LeftJoin("(select ag.genre_id, count(ag.album_id) as album_count from album_genres ag group by ag.genre_id) a on a.genre_id = genre.id").
LeftJoin("(select mg.genre_id, count(mg.media_file_id) as song_count from media_file_genres mg group by mg.genre_id) m on m.genre_id = genre.id")
res := model.Genres{}
+8 -6
View File
@@ -3,26 +3,27 @@ package persistence_test
import (
"context"
"github.com/beego/beego/v2/client/orm"
"github.com/google/uuid"
"github.com/navidrome/navidrome/db"
"github.com/navidrome/navidrome/log"
"github.com/navidrome/navidrome/model"
"github.com/navidrome/navidrome/persistence"
. "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega"
"github.com/pocketbase/dbx"
)
var _ = Describe("GenreRepository", func() {
var repo model.GenreRepository
BeforeEach(func() {
repo = persistence.NewGenreRepository(log.NewContext(context.TODO()), orm.NewOrm())
repo = persistence.NewGenreRepository(log.NewContext(context.TODO()), dbx.NewFromDB(db.Db(), db.Driver))
})
Describe("GetAll()", func() {
It("returns all records", func() {
genres, err := repo.GetAll()
Expect(err).To(BeNil())
Expect(err).ToNot(HaveOccurred())
Expect(genres).To(ConsistOf(
model.Genre{ID: "gn-1", Name: "Electronic", AlbumCount: 1, SongCount: 2},
model.Genre{ID: "gn-2", Name: "Rock", AlbumCount: 3, SongCount: 3},
@@ -43,13 +44,14 @@ var _ = Describe("GenreRepository", func() {
It("insert non-existent genre names", func() {
g := model.Genre{Name: "Reggae"}
err := repo.Put(&g)
Expect(err).To(BeNil())
Expect(err).ToNot(HaveOccurred())
// ID is a uuid
_, err = uuid.Parse(g.ID)
Expect(err).To(BeNil())
Expect(err).ToNot(HaveOccurred())
genres, _ := repo.GetAll()
genres, err := repo.GetAll()
Expect(err).ToNot(HaveOccurred())
Expect(genres).To(HaveLen(3))
Expect(genres).To(ContainElement(model.Genre{ID: g.ID, Name: "Reggae", AlbumCount: 0, SongCount: 0}))
})
+23 -4
View File
@@ -1,6 +1,7 @@
package persistence
import (
"database/sql/driver"
"fmt"
"regexp"
"strings"
@@ -10,14 +11,32 @@ import (
"github.com/fatih/structs"
)
func toSqlArgs(rec interface{}) (map[string]interface{}, error) {
type PostMapper interface {
PostMapArgs(map[string]any) error
}
func toSQLArgs(rec interface{}) (map[string]interface{}, error) {
m := structs.Map(rec)
for k, v := range m {
if t, ok := v.(time.Time); ok {
switch t := v.(type) {
case time.Time:
m[k] = t.Format(time.RFC3339Nano)
case *time.Time:
if t != nil {
m[k] = t.Format(time.RFC3339Nano)
}
case driver.Valuer:
var err error
m[k], err = t.Value()
if err != nil {
return nil, err
}
}
if t, ok := v.(*time.Time); ok && t != nil {
m[k] = t.Format(time.RFC3339Nano)
}
if r, ok := rec.(PostMapper); ok {
err := r.PostMapArgs(m)
if err != nil {
return nil, err
}
}
return m, nil
+3 -2
View File
@@ -23,7 +23,7 @@ var _ = Describe("Helpers", func() {
Expect(toSnakeCase("snake_case")).To(Equal("snake_case"))
})
})
Describe("toSqlArgs", func() {
Describe("toSQLArgs", func() {
type Embed struct{}
type Model struct {
Embed `structs:"-"`
@@ -37,10 +37,11 @@ var _ = Describe("Helpers", func() {
It("returns a map with snake_case keys", func() {
now := time.Now()
m := &Model{ID: "123", AlbumId: "456", CreatedAt: now, UpdatedAt: &now, PlayCount: 2}
args, err := toSqlArgs(m)
args, err := toSQLArgs(m)
Expect(err).To(BeNil())
Expect(args).To(HaveKeyWithValue("id", "123"))
Expect(args).To(HaveKeyWithValue("album_id", "456"))
Expect(args).To(HaveKeyWithValue("play_count", 2))
Expect(args).To(HaveKeyWithValue("updated_at", now.Format(time.RFC3339Nano)))
Expect(args).To(HaveKeyWithValue("created_at", now.Format(time.RFC3339Nano)))
Expect(args).ToNot(HaveKey("Embed"))
+4 -4
View File
@@ -9,10 +9,10 @@ import (
"unicode/utf8"
. "github.com/Masterminds/squirrel"
"github.com/beego/beego/v2/client/orm"
"github.com/deluan/rest"
"github.com/navidrome/navidrome/log"
"github.com/navidrome/navidrome/model"
"github.com/pocketbase/dbx"
)
type mediaFileRepository struct {
@@ -20,10 +20,10 @@ type mediaFileRepository struct {
sqlRestful
}
func NewMediaFileRepository(ctx context.Context, o orm.QueryExecutor) *mediaFileRepository {
func NewMediaFileRepository(ctx context.Context, db dbx.Builder) *mediaFileRepository {
r := &mediaFileRepository{}
r.ctx = ctx
r.ormer = o
r.db = db
r.tableName = "media_file"
r.sortMappings = map[string]string{
"artist": "order_artist_name asc, order_album_name asc, release_date asc, disc_number asc, track_number asc",
@@ -146,7 +146,7 @@ func (r *mediaFileRepository) FindPathsRecursively(basePath string) ([]string, e
sel := r.newSelect().Columns(fmt.Sprintf("distinct rtrim(path, replace(path, '%s', ''))", string(os.PathSeparator))).
Where(pathStartsWith(path))
var res []string
err := r.queryAll(sel, &res)
err := r.queryAllSlice(sel, &res)
return res, err
}
+1 -2
View File
@@ -5,7 +5,6 @@ import (
"time"
"github.com/Masterminds/squirrel"
"github.com/beego/beego/v2/client/orm"
"github.com/google/uuid"
"github.com/navidrome/navidrome/log"
"github.com/navidrome/navidrome/model"
@@ -20,7 +19,7 @@ var _ = Describe("MediaRepository", func() {
BeforeEach(func() {
ctx := log.NewContext(context.TODO())
ctx = request.WithUser(ctx, model.User{ID: "userid"})
mr = NewMediaFileRepository(ctx, orm.NewOrm())
mr = NewMediaFileRepository(ctx, getDBXBuilder())
})
It("gets mediafile from the DB", func() {
+2 -2
View File
@@ -3,16 +3,16 @@ package persistence
import (
"context"
"github.com/beego/beego/v2/client/orm"
"github.com/navidrome/navidrome/conf"
"github.com/navidrome/navidrome/model"
"github.com/pocketbase/dbx"
)
type mediaFolderRepository struct {
ctx context.Context
}
func NewMediaFolderRepository(ctx context.Context, o orm.QueryExecutor) model.MediaFolderRepository {
func NewMediaFolderRepository(ctx context.Context, _ dbx.Builder) model.MediaFolderRepository {
return &mediaFolderRepository{ctx}
}
+28 -33
View File
@@ -5,79 +5,78 @@ import (
"database/sql"
"reflect"
"github.com/beego/beego/v2/client/orm"
"github.com/navidrome/navidrome/db"
"github.com/navidrome/navidrome/log"
"github.com/navidrome/navidrome/model"
"github.com/pocketbase/dbx"
)
type SQLStore struct {
orm orm.QueryExecutor
db *sql.DB
db dbx.Builder
}
func New(db *sql.DB) model.DataStore {
return &SQLStore{db: db}
func New(conn *sql.DB) model.DataStore {
return &SQLStore{db: dbx.NewFromDB(conn, db.Driver)}
}
func (s *SQLStore) Album(ctx context.Context) model.AlbumRepository {
return NewAlbumRepository(ctx, s.getOrmer())
return NewAlbumRepository(ctx, s.getDBXBuilder())
}
func (s *SQLStore) Artist(ctx context.Context) model.ArtistRepository {
return NewArtistRepository(ctx, s.getOrmer())
return NewArtistRepository(ctx, s.getDBXBuilder())
}
func (s *SQLStore) MediaFile(ctx context.Context) model.MediaFileRepository {
return NewMediaFileRepository(ctx, s.getOrmer())
return NewMediaFileRepository(ctx, s.getDBXBuilder())
}
func (s *SQLStore) MediaFolder(ctx context.Context) model.MediaFolderRepository {
return NewMediaFolderRepository(ctx, s.getOrmer())
return NewMediaFolderRepository(ctx, s.getDBXBuilder())
}
func (s *SQLStore) Genre(ctx context.Context) model.GenreRepository {
return NewGenreRepository(ctx, s.getOrmer())
return NewGenreRepository(ctx, s.getDBXBuilder())
}
func (s *SQLStore) PlayQueue(ctx context.Context) model.PlayQueueRepository {
return NewPlayQueueRepository(ctx, s.getOrmer())
return NewPlayQueueRepository(ctx, s.getDBXBuilder())
}
func (s *SQLStore) Playlist(ctx context.Context) model.PlaylistRepository {
return NewPlaylistRepository(ctx, s.getOrmer())
return NewPlaylistRepository(ctx, s.getDBXBuilder())
}
func (s *SQLStore) Property(ctx context.Context) model.PropertyRepository {
return NewPropertyRepository(ctx, s.getOrmer())
return NewPropertyRepository(ctx, s.getDBXBuilder())
}
func (s *SQLStore) Radio(ctx context.Context) model.RadioRepository {
return NewRadioRepository(ctx, s.getOrmer())
return NewRadioRepository(ctx, s.getDBXBuilder())
}
func (s *SQLStore) UserProps(ctx context.Context) model.UserPropsRepository {
return NewUserPropsRepository(ctx, s.getOrmer())
return NewUserPropsRepository(ctx, s.getDBXBuilder())
}
func (s *SQLStore) Share(ctx context.Context) model.ShareRepository {
return NewShareRepository(ctx, s.getOrmer())
return NewShareRepository(ctx, s.getDBXBuilder())
}
func (s *SQLStore) User(ctx context.Context) model.UserRepository {
return NewUserRepository(ctx, s.getOrmer())
return NewUserRepository(ctx, s.getDBXBuilder())
}
func (s *SQLStore) Transcoding(ctx context.Context) model.TranscodingRepository {
return NewTranscodingRepository(ctx, s.getOrmer())
return NewTranscodingRepository(ctx, s.getDBXBuilder())
}
func (s *SQLStore) Player(ctx context.Context) model.PlayerRepository {
return NewPlayerRepository(ctx, s.getOrmer())
return NewPlayerRepository(ctx, s.getDBXBuilder())
}
func (s *SQLStore) ScrobbleBuffer(ctx context.Context) model.ScrobbleBufferRepository {
return NewScrobbleBufferRepository(ctx, s.getOrmer())
return NewScrobbleBufferRepository(ctx, s.getDBXBuilder())
}
func (s *SQLStore) Resource(ctx context.Context, m interface{}) model.ResourceRepository {
@@ -108,12 +107,12 @@ func (s *SQLStore) Resource(ctx context.Context, m interface{}) model.ResourceRe
}
func (s *SQLStore) WithTx(block func(tx model.DataStore) error) error {
o, err := orm.NewOrmWithDB(db.Driver, "default", s.db)
if err != nil {
return err
conn, ok := s.db.(*dbx.DB)
if !ok {
conn = dbx.NewFromDB(db.Db(), db.Driver)
}
return o.DoTx(func(ctx context.Context, txOrm orm.TxOrmer) error {
newDb := &SQLStore{orm: txOrm}
return conn.Transactional(func(tx *dbx.Tx) error {
newDb := &SQLStore{db: tx}
return block(newDb)
})
}
@@ -171,13 +170,9 @@ func (s *SQLStore) GC(ctx context.Context, rootFolder string) error {
return err
}
func (s *SQLStore) getOrmer() orm.QueryExecutor {
if s.orm == nil {
o, err := orm.NewOrmWithDB(db.Driver, "default", s.db)
if err != nil {
log.Error("Error obtaining new orm instance", err)
}
return o
func (s *SQLStore) getDBXBuilder() dbx.Builder {
if s.db == nil {
return dbx.NewFromDB(db.Db(), db.Driver)
}
return s.orm
return s.db
}
+13 -10
View File
@@ -5,7 +5,6 @@ import (
"path/filepath"
"testing"
"github.com/beego/beego/v2/client/orm"
_ "github.com/mattn/go-sqlite3"
"github.com/navidrome/navidrome/conf"
"github.com/navidrome/navidrome/db"
@@ -15,6 +14,7 @@ import (
"github.com/navidrome/navidrome/tests"
. "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega"
"github.com/pocketbase/dbx"
)
func TestPersistence(t *testing.T) {
@@ -23,13 +23,16 @@ func TestPersistence(t *testing.T) {
//os.Remove("./test-123.db")
//conf.Server.DbPath = "./test-123.db"
conf.Server.DbPath = "file::memory:?cache=shared"
_ = orm.RegisterDataBase("default", db.Driver, conf.Server.DbPath)
db.Init()
log.SetLevel(log.LevelError)
RegisterFailHandler(Fail)
RunSpecs(t, "Persistence Suite")
}
func getDBXBuilder() *dbx.DB {
return dbx.NewFromDB(db.Db(), db.Driver)
}
var (
genreElectronic = model.Genre{ID: "gn-1", Name: "Electronic"}
genreRock = model.Genre{ID: "gn-2", Name: "Rock"}
@@ -88,18 +91,18 @@ func P(path string) string {
// Initialize test DB
// TODO Load this data setup from file(s)
var _ = BeforeSuite(func() {
o := orm.NewOrm()
conn := getDBXBuilder()
ctx := log.NewContext(context.TODO())
user := model.User{ID: "userid", UserName: "userid", IsAdmin: true}
ctx = request.WithUser(ctx, user)
ur := NewUserRepository(ctx, o)
ur := NewUserRepository(ctx, conn)
err := ur.Put(&user)
if err != nil {
panic(err)
}
gr := NewGenreRepository(ctx, o)
gr := NewGenreRepository(ctx, conn)
for i := range testGenres {
g := testGenres[i]
err := gr.Put(&g)
@@ -108,7 +111,7 @@ var _ = BeforeSuite(func() {
}
}
mr := NewMediaFileRepository(ctx, o)
mr := NewMediaFileRepository(ctx, conn)
for i := range testSongs {
s := testSongs[i]
err := mr.Put(&s)
@@ -117,7 +120,7 @@ var _ = BeforeSuite(func() {
}
}
alr := NewAlbumRepository(ctx, o).(*albumRepository)
alr := NewAlbumRepository(ctx, conn).(*albumRepository)
for i := range testAlbums {
a := testAlbums[i]
err := alr.Put(&a)
@@ -126,7 +129,7 @@ var _ = BeforeSuite(func() {
}
}
arr := NewArtistRepository(ctx, o)
arr := NewArtistRepository(ctx, conn)
for i := range testArtists {
a := testArtists[i]
err := arr.Put(&a)
@@ -135,7 +138,7 @@ var _ = BeforeSuite(func() {
}
}
rar := NewRadioRepository(ctx, o)
rar := NewRadioRepository(ctx, conn)
for i := range testRadios {
r := testRadios[i]
err := rar.Put(&r)
@@ -157,7 +160,7 @@ var _ = BeforeSuite(func() {
plsCool.AddTracks([]string{"1004"})
testPlaylists = []*model.Playlist{&plsBest, &plsCool}
pr := NewPlaylistRepository(ctx, o)
pr := NewPlaylistRepository(ctx, conn)
for i := range testPlaylists {
err := pr.Put(testPlaylists[i])
if err != nil {
+3 -3
View File
@@ -5,9 +5,9 @@ import (
"errors"
. "github.com/Masterminds/squirrel"
"github.com/beego/beego/v2/client/orm"
"github.com/deluan/rest"
"github.com/navidrome/navidrome/model"
"github.com/pocketbase/dbx"
)
type playerRepository struct {
@@ -15,10 +15,10 @@ type playerRepository struct {
sqlRestful
}
func NewPlayerRepository(ctx context.Context, o orm.QueryExecutor) model.PlayerRepository {
func NewPlayerRepository(ctx context.Context, db dbx.Builder) model.PlayerRepository {
r := &playerRepository{}
r.ctx = ctx
r.ormer = o
r.db = db
r.tableName = "player"
r.filterMappings = map[string]filterFunc{
"name": containsFilter,
+50 -45
View File
@@ -2,18 +2,18 @@ package persistence
import (
"context"
"database/sql"
"encoding/json"
"errors"
"strings"
"time"
. "github.com/Masterminds/squirrel"
"github.com/beego/beego/v2/client/orm"
"github.com/deluan/rest"
"github.com/navidrome/navidrome/log"
"github.com/navidrome/navidrome/model"
"github.com/navidrome/navidrome/model/criteria"
"github.com/navidrome/navidrome/utils/slice"
"github.com/pocketbase/dbx"
)
type playlistRepository struct {
@@ -23,13 +23,30 @@ type playlistRepository struct {
type dbPlaylist struct {
model.Playlist `structs:",flatten"`
RawRules string `structs:"rules" orm:"column(rules)"`
Rules sql.NullString `structs:"-"`
}
func NewPlaylistRepository(ctx context.Context, o orm.QueryExecutor) model.PlaylistRepository {
func (p *dbPlaylist) PostScan() error {
if p.Rules.String != "" {
return json.Unmarshal([]byte(p.Rules.String), &p.Playlist.Rules)
}
return nil
}
func (p dbPlaylist) PostMapArgs(args map[string]any) error {
var err error
if p.Playlist.IsSmartPlaylist() {
args["rules"], err = json.Marshal(p.Playlist.Rules)
return err
}
delete(args, "rules")
return nil
}
func NewPlaylistRepository(ctx context.Context, db dbx.Builder) model.PlaylistRepository {
r := &playlistRepository{}
r.ctx = ctx
r.ormer = o
r.db = db
r.tableName = "playlist"
r.filterMappings = map[string]filterFunc{
"q": playlistFilter,
@@ -64,8 +81,8 @@ func (r *playlistRepository) userFilter() Sqlizer {
}
func (r *playlistRepository) CountAll(options ...model.QueryOptions) (int64, error) {
sql := Select().Where(r.userFilter())
return r.count(sql, options...)
sq := Select().Where(r.userFilter())
return r.count(sq, options...)
}
func (r *playlistRepository) Exists(id string) (bool, error) {
@@ -88,13 +105,6 @@ func (r *playlistRepository) Delete(id string) error {
func (r *playlistRepository) Put(p *model.Playlist) error {
pls := dbPlaylist{Playlist: *p}
if p.IsSmartPlaylist() {
j, err := json.Marshal(p.Rules)
if err != nil {
return err
}
pls.RawRules = string(j)
}
if pls.ID == "" {
pls.CreatedAt = time.Now()
} else {
@@ -161,22 +171,7 @@ func (r *playlistRepository) findBy(sql Sqlizer) (*model.Playlist, error) {
return nil, model.ErrNotFound
}
return r.toModel(pls[0])
}
func (r *playlistRepository) toModel(pls dbPlaylist) (*model.Playlist, error) {
var err error
if strings.TrimSpace(pls.RawRules) != "" {
var c criteria.Criteria
err = json.Unmarshal([]byte(pls.RawRules), &c)
if err != nil {
return nil, err
}
pls.Playlist.Rules = &c
} else {
pls.Playlist.Rules = nil
}
return &pls.Playlist, err
return &pls[0].Playlist, nil
}
func (r *playlistRepository) GetAll(options ...model.QueryOptions) (model.Playlists, error) {
@@ -188,11 +183,7 @@ func (r *playlistRepository) GetAll(options ...model.QueryOptions) (model.Playli
}
playlists := make(model.Playlists, len(res))
for i, p := range res {
pls, err := r.toModel(p)
if err != nil {
return nil, err
}
playlists[i] = *pls
playlists[i] = p.Playlist
}
return playlists, err
}
@@ -204,13 +195,14 @@ func (r *playlistRepository) selectPlaylist(options ...model.QueryOptions) Selec
func (r *playlistRepository) refreshSmartPlaylist(pls *model.Playlist) bool {
// Only refresh if it is a smart playlist and was not refreshed in the last 5 seconds
if !pls.IsSmartPlaylist() || time.Since(pls.EvaluatedAt) < 5*time.Second {
if !pls.IsSmartPlaylist() || (pls.EvaluatedAt != nil && time.Since(*pls.EvaluatedAt) < 5*time.Second) {
return false
}
// Never refresh other users' playlists
usr := loggedUser(r.ctx)
if pls.OwnerID != usr.ID {
log.Trace(r.ctx, "Not refreshing smart playlist from other user", "playlist", pls.Name, "id", pls.ID)
return false
}
@@ -221,20 +213,21 @@ func (r *playlistRepository) refreshSmartPlaylist(pls *model.Playlist) bool {
del := Delete("playlist_tracks").Where(Eq{"playlist_id": pls.ID})
_, err := r.executeSQL(del)
if err != nil {
log.Error(r.ctx, "Error deleting old smart playlist tracks", "playlist", pls.Name, "id", pls.ID, err)
return false
}
// Re-populate playlist based on Smart Playlist criteria
rules := *pls.Rules
sql := Select("row_number() over (order by "+rules.OrderBy()+") as id", "'"+pls.ID+"' as playlist_id", "media_file.id as media_file_id").
sq := Select("row_number() over (order by "+rules.OrderBy()+") as id", "'"+pls.ID+"' as playlist_id", "media_file.id as media_file_id").
From("media_file").LeftJoin("annotation on (" +
"annotation.item_id = media_file.id" +
" AND annotation.item_type = 'media_file'" +
" AND annotation.user_id = '" + userId(r.ctx) + "')").
LeftJoin("media_file_genres ag on media_file.id = ag.media_file_id").
LeftJoin("genre on ag.genre_id = genre.id").GroupBy("media_file.id")
sql = r.addCriteria(sql, rules)
insSql := Insert("playlist_tracks").Columns("id", "playlist_id", "media_file_id").Select(sql)
sq = r.addCriteria(sq, rules)
insSql := Insert("playlist_tracks").Columns("id", "playlist_id", "media_file_id").Select(sq)
_, err = r.executeSQL(insSql)
if err != nil {
log.Error(r.ctx, "Error refreshing smart playlist tracks", "playlist", pls.Name, "id", pls.ID, err)
@@ -262,7 +255,7 @@ func (r *playlistRepository) refreshSmartPlaylist(pls *model.Playlist) bool {
}
func (r *playlistRepository) addCriteria(sql SelectBuilder, c criteria.Criteria) SelectBuilder {
sql = sql.Where(c.ToSql())
sql = sql.Where(c)
if c.Limit > 0 {
sql = sql.Limit(uint64(c.Limit)).Offset(uint64(c.Offset))
}
@@ -318,7 +311,11 @@ func (r *playlistRepository) addTracks(playlistId string, startingPos int, media
// RefreshStatus updates total playlist duration, size and count
func (r *playlistRepository) refreshCounters(pls *model.Playlist) error {
statsSql := Select("sum(duration) as duration", "sum(size) as size", "count(*) as count").
statsSql := Select(
"coalesce(sum(duration), 0) as duration",
"coalesce(sum(size), 0) as size",
"count(*) as count",
).
From("media_file").
Join("playlist_tracks f on f.media_file_id = media_file.id").
Where(Eq{"playlist_id": pls.ID})
@@ -347,7 +344,15 @@ func (r *playlistRepository) refreshCounters(pls *model.Playlist) error {
func (r *playlistRepository) loadTracks(sel SelectBuilder, id string) (model.PlaylistTracks, error) {
tracksQuery := sel.
Columns("starred", "starred_at", "play_count", "play_date", "rating", "f.*", "playlist_tracks.*").
Columns(
"coalesce(starred, 0)",
"starred_at",
"coalesce(play_count, 0)",
"play_date",
"coalesce(rating, 0)",
"f.*",
"playlist_tracks.*",
).
LeftJoin("annotation on (" +
"annotation.item_id = media_file_id" +
" AND annotation.item_type = 'media_file'" +
@@ -402,7 +407,7 @@ func (r *playlistRepository) Update(id string, entity interface{}, cols ...strin
if !usr.IsAdmin && current.OwnerID != usr.ID {
return rest.ErrPermissionDenied
}
pls := entity.(*model.Playlist)
pls := dbPlaylist{Playlist: *entity.(*model.Playlist)}
pls.ID = id
pls.UpdatedAt = time.Now()
_, err = r.put(id, pls, append(cols, "updatedAt")...)
@@ -447,8 +452,8 @@ func (r *playlistRepository) removeOrphans() error {
func (r *playlistRepository) renumber(id string) error {
var ids []string
sql := Select("media_file_id").From("playlist_tracks").Where(Eq{"playlist_id": id}).OrderBy("id")
err := r.queryAll(sql, &ids)
sq := Select("media_file_id").From("playlist_tracks").Where(Eq{"playlist_id": id}).OrderBy("id")
err := r.queryAllSlice(sq, &ids)
if err != nil {
return err
}
+22 -3
View File
@@ -3,9 +3,9 @@ package persistence
import (
"context"
"github.com/beego/beego/v2/client/orm"
"github.com/navidrome/navidrome/log"
"github.com/navidrome/navidrome/model"
"github.com/navidrome/navidrome/model/criteria"
"github.com/navidrome/navidrome/model/request"
. "github.com/onsi/ginkgo/v2"
. "github.com/onsi/gomega"
@@ -17,7 +17,7 @@ var _ = Describe("PlaylistRepository", func() {
BeforeEach(func() {
ctx := log.NewContext(context.TODO())
ctx = request.WithUser(ctx, model.User{ID: "userid", UserName: "userid", IsAdmin: true})
repo = NewPlaylistRepository(ctx, orm.NewOrm())
repo = NewPlaylistRepository(ctx, getDBXBuilder())
})
Describe("Count", func() {
@@ -56,7 +56,7 @@ var _ = Describe("PlaylistRepository", func() {
})
It("returns all tracks", func() {
pls, err := repo.GetWithTracks(plsBest.ID, true)
Expect(err).To(BeNil())
Expect(err).ToNot(HaveOccurred())
Expect(pls.Name).To(Equal(plsBest.Name))
Expect(pls.Tracks).To(HaveLen(2))
Expect(pls.Tracks[0].ID).To(Equal("1"))
@@ -109,4 +109,23 @@ var _ = Describe("PlaylistRepository", func() {
Expect(all[1].ID).To(Equal(plsCool.ID))
})
})
Context("Smart Playlists", func() {
var rules *criteria.Criteria
BeforeEach(func() {
rules = &criteria.Criteria{
Expression: criteria.All{
criteria.Contains{"title": "love"},
},
}
})
It("Put/Get", func() {
newPls := model.Playlist{Name: "Great!", OwnerID: "userid", Rules: rules}
Expect(repo.Put(&newPls)).To(Succeed())
savedPls, err := repo.Get(newPls.ID)
Expect(err).ToNot(HaveOccurred())
Expect(savedPls.Rules).To(Equal(rules))
})
})
})
+19 -9
View File
@@ -1,6 +1,8 @@
package persistence
import (
"database/sql"
. "github.com/Masterminds/squirrel"
"github.com/deluan/rest"
"github.com/navidrome/navidrome/log"
@@ -21,7 +23,7 @@ func (r *playlistRepository) Tracks(playlistId string, refreshSmartPlaylist bool
p.playlistRepo = r
p.playlistId = playlistId
p.ctx = r.ctx
p.ormer = r.ormer
p.db = r.db
p.tableName = "playlist_tracks"
p.sortMappings = map[string]string{
"id": "playlist_tracks.id",
@@ -47,7 +49,15 @@ func (r *playlistTrackRepository) Read(id string) (interface{}, error) {
"annotation.item_id = media_file_id"+
" AND annotation.item_type = 'media_file'"+
" AND annotation.user_id = '"+userId(r.ctx)+"')").
Columns("starred", "starred_at", "play_count", "play_date", "rating", "f.*", "playlist_tracks.*").
Columns(
"coalesce(starred, 0)",
"coalesce(play_count, 0)",
"coalesce(rating, 0)",
"starred_at",
"play_date",
"f.*",
"playlist_tracks.*",
).
Join("media_file f on f.id = media_file_id").
Where(And{Eq{"playlist_id": r.playlistId}, Eq{"id": id}})
var trk model.PlaylistTrack
@@ -77,7 +87,7 @@ func (r *playlistTrackRepository) GetAlbumIDs(options ...model.QueryOptions) ([]
Join("media_file mf on mf.id = media_file_id").
Where(Eq{"playlist_id": r.playlistId})
var ids []string
err := r.queryAll(sql, &ids)
err := r.queryAllSlice(sql, &ids)
if err != nil {
return nil, err
}
@@ -112,20 +122,20 @@ func (r *playlistTrackRepository) Add(mediaFileIds []string) (int, error) {
}
// Get next pos (ID) in playlist
sql := r.newSelect().Columns("max(id) as max").Where(Eq{"playlist_id": r.playlistId})
var max int
err := r.queryOne(sql, &max)
sq := r.newSelect().Columns("max(id) as max").Where(Eq{"playlist_id": r.playlistId})
var res struct{ Max sql.NullInt32 }
err := r.queryOne(sq, &res)
if err != nil {
return 0, err
}
return len(mediaFileIds), r.playlistRepo.addTracks(r.playlistId, max+1, mediaFileIds)
return len(mediaFileIds), r.playlistRepo.addTracks(r.playlistId, int(res.Max.Int32+1), mediaFileIds)
}
func (r *playlistTrackRepository) addMediaFileIds(cond Sqlizer) (int, error) {
sq := Select("id").From("media_file").Where(cond).OrderBy("album_artist, album, release_date, disc_number, track_number")
var ids []string
err := r.queryAll(sq, &ids)
err := r.queryAllSlice(sq, &ids)
if err != nil {
log.Error(r.ctx, "Error getting tracks to add to playlist", err)
return 0, err
@@ -156,7 +166,7 @@ func (r *playlistTrackRepository) AddDiscs(discs []model.DiscID) (int, error) {
func (r *playlistTrackRepository) getTracks() ([]string, error) {
all := r.newSelect().Columns("media_file_id").Where(Eq{"playlist_id": r.playlistId}).OrderBy("id")
var ids []string
err := r.queryAll(all, &ids)
err := r.queryAllSlice(all, &ids)
if err != nil {
log.Error(r.ctx, "Error querying current tracks from playlist", "playlistId", r.playlistId, err)
return nil, err
+6 -6
View File
@@ -6,27 +6,27 @@ import (
"time"
. "github.com/Masterminds/squirrel"
"github.com/beego/beego/v2/client/orm"
"github.com/navidrome/navidrome/log"
"github.com/navidrome/navidrome/model"
"github.com/navidrome/navidrome/utils/slice"
"github.com/pocketbase/dbx"
)
type playQueueRepository struct {
sqlRepository
}
func NewPlayQueueRepository(ctx context.Context, o orm.QueryExecutor) model.PlayQueueRepository {
func NewPlayQueueRepository(ctx context.Context, db dbx.Builder) model.PlayQueueRepository {
r := &playQueueRepository{}
r.ctx = ctx
r.ormer = o
r.db = db
r.tableName = "playqueue"
return r
}
type playQueue struct {
ID string `structs:"id" orm:"column(id)"`
UserID string `structs:"user_id" orm:"column(user_id)"`
ID string `structs:"id"`
UserID string `structs:"user_id"`
Current string `structs:"current"`
Position int64 `structs:"position"`
ChangedBy string `structs:"changed_by"`
@@ -116,7 +116,7 @@ func (r *playQueueRepository) loadTracks(tracks model.MediaFiles) model.MediaFil
chunks := slice.BreakUp(ids, 50)
// Query each chunk of media_file ids and store results in a map
mfRepo := NewMediaFileRepository(r.ctx, r.ormer)
mfRepo := NewMediaFileRepository(r.ctx, r.db)
trackMap := map[string]model.MediaFile{}
for i := range chunks {
idsFilter := Eq{"media_file.id": chunks[i]}
+11 -12
View File
@@ -5,7 +5,6 @@ import (
"time"
"github.com/Masterminds/squirrel"
"github.com/beego/beego/v2/client/orm"
"github.com/google/uuid"
"github.com/navidrome/navidrome/log"
"github.com/navidrome/navidrome/model"
@@ -19,8 +18,8 @@ var _ = Describe("PlayQueueRepository", func() {
BeforeEach(func() {
ctx := log.NewContext(context.TODO())
ctx = request.WithUser(ctx, model.User{ID: "user1", UserName: "user1", IsAdmin: true})
repo = NewPlayQueueRepository(ctx, orm.NewOrm())
ctx = request.WithUser(ctx, model.User{ID: "userid", UserName: "userid", IsAdmin: true})
repo = NewPlayQueueRepository(ctx, getDBXBuilder())
})
Describe("PlayQueues", func() {
@@ -32,24 +31,24 @@ var _ = Describe("PlayQueueRepository", func() {
It("stores and retrieves the playqueue for the user", func() {
By("Storing a playqueue for the user")
expected := aPlayQueue("user1", songDayInALife.ID, 123, songComeTogether, songDayInALife)
Expect(repo.Store(expected)).To(BeNil())
expected := aPlayQueue("userid", songDayInALife.ID, 123, songComeTogether, songDayInALife)
Expect(repo.Store(expected)).To(Succeed())
actual, err := repo.Retrieve("user1")
Expect(err).To(BeNil())
actual, err := repo.Retrieve("userid")
Expect(err).ToNot(HaveOccurred())
AssertPlayQueue(expected, actual)
By("Storing a new playqueue for the same user")
another := aPlayQueue("user1", songRadioactivity.ID, 321, songAntenna, songRadioactivity)
Expect(repo.Store(another)).To(BeNil())
another := aPlayQueue("userid", songRadioactivity.ID, 321, songAntenna, songRadioactivity)
Expect(repo.Store(another)).To(Succeed())
actual, err = repo.Retrieve("user1")
Expect(err).To(BeNil())
actual, err = repo.Retrieve("userid")
Expect(err).ToNot(HaveOccurred())
AssertPlayQueue(another, actual)
Expect(countPlayQueues(repo, "user1")).To(Equal(1))
Expect(countPlayQueues(repo, "userid")).To(Equal(1))
})
})
})
+3 -3
View File
@@ -5,18 +5,18 @@ import (
"errors"
. "github.com/Masterminds/squirrel"
"github.com/beego/beego/v2/client/orm"
"github.com/navidrome/navidrome/model"
"github.com/pocketbase/dbx"
)
type propertyRepository struct {
sqlRepository
}
func NewPropertyRepository(ctx context.Context, o orm.QueryExecutor) model.PropertyRepository {
func NewPropertyRepository(ctx context.Context, db dbx.Builder) model.PropertyRepository {
r := &propertyRepository{}
r.ctx = ctx
r.ormer = o
r.db = db
r.tableName = "property"
return r
}
+1 -2
View File
@@ -3,7 +3,6 @@ package persistence
import (
"context"
"github.com/beego/beego/v2/client/orm"
"github.com/navidrome/navidrome/log"
"github.com/navidrome/navidrome/model"
. "github.com/onsi/ginkgo/v2"
@@ -14,7 +13,7 @@ var _ = Describe("Property Repository", func() {
var pr model.PropertyRepository
BeforeEach(func() {
pr = NewPropertyRepository(log.NewContext(context.TODO()), orm.NewOrm())
pr = NewPropertyRepository(log.NewContext(context.TODO()), getDBXBuilder())
})
It("saves and restore a new property", func() {
+5 -5
View File
@@ -7,10 +7,10 @@ import (
"time"
. "github.com/Masterminds/squirrel"
"github.com/beego/beego/v2/client/orm"
"github.com/deluan/rest"
"github.com/google/uuid"
"github.com/navidrome/navidrome/model"
"github.com/pocketbase/dbx"
)
type radioRepository struct {
@@ -18,10 +18,10 @@ type radioRepository struct {
sqlRestful
}
func NewRadioRepository(ctx context.Context, o orm.QueryExecutor) model.RadioRepository {
func NewRadioRepository(ctx context.Context, db dbx.Builder) model.RadioRepository {
r := &radioRepository{}
r.ctx = ctx
r.ormer = o
r.db = db
r.tableName = "radio"
r.filterMappings = map[string]filterFunc{
"name": containsFilter,
@@ -73,9 +73,9 @@ func (r *radioRepository) Put(radio *model.Radio) error {
if radio.ID == "" {
radio.CreatedAt = time.Now()
radio.ID = strings.ReplaceAll(uuid.NewString(), "-", "")
values, _ = toSqlArgs(*radio)
values, _ = toSQLArgs(*radio)
} else {
values, _ = toSqlArgs(*radio)
values, _ = toSQLArgs(*radio)
update := Update(r.tableName).Where(Eq{"id": radio.ID}).SetMap(values)
count, err := r.executeSQL(update)
+2 -3
View File
@@ -3,7 +3,6 @@ package persistence
import (
"context"
"github.com/beego/beego/v2/client/orm"
"github.com/deluan/rest"
"github.com/navidrome/navidrome/log"
"github.com/navidrome/navidrome/model"
@@ -23,7 +22,7 @@ var _ = Describe("RadioRepository", func() {
BeforeEach(func() {
ctx := log.NewContext(context.TODO())
ctx = request.WithUser(ctx, model.User{ID: "userid", UserName: "userid", IsAdmin: true})
repo = NewRadioRepository(ctx, orm.NewOrm())
repo = NewRadioRepository(ctx, getDBXBuilder())
_ = repo.Put(&radioWithHomePage)
})
@@ -120,7 +119,7 @@ var _ = Describe("RadioRepository", func() {
BeforeEach(func() {
ctx := log.NewContext(context.TODO())
ctx = request.WithUser(ctx, model.User{ID: "userid", UserName: "userid", IsAdmin: false})
repo = NewRadioRepository(ctx, orm.NewOrm())
repo = NewRadioRepository(ctx, getDBXBuilder())
})
Describe("Count", func() {
+4 -4
View File
@@ -6,18 +6,18 @@ import (
"time"
. "github.com/Masterminds/squirrel"
"github.com/beego/beego/v2/client/orm"
"github.com/navidrome/navidrome/model"
"github.com/pocketbase/dbx"
)
type scrobbleBufferRepository struct {
sqlRepository
}
func NewScrobbleBufferRepository(ctx context.Context, o orm.QueryExecutor) model.ScrobbleBufferRepository {
func NewScrobbleBufferRepository(ctx context.Context, db dbx.Builder) model.ScrobbleBufferRepository {
r := &scrobbleBufferRepository{}
r.ctx = ctx
r.ormer = o
r.db = db
r.tableName = "scrobble_buffer"
return r
}
@@ -31,7 +31,7 @@ func (r *scrobbleBufferRepository) UserIDs(service string) ([]string, error) {
GroupBy("user_id").
OrderBy("count(*)")
var userIds []string
err := r.queryAll(sql, &userIds)
err := r.queryAllSlice(sql, &userIds)
return userIds, err
}
+9 -9
View File
@@ -8,11 +8,11 @@ import (
"time"
. "github.com/Masterminds/squirrel"
"github.com/beego/beego/v2/client/orm"
"github.com/deluan/rest"
"github.com/navidrome/navidrome/log"
"github.com/navidrome/navidrome/model"
"github.com/navidrome/navidrome/model/request"
"github.com/pocketbase/dbx"
)
type shareRepository struct {
@@ -20,10 +20,10 @@ type shareRepository struct {
sqlRestful
}
func NewShareRepository(ctx context.Context, o orm.QueryExecutor) model.ShareRepository {
func NewShareRepository(ctx context.Context, db dbx.Builder) model.ShareRepository {
r := &shareRepository{}
r.ctx = ctx
r.ormer = o
r.db = db
r.tableName = "share"
return r
}
@@ -79,27 +79,27 @@ func (r *shareRepository) loadMedia(share *model.Share) error {
}
switch share.ResourceType {
case "artist":
albumRepo := NewAlbumRepository(r.ctx, r.ormer)
albumRepo := NewAlbumRepository(r.ctx, r.db)
share.Albums, err = albumRepo.GetAll(model.QueryOptions{Filters: Eq{"album_artist_id": ids}, Sort: "artist"})
if err != nil {
return err
}
mfRepo := NewMediaFileRepository(r.ctx, r.ormer)
mfRepo := NewMediaFileRepository(r.ctx, r.db)
share.Tracks, err = mfRepo.GetAll(model.QueryOptions{Filters: Eq{"album_artist_id": ids}, Sort: "artist"})
return err
case "album":
albumRepo := NewAlbumRepository(r.ctx, r.ormer)
albumRepo := NewAlbumRepository(r.ctx, r.db)
share.Albums, err = albumRepo.GetAll(model.QueryOptions{Filters: Eq{"id": ids}})
if err != nil {
return err
}
mfRepo := NewMediaFileRepository(r.ctx, r.ormer)
mfRepo := NewMediaFileRepository(r.ctx, r.db)
share.Tracks, err = mfRepo.GetAll(model.QueryOptions{Filters: Eq{"album_id": ids}, Sort: "album"})
return err
case "playlist":
// Create a context with a fake admin user, to be able to access all playlists
ctx := request.WithUser(r.ctx, model.User{IsAdmin: true})
plsRepo := NewPlaylistRepository(ctx, r.ormer)
plsRepo := NewPlaylistRepository(ctx, r.db)
tracks, err := plsRepo.Tracks(ids[0], true).GetAll(model.QueryOptions{Sort: "id"})
if err != nil {
return err
@@ -109,7 +109,7 @@ func (r *shareRepository) loadMedia(share *model.Share) error {
}
return nil
case "media_file":
mfRepo := NewMediaFileRepository(r.ctx, r.ormer)
mfRepo := NewMediaFileRepository(r.ctx, r.db)
tracks, err := mfRepo.GetAll(model.QueryOptions{Filters: Eq{"id": ids}})
share.Tracks = sortByIdPosition(tracks, ids)
return err
+3 -3
View File
@@ -1,11 +1,11 @@
package persistence
import (
"database/sql"
"errors"
"time"
. "github.com/Masterminds/squirrel"
"github.com/beego/beego/v2/client/orm"
"github.com/google/uuid"
"github.com/navidrome/navidrome/log"
"github.com/navidrome/navidrome/model"
@@ -42,7 +42,7 @@ func (r sqlRepository) annUpsert(values map[string]interface{}, itemIDs ...strin
upd = upd.Set(f, v)
}
c, err := r.executeSQL(upd)
if c == 0 || errors.Is(err, orm.ErrNoRows) {
if c == 0 || errors.Is(err, sql.ErrNoRows) {
for _, itemID := range itemIDs {
values["ann_id"] = uuid.NewString()
values["user_id"] = userId(r.ctx)
@@ -73,7 +73,7 @@ func (r sqlRepository) IncPlayCount(itemID string, ts time.Time) error {
Set("play_date", Expr("max(ifnull(play_date,''),?)", ts))
c, err := r.executeSQL(upd)
if c == 0 || errors.Is(err, orm.ErrNoRows) {
if c == 0 || errors.Is(err, sql.ErrNoRows) {
values := map[string]interface{}{}
values["ann_id"] = uuid.NewString()
values["user_id"] = userId(r.ctx)
+60 -28
View File
@@ -2,24 +2,25 @@ package persistence
import (
"context"
"database/sql"
"errors"
"fmt"
"strings"
"time"
. "github.com/Masterminds/squirrel"
"github.com/beego/beego/v2/client/orm"
"github.com/google/uuid"
"github.com/navidrome/navidrome/conf"
"github.com/navidrome/navidrome/log"
"github.com/navidrome/navidrome/model"
"github.com/navidrome/navidrome/model/request"
"github.com/pocketbase/dbx"
)
type sqlRepository struct {
ctx context.Context
tableName string
ormer orm.QueryExecutor
db dbx.Builder
sortMappings map[string]string
}
@@ -122,13 +123,13 @@ func (r sqlRepository) applyFilters(sq SelectBuilder, options ...model.QueryOpti
}
func (r sqlRepository) executeSQL(sq Sqlizer) (int64, error) {
query, args, err := sq.ToSql()
query, args, err := r.toSQL(sq)
if err != nil {
return 0, err
}
start := time.Now()
var c int64
res, err := r.ormer.Raw(query, args...).Exec()
res, err := r.db.NewQuery(query).Bind(args).Execute()
if res != nil {
c, _ = res.RowsAffected()
}
@@ -141,16 +142,31 @@ func (r sqlRepository) executeSQL(sq Sqlizer) (int64, error) {
return res.RowsAffected()
}
func (r sqlRepository) toSQL(sq Sqlizer) (string, dbx.Params, error) {
query, args, err := sq.ToSql()
if err != nil {
return "", nil, err
}
// Replace query placeholders with named params
params := dbx.Params{}
for i, arg := range args {
p := fmt.Sprintf("p%d", i)
query = strings.Replace(query, "?", "{:"+p+"}", 1)
params[p] = arg
}
return query, params, nil
}
// Note: Due to a bug in the QueryRow method, this function does not map any embedded structs (ex: annotations)
// In this case, use the queryAll method and get the first item of the returned list
func (r sqlRepository) queryOne(sq Sqlizer, response interface{}) error {
query, args, err := sq.ToSql()
query, args, err := r.toSQL(sq)
if err != nil {
return err
}
start := time.Now()
err = r.ormer.Raw(query, args...).QueryRow(response)
if errors.Is(err, orm.ErrNoRows) {
err = r.db.NewQuery(query).Bind(args).One(response)
if errors.Is(err, sql.ErrNoRows) {
r.logSQL(query, args, nil, 0, start)
return model.ErrNotFound
}
@@ -162,17 +178,33 @@ func (r sqlRepository) queryAll(sq SelectBuilder, response interface{}, options
if len(options) > 0 && options[0].Offset > 0 {
sq = r.optimizePagination(sq, options[0])
}
query, args, err := sq.ToSql()
query, args, err := r.toSQL(sq)
if err != nil {
return err
}
start := time.Now()
c, err := r.ormer.Raw(query, args...).QueryRows(response)
if errors.Is(err, orm.ErrNoRows) {
r.logSQL(query, args, nil, c, start)
err = r.db.NewQuery(query).Bind(args).All(response)
if errors.Is(err, sql.ErrNoRows) {
r.logSQL(query, args, nil, -1, start)
return model.ErrNotFound
}
r.logSQL(query, args, err, c, start)
r.logSQL(query, args, err, -1, start)
return err
}
// queryAllSlice is a helper function to query a single column and return the result in a slice
func (r sqlRepository) queryAllSlice(sq SelectBuilder, response interface{}) error {
query, args, err := r.toSQL(sq)
if err != nil {
return err
}
start := time.Now()
err = r.db.NewQuery(query).Bind(args).Column(response)
if errors.Is(err, sql.ErrNoRows) {
r.logSQL(query, args, nil, -1, start)
return model.ErrNotFound
}
r.logSQL(query, args, err, -1, start)
return err
}
@@ -207,7 +239,7 @@ func (r sqlRepository) count(countQuery SelectBuilder, options ...model.QueryOpt
}
func (r sqlRepository) put(id string, m interface{}, colsToUpdate ...string) (newId string, err error) {
values, _ := toSqlArgs(m)
values, _ := toSQLArgs(m)
// If there's an ID, try to update first
if id != "" {
updateValues := map[string]interface{}{}
@@ -246,28 +278,28 @@ func (r sqlRepository) put(id string, m interface{}, colsToUpdate ...string) (ne
func (r sqlRepository) delete(cond Sqlizer) error {
del := Delete(r.tableName).Where(cond)
_, err := r.executeSQL(del)
if errors.Is(err, orm.ErrNoRows) {
if errors.Is(err, sql.ErrNoRows) {
return model.ErrNotFound
}
return err
}
func (r sqlRepository) logSQL(sql string, args []interface{}, err error, rowsAffected int64, start time.Time) {
func (r sqlRepository) logSQL(sql string, args dbx.Params, err error, rowsAffected int64, start time.Time) {
elapsed := time.Since(start)
var fmtArgs []string
for i := range args {
var f string
switch a := args[i].(type) {
case string:
f = `'` + a + `'`
default:
f = fmt.Sprintf("%v", a)
}
fmtArgs = append(fmtArgs, f)
}
//var fmtArgs []string
//for name, val := range args {
// var f string
// switch a := args[val].(type) {
// case string:
// f = `'` + a + `'`
// default:
// f = fmt.Sprintf("%v", a)
// }
// fmtArgs = append(fmtArgs, f)
//}
if err != nil {
log.Error(r.ctx, "SQL: `"+sql+"`", "args", `[`+strings.Join(fmtArgs, ",")+`]`, "rowsAffected", rowsAffected, "elapsedTime", elapsed, err)
log.Error(r.ctx, "SQL: `"+sql+"`", "args", args, "rowsAffected", rowsAffected, "elapsedTime", elapsed, err)
} else {
log.Trace(r.ctx, "SQL: `"+sql+"`", "args", `[`+strings.Join(fmtArgs, ",")+`]`, "rowsAffected", rowsAffected, "elapsedTime", elapsed)
log.Trace(r.ctx, "SQL: `"+sql+"`", "args", args, "rowsAffected", rowsAffected, "elapsedTime", elapsed)
}
}
+10 -10
View File
@@ -1,11 +1,11 @@
package persistence
import (
"database/sql"
"errors"
"time"
. "github.com/Masterminds/squirrel"
"github.com/beego/beego/v2/client/orm"
"github.com/navidrome/navidrome/log"
"github.com/navidrome/navidrome/model"
"github.com/navidrome/navidrome/model/request"
@@ -19,7 +19,7 @@ func (r sqlRepository) withBookmark(sql SelectBuilder, idField string) SelectBui
"bookmark.item_id = " + idField +
" AND bookmark.item_type = '" + r.tableName + "'" +
" AND bookmark.user_id = '" + userId(r.ctx) + "')").
Columns("position as bookmark_position")
Columns("coalesce(position, 0) as bookmark_position")
}
func (r sqlRepository) bmkID(itemID ...string) And {
@@ -45,7 +45,7 @@ func (r sqlRepository) bmkUpsert(itemID, comment string, position int64) error {
if err == nil {
log.Debug(r.ctx, "Updated bookmark", "id", itemID, "user", user.UserName, "position", position, "comment", comment)
}
if c == 0 || errors.Is(err, orm.ErrNoRows) {
if c == 0 || errors.Is(err, sql.ErrNoRows) {
values["user_id"] = user.ID
values["item_type"] = r.tableName
values["item_id"] = itemID
@@ -82,8 +82,8 @@ func (r sqlRepository) DeleteBookmark(id string) error {
}
type bookmark struct {
UserID string `json:"user_id" orm:"column(user_id)"`
ItemID string `json:"item_id" orm:"column(item_id)"`
UserID string `json:"user_id"`
ItemID string `json:"item_id"`
ItemType string `json:"item_type"`
Comment string `json:"comment"`
Position int64 `json:"position"`
@@ -96,10 +96,10 @@ func (r sqlRepository) GetBookmarks() (model.Bookmarks, error) {
user, _ := request.UserFrom(r.ctx)
idField := r.tableName + ".id"
sql := r.newSelectWithAnnotation(idField).Columns("*")
sql = r.withBookmark(sql, idField).Where(NotEq{bookmarkTable + ".item_id": nil})
sq := r.newSelectWithAnnotation(idField).Columns(r.tableName + ".*")
sq = r.withBookmark(sq, idField).Where(NotEq{bookmarkTable + ".item_id": nil})
var mfs model.MediaFiles
err := r.queryAll(sql, &mfs)
err := r.queryAll(sq, &mfs)
if err != nil {
log.Error(r.ctx, "Error getting mediafiles with bookmarks", "user", user.UserName, err)
return nil, err
@@ -117,9 +117,9 @@ func (r sqlRepository) GetBookmarks() (model.Bookmarks, error) {
mfMap[mf.ID] = i
}
sql = Select("*").From(bookmarkTable).Where(r.bmkID(ids...))
sq = Select("*").From(bookmarkTable).Where(r.bmkID(ids...))
var bmks []bookmark
err = r.queryAll(sql, &bmks)
err = r.queryAll(sq, &bmks)
if err != nil {
log.Error(r.ctx, "Error getting bookmarks", "user", user.UserName, "ids", ids, err)
return nil, err
+10 -8
View File
@@ -3,7 +3,6 @@ package persistence
import (
"context"
"github.com/beego/beego/v2/client/orm"
"github.com/navidrome/navidrome/log"
"github.com/navidrome/navidrome/model"
"github.com/navidrome/navidrome/model/request"
@@ -16,8 +15,8 @@ var _ = Describe("sqlBookmarks", func() {
BeforeEach(func() {
ctx := log.NewContext(context.TODO())
ctx = request.WithUser(ctx, model.User{ID: "user1"})
mr = NewMediaFileRepository(ctx, orm.NewOrm())
ctx = request.WithUser(ctx, model.User{ID: "userid"})
mr = NewMediaFileRepository(ctx, getDBXBuilder())
})
Describe("Bookmarks", func() {
@@ -30,7 +29,7 @@ var _ = Describe("sqlBookmarks", func() {
Expect(mr.AddBookmark(songAntenna.ID, "this is a comment", 123)).To(BeNil())
bms, err := mr.GetBookmarks()
Expect(err).To(BeNil())
Expect(err).ToNot(HaveOccurred())
Expect(bms).To(HaveLen(1))
Expect(bms[0].Item.ID).To(Equal(songAntenna.ID))
@@ -46,7 +45,7 @@ var _ = Describe("sqlBookmarks", func() {
Expect(mr.AddBookmark(songAntenna.ID, "another comment", 333)).To(BeNil())
bms, err = mr.GetBookmarks()
Expect(err).To(BeNil())
Expect(err).ToNot(HaveOccurred())
Expect(bms[0].Item.ID).To(Equal(songAntenna.ID))
Expect(bms[0].Comment).To(Equal("another comment"))
@@ -57,16 +56,19 @@ var _ = Describe("sqlBookmarks", func() {
By("Saving another bookmark")
Expect(mr.AddBookmark(songComeTogether.ID, "one more comment", 444)).To(BeNil())
bms, err = mr.GetBookmarks()
Expect(err).To(BeNil())
Expect(err).ToNot(HaveOccurred())
Expect(bms).To(HaveLen(2))
By("Delete bookmark")
Expect(mr.DeleteBookmark(songAntenna.ID))
Expect(mr.DeleteBookmark(songAntenna.ID)).To(Succeed())
bms, err = mr.GetBookmarks()
Expect(err).To(BeNil())
Expect(err).ToNot(HaveOccurred())
Expect(bms).To(HaveLen(1))
Expect(bms[0].Item.ID).To(Equal(songComeTogether.ID))
Expect(bms[0].Item.Title).To(Equal(songComeTogether.Title))
Expect(mr.DeleteBookmark(songComeTogether.ID)).To(Succeed())
Expect(mr.GetBookmarks()).To(BeEmpty())
})
})
})
+1 -1
View File
@@ -21,7 +21,7 @@ func (r sqlRepository) doSearch(q string, offset, size int, results interface{},
return nil
}
sq := r.newSelectWithAnnotation(r.tableName + ".id").Columns("*")
sq := r.newSelectWithAnnotation(r.tableName + ".id").Columns(r.tableName + ".*")
filter := fullTextExpr(q)
if filter != nil {
sq = sq.Where(filter)
+3 -3
View File
@@ -5,9 +5,9 @@ import (
"errors"
. "github.com/Masterminds/squirrel"
"github.com/beego/beego/v2/client/orm"
"github.com/deluan/rest"
"github.com/navidrome/navidrome/model"
"github.com/pocketbase/dbx"
)
type transcodingRepository struct {
@@ -15,10 +15,10 @@ type transcodingRepository struct {
sqlRestful
}
func NewTranscodingRepository(ctx context.Context, o orm.QueryExecutor) model.TranscodingRepository {
func NewTranscodingRepository(ctx context.Context, db dbx.Builder) model.TranscodingRepository {
r := &transcodingRepository{}
r.ctx = ctx
r.ormer = o
r.db = db
r.tableName = "transcoding"
return r
}
+3 -3
View File
@@ -5,18 +5,18 @@ import (
"errors"
. "github.com/Masterminds/squirrel"
"github.com/beego/beego/v2/client/orm"
"github.com/navidrome/navidrome/model"
"github.com/pocketbase/dbx"
)
type userPropsRepository struct {
sqlRepository
}
func NewUserPropsRepository(ctx context.Context, o orm.QueryExecutor) model.UserPropsRepository {
func NewUserPropsRepository(ctx context.Context, db dbx.Builder) model.UserPropsRepository {
r := &userPropsRepository{}
r.ctx = ctx
r.ormer = o
r.db = db
r.tableName = "user_props"
return r
}
+5 -5
View File
@@ -10,7 +10,6 @@ import (
"time"
. "github.com/Masterminds/squirrel"
"github.com/beego/beego/v2/client/orm"
"github.com/deluan/rest"
"github.com/google/uuid"
"github.com/navidrome/navidrome/conf"
@@ -18,6 +17,7 @@ import (
"github.com/navidrome/navidrome/log"
"github.com/navidrome/navidrome/model"
"github.com/navidrome/navidrome/utils"
"github.com/pocketbase/dbx"
)
type userRepository struct {
@@ -30,10 +30,10 @@ var (
encKey []byte
)
func NewUserRepository(ctx context.Context, o orm.QueryExecutor) model.UserRepository {
func NewUserRepository(ctx context.Context, db dbx.Builder) model.UserRepository {
r := &userRepository{}
r.ctx = ctx
r.ormer = o
r.db = db
r.tableName = "user"
once.Do(func() {
_ = r.initPasswordEncryptionKey()
@@ -67,7 +67,7 @@ func (r *userRepository) Put(u *model.User) error {
if u.NewPassword != "" {
_ = r.encryptPassword(u)
}
values, _ := toSqlArgs(*u)
values, _ := toSQLArgs(*u)
delete(values, "current_password")
update := Update(r.tableName).Where(Eq{"id": u.ID}).SetMap(values)
count, err := r.executeSQL(update)
@@ -268,7 +268,7 @@ func (r *userRepository) initPasswordEncryptionKey() error {
key := keyTo32Bytes(conf.Server.PasswordEncryptionKey)
keySum := fmt.Sprintf("%x", sha256.Sum256(key))
props := NewPropertyRepository(r.ctx, r.ormer)
props := NewPropertyRepository(r.ctx, r.db)
savedKeySum, err := props.Get(consts.PasswordsEncryptedKey)
// If passwords are already encrypted
+1 -2
View File
@@ -4,7 +4,6 @@ import (
"context"
"errors"
"github.com/beego/beego/v2/client/orm"
"github.com/deluan/rest"
"github.com/google/uuid"
"github.com/navidrome/navidrome/consts"
@@ -19,7 +18,7 @@ var _ = Describe("UserRepository", func() {
var repo model.UserRepository
BeforeEach(func() {
repo = NewUserRepository(log.NewContext(context.TODO()), orm.NewOrm())
repo = NewUserRepository(log.NewContext(context.TODO()), getDBXBuilder())
})
Describe("Put/Get/FindByUsername", func() {