Opt: db handle not found func

pull/40/head
zijiren233 2 years ago
parent 71d1b96e69
commit b3f9733e35

@ -20,11 +20,35 @@ import (
) )
func InitDatabase(ctx context.Context) (err error) { func InitDatabase(ctx context.Context) (err error) {
var dialector gorm.Dialector dialector, err := createDialector(conf.Conf.Database)
if err != nil {
log.Fatalf("failed to create dialector: %s", err.Error())
}
var opts []gorm.Option var opts []gorm.Option
opts = append(opts, &gorm.Config{
TranslateError: true,
Logger: newDBLogger(),
PrepareStmt: true,
})
d, err := gorm.Open(dialector, opts...)
if err != nil {
log.Fatalf("failed to connect database: %s", err.Error())
}
sqlDB, err := d.DB()
if err != nil {
log.Fatalf("failed to get sqlDB: %s", err.Error())
}
if conf.Conf.Database.Type != conf.DatabaseTypeSqlite3 {
initRawDB(sqlDB)
}
return db.Init(d, conf.Conf.Database.Type)
}
func createDialector(dbConf conf.DatabaseConfig) (dialector gorm.Dialector, err error) {
var dsn string
switch conf.Conf.Database.Type { switch conf.Conf.Database.Type {
case conf.DatabaseTypeMysql: case conf.DatabaseTypeMysql:
var dsn string
if conf.Conf.Database.CustomDSN != "" { if conf.Conf.Database.CustomDSN != "" {
dsn = conf.Conf.Database.CustomDSN dsn = conf.Conf.Database.CustomDSN
} else if conf.Conf.Database.Port == 0 { } else if conf.Conf.Database.Port == 0 {
@ -55,9 +79,7 @@ func InitDatabase(ctx context.Context) (err error) {
DontSupportRenameColumn: true, DontSupportRenameColumn: true,
SkipInitializeWithVersion: false, SkipInitializeWithVersion: false,
}) })
// opts = append(opts, &gorm.Config{})
case conf.DatabaseTypeSqlite3: case conf.DatabaseTypeSqlite3:
var dsn string
if conf.Conf.Database.CustomDSN != "" { if conf.Conf.Database.CustomDSN != "" {
dsn = conf.Conf.Database.CustomDSN dsn = conf.Conf.Database.CustomDSN
} else if conf.Conf.Database.DBName == "memory" || strings.HasPrefix(conf.Conf.Database.DBName, ":memory:") { } else if conf.Conf.Database.DBName == "memory" || strings.HasPrefix(conf.Conf.Database.DBName, ":memory:") {
@ -75,9 +97,7 @@ func InitDatabase(ctx context.Context) (err error) {
log.Infof("sqlite3 database file: %s", conf.Conf.Database.DBName) log.Infof("sqlite3 database file: %s", conf.Conf.Database.DBName)
} }
dialector = sqlite.Open(dsn) dialector = sqlite.Open(dsn)
// opts = append(opts, &gorm.Config{})
case conf.DatabaseTypePostgres: case conf.DatabaseTypePostgres:
var dsn string
if conf.Conf.Database.CustomDSN != "" { if conf.Conf.Database.CustomDSN != "" {
dsn = conf.Conf.Database.CustomDSN dsn = conf.Conf.Database.CustomDSN
} else if conf.Conf.Database.Port == 0 { } else if conf.Conf.Database.Port == 0 {
@ -104,27 +124,10 @@ func InitDatabase(ctx context.Context) (err error) {
DSN: dsn, DSN: dsn,
PreferSimpleProtocol: true, PreferSimpleProtocol: true,
}) })
// opts = append(opts, &gorm.Config{})
default: default:
log.Fatalf("unknown database type: %s", conf.Conf.Database.Type) log.Fatalf("unknown database type: %s", conf.Conf.Database.Type)
} }
opts = append(opts, &gorm.Config{ return
TranslateError: true,
Logger: newDBLogger(),
PrepareStmt: true,
})
d, err := gorm.Open(dialector, opts...)
if err != nil {
log.Fatalf("failed to connect database: %s", err.Error())
}
sqlDB, err := d.DB()
if err != nil {
log.Fatalf("failed to get sqlDB: %s", err.Error())
}
if conf.Conf.Database.Type != conf.DatabaseTypeSqlite3 {
initRawDB(sqlDB)
}
return db.Init(d, conf.Conf.Database.Type)
} }
func newDBLogger() logger.Interface { func newDBLogger() logger.Interface {

@ -3,6 +3,7 @@ package db
import ( import (
"errors" "errors"
"fmt" "fmt"
"strings"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"github.com/synctv-org/synctv/internal/conf" "github.com/synctv-org/synctv/internal/conf"
@ -284,3 +285,23 @@ func WhereRoomUserStatus(status model.RoomUserStatus) func(db *gorm.DB) *gorm.DB
return db.Where("status = ?", status) return db.Where("status = ?", status)
} }
} }
func HandleNotFound(err error, errMsg ...string) error {
if err != nil && errors.Is(err, gorm.ErrRecordNotFound) {
return fmt.Errorf("not found: %s", strings.Join(errMsg, " "))
}
return err
}
func Transactional(txFunc func(*gorm.DB) error) (err error) {
tx := db.Begin()
defer func() {
if err != nil {
tx.Rollback()
} else {
tx.Commit()
}
}()
err = txFunc(tx)
return
}

