From 688f093558b6d7efb1520667005cda3ac39934cb Mon Sep 17 00:00:00 2001 From: zijiren233 Date: Wed, 6 Dec 2023 20:22:02 +0800 Subject: [PATCH] Opt: oauth2 state handler --- server/oauth2/auth.go | 130 +++++++++++++++++----------------------- server/oauth2/bind.go | 45 ++++++++++++-- server/oauth2/render.go | 11 ++-- 3 files changed, 99 insertions(+), 87 deletions(-) diff --git a/server/oauth2/auth.go b/server/oauth2/auth.go index d8e50a7..ded717b 100644 --- a/server/oauth2/auth.go +++ b/server/oauth2/auth.go @@ -1,8 +1,6 @@ package auth import ( - "context" - "errors" "fmt" "net/http" "time" @@ -30,11 +28,7 @@ func OAuth2(ctx *gin.Context) { } state := utils.RandString(16) - states.Store(state, stateMeta{ - OAuth2Req: model.OAuth2Req{ - Redirect: ctx.Query("redirect"), - }, - }, time.Minute*5) + states.Store(state, newAuthFunc(ctx.Query("redirect")), time.Minute*5) RenderRedirect(ctx, pi.NewAuthURL(state)) } @@ -53,9 +47,7 @@ func OAuth2Api(ctx *gin.Context) { } state := utils.RandString(16) - states.Store(state, stateMeta{ - OAuth2Req: meta, - }, time.Minute*5) + states.Store(state, newAuthFunc(meta.Redirect), time.Minute*5) ctx.JSON(http.StatusOK, model.NewApiDataResp(gin.H{ "url": pi.NewAuthURL(state), @@ -77,17 +69,15 @@ func OAuth2Callback(ctx *gin.Context) { return } - ld, err := login(ctx, ctx.Query("state"), code, pi) - if err != nil { - if err == op.ErrUserBanned || err == op.ErrUserPending { - ctx.AbortWithStatusJSON(http.StatusForbidden, model.NewApiErrorResp(err)) - return - } - ctx.AbortWithStatusJSON(http.StatusBadRequest, model.NewApiErrorResp(err)) + meta, loaded := states.LoadAndDelete(ctx.Query("state")) + if !loaded { + ctx.AbortWithStatusJSON(http.StatusBadRequest, model.NewApiErrorStringResp("invalid oauth2 state")) return } - RenderToken(ctx, ld.redirect, ld.token) + if meta.Value() != nil { + meta.Value()(ctx, pi, code) + } } // POST @@ -104,77 +94,65 @@ func OAuth2CallbackApi(ctx *gin.Context) { ctx.AbortWithStatusJSON(http.StatusBadRequest, model.NewApiErrorResp(err)) } - ld, err := login(ctx, req.State, req.Code, pi) - if err != nil { - if err == op.ErrUserBanned || err == op.ErrUserPending { - ctx.AbortWithStatusJSON(http.StatusForbidden, model.NewApiErrorResp(err)) - return - } - ctx.AbortWithStatusJSON(http.StatusBadRequest, model.NewApiErrorResp(err)) + meta, loaded := states.LoadAndDelete(req.State) + if !loaded { + ctx.AbortWithStatusJSON(http.StatusBadRequest, model.NewApiErrorStringResp("invalid oauth2 state")) return } - ctx.JSON(http.StatusOK, model.NewApiDataResp(gin.H{ - "token": ld.token, - "redirect": ld.redirect, - })) -} - -type loginData struct { - token, redirect string -} - -func login(ctx context.Context, state, code string, pi provider.ProviderInterface) (*loginData, error) { - meta, loaded := states.LoadAndDelete(state) - if !loaded { - return nil, errors.New("invalid oauth2 state") + if meta.Value() != nil { + meta.Value()(ctx, pi, req.Code) } +} - t, err := pi.GetToken(ctx, code) - if err != nil { - return nil, err - } +func newAuthFunc(redirect string) stateHandler { + return func(ctx *gin.Context, pi provider.ProviderInterface, code string) { + t, err := pi.GetToken(ctx, code) + if err != nil { + ctx.AbortWithStatusJSON(http.StatusBadRequest, model.NewApiErrorResp(err)) + return + } - ui, err := pi.GetUserInfo(ctx, t) - if err != nil { - return nil, err - } + ui, err := pi.GetUserInfo(ctx, t) + if err != nil { + ctx.AbortWithStatusJSON(http.StatusBadRequest, model.NewApiErrorResp(err)) + return + } - pgs, loaded := bootstrap.ProviderGroupSettings[dbModel.SettingGroup(fmt.Sprintf("%s_%s", dbModel.SettingGroupOauth2, pi.Provider()))] - if !loaded { - return nil, errors.New("invalid oauth2 provider") - } + pgs, loaded := bootstrap.ProviderGroupSettings[dbModel.SettingGroup(fmt.Sprintf("%s_%s", dbModel.SettingGroupOauth2, pi.Provider()))] + if !loaded { + ctx.AbortWithStatusJSON(http.StatusBadRequest, model.NewApiErrorStringResp("invalid oauth2 provider")) + return + } - var user *op.User - if meta.Value().BindUserId != "" { - user, err = op.LoadOrInitUserByID(meta.Value().BindUserId) - } else if settings.DisableUserSignup.Get() || pgs.DisableUserSignup.Get() { - user, err = op.GetUserByProvider(pi.Provider(), ui.ProviderUserID) - } else { - if settings.SignupNeedReview.Get() || pgs.SignupNeedReview.Get() { - user, err = op.CreateOrLoadUserWithProvider(ui.Username, utils.RandString(16), pi.Provider(), ui.ProviderUserID, db.WithRole(dbModel.RolePending)) + var user *op.User + if settings.DisableUserSignup.Get() || pgs.DisableUserSignup.Get() { + user, err = op.GetUserByProvider(pi.Provider(), ui.ProviderUserID) } else { - user, err = op.CreateOrLoadUserWithProvider(ui.Username, utils.RandString(16), pi.Provider(), ui.ProviderUserID) + if settings.SignupNeedReview.Get() || pgs.SignupNeedReview.Get() { + user, err = op.CreateOrLoadUserWithProvider(ui.Username, utils.RandString(16), pi.Provider(), ui.ProviderUserID, db.WithRole(dbModel.RolePending)) + } else { + user, err = op.CreateOrLoadUserWithProvider(ui.Username, utils.RandString(16), pi.Provider(), ui.ProviderUserID) + } + } + if err != nil { + ctx.AbortWithStatusJSON(http.StatusBadRequest, model.NewApiErrorResp(err)) + return } - } - if err != nil { - return nil, err - } - if meta.Value().BindUserId != "" { - err = user.BindProvider(pi.Provider(), ui.ProviderUserID) + token, err := middlewares.NewAuthUserToken(user) if err != nil { - return nil, err + ctx.AbortWithStatusJSON(http.StatusBadRequest, model.NewApiErrorResp(err)) + return } - } - token, err := middlewares.NewAuthUserToken(user) - if err != nil { - return nil, err + if ctx.Request.Method == http.MethodGet { + RenderToken(ctx, redirect, token) + } else if ctx.Request.Method == http.MethodPost { + ctx.JSON(http.StatusOK, model.NewApiDataResp(gin.H{ + "token": token, + "redirect": redirect, + })) + } } - - return &loginData{ - token: token, - redirect: meta.Value().Redirect, - }, nil } diff --git a/server/oauth2/bind.go b/server/oauth2/bind.go index c250661..51914cd 100644 --- a/server/oauth2/bind.go +++ b/server/oauth2/bind.go @@ -9,6 +9,7 @@ import ( "github.com/synctv-org/synctv/internal/op" "github.com/synctv-org/synctv/internal/provider" "github.com/synctv-org/synctv/internal/provider/providers" + "github.com/synctv-org/synctv/server/middlewares" "github.com/synctv-org/synctv/server/model" "github.com/synctv-org/synctv/utils" ) @@ -28,10 +29,7 @@ func BindApi(ctx *gin.Context) { } state := utils.RandString(16) - states.Store(state, stateMeta{ - OAuth2Req: meta, - BindUserId: user.ID, - }, time.Minute*5) + states.Store(state, newBindFunc(user.ID, meta.Redirect), time.Minute*5) ctx.JSON(http.StatusOK, model.NewApiDataResp(gin.H{ "url": pi.NewAuthURL(state), @@ -55,3 +53,42 @@ func UnBindApi(ctx *gin.Context) { ctx.Status(http.StatusNoContent) } + +func newBindFunc(userID, redirect string) stateHandler { + return func(ctx *gin.Context, pi provider.ProviderInterface, code string) { + t, err := pi.GetToken(ctx, code) + if err != nil { + ctx.AbortWithStatusJSON(http.StatusBadRequest, model.NewApiErrorResp(err)) + return + } + + ui, err := pi.GetUserInfo(ctx, t) + if err != nil { + ctx.AbortWithStatusJSON(http.StatusBadRequest, model.NewApiErrorResp(err)) + return + } + + user, err := op.LoadOrInitUserByID(userID) + if err != nil { + ctx.AbortWithStatusJSON(http.StatusBadRequest, model.NewApiErrorResp(err)) + return + } + + err = user.BindProvider(pi.Provider(), ui.ProviderUserID) + if err != nil { + ctx.AbortWithStatusJSON(http.StatusBadRequest, model.NewApiErrorResp(err)) + return + } + + token, err := middlewares.NewAuthUserToken(user) + if err != nil { + ctx.AbortWithStatusJSON(http.StatusBadRequest, model.NewApiErrorResp(err)) + return + } + + ctx.JSON(http.StatusOK, model.NewApiDataResp(gin.H{ + "token": token, + "redirect": redirect, + })) + } +} diff --git a/server/oauth2/render.go b/server/oauth2/render.go index 972d56d..0a475d0 100644 --- a/server/oauth2/render.go +++ b/server/oauth2/render.go @@ -6,7 +6,7 @@ import ( "time" "github.com/gin-gonic/gin" - "github.com/synctv-org/synctv/server/model" + "github.com/synctv-org/synctv/internal/provider" "github.com/zijiren233/gencontainer/synccache" ) @@ -16,13 +16,10 @@ var temp embed.FS var ( redirectTemplate *template.Template tokenTemplate *template.Template - states *synccache.SyncCache[string, stateMeta] + states *synccache.SyncCache[string, stateHandler] ) -type stateMeta struct { - model.OAuth2Req - BindUserId string -} +type stateHandler func(ctx *gin.Context, pi provider.ProviderInterface, code string) func RenderRedirect(ctx *gin.Context, url string) error { ctx.Header("Content-Type", "text/html; charset=utf-8") @@ -37,5 +34,5 @@ func RenderToken(ctx *gin.Context, url, token string) error { func init() { redirectTemplate = template.Must(template.ParseFS(temp, "templates/redirect.html")) tokenTemplate = template.Must(template.ParseFS(temp, "templates/token.html")) - states = synccache.NewSyncCache[string, stateMeta](time.Minute * 10) + states = synccache.NewSyncCache[string, stateHandler](time.Minute * 10) }