diff --git a/go.mod b/go.mod index ed3b744..32ebebc 100644 --- a/go.mod +++ b/go.mod @@ -3,7 +3,6 @@ module github.com/synctv-org/synctv go 1.20 require ( - github.com/bluele/gcache v0.0.2 github.com/caarlos0/env/v9 v9.0.0 github.com/cavaliergopher/grab/v3 v3.0.1 github.com/gin-contrib/cors v1.4.0 diff --git a/go.sum b/go.sum index 0f580b8..b235774 100644 --- a/go.sum +++ b/go.sum @@ -19,8 +19,6 @@ github.com/beorn7/perks v0.0.0-20180321164747-3a771d992973/go.mod h1:Dwedo/Wpr24 github.com/beorn7/perks v1.0.0/go.mod h1:KWe93zE9D1o94FZ5RNwFwVgaQK1VOXiVxmqh+CedLV8= github.com/beorn7/perks v1.0.1/go.mod h1:G2ZrVWU2WbWT9wwq4/hrbKbnv/1ERSJQ0ibhJ6rlkpw= github.com/bgentry/speakeasy v0.1.0/go.mod h1:+zsyZBPWlz7T6j88CTgSN5bM796AkVf0kBD4zp0CCIs= -github.com/bluele/gcache v0.0.2 h1:WcbfdXICg7G/DGBh1PFfcirkWOQV+v077yF1pSy3DGw= -github.com/bluele/gcache v0.0.2/go.mod h1:m15KV+ECjptwSPxKhOhQoAFQVtUFjTVkc3H8o0t/fp0= github.com/bufbuild/protocompile v0.4.0 h1:LbFKd2XowZvQ/kajzguUp2DC9UEIQhIq77fZZlaQsNA= github.com/bytedance/sonic v1.5.0/go.mod h1:ED5hyg4y6t3/9Ku1R6dU/4KyJ48DZ4jPhfY1O2AihPM= github.com/bytedance/sonic v1.10.0-rc/go.mod h1:ElCzW+ufi8qKqNW0FY314xriJhyJhuoJ3gFZdAHF7NM= diff --git a/internal/db/db.go b/internal/db/db.go index dbd68b6..0d7f4d7 100644 --- a/internal/db/db.go +++ b/internal/db/db.go @@ -1,6 +1,7 @@ package db import ( + "errors" "fmt" log "github.com/sirupsen/logrus" @@ -35,9 +36,23 @@ func AutoMigrate(dst ...any) error { if err != nil { return err } + err = initRootUser() + if err != nil { + return err + } return upgradeDatabase() } +func initRootUser() error { + user := model.User{} + err := db.Where("role = ?", model.RoleRoot).First(&user).Error + if err == nil || !errors.Is(err, gorm.ErrRecordNotFound) { + return err + } + _, err = CreateUser("root", "root", WithRole(model.RoleRoot)) + return err +} + func DB() *gorm.DB { return db } @@ -102,17 +117,13 @@ func WithUser(db *gorm.DB) *gorm.DB { return db.Preload("User") } -func WithUserAndProvider(db *gorm.DB) *gorm.DB { - return db.Preload("User").Preload("User.Provider") -} - func WhereRoomID(roomID string) func(db *gorm.DB) *gorm.DB { return func(db *gorm.DB) *gorm.DB { return db.Where("room_id = ?", roomID) } } -func PreloadRoomUserRelation(scopes ...func(*gorm.DB) *gorm.DB) func(db *gorm.DB) *gorm.DB { +func PreloadRoomUserRelations(scopes ...func(*gorm.DB) *gorm.DB) func(db *gorm.DB) *gorm.DB { return func(db *gorm.DB) *gorm.DB { return db.Preload("RoomUserRelations", func(db *gorm.DB) *gorm.DB { return db.Scopes(scopes...) @@ -120,6 +131,14 @@ func PreloadRoomUserRelation(scopes ...func(*gorm.DB) *gorm.DB) func(db *gorm.DB } } +func PreloadUserProviders(scopes ...func(*gorm.DB) *gorm.DB) func(db *gorm.DB) *gorm.DB { + return func(db *gorm.DB) *gorm.DB { + return db.Preload("UserProviders", func(db *gorm.DB) *gorm.DB { + return db.Scopes(scopes...) + }) + } +} + func WhereUserID(userID string) func(db *gorm.DB) *gorm.DB { return func(db *gorm.DB) *gorm.DB { return db.Where("user_id = ?", userID) diff --git a/internal/db/user.go b/internal/db/user.go index 51d33e1..0c0d73a 100644 --- a/internal/db/user.go +++ b/internal/db/user.go @@ -7,6 +7,8 @@ import ( "github.com/synctv-org/synctv/internal/model" "github.com/synctv-org/synctv/internal/provider" "github.com/synctv-org/synctv/utils" + "github.com/zijiren233/stream" + "golang.org/x/crypto/bcrypt" "gorm.io/gorm" "gorm.io/gorm/clause" ) @@ -19,20 +21,60 @@ func WithRole(role model.Role) CreateUserConfig { } } -func CreateUser(username string, p provider.OAuth2Provider, puid string, conf ...CreateUserConfig) (*model.User, error) { +func WithAppendProvider(p provider.OAuth2Provider, puid string) CreateUserConfig { + return func(u *model.User) { + u.UserProviders = append(u.UserProviders, model.UserProvider{ + Provider: p, + ProviderUserID: puid, + }) + } +} + +func WithSetProvider(p provider.OAuth2Provider, puid string) CreateUserConfig { + return func(u *model.User) { + u.UserProviders = []model.UserProvider{{ + Provider: p, + ProviderUserID: puid, + }} + } +} + +func WithAppendProviders(providers []model.UserProvider) CreateUserConfig { + return func(u *model.User) { + u.UserProviders = append(u.UserProviders, providers...) + } +} + +func WithSetProviders(providers []model.UserProvider) CreateUserConfig { + return func(u *model.User) { + u.UserProviders = providers + } +} + +func WithRegisteredByProvider(b bool) CreateUserConfig { + return func(u *model.User) { + u.RegisteredByProvider = b + } +} + +func CreateUserWithHashedPassword(username string, hashedPassword []byte, conf ...CreateUserConfig) (*model.User, error) { + if username == "" { + return nil, errors.New("username cannot be empty") + } + if len(hashedPassword) == 0 { + return nil, errors.New("password cannot be empty") + } u := &model.User{ - Username: username, - Role: model.RoleUser, - Providers: []model.UserProvider{ - { - Provider: p, - ProviderUserID: puid, - }, - }, + Username: username, + Role: model.RoleUser, + HashedPassword: hashedPassword, } for _, c := range conf { c(u) } + if u.Role == 0 { + return nil, errors.New("role cannot be empty") + } err := db.Create(u).Error if err != nil && errors.Is(err, gorm.ErrDuplicatedKey) { return u, errors.New("user already exists") @@ -40,29 +82,80 @@ func CreateUser(username string, p provider.OAuth2Provider, puid string, conf .. return u, err } -// 只有当provider和puid没有找到对应的user时才会创建 -func CreateOrLoadUser(username string, p provider.OAuth2Provider, puid string, conf ...CreateUserConfig) (*model.User, error) { +func CreateUser(username string, password string, conf ...CreateUserConfig) (*model.User, error) { + if username == "" { + return nil, errors.New("username cannot be empty") + } + if password == "" { + return nil, errors.New("password cannot be empty") + } + hashedPassword, err := bcrypt.GenerateFromPassword(stream.StringToBytes(password), bcrypt.DefaultCost) + if err != nil { + return nil, err + } + return CreateUserWithHashedPassword(username, hashedPassword, conf...) +} + +func CreateOrLoadUser(username string, password string, conf ...CreateUserConfig) (*model.User, error) { + if username == "" { + return nil, errors.New("username cannot be empty") + } var user model.User - var userProvider model.UserProvider + if err := db.Where("username = ?", username).First(&user).Error; err != nil { + if errors.Is(err, gorm.ErrRecordNotFound) { + return CreateUser(username, password, conf...) + } else { + return nil, err + } + } + return &user, nil +} - if err := db.Where("provider = ? AND provider_user_id = ?", p, puid).First(&userProvider).Error; err != nil { +func CreateOrLoadUserWithHashedPassword(username string, hashedPassword []byte, conf ...CreateUserConfig) (*model.User, error) { + if username == "" { + return nil, errors.New("username cannot be empty") + } + var user model.User + if err := db.Where("username = ?", username).First(&user).Error; err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { - return CreateUser(username, p, puid, conf...) + return CreateUserWithHashedPassword(username, hashedPassword, conf...) } else { return nil, err } - } else { - if err := db.Where("id = ?", userProvider.UserID).First(&user).Error; err != nil { + } + return &user, nil +} + +// 只有当provider和puid没有找到对应的user时才会创建 +func CreateOrLoadUserWithProvider(username, password string, p provider.OAuth2Provider, puid string, conf ...CreateUserConfig) (*model.User, error) { + var user model.User + + if err := db.Where("id = (?)", db.Table("user_providers").Where("provider = ? AND provider_user_id = ?", p, puid).Select("user_id")).First(&user).Error; err != nil { + if errors.Is(err, gorm.ErrRecordNotFound) { + return CreateUser(username, password, append(conf, WithSetProvider(p, puid), WithRegisteredByProvider(true))...) + } else { return nil, err } + } else { + return &user, nil } +} +func GetUserByProvider(p provider.OAuth2Provider, puid string) (*model.User, error) { + var user model.User + if err := db.Where("id = (?)", db.Table("user_providers").Where("provider = ? AND provider_user_id = ?", p, puid).Select("user_id")).First(&user).Error; err != nil { + if errors.Is(err, gorm.ErrRecordNotFound) { + return &user, errors.New("user not found") + } else { + return &user, err + } + } return &user, nil } func GetProviderUserID(p provider.OAuth2Provider, puid string) (string, error) { var userProvider model.UserProvider - if err := db.Where("provider = ? AND provider_user_id = ?", p, puid).First(&userProvider).Error; err != nil { + if err := db.Where("provider = ? AND provider_user_id = ?", p, puid).Select("user_id").First(&userProvider).Error; err != nil { if errors.Is(err, gorm.ErrRecordNotFound) { return "", errors.New("user not found") } else { @@ -84,12 +177,29 @@ func BindProvider(uid string, p provider.OAuth2Provider, puid string) error { return err } +// 当用户是通过provider注册的时候,则最少保留一个provider,否则禁止解除绑定 func UnBindProvider(uid string, p provider.OAuth2Provider) error { - err := db.Where("user_id = ? AND provider = ?", uid, p).Delete(&model.UserProvider{}).Error - if errors.Is(err, gorm.ErrRecordNotFound) { - return errors.New("user could not bind provider") + tx := db.Begin() + user := model.User{} + if err := tx.Scopes(PreloadUserProviders()).Where("id = ?", uid).First(&user).Error; err != nil { + tx.Rollback() + if errors.Is(err, gorm.ErrRecordNotFound) { + return errors.New("user not found") + } + return err } - return err + if user.RegisteredByProvider && len(user.UserProviders) == 1 { + tx.Rollback() + return errors.New("user must have at least one provider") + } + if err := tx.Where("user_id = ? AND provider = ?", uid, p).Delete(&model.UserProvider{}).Error; err != nil { + tx.Rollback() + if errors.Is(err, gorm.ErrRecordNotFound) { + return errors.New("user could not bind provider") + } + return err + } + return tx.Commit().Error } func GetBindProviders(uid string) ([]*model.UserProvider, error) { @@ -312,3 +422,11 @@ func GetAllUsers(scopes ...func(*gorm.DB) *gorm.DB) []*model.User { db.Scopes(scopes...).Find(&users) return users } + +func SetUserHashedPassword(id string, hashedPassword []byte) error { + err := db.Model(&model.User{}).Where("id = ?", id).Update("hashed_password", hashedPassword).Error + if err != nil && errors.Is(err, gorm.ErrRecordNotFound) { + return errors.New("user not found") + } + return err +} diff --git a/internal/model/oauth2.go b/internal/model/oauth2.go index fbd2437..3d185a3 100644 --- a/internal/model/oauth2.go +++ b/internal/model/oauth2.go @@ -7,9 +7,9 @@ import ( ) type UserProvider struct { - Provider provider.OAuth2Provider `gorm:"not null;primarykey"` - ProviderUserID string `gorm:"not null;primarykey;autoIncrement:false"` + Provider provider.OAuth2Provider `gorm:"not null;primarykey;uniqueIndex:idx_provider_user_id"` + ProviderUserID string `gorm:"not null;primarykey"` CreatedAt time.Time UpdatedAt time.Time - UserID string `gorm:"not null;index"` + UserID string `gorm:"not null;uniqueIndex:idx_provider_user_id"` } diff --git a/internal/model/user.go b/internal/model/user.go index 8c1a4dc..8fdcdbf 100644 --- a/internal/model/user.go +++ b/internal/model/user.go @@ -6,6 +6,8 @@ import ( "time" "github.com/synctv-org/synctv/utils" + "github.com/zijiren233/stream" + "golang.org/x/crypto/bcrypt" "gorm.io/gorm" ) @@ -40,8 +42,10 @@ type User struct { ID string `gorm:"primaryKey;type:varchar(32)" json:"id"` CreatedAt time.Time UpdatedAt time.Time - Providers []UserProvider `gorm:"foreignKey:UserID;constraint:OnUpdate:CASCADE,OnDelete:CASCADE"` + RegisteredByProvider bool `gorm:"not null;default:false"` + UserProviders []UserProvider `gorm:"foreignKey:UserID;constraint:OnUpdate:CASCADE,OnDelete:CASCADE"` Username string `gorm:"not null;uniqueIndex"` + HashedPassword []byte `gorm:"not null"` Role Role `gorm:"not null;default:2"` RoomUserRelations []RoomUserRelation `gorm:"foreignKey:UserID;constraint:OnUpdate:CASCADE,OnDelete:CASCADE"` Rooms []Room `gorm:"foreignKey:CreatorID;constraint:OnUpdate:CASCADE,OnDelete:CASCADE"` @@ -49,6 +53,10 @@ type User struct { StreamingVendorInfos []StreamingVendorInfo `gorm:"foreignKey:UserID;constraint:OnUpdate:CASCADE,OnDelete:CASCADE"` } +func (u *User) CheckPassword(password string) bool { + return bcrypt.CompareHashAndPassword(u.HashedPassword, stream.StringToBytes(password)) == nil +} + func (u *User) BeforeCreate(tx *gorm.DB) error { var existingUser User err := tx.Where("username = ?", u.Username).First(&existingUser).Error diff --git a/internal/op/op.go b/internal/op/op.go index 0ac2af4..f6309c4 100644 --- a/internal/op/op.go +++ b/internal/op/op.go @@ -3,7 +3,6 @@ package op import ( "time" - "github.com/bluele/gcache" "github.com/zijiren233/gencontainer/synccache" ) @@ -11,9 +10,7 @@ func Init(size int) error { roomCache = synccache.NewSyncCache[string, *Room](time.Minute*5, synccache.WithDeletedCallback[string, *Room](func(v *Room) { v.close() })) - userCache = gcache.New(size). - LRU(). - Build() + userCache = synccache.NewSyncCache[string, *User](time.Minute * 5) return nil } diff --git a/internal/op/rooms.go b/internal/op/rooms.go index 9a55ce5..d063294 100644 --- a/internal/op/rooms.go +++ b/internal/op/rooms.go @@ -50,18 +50,18 @@ func DeleteRoomByID(roomID string) error { if err != nil { return err } - return CloseRoomByID(roomID) + return CloseRoomById(roomID) } func CompareAndDeleteRoom(room *Room) error { - err := CompareAndCloseRoom(room) + err := db.DeleteRoomByID(room.ID) if err != nil { return err } - return db.DeleteRoomByID(room.ID) + return CompareAndCloseRoom(room) } -func CloseRoomByID(roomID string) error { +func CloseRoomById(roomID string) error { r, loaded := roomCache.LoadAndDelete(roomID) if loaded { r.Value().close() @@ -69,6 +69,13 @@ func CloseRoomByID(roomID string) error { return nil } +func CompareAndCloseRoomEntry(id string, room *synccache.Entry[*Room]) error { + if roomCache.CompareAndDelete(id, room) { + room.Value().close() + } + return nil +} + func CompareAndCloseRoom(room *Room) error { r, loaded := roomCache.Load(room.ID) if loaded { @@ -135,14 +142,6 @@ func HasRoomByName(name string) bool { return ok } -func SetRoomPassword(roomID, password string) error { - r, err := LoadOrInitRoomByID(roomID) - if err != nil { - return err - } - return r.SetPassword(password) -} - func GetAllRoomsInCacheWithNoNeedPassword() []*Room { rooms := make([]*Room, 0) roomCache.Range(func(key string, value *synccache.Entry[*Room]) bool { diff --git a/internal/op/user.go b/internal/op/user.go index 31a50d9..a373b91 100644 --- a/internal/op/user.go +++ b/internal/op/user.go @@ -2,14 +2,41 @@ package op import ( "errors" + "hash/crc32" + "sync/atomic" "github.com/synctv-org/synctv/internal/db" "github.com/synctv-org/synctv/internal/model" + "github.com/synctv-org/synctv/internal/provider" "github.com/synctv-org/synctv/internal/settings" + "github.com/zijiren233/stream" + "golang.org/x/crypto/bcrypt" ) type User struct { model.User + version uint32 +} + +func (u *User) Version() uint32 { + return atomic.LoadUint32(&u.version) +} + +func (u *User) CheckVersion(version uint32) bool { + return atomic.LoadUint32(&u.version) == version +} + +func (u *User) SetPassword(password string) error { + if u.CheckPassword(password) { + return errors.New("password is the same") + } + hashedPassword, err := bcrypt.GenerateFromPassword(stream.StringToBytes(password), bcrypt.DefaultCost) + if err != nil { + return err + } + atomic.StoreUint32(&u.version, crc32.ChecksumIEEE(hashedPassword)) + u.HashedPassword = hashedPassword + return db.SetUserHashedPassword(u.ID, hashedPassword) } func (u *User) CreateRoom(name, password string, conf ...db.CreateRoomConfig) (*Room, error) { @@ -191,3 +218,11 @@ func (u *User) SetCurrentMovieByID(room *Room, movieID string, play bool) error } return u.SetCurrentMovie(room, m, play) } + +func (u *User) BindProvider(p provider.OAuth2Provider, pid string) error { + err := db.BindProvider(u.ID, p, pid) + if err != nil { + return err + } + return nil +} diff --git a/internal/op/users.go b/internal/op/users.go index 4e92a9b..747396b 100644 --- a/internal/op/users.go +++ b/internal/op/users.go @@ -2,121 +2,140 @@ package op import ( "errors" + "hash/crc32" "time" - "github.com/bluele/gcache" "github.com/synctv-org/synctv/internal/db" "github.com/synctv-org/synctv/internal/model" "github.com/synctv-org/synctv/internal/provider" "github.com/zijiren233/gencontainer/synccache" ) -var userCache gcache.Cache +var userCache *synccache.SyncCache[string, *User] var ( ErrUserBanned = errors.New("user banned") ErrUserPending = errors.New("user pending, please wait for admin to approve") ) -func GetUserById(id string) (*User, error) { - i, err := userCache.Get(id) - if err == nil { - return i.(*User), nil - } - - u, err := db.GetUserByID(id) - if err != nil { - return nil, err - } - +func LoadOrInitUser(u *model.User) (*User, error) { switch u.Role { case model.RoleBanned: return nil, ErrUserBanned case model.RolePending: return nil, ErrUserPending } + i, _ := userCache.LoadOrStore(u.ID, &User{ + User: *u, + version: crc32.ChecksumIEEE(u.HashedPassword), + }, time.Hour) + return i.Value(), nil +} + +func LoadOrInitUserByID(id string) (*User, error) { + u, ok := userCache.Load(id) + if ok { + u.SetExpiration(time.Now().Add(time.Hour)) + return u.Value(), nil + } - u2 := &User{ - User: *u, + user, err := db.GetUserByID(id) + if err != nil { + return nil, err + } + + return LoadOrInitUser(user) +} + +func LoadUserByUsername(username string) (*User, error) { + u, err := db.GetUserByUsername(username) + if err != nil { + return nil, err } - return u2, userCache.SetWithExpire(id, u2, time.Hour) + return LoadOrInitUser(u) } -func CreateUser(username string, p provider.OAuth2Provider, pid string, conf ...db.CreateUserConfig) (*User, error) { +func CreateOrLoadUser(username string, password string, conf ...db.CreateUserConfig) (*User, error) { if username == "" { return nil, errors.New("username cannot be empty") } - u, err := db.CreateUser(username, p, pid, conf...) + u, err := db.CreateOrLoadUser(username, password, conf...) if err != nil { return nil, err } - u2 := &User{ - User: *u, - } - - return u2, userCache.SetWithExpire(u.ID, u2, time.Hour) + return LoadOrInitUser(u) } -func CreateOrLoadUser(username string, p provider.OAuth2Provider, pid string, conf ...db.CreateUserConfig) (*User, error) { +func CreateOrLoadUserWithProvider(username, password string, p provider.OAuth2Provider, pid string, conf ...db.CreateUserConfig) (*User, error) { if username == "" { return nil, errors.New("username cannot be empty") } - u, err := db.CreateOrLoadUser(username, p, pid, conf...) + u, err := db.CreateOrLoadUserWithProvider(username, password, p, pid, conf...) if err != nil { return nil, err } - u2 := &User{ - User: *u, - } - - return u2, userCache.SetWithExpire(u.ID, u2, time.Hour) + return LoadOrInitUser(u) } func GetUserByProvider(p provider.OAuth2Provider, pid string) (*User, error) { - uid, err := db.GetProviderUserID(p, pid) + u, err := db.GetUserByProvider(p, pid) if err != nil { return nil, err } - return GetUserById(uid) + return LoadOrInitUser(u) } -func BindProvider(uid string, p provider.OAuth2Provider, pid string) error { - err := db.BindProvider(uid, p, pid) +func CompareAndDeleteUser(user *User) error { + err := db.DeleteUserByID(user.ID) if err != nil { return err } - return nil + return CompareAndCloseUser(user) } -func DeleteUserByID(userID string) error { - err := db.DeleteUserByID(userID) +func DeleteUserByID(id string) error { + err := db.DeleteUserByID(id) if err != nil { return err } - userCache.Remove(userID) + return CloseUserById(id) +} +func CloseUserById(id string) error { + userCache.Delete(id) roomCache.Range(func(key string, value *synccache.Entry[*Room]) bool { - v := value.Value() - if v.CreatorID == userID { - roomCache.CompareAndDelete(key, value) + if value.Value().CreatorID == id { + CompareAndCloseRoomEntry(key, value) } return true }) - return nil } -func SaveUser(u *model.User) error { - defer userCache.Remove(u.ID) - return db.SaveUser(u) +func CompareAndCloseUser(user *User) error { + u, loaded := userCache.LoadAndDelete(user.ID) + if loaded { + if u.Value() != user { + return errors.New("user compare failed") + } + if userCache.CompareAndDelete(user.ID, u) { + roomCache.Range(func(key string, value *synccache.Entry[*Room]) bool { + if value.Value().CreatorID == user.ID { + CompareAndCloseRoomEntry(key, value) + } + return true + }) + } + } + return nil } func GetUserName(userID string) string { - u, err := GetUserById(userID) + u, err := LoadOrInitUserByID(userID) if err != nil { return "" } @@ -128,21 +147,18 @@ func SetRoleByID(userID string, role model.Role) error { if err != nil { return err } - userCache.Remove(userID) - err = db.SetRoomStatusByCreator(userID, model.RoomStatusBanned) - if err != nil { - return err - } + userCache.Delete(userID) switch role { case model.RoleBanned: + err = db.SetRoomStatusByCreator(userID, model.RoomStatusBanned) + if err != nil { + return err + } roomCache.Range(func(key string, value *synccache.Entry[*Room]) bool { - v := value.Value() - if v.CreatorID == userID { - if roomCache.CompareAndDelete(key, value) { - v.close() - } + if value.Value().CreatorID == userID { + CompareAndCloseRoomEntry(key, value) } return true }) diff --git a/server/handlers/admin.go b/server/handlers/admin.go index c8ab446..38fc386 100644 --- a/server/handlers/admin.go +++ b/server/handlers/admin.go @@ -152,7 +152,7 @@ func GetRoomUsers(ctx *gin.Context) { var desc = ctx.DefaultQuery("order", "desc") == "desc" scopes := []func(db *gorm.DB) *gorm.DB{ - db.PreloadRoomUserRelation(db.WhereRoomID(id)), + db.PreloadRoomUserRelations(db.WhereRoomID(id)), } switch ctx.DefaultQuery("sort", "name") { @@ -508,3 +508,49 @@ func UnBanRoom(ctx *gin.Context) { ctx.Status(http.StatusNoContent) } + +func AddUser(ctx *gin.Context) { + // user := ctx.MustGet("user").(*op.User) + + req := model.AddUserReq{} + if err := model.Decode(ctx, &req); err != nil { + ctx.AbortWithStatusJSON(http.StatusBadRequest, model.NewApiErrorResp(err)) + return + } + + _, err := op.CreateOrLoadUser(req.Username, req.Password, db.WithRole(req.Role)) + if err != nil { + ctx.AbortWithStatusJSON(http.StatusBadRequest, model.NewApiErrorResp(err)) + return + } + + ctx.Status(http.StatusNoContent) +} + +func DeleteUser(ctx *gin.Context) { + user := ctx.MustGet("user").(*op.User) + + req := model.UserIDReq{} + if err := model.Decode(ctx, &req); err != nil { + ctx.AbortWithStatusJSON(http.StatusBadRequest, model.NewApiErrorResp(err)) + return + } + + u, err := db.GetUserByID(req.ID) + if err != nil { + ctx.AbortWithStatusJSON(http.StatusBadRequest, model.NewApiErrorResp(err)) + return + } + + if u.IsAdmin() && !user.IsRoot() { + ctx.AbortWithStatusJSON(http.StatusForbidden, model.NewApiErrorStringResp("cannot delete admin")) + return + } + + if err := op.DeleteUserByID(req.ID); err != nil { + ctx.AbortWithStatusJSON(http.StatusInternalServerError, model.NewApiErrorResp(err)) + return + } + + ctx.Status(http.StatusNoContent) +} diff --git a/server/handlers/init.go b/server/handlers/init.go index 18c0a4c..b98b9d8 100644 --- a/server/handlers/init.go +++ b/server/handlers/init.go @@ -39,6 +39,10 @@ func Init(e *gin.Engine) { { user := admin.Group("/user") + user.POST("/add", AddUser) + + user.POST("/delete", DeleteUser) + // 查找用户 user.GET("/list", Users) @@ -141,9 +145,11 @@ func Init(e *gin.Engine) { } { - // user := api.Group("/user") + user := api.Group("/user") needAuthUser := needAuthUserApi.Group("/user") + user.POST("/login", LoginUser) + needAuthUser.POST("/logout", LogoutUser) needAuthUser.GET("/me", Me) @@ -151,6 +157,8 @@ func Init(e *gin.Engine) { needAuthUser.GET("/rooms", UserRooms) needAuthUser.POST("/username", SetUsername) + + needAuthUser.POST("/password", SetUserPassword) } { diff --git a/server/handlers/room.go b/server/handlers/room.go index 7e7d4ac..3e2cad4 100644 --- a/server/handlers/room.go +++ b/server/handlers/room.go @@ -177,6 +177,10 @@ func LoginRoom(ctx *gin.Context) { room, err := op.LoadOrInitRoomByID(req.RoomId) if err != nil { + if err == op.ErrRoomBanned || err == op.ErrRoomPending { + ctx.AbortWithStatusJSON(http.StatusForbidden, model.NewApiErrorResp(err)) + return + } ctx.AbortWithStatusJSON(http.StatusNotFound, model.NewApiErrorResp(err)) return } @@ -325,7 +329,7 @@ func RoomUsers(ctx *gin.Context) { scopes = append(scopes, db.WhereIDIn(db.GerUsersIDByIDLike(keyword))) } } - scopes = append(scopes, db.PreloadRoomUserRelation(preloadScopes...)) + scopes = append(scopes, db.PreloadRoomUserRelations(preloadScopes...)) ctx.JSON(http.StatusOK, model.NewApiDataResp(gin.H{ "total": db.GetAllUserCount(scopes...), diff --git a/server/handlers/root.go b/server/handlers/root.go index 6278b2e..97dac0c 100644 --- a/server/handlers/root.go +++ b/server/handlers/root.go @@ -22,7 +22,7 @@ func AddAdmin(ctx *gin.Context) { ctx.AbortWithStatusJSON(http.StatusBadRequest, model.NewApiErrorStringResp("cannot add yourself")) return } - u, err := op.GetUserById(req.Id) + u, err := op.LoadOrInitUserByID(req.Id) if err != nil { ctx.AbortWithStatusJSON(http.StatusInternalServerError, model.NewApiErrorStringResp("user not found")) return @@ -53,7 +53,7 @@ func DeleteAdmin(ctx *gin.Context) { ctx.AbortWithStatusJSON(http.StatusBadRequest, model.NewApiErrorStringResp("cannot remove yourself")) return } - u, err := op.GetUserById(req.Id) + u, err := op.LoadOrInitUserByID(req.Id) if err != nil { ctx.AbortWithStatusJSON(http.StatusInternalServerError, model.NewApiErrorStringResp("user not found")) return diff --git a/server/handlers/user.go b/server/handlers/user.go index ef1ca04..76f2962 100644 --- a/server/handlers/user.go +++ b/server/handlers/user.go @@ -7,6 +7,7 @@ import ( "github.com/synctv-org/synctv/internal/db" dbModel "github.com/synctv-org/synctv/internal/model" "github.com/synctv-org/synctv/internal/op" + "github.com/synctv-org/synctv/server/middlewares" "github.com/synctv-org/synctv/server/model" "gorm.io/gorm" ) @@ -22,10 +23,43 @@ func Me(ctx *gin.Context) { })) } +func LoginUser(ctx *gin.Context) { + req := model.LoginUserReq{} + if err := model.Decode(ctx, &req); err != nil { + ctx.AbortWithStatusJSON(http.StatusBadRequest, model.NewApiErrorResp(err)) + return + } + + user, err := op.LoadUserByUsername(req.Username) + if err != nil { + if err == op.ErrUserBanned || err == op.ErrUserPending { + ctx.AbortWithStatusJSON(http.StatusForbidden, model.NewApiErrorResp(err)) + return + } + ctx.AbortWithStatusJSON(http.StatusInternalServerError, model.NewApiErrorResp(err)) + return + } + + if ok := user.CheckPassword(req.Password); !ok { + ctx.AbortWithStatusJSON(http.StatusForbidden, model.NewApiErrorStringResp("password incorrect")) + return + } + + token, err := middlewares.NewAuthUserToken(user) + if err != nil { + ctx.AbortWithStatusJSON(http.StatusInternalServerError, model.NewApiErrorResp(err)) + return + } + + ctx.JSON(http.StatusOK, model.NewApiDataResp(gin.H{ + "token": token, + })) +} + func LogoutUser(ctx *gin.Context) { user := ctx.MustGet("user").(*op.User) - err := op.DeleteUserByID(user.ID) + err := op.CompareAndDeleteUser(user) if err != nil { ctx.AbortWithStatusJSON(http.StatusInternalServerError, model.NewApiErrorResp(err)) return @@ -112,6 +146,24 @@ func SetUsername(ctx *gin.Context) { ctx.Status(http.StatusNoContent) } +func SetUserPassword(ctx *gin.Context) { + user := ctx.MustGet("user").(*op.User) + + var req model.SetUserPasswordReq + if err := model.Decode(ctx, &req); err != nil { + ctx.AbortWithStatusJSON(http.StatusBadRequest, model.NewApiErrorResp(err)) + return + } + + err := user.SetPassword(req.Password) + if err != nil { + ctx.AbortWithStatusJSON(http.StatusBadRequest, model.NewApiErrorResp(err)) + return + } + + ctx.Status(http.StatusNoContent) +} + func UserBindProviders(ctx *gin.Context) { user := ctx.MustGet("user").(*op.User) @@ -121,9 +173,9 @@ func UserBindProviders(ctx *gin.Context) { return } - resp := make([]model.UserBindProviderReq, len(up)) + resp := make([]model.UserBindProviderResp, len(up)) for i, v := range up { - resp[i] = model.UserBindProviderReq{ + resp[i] = model.UserBindProviderResp{ Provider: v.Provider, ProviderUserID: v.ProviderUserID, CreatedAt: v.CreatedAt.UnixMilli(), diff --git a/server/middlewares/auth.go b/server/middlewares/auth.go index 5cd30d6..f6f7c3a 100644 --- a/server/middlewares/auth.go +++ b/server/middlewares/auth.go @@ -20,14 +20,15 @@ var ( ) type AuthClaims struct { - UserId string `json:"u"` + UserId string `json:"u"` + UserVersion uint32 `json:"uv"` jwt.RegisteredClaims } type AuthRoomClaims struct { AuthClaims - RoomId string `json:"r"` - Version uint32 `json:"rv"` + RoomId string `json:"r"` + RoomVersion uint32 `json:"rv"` } func authRoom(Authorization string) (*AuthRoomClaims, error) { @@ -72,16 +73,21 @@ func AuthRoom(Authorization string) (*op.User, *op.Room, error) { return nil, nil, ErrAuthFailed } - u, err := op.GetUserById(claims.UserId) + u, err := op.LoadOrInitUserByID(claims.UserId) if err != nil { return nil, nil, err } + if !u.CheckVersion(claims.UserVersion) { + return nil, nil, ErrAuthExpired + } + r, err := op.LoadOrInitRoomByID(claims.RoomId) if err != nil { return nil, nil, err } - if !r.CheckVersion(claims.Version) { + + if !r.CheckVersion(claims.RoomVersion) { return nil, nil, ErrAuthExpired } @@ -98,11 +104,15 @@ func AuthUser(Authorization string) (*op.User, error) { return nil, ErrAuthFailed } - u, err := op.GetUserById(claims.UserId) + u, err := op.LoadOrInitUserByID(claims.UserId) if err != nil { return nil, err } + if !u.CheckVersion(claims.UserVersion) { + return nil, ErrAuthExpired + } + return u, nil } @@ -118,7 +128,8 @@ func NewAuthUserToken(user *op.User) (string, error) { return "", err } claims := &AuthClaims{ - UserId: user.ID, + UserId: user.ID, + UserVersion: user.Version(), RegisteredClaims: jwt.RegisteredClaims{ NotBefore: jwt.NewNumericDate(time.Now()), ExpiresAt: jwt.NewNumericDate(time.Now().Add(t)), @@ -154,14 +165,15 @@ func NewAuthRoomToken(user *op.User, room *op.Room) (string, error) { } claims := &AuthRoomClaims{ AuthClaims: AuthClaims{ - UserId: user.ID, + UserId: user.ID, + UserVersion: user.Version(), RegisteredClaims: jwt.RegisteredClaims{ NotBefore: jwt.NewNumericDate(time.Now()), ExpiresAt: jwt.NewNumericDate(time.Now().Add(t)), }, }, - RoomId: room.ID, - Version: room.Version(), + RoomId: room.ID, + RoomVersion: room.Version(), } return jwt.NewWithClaims(jwt.SigningMethodHS256, claims).SignedString(stream.StringToBytes(conf.Conf.Jwt.Secret)) } diff --git a/server/model/admin.go b/server/model/admin.go index e8c4ad1..6a3d59a 100644 --- a/server/model/admin.go +++ b/server/model/admin.go @@ -5,6 +5,7 @@ import ( "github.com/gin-gonic/gin" json "github.com/json-iterator/go" + dbModel "github.com/synctv-org/synctv/internal/model" ) var ( @@ -22,3 +23,33 @@ func (asr *AdminSettingsReq) Decode(ctx *gin.Context) error { } type AdminSettingsResp map[string]map[string]any + +type AddUserReq struct { + Username string `json:"username"` + Password string `json:"password"` + Role dbModel.Role `json:"role"` +} + +func (aur *AddUserReq) Validate() error { + if aur.Username == "" { + return errors.New("username is empty") + } else if len(aur.Username) > 32 { + return ErrUsernameTooLong + } else if !alnumPrintHanReg.MatchString(aur.Username) { + return ErrUsernameHasInvalidChar + } + + if aur.Password == "" { + return FormatEmptyPasswordError("user") + } else if len(aur.Password) > 32 { + return ErrPasswordTooLong + } else if !alnumPrintReg.MatchString(aur.Password) { + return ErrPasswordHasInvalidChar + } + + return nil +} + +func (aur *AddUserReq) Decode(ctx *gin.Context) error { + return json.NewDecoder(ctx.Request.Body).Decode(aur) +} diff --git a/server/model/user.go b/server/model/user.go index ce79a07..451607c 100644 --- a/server/model/user.go +++ b/server/model/user.go @@ -42,12 +42,16 @@ func (l *LoginUserReq) Validate() error { return errors.New("username is empty") } else if len(l.Username) > 32 { return ErrUsernameTooLong + } else if !alnumPrintHanReg.MatchString(l.Username) { + return ErrUsernameHasInvalidChar } if l.Password == "" { return FormatEmptyPasswordError("user") } else if len(l.Password) > 32 { return ErrPasswordTooLong + } else if !alnumPrintReg.MatchString(l.Password) { + return ErrPasswordHasInvalidChar } return nil } @@ -68,7 +72,7 @@ func (s *SetUsernameReq) Validate() error { return errors.New("username is empty") } else if len(s.Username) > 32 { return ErrUsernameTooLong - } else if !alnumPrintReg.MatchString(s.Username) { + } else if !alnumPrintHanReg.MatchString(s.Username) { return ErrUsernameHasInvalidChar } return nil @@ -93,7 +97,7 @@ func (u *UserIDReq) Validate() error { return nil } -type UserBindProviderReq struct { +type UserBindProviderResp struct { Provider provider.OAuth2Provider `json:"provider"` ProviderUserID string `json:"providerUserID"` CreatedAt int64 `json:"createdAt"` diff --git a/server/oauth2/auth.go b/server/oauth2/auth.go index dba032a..f105575 100644 --- a/server/oauth2/auth.go +++ b/server/oauth2/auth.go @@ -77,6 +77,10 @@ func OAuth2Callback(ctx *gin.Context) { ld, err := login(ctx, ctx.Query("state"), code, pi) if err != nil { + if err == op.ErrUserBanned || err == op.ErrUserPending { + ctx.AbortWithStatusJSON(http.StatusForbidden, model.NewApiErrorResp(err)) + return + } ctx.AbortWithStatusJSON(http.StatusBadRequest, model.NewApiErrorResp(err)) return } @@ -100,6 +104,10 @@ func OAuth2CallbackApi(ctx *gin.Context) { ld, err := login(ctx, req.State, req.Code, pi) if err != nil { + if err == op.ErrUserBanned || err == op.ErrUserPending { + ctx.AbortWithStatusJSON(http.StatusForbidden, model.NewApiErrorResp(err)) + return + } ctx.AbortWithStatusJSON(http.StatusBadRequest, model.NewApiErrorResp(err)) return } @@ -132,14 +140,14 @@ func login(ctx context.Context, state, code string, pi provider.ProviderInterfac var user *op.User if meta.Value().BindUserId != "" { - user, err = op.GetUserById(meta.Value().BindUserId) + user, err = op.LoadOrInitUserByID(meta.Value().BindUserId) } else if settings.DisableUserSignup.Get() { user, err = op.GetUserByProvider(pi.Provider(), ui.ProviderUserID) } else { if settings.SignupNeedReview.Get() { - user, err = op.CreateOrLoadUser(ui.Username, pi.Provider(), ui.ProviderUserID, db.WithRole(dbModel.RolePending)) + user, err = op.CreateOrLoadUserWithProvider(ui.Username, utils.RandString(16), pi.Provider(), ui.ProviderUserID, db.WithRole(dbModel.RolePending)) } else { - user, err = op.CreateOrLoadUser(ui.Username, pi.Provider(), ui.ProviderUserID) + user, err = op.CreateOrLoadUserWithProvider(ui.Username, utils.RandString(16), pi.Provider(), ui.ProviderUserID) } } if err != nil { @@ -147,7 +155,7 @@ func login(ctx context.Context, state, code string, pi provider.ProviderInterfac } if meta.Value().BindUserId != "" { - err = op.BindProvider(meta.Value().BindUserId, pi.Provider(), ui.ProviderUserID) + err = user.BindProvider(pi.Provider(), ui.ProviderUserID) if err != nil { return nil, err }