refactor: new persistence, more SQL, less ORM
This commit is contained in:
+116
-80
@@ -1,40 +1,51 @@
|
||||
package persistence
|
||||
|
||||
import (
|
||||
"github.com/Masterminds/squirrel"
|
||||
"context"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
. "github.com/Masterminds/squirrel"
|
||||
"github.com/astaxie/beego/orm"
|
||||
"github.com/deluan/navidrome/log"
|
||||
"github.com/deluan/navidrome/model"
|
||||
"github.com/deluan/rest"
|
||||
)
|
||||
|
||||
type sqlRepository struct {
|
||||
tableName string
|
||||
ormer orm.Ormer
|
||||
ctx context.Context
|
||||
tableName string
|
||||
fieldNames []string
|
||||
ormer orm.Ormer
|
||||
}
|
||||
|
||||
func (r *sqlRepository) newQuery(options ...model.QueryOptions) orm.QuerySeter {
|
||||
q := r.ormer.QueryTable(r.tableName)
|
||||
if len(options) > 0 {
|
||||
opts := options[0]
|
||||
q = q.Offset(opts.Offset)
|
||||
if opts.Max > 0 {
|
||||
q = q.Limit(opts.Max)
|
||||
}
|
||||
if opts.Sort != "" {
|
||||
if opts.Order == "desc" {
|
||||
q = q.OrderBy("-" + opts.Sort)
|
||||
} else {
|
||||
q = q.OrderBy(opts.Sort)
|
||||
}
|
||||
}
|
||||
for field, value := range opts.Filters {
|
||||
q = q.Filter(field, value)
|
||||
}
|
||||
const invalidUserId = "-1"
|
||||
|
||||
func userId(ctx context.Context) string {
|
||||
user := ctx.Value("user")
|
||||
if user == nil {
|
||||
return invalidUserId
|
||||
}
|
||||
return q
|
||||
usr := user.(*model.User)
|
||||
return usr.ID
|
||||
}
|
||||
|
||||
func (r *sqlRepository) newRawQuery(options ...model.QueryOptions) squirrel.SelectBuilder {
|
||||
sq := squirrel.Select("*").From(r.tableName)
|
||||
func (r *sqlRepository) newSelectWithAnnotation(itemType, idField string, options ...model.QueryOptions) SelectBuilder {
|
||||
return r.newSelect(options...).
|
||||
LeftJoin("annotation on ("+
|
||||
"annotation.item_id = "+idField+
|
||||
" AND annotation.item_type = '"+itemType+"'"+
|
||||
" AND annotation.user_id = '"+userId(r.ctx)+"')").
|
||||
Columns("starred", "starred_at", "play_count", "play_date", "rating")
|
||||
}
|
||||
|
||||
func (r *sqlRepository) newSelect(options ...model.QueryOptions) SelectBuilder {
|
||||
sq := Select().From(r.tableName)
|
||||
sq = r.applyOptions(sq, options...)
|
||||
return sq
|
||||
}
|
||||
|
||||
func (r *sqlRepository) applyOptions(sq SelectBuilder, options ...model.QueryOptions) SelectBuilder {
|
||||
if len(options) > 0 {
|
||||
if options[0].Max > 0 {
|
||||
sq = sq.Limit(uint64(options[0].Max))
|
||||
@@ -49,83 +60,108 @@ func (r *sqlRepository) newRawQuery(options ...model.QueryOptions) squirrel.Sele
|
||||
sq = sq.OrderBy(options[0].Sort)
|
||||
}
|
||||
}
|
||||
if len(options[0].Filters) > 0 {
|
||||
for f, v := range options[0].Filters {
|
||||
sq = sq.Where(Eq{f: v})
|
||||
}
|
||||
}
|
||||
}
|
||||
return sq
|
||||
}
|
||||
|
||||
func (r *sqlRepository) CountAll() (int64, error) {
|
||||
return r.newQuery().Count()
|
||||
}
|
||||
|
||||
func (r *sqlRepository) Exists(id string) (bool, error) {
|
||||
c, err := r.newQuery().Filter("id", id).Count()
|
||||
return c == 1, err
|
||||
}
|
||||
|
||||
// "Hack" to bypass Postgres driver limitation
|
||||
func (r *sqlRepository) insert(record interface{}) error {
|
||||
_, err := r.ormer.Insert(record)
|
||||
if err != nil && err.Error() != "LastInsertId is not supported by this driver" {
|
||||
return err
|
||||
func (r sqlRepository) executeSQL(sq Sqlizer) (int64, error) {
|
||||
query, args, err := r.toSql(sq)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return nil
|
||||
res, err := r.ormer.Raw(query, args...).Exec()
|
||||
if err != nil {
|
||||
if err.Error() != "LastInsertId is not supported by this driver" {
|
||||
return 0, err
|
||||
}
|
||||
}
|
||||
return res.RowsAffected()
|
||||
}
|
||||
|
||||
func (r *sqlRepository) put(id string, a interface{}) error {
|
||||
c, err := r.newQuery().Filter("id", id).Count()
|
||||
func (r sqlRepository) queryOne(sq Sqlizer, response interface{}) error {
|
||||
query, args, err := r.toSql(sq)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if c == 0 {
|
||||
err = r.insert(a)
|
||||
if err != nil && err.Error() == "LastInsertId is not supported by this driver" {
|
||||
err = nil
|
||||
}
|
||||
err = r.ormer.Raw(query, args...).QueryRow(response)
|
||||
if err == orm.ErrNoRows {
|
||||
return model.ErrNotFound
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
func (r sqlRepository) queryAll(sq Sqlizer, response interface{}) error {
|
||||
query, args, err := r.toSql(sq)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
_, err = r.ormer.Update(a)
|
||||
_, err = r.ormer.Raw(query, args...).QueryRows(response)
|
||||
if err == orm.ErrNoRows {
|
||||
return model.ErrNotFound
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
func paginateSlice(slice []string, skip int, size int) []string {
|
||||
if skip > len(slice) {
|
||||
skip = len(slice)
|
||||
func (r sqlRepository) exists(existsQuery SelectBuilder) (bool, error) {
|
||||
existsQuery = existsQuery.Columns("count(*) as count").From(r.tableName)
|
||||
query, args, err := r.toSql(existsQuery)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
end := skip + size
|
||||
if end > len(slice) {
|
||||
end = len(slice)
|
||||
}
|
||||
|
||||
return slice[skip:end]
|
||||
var res struct{ Count int64 }
|
||||
err = r.ormer.Raw(query, args...).QueryRow(&res)
|
||||
return res.Count > 0, err
|
||||
}
|
||||
|
||||
func difference(slice1 []string, slice2 []string) []string {
|
||||
var diffStr []string
|
||||
m := map[string]int{}
|
||||
|
||||
for _, s1Val := range slice1 {
|
||||
m[s1Val] = 1
|
||||
func (r sqlRepository) count(countQuery SelectBuilder, options ...model.QueryOptions) (int64, error) {
|
||||
countQuery = countQuery.Columns("count(*) as count").From(r.tableName)
|
||||
countQuery = r.applyOptions(countQuery, options...)
|
||||
query, args, err := r.toSql(countQuery)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
for _, s2Val := range slice2 {
|
||||
m[s2Val] = m[s2Val] + 1
|
||||
var res struct{ Count int64 }
|
||||
err = r.ormer.Raw(query, args...).QueryRow(&res)
|
||||
if err == orm.ErrNoRows {
|
||||
return 0, model.ErrNotFound
|
||||
}
|
||||
return res.Count, nil
|
||||
}
|
||||
|
||||
for mKey, mVal := range m {
|
||||
if mVal == 1 {
|
||||
diffStr = append(diffStr, mKey)
|
||||
func (r sqlRepository) delete(cond Sqlizer) error {
|
||||
del := Delete(r.tableName).Where(cond)
|
||||
_, err := r.executeSQL(del)
|
||||
if err == orm.ErrNoRows {
|
||||
return model.ErrNotFound
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
func (r sqlRepository) toSql(sq Sqlizer) (string, []interface{}, error) {
|
||||
sql, args, err := sq.ToSql()
|
||||
if err == nil {
|
||||
log.Trace(r.ctx, "SQL: `"+sql+"`", "args", strings.TrimPrefix(fmt.Sprintf("%#v", args), "[]interface {}"))
|
||||
}
|
||||
return sql, args, err
|
||||
}
|
||||
|
||||
func (r sqlRepository) parseRestOptions(options ...rest.QueryOptions) model.QueryOptions {
|
||||
qo := model.QueryOptions{}
|
||||
if len(options) > 0 {
|
||||
qo.Sort = toSnakeCase(options[0].Sort)
|
||||
qo.Order = options[0].Order
|
||||
qo.Max = options[0].Max
|
||||
qo.Offset = options[0].Offset
|
||||
if len(options[0].Filters) > 0 {
|
||||
for f, v := range options[0].Filters {
|
||||
qo.Filters = Like{f: fmt.Sprintf("%s%%", v)}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return diffStr
|
||||
}
|
||||
|
||||
func (r *sqlRepository) Delete(id string) error {
|
||||
_, err := r.newQuery().Filter("id", id).Delete()
|
||||
return err
|
||||
}
|
||||
|
||||
func (r *sqlRepository) DeleteAll() error {
|
||||
_, err := r.newQuery().Filter("id__isnull", false).Delete()
|
||||
return err
|
||||
return qo
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user