|
|
|
@ -7,6 +7,8 @@ import (
|
|
|
|
|
"github.com/synctv-org/synctv/internal/model"
|
|
|
|
|
"github.com/synctv-org/synctv/internal/provider"
|
|
|
|
|
"github.com/synctv-org/synctv/utils"
|
|
|
|
|
"github.com/zijiren233/stream"
|
|
|
|
|
"golang.org/x/crypto/bcrypt"
|
|
|
|
|
"gorm.io/gorm"
|
|
|
|
|
"gorm.io/gorm/clause"
|
|
|
|
|
)
|
|
|
|
@ -19,20 +21,60 @@ func WithRole(role model.Role) CreateUserConfig {
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
func CreateUser(username string, p provider.OAuth2Provider, puid string, conf ...CreateUserConfig) (*model.User, error) {
|
|
|
|
|
func WithAppendProvider(p provider.OAuth2Provider, puid string) CreateUserConfig {
|
|
|
|
|
return func(u *model.User) {
|
|
|
|
|
u.UserProviders = append(u.UserProviders, model.UserProvider{
|
|
|
|
|
Provider: p,
|
|
|
|
|
ProviderUserID: puid,
|
|
|
|
|
})
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
func WithSetProvider(p provider.OAuth2Provider, puid string) CreateUserConfig {
|
|
|
|
|
return func(u *model.User) {
|
|
|
|
|
u.UserProviders = []model.UserProvider{{
|
|
|
|
|
Provider: p,
|
|
|
|
|
ProviderUserID: puid,
|
|
|
|
|
}}
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
func WithAppendProviders(providers []model.UserProvider) CreateUserConfig {
|
|
|
|
|
return func(u *model.User) {
|
|
|
|
|
u.UserProviders = append(u.UserProviders, providers...)
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
func WithSetProviders(providers []model.UserProvider) CreateUserConfig {
|
|
|
|
|
return func(u *model.User) {
|
|
|
|
|
u.UserProviders = providers
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
func WithRegisteredByProvider(b bool) CreateUserConfig {
|
|
|
|
|
return func(u *model.User) {
|
|
|
|
|
u.RegisteredByProvider = b
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
func CreateUserWithHashedPassword(username string, hashedPassword []byte, conf ...CreateUserConfig) (*model.User, error) {
|
|
|
|
|
if username == "" {
|
|
|
|
|
return nil, errors.New("username cannot be empty")
|
|
|
|
|
}
|
|
|
|
|
if len(hashedPassword) == 0 {
|
|
|
|
|
return nil, errors.New("password cannot be empty")
|
|
|
|
|
}
|
|
|
|
|
u := &model.User{
|
|
|
|
|
Username: username,
|
|
|
|
|
Role: model.RoleUser,
|
|
|
|
|
Providers: []model.UserProvider{
|
|
|
|
|
{
|
|
|
|
|
Provider: p,
|
|
|
|
|
ProviderUserID: puid,
|
|
|
|
|
},
|
|
|
|
|
},
|
|
|
|
|
Username: username,
|
|
|
|
|
Role: model.RoleUser,
|
|
|
|
|
HashedPassword: hashedPassword,
|
|
|
|
|
}
|
|
|
|
|
for _, c := range conf {
|
|
|
|
|
c(u)
|
|
|
|
|
}
|
|
|
|
|
if u.Role == 0 {
|
|
|
|
|
return nil, errors.New("role cannot be empty")
|
|
|
|
|
}
|
|
|
|
|
err := db.Create(u).Error
|
|
|
|
|
if err != nil && errors.Is(err, gorm.ErrDuplicatedKey) {
|
|
|
|
|
return u, errors.New("user already exists")
|
|
|
|
@ -40,29 +82,80 @@ func CreateUser(username string, p provider.OAuth2Provider, puid string, conf ..
|
|
|
|
|
return u, err
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// 只有当provider和puid没有找到对应的user时才会创建
|
|
|
|
|
func CreateOrLoadUser(username string, p provider.OAuth2Provider, puid string, conf ...CreateUserConfig) (*model.User, error) {
|
|
|
|
|
func CreateUser(username string, password string, conf ...CreateUserConfig) (*model.User, error) {
|
|
|
|
|
if username == "" {
|
|
|
|
|
return nil, errors.New("username cannot be empty")
|
|
|
|
|
}
|
|
|
|
|
if password == "" {
|
|
|
|
|
return nil, errors.New("password cannot be empty")
|
|
|
|
|
}
|
|
|
|
|
hashedPassword, err := bcrypt.GenerateFromPassword(stream.StringToBytes(password), bcrypt.DefaultCost)
|
|
|
|
|
if err != nil {
|
|
|
|
|
return nil, err
|
|
|
|
|
}
|
|
|
|
|
return CreateUserWithHashedPassword(username, hashedPassword, conf...)
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
func CreateOrLoadUser(username string, password string, conf ...CreateUserConfig) (*model.User, error) {
|
|
|
|
|
if username == "" {
|
|
|
|
|
return nil, errors.New("username cannot be empty")
|
|
|
|
|
}
|
|
|
|
|
var user model.User
|
|
|
|
|
var userProvider model.UserProvider
|
|
|
|
|
if err := db.Where("username = ?", username).First(&user).Error; err != nil {
|
|
|
|
|
if errors.Is(err, gorm.ErrRecordNotFound) {
|
|
|
|
|
return CreateUser(username, password, conf...)
|
|
|
|
|
} else {
|
|
|
|
|
return nil, err
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
return &user, nil
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
if err := db.Where("provider = ? AND provider_user_id = ?", p, puid).First(&userProvider).Error; err != nil {
|
|
|
|
|
func CreateOrLoadUserWithHashedPassword(username string, hashedPassword []byte, conf ...CreateUserConfig) (*model.User, error) {
|
|
|
|
|
if username == "" {
|
|
|
|
|
return nil, errors.New("username cannot be empty")
|
|
|
|
|
}
|
|
|
|
|
var user model.User
|
|
|
|
|
if err := db.Where("username = ?", username).First(&user).Error; err != nil {
|
|
|
|
|
if errors.Is(err, gorm.ErrRecordNotFound) {
|
|
|
|
|
return CreateUser(username, p, puid, conf...)
|
|
|
|
|
return CreateUserWithHashedPassword(username, hashedPassword, conf...)
|
|
|
|
|
} else {
|
|
|
|
|
return nil, err
|
|
|
|
|
}
|
|
|
|
|
} else {
|
|
|
|
|
if err := db.Where("id = ?", userProvider.UserID).First(&user).Error; err != nil {
|
|
|
|
|
}
|
|
|
|
|
return &user, nil
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// 只有当provider和puid没有找到对应的user时才会创建
|
|
|
|
|
func CreateOrLoadUserWithProvider(username, password string, p provider.OAuth2Provider, puid string, conf ...CreateUserConfig) (*model.User, error) {
|
|
|
|
|
var user model.User
|
|
|
|
|
|
|
|
|
|
if err := db.Where("id = (?)", db.Table("user_providers").Where("provider = ? AND provider_user_id = ?", p, puid).Select("user_id")).First(&user).Error; err != nil {
|
|
|
|
|
if errors.Is(err, gorm.ErrRecordNotFound) {
|
|
|
|
|
return CreateUser(username, password, append(conf, WithSetProvider(p, puid), WithRegisteredByProvider(true))...)
|
|
|
|
|
} else {
|
|
|
|
|
return nil, err
|
|
|
|
|
}
|
|
|
|
|
} else {
|
|
|
|
|
return &user, nil
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
func GetUserByProvider(p provider.OAuth2Provider, puid string) (*model.User, error) {
|
|
|
|
|
var user model.User
|
|
|
|
|
if err := db.Where("id = (?)", db.Table("user_providers").Where("provider = ? AND provider_user_id = ?", p, puid).Select("user_id")).First(&user).Error; err != nil {
|
|
|
|
|
if errors.Is(err, gorm.ErrRecordNotFound) {
|
|
|
|
|
return &user, errors.New("user not found")
|
|
|
|
|
} else {
|
|
|
|
|
return &user, err
|
|
|
|
|
}
|
|
|
|
|
}
|
|
|
|
|
return &user, nil
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
func GetProviderUserID(p provider.OAuth2Provider, puid string) (string, error) {
|
|
|
|
|
var userProvider model.UserProvider
|
|
|
|
|
if err := db.Where("provider = ? AND provider_user_id = ?", p, puid).First(&userProvider).Error; err != nil {
|
|
|
|
|
if err := db.Where("provider = ? AND provider_user_id = ?", p, puid).Select("user_id").First(&userProvider).Error; err != nil {
|
|
|
|
|
if errors.Is(err, gorm.ErrRecordNotFound) {
|
|
|
|
|
return "", errors.New("user not found")
|
|
|
|
|
} else {
|
|
|
|
@ -84,12 +177,29 @@ func BindProvider(uid string, p provider.OAuth2Provider, puid string) error {
|
|
|
|
|
return err
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
// 当用户是通过provider注册的时候,则最少保留一个provider,否则禁止解除绑定
|
|
|
|
|
func UnBindProvider(uid string, p provider.OAuth2Provider) error {
|
|
|
|
|
err := db.Where("user_id = ? AND provider = ?", uid, p).Delete(&model.UserProvider{}).Error
|
|
|
|
|
if errors.Is(err, gorm.ErrRecordNotFound) {
|
|
|
|
|
return errors.New("user could not bind provider")
|
|
|
|
|
tx := db.Begin()
|
|
|
|
|
user := model.User{}
|
|
|
|
|
if err := tx.Scopes(PreloadUserProviders()).Where("id = ?", uid).First(&user).Error; err != nil {
|
|
|
|
|
tx.Rollback()
|
|
|
|
|
if errors.Is(err, gorm.ErrRecordNotFound) {
|
|
|
|
|
return errors.New("user not found")
|
|
|
|
|
}
|
|
|
|
|
return err
|
|
|
|
|
}
|
|
|
|
|
return err
|
|
|
|
|
if user.RegisteredByProvider && len(user.UserProviders) == 1 {
|
|
|
|
|
tx.Rollback()
|
|
|
|
|
return errors.New("user must have at least one provider")
|
|
|
|
|
}
|
|
|
|
|
if err := tx.Where("user_id = ? AND provider = ?", uid, p).Delete(&model.UserProvider{}).Error; err != nil {
|
|
|
|
|
tx.Rollback()
|
|
|
|
|
if errors.Is(err, gorm.ErrRecordNotFound) {
|
|
|
|
|
return errors.New("user could not bind provider")
|
|
|
|
|
}
|
|
|
|
|
return err
|
|
|
|
|
}
|
|
|
|
|
return tx.Commit().Error
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
func GetBindProviders(uid string) ([]*model.UserProvider, error) {
|
|
|
|
@ -312,3 +422,11 @@ func GetAllUsers(scopes ...func(*gorm.DB) *gorm.DB) []*model.User {
|
|
|
|
|
db.Scopes(scopes...).Find(&users)
|
|
|
|
|
return users
|
|
|
|
|
}
|
|
|
|
|
|
|
|
|
|
func SetUserHashedPassword(id string, hashedPassword []byte) error {
|
|
|
|
|
err := db.Model(&model.User{}).Where("id = ?", id).Update("hashed_password", hashedPassword).Error
|
|
|
|
|
if err != nil && errors.Is(err, gorm.ErrRecordNotFound) {
|
|
|
|
|
return errors.New("user not found")
|
|
|
|
|
}
|
|
|
|
|
return err
|
|
|
|
|
}
|
|
|
|
|