Feat: settings

pull/21/head
zijiren233 2 years ago
parent d3abf4bc45
commit aa6fa5411f

@ -10,6 +10,8 @@ import (
"github.com/synctv-org/synctv/cmd/flags" "github.com/synctv-org/synctv/cmd/flags"
"github.com/synctv-org/synctv/internal/conf" "github.com/synctv-org/synctv/internal/conf"
"github.com/synctv-org/synctv/internal/provider" "github.com/synctv-org/synctv/internal/provider"
"github.com/synctv-org/synctv/internal/provider/plugins"
"github.com/synctv-org/synctv/internal/provider/providers"
"github.com/synctv-org/synctv/utils" "github.com/synctv-org/synctv/utils"
) )
@ -27,7 +29,7 @@ func InitProvider(ctx context.Context) error {
log.Errorf("create plugin dir: %s failed: %s", filepath.Dir(op.PluginFile), err) log.Errorf("create plugin dir: %s failed: %s", filepath.Dir(op.PluginFile), err)
return err return err
} }
err = provider.InitProviderPlugins(op.PluginFile, op.Arges, hclog.New(&hclog.LoggerOptions{ err = plugins.InitProviderPlugins(op.PluginFile, op.Arges, hclog.New(&hclog.LoggerOptions{
Name: op.PluginFile, Name: op.PluginFile,
Level: logLevle, Level: logLevle,
Output: logOur, Output: logOur,
@ -39,7 +41,7 @@ func InitProvider(ctx context.Context) error {
} }
} }
for op, v := range conf.Conf.OAuth2.Providers { for op, v := range conf.Conf.OAuth2.Providers {
err := provider.InitProvider(op, provider.Oauth2Option{ err := providers.InitProvider(op, provider.Oauth2Option{
ClientID: v.ClientID, ClientID: v.ClientID,
ClientSecret: v.ClientSecret, ClientSecret: v.ClientSecret,
RedirectURL: v.RedirectURL, RedirectURL: v.RedirectURL,

@ -2,7 +2,6 @@ package conf
import ( import (
"github.com/synctv-org/synctv/internal/provider" "github.com/synctv-org/synctv/internal/provider"
"github.com/synctv-org/synctv/internal/provider/providers"
) )
type OAuth2Config struct { type OAuth2Config struct {
@ -23,12 +22,7 @@ type OAuth2ProviderConfig struct {
func DefaultOAuth2Config() OAuth2Config { func DefaultOAuth2Config() OAuth2Config {
return OAuth2Config{ return OAuth2Config{
Providers: map[provider.OAuth2Provider]OAuth2ProviderConfig{ Providers: map[provider.OAuth2Provider]OAuth2ProviderConfig{},
(&providers.GithubProvider{}).Provider(): { Plugins: []Oauth2Plugin{},
ClientID: "",
ClientSecret: "",
RedirectURL: "",
},
},
} }
} }

@ -1,9 +1,9 @@
package conf package conf
type ServerConfig struct { type ServerConfig struct {
Listen string `yaml:"listen" lc:"default: 0.0.0.0" env:"SERVER_LISTEN"` Listen string `yaml:"listen" env:"SERVER_LISTEN"`
Port uint16 `yaml:"port" lc:"default: 8080" env:"SERVER_PORT"` Port uint16 `yaml:"port" env:"SERVER_PORT"`
Quic bool `yaml:"quic" hc:"enable http3/quic, need set cert and key file" env:"SERVER_QUIC"` Quic bool `yaml:"quic" hc:"enable http3/quic need set cert and key file" env:"SERVER_QUIC"`
CertPath string `yaml:"cert_path" env:"SERVER_CERT_PATH"` CertPath string `yaml:"cert_path" env:"SERVER_CERT_PATH"`
KeyPath string `yaml:"key_path" env:"SERVER_KEY_PATH"` KeyPath string `yaml:"key_path" env:"SERVER_KEY_PATH"`

@ -19,7 +19,7 @@ var (
func Init(d *gorm.DB, t conf.DatabaseType) error { func Init(d *gorm.DB, t conf.DatabaseType) error {
db = d db = d
dbType = t dbType = t
return AutoMigrate(new(model.Movie), new(model.Room), new(model.User), new(model.RoomUserRelation), new(model.UserProvider), new(model.SettingItem)) return AutoMigrate(new(model.Movie), new(model.Room), new(model.User), new(model.RoomUserRelation), new(model.UserProvider), new(model.Setting))
} }
func AutoMigrate(dst ...any) error { func AutoMigrate(dst ...any) error {

@ -5,38 +5,38 @@ import (
"gorm.io/gorm/clause" "gorm.io/gorm/clause"
) )
func GetSettingItems() ([]*model.SettingItem, error) { func GetSettingItems() ([]*model.Setting, error) {
var items []*model.SettingItem var items []*model.Setting
err := db.Find(&items).Error err := db.Find(&items).Error
return items, err return items, err
} }
func GetSettingItemByName(name string) (*model.SettingItem, error) { func GetSettingItemByName(name string) (*model.Setting, error) {
var item model.SettingItem var item model.Setting
err := db.Where("name = ?", name).First(&item).Error err := db.Where("name = ?", name).First(&item).Error
return &item, err return &item, err
} }
func SaveSettingItem(item *model.SettingItem) error { func SaveSettingItem(item *model.Setting) error {
return db.Clauses(clause.OnConflict{ return db.Clauses(clause.OnConflict{
UpdateAll: true, UpdateAll: true,
}).Save(item).Error }).Save(item).Error
} }
func DeleteSettingItem(item *model.SettingItem) error { func DeleteSettingItem(item *model.Setting) error {
return db.Delete(item).Error return db.Delete(item).Error
} }
func DeleteSettingItemByName(name string) error { func DeleteSettingItemByName(name string) error {
return db.Where("name = ?", name).Delete(&model.SettingItem{}).Error return db.Where("name = ?", name).Delete(&model.Setting{}).Error
} }
func GetSettingItemValue(name string) (string, error) { func GetSettingItemValue(name string) (string, error) {
var value string var value string
err := db.Model(&model.SettingItem{}).Where("name = ?", name).Select("value").First(&value).Error err := db.Model(&model.Setting{}).Where("name = ?", name).Select("value").First(&value).Error
return value, err return value, err
} }
func SetSettingItemValue(name, value string) error { func SetSettingItemValue(name, value string) error {
return db.Model(&model.SettingItem{}).Where("name = ?", name).Update("value", value).Error return db.Model(&model.Setting{}).Where("name = ?", name).Assign("value", value).FirstOrCreate(&model.Setting{}).Error
} }

@ -1,6 +1,16 @@
package model package model
type SettingItem struct { type SettingType string
const (
SettingTypeBool SettingType = "bool"
SettingTypeInt64 SettingType = "int64"
SettingTypeFloat64 SettingType = "float64"
SettingTypeString SettingType = "string"
)
type Setting struct {
Name string `gorm:"primaryKey"` Name string `gorm:"primaryKey"`
Value string Value string
Type SettingType `gorm:"not null;default:string"`
} }

@ -49,3 +49,11 @@ func (u *User) BeforeCreate(tx *gorm.DB) error {
} }
return nil return nil
} }
func (u *User) IsAdmin() bool {
return u.Role == RoleAdmin
}
func (u *User) IsBanned() bool {
return u.Role == RoleBanned
}

@ -2,6 +2,8 @@ package op
import ( import (
"github.com/bluele/gcache" "github.com/bluele/gcache"
"github.com/synctv-org/synctv/internal/db"
"github.com/synctv-org/synctv/internal/model"
) )
func Init(size int) error { func Init(size int) error {
@ -9,5 +11,20 @@ func Init(size int) error {
LRU(). LRU().
Build() Build()
si, err := db.GetSettingItems()
if err != nil {
panic(err)
}
for _, si2 := range si {
switch si2.Type {
case model.SettingTypeBool:
b, ok := boolSettings[si2.Name]
if ok {
b.value = si2.Value
}
}
}
cleanReg()
return nil return nil
} }

@ -0,0 +1,92 @@
package op
import (
"github.com/synctv-org/synctv/internal/db"
"github.com/synctv-org/synctv/internal/model"
)
var boolSettings map[string]*Bool
type Setting interface {
Name() string
Raw() string
Type() model.SettingType
}
type BoolSetting interface {
Setting
Set(value bool) error
Get() (bool, error)
}
type Bool struct {
name string
value string
}
func NewBool(name, value string) *Bool {
return &Bool{
name: name,
value: value,
}
}
func (b *Bool) Name() string {
return b.name
}
func (b *Bool) Set(value bool) error {
if value {
b.value = "1"
} else {
b.value = "0"
}
return db.SetSettingItemValue(b.name, b.value)
}
func (b *Bool) Get() (bool, error) {
return b.value == "1", nil
}
func (b *Bool) Raw() string {
return b.value
}
func (b *Bool) Type() model.SettingType {
return model.SettingTypeBool
}
type Int64Setting interface {
Set(value int64) error
Get() (int64, error)
Raw() string
}
type Float64Setting interface {
Set(value float64) error
Get() (float64, error)
Raw() string
}
type StringSetting interface {
Set(value string) error
Get() (string, error)
Raw() string
}
func cleanReg() {
boolSettings = nil
}
func newRegBoolSetting(k, v string) BoolSetting {
b := NewBool(k, v)
if boolSettings == nil {
boolSettings = make(map[string]*Bool)
}
boolSettings[k] = b
return b
}
var (
DisableCreateRoom = newRegBoolSetting("disable_create_room", "0")
)

@ -22,6 +22,14 @@ func (u *User) NewMovie(movie model.MovieInfo) model.Movie {
} }
} }
func (u *User) IsAdmin() bool {
return u.Role == model.RoleAdmin
}
func (u *User) IsBanned() bool {
return u.Role == model.RoleBanned
}
func (u *User) HasPermission(roomID uint, permission model.Permission) bool { func (u *User) HasPermission(roomID uint, permission model.Permission) bool {
if u.Role == model.RoleAdmin { if u.Role == model.RoleAdmin {
return true return true

@ -1,18 +1,19 @@
package provider package plugins
import ( import (
"context" "context"
"time" "time"
"github.com/synctv-org/synctv/internal/provider"
providerpb "github.com/synctv-org/synctv/proto/provider" providerpb "github.com/synctv-org/synctv/proto/provider"
"golang.org/x/oauth2" "golang.org/x/oauth2"
) )
type GRPCClient struct{ client providerpb.Oauth2PluginClient } type GRPCClient struct{ client providerpb.Oauth2PluginClient }
var _ ProviderInterface = (*GRPCClient)(nil) var _ provider.ProviderInterface = (*GRPCClient)(nil)
func (c *GRPCClient) Init(o Oauth2Option) { func (c *GRPCClient) Init(o provider.Oauth2Option) {
c.client.Init(context.Background(), &providerpb.InitReq{ c.client.Init(context.Background(), &providerpb.InitReq{
ClientId: o.ClientID, ClientId: o.ClientID,
ClientSecret: o.ClientSecret, ClientSecret: o.ClientSecret,
@ -20,12 +21,12 @@ func (c *GRPCClient) Init(o Oauth2Option) {
}) })
} }
func (c *GRPCClient) Provider() OAuth2Provider { func (c *GRPCClient) Provider() provider.OAuth2Provider {
resp, err := c.client.Provider(context.Background(), &providerpb.Enpty{}) resp, err := c.client.Provider(context.Background(), &providerpb.Enpty{})
if err != nil { if err != nil {
return "" return ""
} }
return OAuth2Provider(resp.Name) return provider.OAuth2Provider(resp.Name)
} }
func (c *GRPCClient) NewAuthURL(state string) string { func (c *GRPCClient) NewAuthURL(state string) string {
@ -64,7 +65,7 @@ func (c *GRPCClient) RefreshToken(ctx context.Context, tk string) (*oauth2.Token
}, nil }, nil
} }
func (c *GRPCClient) GetUserInfo(ctx context.Context, tk *oauth2.Token) (*UserInfo, error) { func (c *GRPCClient) GetUserInfo(ctx context.Context, tk *oauth2.Token) (*provider.UserInfo, error) {
resp, err := c.client.GetUserInfo(ctx, &providerpb.GetUserInfoReq{ resp, err := c.client.GetUserInfo(ctx, &providerpb.GetUserInfoReq{
Token: &providerpb.Token{ Token: &providerpb.Token{
AccessToken: tk.AccessToken, AccessToken: tk.AccessToken,
@ -76,7 +77,7 @@ func (c *GRPCClient) GetUserInfo(ctx context.Context, tk *oauth2.Token) (*UserIn
if err != nil { if err != nil {
return nil, err return nil, err
} }
return &UserInfo{ return &provider.UserInfo{
Username: resp.Username, Username: resp.Username,
ProviderUserID: uint(resp.ProviderUserId), ProviderUserID: uint(resp.ProviderUserId),
}, nil }, nil

@ -7,6 +7,7 @@ import (
plugin "github.com/hashicorp/go-plugin" plugin "github.com/hashicorp/go-plugin"
"github.com/synctv-org/synctv/internal/provider" "github.com/synctv-org/synctv/internal/provider"
"github.com/synctv-org/synctv/internal/provider/plugins"
"golang.org/x/oauth2" "golang.org/x/oauth2"
) )
@ -85,10 +86,10 @@ type giteeUserInfo struct {
func main() { func main() {
var pluginMap = map[string]plugin.Plugin{ var pluginMap = map[string]plugin.Plugin{
"Provider": &provider.ProviderPlugin{Impl: &GiteeProvider{}}, "Provider": &plugins.ProviderPlugin{Impl: &GiteeProvider{}},
} }
plugin.Serve(&plugin.ServeConfig{ plugin.Serve(&plugin.ServeConfig{
HandshakeConfig: provider.HandshakeConfig, HandshakeConfig: plugins.HandshakeConfig,
Plugins: pluginMap, Plugins: pluginMap,
GRPCServer: plugin.DefaultGRPCServer, GRPCServer: plugin.DefaultGRPCServer,
}) })

@ -1,4 +1,4 @@
package provider package plugins
import ( import (
"context" "context"
@ -7,6 +7,8 @@ import (
"github.com/hashicorp/go-hclog" "github.com/hashicorp/go-hclog"
"github.com/hashicorp/go-plugin" "github.com/hashicorp/go-plugin"
"github.com/synctv-org/synctv/internal/provider"
"github.com/synctv-org/synctv/internal/provider/providers"
sysnotify "github.com/synctv-org/synctv/internal/sysNotify" sysnotify "github.com/synctv-org/synctv/internal/sysNotify"
providerpb "github.com/synctv-org/synctv/proto/provider" providerpb "github.com/synctv-org/synctv/proto/provider"
"google.golang.org/grpc" "google.golang.org/grpc"
@ -26,11 +28,11 @@ func InitProviderPlugins(name string, arg []string, Logger hclog.Logger) error {
if err != nil { if err != nil {
return err return err
} }
provider, ok := i.(ProviderInterface) provider, ok := i.(provider.ProviderInterface)
if !ok { if !ok {
return fmt.Errorf("%s not implement ProviderInterface", name) return fmt.Errorf("%s not implement ProviderInterface", name)
} }
RegisterProvider(provider) providers.RegisterProvider(provider)
return nil return nil
} }
@ -46,7 +48,7 @@ var pluginMap = map[string]plugin.Plugin{
type ProviderPlugin struct { type ProviderPlugin struct {
plugin.Plugin plugin.Plugin
Impl ProviderInterface Impl provider.ProviderInterface
} }
func (p *ProviderPlugin) GRPCServer(broker *plugin.GRPCBroker, s *grpc.Server) error { func (p *ProviderPlugin) GRPCServer(broker *plugin.GRPCBroker, s *grpc.Server) error {

@ -1,20 +1,21 @@
package provider package plugins
import ( import (
"context" "context"
"time" "time"
"github.com/synctv-org/synctv/internal/provider"
providerpb "github.com/synctv-org/synctv/proto/provider" providerpb "github.com/synctv-org/synctv/proto/provider"
"golang.org/x/oauth2" "golang.org/x/oauth2"
) )
type GRPCServer struct { type GRPCServer struct {
providerpb.UnimplementedOauth2PluginServer providerpb.UnimplementedOauth2PluginServer
Impl ProviderInterface Impl provider.ProviderInterface
} }
func (s *GRPCServer) Init(ctx context.Context, req *providerpb.InitReq) (*providerpb.Enpty, error) { func (s *GRPCServer) Init(ctx context.Context, req *providerpb.InitReq) (*providerpb.Enpty, error) {
s.Impl.Init(Oauth2Option{ s.Impl.Init(provider.Oauth2Option{
ClientID: req.ClientId, ClientID: req.ClientId,
ClientSecret: req.ClientSecret, ClientSecret: req.ClientSecret,
RedirectURL: req.RedirectUrl, RedirectURL: req.RedirectUrl,

@ -2,7 +2,6 @@ package provider
import ( import (
"context" "context"
"fmt"
"golang.org/x/oauth2" "golang.org/x/oauth2"
) )
@ -28,49 +27,3 @@ type ProviderInterface interface {
RefreshToken(context.Context, string) (*oauth2.Token, error) RefreshToken(context.Context, string) (*oauth2.Token, error)
GetUserInfo(context.Context, *oauth2.Token) (*UserInfo, error) GetUserInfo(context.Context, *oauth2.Token) (*UserInfo, error)
} }
var (
enabledProviders map[OAuth2Provider]ProviderInterface
allowedProviders = make(map[OAuth2Provider]ProviderInterface)
)
func InitProvider(p OAuth2Provider, c Oauth2Option) error {
pi, ok := allowedProviders[p]
if !ok {
return FormatErrNotImplemented(p)
}
pi.Init(c)
if enabledProviders == nil {
enabledProviders = make(map[OAuth2Provider]ProviderInterface)
}
enabledProviders[pi.Provider()] = pi
return nil
}
func RegisterProvider(ps ...ProviderInterface) {
for _, p := range ps {
allowedProviders[p.Provider()] = p
}
}
func GetProvider(p OAuth2Provider) (ProviderInterface, error) {
pi, ok := enabledProviders[p]
if !ok {
return nil, FormatErrNotImplemented(p)
}
return pi, nil
}
func AllowedProvider() map[OAuth2Provider]ProviderInterface {
return allowedProviders
}
func EnabledProvider() map[OAuth2Provider]ProviderInterface {
return enabledProviders
}
type FormatErrNotImplemented string
func (f FormatErrNotImplemented) Error() string {
return fmt.Sprintf("%s not implemented", string(f))
}

@ -67,13 +67,13 @@ func (p *BaiduNetDiskProvider) GetUserInfo(ctx context.Context, tk *oauth2.Token
}, nil }, nil
} }
func init() {
provider.RegisterProvider(new(BaiduNetDiskProvider))
}
type baiduNetDiskProviderUserInfo struct { type baiduNetDiskProviderUserInfo struct {
BaiduName string `json:"baidu_name"` BaiduName string `json:"baidu_name"`
Errmsg string `json:"errmsg"` Errmsg string `json:"errmsg"`
Errno int `json:"errno"` Errno int `json:"errno"`
Uk uint `json:"uk"` Uk uint `json:"uk"`
} }
func init() {
RegisterProvider(new(BaiduNetDiskProvider))
}

@ -67,7 +67,7 @@ func (p *BaiduProvider) GetUserInfo(ctx context.Context, tk *oauth2.Token) (*pro
} }
func init() { func init() {
provider.RegisterProvider(new(BaiduProvider)) RegisterProvider(new(BaiduProvider))
} }
type baiduProviderUserInfo struct { type baiduProviderUserInfo struct {

@ -68,5 +68,5 @@ type giteeUserInfo struct {
} }
func init() { func init() {
provider.RegisterProvider(new(GiteeProvider)) RegisterProvider(new(GiteeProvider))
} }

@ -66,5 +66,5 @@ type githubUserInfo struct {
} }
func init() { func init() {
provider.RegisterProvider(new(GithubProvider)) RegisterProvider(new(GithubProvider))
} }

@ -48,9 +48,9 @@ func (g *GitlabProvider) GetUserInfo(ctx context.Context, tk *oauth2.Token) (*pr
return nil, err return nil, err
} }
defer resp.Body.Close() defer resp.Body.Close()
return nil, provider.FormatErrNotImplemented("gitlab") return nil, FormatErrNotImplemented("gitlab")
} }
func init() { func init() {
provider.RegisterProvider(new(GitlabProvider)) RegisterProvider(new(GitlabProvider))
} }

@ -61,7 +61,7 @@ func (g *GoogleProvider) GetUserInfo(ctx context.Context, tk *oauth2.Token) (*pr
} }
func init() { func init() {
provider.RegisterProvider(new(GoogleProvider)) RegisterProvider(new(GoogleProvider))
} }
type googleUserInfo struct { type googleUserInfo struct {

@ -68,5 +68,5 @@ type microsoftUserInfo struct {
} }
func init() { func init() {
provider.RegisterProvider(new(MicrosoftProvider)) RegisterProvider(new(MicrosoftProvider))
} }

@ -0,0 +1,53 @@
package providers
import (
"fmt"
"github.com/synctv-org/synctv/internal/provider"
)
var (
enabledProviders map[provider.OAuth2Provider]provider.ProviderInterface
allowedProviders = make(map[provider.OAuth2Provider]provider.ProviderInterface)
)
func InitProvider(p provider.OAuth2Provider, c provider.Oauth2Option) error {
pi, ok := allowedProviders[p]
if !ok {
return FormatErrNotImplemented(p)
}
pi.Init(c)
if enabledProviders == nil {
enabledProviders = make(map[provider.OAuth2Provider]provider.ProviderInterface)
}
enabledProviders[pi.Provider()] = pi
return nil
}
func RegisterProvider(ps ...provider.ProviderInterface) {
for _, p := range ps {
allowedProviders[p.Provider()] = p
}
}
func GetProvider(p provider.OAuth2Provider) (provider.ProviderInterface, error) {
pi, ok := enabledProviders[p]
if !ok {
return nil, FormatErrNotImplemented(p)
}
return pi, nil
}
func AllowedProvider() map[provider.OAuth2Provider]provider.ProviderInterface {
return allowedProviders
}
func EnabledProvider() map[provider.OAuth2Provider]provider.ProviderInterface {
return enabledProviders
}
type FormatErrNotImplemented string
func (f FormatErrNotImplemented) Error() string {
return fmt.Sprintf("%s not implemented", string(f))
}

@ -0,0 +1,7 @@
package handlers
import "github.com/gin-gonic/gin"
func AdminSettings(ctx *gin.Context) {
}

@ -30,6 +30,17 @@ func (e FormatErrNotSupportPosition) Error() string {
func CreateRoom(ctx *gin.Context) { func CreateRoom(ctx *gin.Context) {
user := ctx.MustGet("user").(*op.User) user := ctx.MustGet("user").(*op.User)
v, err := op.DisableCreateRoom.Get()
if err != nil {
ctx.AbortWithStatusJSON(http.StatusInternalServerError, model.NewApiErrorResp(err))
return
}
if v && !user.IsAdmin() {
ctx.AbortWithStatusJSON(http.StatusForbidden, model.NewApiErrorStringResp("create room is disabled"))
return
}
req := model.CreateRoomReq{} req := model.CreateRoomReq{}
if err := model.Decode(ctx, &req); err != nil { if err := model.Decode(ctx, &req); err != nil {
ctx.AbortWithStatusJSON(http.StatusBadRequest, model.NewApiErrorResp(err)) ctx.AbortWithStatusJSON(http.StatusBadRequest, model.NewApiErrorResp(err))

@ -8,7 +8,6 @@ import (
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/golang-jwt/jwt/v5" "github.com/golang-jwt/jwt/v5"
"github.com/synctv-org/synctv/internal/conf" "github.com/synctv-org/synctv/internal/conf"
dbModel "github.com/synctv-org/synctv/internal/model"
"github.com/synctv-org/synctv/internal/op" "github.com/synctv-org/synctv/internal/op"
"github.com/synctv-org/synctv/server/model" "github.com/synctv-org/synctv/server/model"
"github.com/zijiren233/stream" "github.com/zijiren233/stream"
@ -76,7 +75,7 @@ func AuthRoom(Authorization string) (*op.User, *op.Room, error) {
if err != nil { if err != nil {
return nil, nil, err return nil, nil, err
} }
if u.Role == dbModel.RoleBanned { if u.IsBanned() {
return nil, nil, errors.New("user banned") return nil, nil, errors.New("user banned")
} }
@ -105,7 +104,7 @@ func AuthUser(Authorization string) (*op.User, error) {
if err != nil { if err != nil {
return nil, err return nil, err
} }
if u.Role == dbModel.RoleBanned { if u.IsBanned() {
return nil, errors.New("user banned") return nil, errors.New("user banned")
} }
@ -113,7 +112,7 @@ func AuthUser(Authorization string) (*op.User, error) {
} }
func NewAuthUserToken(user *op.User) (string, error) { func NewAuthUserToken(user *op.User) (string, error) {
if user.Role == dbModel.RoleBanned { if user.IsBanned() {
return "", errors.New("user banned") return "", errors.New("user banned")
} }
t, err := time.ParseDuration(conf.Conf.Jwt.Expire) t, err := time.ParseDuration(conf.Conf.Jwt.Expire)
@ -131,7 +130,7 @@ func NewAuthUserToken(user *op.User) (string, error) {
} }
func NewAuthRoomToken(user *op.User, room *op.Room) (string, error) { func NewAuthRoomToken(user *op.User, room *op.Room) (string, error) {
if user.Role == dbModel.RoleBanned { if user.IsBanned() {
return "", errors.New("user banned") return "", errors.New("user banned")
} }
t, err := time.ParseDuration(conf.Conf.Jwt.Expire) t, err := time.ParseDuration(conf.Conf.Jwt.Expire)

@ -7,6 +7,7 @@ import (
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/synctv-org/synctv/internal/op" "github.com/synctv-org/synctv/internal/op"
"github.com/synctv-org/synctv/internal/provider" "github.com/synctv-org/synctv/internal/provider"
"github.com/synctv-org/synctv/internal/provider/providers"
"github.com/synctv-org/synctv/server/middlewares" "github.com/synctv-org/synctv/server/middlewares"
"github.com/synctv-org/synctv/server/model" "github.com/synctv-org/synctv/server/model"
"github.com/synctv-org/synctv/utils" "github.com/synctv-org/synctv/utils"
@ -16,7 +17,7 @@ import (
func OAuth2(ctx *gin.Context) { func OAuth2(ctx *gin.Context) {
t := ctx.Param("type") t := ctx.Param("type")
pi, err := provider.GetProvider(provider.OAuth2Provider(t)) pi, err := providers.GetProvider(provider.OAuth2Provider(t))
if err != nil { if err != nil {
ctx.AbortWithStatusJSON(http.StatusBadRequest, model.NewApiErrorResp(err)) ctx.AbortWithStatusJSON(http.StatusBadRequest, model.NewApiErrorResp(err))
return return
@ -30,7 +31,7 @@ func OAuth2(ctx *gin.Context) {
func OAuth2Api(ctx *gin.Context) { func OAuth2Api(ctx *gin.Context) {
t := ctx.Param("type") t := ctx.Param("type")
pi, err := provider.GetProvider(provider.OAuth2Provider(t)) pi, err := providers.GetProvider(provider.OAuth2Provider(t))
if err != nil { if err != nil {
ctx.AbortWithStatusJSON(http.StatusBadRequest, model.NewApiErrorResp(err)) ctx.AbortWithStatusJSON(http.StatusBadRequest, model.NewApiErrorResp(err))
} }
@ -64,7 +65,7 @@ func OAuth2Callback(ctx *gin.Context) {
} }
p := provider.OAuth2Provider(ctx.Param("type")) p := provider.OAuth2Provider(ctx.Param("type"))
pi, err := provider.GetProvider(p) pi, err := providers.GetProvider(p)
if err != nil { if err != nil {
ctx.AbortWithStatusJSON(http.StatusBadRequest, model.NewApiErrorResp(err)) ctx.AbortWithStatusJSON(http.StatusBadRequest, model.NewApiErrorResp(err))
return return
@ -112,7 +113,7 @@ func OAuth2CallbackApi(ctx *gin.Context) {
} }
p := provider.OAuth2Provider(ctx.Param("type")) p := provider.OAuth2Provider(ctx.Param("type"))
pi, err := provider.GetProvider(p) pi, err := providers.GetProvider(p)
if err != nil { if err != nil {
ctx.AbortWithStatusJSON(http.StatusBadRequest, model.NewApiErrorResp(err)) ctx.AbortWithStatusJSON(http.StatusBadRequest, model.NewApiErrorResp(err))
} }

@ -2,12 +2,12 @@ package auth
import ( import (
"github.com/gin-gonic/gin" "github.com/gin-gonic/gin"
"github.com/synctv-org/synctv/internal/provider" "github.com/synctv-org/synctv/internal/provider/providers"
"golang.org/x/exp/maps" "golang.org/x/exp/maps"
) )
func OAuth2EnabledApi(ctx *gin.Context) { func OAuth2EnabledApi(ctx *gin.Context) {
ctx.JSON(200, gin.H{ ctx.JSON(200, gin.H{
"enabled": maps.Keys(provider.EnabledProvider()), "enabled": maps.Keys(providers.EnabledProvider()),
}) })
} }

Loading…
Cancel
Save