diff --git a/internal/bootstrap/provider.go b/internal/bootstrap/provider.go index 2c2ddbf..51661d5 100644 --- a/internal/bootstrap/provider.go +++ b/internal/bootstrap/provider.go @@ -182,6 +182,10 @@ func InitProviderSetting(pi provider.Provider) { groupSettings.DisableUserSignup = settings.NewBoolSetting(group+"_disable_user_signup", false, group) groupSettings.SignupNeedReview = settings.NewBoolSetting(group+"_signup_need_review", false, group) + + if registerSetting, ok := pi.(provider.ProviderRegistSetting); ok { + registerSetting.RegistSetting(group) + } } func InitAggregationProviderSetting(pi provider.Provider) { diff --git a/internal/db/user.go b/internal/db/user.go index 482fe6f..83c71c4 100644 --- a/internal/db/user.go +++ b/internal/db/user.go @@ -5,7 +5,6 @@ import ( "fmt" "github.com/synctv-org/synctv/internal/model" - "github.com/synctv-org/synctv/internal/provider" "github.com/synctv-org/synctv/utils" "github.com/zijiren233/stream" "golang.org/x/crypto/bcrypt" @@ -95,7 +94,7 @@ func CreateUser(username string, password string, conf ...CreateUserConfig) (*mo return CreateUserWithHashedPassword(username, hashedPassword, conf...) } -func CreateOrLoadUserWithProvider(username, password string, p provider.OAuth2Provider, puid string, conf ...CreateUserConfig) (*model.User, error) { +func CreateOrLoadUserWithProvider(username, password string, p string, puid string, conf ...CreateUserConfig) (*model.User, error) { if puid == "" { return nil, errors.New("provider user id cannot be empty") } @@ -139,7 +138,7 @@ func CreateUserWithEmail(username, password, email string, conf ...CreateUserCon )...) } -func GetUserByProvider(p provider.OAuth2Provider, puid string) (*model.User, error) { +func GetUserByProvider(p string, puid string) (*model.User, error) { var user model.User err := db.Joins("JOIN user_providers ON users.id = user_providers.user_id"). Where("user_providers.provider = ? AND user_providers.provider_user_id = ?", p, puid). @@ -153,7 +152,7 @@ func GetUserByEmail(email string) (*model.User, error) { return &user, HandleNotFound(err, ErrUserNotFound) } -func GetProviderUserID(p provider.OAuth2Provider, puid string) (string, error) { +func GetProviderUserID(p string, puid string) (string, error) { var userID string err := db.Model(&model.UserProvider{}). Where("provider = ? AND provider_user_id = ?", p, puid). @@ -162,7 +161,7 @@ func GetProviderUserID(p provider.OAuth2Provider, puid string) (string, error) { return userID, HandleNotFound(err, ErrUserNotFound) } -func BindProvider(uid string, p provider.OAuth2Provider, puid string) error { +func BindProvider(uid string, p string, puid string) error { err := db.Create(&model.UserProvider{ UserID: uid, Provider: p, @@ -177,7 +176,7 @@ func BindProvider(uid string, p provider.OAuth2Provider, puid string) error { return nil } -func UnBindProvider(uid string, p provider.OAuth2Provider) error { +func UnBindProvider(uid string, p string) error { return Transactional(func(tx *gorm.DB) error { var user model.User if err := tx.Preload("UserProviders").Where("id = ?", uid).First(&user).Error; err != nil { diff --git a/internal/model/oauth2.go b/internal/model/oauth2.go index b109b67..bbee812 100644 --- a/internal/model/oauth2.go +++ b/internal/model/oauth2.go @@ -2,13 +2,11 @@ package model import ( "time" - - "github.com/synctv-org/synctv/internal/provider" ) type UserProvider struct { - Provider provider.OAuth2Provider `gorm:"primarykey;type:varchar(32);uniqueIndex:idx_provider_user_id"` - ProviderUserID string `gorm:"primarykey;type:varchar(64)"` + Provider string `gorm:"primarykey;type:varchar(32);uniqueIndex:idx_provider_user_id"` + ProviderUserID string `gorm:"primarykey;type:varchar(64)"` CreatedAt time.Time UpdatedAt time.Time UserID string `gorm:"not null;type:char(32);uniqueIndex:idx_provider_user_id"` diff --git a/internal/provider/provider.go b/internal/provider/provider.go index 18a3409..8b3f6a4 100644 --- a/internal/provider/provider.go +++ b/internal/provider/provider.go @@ -22,6 +22,10 @@ type Provider interface { Provider() OAuth2Provider } +type ProviderRegistSetting interface { + RegistSetting(group string) +} + type Interface interface { Provider NewAuthURL(context.Context, string) (string, error) diff --git a/internal/provider/providers/logto.go b/internal/provider/providers/logto.go new file mode 100644 index 0000000..04d2801 --- /dev/null +++ b/internal/provider/providers/logto.go @@ -0,0 +1,115 @@ +package providers + +import ( + "context" + "encoding/json" + "net/http" + "strings" + + "github.com/synctv-org/synctv/internal/provider" + "github.com/synctv-org/synctv/internal/settings" + "golang.org/x/oauth2" +) + +// https://openapi.logto.io/authentication +type logtoProvider struct { + config oauth2.Config + endpoint string +} + +func newLogtoProvider() provider.Interface { + return &logtoProvider{ + config: oauth2.Config{ + Scopes: []string{"profile", "email", "phone", "name", "openid"}, + }, + } +} + +func (p *logtoProvider) Init(opt provider.Oauth2Option) { + p.config.ClientID = opt.ClientID + p.config.ClientSecret = opt.ClientSecret + p.config.RedirectURL = opt.RedirectURL +} + +func (p *logtoProvider) NewAuthURL(ctx context.Context, state string) (string, error) { + return p.config.AuthCodeURL(state, oauth2.AccessTypeOnline), nil +} + +func (p *logtoProvider) GetToken(ctx context.Context, code string) (*oauth2.Token, error) { + return p.config.Exchange(ctx, code) +} + +func (p *logtoProvider) RefreshToken(ctx context.Context, token string) (*oauth2.Token, error) { + return p.config.TokenSource(ctx, &oauth2.Token{RefreshToken: token}).Token() +} + +func (p *logtoProvider) GetUserInfo(ctx context.Context, code string) (*provider.UserInfo, error) { + tk, err := p.GetToken(ctx, code) + if err != nil { + return nil, err + } + client := p.config.Client(ctx, tk) + req, err := http.NewRequestWithContext(ctx, http.MethodGet, p.endpoint+"/oidc/me", nil) + if err != nil { + return nil, err + } + resp, err := client.Do(req) + if err != nil { + return nil, err + } + defer resp.Body.Close() + var ui logtoUserInfo + err = json.NewDecoder(resp.Body).Decode(&ui) + if err != nil { + return nil, err + } + un := ui.Username + if un == "" { + un = ui.Name + } + return &provider.UserInfo{ + ProviderUserID: ui.Sub, + Username: un, + }, nil +} + +type logtoUserInfo struct { + Sub string `json:"sub"` + Username string `json:"username"` + PrimaryEmail string `json:"primaryEmail"` + PrimaryPhone string `json:"primaryPhone"` + Name string `json:"name"` + Email string `json:"email"` +} + +func (p *logtoProvider) RegistSetting(group string) { + settings.NewStringSetting( + group+"_endpoint", "", group, + settings.WithAfterInitString(func(ss settings.StringSetting, s string) { + s = strings.TrimSuffix(s, "/") + s = strings.TrimSuffix(s, "/oidc") + p.endpoint = s + p.config.Endpoint = oauth2.Endpoint{ + AuthURL: s + "/oidc/auth", + TokenURL: s + "/oidc/token", + } + }), + settings.WithAfterSetString(func(ss settings.StringSetting, s string) { + s = strings.TrimSuffix(s, "/") + s = strings.TrimSuffix(s, "/oidc") + p.endpoint = s + p.config.Endpoint = oauth2.Endpoint{ + AuthURL: s + "/oidc/auth", + TokenURL: s + "/oidc/token", + } + }), + ) +} + +func (p *logtoProvider) Provider() provider.OAuth2Provider { + return "logto" +} + +func init() { + RegisterProvider(newLogtoProvider()) +} diff --git a/server/oauth2/auth.go b/server/oauth2/auth.go index b08a6d6..00c8a54 100644 --- a/server/oauth2/auth.go +++ b/server/oauth2/auth.go @@ -157,6 +157,17 @@ func newAuthFunc(redirect string) stateHandler { return } + if ui.ProviderUserID == "" { + log.Errorf("invalid oauth2 provider user id") + ctx.AbortWithStatusJSON(http.StatusBadRequest, model.NewAPIErrorStringResp("invalid oauth2 provider user id")) + return + } + if ui.Username == "" { + log.Errorf("invalid oauth2 username") + ctx.AbortWithStatusJSON(http.StatusBadRequest, model.NewAPIErrorStringResp("invalid oauth2 username")) + return + } + pgs, loaded := bootstrap.ProviderGroupSettings[fmt.Sprintf("%s_%s", dbModel.SettingGroupOauth2, pi.Provider())] if !loaded { log.Errorf("invalid oauth2 provider") diff --git a/server/oauth2/bind.go b/server/oauth2/bind.go index af3860f..8351b88 100644 --- a/server/oauth2/bind.go +++ b/server/oauth2/bind.go @@ -78,6 +78,17 @@ func newBindFunc(userID, redirect string) stateHandler { return } + if ui.ProviderUserID == "" { + log.Errorf("invalid oauth2 provider user id") + ctx.AbortWithStatusJSON(http.StatusBadRequest, model.NewAPIErrorStringResp("invalid oauth2 provider user id")) + return + } + if ui.Username == "" { + log.Errorf("invalid oauth2 username") + ctx.AbortWithStatusJSON(http.StatusBadRequest, model.NewAPIErrorStringResp("invalid oauth2 username")) + return + } + user, err := op.LoadOrInitUserByID(userID) if err != nil { log.Errorf("failed to load user: %v", err)