diff --git a/internal/db/user.go b/internal/db/user.go index 711fbd4..51d33e1 100644 --- a/internal/db/user.go +++ b/internal/db/user.go @@ -72,6 +72,35 @@ func GetProviderUserID(p provider.OAuth2Provider, puid string) (string, error) { return userProvider.UserID, nil } +func BindProvider(uid string, p provider.OAuth2Provider, puid string) error { + err := db.Create(&model.UserProvider{ + UserID: uid, + Provider: p, + ProviderUserID: puid, + }).Error + if err != nil && errors.Is(err, gorm.ErrDuplicatedKey) { + return errors.New("provider already bind") + } + return err +} + +func UnBindProvider(uid string, p provider.OAuth2Provider) error { + err := db.Where("user_id = ? AND provider = ?", uid, p).Delete(&model.UserProvider{}).Error + if errors.Is(err, gorm.ErrRecordNotFound) { + return errors.New("user could not bind provider") + } + return err +} + +func GetBindProviders(uid string) ([]*model.UserProvider, error) { + var providers []*model.UserProvider + err := db.Where("user_id = ?", uid).Find(&providers).Error + if errors.Is(err, gorm.ErrRecordNotFound) { + return providers, errors.New("user not found") + } + return providers, err +} + func GetUserByUsername(username string) (*model.User, error) { u := &model.User{} err := db.Where("username = ?", username).First(u).Error diff --git a/internal/op/users.go b/internal/op/users.go index ececede..4e92a9b 100644 --- a/internal/op/users.go +++ b/internal/op/users.go @@ -84,6 +84,14 @@ func GetUserByProvider(p provider.OAuth2Provider, pid string) (*User, error) { return GetUserById(uid) } +func BindProvider(uid string, p provider.OAuth2Provider, pid string) error { + err := db.BindProvider(uid, p, pid) + if err != nil { + return err + } + return nil +} + func DeleteUserByID(userID string) error { err := db.DeleteUserByID(userID) if err != nil { diff --git a/server/handlers/user.go b/server/handlers/user.go index daa0419..ef1ca04 100644 --- a/server/handlers/user.go +++ b/server/handlers/user.go @@ -111,3 +111,24 @@ func SetUsername(ctx *gin.Context) { ctx.Status(http.StatusNoContent) } + +func UserBindProviders(ctx *gin.Context) { + user := ctx.MustGet("user").(*op.User) + + up, err := db.GetBindProviders(user.ID) + if err != nil { + ctx.AbortWithStatusJSON(http.StatusInternalServerError, model.NewApiErrorResp(err)) + return + } + + resp := make([]model.UserBindProviderReq, len(up)) + for i, v := range up { + resp[i] = model.UserBindProviderReq{ + Provider: v.Provider, + ProviderUserID: v.ProviderUserID, + CreatedAt: v.CreatedAt.UnixMilli(), + } + } + + ctx.JSON(http.StatusOK, resp) +} diff --git a/server/model/auth.go b/server/model/auth.go index b4e2c9e..2d148c6 100644 --- a/server/model/auth.go +++ b/server/model/auth.go @@ -30,3 +30,15 @@ func (o *OAuth2CallbackReq) Validate() error { func (o *OAuth2CallbackReq) Decode(ctx *gin.Context) error { return json.NewDecoder(ctx.Request.Body).Decode(o) } + +type OAuth2Req struct { + Redirect string `json:"redirect"` +} + +func (o *OAuth2Req) Validate() error { + return nil +} + +func (o *OAuth2Req) Decode(ctx *gin.Context) error { + return json.NewDecoder(ctx.Request.Body).Decode(o) +} diff --git a/server/model/user.go b/server/model/user.go index 5937326..ce79a07 100644 --- a/server/model/user.go +++ b/server/model/user.go @@ -6,6 +6,7 @@ import ( "github.com/gin-gonic/gin" json "github.com/json-iterator/go" dbModel "github.com/synctv-org/synctv/internal/model" + "github.com/synctv-org/synctv/internal/provider" ) type SetUserPasswordReq struct { @@ -91,3 +92,9 @@ func (u *UserIDReq) Validate() error { } return nil } + +type UserBindProviderReq struct { + Provider provider.OAuth2Provider `json:"provider"` + ProviderUserID string `json:"providerUserID"` + CreatedAt int64 `json:"createdAt"` +} diff --git a/server/oauth2/auth.go b/server/oauth2/auth.go index 18181b9..290786a 100644 --- a/server/oauth2/auth.go +++ b/server/oauth2/auth.go @@ -1,6 +1,8 @@ package auth import ( + "context" + "errors" "net/http" "time" @@ -16,37 +18,49 @@ import ( "github.com/synctv-org/synctv/utils" ) +// GET // /oauth2/login/:type func OAuth2(ctx *gin.Context) { - t := ctx.Param("type") - - pi, err := providers.GetProvider(provider.OAuth2Provider(t)) + pi, err := providers.GetProvider(provider.OAuth2Provider(ctx.Param("type"))) if err != nil { ctx.AbortWithStatusJSON(http.StatusBadRequest, model.NewApiErrorResp(err)) return } state := utils.RandString(16) - states.Store(state, struct{}{}, time.Minute*5) + states.Store(state, stateMeta{ + OAuth2Req: model.OAuth2Req{ + Redirect: ctx.Query("redirect"), + }, + }, time.Minute*5) RenderRedirect(ctx, pi.NewAuthURL(state)) } +// POST func OAuth2Api(ctx *gin.Context) { - t := ctx.Param("type") - pi, err := providers.GetProvider(provider.OAuth2Provider(t)) + pi, err := providers.GetProvider(provider.OAuth2Provider(ctx.Param("type"))) if err != nil { ctx.AbortWithStatusJSON(http.StatusBadRequest, model.NewApiErrorResp(err)) } + meta := model.OAuth2Req{} + if err := model.Decode(ctx, &meta); err != nil { + ctx.AbortWithStatusJSON(http.StatusBadRequest, model.NewApiErrorResp(err)) + return + } + state := utils.RandString(16) - states.Store(state, struct{}{}, time.Minute*5) + states.Store(state, stateMeta{ + OAuth2Req: meta, + }, time.Minute*5) ctx.JSON(http.StatusOK, model.NewApiDataResp(gin.H{ "url": pi.NewAuthURL(state), })) } +// GET // /oauth2/callback/:type func OAuth2Callback(ctx *gin.Context) { code := ctx.Query("code") @@ -55,111 +69,102 @@ func OAuth2Callback(ctx *gin.Context) { return } - state := ctx.Query("state") - if state == "" { - ctx.AbortWithStatusJSON(http.StatusBadRequest, model.NewApiErrorStringResp("invalid oauth2 state")) - return - } - - _, loaded := states.LoadAndDelete(state) - if !loaded { - ctx.AbortWithStatusJSON(http.StatusBadRequest, model.NewApiErrorStringResp("invalid oauth2 state")) - return - } - - p := provider.OAuth2Provider(ctx.Param("type")) - pi, err := providers.GetProvider(p) + pi, err := providers.GetProvider(provider.OAuth2Provider(ctx.Param("type"))) if err != nil { ctx.AbortWithStatusJSON(http.StatusBadRequest, model.NewApiErrorResp(err)) return } - t, err := pi.GetToken(ctx, code) + ld, err := login(ctx, ctx.Query("state"), code, pi) if err != nil { ctx.AbortWithStatusJSON(http.StatusBadRequest, model.NewApiErrorResp(err)) return } - ui, err := pi.GetUserInfo(ctx, t) - if err != nil { + RenderToken(ctx, ld.redirect, ld.token) +} + +// POST +// /oauth2/callback/:type +func OAuth2CallbackApi(ctx *gin.Context) { + req := model.OAuth2CallbackReq{} + if err := req.Decode(ctx); err != nil { ctx.AbortWithStatusJSON(http.StatusBadRequest, model.NewApiErrorResp(err)) return } - var user *op.User - if settings.DisableUserSignup.Get() { - user, err = op.GetUserByProvider(p, ui.ProviderUserID) - } else { - if settings.SignupNeedReview.Get() { - user, err = op.CreateOrLoadUser(ui.Username, p, ui.ProviderUserID, db.WithRole(dbModel.RolePending)) - } else { - user, err = op.CreateOrLoadUser(ui.Username, p, ui.ProviderUserID) - } - } + pi, err := providers.GetProvider(provider.OAuth2Provider(ctx.Param("type"))) if err != nil { - ctx.AbortWithStatusJSON(http.StatusInternalServerError, model.NewApiErrorResp(err)) - return + ctx.AbortWithStatusJSON(http.StatusBadRequest, model.NewApiErrorResp(err)) } - token, err := middlewares.NewAuthUserToken(user) + ld, err := login(ctx, req.State, req.Code, pi) if err != nil { - ctx.AbortWithStatusJSON(http.StatusInternalServerError, model.NewApiErrorResp(err)) + ctx.AbortWithStatusJSON(http.StatusBadRequest, model.NewApiErrorResp(err)) return } - RenderToken(ctx, "/web/", token) + ctx.JSON(http.StatusOK, model.NewApiDataResp(gin.H{ + "token": ld.token, + "redirect": ld.redirect, + })) } -// /oauth2/callback/:type -func OAuth2CallbackApi(ctx *gin.Context) { - req := model.OAuth2CallbackReq{} - if err := req.Decode(ctx); err != nil { - ctx.AbortWithStatusJSON(http.StatusBadRequest, model.NewApiErrorResp(err)) - return - } +type loginData struct { + token, redirect string +} - _, loaded := states.LoadAndDelete(req.State) +func login(ctx context.Context, state, code string, pi provider.ProviderInterface) (*loginData, error) { + meta, loaded := states.LoadAndDelete(state) if !loaded { - ctx.AbortWithStatusJSON(http.StatusBadRequest, model.NewApiErrorStringResp("invalid oauth2 state")) - return + return nil, errors.New("invalid oauth2 state") } - p := provider.OAuth2Provider(ctx.Param("type")) - pi, err := providers.GetProvider(p) + t, err := pi.GetToken(ctx, code) if err != nil { - ctx.AbortWithStatusJSON(http.StatusBadRequest, model.NewApiErrorResp(err)) - } - - t, err := pi.GetToken(ctx, req.Code) - if err != nil { - ctx.AbortWithStatusJSON(http.StatusBadRequest, model.NewApiErrorResp(err)) - return + return nil, err } ui, err := pi.GetUserInfo(ctx, t) if err != nil { - ctx.AbortWithStatusJSON(http.StatusBadRequest, model.NewApiErrorResp(err)) - return + return nil, err } var user *op.User - if settings.DisableUserSignup.Get() { - user, err = op.GetUserByProvider(p, ui.ProviderUserID) + if meta.Value().BindUserId != "" { + user, err = op.GetUserById(meta.Value().BindUserId) + } else if settings.DisableUserSignup.Get() { + user, err = op.GetUserByProvider(pi.Provider(), ui.ProviderUserID) } else { - user, err = op.CreateOrLoadUser(ui.Username, p, ui.ProviderUserID) + if settings.SignupNeedReview.Get() { + user, err = op.CreateOrLoadUser(ui.Username, pi.Provider(), ui.ProviderUserID, db.WithRole(dbModel.RolePending)) + } else { + user, err = op.CreateOrLoadUser(ui.Username, pi.Provider(), ui.ProviderUserID) + } } if err != nil { - ctx.AbortWithStatusJSON(http.StatusInternalServerError, model.NewApiErrorResp(err)) - return + return nil, err + } + + if meta.Value().BindUserId != "" { + err = op.BindProvider(meta.Value().BindUserId, pi.Provider(), ui.ProviderUserID) + if err != nil { + return nil, err + } } token, err := middlewares.NewAuthUserToken(user) if err != nil { - ctx.AbortWithStatusJSON(http.StatusInternalServerError, model.NewApiErrorResp(err)) - return + return nil, err } - ctx.JSON(http.StatusOK, model.NewApiDataResp(gin.H{ - "token": token, - })) + redirect := "/web/" + if meta.Value().Redirect != "" { + redirect = meta.Value().Redirect + } + + return &loginData{ + token: token, + redirect: redirect, + }, nil } diff --git a/server/oauth2/bind.go b/server/oauth2/bind.go new file mode 100644 index 0000000..c250661 --- /dev/null +++ b/server/oauth2/bind.go @@ -0,0 +1,57 @@ +package auth + +import ( + "net/http" + "time" + + "github.com/gin-gonic/gin" + "github.com/synctv-org/synctv/internal/db" + "github.com/synctv-org/synctv/internal/op" + "github.com/synctv-org/synctv/internal/provider" + "github.com/synctv-org/synctv/internal/provider/providers" + "github.com/synctv-org/synctv/server/model" + "github.com/synctv-org/synctv/utils" +) + +func BindApi(ctx *gin.Context) { + user := ctx.MustGet("user").(*op.User) + + pi, err := providers.GetProvider(provider.OAuth2Provider(ctx.Param("type"))) + if err != nil { + ctx.AbortWithStatusJSON(http.StatusBadRequest, model.NewApiErrorResp(err)) + } + + meta := model.OAuth2Req{} + if err := model.Decode(ctx, &meta); err != nil { + ctx.AbortWithStatusJSON(http.StatusBadRequest, model.NewApiErrorResp(err)) + return + } + + state := utils.RandString(16) + states.Store(state, stateMeta{ + OAuth2Req: meta, + BindUserId: user.ID, + }, time.Minute*5) + + ctx.JSON(http.StatusOK, model.NewApiDataResp(gin.H{ + "url": pi.NewAuthURL(state), + })) +} + +func UnBindApi(ctx *gin.Context) { + user := ctx.MustGet("user").(*op.User) + + pi, err := providers.GetProvider(provider.OAuth2Provider(ctx.Param("type"))) + if err != nil { + ctx.AbortWithStatusJSON(http.StatusBadRequest, model.NewApiErrorResp(err)) + return + } + + err = db.UnBindProvider(user.ID, pi.Provider()) + if err != nil { + ctx.AbortWithStatusJSON(http.StatusBadRequest, model.NewApiErrorResp(err)) + return + } + + ctx.Status(http.StatusNoContent) +} diff --git a/server/oauth2/init.go b/server/oauth2/init.go index 3f199fc..c6d8f17 100644 --- a/server/oauth2/init.go +++ b/server/oauth2/init.go @@ -1,19 +1,28 @@ package auth -import "github.com/gin-gonic/gin" +import ( + "github.com/gin-gonic/gin" + "github.com/synctv-org/synctv/server/middlewares" +) func Init(e *gin.Engine) { { - auth := e.Group("/oauth2") + oauth2 := e.Group("/oauth2") + needAuthOauth2 := oauth2.Group("") + needAuthOauth2.Use(middlewares.AuthUserMiddleware) - auth.GET("/enabled", OAuth2EnabledApi) + oauth2.GET("/enabled", OAuth2EnabledApi) - auth.GET("/login/:type", OAuth2) + oauth2.GET("/login/:type", OAuth2) - auth.POST("/login/:type", OAuth2Api) + oauth2.POST("/login/:type", OAuth2Api) - auth.GET("/callback/:type", OAuth2Callback) + oauth2.GET("/callback/:type", OAuth2Callback) - auth.POST("/callback/:type", OAuth2CallbackApi) + oauth2.POST("/callback/:type", OAuth2CallbackApi) + + needAuthOauth2.POST("/bind/:type", BindApi) + + needAuthOauth2.POST("/unbind/:type", UnBindApi) } } diff --git a/server/oauth2/render.go b/server/oauth2/render.go index 3a3f6ac..972d56d 100644 --- a/server/oauth2/render.go +++ b/server/oauth2/render.go @@ -6,6 +6,7 @@ import ( "time" "github.com/gin-gonic/gin" + "github.com/synctv-org/synctv/server/model" "github.com/zijiren233/gencontainer/synccache" ) @@ -15,9 +16,14 @@ var temp embed.FS var ( redirectTemplate *template.Template tokenTemplate *template.Template - states *synccache.SyncCache[string, struct{}] + states *synccache.SyncCache[string, stateMeta] ) +type stateMeta struct { + model.OAuth2Req + BindUserId string +} + func RenderRedirect(ctx *gin.Context, url string) error { ctx.Header("Content-Type", "text/html; charset=utf-8") return redirectTemplate.Execute(ctx.Writer, url) @@ -31,5 +37,5 @@ func RenderToken(ctx *gin.Context, url, token string) error { func init() { redirectTemplate = template.Must(template.ParseFS(temp, "templates/redirect.html")) tokenTemplate = template.Must(template.ParseFS(temp, "templates/token.html")) - states = synccache.NewSyncCache[string, struct{}](time.Minute * 10) + states = synccache.NewSyncCache[string, stateMeta](time.Minute * 10) }