@ -1,9 +1,6 @@
package db package db
import ( import (
"errors"
"fmt"
"github.com/synctv-org/synctv/internal/model" "github.com/synctv-org/synctv/internal/model"
"gorm.io/gorm" "gorm.io/gorm"
"gorm.io/gorm/clause" "gorm.io/gorm/clause"
@ -25,86 +22,53 @@ func GetAllMoviesByRoomID(roomID string) []*model.Movie {
func DeleteMovieByID(roomID, id string) error { func DeleteMovieByID(roomID, id string) error {
err := db.Unscoped().Where("room_id = ? AND id = ?", roomID, id).Delete(&model.Movie{}).Error err := db.Unscoped().Where("room_id = ? AND id = ?", roomID, id).Delete(&model.Movie{}).Error
if err != nil && errors.Is(err, gorm.ErrRecordNotFound) { return HandleNotFound(err, "room or movie")
return errors.New("room or movie not found")
}
return err
} }
func LoadAndDeleteMovieByID(roomID, id string, columns ...clause.Column) (*model.Movie, error) { func LoadAndDeleteMovieByID(roomID, id string, columns ...clause.Column) (*model.Movie, error) {
movie := &model.Movie{} movie := &model.Movie{}
err := db.Unscoped().Clauses(clause.Returning{Columns: columns}).Where("room_id = ? AND id = ?", roomID, id).Delete(movie).Error err := db.Unscoped().Clauses(clause.Returning{Columns: columns}).Where("room_id = ? AND id = ?", roomID, id).Delete(movie).Error
if err != nil && errors.Is(err, gorm.ErrRecordNotFound) { return movie, HandleNotFound(err, "room or movie")
return movie, errors.New("room or movie not found")
}
return movie, err
} }
func DeleteMoviesByRoomID(roomID string) error { func DeleteMoviesByRoomID(roomID string) error {
err := db.Unscoped().Where("room_id = ?", roomID).Delete(&model.Movie{}).Error err := db.Unscoped().Where("room_id = ?", roomID).Delete(&model.Movie{}).Error
if err != nil && errors.Is(err, gorm.ErrRecordNotFound) { return HandleNotFound(err, "room")
return errors.New("room not found")
}
return err
} }
func LoadAndDeleteMoviesByRoomID(roomID string, columns ...clause.Column) ([]*model.Movie, error) { func LoadAndDeleteMoviesByRoomID(roomID string, columns ...clause.Column) ([]*model.Movie, error) {
movies := []*model.Movie{} movies := []*model.Movie{}
err := db.Unscoped().Clauses(clause.Returning{Columns: columns}).Where("room_id = ?", roomID).Delete(&movies).Error err := db.Unscoped().Clauses(clause.Returning{Columns: columns}).Where("room_id = ?", roomID).Delete(&movies).Error
if err != nil && errors.Is(err, gorm.ErrRecordNotFound) { return movies, HandleNotFound(err, "room")
return nil, errors.New("room not found")
}
return movies, err
} }
func UpdateMovie(movie *model.Movie, columns ...clause.Column) error { func UpdateMovie(movie *model.Movie, columns ...clause.Column) error {
err := db.Model(movie).Clauses(clause.Returning{Columns: columns}).Where("room_id = ? AND id = ?", movie.RoomID, movie.ID).Updates(movie).Error err := db.Model(movie).Clauses(clause.Returning{Columns: columns}).Where("room_id = ? AND id = ?", movie.RoomID, movie.ID).Updates(movie).Error
if err != nil && errors.Is(err, gorm.ErrRecordNotFound) { return HandleNotFound(err, "room or movie")
return errors.New("room or movie not found")
}
return err
} }
func SaveMovie(movie *model.Movie, columns ...clause.Column) error { func SaveMovie(movie *model.Movie, columns ...clause.Column) error {
err := db.Model(movie).Clauses(clause.Returning{Columns: columns}).Where("room_id = ? AND id = ?", movie.RoomID, movie.ID).Save(movie).Error err := db.Model(movie).Clauses(clause.Returning{Columns: columns}).Where("room_id = ? AND id = ?", movie.RoomID, movie.ID).Save(movie).Error
if err != nil && errors.Is(err, gorm.ErrRecordNotFound) { return HandleNotFound(err, "room or movie")
return errors.New("room or movie not found")
}
return err
} }
func SwapMoviePositions(roomID, movie1ID, movie2ID string) (err error) { func SwapMoviePositions(roomID, movie1ID, movie2ID string) (err error) {
tx := db.Begin() return Transactional(func(tx *gorm.DB) error {
defer func() { movie1 := &model.Movie{}
movie2 := &model.Movie{}
err = tx.Where("room_id = ? AND id = ?", roomID, movie1ID).First(movie1).Error
if err != nil { if err != nil {
tx.Rollback() return HandleNotFound(err, "movie1")
} else {
tx.Commit()
} }
}() err = tx.Where("room_id = ? AND id = ?", roomID, movie2ID).First(movie2).Error
movie1 := &model.Movie{} if err != nil {
movie2 := &model.Movie{} return HandleNotFound(err, "movie2")
err = tx.Select("position").Where("room_id = ? AND id = ?", roomID, movie1ID).First(movie1).Error
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
err = fmt.Errorf("movie with id %s not found", movie1ID)
} }
return movie1.Position, movie2.Position = movie2.Position, movie1.Position
} err = tx.Save(movie1).Error
err = tx.Select("position").Where("room_id = ? AND id = ?", roomID, movie2ID).First(movie2).Error if err != nil {
if err != nil { return err
if errors.Is(err, gorm.ErrRecordNotFound) {
err = fmt.Errorf("movie with id %s not found", movie2ID)
} }
return return tx.Save(movie2).Error
} })
err = tx.Model(&model.Movie{}).Where("room_id = ? AND id = ?", roomID, movie1ID).Update("position", movie2.Position).Error
if err != nil {
return
}
err = tx.Model(&model.Movie{}).Where("room_id = ? AND id = ?", roomID, movie2ID).Update("position", movie1.Position).Error
if err != nil {
return
}
return
} }

@ -1,8 +1,6 @@
package db package db
import ( import (
"errors"
"github.com/synctv-org/synctv/internal/model" "github.com/synctv-org/synctv/internal/model"
"gorm.io/gorm" "gorm.io/gorm"
) )
@ -38,42 +36,27 @@ func FirstOrCreateRoomUserRelation(roomID, userID string, conf ...CreateRoomUser
func GetRoomUserRelation(roomID, userID string) (*model.RoomUserRelation, error) { func GetRoomUserRelation(roomID, userID string) (*model.RoomUserRelation, error) {
roomUserRelation := &model.RoomUserRelation{} roomUserRelation := &model.RoomUserRelation{}
err := db.Where("room_id = ? AND user_id = ?", roomID, userID).First(roomUserRelation).Error err := db.Where("room_id = ? AND user_id = ?", roomID, userID).First(roomUserRelation).Error
if err != nil && errors.Is(err, gorm.ErrRecordNotFound) { return roomUserRelation, HandleNotFound(err, "room or user")
return roomUserRelation, errors.New("room or user not found")
}
return roomUserRelation, err
} }
func SetRoomUserStatus(roomID string, userID string, status model.RoomUserStatus) error { func SetRoomUserStatus(roomID string, userID string, status model.RoomUserStatus) error {
err := db.Model(&model.RoomUserRelation{}).Where("room_id = ? AND user_id = ?", roomID, userID).Update("status", status).Error err := db.Model(&model.RoomUserRelation{}).Where("room_id = ? AND user_id = ?", roomID, userID).Update("status", status).Error
if err != nil && errors.Is(err, gorm.ErrRecordNotFound) { return HandleNotFound(err, "room or user")
return errors.New("room or user not found")
}
return err
} }
func SetUserPermission(roomID string, userID string, permission model.RoomUserPermission) error { func SetUserPermission(roomID string, userID string, permission model.RoomUserPermission) error {
err := db.Model(&model.RoomUserRelation{}).Where("room_id = ? AND user_id = ?", roomID, userID).Update("permissions", permission).Error err := db.Model(&model.RoomUserRelation{}).Where("room_id = ? AND user_id = ?", roomID, userID).Update("permissions", permission).Error
if err != nil && errors.Is(err, gorm.ErrRecordNotFound) { return HandleNotFound(err, "room or user")
return errors.New("room or user not found")
}
return err
} }
func AddUserPermission(roomID string, userID string, permission model.RoomUserPermission) error { func AddUserPermission(roomID string, userID string, permission model.RoomUserPermission) error {
err := db.Model(&model.RoomUserRelation{}).Where("room_id = ? AND user_id = ?", roomID, userID).Update("permissions", db.Raw("permissions | ?", permission)).Error err := db.Model(&model.RoomUserRelation{}).Where("room_id = ? AND user_id = ?", roomID, userID).Update("permissions", db.Raw("permissions | ?", permission)).Error
if err != nil && errors.Is(err, gorm.ErrRecordNotFound) { return HandleNotFound(err, "room or user")
return errors.New("room or user not found")
}
return err
} }
func RemoveUserPermission(roomID string, userID string, permission model.RoomUserPermission) error { func RemoveUserPermission(roomID string, userID string, permission model.RoomUserPermission) error {
err := db.Model(&model.RoomUserRelation{}).Where("room_id = ? AND user_id = ?", roomID, userID).Update("permissions", db.Raw("permissions & ?", ^permission)).Error err := db.Model(&model.RoomUserRelation{}).Where("room_id = ? AND user_id = ?", roomID, userID).Update("permissions", db.Raw("permissions & ?", ^permission)).Error
if err != nil && errors.Is(err, gorm.ErrRecordNotFound) { return HandleNotFound(err, "room or user")
return errors.New("room or user not found")
}
return err
} }
func GetAllRoomUsersRelation(roomID string, scopes ...func(*gorm.DB) *gorm.DB) []*model.RoomUserRelation { func GetAllRoomUsersRelation(roomID string, scopes ...func(*gorm.DB) *gorm.DB) []*model.RoomUserRelation {

@ -58,25 +58,23 @@ func CreateRoom(name, password string, maxCount int64, conf ...CreateRoomConfig)
} }
} }
tx := db.Begin() return r, Transactional(func(tx *gorm.DB) error {
if maxCount != 0 { if maxCount != 0 {
var count int64 var count int64
tx.Model(&model.Room{}).Where("creator_id = ?", r.CreatorID).Count(&count) tx.Model(&model.Room{}).Where("creator_id = ?", r.CreatorID).Count(&count)
if count >= maxCount { if count >= maxCount {
tx.Rollback() return errors.New("room count is over limit")
return nil, errors.New("room count is over limit") }
} }
} err := tx.Create(r).Error
err := tx.Create(r).Error if err != nil {
if err != nil { if errors.Is(err, gorm.ErrDuplicatedKey) {
tx.Rollback() return errors.New("room already exists")
if errors.Is(err, gorm.ErrDuplicatedKey) { }
return r, errors.New("room already exists") return err
} }
return r, err return nil
} })
tx.Commit()
return r, nil
} }
func GetRoomByID(id string) (*model.Room, error) { func GetRoomByID(id string) (*model.Room, error) {
@ -85,50 +83,17 @@ func GetRoomByID(id string) (*model.Room, error) {
} }
r := &model.Room{} r := &model.Room{}
err := db.Where("id = ?", id).First(r).Error err := db.Where("id = ?", id).First(r).Error
if err != nil && errors.Is(err, gorm.ErrRecordNotFound) { return r, HandleNotFound(err, "room")
return r, errors.New("room not found")
}
return r, err
} }
func SaveRoomSettings(roomID string, setting model.RoomSettings) error { func SaveRoomSettings(roomID string, setting model.RoomSettings) error {
err := db.Model(&model.Room{}).Where("id = ?", roomID).Update("setting", setting).Error err := db.Model(&model.Room{}).Where("id = ?", roomID).Update("setting", setting).Error
if err != nil && errors.Is(err, gorm.ErrRecordNotFound) { return HandleNotFound(err, "room")
return errors.New("room not found")
}
return err
} }
func DeleteRoomByID(roomID string) error { func DeleteRoomByID(roomID string) error {
err := db.Unscoped().Where("id = ?", roomID).Delete(&model.Room{}).Error err := db.Unscoped().Where("id = ?", roomID).Delete(&model.Room{}).Error
if err != nil && errors.Is(err, gorm.ErrRecordNotFound) { return HandleNotFound(err, "room")
return errors.New("room not found")
}
return err
}
func HasRoom(roomID string) (bool, error) {
r := &model.Room{}
err := db.Where("id = ?", roomID).First(r).Error
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
err = nil
}
return false, err
}
return true, nil
}
func HasRoomByName(name string) (bool, error) {
r := &model.Room{}
err := db.Where("name = ?", name).First(r).Error
if err != nil {
if errors.Is(err, gorm.ErrRecordNotFound) {
err = nil
}
return false, err
}
return true, nil
} }
func SetRoomPassword(roomID, password string) error { func SetRoomPassword(roomID, password string) error {
@ -145,10 +110,7 @@ func SetRoomPassword(roomID, password string) error {
func SetRoomHashedPassword(roomID string, hashedPassword []byte) error { func SetRoomHashedPassword(roomID string, hashedPassword []byte) error {
err := db.Model(&model.Room{}).Where("id = ?", roomID).Update("hashed_password", hashedPassword).Error err := db.Model(&model.Room{}).Where("id = ?", roomID).Update("hashed_password", hashedPassword).Error
if err != nil && errors.Is(err, gorm.ErrRecordNotFound) { return HandleNotFound(err, "room")
return errors.New("room not found")
}
return err
} }
func GetAllRooms(scopes ...func(*gorm.DB) *gorm.DB) []*model.Room { func GetAllRooms(scopes ...func(*gorm.DB) *gorm.DB) []*model.Room {
@ -177,10 +139,7 @@ func GetAllRoomsByUserID(userID string) []*model.Room {
func SetRoomStatus(roomID string, status model.RoomStatus) error { func SetRoomStatus(roomID string, status model.RoomStatus) error {
err := db.Model(&model.Room{}).Where("id = ?", roomID).Update("status", status).Error err := db.Model(&model.Room{}).Where("id = ?", roomID).Update("status", status).Error
if err != nil && errors.Is(err, gorm.ErrRecordNotFound) { return HandleNotFound(err, "room")
return errors.New("room not found")
}
return err
} }
func SetRoomStatusByCreator(userID string, status model.RoomStatus) error { func SetRoomStatusByCreator(userID string, status model.RoomStatus) error {

@ -143,26 +143,14 @@ func CreateOrLoadUserWithProvider(username, password string, p provider.OAuth2Pr
func GetUserByProvider(p provider.OAuth2Provider, puid string) (*model.User, error) { func GetUserByProvider(p provider.OAuth2Provider, puid string) (*model.User, error) {
var user model.User var user model.User
if err := db.Where("id = (?)", db.Table("user_providers").Where("provider = ? AND provider_user_id = ?", p, puid).Select("user_id")).First(&user).Error; err != nil { err := db.Where("id = (?)", db.Table("user_providers").Where("provider = ? AND provider_user_id = ?", p, puid).Select("user_id")).First(&user).Error
if errors.Is(err, gorm.ErrRecordNotFound) { return &user, HandleNotFound(err, "user")
return &user, errors.New("user not found")
} else {
return &user, err
}
}
return &user, nil
} }
func GetProviderUserID(p provider.OAuth2Provider, puid string) (string, error) { func GetProviderUserID(p provider.OAuth2Provider, puid string) (string, error) {
var userProvider model.UserProvider var userProvider model.UserProvider
if err := db.Where("provider = ? AND provider_user_id = ?", p, puid).Select("user_id").First(&userProvider).Error; err != nil { err := db.Where("provider = ? AND provider_user_id = ?", p, puid).Select("user_id").First(&userProvider).Error
if errors.Is(err, gorm.ErrRecordNotFound) { return userProvider.UserID, HandleNotFound(err, "user")
return "", errors.New("user not found")
} else {
return "", err
}
}
return userProvider.UserID, nil
} }
func BindProvider(uid string, p provider.OAuth2Provider, puid string) error { func BindProvider(uid string, p provider.OAuth2Provider, puid string) error {
@ -179,45 +167,31 @@ func BindProvider(uid string, p provider.OAuth2Provider, puid string) error {
// 当用户是通过provider注册的时候则最少保留一个provider否则禁止解除绑定 // 当用户是通过provider注册的时候则最少保留一个provider否则禁止解除绑定
func UnBindProvider(uid string, p provider.OAuth2Provider) error { func UnBindProvider(uid string, p provider.OAuth2Provider) error {
tx := db.Begin() return Transactional(func(tx *gorm.DB) error {
user := model.User{} user := model.User{}
if err := tx.Scopes(PreloadUserProviders()).Where("id = ?", uid).First(&user).Error; err != nil { if err := tx.Scopes(PreloadUserProviders()).Where("id = ?", uid).First(&user).Error; err != nil {
tx.Rollback() return HandleNotFound(err, "user")
if errors.Is(err, gorm.ErrRecordNotFound) {
return errors.New("user not found")
} }
return err if user.RegisteredByProvider && len(user.UserProviders) == 1 {
} return errors.New("user must have at least one provider")
if user.RegisteredByProvider && len(user.UserProviders) == 1 {
tx.Rollback()
return errors.New("user must have at least one provider")
}
if err := tx.Where("user_id = ? AND provider = ?", uid, p).Delete(&model.UserProvider{}).Error; err != nil {
tx.Rollback()
if errors.Is(err, gorm.ErrRecordNotFound) {
return errors.New("user could not bind provider")
} }
return err if err := tx.Where("user_id = ? AND provider = ?", uid, p).Delete(&model.UserProvider{}).Error; err != nil {
} return HandleNotFound(err, "provider")
return tx.Commit().Error }
return nil
})
} }
func GetBindProviders(uid string) ([]*model.UserProvider, error) { func GetBindProviders(uid string) ([]*model.UserProvider, error) {
var providers []*model.UserProvider var providers []*model.UserProvider
err := db.Where("user_id = ?", uid).Find(&providers).Error err := db.Where("user_id = ?", uid).Find(&providers).Error
if errors.Is(err, gorm.ErrRecordNotFound) { return providers, HandleNotFound(err, "user")
return providers, errors.New("user not found")
}
return providers, err
} }
func GetUserByUsername(username string) (*model.User, error) { func GetUserByUsername(username string) (*model.User, error) {
u := &model.User{} u := &model.User{}
err := db.Where("username = ?", username).First(u).Error err := db.Where("username = ?", username).First(u).Error
if errors.Is(err, gorm.ErrRecordNotFound) { return u, HandleNotFound(err, "user")
return u, errors.New("user not found")
}
return u, err
} }
func GetUserByUsernameLike(username string, scopes ...func(*gorm.DB) *gorm.DB) []*model.User { func GetUserByUsernameLike(username string, scopes ...func(*gorm.DB) *gorm.DB) []*model.User {
@ -241,10 +215,7 @@ func GerUsersIDByIDLike(id string, scopes ...func(*gorm.DB) *gorm.DB) []string {
func GetUserByIDOrUsernameLike(idOrUsername string, scopes ...func(*gorm.DB) *gorm.DB) ([]*model.User, error) { func GetUserByIDOrUsernameLike(idOrUsername string, scopes ...func(*gorm.DB) *gorm.DB) ([]*model.User, error) {
var users []*model.User var users []*model.User
err := db.Where("id = ? OR username LIKE ?", idOrUsername, fmt.Sprintf("%%%s%%", idOrUsername)).Scopes(scopes...).Find(&users).Error err := db.Where("id = ? OR username LIKE ?", idOrUsername, fmt.Sprintf("%%%s%%", idOrUsername)).Scopes(scopes...).Find(&users).Error
if errors.Is(err, gorm.ErrRecordNotFound) { return users, HandleNotFound(err, "user")
return users, errors.New("user not found")
}
return users, err
} }
func GetUserByID(id string) (*model.User, error) { func GetUserByID(id string) (*model.User, error) {
@ -253,10 +224,7 @@ func GetUserByID(id string) (*model.User, error) {
} }
u := &model.User{} u := &model.User{}
err := db.Where("id = ?", id).First(u).Error err := db.Where("id = ?", id).First(u).Error
if errors.Is(err, gorm.ErrRecordNotFound) { return u, HandleNotFound(err, "user")
return u, errors.New("user not found")
}
return u, err
} }
func BanUser(u *model.User) error { func BanUser(u *model.User) error {
@ -269,10 +237,7 @@ func BanUser(u *model.User) error {
func BanUserByID(userID string) error { func BanUserByID(userID string) error {
err := db.Model(&model.User{}).Where("id = ?", userID).Update("role", model.RoleBanned).Error err := db.Model(&model.User{}).Where("id = ?", userID).Update("role", model.RoleBanned).Error
if errors.Is(err, gorm.ErrRecordNotFound) { return HandleNotFound(err, "user")
return errors.New("user not found")
}
return err
} }
func UnbanUser(u *model.User) error { func UnbanUser(u *model.User) error {
@ -285,18 +250,12 @@ func UnbanUser(u *model.User) error {
func UnbanUserByID(userID string) error { func UnbanUserByID(userID string) error {
err := db.Model(&model.User{}).Where("id = ?", userID).Update("role", model.RoleUser).Error err := db.Model(&model.User{}).Where("id = ?", userID).Update("role", model.RoleUser).Error
if errors.Is(err, gorm.ErrRecordNotFound) { return HandleNotFound(err, "user")
return errors.New("user not found")
}
return err
} }
func DeleteUserByID(userID string) error { func DeleteUserByID(userID string) error {
err := db.Unscoped().Where("id = ?", userID).Delete(&model.User{}).Error err := db.Unscoped().Where("id = ?", userID).Delete(&model.User{}).Error
if errors.Is(err, gorm.ErrRecordNotFound) { return HandleNotFound(err, "user")
return errors.New("user not found")
}
return err
} }
func LoadAndDeleteUserByID(userID string, columns ...clause.Column) (*model.User, error) { func LoadAndDeleteUserByID(userID string, columns ...clause.Column) (*model.User, error) {
@ -338,18 +297,12 @@ func GetAdmins() []*model.User {
func AddAdminByID(userID string) error { func AddAdminByID(userID string) error {
err := db.Model(&model.User{}).Where("id = ?", userID).Update("role", model.RoleAdmin).Error err := db.Model(&model.User{}).Where("id = ?", userID).Update("role", model.RoleAdmin).Error
if errors.Is(err, gorm.ErrRecordNotFound) { return HandleNotFound(err, "user")
return errors.New("user not found")
}
return err
} }
func RemoveAdminByID(userID string) error { func RemoveAdminByID(userID string) error {
err := db.Model(&model.User{}).Where("id = ?", userID).Update("role", model.RoleUser).Error err := db.Model(&model.User{}).Where("id = ?", userID).Update("role", model.RoleUser).Error
if errors.Is(err, gorm.ErrRecordNotFound) { return HandleNotFound(err, "user")
return errors.New("user not found")
}
return err
} }
func AddRoot(u *model.User) error { func AddRoot(u *model.User) error {
@ -370,18 +323,12 @@ func RemoveRoot(u *model.User) error {
func AddRootByID(userID string) error { func AddRootByID(userID string) error {
err := db.Model(&model.User{}).Where("id = ?", userID).Update("role", model.RoleRoot).Error err := db.Model(&model.User{}).Where("id = ?", userID).Update("role", model.RoleRoot).Error
if errors.Is(err, gorm.ErrRecordNotFound) { return HandleNotFound(err, "user")
return errors.New("user not found")
}
return err
} }
func RemoveRootByID(userID string) error { func RemoveRootByID(userID string) error {
err := db.Model(&model.User{}).Where("id = ?", userID).Update("role", model.RoleUser).Error err := db.Model(&model.User{}).Where("id = ?", userID).Update("role", model.RoleUser).Error
if errors.Is(err, gorm.ErrRecordNotFound) { return HandleNotFound(err, "user")
return errors.New("user not found")
}
return err
} }
func GetRoots() []*model.User { func GetRoots() []*model.User {
@ -397,18 +344,12 @@ func SetRole(u *model.User, role model.Role) error {
func SetRoleByID(userID string, role model.Role) error { func SetRoleByID(userID string, role model.Role) error {
err := db.Model(&model.User{}).Where("id = ?", userID).Update("role", role).Error err := db.Model(&model.User{}).Where("id = ?", userID).Update("role", role).Error
if errors.Is(err, gorm.ErrRecordNotFound) { return HandleNotFound(err, "user")
return errors.New("user not found")
}
return err
} }
func SetUsernameByID(userID string, username string) error { func SetUsernameByID(userID string, username string) error {
err := db.Model(&model.User{}).Where("id = ?", userID).Update("username", username).Error err := db.Model(&model.User{}).Where("id = ?", userID).Update("username", username).Error
if errors.Is(err, gorm.ErrRecordNotFound) { return HandleNotFound(err, "user")
return errors.New("user not found")
}
return err
} }
func GetAllUserCount(scopes ...func(*gorm.DB) *gorm.DB) int64 { func GetAllUserCount(scopes ...func(*gorm.DB) *gorm.DB) int64 {
@ -425,8 +366,5 @@ func GetAllUsers(scopes ...func(*gorm.DB) *gorm.DB) []*model.User {
func SetUserHashedPassword(id string, hashedPassword []byte) error { func SetUserHashedPassword(id string, hashedPassword []byte) error {
err := db.Model(&model.User{}).Where("id = ?", id).Update("hashed_password", hashedPassword).Error err := db.Model(&model.User{}).Where("id = ?", id).Update("hashed_password", hashedPassword).Error
if err != nil && errors.Is(err, gorm.ErrRecordNotFound) { return HandleNotFound(err, "user")
return errors.New("user not found")
}
return err
} }

@ -1,11 +1,9 @@
package db package db
import ( import (
"errors"
"net/http" "net/http"
"github.com/synctv-org/synctv/internal/model" "github.com/synctv-org/synctv/internal/model"
"gorm.io/gorm"
) )
func GetVendorByUserID(userID string) ([]*model.StreamingVendorInfo, error) { func GetVendorByUserID(userID string) ([]*model.StreamingVendorInfo, error) {
@ -20,10 +18,7 @@ func GetVendorByUserID(userID string) ([]*model.StreamingVendorInfo, error) {
func GetVendorByUserIDAndVendor(userID string, vendor model.StreamingVendor) (*model.StreamingVendorInfo, error) { func GetVendorByUserIDAndVendor(userID string, vendor model.StreamingVendor) (*model.StreamingVendorInfo, error) {
var vendorInfo model.StreamingVendorInfo var vendorInfo model.StreamingVendorInfo
err := db.Where("user_id = ? AND vendor = ?", userID, vendor).First(&vendorInfo).Error err := db.Where("user_id = ? AND vendor = ?", userID, vendor).First(&vendorInfo).Error
if err != nil && errors.Is(err, gorm.ErrRecordNotFound) { return &vendorInfo, HandleNotFound(err, "vendor")
return nil, errors.New("vendor not found")
}
return &vendorInfo, err
} }
type CreateVendorConfig func(*model.StreamingVendorInfo) type CreateVendorConfig func(*model.StreamingVendorInfo)

@ -122,26 +122,6 @@ func PeopleNum(roomID string) int64 {
return 0 return 0
} }
func HasRoom(roomID string) bool {
_, ok := roomCache.Load(roomID)
if ok {
return true
}
ok, err := db.HasRoom(roomID)
if err != nil {
return false
}
return ok
}
func HasRoomByName(name string) bool {
ok, err := db.HasRoomByName(name)
if err != nil {
return false
}
return ok
}
func GetAllRoomsInCacheWithNoNeedPassword() []*Room { func GetAllRoomsInCacheWithNoNeedPassword() []*Room {
rooms := make([]*Room, 0) rooms := make([]*Room, 0)
roomCache.Range(func(key string, value *synccache.Entry[*Room]) bool { roomCache.Range(func(key string, value *synccache.Entry[*Room]) bool {

@ -12,27 +12,25 @@ import (
) )
func Init(e *gin.Engine) { func Init(e *gin.Engine) {
{ e.GET("/", func(ctx *gin.Context) {
e.GET("/", func(ctx *gin.Context) { ctx.Redirect(http.StatusMovedPermanently, "/web/")
ctx.Redirect(http.StatusMovedPermanently, "/web/") })
})
web := e.Group("/web") web := e.Group("/web")
web.Use(middlewares.NewDistCacheControl("/web/")) web.Use(middlewares.NewDistCacheControl("/web/"))
err := initFSRouter(web, public.Public.(fs.ReadDirFS), ".") err := initFSRouter(web, public.Public.(fs.ReadDirFS), ".")
if err != nil { if err != nil {
panic(err) panic(err)
}
e.NoRoute(func(ctx *gin.Context) {
if strings.HasPrefix(ctx.Request.URL.Path, "/web/") {
ctx.FileFromFS("", http.FS(public.Public))
return
}
})
} }
e.NoRoute(func(ctx *gin.Context) {
if strings.HasPrefix(ctx.Request.URL.Path, "/web/") {
ctx.FileFromFS("", http.FS(public.Public))
return
}
})
} }
func initFSRouter(e *gin.RouterGroup, f fs.ReadDirFS, path string) error { func initFSRouter(e *gin.RouterGroup, f fs.ReadDirFS, path string) error {

@ -92,12 +92,8 @@ func In[T comparable](items []T, item T) bool {
} }
func Exists(name string) bool { func Exists(name string) bool {
if _, err := os.Stat(name); err != nil { _, err := os.Stat(name)
if os.IsNotExist(err) { return !os.IsNotExist(err)
return false
}
}
return true
} }
func WriteYaml(file string, module any) error { func WriteYaml(file string, module any) error {

Loading…
Cancel
Save