Feat: user bind unbind providers

pull/39/head
zijiren233 1 year ago
parent 82ca47be58
commit 7d4a60a5b3

@ -72,6 +72,35 @@ func GetProviderUserID(p provider.OAuth2Provider, puid string) (string, error) {
return userProvider.UserID, nil
}
func BindProvider(uid string, p provider.OAuth2Provider, puid string) error {
err := db.Create(&model.UserProvider{
UserID: uid,
Provider: p,
ProviderUserID: puid,
}).Error
if err != nil && errors.Is(err, gorm.ErrDuplicatedKey) {
return errors.New("provider already bind")
}
return err
}
func UnBindProvider(uid string, p provider.OAuth2Provider) error {
err := db.Where("user_id = ? AND provider = ?", uid, p).Delete(&model.UserProvider{}).Error
if errors.Is(err, gorm.ErrRecordNotFound) {
return errors.New("user could not bind provider")
}
return err
}
func GetBindProviders(uid string) ([]*model.UserProvider, error) {
var providers []*model.UserProvider
err := db.Where("user_id = ?", uid).Find(&providers).Error
if errors.Is(err, gorm.ErrRecordNotFound) {
return providers, errors.New("user not found")
}
return providers, err
}
func GetUserByUsername(username string) (*model.User, error) {
u := &model.User{}
err := db.Where("username = ?", username).First(u).Error

@ -84,6 +84,14 @@ func GetUserByProvider(p provider.OAuth2Provider, pid string) (*User, error) {
return GetUserById(uid)
}
func BindProvider(uid string, p provider.OAuth2Provider, pid string) error {
err := db.BindProvider(uid, p, pid)
if err != nil {
return err
}
return nil
}
func DeleteUserByID(userID string) error {
err := db.DeleteUserByID(userID)
if err != nil {

@ -111,3 +111,24 @@ func SetUsername(ctx *gin.Context) {
ctx.Status(http.StatusNoContent)
}
func UserBindProviders(ctx *gin.Context) {
user := ctx.MustGet("user").(*op.User)
up, err := db.GetBindProviders(user.ID)
if err != nil {
ctx.AbortWithStatusJSON(http.StatusInternalServerError, model.NewApiErrorResp(err))
return
}
resp := make([]model.UserBindProviderReq, len(up))
for i, v := range up {
resp[i] = model.UserBindProviderReq{
Provider: v.Provider,
ProviderUserID: v.ProviderUserID,
CreatedAt: v.CreatedAt.UnixMilli(),
}
}
ctx.JSON(http.StatusOK, resp)
}

@ -30,3 +30,15 @@ func (o *OAuth2CallbackReq) Validate() error {
func (o *OAuth2CallbackReq) Decode(ctx *gin.Context) error {
return json.NewDecoder(ctx.Request.Body).Decode(o)
}
type OAuth2Req struct {
Redirect string `json:"redirect"`
}
func (o *OAuth2Req) Validate() error {
return nil
}
func (o *OAuth2Req) Decode(ctx *gin.Context) error {
return json.NewDecoder(ctx.Request.Body).Decode(o)
}

@ -6,6 +6,7 @@ import (
"github.com/gin-gonic/gin"
json "github.com/json-iterator/go"
dbModel "github.com/synctv-org/synctv/internal/model"
"github.com/synctv-org/synctv/internal/provider"
)
type SetUserPasswordReq struct {
@ -91,3 +92,9 @@ func (u *UserIDReq) Validate() error {
}
return nil
}
type UserBindProviderReq struct {
Provider provider.OAuth2Provider `json:"provider"`
ProviderUserID string `json:"providerUserID"`
CreatedAt int64 `json:"createdAt"`
}

@ -1,6 +1,8 @@
package auth
import (
"context"
"errors"
"net/http"
"time"
@ -16,37 +18,49 @@ import (
"github.com/synctv-org/synctv/utils"
)
// GET
// /oauth2/login/:type
func OAuth2(ctx *gin.Context) {
t := ctx.Param("type")
pi, err := providers.GetProvider(provider.OAuth2Provider(t))
pi, err := providers.GetProvider(provider.OAuth2Provider(ctx.Param("type")))
if err != nil {
ctx.AbortWithStatusJSON(http.StatusBadRequest, model.NewApiErrorResp(err))
return
}
state := utils.RandString(16)
states.Store(state, struct{}{}, time.Minute*5)
states.Store(state, stateMeta{
OAuth2Req: model.OAuth2Req{
Redirect: ctx.Query("redirect"),
},
}, time.Minute*5)
RenderRedirect(ctx, pi.NewAuthURL(state))
}
// POST
func OAuth2Api(ctx *gin.Context) {
t := ctx.Param("type")
pi, err := providers.GetProvider(provider.OAuth2Provider(t))
pi, err := providers.GetProvider(provider.OAuth2Provider(ctx.Param("type")))
if err != nil {
ctx.AbortWithStatusJSON(http.StatusBadRequest, model.NewApiErrorResp(err))
}
meta := model.OAuth2Req{}
if err := model.Decode(ctx, &meta); err != nil {
ctx.AbortWithStatusJSON(http.StatusBadRequest, model.NewApiErrorResp(err))
return
}
state := utils.RandString(16)
states.Store(state, struct{}{}, time.Minute*5)
states.Store(state, stateMeta{
OAuth2Req: meta,
}, time.Minute*5)
ctx.JSON(http.StatusOK, model.NewApiDataResp(gin.H{
"url": pi.NewAuthURL(state),
}))
}
// GET
// /oauth2/callback/:type
func OAuth2Callback(ctx *gin.Context) {
code := ctx.Query("code")
@ -55,111 +69,102 @@ func OAuth2Callback(ctx *gin.Context) {
return
}
state := ctx.Query("state")
if state == "" {
ctx.AbortWithStatusJSON(http.StatusBadRequest, model.NewApiErrorStringResp("invalid oauth2 state"))
return
}
_, loaded := states.LoadAndDelete(state)
if !loaded {
ctx.AbortWithStatusJSON(http.StatusBadRequest, model.NewApiErrorStringResp("invalid oauth2 state"))
return
}
p := provider.OAuth2Provider(ctx.Param("type"))
pi, err := providers.GetProvider(p)
pi, err := providers.GetProvider(provider.OAuth2Provider(ctx.Param("type")))
if err != nil {
ctx.AbortWithStatusJSON(http.StatusBadRequest, model.NewApiErrorResp(err))
return
}
t, err := pi.GetToken(ctx, code)
ld, err := login(ctx, ctx.Query("state"), code, pi)
if err != nil {
ctx.AbortWithStatusJSON(http.StatusBadRequest, model.NewApiErrorResp(err))
return
}
ui, err := pi.GetUserInfo(ctx, t)
if err != nil {
RenderToken(ctx, ld.redirect, ld.token)
}
// POST
// /oauth2/callback/:type
func OAuth2CallbackApi(ctx *gin.Context) {
req := model.OAuth2CallbackReq{}
if err := req.Decode(ctx); err != nil {
ctx.AbortWithStatusJSON(http.StatusBadRequest, model.NewApiErrorResp(err))
return
}
var user *op.User
if settings.DisableUserSignup.Get() {
user, err = op.GetUserByProvider(p, ui.ProviderUserID)
} else {
if settings.SignupNeedReview.Get() {
user, err = op.CreateOrLoadUser(ui.Username, p, ui.ProviderUserID, db.WithRole(dbModel.RolePending))
} else {
user, err = op.CreateOrLoadUser(ui.Username, p, ui.ProviderUserID)
}
}
pi, err := providers.GetProvider(provider.OAuth2Provider(ctx.Param("type")))
if err != nil {
ctx.AbortWithStatusJSON(http.StatusInternalServerError, model.NewApiErrorResp(err))
return
ctx.AbortWithStatusJSON(http.StatusBadRequest, model.NewApiErrorResp(err))
}
token, err := middlewares.NewAuthUserToken(user)
ld, err := login(ctx, req.State, req.Code, pi)
if err != nil {
ctx.AbortWithStatusJSON(http.StatusInternalServerError, model.NewApiErrorResp(err))
ctx.AbortWithStatusJSON(http.StatusBadRequest, model.NewApiErrorResp(err))
return
}
RenderToken(ctx, "/web/", token)
ctx.JSON(http.StatusOK, model.NewApiDataResp(gin.H{
"token": ld.token,
"redirect": ld.redirect,
}))
}
// /oauth2/callback/:type
func OAuth2CallbackApi(ctx *gin.Context) {
req := model.OAuth2CallbackReq{}
if err := req.Decode(ctx); err != nil {
ctx.AbortWithStatusJSON(http.StatusBadRequest, model.NewApiErrorResp(err))
return
type loginData struct {
token, redirect string
}
_, loaded := states.LoadAndDelete(req.State)
func login(ctx context.Context, state, code string, pi provider.ProviderInterface) (*loginData, error) {
meta, loaded := states.LoadAndDelete(state)
if !loaded {
ctx.AbortWithStatusJSON(http.StatusBadRequest, model.NewApiErrorStringResp("invalid oauth2 state"))
return
}
p := provider.OAuth2Provider(ctx.Param("type"))
pi, err := providers.GetProvider(p)
if err != nil {
ctx.AbortWithStatusJSON(http.StatusBadRequest, model.NewApiErrorResp(err))
return nil, errors.New("invalid oauth2 state")
}
t, err := pi.GetToken(ctx, req.Code)
t, err := pi.GetToken(ctx, code)
if err != nil {
ctx.AbortWithStatusJSON(http.StatusBadRequest, model.NewApiErrorResp(err))
return
return nil, err
}
ui, err := pi.GetUserInfo(ctx, t)
if err != nil {
ctx.AbortWithStatusJSON(http.StatusBadRequest, model.NewApiErrorResp(err))
return
return nil, err
}
var user *op.User
if settings.DisableUserSignup.Get() {
user, err = op.GetUserByProvider(p, ui.ProviderUserID)
if meta.Value().BindUserId != "" {
user, err = op.GetUserById(meta.Value().BindUserId)
} else if settings.DisableUserSignup.Get() {
user, err = op.GetUserByProvider(pi.Provider(), ui.ProviderUserID)
} else {
if settings.SignupNeedReview.Get() {
user, err = op.CreateOrLoadUser(ui.Username, pi.Provider(), ui.ProviderUserID, db.WithRole(dbModel.RolePending))
} else {
user, err = op.CreateOrLoadUser(ui.Username, p, ui.ProviderUserID)
user, err = op.CreateOrLoadUser(ui.Username, pi.Provider(), ui.ProviderUserID)
}
}
if err != nil {
ctx.AbortWithStatusJSON(http.StatusInternalServerError, model.NewApiErrorResp(err))
return
return nil, err
}
if meta.Value().BindUserId != "" {
err = op.BindProvider(meta.Value().BindUserId, pi.Provider(), ui.ProviderUserID)
if err != nil {
return nil, err
}
}
token, err := middlewares.NewAuthUserToken(user)
if err != nil {
ctx.AbortWithStatusJSON(http.StatusInternalServerError, model.NewApiErrorResp(err))
return
return nil, err
}
ctx.JSON(http.StatusOK, model.NewApiDataResp(gin.H{
"token": token,
}))
redirect := "/web/"
if meta.Value().Redirect != "" {
redirect = meta.Value().Redirect
}
return &loginData{
token: token,
redirect: redirect,
}, nil
}

@ -0,0 +1,57 @@
package auth
import (
"net/http"
"time"
"github.com/gin-gonic/gin"
"github.com/synctv-org/synctv/internal/db"
"github.com/synctv-org/synctv/internal/op"
"github.com/synctv-org/synctv/internal/provider"
"github.com/synctv-org/synctv/internal/provider/providers"
"github.com/synctv-org/synctv/server/model"
"github.com/synctv-org/synctv/utils"
)
func BindApi(ctx *gin.Context) {
user := ctx.MustGet("user").(*op.User)
pi, err := providers.GetProvider(provider.OAuth2Provider(ctx.Param("type")))
if err != nil {
ctx.AbortWithStatusJSON(http.StatusBadRequest, model.NewApiErrorResp(err))
}
meta := model.OAuth2Req{}
if err := model.Decode(ctx, &meta); err != nil {
ctx.AbortWithStatusJSON(http.StatusBadRequest, model.NewApiErrorResp(err))
return
}
state := utils.RandString(16)
states.Store(state, stateMeta{
OAuth2Req: meta,
BindUserId: user.ID,
}, time.Minute*5)
ctx.JSON(http.StatusOK, model.NewApiDataResp(gin.H{
"url": pi.NewAuthURL(state),
}))
}
func UnBindApi(ctx *gin.Context) {
user := ctx.MustGet("user").(*op.User)
pi, err := providers.GetProvider(provider.OAuth2Provider(ctx.Param("type")))
if err != nil {
ctx.AbortWithStatusJSON(http.StatusBadRequest, model.NewApiErrorResp(err))
return
}
err = db.UnBindProvider(user.ID, pi.Provider())
if err != nil {
ctx.AbortWithStatusJSON(http.StatusBadRequest, model.NewApiErrorResp(err))
return
}
ctx.Status(http.StatusNoContent)
}

@ -1,19 +1,28 @@
package auth
import "github.com/gin-gonic/gin"
import (
"github.com/gin-gonic/gin"
"github.com/synctv-org/synctv/server/middlewares"
)
func Init(e *gin.Engine) {
{
auth := e.Group("/oauth2")
oauth2 := e.Group("/oauth2")
needAuthOauth2 := oauth2.Group("")
needAuthOauth2.Use(middlewares.AuthUserMiddleware)
auth.GET("/enabled", OAuth2EnabledApi)
oauth2.GET("/enabled", OAuth2EnabledApi)
auth.GET("/login/:type", OAuth2)
oauth2.GET("/login/:type", OAuth2)
auth.POST("/login/:type", OAuth2Api)
oauth2.POST("/login/:type", OAuth2Api)
auth.GET("/callback/:type", OAuth2Callback)
oauth2.GET("/callback/:type", OAuth2Callback)
auth.POST("/callback/:type", OAuth2CallbackApi)
oauth2.POST("/callback/:type", OAuth2CallbackApi)
needAuthOauth2.POST("/bind/:type", BindApi)
needAuthOauth2.POST("/unbind/:type", UnBindApi)
}
}

@ -6,6 +6,7 @@ import (
"time"
"github.com/gin-gonic/gin"
"github.com/synctv-org/synctv/server/model"
"github.com/zijiren233/gencontainer/synccache"
)
@ -15,9 +16,14 @@ var temp embed.FS
var (
redirectTemplate *template.Template
tokenTemplate *template.Template
states *synccache.SyncCache[string, struct{}]
states *synccache.SyncCache[string, stateMeta]
)
type stateMeta struct {
model.OAuth2Req
BindUserId string
}
func RenderRedirect(ctx *gin.Context, url string) error {
ctx.Header("Content-Type", "text/html; charset=utf-8")
return redirectTemplate.Execute(ctx.Writer, url)
@ -31,5 +37,5 @@ func RenderToken(ctx *gin.Context, url, token string) error {
func init() {
redirectTemplate = template.Must(template.ParseFS(temp, "templates/redirect.html"))
tokenTemplate = template.Must(template.ParseFS(temp, "templates/token.html"))
states = synccache.NewSyncCache[string, struct{}](time.Minute * 10)
states = synccache.NewSyncCache[string, stateMeta](time.Minute * 10)
}

Loading…
Cancel
Save