Replace beego/orm with dbx (#2693)
* Start migration to dbx package * Fix annotations and bookmarks bindings * Fix tests * Fix more tests * Remove remaining references to beego/orm * Add PostScanner/PostMapper interfaces * Fix importing SmartPlaylists * Renaming * More renaming * Fix artist DB mapping * Fix playlist updates * Remove bookmarks at the end of the test * Remove remaining `orm` struct tags * Fix user timestamps DB access * Fix smart playlist evaluated_at DB access * Fix search3
This commit is contained in:
@@ -6,11 +6,11 @@ import (
|
||||
"strings"
|
||||
|
||||
. "github.com/Masterminds/squirrel"
|
||||
"github.com/beego/beego/v2/client/orm"
|
||||
"github.com/deluan/rest"
|
||||
"github.com/navidrome/navidrome/conf"
|
||||
"github.com/navidrome/navidrome/log"
|
||||
"github.com/navidrome/navidrome/model"
|
||||
"github.com/pocketbase/dbx"
|
||||
)
|
||||
|
||||
type albumRepository struct {
|
||||
@@ -18,10 +18,10 @@ type albumRepository struct {
|
||||
sqlRestful
|
||||
}
|
||||
|
||||
func NewAlbumRepository(ctx context.Context, o orm.QueryExecutor) model.AlbumRepository {
|
||||
func NewAlbumRepository(ctx context.Context, db dbx.Builder) model.AlbumRepository {
|
||||
r := &albumRepository{}
|
||||
r.ctx = ctx
|
||||
r.ormer = o
|
||||
r.db = db
|
||||
r.tableName = "album"
|
||||
r.sortMappings = map[string]string{
|
||||
"name": "order_album_name asc, order_album_artist_name asc",
|
||||
|
||||
@@ -3,7 +3,6 @@ package persistence
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/beego/beego/v2/client/orm"
|
||||
"github.com/navidrome/navidrome/log"
|
||||
"github.com/navidrome/navidrome/model"
|
||||
"github.com/navidrome/navidrome/model/request"
|
||||
@@ -16,7 +15,7 @@ var _ = Describe("AlbumRepository", func() {
|
||||
|
||||
BeforeEach(func() {
|
||||
ctx := request.WithUser(log.NewContext(context.TODO()), model.User{ID: "userid", UserName: "johndoe"})
|
||||
repo = NewAlbumRepository(ctx, orm.NewOrm())
|
||||
repo = NewAlbumRepository(ctx, getDBXBuilder())
|
||||
})
|
||||
|
||||
Describe("Get", func() {
|
||||
|
||||
@@ -8,13 +8,13 @@ import (
|
||||
"strings"
|
||||
|
||||
. "github.com/Masterminds/squirrel"
|
||||
"github.com/beego/beego/v2/client/orm"
|
||||
"github.com/deluan/rest"
|
||||
"github.com/navidrome/navidrome/conf"
|
||||
"github.com/navidrome/navidrome/consts"
|
||||
"github.com/navidrome/navidrome/log"
|
||||
"github.com/navidrome/navidrome/model"
|
||||
"github.com/navidrome/navidrome/utils"
|
||||
"github.com/pocketbase/dbx"
|
||||
)
|
||||
|
||||
type artistRepository struct {
|
||||
@@ -24,14 +24,40 @@ type artistRepository struct {
|
||||
}
|
||||
|
||||
type dbArtist struct {
|
||||
model.Artist `structs:",flatten"`
|
||||
SimilarArtists string `structs:"similar_artists" json:"similarArtists"`
|
||||
*model.Artist `structs:",flatten"`
|
||||
SimilarArtists string `structs:"-" json:"similarArtists"`
|
||||
}
|
||||
|
||||
func NewArtistRepository(ctx context.Context, o orm.QueryExecutor) model.ArtistRepository {
|
||||
func (a *dbArtist) PostScan() error {
|
||||
if a.SimilarArtists == "" {
|
||||
return nil
|
||||
}
|
||||
for _, s := range strings.Split(a.SimilarArtists, ";") {
|
||||
fields := strings.Split(s, ":")
|
||||
if len(fields) != 2 {
|
||||
continue
|
||||
}
|
||||
name, _ := url.QueryUnescape(fields[1])
|
||||
a.Artist.SimilarArtists = append(a.Artist.SimilarArtists, model.Artist{
|
||||
ID: fields[0],
|
||||
Name: name,
|
||||
})
|
||||
}
|
||||
return nil
|
||||
}
|
||||
func (a *dbArtist) PostMapArgs(m map[string]any) error {
|
||||
var sa []string
|
||||
for _, s := range a.Artist.SimilarArtists {
|
||||
sa = append(sa, fmt.Sprintf("%s:%s", s.ID, url.QueryEscape(s.Name)))
|
||||
}
|
||||
m["similar_artists"] = strings.Join(sa, ";")
|
||||
return nil
|
||||
}
|
||||
|
||||
func NewArtistRepository(ctx context.Context, db dbx.Builder) model.ArtistRepository {
|
||||
r := &artistRepository{}
|
||||
r.ctx = ctx
|
||||
r.ormer = o
|
||||
r.db = db
|
||||
r.indexGroups = utils.ParseIndexGroups(conf.Server.IndexGroups)
|
||||
r.tableName = "artist"
|
||||
r.sortMappings = map[string]string{
|
||||
@@ -62,7 +88,7 @@ func (r *artistRepository) Exists(id string) (bool, error) {
|
||||
|
||||
func (r *artistRepository) Put(a *model.Artist, colsToUpdate ...string) error {
|
||||
a.FullText = getFullText(a.Name, a.SortArtistName)
|
||||
dba := r.fromModel(a)
|
||||
dba := &dbArtist{Artist: a}
|
||||
_, err := r.put(dba.ID, dba, colsToUpdate...)
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -102,41 +128,11 @@ func (r *artistRepository) GetAll(options ...model.QueryOptions) (model.Artists,
|
||||
func (r *artistRepository) toModels(dba []dbArtist) model.Artists {
|
||||
res := model.Artists{}
|
||||
for i := range dba {
|
||||
a := dba[i]
|
||||
res = append(res, *r.toModel(&a))
|
||||
res = append(res, *dba[i].Artist)
|
||||
}
|
||||
return res
|
||||
}
|
||||
|
||||
func (r *artistRepository) toModel(dba *dbArtist) *model.Artist {
|
||||
a := dba.Artist
|
||||
a.SimilarArtists = nil
|
||||
for _, s := range strings.Split(dba.SimilarArtists, ";") {
|
||||
fields := strings.Split(s, ":")
|
||||
if len(fields) != 2 {
|
||||
continue
|
||||
}
|
||||
name, _ := url.QueryUnescape(fields[1])
|
||||
a.SimilarArtists = append(a.SimilarArtists, model.Artist{
|
||||
ID: fields[0],
|
||||
Name: name,
|
||||
})
|
||||
}
|
||||
return &a
|
||||
}
|
||||
|
||||
func (r *artistRepository) fromModel(a *model.Artist) *dbArtist {
|
||||
dba := &dbArtist{Artist: *a}
|
||||
var sa []string
|
||||
|
||||
for _, s := range a.SimilarArtists {
|
||||
sa = append(sa, fmt.Sprintf("%s:%s", s.ID, url.QueryEscape(s.Name)))
|
||||
}
|
||||
|
||||
dba.SimilarArtists = strings.Join(sa, ";")
|
||||
return dba
|
||||
}
|
||||
|
||||
func (r *artistRepository) getIndexKey(a *model.Artist) string {
|
||||
name := strings.ToLower(utils.NoArticle(a.Name))
|
||||
for k, v := range r.indexGroups {
|
||||
|
||||
@@ -3,7 +3,7 @@ package persistence
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/beego/beego/v2/client/orm"
|
||||
"github.com/fatih/structs"
|
||||
"github.com/navidrome/navidrome/log"
|
||||
"github.com/navidrome/navidrome/model"
|
||||
"github.com/navidrome/navidrome/model/request"
|
||||
@@ -18,7 +18,7 @@ var _ = Describe("ArtistRepository", func() {
|
||||
BeforeEach(func() {
|
||||
ctx := log.NewContext(context.TODO())
|
||||
ctx = request.WithUser(ctx, model.User{ID: "userid"})
|
||||
repo = NewArtistRepository(ctx, orm.NewOrm())
|
||||
repo = NewArtistRepository(ctx, getDBXBuilder())
|
||||
})
|
||||
|
||||
Describe("Count", func() {
|
||||
@@ -71,8 +71,17 @@ var _ = Describe("ArtistRepository", func() {
|
||||
}}
|
||||
})
|
||||
It("maps fields", func() {
|
||||
dba := repo.(*artistRepository).fromModel(a)
|
||||
actual := repo.(*artistRepository).toModel(dba)
|
||||
dba := &dbArtist{Artist: a}
|
||||
m := structs.Map(dba)
|
||||
Expect(dba.PostMapArgs(m)).To(Succeed())
|
||||
Expect(m).To(HaveKeyWithValue("similar_artists", "2:AC%2FDC;-1:Test%3BWith%3ASep%2CChars"))
|
||||
|
||||
other := dbArtist{SimilarArtists: m["similar_artists"].(string), Artist: &model.Artist{
|
||||
ID: "1", Name: "Van Halen",
|
||||
}}
|
||||
Expect(other.PostScan()).To(Succeed())
|
||||
|
||||
actual := other.Artist
|
||||
Expect(*actual).To(MatchFields(IgnoreExtras, Fields{
|
||||
"ID": Equal(a.ID),
|
||||
"Name": Equal(a.Name),
|
||||
|
||||
@@ -4,9 +4,9 @@ import (
|
||||
"context"
|
||||
|
||||
"github.com/google/uuid"
|
||||
"github.com/pocketbase/dbx"
|
||||
|
||||
. "github.com/Masterminds/squirrel"
|
||||
"github.com/beego/beego/v2/client/orm"
|
||||
"github.com/deluan/rest"
|
||||
"github.com/navidrome/navidrome/log"
|
||||
"github.com/navidrome/navidrome/model"
|
||||
@@ -17,10 +17,10 @@ type genreRepository struct {
|
||||
sqlRestful
|
||||
}
|
||||
|
||||
func NewGenreRepository(ctx context.Context, o orm.QueryExecutor) model.GenreRepository {
|
||||
func NewGenreRepository(ctx context.Context, db dbx.Builder) model.GenreRepository {
|
||||
r := &genreRepository{}
|
||||
r.ctx = ctx
|
||||
r.ormer = o
|
||||
r.db = db
|
||||
r.tableName = "genre"
|
||||
r.filterMappings = map[string]filterFunc{
|
||||
"name": containsFilter,
|
||||
@@ -29,7 +29,12 @@ func NewGenreRepository(ctx context.Context, o orm.QueryExecutor) model.GenreRep
|
||||
}
|
||||
|
||||
func (r *genreRepository) GetAll(opt ...model.QueryOptions) (model.Genres, error) {
|
||||
sq := r.newSelect(opt...).Columns("genre.id", "genre.name", "a.album_count", "m.song_count").
|
||||
sq := r.newSelect(opt...).Columns(
|
||||
"genre.id",
|
||||
"genre.name",
|
||||
"coalesce(a.album_count, 0) as album_count",
|
||||
"coalesce(m.song_count, 0) as song_count",
|
||||
).
|
||||
LeftJoin("(select ag.genre_id, count(ag.album_id) as album_count from album_genres ag group by ag.genre_id) a on a.genre_id = genre.id").
|
||||
LeftJoin("(select mg.genre_id, count(mg.media_file_id) as song_count from media_file_genres mg group by mg.genre_id) m on m.genre_id = genre.id")
|
||||
res := model.Genres{}
|
||||
|
||||
@@ -3,26 +3,27 @@ package persistence_test
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/beego/beego/v2/client/orm"
|
||||
"github.com/google/uuid"
|
||||
"github.com/navidrome/navidrome/db"
|
||||
"github.com/navidrome/navidrome/log"
|
||||
"github.com/navidrome/navidrome/model"
|
||||
"github.com/navidrome/navidrome/persistence"
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
"github.com/pocketbase/dbx"
|
||||
)
|
||||
|
||||
var _ = Describe("GenreRepository", func() {
|
||||
var repo model.GenreRepository
|
||||
|
||||
BeforeEach(func() {
|
||||
repo = persistence.NewGenreRepository(log.NewContext(context.TODO()), orm.NewOrm())
|
||||
repo = persistence.NewGenreRepository(log.NewContext(context.TODO()), dbx.NewFromDB(db.Db(), db.Driver))
|
||||
})
|
||||
|
||||
Describe("GetAll()", func() {
|
||||
It("returns all records", func() {
|
||||
genres, err := repo.GetAll()
|
||||
Expect(err).To(BeNil())
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(genres).To(ConsistOf(
|
||||
model.Genre{ID: "gn-1", Name: "Electronic", AlbumCount: 1, SongCount: 2},
|
||||
model.Genre{ID: "gn-2", Name: "Rock", AlbumCount: 3, SongCount: 3},
|
||||
@@ -43,13 +44,14 @@ var _ = Describe("GenreRepository", func() {
|
||||
It("insert non-existent genre names", func() {
|
||||
g := model.Genre{Name: "Reggae"}
|
||||
err := repo.Put(&g)
|
||||
Expect(err).To(BeNil())
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
// ID is a uuid
|
||||
_, err = uuid.Parse(g.ID)
|
||||
Expect(err).To(BeNil())
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
genres, _ := repo.GetAll()
|
||||
genres, err := repo.GetAll()
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(genres).To(HaveLen(3))
|
||||
Expect(genres).To(ContainElement(model.Genre{ID: g.ID, Name: "Reggae", AlbumCount: 0, SongCount: 0}))
|
||||
})
|
||||
|
||||
+23
-4
@@ -1,6 +1,7 @@
|
||||
package persistence
|
||||
|
||||
import (
|
||||
"database/sql/driver"
|
||||
"fmt"
|
||||
"regexp"
|
||||
"strings"
|
||||
@@ -10,14 +11,32 @@ import (
|
||||
"github.com/fatih/structs"
|
||||
)
|
||||
|
||||
func toSqlArgs(rec interface{}) (map[string]interface{}, error) {
|
||||
type PostMapper interface {
|
||||
PostMapArgs(map[string]any) error
|
||||
}
|
||||
|
||||
func toSQLArgs(rec interface{}) (map[string]interface{}, error) {
|
||||
m := structs.Map(rec)
|
||||
for k, v := range m {
|
||||
if t, ok := v.(time.Time); ok {
|
||||
switch t := v.(type) {
|
||||
case time.Time:
|
||||
m[k] = t.Format(time.RFC3339Nano)
|
||||
case *time.Time:
|
||||
if t != nil {
|
||||
m[k] = t.Format(time.RFC3339Nano)
|
||||
}
|
||||
case driver.Valuer:
|
||||
var err error
|
||||
m[k], err = t.Value()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
if t, ok := v.(*time.Time); ok && t != nil {
|
||||
m[k] = t.Format(time.RFC3339Nano)
|
||||
}
|
||||
if r, ok := rec.(PostMapper); ok {
|
||||
err := r.PostMapArgs(m)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
return m, nil
|
||||
|
||||
@@ -23,7 +23,7 @@ var _ = Describe("Helpers", func() {
|
||||
Expect(toSnakeCase("snake_case")).To(Equal("snake_case"))
|
||||
})
|
||||
})
|
||||
Describe("toSqlArgs", func() {
|
||||
Describe("toSQLArgs", func() {
|
||||
type Embed struct{}
|
||||
type Model struct {
|
||||
Embed `structs:"-"`
|
||||
@@ -37,10 +37,11 @@ var _ = Describe("Helpers", func() {
|
||||
It("returns a map with snake_case keys", func() {
|
||||
now := time.Now()
|
||||
m := &Model{ID: "123", AlbumId: "456", CreatedAt: now, UpdatedAt: &now, PlayCount: 2}
|
||||
args, err := toSqlArgs(m)
|
||||
args, err := toSQLArgs(m)
|
||||
Expect(err).To(BeNil())
|
||||
Expect(args).To(HaveKeyWithValue("id", "123"))
|
||||
Expect(args).To(HaveKeyWithValue("album_id", "456"))
|
||||
Expect(args).To(HaveKeyWithValue("play_count", 2))
|
||||
Expect(args).To(HaveKeyWithValue("updated_at", now.Format(time.RFC3339Nano)))
|
||||
Expect(args).To(HaveKeyWithValue("created_at", now.Format(time.RFC3339Nano)))
|
||||
Expect(args).ToNot(HaveKey("Embed"))
|
||||
|
||||
@@ -9,10 +9,10 @@ import (
|
||||
"unicode/utf8"
|
||||
|
||||
. "github.com/Masterminds/squirrel"
|
||||
"github.com/beego/beego/v2/client/orm"
|
||||
"github.com/deluan/rest"
|
||||
"github.com/navidrome/navidrome/log"
|
||||
"github.com/navidrome/navidrome/model"
|
||||
"github.com/pocketbase/dbx"
|
||||
)
|
||||
|
||||
type mediaFileRepository struct {
|
||||
@@ -20,10 +20,10 @@ type mediaFileRepository struct {
|
||||
sqlRestful
|
||||
}
|
||||
|
||||
func NewMediaFileRepository(ctx context.Context, o orm.QueryExecutor) *mediaFileRepository {
|
||||
func NewMediaFileRepository(ctx context.Context, db dbx.Builder) *mediaFileRepository {
|
||||
r := &mediaFileRepository{}
|
||||
r.ctx = ctx
|
||||
r.ormer = o
|
||||
r.db = db
|
||||
r.tableName = "media_file"
|
||||
r.sortMappings = map[string]string{
|
||||
"artist": "order_artist_name asc, order_album_name asc, release_date asc, disc_number asc, track_number asc",
|
||||
@@ -146,7 +146,7 @@ func (r *mediaFileRepository) FindPathsRecursively(basePath string) ([]string, e
|
||||
sel := r.newSelect().Columns(fmt.Sprintf("distinct rtrim(path, replace(path, '%s', ''))", string(os.PathSeparator))).
|
||||
Where(pathStartsWith(path))
|
||||
var res []string
|
||||
err := r.queryAll(sel, &res)
|
||||
err := r.queryAllSlice(sel, &res)
|
||||
return res, err
|
||||
}
|
||||
|
||||
|
||||
@@ -5,7 +5,6 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/Masterminds/squirrel"
|
||||
"github.com/beego/beego/v2/client/orm"
|
||||
"github.com/google/uuid"
|
||||
"github.com/navidrome/navidrome/log"
|
||||
"github.com/navidrome/navidrome/model"
|
||||
@@ -20,7 +19,7 @@ var _ = Describe("MediaRepository", func() {
|
||||
BeforeEach(func() {
|
||||
ctx := log.NewContext(context.TODO())
|
||||
ctx = request.WithUser(ctx, model.User{ID: "userid"})
|
||||
mr = NewMediaFileRepository(ctx, orm.NewOrm())
|
||||
mr = NewMediaFileRepository(ctx, getDBXBuilder())
|
||||
})
|
||||
|
||||
It("gets mediafile from the DB", func() {
|
||||
|
||||
@@ -3,16 +3,16 @@ package persistence
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/beego/beego/v2/client/orm"
|
||||
"github.com/navidrome/navidrome/conf"
|
||||
"github.com/navidrome/navidrome/model"
|
||||
"github.com/pocketbase/dbx"
|
||||
)
|
||||
|
||||
type mediaFolderRepository struct {
|
||||
ctx context.Context
|
||||
}
|
||||
|
||||
func NewMediaFolderRepository(ctx context.Context, o orm.QueryExecutor) model.MediaFolderRepository {
|
||||
func NewMediaFolderRepository(ctx context.Context, _ dbx.Builder) model.MediaFolderRepository {
|
||||
return &mediaFolderRepository{ctx}
|
||||
}
|
||||
|
||||
|
||||
+28
-33
@@ -5,79 +5,78 @@ import (
|
||||
"database/sql"
|
||||
"reflect"
|
||||
|
||||
"github.com/beego/beego/v2/client/orm"
|
||||
"github.com/navidrome/navidrome/db"
|
||||
"github.com/navidrome/navidrome/log"
|
||||
"github.com/navidrome/navidrome/model"
|
||||
"github.com/pocketbase/dbx"
|
||||
)
|
||||
|
||||
type SQLStore struct {
|
||||
orm orm.QueryExecutor
|
||||
db *sql.DB
|
||||
db dbx.Builder
|
||||
}
|
||||
|
||||
func New(db *sql.DB) model.DataStore {
|
||||
return &SQLStore{db: db}
|
||||
func New(conn *sql.DB) model.DataStore {
|
||||
return &SQLStore{db: dbx.NewFromDB(conn, db.Driver)}
|
||||
}
|
||||
|
||||
func (s *SQLStore) Album(ctx context.Context) model.AlbumRepository {
|
||||
return NewAlbumRepository(ctx, s.getOrmer())
|
||||
return NewAlbumRepository(ctx, s.getDBXBuilder())
|
||||
}
|
||||
|
||||
func (s *SQLStore) Artist(ctx context.Context) model.ArtistRepository {
|
||||
return NewArtistRepository(ctx, s.getOrmer())
|
||||
return NewArtistRepository(ctx, s.getDBXBuilder())
|
||||
}
|
||||
|
||||
func (s *SQLStore) MediaFile(ctx context.Context) model.MediaFileRepository {
|
||||
return NewMediaFileRepository(ctx, s.getOrmer())
|
||||
return NewMediaFileRepository(ctx, s.getDBXBuilder())
|
||||
}
|
||||
|
||||
func (s *SQLStore) MediaFolder(ctx context.Context) model.MediaFolderRepository {
|
||||
return NewMediaFolderRepository(ctx, s.getOrmer())
|
||||
return NewMediaFolderRepository(ctx, s.getDBXBuilder())
|
||||
}
|
||||
|
||||
func (s *SQLStore) Genre(ctx context.Context) model.GenreRepository {
|
||||
return NewGenreRepository(ctx, s.getOrmer())
|
||||
return NewGenreRepository(ctx, s.getDBXBuilder())
|
||||
}
|
||||
|
||||
func (s *SQLStore) PlayQueue(ctx context.Context) model.PlayQueueRepository {
|
||||
return NewPlayQueueRepository(ctx, s.getOrmer())
|
||||
return NewPlayQueueRepository(ctx, s.getDBXBuilder())
|
||||
}
|
||||
|
||||
func (s *SQLStore) Playlist(ctx context.Context) model.PlaylistRepository {
|
||||
return NewPlaylistRepository(ctx, s.getOrmer())
|
||||
return NewPlaylistRepository(ctx, s.getDBXBuilder())
|
||||
}
|
||||
|
||||
func (s *SQLStore) Property(ctx context.Context) model.PropertyRepository {
|
||||
return NewPropertyRepository(ctx, s.getOrmer())
|
||||
return NewPropertyRepository(ctx, s.getDBXBuilder())
|
||||
}
|
||||
|
||||
func (s *SQLStore) Radio(ctx context.Context) model.RadioRepository {
|
||||
return NewRadioRepository(ctx, s.getOrmer())
|
||||
return NewRadioRepository(ctx, s.getDBXBuilder())
|
||||
}
|
||||
|
||||
func (s *SQLStore) UserProps(ctx context.Context) model.UserPropsRepository {
|
||||
return NewUserPropsRepository(ctx, s.getOrmer())
|
||||
return NewUserPropsRepository(ctx, s.getDBXBuilder())
|
||||
}
|
||||
|
||||
func (s *SQLStore) Share(ctx context.Context) model.ShareRepository {
|
||||
return NewShareRepository(ctx, s.getOrmer())
|
||||
return NewShareRepository(ctx, s.getDBXBuilder())
|
||||
}
|
||||
|
||||
func (s *SQLStore) User(ctx context.Context) model.UserRepository {
|
||||
return NewUserRepository(ctx, s.getOrmer())
|
||||
return NewUserRepository(ctx, s.getDBXBuilder())
|
||||
}
|
||||
|
||||
func (s *SQLStore) Transcoding(ctx context.Context) model.TranscodingRepository {
|
||||
return NewTranscodingRepository(ctx, s.getOrmer())
|
||||
return NewTranscodingRepository(ctx, s.getDBXBuilder())
|
||||
}
|
||||
|
||||
func (s *SQLStore) Player(ctx context.Context) model.PlayerRepository {
|
||||
return NewPlayerRepository(ctx, s.getOrmer())
|
||||
return NewPlayerRepository(ctx, s.getDBXBuilder())
|
||||
}
|
||||
|
||||
func (s *SQLStore) ScrobbleBuffer(ctx context.Context) model.ScrobbleBufferRepository {
|
||||
return NewScrobbleBufferRepository(ctx, s.getOrmer())
|
||||
return NewScrobbleBufferRepository(ctx, s.getDBXBuilder())
|
||||
}
|
||||
|
||||
func (s *SQLStore) Resource(ctx context.Context, m interface{}) model.ResourceRepository {
|
||||
@@ -108,12 +107,12 @@ func (s *SQLStore) Resource(ctx context.Context, m interface{}) model.ResourceRe
|
||||
}
|
||||
|
||||
func (s *SQLStore) WithTx(block func(tx model.DataStore) error) error {
|
||||
o, err := orm.NewOrmWithDB(db.Driver, "default", s.db)
|
||||
if err != nil {
|
||||
return err
|
||||
conn, ok := s.db.(*dbx.DB)
|
||||
if !ok {
|
||||
conn = dbx.NewFromDB(db.Db(), db.Driver)
|
||||
}
|
||||
return o.DoTx(func(ctx context.Context, txOrm orm.TxOrmer) error {
|
||||
newDb := &SQLStore{orm: txOrm}
|
||||
return conn.Transactional(func(tx *dbx.Tx) error {
|
||||
newDb := &SQLStore{db: tx}
|
||||
return block(newDb)
|
||||
})
|
||||
}
|
||||
@@ -171,13 +170,9 @@ func (s *SQLStore) GC(ctx context.Context, rootFolder string) error {
|
||||
return err
|
||||
}
|
||||
|
||||
func (s *SQLStore) getOrmer() orm.QueryExecutor {
|
||||
if s.orm == nil {
|
||||
o, err := orm.NewOrmWithDB(db.Driver, "default", s.db)
|
||||
if err != nil {
|
||||
log.Error("Error obtaining new orm instance", err)
|
||||
}
|
||||
return o
|
||||
func (s *SQLStore) getDBXBuilder() dbx.Builder {
|
||||
if s.db == nil {
|
||||
return dbx.NewFromDB(db.Db(), db.Driver)
|
||||
}
|
||||
return s.orm
|
||||
return s.db
|
||||
}
|
||||
|
||||
@@ -5,7 +5,6 @@ import (
|
||||
"path/filepath"
|
||||
"testing"
|
||||
|
||||
"github.com/beego/beego/v2/client/orm"
|
||||
_ "github.com/mattn/go-sqlite3"
|
||||
"github.com/navidrome/navidrome/conf"
|
||||
"github.com/navidrome/navidrome/db"
|
||||
@@ -15,6 +14,7 @@ import (
|
||||
"github.com/navidrome/navidrome/tests"
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
"github.com/pocketbase/dbx"
|
||||
)
|
||||
|
||||
func TestPersistence(t *testing.T) {
|
||||
@@ -23,13 +23,16 @@ func TestPersistence(t *testing.T) {
|
||||
//os.Remove("./test-123.db")
|
||||
//conf.Server.DbPath = "./test-123.db"
|
||||
conf.Server.DbPath = "file::memory:?cache=shared"
|
||||
_ = orm.RegisterDataBase("default", db.Driver, conf.Server.DbPath)
|
||||
db.Init()
|
||||
log.SetLevel(log.LevelError)
|
||||
RegisterFailHandler(Fail)
|
||||
RunSpecs(t, "Persistence Suite")
|
||||
}
|
||||
|
||||
func getDBXBuilder() *dbx.DB {
|
||||
return dbx.NewFromDB(db.Db(), db.Driver)
|
||||
}
|
||||
|
||||
var (
|
||||
genreElectronic = model.Genre{ID: "gn-1", Name: "Electronic"}
|
||||
genreRock = model.Genre{ID: "gn-2", Name: "Rock"}
|
||||
@@ -88,18 +91,18 @@ func P(path string) string {
|
||||
// Initialize test DB
|
||||
// TODO Load this data setup from file(s)
|
||||
var _ = BeforeSuite(func() {
|
||||
o := orm.NewOrm()
|
||||
conn := getDBXBuilder()
|
||||
ctx := log.NewContext(context.TODO())
|
||||
user := model.User{ID: "userid", UserName: "userid", IsAdmin: true}
|
||||
ctx = request.WithUser(ctx, user)
|
||||
|
||||
ur := NewUserRepository(ctx, o)
|
||||
ur := NewUserRepository(ctx, conn)
|
||||
err := ur.Put(&user)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
gr := NewGenreRepository(ctx, o)
|
||||
gr := NewGenreRepository(ctx, conn)
|
||||
for i := range testGenres {
|
||||
g := testGenres[i]
|
||||
err := gr.Put(&g)
|
||||
@@ -108,7 +111,7 @@ var _ = BeforeSuite(func() {
|
||||
}
|
||||
}
|
||||
|
||||
mr := NewMediaFileRepository(ctx, o)
|
||||
mr := NewMediaFileRepository(ctx, conn)
|
||||
for i := range testSongs {
|
||||
s := testSongs[i]
|
||||
err := mr.Put(&s)
|
||||
@@ -117,7 +120,7 @@ var _ = BeforeSuite(func() {
|
||||
}
|
||||
}
|
||||
|
||||
alr := NewAlbumRepository(ctx, o).(*albumRepository)
|
||||
alr := NewAlbumRepository(ctx, conn).(*albumRepository)
|
||||
for i := range testAlbums {
|
||||
a := testAlbums[i]
|
||||
err := alr.Put(&a)
|
||||
@@ -126,7 +129,7 @@ var _ = BeforeSuite(func() {
|
||||
}
|
||||
}
|
||||
|
||||
arr := NewArtistRepository(ctx, o)
|
||||
arr := NewArtistRepository(ctx, conn)
|
||||
for i := range testArtists {
|
||||
a := testArtists[i]
|
||||
err := arr.Put(&a)
|
||||
@@ -135,7 +138,7 @@ var _ = BeforeSuite(func() {
|
||||
}
|
||||
}
|
||||
|
||||
rar := NewRadioRepository(ctx, o)
|
||||
rar := NewRadioRepository(ctx, conn)
|
||||
for i := range testRadios {
|
||||
r := testRadios[i]
|
||||
err := rar.Put(&r)
|
||||
@@ -157,7 +160,7 @@ var _ = BeforeSuite(func() {
|
||||
plsCool.AddTracks([]string{"1004"})
|
||||
testPlaylists = []*model.Playlist{&plsBest, &plsCool}
|
||||
|
||||
pr := NewPlaylistRepository(ctx, o)
|
||||
pr := NewPlaylistRepository(ctx, conn)
|
||||
for i := range testPlaylists {
|
||||
err := pr.Put(testPlaylists[i])
|
||||
if err != nil {
|
||||
|
||||
@@ -5,9 +5,9 @@ import (
|
||||
"errors"
|
||||
|
||||
. "github.com/Masterminds/squirrel"
|
||||
"github.com/beego/beego/v2/client/orm"
|
||||
"github.com/deluan/rest"
|
||||
"github.com/navidrome/navidrome/model"
|
||||
"github.com/pocketbase/dbx"
|
||||
)
|
||||
|
||||
type playerRepository struct {
|
||||
@@ -15,10 +15,10 @@ type playerRepository struct {
|
||||
sqlRestful
|
||||
}
|
||||
|
||||
func NewPlayerRepository(ctx context.Context, o orm.QueryExecutor) model.PlayerRepository {
|
||||
func NewPlayerRepository(ctx context.Context, db dbx.Builder) model.PlayerRepository {
|
||||
r := &playerRepository{}
|
||||
r.ctx = ctx
|
||||
r.ormer = o
|
||||
r.db = db
|
||||
r.tableName = "player"
|
||||
r.filterMappings = map[string]filterFunc{
|
||||
"name": containsFilter,
|
||||
|
||||
@@ -2,18 +2,18 @@ package persistence
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
. "github.com/Masterminds/squirrel"
|
||||
"github.com/beego/beego/v2/client/orm"
|
||||
"github.com/deluan/rest"
|
||||
"github.com/navidrome/navidrome/log"
|
||||
"github.com/navidrome/navidrome/model"
|
||||
"github.com/navidrome/navidrome/model/criteria"
|
||||
"github.com/navidrome/navidrome/utils/slice"
|
||||
"github.com/pocketbase/dbx"
|
||||
)
|
||||
|
||||
type playlistRepository struct {
|
||||
@@ -23,13 +23,30 @@ type playlistRepository struct {
|
||||
|
||||
type dbPlaylist struct {
|
||||
model.Playlist `structs:",flatten"`
|
||||
RawRules string `structs:"rules" orm:"column(rules)"`
|
||||
Rules sql.NullString `structs:"-"`
|
||||
}
|
||||
|
||||
func NewPlaylistRepository(ctx context.Context, o orm.QueryExecutor) model.PlaylistRepository {
|
||||
func (p *dbPlaylist) PostScan() error {
|
||||
if p.Rules.String != "" {
|
||||
return json.Unmarshal([]byte(p.Rules.String), &p.Playlist.Rules)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (p dbPlaylist) PostMapArgs(args map[string]any) error {
|
||||
var err error
|
||||
if p.Playlist.IsSmartPlaylist() {
|
||||
args["rules"], err = json.Marshal(p.Playlist.Rules)
|
||||
return err
|
||||
}
|
||||
delete(args, "rules")
|
||||
return nil
|
||||
}
|
||||
|
||||
func NewPlaylistRepository(ctx context.Context, db dbx.Builder) model.PlaylistRepository {
|
||||
r := &playlistRepository{}
|
||||
r.ctx = ctx
|
||||
r.ormer = o
|
||||
r.db = db
|
||||
r.tableName = "playlist"
|
||||
r.filterMappings = map[string]filterFunc{
|
||||
"q": playlistFilter,
|
||||
@@ -64,8 +81,8 @@ func (r *playlistRepository) userFilter() Sqlizer {
|
||||
}
|
||||
|
||||
func (r *playlistRepository) CountAll(options ...model.QueryOptions) (int64, error) {
|
||||
sql := Select().Where(r.userFilter())
|
||||
return r.count(sql, options...)
|
||||
sq := Select().Where(r.userFilter())
|
||||
return r.count(sq, options...)
|
||||
}
|
||||
|
||||
func (r *playlistRepository) Exists(id string) (bool, error) {
|
||||
@@ -88,13 +105,6 @@ func (r *playlistRepository) Delete(id string) error {
|
||||
|
||||
func (r *playlistRepository) Put(p *model.Playlist) error {
|
||||
pls := dbPlaylist{Playlist: *p}
|
||||
if p.IsSmartPlaylist() {
|
||||
j, err := json.Marshal(p.Rules)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
pls.RawRules = string(j)
|
||||
}
|
||||
if pls.ID == "" {
|
||||
pls.CreatedAt = time.Now()
|
||||
} else {
|
||||
@@ -161,22 +171,7 @@ func (r *playlistRepository) findBy(sql Sqlizer) (*model.Playlist, error) {
|
||||
return nil, model.ErrNotFound
|
||||
}
|
||||
|
||||
return r.toModel(pls[0])
|
||||
}
|
||||
|
||||
func (r *playlistRepository) toModel(pls dbPlaylist) (*model.Playlist, error) {
|
||||
var err error
|
||||
if strings.TrimSpace(pls.RawRules) != "" {
|
||||
var c criteria.Criteria
|
||||
err = json.Unmarshal([]byte(pls.RawRules), &c)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
pls.Playlist.Rules = &c
|
||||
} else {
|
||||
pls.Playlist.Rules = nil
|
||||
}
|
||||
return &pls.Playlist, err
|
||||
return &pls[0].Playlist, nil
|
||||
}
|
||||
|
||||
func (r *playlistRepository) GetAll(options ...model.QueryOptions) (model.Playlists, error) {
|
||||
@@ -188,11 +183,7 @@ func (r *playlistRepository) GetAll(options ...model.QueryOptions) (model.Playli
|
||||
}
|
||||
playlists := make(model.Playlists, len(res))
|
||||
for i, p := range res {
|
||||
pls, err := r.toModel(p)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
playlists[i] = *pls
|
||||
playlists[i] = p.Playlist
|
||||
}
|
||||
return playlists, err
|
||||
}
|
||||
@@ -204,13 +195,14 @@ func (r *playlistRepository) selectPlaylist(options ...model.QueryOptions) Selec
|
||||
|
||||
func (r *playlistRepository) refreshSmartPlaylist(pls *model.Playlist) bool {
|
||||
// Only refresh if it is a smart playlist and was not refreshed in the last 5 seconds
|
||||
if !pls.IsSmartPlaylist() || time.Since(pls.EvaluatedAt) < 5*time.Second {
|
||||
if !pls.IsSmartPlaylist() || (pls.EvaluatedAt != nil && time.Since(*pls.EvaluatedAt) < 5*time.Second) {
|
||||
return false
|
||||
}
|
||||
|
||||
// Never refresh other users' playlists
|
||||
usr := loggedUser(r.ctx)
|
||||
if pls.OwnerID != usr.ID {
|
||||
log.Trace(r.ctx, "Not refreshing smart playlist from other user", "playlist", pls.Name, "id", pls.ID)
|
||||
return false
|
||||
}
|
||||
|
||||
@@ -221,20 +213,21 @@ func (r *playlistRepository) refreshSmartPlaylist(pls *model.Playlist) bool {
|
||||
del := Delete("playlist_tracks").Where(Eq{"playlist_id": pls.ID})
|
||||
_, err := r.executeSQL(del)
|
||||
if err != nil {
|
||||
log.Error(r.ctx, "Error deleting old smart playlist tracks", "playlist", pls.Name, "id", pls.ID, err)
|
||||
return false
|
||||
}
|
||||
|
||||
// Re-populate playlist based on Smart Playlist criteria
|
||||
rules := *pls.Rules
|
||||
sql := Select("row_number() over (order by "+rules.OrderBy()+") as id", "'"+pls.ID+"' as playlist_id", "media_file.id as media_file_id").
|
||||
sq := Select("row_number() over (order by "+rules.OrderBy()+") as id", "'"+pls.ID+"' as playlist_id", "media_file.id as media_file_id").
|
||||
From("media_file").LeftJoin("annotation on (" +
|
||||
"annotation.item_id = media_file.id" +
|
||||
" AND annotation.item_type = 'media_file'" +
|
||||
" AND annotation.user_id = '" + userId(r.ctx) + "')").
|
||||
LeftJoin("media_file_genres ag on media_file.id = ag.media_file_id").
|
||||
LeftJoin("genre on ag.genre_id = genre.id").GroupBy("media_file.id")
|
||||
sql = r.addCriteria(sql, rules)
|
||||
insSql := Insert("playlist_tracks").Columns("id", "playlist_id", "media_file_id").Select(sql)
|
||||
sq = r.addCriteria(sq, rules)
|
||||
insSql := Insert("playlist_tracks").Columns("id", "playlist_id", "media_file_id").Select(sq)
|
||||
_, err = r.executeSQL(insSql)
|
||||
if err != nil {
|
||||
log.Error(r.ctx, "Error refreshing smart playlist tracks", "playlist", pls.Name, "id", pls.ID, err)
|
||||
@@ -262,7 +255,7 @@ func (r *playlistRepository) refreshSmartPlaylist(pls *model.Playlist) bool {
|
||||
}
|
||||
|
||||
func (r *playlistRepository) addCriteria(sql SelectBuilder, c criteria.Criteria) SelectBuilder {
|
||||
sql = sql.Where(c.ToSql())
|
||||
sql = sql.Where(c)
|
||||
if c.Limit > 0 {
|
||||
sql = sql.Limit(uint64(c.Limit)).Offset(uint64(c.Offset))
|
||||
}
|
||||
@@ -318,7 +311,11 @@ func (r *playlistRepository) addTracks(playlistId string, startingPos int, media
|
||||
|
||||
// RefreshStatus updates total playlist duration, size and count
|
||||
func (r *playlistRepository) refreshCounters(pls *model.Playlist) error {
|
||||
statsSql := Select("sum(duration) as duration", "sum(size) as size", "count(*) as count").
|
||||
statsSql := Select(
|
||||
"coalesce(sum(duration), 0) as duration",
|
||||
"coalesce(sum(size), 0) as size",
|
||||
"count(*) as count",
|
||||
).
|
||||
From("media_file").
|
||||
Join("playlist_tracks f on f.media_file_id = media_file.id").
|
||||
Where(Eq{"playlist_id": pls.ID})
|
||||
@@ -347,7 +344,15 @@ func (r *playlistRepository) refreshCounters(pls *model.Playlist) error {
|
||||
|
||||
func (r *playlistRepository) loadTracks(sel SelectBuilder, id string) (model.PlaylistTracks, error) {
|
||||
tracksQuery := sel.
|
||||
Columns("starred", "starred_at", "play_count", "play_date", "rating", "f.*", "playlist_tracks.*").
|
||||
Columns(
|
||||
"coalesce(starred, 0)",
|
||||
"starred_at",
|
||||
"coalesce(play_count, 0)",
|
||||
"play_date",
|
||||
"coalesce(rating, 0)",
|
||||
"f.*",
|
||||
"playlist_tracks.*",
|
||||
).
|
||||
LeftJoin("annotation on (" +
|
||||
"annotation.item_id = media_file_id" +
|
||||
" AND annotation.item_type = 'media_file'" +
|
||||
@@ -402,7 +407,7 @@ func (r *playlistRepository) Update(id string, entity interface{}, cols ...strin
|
||||
if !usr.IsAdmin && current.OwnerID != usr.ID {
|
||||
return rest.ErrPermissionDenied
|
||||
}
|
||||
pls := entity.(*model.Playlist)
|
||||
pls := dbPlaylist{Playlist: *entity.(*model.Playlist)}
|
||||
pls.ID = id
|
||||
pls.UpdatedAt = time.Now()
|
||||
_, err = r.put(id, pls, append(cols, "updatedAt")...)
|
||||
@@ -447,8 +452,8 @@ func (r *playlistRepository) removeOrphans() error {
|
||||
|
||||
func (r *playlistRepository) renumber(id string) error {
|
||||
var ids []string
|
||||
sql := Select("media_file_id").From("playlist_tracks").Where(Eq{"playlist_id": id}).OrderBy("id")
|
||||
err := r.queryAll(sql, &ids)
|
||||
sq := Select("media_file_id").From("playlist_tracks").Where(Eq{"playlist_id": id}).OrderBy("id")
|
||||
err := r.queryAllSlice(sq, &ids)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -3,9 +3,9 @@ package persistence
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/beego/beego/v2/client/orm"
|
||||
"github.com/navidrome/navidrome/log"
|
||||
"github.com/navidrome/navidrome/model"
|
||||
"github.com/navidrome/navidrome/model/criteria"
|
||||
"github.com/navidrome/navidrome/model/request"
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
. "github.com/onsi/gomega"
|
||||
@@ -17,7 +17,7 @@ var _ = Describe("PlaylistRepository", func() {
|
||||
BeforeEach(func() {
|
||||
ctx := log.NewContext(context.TODO())
|
||||
ctx = request.WithUser(ctx, model.User{ID: "userid", UserName: "userid", IsAdmin: true})
|
||||
repo = NewPlaylistRepository(ctx, orm.NewOrm())
|
||||
repo = NewPlaylistRepository(ctx, getDBXBuilder())
|
||||
})
|
||||
|
||||
Describe("Count", func() {
|
||||
@@ -56,7 +56,7 @@ var _ = Describe("PlaylistRepository", func() {
|
||||
})
|
||||
It("returns all tracks", func() {
|
||||
pls, err := repo.GetWithTracks(plsBest.ID, true)
|
||||
Expect(err).To(BeNil())
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(pls.Name).To(Equal(plsBest.Name))
|
||||
Expect(pls.Tracks).To(HaveLen(2))
|
||||
Expect(pls.Tracks[0].ID).To(Equal("1"))
|
||||
@@ -109,4 +109,23 @@ var _ = Describe("PlaylistRepository", func() {
|
||||
Expect(all[1].ID).To(Equal(plsCool.ID))
|
||||
})
|
||||
})
|
||||
|
||||
Context("Smart Playlists", func() {
|
||||
var rules *criteria.Criteria
|
||||
BeforeEach(func() {
|
||||
rules = &criteria.Criteria{
|
||||
Expression: criteria.All{
|
||||
criteria.Contains{"title": "love"},
|
||||
},
|
||||
}
|
||||
})
|
||||
It("Put/Get", func() {
|
||||
newPls := model.Playlist{Name: "Great!", OwnerID: "userid", Rules: rules}
|
||||
Expect(repo.Put(&newPls)).To(Succeed())
|
||||
|
||||
savedPls, err := repo.Get(newPls.ID)
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(savedPls.Rules).To(Equal(rules))
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
@@ -1,6 +1,8 @@
|
||||
package persistence
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
|
||||
. "github.com/Masterminds/squirrel"
|
||||
"github.com/deluan/rest"
|
||||
"github.com/navidrome/navidrome/log"
|
||||
@@ -21,7 +23,7 @@ func (r *playlistRepository) Tracks(playlistId string, refreshSmartPlaylist bool
|
||||
p.playlistRepo = r
|
||||
p.playlistId = playlistId
|
||||
p.ctx = r.ctx
|
||||
p.ormer = r.ormer
|
||||
p.db = r.db
|
||||
p.tableName = "playlist_tracks"
|
||||
p.sortMappings = map[string]string{
|
||||
"id": "playlist_tracks.id",
|
||||
@@ -47,7 +49,15 @@ func (r *playlistTrackRepository) Read(id string) (interface{}, error) {
|
||||
"annotation.item_id = media_file_id"+
|
||||
" AND annotation.item_type = 'media_file'"+
|
||||
" AND annotation.user_id = '"+userId(r.ctx)+"')").
|
||||
Columns("starred", "starred_at", "play_count", "play_date", "rating", "f.*", "playlist_tracks.*").
|
||||
Columns(
|
||||
"coalesce(starred, 0)",
|
||||
"coalesce(play_count, 0)",
|
||||
"coalesce(rating, 0)",
|
||||
"starred_at",
|
||||
"play_date",
|
||||
"f.*",
|
||||
"playlist_tracks.*",
|
||||
).
|
||||
Join("media_file f on f.id = media_file_id").
|
||||
Where(And{Eq{"playlist_id": r.playlistId}, Eq{"id": id}})
|
||||
var trk model.PlaylistTrack
|
||||
@@ -77,7 +87,7 @@ func (r *playlistTrackRepository) GetAlbumIDs(options ...model.QueryOptions) ([]
|
||||
Join("media_file mf on mf.id = media_file_id").
|
||||
Where(Eq{"playlist_id": r.playlistId})
|
||||
var ids []string
|
||||
err := r.queryAll(sql, &ids)
|
||||
err := r.queryAllSlice(sql, &ids)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -112,20 +122,20 @@ func (r *playlistTrackRepository) Add(mediaFileIds []string) (int, error) {
|
||||
}
|
||||
|
||||
// Get next pos (ID) in playlist
|
||||
sql := r.newSelect().Columns("max(id) as max").Where(Eq{"playlist_id": r.playlistId})
|
||||
var max int
|
||||
err := r.queryOne(sql, &max)
|
||||
sq := r.newSelect().Columns("max(id) as max").Where(Eq{"playlist_id": r.playlistId})
|
||||
var res struct{ Max sql.NullInt32 }
|
||||
err := r.queryOne(sq, &res)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
return len(mediaFileIds), r.playlistRepo.addTracks(r.playlistId, max+1, mediaFileIds)
|
||||
return len(mediaFileIds), r.playlistRepo.addTracks(r.playlistId, int(res.Max.Int32+1), mediaFileIds)
|
||||
}
|
||||
|
||||
func (r *playlistTrackRepository) addMediaFileIds(cond Sqlizer) (int, error) {
|
||||
sq := Select("id").From("media_file").Where(cond).OrderBy("album_artist, album, release_date, disc_number, track_number")
|
||||
var ids []string
|
||||
err := r.queryAll(sq, &ids)
|
||||
err := r.queryAllSlice(sq, &ids)
|
||||
if err != nil {
|
||||
log.Error(r.ctx, "Error getting tracks to add to playlist", err)
|
||||
return 0, err
|
||||
@@ -156,7 +166,7 @@ func (r *playlistTrackRepository) AddDiscs(discs []model.DiscID) (int, error) {
|
||||
func (r *playlistTrackRepository) getTracks() ([]string, error) {
|
||||
all := r.newSelect().Columns("media_file_id").Where(Eq{"playlist_id": r.playlistId}).OrderBy("id")
|
||||
var ids []string
|
||||
err := r.queryAll(all, &ids)
|
||||
err := r.queryAllSlice(all, &ids)
|
||||
if err != nil {
|
||||
log.Error(r.ctx, "Error querying current tracks from playlist", "playlistId", r.playlistId, err)
|
||||
return nil, err
|
||||
|
||||
@@ -6,27 +6,27 @@ import (
|
||||
"time"
|
||||
|
||||
. "github.com/Masterminds/squirrel"
|
||||
"github.com/beego/beego/v2/client/orm"
|
||||
"github.com/navidrome/navidrome/log"
|
||||
"github.com/navidrome/navidrome/model"
|
||||
"github.com/navidrome/navidrome/utils/slice"
|
||||
"github.com/pocketbase/dbx"
|
||||
)
|
||||
|
||||
type playQueueRepository struct {
|
||||
sqlRepository
|
||||
}
|
||||
|
||||
func NewPlayQueueRepository(ctx context.Context, o orm.QueryExecutor) model.PlayQueueRepository {
|
||||
func NewPlayQueueRepository(ctx context.Context, db dbx.Builder) model.PlayQueueRepository {
|
||||
r := &playQueueRepository{}
|
||||
r.ctx = ctx
|
||||
r.ormer = o
|
||||
r.db = db
|
||||
r.tableName = "playqueue"
|
||||
return r
|
||||
}
|
||||
|
||||
type playQueue struct {
|
||||
ID string `structs:"id" orm:"column(id)"`
|
||||
UserID string `structs:"user_id" orm:"column(user_id)"`
|
||||
ID string `structs:"id"`
|
||||
UserID string `structs:"user_id"`
|
||||
Current string `structs:"current"`
|
||||
Position int64 `structs:"position"`
|
||||
ChangedBy string `structs:"changed_by"`
|
||||
@@ -116,7 +116,7 @@ func (r *playQueueRepository) loadTracks(tracks model.MediaFiles) model.MediaFil
|
||||
chunks := slice.BreakUp(ids, 50)
|
||||
|
||||
// Query each chunk of media_file ids and store results in a map
|
||||
mfRepo := NewMediaFileRepository(r.ctx, r.ormer)
|
||||
mfRepo := NewMediaFileRepository(r.ctx, r.db)
|
||||
trackMap := map[string]model.MediaFile{}
|
||||
for i := range chunks {
|
||||
idsFilter := Eq{"media_file.id": chunks[i]}
|
||||
|
||||
@@ -5,7 +5,6 @@ import (
|
||||
"time"
|
||||
|
||||
"github.com/Masterminds/squirrel"
|
||||
"github.com/beego/beego/v2/client/orm"
|
||||
"github.com/google/uuid"
|
||||
"github.com/navidrome/navidrome/log"
|
||||
"github.com/navidrome/navidrome/model"
|
||||
@@ -19,8 +18,8 @@ var _ = Describe("PlayQueueRepository", func() {
|
||||
|
||||
BeforeEach(func() {
|
||||
ctx := log.NewContext(context.TODO())
|
||||
ctx = request.WithUser(ctx, model.User{ID: "user1", UserName: "user1", IsAdmin: true})
|
||||
repo = NewPlayQueueRepository(ctx, orm.NewOrm())
|
||||
ctx = request.WithUser(ctx, model.User{ID: "userid", UserName: "userid", IsAdmin: true})
|
||||
repo = NewPlayQueueRepository(ctx, getDBXBuilder())
|
||||
})
|
||||
|
||||
Describe("PlayQueues", func() {
|
||||
@@ -32,24 +31,24 @@ var _ = Describe("PlayQueueRepository", func() {
|
||||
It("stores and retrieves the playqueue for the user", func() {
|
||||
By("Storing a playqueue for the user")
|
||||
|
||||
expected := aPlayQueue("user1", songDayInALife.ID, 123, songComeTogether, songDayInALife)
|
||||
Expect(repo.Store(expected)).To(BeNil())
|
||||
expected := aPlayQueue("userid", songDayInALife.ID, 123, songComeTogether, songDayInALife)
|
||||
Expect(repo.Store(expected)).To(Succeed())
|
||||
|
||||
actual, err := repo.Retrieve("user1")
|
||||
Expect(err).To(BeNil())
|
||||
actual, err := repo.Retrieve("userid")
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
AssertPlayQueue(expected, actual)
|
||||
|
||||
By("Storing a new playqueue for the same user")
|
||||
|
||||
another := aPlayQueue("user1", songRadioactivity.ID, 321, songAntenna, songRadioactivity)
|
||||
Expect(repo.Store(another)).To(BeNil())
|
||||
another := aPlayQueue("userid", songRadioactivity.ID, 321, songAntenna, songRadioactivity)
|
||||
Expect(repo.Store(another)).To(Succeed())
|
||||
|
||||
actual, err = repo.Retrieve("user1")
|
||||
Expect(err).To(BeNil())
|
||||
actual, err = repo.Retrieve("userid")
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
AssertPlayQueue(another, actual)
|
||||
Expect(countPlayQueues(repo, "user1")).To(Equal(1))
|
||||
Expect(countPlayQueues(repo, "userid")).To(Equal(1))
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
@@ -5,18 +5,18 @@ import (
|
||||
"errors"
|
||||
|
||||
. "github.com/Masterminds/squirrel"
|
||||
"github.com/beego/beego/v2/client/orm"
|
||||
"github.com/navidrome/navidrome/model"
|
||||
"github.com/pocketbase/dbx"
|
||||
)
|
||||
|
||||
type propertyRepository struct {
|
||||
sqlRepository
|
||||
}
|
||||
|
||||
func NewPropertyRepository(ctx context.Context, o orm.QueryExecutor) model.PropertyRepository {
|
||||
func NewPropertyRepository(ctx context.Context, db dbx.Builder) model.PropertyRepository {
|
||||
r := &propertyRepository{}
|
||||
r.ctx = ctx
|
||||
r.ormer = o
|
||||
r.db = db
|
||||
r.tableName = "property"
|
||||
return r
|
||||
}
|
||||
|
||||
@@ -3,7 +3,6 @@ package persistence
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/beego/beego/v2/client/orm"
|
||||
"github.com/navidrome/navidrome/log"
|
||||
"github.com/navidrome/navidrome/model"
|
||||
. "github.com/onsi/ginkgo/v2"
|
||||
@@ -14,7 +13,7 @@ var _ = Describe("Property Repository", func() {
|
||||
var pr model.PropertyRepository
|
||||
|
||||
BeforeEach(func() {
|
||||
pr = NewPropertyRepository(log.NewContext(context.TODO()), orm.NewOrm())
|
||||
pr = NewPropertyRepository(log.NewContext(context.TODO()), getDBXBuilder())
|
||||
})
|
||||
|
||||
It("saves and restore a new property", func() {
|
||||
|
||||
@@ -7,10 +7,10 @@ import (
|
||||
"time"
|
||||
|
||||
. "github.com/Masterminds/squirrel"
|
||||
"github.com/beego/beego/v2/client/orm"
|
||||
"github.com/deluan/rest"
|
||||
"github.com/google/uuid"
|
||||
"github.com/navidrome/navidrome/model"
|
||||
"github.com/pocketbase/dbx"
|
||||
)
|
||||
|
||||
type radioRepository struct {
|
||||
@@ -18,10 +18,10 @@ type radioRepository struct {
|
||||
sqlRestful
|
||||
}
|
||||
|
||||
func NewRadioRepository(ctx context.Context, o orm.QueryExecutor) model.RadioRepository {
|
||||
func NewRadioRepository(ctx context.Context, db dbx.Builder) model.RadioRepository {
|
||||
r := &radioRepository{}
|
||||
r.ctx = ctx
|
||||
r.ormer = o
|
||||
r.db = db
|
||||
r.tableName = "radio"
|
||||
r.filterMappings = map[string]filterFunc{
|
||||
"name": containsFilter,
|
||||
@@ -73,9 +73,9 @@ func (r *radioRepository) Put(radio *model.Radio) error {
|
||||
if radio.ID == "" {
|
||||
radio.CreatedAt = time.Now()
|
||||
radio.ID = strings.ReplaceAll(uuid.NewString(), "-", "")
|
||||
values, _ = toSqlArgs(*radio)
|
||||
values, _ = toSQLArgs(*radio)
|
||||
} else {
|
||||
values, _ = toSqlArgs(*radio)
|
||||
values, _ = toSQLArgs(*radio)
|
||||
update := Update(r.tableName).Where(Eq{"id": radio.ID}).SetMap(values)
|
||||
count, err := r.executeSQL(update)
|
||||
|
||||
|
||||
@@ -3,7 +3,6 @@ package persistence
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/beego/beego/v2/client/orm"
|
||||
"github.com/deluan/rest"
|
||||
"github.com/navidrome/navidrome/log"
|
||||
"github.com/navidrome/navidrome/model"
|
||||
@@ -23,7 +22,7 @@ var _ = Describe("RadioRepository", func() {
|
||||
BeforeEach(func() {
|
||||
ctx := log.NewContext(context.TODO())
|
||||
ctx = request.WithUser(ctx, model.User{ID: "userid", UserName: "userid", IsAdmin: true})
|
||||
repo = NewRadioRepository(ctx, orm.NewOrm())
|
||||
repo = NewRadioRepository(ctx, getDBXBuilder())
|
||||
_ = repo.Put(&radioWithHomePage)
|
||||
})
|
||||
|
||||
@@ -120,7 +119,7 @@ var _ = Describe("RadioRepository", func() {
|
||||
BeforeEach(func() {
|
||||
ctx := log.NewContext(context.TODO())
|
||||
ctx = request.WithUser(ctx, model.User{ID: "userid", UserName: "userid", IsAdmin: false})
|
||||
repo = NewRadioRepository(ctx, orm.NewOrm())
|
||||
repo = NewRadioRepository(ctx, getDBXBuilder())
|
||||
})
|
||||
|
||||
Describe("Count", func() {
|
||||
|
||||
@@ -6,18 +6,18 @@ import (
|
||||
"time"
|
||||
|
||||
. "github.com/Masterminds/squirrel"
|
||||
"github.com/beego/beego/v2/client/orm"
|
||||
"github.com/navidrome/navidrome/model"
|
||||
"github.com/pocketbase/dbx"
|
||||
)
|
||||
|
||||
type scrobbleBufferRepository struct {
|
||||
sqlRepository
|
||||
}
|
||||
|
||||
func NewScrobbleBufferRepository(ctx context.Context, o orm.QueryExecutor) model.ScrobbleBufferRepository {
|
||||
func NewScrobbleBufferRepository(ctx context.Context, db dbx.Builder) model.ScrobbleBufferRepository {
|
||||
r := &scrobbleBufferRepository{}
|
||||
r.ctx = ctx
|
||||
r.ormer = o
|
||||
r.db = db
|
||||
r.tableName = "scrobble_buffer"
|
||||
return r
|
||||
}
|
||||
@@ -31,7 +31,7 @@ func (r *scrobbleBufferRepository) UserIDs(service string) ([]string, error) {
|
||||
GroupBy("user_id").
|
||||
OrderBy("count(*)")
|
||||
var userIds []string
|
||||
err := r.queryAll(sql, &userIds)
|
||||
err := r.queryAllSlice(sql, &userIds)
|
||||
return userIds, err
|
||||
}
|
||||
|
||||
|
||||
@@ -8,11 +8,11 @@ import (
|
||||
"time"
|
||||
|
||||
. "github.com/Masterminds/squirrel"
|
||||
"github.com/beego/beego/v2/client/orm"
|
||||
"github.com/deluan/rest"
|
||||
"github.com/navidrome/navidrome/log"
|
||||
"github.com/navidrome/navidrome/model"
|
||||
"github.com/navidrome/navidrome/model/request"
|
||||
"github.com/pocketbase/dbx"
|
||||
)
|
||||
|
||||
type shareRepository struct {
|
||||
@@ -20,10 +20,10 @@ type shareRepository struct {
|
||||
sqlRestful
|
||||
}
|
||||
|
||||
func NewShareRepository(ctx context.Context, o orm.QueryExecutor) model.ShareRepository {
|
||||
func NewShareRepository(ctx context.Context, db dbx.Builder) model.ShareRepository {
|
||||
r := &shareRepository{}
|
||||
r.ctx = ctx
|
||||
r.ormer = o
|
||||
r.db = db
|
||||
r.tableName = "share"
|
||||
return r
|
||||
}
|
||||
@@ -79,27 +79,27 @@ func (r *shareRepository) loadMedia(share *model.Share) error {
|
||||
}
|
||||
switch share.ResourceType {
|
||||
case "artist":
|
||||
albumRepo := NewAlbumRepository(r.ctx, r.ormer)
|
||||
albumRepo := NewAlbumRepository(r.ctx, r.db)
|
||||
share.Albums, err = albumRepo.GetAll(model.QueryOptions{Filters: Eq{"album_artist_id": ids}, Sort: "artist"})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
mfRepo := NewMediaFileRepository(r.ctx, r.ormer)
|
||||
mfRepo := NewMediaFileRepository(r.ctx, r.db)
|
||||
share.Tracks, err = mfRepo.GetAll(model.QueryOptions{Filters: Eq{"album_artist_id": ids}, Sort: "artist"})
|
||||
return err
|
||||
case "album":
|
||||
albumRepo := NewAlbumRepository(r.ctx, r.ormer)
|
||||
albumRepo := NewAlbumRepository(r.ctx, r.db)
|
||||
share.Albums, err = albumRepo.GetAll(model.QueryOptions{Filters: Eq{"id": ids}})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
mfRepo := NewMediaFileRepository(r.ctx, r.ormer)
|
||||
mfRepo := NewMediaFileRepository(r.ctx, r.db)
|
||||
share.Tracks, err = mfRepo.GetAll(model.QueryOptions{Filters: Eq{"album_id": ids}, Sort: "album"})
|
||||
return err
|
||||
case "playlist":
|
||||
// Create a context with a fake admin user, to be able to access all playlists
|
||||
ctx := request.WithUser(r.ctx, model.User{IsAdmin: true})
|
||||
plsRepo := NewPlaylistRepository(ctx, r.ormer)
|
||||
plsRepo := NewPlaylistRepository(ctx, r.db)
|
||||
tracks, err := plsRepo.Tracks(ids[0], true).GetAll(model.QueryOptions{Sort: "id"})
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -109,7 +109,7 @@ func (r *shareRepository) loadMedia(share *model.Share) error {
|
||||
}
|
||||
return nil
|
||||
case "media_file":
|
||||
mfRepo := NewMediaFileRepository(r.ctx, r.ormer)
|
||||
mfRepo := NewMediaFileRepository(r.ctx, r.db)
|
||||
tracks, err := mfRepo.GetAll(model.QueryOptions{Filters: Eq{"id": ids}})
|
||||
share.Tracks = sortByIdPosition(tracks, ids)
|
||||
return err
|
||||
|
||||
@@ -1,11 +1,11 @@
|
||||
package persistence
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"errors"
|
||||
"time"
|
||||
|
||||
. "github.com/Masterminds/squirrel"
|
||||
"github.com/beego/beego/v2/client/orm"
|
||||
"github.com/google/uuid"
|
||||
"github.com/navidrome/navidrome/log"
|
||||
"github.com/navidrome/navidrome/model"
|
||||
@@ -42,7 +42,7 @@ func (r sqlRepository) annUpsert(values map[string]interface{}, itemIDs ...strin
|
||||
upd = upd.Set(f, v)
|
||||
}
|
||||
c, err := r.executeSQL(upd)
|
||||
if c == 0 || errors.Is(err, orm.ErrNoRows) {
|
||||
if c == 0 || errors.Is(err, sql.ErrNoRows) {
|
||||
for _, itemID := range itemIDs {
|
||||
values["ann_id"] = uuid.NewString()
|
||||
values["user_id"] = userId(r.ctx)
|
||||
@@ -73,7 +73,7 @@ func (r sqlRepository) IncPlayCount(itemID string, ts time.Time) error {
|
||||
Set("play_date", Expr("max(ifnull(play_date,''),?)", ts))
|
||||
c, err := r.executeSQL(upd)
|
||||
|
||||
if c == 0 || errors.Is(err, orm.ErrNoRows) {
|
||||
if c == 0 || errors.Is(err, sql.ErrNoRows) {
|
||||
values := map[string]interface{}{}
|
||||
values["ann_id"] = uuid.NewString()
|
||||
values["user_id"] = userId(r.ctx)
|
||||
|
||||
@@ -2,24 +2,25 @@ package persistence
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
. "github.com/Masterminds/squirrel"
|
||||
"github.com/beego/beego/v2/client/orm"
|
||||
"github.com/google/uuid"
|
||||
"github.com/navidrome/navidrome/conf"
|
||||
"github.com/navidrome/navidrome/log"
|
||||
"github.com/navidrome/navidrome/model"
|
||||
"github.com/navidrome/navidrome/model/request"
|
||||
"github.com/pocketbase/dbx"
|
||||
)
|
||||
|
||||
type sqlRepository struct {
|
||||
ctx context.Context
|
||||
tableName string
|
||||
ormer orm.QueryExecutor
|
||||
db dbx.Builder
|
||||
sortMappings map[string]string
|
||||
}
|
||||
|
||||
@@ -122,13 +123,13 @@ func (r sqlRepository) applyFilters(sq SelectBuilder, options ...model.QueryOpti
|
||||
}
|
||||
|
||||
func (r sqlRepository) executeSQL(sq Sqlizer) (int64, error) {
|
||||
query, args, err := sq.ToSql()
|
||||
query, args, err := r.toSQL(sq)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
start := time.Now()
|
||||
var c int64
|
||||
res, err := r.ormer.Raw(query, args...).Exec()
|
||||
res, err := r.db.NewQuery(query).Bind(args).Execute()
|
||||
if res != nil {
|
||||
c, _ = res.RowsAffected()
|
||||
}
|
||||
@@ -141,16 +142,31 @@ func (r sqlRepository) executeSQL(sq Sqlizer) (int64, error) {
|
||||
return res.RowsAffected()
|
||||
}
|
||||
|
||||
func (r sqlRepository) toSQL(sq Sqlizer) (string, dbx.Params, error) {
|
||||
query, args, err := sq.ToSql()
|
||||
if err != nil {
|
||||
return "", nil, err
|
||||
}
|
||||
// Replace query placeholders with named params
|
||||
params := dbx.Params{}
|
||||
for i, arg := range args {
|
||||
p := fmt.Sprintf("p%d", i)
|
||||
query = strings.Replace(query, "?", "{:"+p+"}", 1)
|
||||
params[p] = arg
|
||||
}
|
||||
return query, params, nil
|
||||
}
|
||||
|
||||
// Note: Due to a bug in the QueryRow method, this function does not map any embedded structs (ex: annotations)
|
||||
// In this case, use the queryAll method and get the first item of the returned list
|
||||
func (r sqlRepository) queryOne(sq Sqlizer, response interface{}) error {
|
||||
query, args, err := sq.ToSql()
|
||||
query, args, err := r.toSQL(sq)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
start := time.Now()
|
||||
err = r.ormer.Raw(query, args...).QueryRow(response)
|
||||
if errors.Is(err, orm.ErrNoRows) {
|
||||
err = r.db.NewQuery(query).Bind(args).One(response)
|
||||
if errors.Is(err, sql.ErrNoRows) {
|
||||
r.logSQL(query, args, nil, 0, start)
|
||||
return model.ErrNotFound
|
||||
}
|
||||
@@ -162,17 +178,33 @@ func (r sqlRepository) queryAll(sq SelectBuilder, response interface{}, options
|
||||
if len(options) > 0 && options[0].Offset > 0 {
|
||||
sq = r.optimizePagination(sq, options[0])
|
||||
}
|
||||
query, args, err := sq.ToSql()
|
||||
query, args, err := r.toSQL(sq)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
start := time.Now()
|
||||
c, err := r.ormer.Raw(query, args...).QueryRows(response)
|
||||
if errors.Is(err, orm.ErrNoRows) {
|
||||
r.logSQL(query, args, nil, c, start)
|
||||
err = r.db.NewQuery(query).Bind(args).All(response)
|
||||
if errors.Is(err, sql.ErrNoRows) {
|
||||
r.logSQL(query, args, nil, -1, start)
|
||||
return model.ErrNotFound
|
||||
}
|
||||
r.logSQL(query, args, err, c, start)
|
||||
r.logSQL(query, args, err, -1, start)
|
||||
return err
|
||||
}
|
||||
|
||||
// queryAllSlice is a helper function to query a single column and return the result in a slice
|
||||
func (r sqlRepository) queryAllSlice(sq SelectBuilder, response interface{}) error {
|
||||
query, args, err := r.toSQL(sq)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
start := time.Now()
|
||||
err = r.db.NewQuery(query).Bind(args).Column(response)
|
||||
if errors.Is(err, sql.ErrNoRows) {
|
||||
r.logSQL(query, args, nil, -1, start)
|
||||
return model.ErrNotFound
|
||||
}
|
||||
r.logSQL(query, args, err, -1, start)
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -207,7 +239,7 @@ func (r sqlRepository) count(countQuery SelectBuilder, options ...model.QueryOpt
|
||||
}
|
||||
|
||||
func (r sqlRepository) put(id string, m interface{}, colsToUpdate ...string) (newId string, err error) {
|
||||
values, _ := toSqlArgs(m)
|
||||
values, _ := toSQLArgs(m)
|
||||
// If there's an ID, try to update first
|
||||
if id != "" {
|
||||
updateValues := map[string]interface{}{}
|
||||
@@ -246,28 +278,28 @@ func (r sqlRepository) put(id string, m interface{}, colsToUpdate ...string) (ne
|
||||
func (r sqlRepository) delete(cond Sqlizer) error {
|
||||
del := Delete(r.tableName).Where(cond)
|
||||
_, err := r.executeSQL(del)
|
||||
if errors.Is(err, orm.ErrNoRows) {
|
||||
if errors.Is(err, sql.ErrNoRows) {
|
||||
return model.ErrNotFound
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
func (r sqlRepository) logSQL(sql string, args []interface{}, err error, rowsAffected int64, start time.Time) {
|
||||
func (r sqlRepository) logSQL(sql string, args dbx.Params, err error, rowsAffected int64, start time.Time) {
|
||||
elapsed := time.Since(start)
|
||||
var fmtArgs []string
|
||||
for i := range args {
|
||||
var f string
|
||||
switch a := args[i].(type) {
|
||||
case string:
|
||||
f = `'` + a + `'`
|
||||
default:
|
||||
f = fmt.Sprintf("%v", a)
|
||||
}
|
||||
fmtArgs = append(fmtArgs, f)
|
||||
}
|
||||
//var fmtArgs []string
|
||||
//for name, val := range args {
|
||||
// var f string
|
||||
// switch a := args[val].(type) {
|
||||
// case string:
|
||||
// f = `'` + a + `'`
|
||||
// default:
|
||||
// f = fmt.Sprintf("%v", a)
|
||||
// }
|
||||
// fmtArgs = append(fmtArgs, f)
|
||||
//}
|
||||
if err != nil {
|
||||
log.Error(r.ctx, "SQL: `"+sql+"`", "args", `[`+strings.Join(fmtArgs, ",")+`]`, "rowsAffected", rowsAffected, "elapsedTime", elapsed, err)
|
||||
log.Error(r.ctx, "SQL: `"+sql+"`", "args", args, "rowsAffected", rowsAffected, "elapsedTime", elapsed, err)
|
||||
} else {
|
||||
log.Trace(r.ctx, "SQL: `"+sql+"`", "args", `[`+strings.Join(fmtArgs, ",")+`]`, "rowsAffected", rowsAffected, "elapsedTime", elapsed)
|
||||
log.Trace(r.ctx, "SQL: `"+sql+"`", "args", args, "rowsAffected", rowsAffected, "elapsedTime", elapsed)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,11 +1,11 @@
|
||||
package persistence
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"errors"
|
||||
"time"
|
||||
|
||||
. "github.com/Masterminds/squirrel"
|
||||
"github.com/beego/beego/v2/client/orm"
|
||||
"github.com/navidrome/navidrome/log"
|
||||
"github.com/navidrome/navidrome/model"
|
||||
"github.com/navidrome/navidrome/model/request"
|
||||
@@ -19,7 +19,7 @@ func (r sqlRepository) withBookmark(sql SelectBuilder, idField string) SelectBui
|
||||
"bookmark.item_id = " + idField +
|
||||
" AND bookmark.item_type = '" + r.tableName + "'" +
|
||||
" AND bookmark.user_id = '" + userId(r.ctx) + "')").
|
||||
Columns("position as bookmark_position")
|
||||
Columns("coalesce(position, 0) as bookmark_position")
|
||||
}
|
||||
|
||||
func (r sqlRepository) bmkID(itemID ...string) And {
|
||||
@@ -45,7 +45,7 @@ func (r sqlRepository) bmkUpsert(itemID, comment string, position int64) error {
|
||||
if err == nil {
|
||||
log.Debug(r.ctx, "Updated bookmark", "id", itemID, "user", user.UserName, "position", position, "comment", comment)
|
||||
}
|
||||
if c == 0 || errors.Is(err, orm.ErrNoRows) {
|
||||
if c == 0 || errors.Is(err, sql.ErrNoRows) {
|
||||
values["user_id"] = user.ID
|
||||
values["item_type"] = r.tableName
|
||||
values["item_id"] = itemID
|
||||
@@ -82,8 +82,8 @@ func (r sqlRepository) DeleteBookmark(id string) error {
|
||||
}
|
||||
|
||||
type bookmark struct {
|
||||
UserID string `json:"user_id" orm:"column(user_id)"`
|
||||
ItemID string `json:"item_id" orm:"column(item_id)"`
|
||||
UserID string `json:"user_id"`
|
||||
ItemID string `json:"item_id"`
|
||||
ItemType string `json:"item_type"`
|
||||
Comment string `json:"comment"`
|
||||
Position int64 `json:"position"`
|
||||
@@ -96,10 +96,10 @@ func (r sqlRepository) GetBookmarks() (model.Bookmarks, error) {
|
||||
user, _ := request.UserFrom(r.ctx)
|
||||
|
||||
idField := r.tableName + ".id"
|
||||
sql := r.newSelectWithAnnotation(idField).Columns("*")
|
||||
sql = r.withBookmark(sql, idField).Where(NotEq{bookmarkTable + ".item_id": nil})
|
||||
sq := r.newSelectWithAnnotation(idField).Columns(r.tableName + ".*")
|
||||
sq = r.withBookmark(sq, idField).Where(NotEq{bookmarkTable + ".item_id": nil})
|
||||
var mfs model.MediaFiles
|
||||
err := r.queryAll(sql, &mfs)
|
||||
err := r.queryAll(sq, &mfs)
|
||||
if err != nil {
|
||||
log.Error(r.ctx, "Error getting mediafiles with bookmarks", "user", user.UserName, err)
|
||||
return nil, err
|
||||
@@ -117,9 +117,9 @@ func (r sqlRepository) GetBookmarks() (model.Bookmarks, error) {
|
||||
mfMap[mf.ID] = i
|
||||
}
|
||||
|
||||
sql = Select("*").From(bookmarkTable).Where(r.bmkID(ids...))
|
||||
sq = Select("*").From(bookmarkTable).Where(r.bmkID(ids...))
|
||||
var bmks []bookmark
|
||||
err = r.queryAll(sql, &bmks)
|
||||
err = r.queryAll(sq, &bmks)
|
||||
if err != nil {
|
||||
log.Error(r.ctx, "Error getting bookmarks", "user", user.UserName, "ids", ids, err)
|
||||
return nil, err
|
||||
|
||||
@@ -3,7 +3,6 @@ package persistence
|
||||
import (
|
||||
"context"
|
||||
|
||||
"github.com/beego/beego/v2/client/orm"
|
||||
"github.com/navidrome/navidrome/log"
|
||||
"github.com/navidrome/navidrome/model"
|
||||
"github.com/navidrome/navidrome/model/request"
|
||||
@@ -16,8 +15,8 @@ var _ = Describe("sqlBookmarks", func() {
|
||||
|
||||
BeforeEach(func() {
|
||||
ctx := log.NewContext(context.TODO())
|
||||
ctx = request.WithUser(ctx, model.User{ID: "user1"})
|
||||
mr = NewMediaFileRepository(ctx, orm.NewOrm())
|
||||
ctx = request.WithUser(ctx, model.User{ID: "userid"})
|
||||
mr = NewMediaFileRepository(ctx, getDBXBuilder())
|
||||
})
|
||||
|
||||
Describe("Bookmarks", func() {
|
||||
@@ -30,7 +29,7 @@ var _ = Describe("sqlBookmarks", func() {
|
||||
Expect(mr.AddBookmark(songAntenna.ID, "this is a comment", 123)).To(BeNil())
|
||||
|
||||
bms, err := mr.GetBookmarks()
|
||||
Expect(err).To(BeNil())
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
Expect(bms).To(HaveLen(1))
|
||||
Expect(bms[0].Item.ID).To(Equal(songAntenna.ID))
|
||||
@@ -46,7 +45,7 @@ var _ = Describe("sqlBookmarks", func() {
|
||||
Expect(mr.AddBookmark(songAntenna.ID, "another comment", 333)).To(BeNil())
|
||||
|
||||
bms, err = mr.GetBookmarks()
|
||||
Expect(err).To(BeNil())
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
|
||||
Expect(bms[0].Item.ID).To(Equal(songAntenna.ID))
|
||||
Expect(bms[0].Comment).To(Equal("another comment"))
|
||||
@@ -57,16 +56,19 @@ var _ = Describe("sqlBookmarks", func() {
|
||||
By("Saving another bookmark")
|
||||
Expect(mr.AddBookmark(songComeTogether.ID, "one more comment", 444)).To(BeNil())
|
||||
bms, err = mr.GetBookmarks()
|
||||
Expect(err).To(BeNil())
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(bms).To(HaveLen(2))
|
||||
|
||||
By("Delete bookmark")
|
||||
Expect(mr.DeleteBookmark(songAntenna.ID))
|
||||
Expect(mr.DeleteBookmark(songAntenna.ID)).To(Succeed())
|
||||
bms, err = mr.GetBookmarks()
|
||||
Expect(err).To(BeNil())
|
||||
Expect(err).ToNot(HaveOccurred())
|
||||
Expect(bms).To(HaveLen(1))
|
||||
Expect(bms[0].Item.ID).To(Equal(songComeTogether.ID))
|
||||
Expect(bms[0].Item.Title).To(Equal(songComeTogether.Title))
|
||||
|
||||
Expect(mr.DeleteBookmark(songComeTogether.ID)).To(Succeed())
|
||||
Expect(mr.GetBookmarks()).To(BeEmpty())
|
||||
})
|
||||
})
|
||||
})
|
||||
|
||||
@@ -21,7 +21,7 @@ func (r sqlRepository) doSearch(q string, offset, size int, results interface{},
|
||||
return nil
|
||||
}
|
||||
|
||||
sq := r.newSelectWithAnnotation(r.tableName + ".id").Columns("*")
|
||||
sq := r.newSelectWithAnnotation(r.tableName + ".id").Columns(r.tableName + ".*")
|
||||
filter := fullTextExpr(q)
|
||||
if filter != nil {
|
||||
sq = sq.Where(filter)
|
||||
|
||||
@@ -5,9 +5,9 @@ import (
|
||||
"errors"
|
||||
|
||||
. "github.com/Masterminds/squirrel"
|
||||
"github.com/beego/beego/v2/client/orm"
|
||||
"github.com/deluan/rest"
|
||||
"github.com/navidrome/navidrome/model"
|
||||
"github.com/pocketbase/dbx"
|
||||
)
|
||||
|
||||
type transcodingRepository struct {
|
||||
@@ -15,10 +15,10 @@ type transcodingRepository struct {
|
||||
sqlRestful
|
||||
}
|
||||
|
||||
func NewTranscodingRepository(ctx context.Context, o orm.QueryExecutor) model.TranscodingRepository {
|
||||
func NewTranscodingRepository(ctx context.Context, db dbx.Builder) model.TranscodingRepository {
|
||||
r := &transcodingRepository{}
|
||||
r.ctx = ctx
|
||||
r.ormer = o
|
||||
r.db = db
|
||||
r.tableName = "transcoding"
|
||||
return r
|
||||
}
|
||||
|
||||
@@ -5,18 +5,18 @@ import (
|
||||
"errors"
|
||||
|
||||
. "github.com/Masterminds/squirrel"
|
||||
"github.com/beego/beego/v2/client/orm"
|
||||
"github.com/navidrome/navidrome/model"
|
||||
"github.com/pocketbase/dbx"
|
||||
)
|
||||
|
||||
type userPropsRepository struct {
|
||||
sqlRepository
|
||||
}
|
||||
|
||||
func NewUserPropsRepository(ctx context.Context, o orm.QueryExecutor) model.UserPropsRepository {
|
||||
func NewUserPropsRepository(ctx context.Context, db dbx.Builder) model.UserPropsRepository {
|
||||
r := &userPropsRepository{}
|
||||
r.ctx = ctx
|
||||
r.ormer = o
|
||||
r.db = db
|
||||
r.tableName = "user_props"
|
||||
return r
|
||||
}
|
||||
|
||||
@@ -10,7 +10,6 @@ import (
|
||||
"time"
|
||||
|
||||
. "github.com/Masterminds/squirrel"
|
||||
"github.com/beego/beego/v2/client/orm"
|
||||
"github.com/deluan/rest"
|
||||
"github.com/google/uuid"
|
||||
"github.com/navidrome/navidrome/conf"
|
||||
@@ -18,6 +17,7 @@ import (
|
||||
"github.com/navidrome/navidrome/log"
|
||||
"github.com/navidrome/navidrome/model"
|
||||
"github.com/navidrome/navidrome/utils"
|
||||
"github.com/pocketbase/dbx"
|
||||
)
|
||||
|
||||
type userRepository struct {
|
||||
@@ -30,10 +30,10 @@ var (
|
||||
encKey []byte
|
||||
)
|
||||
|
||||
func NewUserRepository(ctx context.Context, o orm.QueryExecutor) model.UserRepository {
|
||||
func NewUserRepository(ctx context.Context, db dbx.Builder) model.UserRepository {
|
||||
r := &userRepository{}
|
||||
r.ctx = ctx
|
||||
r.ormer = o
|
||||
r.db = db
|
||||
r.tableName = "user"
|
||||
once.Do(func() {
|
||||
_ = r.initPasswordEncryptionKey()
|
||||
@@ -67,7 +67,7 @@ func (r *userRepository) Put(u *model.User) error {
|
||||
if u.NewPassword != "" {
|
||||
_ = r.encryptPassword(u)
|
||||
}
|
||||
values, _ := toSqlArgs(*u)
|
||||
values, _ := toSQLArgs(*u)
|
||||
delete(values, "current_password")
|
||||
update := Update(r.tableName).Where(Eq{"id": u.ID}).SetMap(values)
|
||||
count, err := r.executeSQL(update)
|
||||
@@ -268,7 +268,7 @@ func (r *userRepository) initPasswordEncryptionKey() error {
|
||||
key := keyTo32Bytes(conf.Server.PasswordEncryptionKey)
|
||||
keySum := fmt.Sprintf("%x", sha256.Sum256(key))
|
||||
|
||||
props := NewPropertyRepository(r.ctx, r.ormer)
|
||||
props := NewPropertyRepository(r.ctx, r.db)
|
||||
savedKeySum, err := props.Get(consts.PasswordsEncryptedKey)
|
||||
|
||||
// If passwords are already encrypted
|
||||
|
||||
@@ -4,7 +4,6 @@ import (
|
||||
"context"
|
||||
"errors"
|
||||
|
||||
"github.com/beego/beego/v2/client/orm"
|
||||
"github.com/deluan/rest"
|
||||
"github.com/google/uuid"
|
||||
"github.com/navidrome/navidrome/consts"
|
||||
@@ -19,7 +18,7 @@ var _ = Describe("UserRepository", func() {
|
||||
var repo model.UserRepository
|
||||
|
||||
BeforeEach(func() {
|
||||
repo = NewUserRepository(log.NewContext(context.TODO()), orm.NewOrm())
|
||||
repo = NewUserRepository(log.NewContext(context.TODO()), getDBXBuilder())
|
||||
})
|
||||
|
||||
Describe("Put/Get/FindByUsername", func() {
|
||||
|
||||
Reference in New Issue
Block a user