diff --git a/.github/workflows/ci.yml b/.github/workflows/ci.yml index 4db8f6b..7e67d22 100644 --- a/.github/workflows/ci.yml +++ b/.github/workflows/ci.yml @@ -41,14 +41,14 @@ jobs: uses: golangci/golangci-lint-action@v8 with: version: latest - args: --color always + args: --color=always - name: Run Fix Linter uses: golangci/golangci-lint-action@v8 if: ${{ failure() }} with: install-mode: none - args: --fix --color always + args: --fix --color=always - name: Auto Fix Diff Content if: ${{ failure() }} diff --git a/.golangci.yml b/.golangci.yml index a678d04..73a73e3 100644 --- a/.golangci.yml +++ b/.golangci.yml @@ -1,7 +1,7 @@ version: "2" run: - go: "1.25.0" + go: "1.24" relative-path-mode: gomod modules-download-mode: readonly @@ -69,6 +69,7 @@ linters: - usetesting - wastedassign - whitespace + - wsl_v5 exclusions: generated: lax presets: @@ -81,6 +82,13 @@ linters: - builtin$ - examples$ settings: + revive: + rules: + - name: var-naming + arguments: + - [] + - [] + - [{ skipPackageNameChecks: true }] copyloopvar: check-alias: true cyclop: @@ -88,6 +96,8 @@ linters: errcheck: check-type-assertions: true forbidigo: + forbid: + - pattern: ^print(ln)?$ analyze-types: true prealloc: for-loops: true @@ -95,6 +105,7 @@ linters: dot-import-whitelist: [] http-status-code-whitelist: [] usestdlibvars: + time-date-month: true time-month: true time-layout: true crypto-hash: true @@ -107,6 +118,9 @@ linters: gosec: excludes: - G404 + wsl_v5: + allow-whole-block: true + branch-max-lines: 4 formatters: enable: @@ -129,5 +143,3 @@ formatters: replacement: "a[b:]" gofumpt: extra-rules: true - golines: - shorten-comments: true diff --git a/cmd/root.go b/cmd/root.go index 6fd6c09..5cbd753 100644 --- a/cmd/root.go +++ b/cmd/root.go @@ -102,10 +102,12 @@ func init() { RootCmd.PersistentFlags().BoolVar(&flags.SkipEnvFlag, "skip-env-flag", true, "skip env flag") RootCmd.PersistentFlags(). StringVar(&flags.Global.GitHubBaseURL, "github-base-url", "https://api.github.com/", "github api base url") + home, err := homedir.Dir() if err != nil { home = "~" } + RootCmd.PersistentFlags(). StringVar(&flags.Global.DataDir, "data-dir", filepath.Join(home, ".synctv"), "data dir") RootCmd.PersistentFlags(). diff --git a/cmd/self-update.go b/cmd/self-update.go index d7dd970..4d4f554 100644 --- a/cmd/self-update.go +++ b/cmd/self-update.go @@ -34,6 +34,7 @@ func SelfUpdate(cmd *cobra.Command, _ []string) error { log.Errorf("get version info error: %v", err) return fmt.Errorf("get version info error: %w", err) } + return v.SelfUpdate(cmd.Context()) } diff --git a/cmd/server.go b/cmd/server.go index a4bbc4b..5b948b2 100644 --- a/cmd/server.go +++ b/cmd/server.go @@ -56,6 +56,7 @@ func setupAddresses() (tcpHTTPAddr, tcpRTMPAddr *net.TCPAddr, err error) { if conf.Conf.Server.RTMP.Listen == "" { conf.Conf.Server.RTMP.Listen = conf.Conf.Server.HTTP.Listen } + if conf.Conf.Server.RTMP.Port == 0 { conf.Conf.Server.RTMP.Port = conf.Conf.Server.HTTP.Port } @@ -64,7 +65,8 @@ func setupAddresses() (tcpHTTPAddr, tcpRTMPAddr *net.TCPAddr, err error) { "tcp", fmt.Sprintf("%s:%d", conf.Conf.Server.RTMP.Listen, conf.Conf.Server.RTMP.Port), ) - return + + return tcpHTTPAddr, tcpRTMPAddr, err } func startHTTPServer(e *gin.Engine, listener net.Listener) { @@ -72,6 +74,7 @@ func startHTTPServer(e *gin.Engine, listener net.Listener) { case conf.Conf.Server.HTTP.CertPath != "" && conf.Conf.Server.HTTP.KeyPath != "": go func() { srv := http.Server{Handler: e.Handler(), ReadHeaderTimeout: 3 * time.Second} + err := srv.ServeTLS( listener, conf.Conf.Server.HTTP.CertPath, @@ -84,6 +87,7 @@ func startHTTPServer(e *gin.Engine, listener net.Listener) { case conf.Conf.Server.HTTP.CertPath == "" && conf.Conf.Server.HTTP.KeyPath == "": go func() { srv := http.Server{Handler: e.Handler(), ReadHeaderTimeout: 3 * time.Second} + err := srv.Serve(listener) if err != nil { log.Panicf("http server error: %v", err) @@ -121,6 +125,7 @@ func Server(_ *cobra.Command, _ []string) { } else { httpListener = muxer.Match(cmux.HTTP1Fast()) } + startHTTPServer(e, httpListener) // Setup RTMP @@ -145,6 +150,7 @@ func Server(_ *cobra.Command, _ []string) { if err != nil { log.Fatal(err) } + go func() { err := rtmp.Server().Serve(rtmpListener) if err != nil { @@ -160,6 +166,7 @@ func Server(_ *cobra.Command, _ []string) { if conf.Conf.Server.RTMP.Enable { log.Infof("rtmp run on tcp://%s:%d", tcpRTMPAddr.IP, tcpRTMPAddr.Port) } + if conf.Conf.Server.HTTP.CertPath != "" && conf.Conf.Server.HTTP.KeyPath != "" { log.Infof("website run on https://%s:%d", tcpHTTPAddr.IP, tcpHTTPAddr.Port) } else { diff --git a/cmd/version.go b/cmd/version.go index 1197f9d..b1732b2 100644 --- a/cmd/version.go +++ b/cmd/version.go @@ -9,7 +9,6 @@ import ( "github.com/synctv-org/synctv/internal/version" ) -//nolint:forbidigo var VersionCmd = &cobra.Command{ Use: "version", Short: "Print the version number of Sync TV Server", diff --git a/internal/bootstrap/config.go b/internal/bootstrap/config.go index 773a1f2..922c180 100644 --- a/internal/bootstrap/config.go +++ b/internal/bootstrap/config.go @@ -23,64 +23,79 @@ func InitConfig(_ context.Context) (err error) { log.Fatal("skip config and skip env at the same time") return errors.New("skip config and skip env at the same time") } + conf.Conf = conf.DefaultConfig() if !flags.Server.SkipConfig { configFile, err := utils.OptFilePath(filepath.Join(flags.Global.DataDir, "config.yaml")) if err != nil { log.Fatalf("config file path error: %v", err) } + err = confFromConfig(configFile, conf.Conf) if err != nil { log.Fatalf("load config from file error: %v", err) } + log.Infof("load config success from file: %s", configFile) + if err = restoreConfig(configFile, conf.Conf); err != nil { log.Warnf("restore config error: %v", err) } else { log.Info("restore config success") } } + if !flags.Server.SkipEnvConfig { prefix := "SYNCTV_" if flags.EnvNoPrefix { prefix = "" + log.Info("load config from env without prefix") } else { log.Infof("load config from env with prefix: %s", prefix) } + err := confFromEnv(prefix, conf.Conf) if err != nil { log.Fatalf("load config from env error: %v", err) } + log.Info("load config success from env") } + return optConfigPath(conf.Conf) } func optConfigPath(conf *conf.Config) error { var err error + conf.Server.ProxyCachePath, err = utils.OptFilePath(conf.Server.ProxyCachePath) if err != nil { return fmt.Errorf("get proxy cache path error: %w", err) } + conf.Server.HTTP.CertPath, err = utils.OptFilePath(conf.Server.HTTP.CertPath) if err != nil { return fmt.Errorf("get http cert path error: %w", err) } + conf.Server.HTTP.KeyPath, err = utils.OptFilePath(conf.Server.HTTP.KeyPath) if err != nil { return fmt.Errorf("get http key path error: %w", err) } + conf.Log.FilePath, err = utils.OptFilePath(conf.Log.FilePath) if err != nil { return fmt.Errorf("get log file path error: %w", err) } + for _, op := range conf.Oauth2Plugins { op.PluginFile, err = utils.OptFilePath(op.PluginFile) if err != nil { return fmt.Errorf("get oauth2 plugin file path error: %w", err) } } + return nil } @@ -88,8 +103,10 @@ func confFromConfig(filePath string, conf *conf.Config) error { if filePath == "" { return errors.New("config file path is empty") } + if !utils.Exists(filePath) { log.Infof("config file not exists, create new config file: %s", filePath) + err := conf.Save(filePath) if err != nil { return err @@ -100,6 +117,7 @@ func confFromConfig(filePath string, conf *conf.Config) error { return err } } + return nil } diff --git a/internal/bootstrap/db.go b/internal/bootstrap/db.go index 6246783..3d5d8fc 100644 --- a/internal/bootstrap/db.go +++ b/internal/bootstrap/db.go @@ -27,6 +27,7 @@ func InitDatabase(_ context.Context) (err error) { } var opts []gorm.Option + opts = append(opts, &gorm.Config{ TranslateError: true, Logger: newDBLogger(), @@ -34,14 +35,17 @@ func InitDatabase(_ context.Context) (err error) { DisableForeignKeyConstraintWhenMigrating: false, IgnoreRelationshipsWhenMigrating: false, }) + d, err := gorm.Open(dialector, opts...) if err != nil { log.Fatalf("failed to connect database: %s", err.Error()) } + sqlDB, err := d.DB() if err != nil { log.Fatalf("failed to get sqlDB: %s", err.Error()) } + err = sysnotify.RegisterSysNotifyTask( 0, sysnotify.NewSysNotifyTask("database", sysnotify.NotifyTypeEXIT, func() error { @@ -51,9 +55,11 @@ func InitDatabase(_ context.Context) (err error) { if err != nil { log.Fatalf("failed to register sysnotify task: %s", err.Error()) } + if conf.Conf.Database.Type != conf.DatabaseTypeSqlite3 { initRawDB(sqlDB) } + return db.Init(d, conf.Conf.Database.Type) } @@ -86,6 +92,7 @@ func createDialector(dbConf conf.DatabaseConfig) (dialector gorm.Dialector, err ) log.Infof("mysql database tcp: %s:%d", dbConf.Host, dbConf.Port) } + dialector = mysql.New(mysql.Config{ DSN: dsn, DefaultStringSize: 256, @@ -100,18 +107,22 @@ func createDialector(dbConf conf.DatabaseConfig) (dialector gorm.Dialector, err dsn = dbConf.CustomDSN case dbConf.Name == "memory" || strings.HasPrefix(dbConf.Name, ":memory:"): dsn = "file::memory:?cache=shared&_journal_mode=WAL&_vacuum=incremental&_pragma=foreign_keys(1)" + log.Infof("sqlite3 database memory") default: if !strings.HasSuffix(dbConf.Name, ".db") { dbConf.Name += ".db" } + dbConf.Name, err = utils.OptFilePath(dbConf.Name) if err != nil { log.Fatalf("sqlite3 database file path error: %v", err) } + dsn = dbConf.Name + "?_journal_mode=WAL&_vacuum=incremental&_pragma=foreign_keys(1)" log.Infof("sqlite3 database file: %s", dbConf.Name) } + dialector = openSqlite(dsn) case conf.DatabaseTypePostgres: switch { @@ -137,6 +148,7 @@ func createDialector(dbConf conf.DatabaseConfig) (dialector gorm.Dialector, err ) log.Infof("postgres database tcp: %s:%d", dbConf.Host, dbConf.Port) } + dialector = postgres.New(postgres.Config{ DSN: dsn, PreferSimpleProtocol: true, @@ -144,6 +156,7 @@ func createDialector(dbConf conf.DatabaseConfig) (dialector gorm.Dialector, err default: log.Fatalf("unknown database type: %s", dbConf.Type) } + return dialector, err } @@ -154,6 +167,7 @@ func newDBLogger() logger.Interface { } else { logLevel = logger.Warn } + return logger.New( log.StandardLogger(), logger.Config{ @@ -169,14 +183,18 @@ func newDBLogger() logger.Interface { func initRawDB(db *sql.DB) { db.SetMaxOpenConns(conf.Conf.Database.MaxOpenConns) db.SetMaxIdleConns(conf.Conf.Database.MaxIdleConns) + d, err := time.ParseDuration(conf.Conf.Database.ConnMaxLifetime) if err != nil { log.Fatalf("failed to parse conn_max_lifetime: %s", err.Error()) } + db.SetConnMaxLifetime(d) + d, err = time.ParseDuration(conf.Conf.Database.ConnMaxIdleTime) if err != nil { log.Fatalf("failed to parse conn_max_idle_time: %s", err.Error()) } + db.SetConnMaxIdleTime(d) } diff --git a/internal/bootstrap/gin.go b/internal/bootstrap/gin.go index 8135245..86a98da 100644 --- a/internal/bootstrap/gin.go +++ b/internal/bootstrap/gin.go @@ -14,6 +14,7 @@ func InitGinMode(_ context.Context) error { } else { gin.SetMode(gin.ReleaseMode) } + if utils.ForceColor() { gin.ForceConsoleColor() } else { diff --git a/internal/bootstrap/init.go b/internal/bootstrap/init.go index 6a7d28d..303d495 100644 --- a/internal/bootstrap/init.go +++ b/internal/bootstrap/init.go @@ -21,6 +21,7 @@ func New(conf ...Conf) *Bootstrap { for _, c := range conf { c(b) } + return b } @@ -37,5 +38,6 @@ func (b *Bootstrap) Run(ctx context.Context) error { return err } } + return nil } diff --git a/internal/bootstrap/log.go b/internal/bootstrap/log.go index 920409a..bf761db 100644 --- a/internal/bootstrap/log.go +++ b/internal/bootstrap/log.go @@ -33,6 +33,7 @@ var logCallerIgnoreFuncs = map[string]struct{}{ func InitLog(_ context.Context) (err error) { setLog(logrus.StandardLogger()) + forceColor := utils.ForceColor() if conf.Conf.Log.Enable { l := &lumberjack.Logger{ @@ -45,12 +46,14 @@ func InitLog(_ context.Context) (err error) { if err := l.Rotate(); err != nil { logrus.Fatalf("log: rotate log file error: %v", err) } + var w io.Writer if forceColor { w = colorable.NewNonColorableWriter(l) } else { w = l } + if flags.Global.Dev || flags.Global.LogStd { logrus.SetOutput(io.MultiWriter(os.Stdout, w)) logrus.Infof("log: enable log to stdout and file: %s", conf.Conf.Log.FilePath) @@ -59,6 +62,7 @@ func InitLog(_ context.Context) (err error) { logrus.Infof("log: disable log to stdout, only log to file: %s", conf.Conf.Log.FilePath) } } + switch conf.Conf.Log.LogFormat { case "json": logrus.SetFormatter(&logrus.JSONFormatter{ @@ -74,6 +78,7 @@ func InitLog(_ context.Context) (err error) { if conf.Conf.Log.LogFormat != "text" { logrus.Warnf("unknown log format: %s, use default: text", conf.Conf.Log.LogFormat) } + logrus.SetFormatter(&logrus.TextFormatter{ ForceColors: forceColor, DisableColors: !forceColor, @@ -91,7 +96,9 @@ func InitLog(_ context.Context) (err error) { }, }) } + log.SetOutput(logrus.StandardLogger().Writer()) + return nil } diff --git a/internal/bootstrap/provider.go b/internal/bootstrap/provider.go index db8e55f..3eeaf28 100644 --- a/internal/bootstrap/provider.go +++ b/internal/bootstrap/provider.go @@ -82,17 +82,21 @@ var Oauth2SignupEnabledCache = refreshcache0.NewRefreshCache( func InitProvider(_ context.Context) (err error) { logOur := log.StandardLogger().Writer() + logLevle := hclog.Info if flags.Global.Dev { logLevle = hclog.Debug } + for _, op := range conf.Conf.Oauth2Plugins { log.Infof("load oauth2 plugin: %s", op.PluginFile) + err := os.MkdirAll(filepath.Dir(op.PluginFile), 0o755) if err != nil { log.Fatalf("create plugin dir: %s failed: %s", filepath.Dir(op.PluginFile), err) return err } + err = plugins.InitProviderPlugins(op.PluginFile, op.Args, hclog.New(&hclog.LoggerOptions{ Name: op.PluginFile, Level: logLevle, @@ -112,6 +116,7 @@ func InitProvider(_ context.Context) (err error) { for _, api := range aggregations.AllAggregation() { InitAggregationSetting(api) } + return nil } @@ -123,17 +128,21 @@ func InitProviderSetting(pi provider.Provider) { groupSettings.Enabled = settings.NewBoolSetting(group+"_enabled", false, group, settings.WithBeforeInitBool(func(_ settings.BoolSetting, b bool) (bool, error) { defer func() { _, _ = Oauth2EnabledCache.Refresh(context.Background()) }() + if b { return b, providers.EnableProvider(pi.Provider()) } + return b, providers.DisableProvider(pi.Provider()) }), settings.WithInitPriorityBool(1), settings.WithBeforeSetBool(func(_ settings.BoolSetting, b bool) (bool, error) { defer func() { _, _ = Oauth2EnabledCache.Refresh(context.Background()) }() + if b { return b, providers.EnableProvider(pi.Provider()) } + return b, providers.DisableProvider(pi.Provider()) }), ) @@ -212,9 +221,11 @@ func InitAggregationProviderSetting(pi provider.Provider) { groupSettings.Enabled = settings.LoadOrNewBoolSetting(group+"_enabled", false, group, settings.WithBeforeSetBool(func(_ settings.BoolSetting, b bool) (bool, error) { defer func() { _, _ = Oauth2EnabledCache.Refresh(context.Background()) }() + if b { return b, providers.EnableProvider(pi.Provider()) } + return b, providers.DisableProvider(pi.Provider()) }), ) @@ -305,12 +316,14 @@ func InitAggregationSetting(pi provider.AggregationProviderInterface) { if s == "" { return s, nil } + list := strings.Split(s, ",") for _, p := range list { if slices.Index(pi.Providers(), p) == -1 { return s, fmt.Errorf("provider %s not found", p) } } + return s, nil }), ) @@ -325,8 +338,10 @@ func InitAggregationSetting(pi provider.AggregationProviderInterface) { pi.Provider(), ) } + all := pi.Providers() list := strings.Split(s, ",") + enabled := make([]provider.OAuth2Provider, 0, len(list)) for _, p := range list { if slices.Index(all, p) != -1 { @@ -343,13 +358,16 @@ func InitAggregationSetting(pi provider.AggregationProviderInterface) { pi.Provider(), err, ) + return b, nil } + for _, pi2 := range pi2 { providers.RegisterProvider(pi2) InitAggregationProviderSetting(pi2) } } + return b, nil }), settings.WithBeforeSetBool(func(_ settings.BoolSetting, b bool) (bool, error) { diff --git a/internal/bootstrap/rtmp.go b/internal/bootstrap/rtmp.go index 4e15f8c..995e43a 100644 --- a/internal/bootstrap/rtmp.go +++ b/internal/bootstrap/rtmp.go @@ -24,6 +24,7 @@ func auth(reqAppName, reqChannelName string, isPublisher bool) (*rtmps.Channel, log.Errorf("rtmp: get room by id error: %v", err) return nil, err } + room := roomE.Value() if err := validateRoom(room); err != nil { @@ -41,9 +42,11 @@ func validateRoom(room *op.Room) error { if room.IsBanned() { return fmt.Errorf("rtmp: room %s is banned", room.ID) } + if room.IsPending() { return fmt.Errorf("rtmp: room %s is pending, need admin approval", room.ID) } + return nil } @@ -53,7 +56,9 @@ func handlePublisher(reqAppName, reqChannelName string, room *op.Room) (*rtmps.C log.Errorf("rtmp: publish auth to %s error: %v", reqAppName, err) return nil, err } + log.Infof("rtmp: publisher login success: %s/%s", reqAppName, channelName) + return room.GetChannel(channelName) } @@ -63,5 +68,6 @@ func handlePlayer(reqAppName, reqChannelName string, room *op.Room) (*rtmps.Chan log.Warnf("rtmp: dial to %s/%s error: %s", reqAppName, reqChannelName, err) return nil, err } + return room.GetChannel(reqChannelName) } diff --git a/internal/bootstrap/setting.go b/internal/bootstrap/setting.go index 15a99a3..bbb22a4 100644 --- a/internal/bootstrap/setting.go +++ b/internal/bootstrap/setting.go @@ -21,6 +21,7 @@ func initAndFixSettings() error { if err != nil { return err } + var setting *model.Setting for { @@ -38,11 +39,13 @@ func initAndFixSettings() error { Type: b.Type(), Group: b.Group(), } + err := db.FirstOrCreateSettingItemValue(setting) if err != nil { return err } } + err = b.Init(setting.Value) if err != nil { // auto fix diff --git a/internal/bootstrap/update.go b/internal/bootstrap/update.go index 2b2ffb6..ed67012 100644 --- a/internal/bootstrap/update.go +++ b/internal/bootstrap/update.go @@ -26,6 +26,7 @@ func InitCheckUpdate(ctx context.Context) error { latest string url string ) + need, latest, url, err = check(ctx, v) if err != nil { log.Errorf("check update error: %v", err) @@ -42,6 +43,7 @@ func InitCheckUpdate(ctx context.Context) error { log.Infof("new version (%s) available: %s", latest, url) log.Infof("run '%s self-update' to auto update", execFile) } + return nil }, )) @@ -51,6 +53,7 @@ func InitCheckUpdate(ctx context.Context) error { t := time.NewTicker(time.Hour * 6) defer t.Stop() + for range t.C { func() { defer func() { @@ -58,6 +61,7 @@ func InitCheckUpdate(ctx context.Context) error { log.Errorf("check update panic: %v", err) } }() + need, latest, url, err = check(ctx, v) if err != nil { log.Errorf("check update error: %v", err) @@ -74,18 +78,23 @@ func check(ctx context.Context, v *version.Info) (need bool, latest, url string, if err != nil { return false, "", "", err } + latest = l + b, err := v.NeedUpdate(ctx) if err != nil { return false, "", "", err } + need = b if b { u, err := v.LatestBinaryURL(ctx) if err != nil { return false, "", "", err } + url = u } + return need, latest, url, nil } diff --git a/internal/cache/alist.go b/internal/cache/alist.go index 80ff3f5..231bad9 100644 --- a/internal/cache/alist.go +++ b/internal/cache/alist.go @@ -47,6 +47,7 @@ func AlistAuthorizationCacheWithUserIDInitFunc( if err != nil { return nil, err } + return AlistAuthorizationCacheWithConfigInitFunc(ctx, v) } @@ -64,12 +65,14 @@ func AlistAuthorizationCacheWithConfigInitFunc( if err != nil { return nil, err } + return &AlistUserCacheData{ Host: v.Host, ServerID: v.ServerID, Backend: v.Backend, }, nil } + resp, err := cli.Login(ctx, &alist.LoginReq{ Host: v.Host, Username: v.Username, @@ -133,6 +136,7 @@ func newAliSubtitles( if v.GetStatus() != "finished" { return nil } + url := v.GetUrl() caches[i] = &AlistSubtitle{ Cache: refreshcache0.NewRefreshCache(func(ctx context.Context) ([]byte, error) { @@ -140,14 +144,17 @@ func newAliSubtitles( if err != nil { return nil, err } + resp, err := uhc.Do(r) if err != nil { return nil, err } defer resp.Body.Close() + if resp.StatusCode != http.StatusOK { return nil, fmt.Errorf("status code: %d", resp.StatusCode) } + if resp.ContentLength > subtitleMaxLength { return nil, fmt.Errorf( "subtitle too large, got: %d, max: %d", @@ -155,6 +162,7 @@ func newAliSubtitles( subtitleMaxLength, ) } + return io.ReadAll(io.LimitReader(resp.Body, subtitleMaxLength)) }, -1), Name: v.GetLanguage(), @@ -162,6 +170,7 @@ func newAliSubtitles( Type: utils.GetFileExtension(v.GetUrl()), } } + return caches } @@ -171,10 +180,12 @@ func genAliM3U8ListFile( buf := bytes.NewBuffer(nil) buf.WriteString("#EXTM3U\n") buf.WriteString("#EXT-X-VERSION:3\n") + for _, v := range urls { if v.GetStatus() != "finished" { return nil } + fmt.Fprintf( buf, "#EXT-X-STREAM-INF:BANDWIDTH=%d,RESOLUTION=%dx%d,NAME=\"%d\"\n", @@ -185,6 +196,7 @@ func genAliM3U8ListFile( ) buf.WriteString(v.GetUrl() + "\n") } + return buf.Bytes() } @@ -211,11 +223,13 @@ func NewAlistMovieCacheInitFunc( if err != nil { return nil, err } + if aucd.Host == "" { return nil, errors.New("not bind alist vendor") } cli := vendor.LoadAlistClient(movie.VendorInfo.Backend) + fg, err := getFsGet( ctx, cli, @@ -261,12 +275,15 @@ func validateArgs(args *AlistMovieCacheFuncArgs, movie *model.Movie, subPath str if args == nil { return errors.New("need alist user cache") } + if args.UserCache == nil { return errors.New("need alist user cache") } + if movie.IsFolder && subPath == "" { return errors.New("sub path is empty") } + return nil } @@ -281,6 +298,7 @@ func getServerIDAndPath(movie *model.Movie, subPath string) (string, string, err if !strings.HasPrefix(newPath, truePath) { return "", "", errors.New("sub path is not in parent path") } + truePath = newPath } @@ -315,6 +333,7 @@ func processSubtitles( if related.GetType() != 4 { continue } + if utils.GetFileExtension(related.GetName()) == "xml" { continue } @@ -334,6 +353,7 @@ func processSubtitles( } cache.Subtitles = append(cache.Subtitles, subtitle) } + return nil } @@ -342,14 +362,17 @@ func fetchSubtitleContent(ctx context.Context, url string) ([]byte, error) { if err != nil { return nil, err } + resp, err := uhc.Do(r) if err != nil { return nil, err } defer resp.Body.Close() + if resp.StatusCode != http.StatusOK { return nil, fmt.Errorf("status code: %d", resp.StatusCode) } + if resp.ContentLength > subtitleMaxLength { return nil, fmt.Errorf( "subtitle too large, got: %d, max: %d", @@ -357,6 +380,7 @@ func fetchSubtitleContent(ctx context.Context, url string) ([]byte, error) { subtitleMaxLength, ) } + return io.ReadAll(io.LimitReader(resp.Body, subtitleMaxLength)) } @@ -383,8 +407,10 @@ func processAliProvider( if err != nil { return nil, err } + url = u.GetRawUrl() } + fo, err := cli.FsOther(ctx, &alist.FsOtherReq{ Host: aucd.Host, Token: aucd.Token, @@ -395,6 +421,7 @@ func processAliProvider( if err != nil { return nil, err } + return &AlistAliCache{ URL: url, M3U8ListFile: genAliM3U8ListFile( diff --git a/internal/cache/bilibili.go b/internal/cache/bilibili.go index ebc9863..2e44021 100644 --- a/internal/cache/bilibili.go +++ b/internal/cache/bilibili.go @@ -60,6 +60,7 @@ func BilibiliSharedMpdCacheInitFunc( } cli := vendor.LoadBilibiliClient(movie.VendorInfo.Backend) + m, hevcM, err := getBilibiliMpd(ctx, cli, movie.VendorInfo.Bilibili, cookies) if err != nil { return nil, err @@ -83,6 +84,7 @@ func getBilibiliCookies(ctx context.Context, args *BilibiliUserCache) ([]*http.C } return nil, nil } + return vendorInfo.Cookies, nil } @@ -103,6 +105,7 @@ func getBilibiliMpd( if err != nil { return nil, nil, err } + return parseMpdResponse(resp.GetMpd(), resp.GetHevcMpd()) case biliInfo.Bvid != "": @@ -114,6 +117,7 @@ func getBilibiliMpd( if err != nil { return nil, nil, err } + return parseMpdResponse(resp.GetMpd(), resp.GetHevcMpd()) default: @@ -170,11 +174,13 @@ func processMpdUrls(m, hevcM *mpd.MPD, movieID, roomID string) []string { func BilibiliMpdToString(mpdRaw *mpd.MPD, token string) (string, error) { newMpdRaw := *mpdRaw + newPeriods := make([]*mpd.Period, len(mpdRaw.Periods)) for i, p := range mpdRaw.Periods { n := *p newPeriods[i] = &n } + newMpdRaw.Periods = newPeriods for _, p := range newMpdRaw.Periods { newAdaptationSets := make([]*mpd.AdaptationSet, len(p.AdaptationSets)) @@ -182,6 +188,7 @@ func BilibiliMpdToString(mpdRaw *mpd.MPD, token string) (string, error) { n := *as newAdaptationSets[i] = &n } + p.AdaptationSets = newAdaptationSets for _, as := range p.AdaptationSets { newRepresentations := make([]*mpd.Representation, len(as.Representations)) @@ -189,10 +196,12 @@ func BilibiliMpdToString(mpdRaw *mpd.MPD, token string) (string, error) { n := *r newRepresentations[i] = &n } + as.Representations = newRepresentations for _, r := range as.Representations { newBaseURL := make([]string, len(r.BaseURL)) copy(newBaseURL, r.BaseURL) + r.BaseURL = newBaseURL for i := range r.BaseURL { r.BaseURL[i] = fmt.Sprintf("%s&token=%s", r.BaseURL[i], token) @@ -200,6 +209,7 @@ func BilibiliMpdToString(mpdRaw *mpd.MPD, token string) (string, error) { } } } + return newMpdRaw.WriteToString() } @@ -219,7 +229,9 @@ func BilibiliNoSharedMovieCacheInitFunc( if len(args) == 0 { return "", errors.New("no bilibili user cache data") } + var cookies []*http.Cookie + vendorInfo, err := args[0].Get(ctx) if err != nil { if !errors.Is(err, db.NotFoundError(db.ErrVendorNotFound)) { @@ -228,8 +240,11 @@ func BilibiliNoSharedMovieCacheInitFunc( } else { cookies = vendorInfo.Cookies } + cli := vendor.LoadBilibiliClient(movie.VendorInfo.Backend) + var u string + biliInfo := movie.VendorInfo.Bilibili switch { case biliInfo.Epid != 0: @@ -240,6 +255,7 @@ func BilibiliNoSharedMovieCacheInitFunc( if err != nil { return "", err } + u = resp.GetUrl() case biliInfo.Bvid != "": @@ -251,6 +267,7 @@ func BilibiliNoSharedMovieCacheInitFunc( if err != nil { return "", err } + u = resp.GetUrl() default: @@ -303,6 +320,7 @@ func BilibiliSubtitleCacheInitFunc( // must login var cookies []*http.Cookie + vendorInfo, err := args.Get(ctx) if err != nil { if errors.Is(err, db.NotFoundError(db.ErrVendorNotFound)) { @@ -310,9 +328,11 @@ func BilibiliSubtitleCacheInitFunc( } return nil, err } + cookies = vendorInfo.Cookies cli := vendor.LoadBilibiliClient(movie.VendorInfo.Backend) + resp, err := cli.GetSubtitles(ctx, &bilibili.GetSubtitlesReq{ Cookies: utils.HTTPCookieToMap(cookies), Bvid: biliInfo.Bvid, @@ -321,6 +341,7 @@ func BilibiliSubtitleCacheInitFunc( if err != nil { return nil, err } + subtitleCache := make(BilibiliSubtitleCache, len(resp.GetSubtitles())) for k, v := range resp.GetSubtitles() { subtitleCache[k] = &BilibiliSubtitleCacheItem{ @@ -336,6 +357,7 @@ func BilibiliSubtitleCacheInitFunc( func convertToSRT(subtitles *bilibiliSubtitleResp) []byte { srt := bytes.NewBuffer(nil) + counter := 0 for _, subtitle := range subtitles.Body { fmt.Fprintf(srt, @@ -346,6 +368,7 @@ func convertToSRT(subtitles *bilibiliSubtitleResp) []byte { subtitle.Content) counter++ } + return srt.Bytes() } @@ -355,6 +378,7 @@ func formatTime(seconds float64) string { minutes := int(seconds) / 60 seconds = math.Mod(seconds, 60) milliseconds := int((seconds - float64(int(seconds))) * 1000) + return fmt.Sprintf("%02d:%02d:%02d,%03d", hours, minutes, int(seconds), milliseconds) } @@ -363,18 +387,23 @@ func translateBilibiliSubtitleToSrt(ctx context.Context, url string) ([]byte, er if err != nil { return nil, err } + r.Header.Set("User-Agent", utils.UA) r.Header.Set("Referer", "https://www.bilibili.com") + resp, err := uhc.Do(r) if err != nil { return nil, err } defer resp.Body.Close() + var srt bilibiliSubtitleResp + err = json.NewDecoder(resp.Body).Decode(&srt) if err != nil { return nil, err } + return convertToSRT(&srt), nil } @@ -388,10 +417,12 @@ func genBilibiliLiveM3U8ListFile(urls []*bilibili.LiveStream) []byte { buf := bytes.NewBuffer(nil) buf.WriteString("#EXTM3U\n") buf.WriteString("#EXT-X-VERSION:3\n") + for _, v := range urls { if len(v.GetUrls()) == 0 { continue } + fmt.Fprintf( buf, "#EXT-X-STREAM-INF:BANDWIDTH=%d,NAME=\"%s\"\n", @@ -400,11 +431,13 @@ func genBilibiliLiveM3U8ListFile(urls []*bilibili.LiveStream) []byte { ) buf.WriteString(v.GetUrls()[0] + "\n") } + return buf.Bytes() } func BilibiliLiveCacheInitFunc(ctx context.Context, movie *model.Movie) ([]byte, error) { cli := vendor.LoadBilibiliClient(movie.VendorInfo.Backend) + resp, err := cli.GetLiveStreams(ctx, &bilibili.GetLiveStreamsReq{ Cid: movie.VendorInfo.Bilibili.Cid, Hls: true, @@ -412,6 +445,7 @@ func BilibiliLiveCacheInitFunc(ctx context.Context, movie *model.Movie) ([]byte, if err != nil { return nil, err } + return genBilibiliLiveM3U8ListFile(resp.GetLiveStreams()), nil } @@ -423,24 +457,30 @@ func NewBilibiliDanmuCacheInitFunc(movie *model.Movie) func(ctx context.Context) func BilibiliDanmuCacheInitFunc(ctx context.Context, movie *model.Movie) ([]byte, error) { u := fmt.Sprintf("https://comment.bilibili.com/%d.xml", movie.VendorInfo.Bilibili.Cid) + req, err := http.NewRequestWithContext(ctx, http.MethodGet, u, nil) if err != nil { return nil, err } + resp, err := http.DefaultClient.Do(req) if err != nil { return nil, err } defer resp.Body.Close() + if resp.StatusCode != http.StatusOK { return nil, fmt.Errorf("status code: %d", resp.StatusCode) } + gz := flate.NewReader(resp.Body) defer gz.Close() + data, err := io.ReadAll(gz) if err != nil { return nil, err } + return data, nil } @@ -477,6 +517,7 @@ type BilibiliUserCacheData struct { func NewBilibiliUserCache(userID string) *BilibiliUserCache { f := BilibiliAuthorizationCacheWithUserIDInitFunc(userID) + return refreshcache.NewRefreshCache( func(ctx context.Context, _ ...struct{}) (*BilibiliUserCacheData, error) { return f(ctx) @@ -493,6 +534,7 @@ func BilibiliAuthorizationCacheWithUserIDInitFunc( if err != nil { return nil, err } + return &BilibiliUserCacheData{ Cookies: utils.MapToHTTPCookie(v.Cookies), Backend: v.Backend, diff --git a/internal/cache/cache.go b/internal/cache/cache.go index c0b08b2..7a7cdb1 100644 --- a/internal/cache/cache.go +++ b/internal/cache/cache.go @@ -28,6 +28,7 @@ func newMapCache[T, A any](refreshFunc MapRefreshFunc[T, A], maxAge time.Duratio func (b *MapCache[T, A]) Clear() { b.lock.Lock() defer b.lock.Unlock() + b.clear() } @@ -38,23 +39,28 @@ func (b *MapCache[T, A]) clear() { func (b *MapCache[T, A]) Delete(key string) { b.lock.Lock() defer b.lock.Unlock() + delete(b.cache, key) } func (b *MapCache[T, A]) LoadOrStore(ctx context.Context, key string, args ...A) (T, error) { b.lock.RLock() + c, loaded := b.cache[key] if loaded { b.lock.RUnlock() return c.Get(ctx, args...) } + b.lock.RUnlock() b.lock.Lock() + c, loaded = b.cache[key] if loaded { b.lock.Unlock() return c.Get(ctx, args...) } + c = refreshcache.NewRefreshCache( refreshcache.RefreshFunc[T, A](func(ctx context.Context, args ...A) (T, error) { return b.refreshFunc(ctx, key, args...) @@ -63,23 +69,28 @@ func (b *MapCache[T, A]) LoadOrStore(ctx context.Context, key string, args ...A) ) b.cache[key] = c b.lock.Unlock() + return c.Get(ctx, args...) } func (b *MapCache[T, A]) StoreOrRefresh(ctx context.Context, key string, args ...A) (T, error) { b.lock.RLock() + c, ok := b.cache[key] if ok { b.lock.RUnlock() return c.Refresh(ctx, args...) } + b.lock.RUnlock() b.lock.Lock() + c, ok = b.cache[key] if ok { b.lock.Unlock() return c.Refresh(ctx, args...) } + c = refreshcache.NewRefreshCache( refreshcache.RefreshFunc[T, A](func(ctx context.Context, args ...A) (T, error) { return b.refreshFunc(ctx, key, args...) @@ -88,6 +99,7 @@ func (b *MapCache[T, A]) StoreOrRefresh(ctx context.Context, key string, args .. ) b.cache[key] = c b.lock.Unlock() + return c.Refresh(ctx, args...) } @@ -100,18 +112,22 @@ func (b *MapCache[T, A]) LoadCache(key string) (*refreshcache.RefreshCache[T, A] func (b *MapCache[T, A]) LoadOrNewCache(key string) *refreshcache.RefreshCache[T, A] { b.lock.RLock() + c, ok := b.cache[key] if ok { b.lock.RUnlock() return c } + b.lock.RUnlock() b.lock.Lock() + c, ok = b.cache[key] if ok { b.lock.Unlock() return c } + c = refreshcache.NewRefreshCache( refreshcache.RefreshFunc[T, A](func(ctx context.Context, args ...A) (T, error) { return b.refreshFunc(ctx, key, args...) @@ -120,6 +136,7 @@ func (b *MapCache[T, A]) LoadOrNewCache(key string) *refreshcache.RefreshCache[T ) b.cache[key] = c b.lock.Unlock() + return c } @@ -130,24 +147,30 @@ func (b *MapCache[T, A]) LoadOrStoreWithDynamicFunc( args ...A, ) (T, error) { b.lock.RLock() + c, loaded := b.cache[key] if loaded { b.lock.RUnlock() + return c.Data(). Get(ctx, refreshcache.RefreshFunc[T, A](func(ctx context.Context, args ...A) (T, error) { return refreshFunc(ctx, key, args...) }), args...) } + b.lock.RUnlock() b.lock.Lock() + c, loaded = b.cache[key] if loaded { b.lock.Unlock() + return c.Data(). Get(ctx, refreshcache.RefreshFunc[T, A](func(ctx context.Context, args ...A) (T, error) { return refreshFunc(ctx, key, args...) }), args...) } + c = refreshcache.NewRefreshCache( refreshcache.RefreshFunc[T, A](func(ctx context.Context, args ...A) (T, error) { return b.refreshFunc(ctx, key, args...) @@ -156,6 +179,7 @@ func (b *MapCache[T, A]) LoadOrStoreWithDynamicFunc( ) b.cache[key] = c b.lock.Unlock() + return c.Data(). Get(ctx, refreshcache.RefreshFunc[T, A](func(ctx context.Context, args ...A) (T, error) { return refreshFunc(ctx, key, args...) @@ -169,24 +193,30 @@ func (b *MapCache[T, A]) StoreOrRefreshWithDynamicFunc( args ...A, ) (T, error) { b.lock.RLock() + c, ok := b.cache[key] if ok { b.lock.RUnlock() + return c.Data(). Refresh(ctx, refreshcache.RefreshFunc[T, A](func(ctx context.Context, args ...A) (T, error) { return refreshFunc(ctx, key, args...) }), args...) } + b.lock.RUnlock() b.lock.Lock() + c, ok = b.cache[key] if ok { b.lock.Unlock() + return c.Data(). Refresh(ctx, refreshcache.RefreshFunc[T, A](func(ctx context.Context, args ...A) (T, error) { return refreshFunc(ctx, key, args...) }), args...) } + c = refreshcache.NewRefreshCache( refreshcache.RefreshFunc[T, A](func(ctx context.Context, args ...A) (T, error) { return b.refreshFunc(ctx, key, args...) @@ -195,6 +225,7 @@ func (b *MapCache[T, A]) StoreOrRefreshWithDynamicFunc( ) b.cache[key] = c b.lock.Unlock() + return c.Data(). Refresh(ctx, refreshcache.RefreshFunc[T, A](func(ctx context.Context, args ...A) (T, error) { return refreshFunc(ctx, key, args...) diff --git a/internal/cache/cache0.go b/internal/cache/cache0.go index 05318e6..9b969b2 100644 --- a/internal/cache/cache0.go +++ b/internal/cache/cache0.go @@ -28,6 +28,7 @@ func newMapCache0[T any](refreshFunc MapRefreshFunc0[T], maxAge time.Duration) * func (b *MapCache0[T]) Clear() { b.lock.Lock() defer b.lock.Unlock() + b.clear() } @@ -38,50 +39,61 @@ func (b *MapCache0[T]) clear() { func (b *MapCache0[T]) Delete(key string) { b.lock.Lock() defer b.lock.Unlock() + delete(b.cache, key) } func (b *MapCache0[T]) LoadOrStore(ctx context.Context, key string) (T, error) { b.lock.RLock() + c, loaded := b.cache[key] if loaded { b.lock.RUnlock() return c.Get(ctx) } + b.lock.RUnlock() b.lock.Lock() + c, loaded = b.cache[key] if loaded { b.lock.Unlock() return c.Get(ctx) } + c = refreshcache0.NewRefreshCache(func(ctx context.Context) (T, error) { return b.refreshFunc(ctx, key) }, b.maxAge) b.cache[key] = c b.lock.Unlock() + return c.Get(ctx) } func (b *MapCache0[T]) StoreOrRefresh(ctx context.Context, key string) (T, error) { b.lock.RLock() + c, ok := b.cache[key] if ok { b.lock.RUnlock() return c.Refresh(ctx) } + b.lock.RUnlock() b.lock.Lock() + c, ok = b.cache[key] if ok { b.lock.Unlock() return c.Refresh(ctx) } + c = refreshcache0.NewRefreshCache(func(ctx context.Context) (T, error) { return b.refreshFunc(ctx, key) }, b.maxAge) b.cache[key] = c b.lock.Unlock() + return c.Refresh(ctx) } @@ -94,23 +106,28 @@ func (b *MapCache0[T]) LoadCache(key string) (*refreshcache0.RefreshCache[T], bo func (b *MapCache0[T]) LoadOrNewCache(key string) *refreshcache0.RefreshCache[T] { b.lock.RLock() + c, ok := b.cache[key] if ok { b.lock.RUnlock() return c } + b.lock.RUnlock() b.lock.Lock() + c, ok = b.cache[key] if ok { b.lock.Unlock() return c } + c = refreshcache0.NewRefreshCache(func(ctx context.Context) (T, error) { return b.refreshFunc(ctx, key) }, b.maxAge) b.cache[key] = c b.lock.Unlock() + return c } @@ -120,6 +137,7 @@ func (b *MapCache0[T]) LoadOrStoreWithDynamicFunc( refreshFunc MapRefreshFunc0[T], ) (T, error) { b.lock.RLock() + c, loaded := b.cache[key] if loaded { b.lock.RUnlock() @@ -127,8 +145,10 @@ func (b *MapCache0[T]) LoadOrStoreWithDynamicFunc( return refreshFunc(ctx, key) }) } + b.lock.RUnlock() b.lock.Lock() + c, loaded = b.cache[key] if loaded { b.lock.Unlock() @@ -136,11 +156,13 @@ func (b *MapCache0[T]) LoadOrStoreWithDynamicFunc( return refreshFunc(ctx, key) }) } + c = refreshcache0.NewRefreshCache(func(ctx context.Context) (T, error) { return b.refreshFunc(ctx, key) }, b.maxAge) b.cache[key] = c b.lock.Unlock() + return c.Data().Get(ctx, func(ctx context.Context) (T, error) { return refreshFunc(ctx, key) }) @@ -152,6 +174,7 @@ func (b *MapCache0[T]) StoreOrRefreshWithDynamicFunc( refreshFunc MapRefreshFunc0[T], ) (T, error) { b.lock.RLock() + c, ok := b.cache[key] if ok { b.lock.RUnlock() @@ -159,8 +182,10 @@ func (b *MapCache0[T]) StoreOrRefreshWithDynamicFunc( return refreshFunc(ctx, key) }) } + b.lock.RUnlock() b.lock.Lock() + c, ok = b.cache[key] if ok { b.lock.Unlock() @@ -168,11 +193,13 @@ func (b *MapCache0[T]) StoreOrRefreshWithDynamicFunc( return refreshFunc(ctx, key) }) } + c = refreshcache0.NewRefreshCache(func(ctx context.Context) (T, error) { return b.refreshFunc(ctx, key) }, b.maxAge) b.cache[key] = c b.lock.Unlock() + return c.Data().Refresh(ctx, func(ctx context.Context) (T, error) { return refreshFunc(ctx, key) }) diff --git a/internal/cache/emby.go b/internal/cache/emby.go index b0a55b9..91e6cc3 100644 --- a/internal/cache/emby.go +++ b/internal/cache/emby.go @@ -41,13 +41,16 @@ func EmbyAuthorizationCacheWithUserIDInitFunc(userID, serverID string) (*EmbyUse if serverID == "" { return nil, errors.New("serverID is required") } + v, err := db.GetEmbyVendor(userID, serverID) if err != nil { return nil, err } + if v.APIKey == "" || v.Host == "" { return nil, db.NotFoundError(db.ErrVendorNotFound) } + return &EmbyUserCacheData{ Host: v.Host, ServerID: v.ServerID, @@ -92,6 +95,7 @@ func NewEmbyMovieClearCacheFunc( if !movie.VendorInfo.Emby.Transcode { return nil } + if args == nil { return errors.New("need emby user cache") } @@ -110,10 +114,13 @@ func NewEmbyMovieClearCacheFunc( if err != nil { return err } + if aucd.Host == "" || aucd.APIKey == "" { return errors.New("not bind emby vendor") } + cli := vendor.LoadEmbyClient(aucd.Backend) + _, err = cli.DeleteActiveEncodeings(ctx, &emby.DeleteActiveEncodeingsReq{ Host: aucd.Host, Token: aucd.APIKey, @@ -122,6 +129,7 @@ func NewEmbyMovieClearCacheFunc( if err != nil { log.Errorf("delete active encodeings: %v", err) } + return nil } } @@ -144,6 +152,7 @@ func NewEmbyMovieCacheInitFunc( if err != nil { return nil, err } + if aucd.Host == "" || aucd.APIKey == "" { return nil, errors.New("not bind emby vendor") } @@ -168,6 +177,7 @@ func NewEmbyMovieCacheInitFunc( if err != nil { return nil, err } + if source != nil { resp.Sources[i] = *source resp.Sources[i].Subtitles = processEmbySubtitles(v, truePath, u) @@ -182,9 +192,11 @@ func validateEmbyArgs(args *EmbyUserCache, movie *model.Movie, subPath string) e if args == nil { return errors.New("need emby user cache") } + if movie.IsFolder && subPath == "" { return errors.New("sub path is empty") } + return nil } @@ -193,9 +205,11 @@ func getEmbyServerIDAndPath(movie *model.Movie, subPath string) (string, string, if err != nil { return "", "", err } + if movie.IsFolder { truePath = subPath } + return serverID, truePath, nil } @@ -205,6 +219,7 @@ func getPlaybackInfo( truePath string, ) (*emby.PlaybackInfoResp, error) { cli := vendor.LoadEmbyClient(aucd.Backend) + data, err := cli.PlaybackInfo(ctx, &emby.PlaybackInfoReq{ Host: aucd.Host, Token: aucd.APIKey, @@ -214,6 +229,7 @@ func getPlaybackInfo( if err != nil { return nil, fmt.Errorf("playback info: %w", err) } + return data, nil } @@ -237,10 +253,12 @@ func processMediaSource( if v.GetContainer() == "" { return nil, nil } + result, err := url.JoinPath("emby", "Videos", truePath, "stream."+v.GetContainer()) if err != nil { return nil, err } + u.Path = result query := url.Values{} query.Set("api_key", aucd.APIKey) @@ -265,6 +283,7 @@ func processEmbySubtitles( } subtutleType := "srt" + result, err := url.JoinPath( "emby", "Videos", @@ -277,6 +296,7 @@ func processEmbySubtitles( if err != nil { continue } + u.Path = result u.RawQuery = "" url := u.String() @@ -297,6 +317,7 @@ func processEmbySubtitles( Cache: refreshcache0.NewRefreshCache(newEmbySubtitleCacheInitFunc(url), -1), }) } + return subtitles } @@ -306,16 +327,20 @@ func newEmbySubtitleCacheInitFunc(url string) func(ctx context.Context) ([]byte, if err != nil { return nil, err } + req.Header.Set("User-Agent", utils.UA) req.Header.Set("Referer", req.URL.Host) + resp, err := uhc.Do(req) if err != nil { return nil, err } defer resp.Body.Close() + if resp.StatusCode != http.StatusOK { return nil, errors.New("bad status code") } + return io.ReadAll(resp.Body) } } diff --git a/internal/db/db.go b/internal/db/db.go index 347bd4d..5528fd1 100644 --- a/internal/db/db.go +++ b/internal/db/db.go @@ -21,25 +21,31 @@ var ( func Init(d *gorm.DB, t conf.DatabaseType) error { db = d dbType = t + err := UpgradeDatabase() if err != nil { return err } + err = initGuestUser() if err != nil { return err } + return initRootUser() } func initRootUser() error { user := model.User{} + err := db.Where("role = ?", model.RoleRoot).First(&user).Error if err == nil || !errors.Is(err, gorm.ErrRecordNotFound) { return err } + u, err := CreateUser("root", "root", WithRole(model.RoleRoot)) log.Infof("init root user:\nid: %s\nusername: %s\npassword: %s", u.ID, u.Username, "root") + return err } @@ -52,10 +58,12 @@ func initGuestUser() error { user := model.User{ ID: GuestUserID, } + err := db.First(&user).Error if err == nil || !errors.Is(err, gorm.ErrRecordNotFound) { return err } + u, err := CreateUser( "guest", utils.RandString(32), @@ -63,6 +71,7 @@ func initGuestUser() error { WithID(GuestUserID), ) log.Infof("init guest user:\nid: %s\nusername: %s", u.ID, u.Username) + return err } @@ -72,11 +81,13 @@ func DB() *gorm.DB { func Close() { log.Info("closing db") + sqlDB, err := db.DB() if err != nil { log.Errorf("failed to get db: %s", err.Error()) return } + err = sqlDB.Close() if err != nil { log.Errorf("failed to close db: %s", err.Error()) @@ -89,11 +100,13 @@ func Paginate(page, pageSize int) func(db *gorm.DB) *gorm.DB { if page <= 0 { page = 1 } + if pageSize <= 0 { pageSize = 10 } offset := (page - 1) * pageSize + return db.Offset(offset).Limit(pageSize) } } @@ -404,8 +417,10 @@ func Transactional(txFunc func(*gorm.DB) error) (err error) { tx.Commit() } }() + err = txFunc(tx) - return + + return err } // Helper function to handle update results @@ -413,8 +428,10 @@ func HandleUpdateResult(result *gorm.DB, entityName string) error { if result.Error != nil { return HandleNotFound(result.Error, entityName) } + if result.RowsAffected == 0 { return NotFoundError(entityName) } + return nil } diff --git a/internal/db/member.go b/internal/db/member.go index 5b49a7c..787009a 100644 --- a/internal/db/member.go +++ b/internal/db/member.go @@ -44,6 +44,7 @@ func FirstOrCreateRoomMemberRelation( conf ...CreateRoomMemberRelationConfig, ) (*model.RoomMember, error) { roomMemberRelation := &model.RoomMember{} + d := &model.RoomMember{ RoomID: roomID, UserID: userID, @@ -55,10 +56,12 @@ func FirstOrCreateRoomMemberRelation( for _, c := range conf { c(d) } + err := db.Where("room_id = ? AND user_id = ?", roomID, userID). Attrs(d). FirstOrCreate(roomMemberRelation). Error + return roomMemberRelation, err } @@ -164,6 +167,7 @@ func RoomSetAdmin(roomID, userID string, permissions model.RoomAdminPermission) "permissions": model.AllPermissions, "admin_permissions": permissions, }) + return HandleUpdateResult(result, ErrRoomMemberNotFound) } @@ -175,24 +179,29 @@ func RoomSetMember(roomID, userID string, permissions model.RoomMemberPermission "permissions": permissions, "admin_permissions": model.NoAdminPermission, }) + return HandleUpdateResult(result, ErrRoomMemberNotFound) } func GetRoomMembers(roomID string, scopes ...func(*gorm.DB) *gorm.DB) ([]*model.RoomMember, error) { var members []*model.RoomMember + err := db. Where("room_id = ?", roomID). Scopes(scopes...). Find(&members).Error + return members, err } func GetRoomMembersCount(roomID string, scopes ...func(*gorm.DB) *gorm.DB) (int64, error) { var count int64 + err := db. Model(&model.RoomMember{}). Where("room_id = ?", roomID). Scopes(scopes...). Count(&count).Error + return count, err } diff --git a/internal/db/movie.go b/internal/db/movie.go index d9b93e3..68ce03a 100644 --- a/internal/db/movie.go +++ b/internal/db/movie.go @@ -29,26 +29,31 @@ func WithParentMovieID(parentMovieID string) func(*gorm.DB) *gorm.DB { func GetMoviesByRoomID(roomID string, scopes ...func(*gorm.DB) *gorm.DB) ([]*model.Movie, error) { var movies []*model.Movie + err := db.Where("room_id = ?", roomID). Order("position ASC"). Scopes(scopes...). Find(&movies). Error + return movies, err } func GetMoviesCountByRoomID(roomID string, scopes ...func(*gorm.DB) *gorm.DB) (int64, error) { var count int64 + err := db.Model(&model.Movie{}). Where("room_id = ?", roomID). Scopes(scopes...). Count(&count). Error + return count, err } func GetMovieByID(roomID, id string, scopes ...func(*gorm.DB) *gorm.DB) (*model.Movie, error) { var movie model.Movie + err := db.Where("room_id = ? AND id = ?", roomID, id).Scopes(scopes...).First(&movie).Error return &movie, HandleNotFound(err, ErrRoomOrMovieNotFound) } @@ -77,6 +82,7 @@ func UpdateMovie(movie *model.Movie, columns ...clause.Column) error { Clauses(clause.Returning{Columns: columns}). Where("room_id = ? AND id = ?", movie.RoomID, movie.ID). Updates(movie) + return HandleUpdateResult(result, ErrRoomOrMovieNotFound) } @@ -86,6 +92,7 @@ func SaveMovie(movie *model.Movie, columns ...clause.Column) error { Where("room_id = ? AND id = ?", movie.RoomID, movie.ID). Omit("created_at"). Save(movie) + return HandleUpdateResult(result, ErrRoomOrMovieNotFound) } @@ -95,6 +102,7 @@ func SwapMoviePositions(roomID, movie1ID, movie2ID string) error { if err := tx.Where("room_id = ? AND id = ?", roomID, movie1ID).First(&movie1).Error; err != nil { return HandleNotFound(err, ErrRoomOrMovieNotFound) } + if err := tx.Where("room_id = ? AND id = ?", roomID, movie2ID).First(&movie2).Error; err != nil { return HandleNotFound(err, ErrRoomOrMovieNotFound) } @@ -107,9 +115,11 @@ func SwapMoviePositions(roomID, movie1ID, movie2ID string) error { if err := HandleUpdateResult(result1, ErrRoomOrMovieNotFound); err != nil { return err } + result2 := tx.Model(&movie2). Where("room_id = ? AND id = ?", roomID, movie2ID). Update("position", movie2.Position) + return HandleUpdateResult(result2, ErrRoomOrMovieNotFound) }) } diff --git a/internal/db/room.go b/internal/db/room.go index e106fe3..4a64119 100644 --- a/internal/db/room.go +++ b/internal/db/room.go @@ -55,6 +55,7 @@ func WithSettingHidden(hidden bool) CreateRoomConfig { if r.Settings == nil { r.Settings = model.DefaultRoomSettings() } + r.Settings.Hidden = hidden } } @@ -72,6 +73,7 @@ func CreateRoom( for _, c := range conf { c(r) } + if password != "" { hashedPassword, err := bcrypt.GenerateFromPassword( stream.StringToBytes(password), @@ -80,6 +82,7 @@ func CreateRoom( if err != nil { return nil, fmt.Errorf("failed to hash password: %w", err) } + r.HashedPassword = hashedPassword } @@ -89,21 +92,25 @@ func CreateRoom( if err := tx.Model(&model.Room{}).Where("creator_id = ?", r.CreatorID).Count(&count).Error; err != nil { return fmt.Errorf("failed to count rooms: %w", err) } + if count >= maxCount { return errors.New("room count exceeds limit") } } + if err := tx.Create(r).Error; err != nil { if errors.Is(err, gorm.ErrDuplicatedKey) { return errors.New("room already exists") } return fmt.Errorf("failed to create room: %w", err) } + return nil }) if err != nil { return nil, err } + return r, nil } @@ -111,17 +118,22 @@ func GetRoomByID(id string) (*model.Room, error) { if len(id) != 32 { return nil, errors.New("room id is not 32 bit") } + var r model.Room + err := db.Where("id = ?", id).First(&r).Error + return &r, HandleNotFound(err, ErrRoomNotFound) } func CreateOrLoadRoomSettings(roomID string) (*model.RoomSettings, error) { var rs model.RoomSettings + err := db.Where(model.RoomSettings{ID: roomID}). Attrs(model.DefaultRoomSettings()). FirstOrCreate(&rs). Error + return &rs, err } @@ -132,10 +144,12 @@ func SaveRoomSettings(roomID string, settings *model.RoomSettings) error { func UpdateRoomSettings(roomID string, settings map[string]any) (*model.RoomSettings, error) { var rs model.RoomSettings + err := db.Model(&model.RoomSettings{ID: roomID}). Clauses(clause.Returning{}). Updates(settings). First(&rs).Error + return &rs, HandleNotFound(err, "room settings") } @@ -145,8 +159,11 @@ func DeleteRoomByID(roomID string) error { } func SetRoomPassword(roomID, password string) error { - var hashedPassword []byte - var err error + var ( + hashedPassword []byte + err error + ) + if password != "" { hashedPassword, err = bcrypt.GenerateFromPassword( stream.StringToBytes(password), @@ -156,6 +173,7 @@ func SetRoomPassword(roomID, password string) error { return fmt.Errorf("failed to hash password: %w", err) } } + return SetRoomHashedPassword(roomID, hashedPassword) } @@ -168,24 +186,28 @@ func SetRoomHashedPassword(roomID string, hashedPassword []byte) error { func GetAllRooms(scopes ...func(*gorm.DB) *gorm.DB) ([]*model.Room, error) { var rooms []*model.Room + err := db.Scopes(scopes...).Find(&rooms).Error return rooms, err } func GetAllRoomsCount(scopes ...func(*gorm.DB) *gorm.DB) (int64, error) { var count int64 + err := db.Model(&model.Room{}).Scopes(scopes...).Count(&count).Error return count, err } func GetAllRoomsAndCreator(scopes ...func(*gorm.DB) *gorm.DB) ([]*model.Room, error) { var rooms []*model.Room + err := db.Preload("Creator").Scopes(scopes...).Find(&rooms).Error return rooms, err } func GetAllRoomsByUserID(userID string) ([]*model.Room, error) { var rooms []*model.Room + err := db.Where("creator_id = ?", userID).Find(&rooms).Error return rooms, err } @@ -208,5 +230,6 @@ func SetRoomCurrent(roomID string, current *model.Current) error { Where("id = ?", roomID). Select("Current"). Updates(r) + return HandleUpdateResult(result, ErrRoomNotFound) } diff --git a/internal/db/setting.go b/internal/db/setting.go index ab526ea..483a352 100644 --- a/internal/db/setting.go +++ b/internal/db/setting.go @@ -19,19 +19,23 @@ func GetSettingItemsToMap() (map[string]*model.Setting, error) { if err != nil { return nil, err } + m := make(map[string]*model.Setting, len(items)) for _, item := range items { m[item.Name] = item } + return m, nil } func GetSettingItemByName(name string) (*model.Setting, error) { var item model.Setting + err := db.Where("name = ?", name).First(&item).Error if err != nil { return nil, err } + return &item, nil } @@ -54,10 +58,12 @@ func DeleteSettingItemByName(name string) error { func GetSettingItemValue(name string) (string, error) { var value string + err := db.Model(&model.Setting{}).Where("name = ?", name).Select("value").Take(&value).Error if err != nil { return "", err } + return value, nil } diff --git a/internal/db/update.go b/internal/db/update.go index d3cd434..8389fe4 100644 --- a/internal/db/update.go +++ b/internal/db/update.go @@ -96,49 +96,60 @@ func UpgradeDatabase() error { if !db.Migrator().HasTable(&model.Setting{}) { return autoMigrate(models...) } + setting := model.Setting{ Name: "database_version", Type: model.SettingTypeString, Group: model.SettingGroupDatabase, Value: CurrentVersion, } + err := FirstOrCreateSettingItemValue(&setting) if err != nil { return err } + currentVersion := setting.Value log.Infof("current database version: %s", currentVersion) + if flags.Global.ForceAutoMigrate || currentVersion != CurrentVersion { err = autoMigrate(models...) if err != nil { log.Fatalf("failed to auto migrate database: %s", err.Error()) } } + for currentVersion != "" { version, ok := dbVersions[currentVersion] if !ok { break } + if version.NextVersion != "" { log.Infof("Upgrading database to version %s", version.NextVersion) + if version.Upgrade != nil { err := version.Upgrade(db) if err != nil { return err } } + err := UpdateSettingItemValue("database_version", version.NextVersion) if err != nil { return err } } + currentVersion = version.NextVersion } + return nil } func autoMigrate(dst ...any) error { log.Info("migrating database...") + switch conf.Conf.Database.Type { case conf.DatabaseTypeMysql: if err := db.Exec("SET FOREIGN_KEY_CHECKS = 0").Error; err != nil { @@ -150,6 +161,7 @@ func autoMigrate(dst ...any) error { log.Fatalf("failed to set foreign key checks: %s", err.Error()) } }() + return db.Set("gorm:table_options", "ENGINE=InnoDB CHARSET=utf8mb4").AutoMigrate(dst...) case conf.DatabaseTypeSqlite3: if err := db.Exec("PRAGMA foreign_keys = OFF").Error; err != nil { @@ -161,6 +173,7 @@ func autoMigrate(dst ...any) error { log.Fatalf("failed to set foreign key checks: %s", err.Error()) } }() + return db.AutoMigrate(dst...) case conf.DatabaseTypePostgres: if err := db.Exec("SET CONSTRAINTS ALL DEFERRED").Error; err != nil { @@ -172,6 +185,7 @@ func autoMigrate(dst ...any) error { log.Fatalf("failed to set foreign key checks: %s", err.Error()) } }() + return db.AutoMigrate(dst...) default: return fmt.Errorf("unknown database type: %s", conf.Conf.Database.Type) diff --git a/internal/db/user.go b/internal/db/user.go index 5337117..6a56c44 100644 --- a/internal/db/user.go +++ b/internal/db/user.go @@ -57,9 +57,11 @@ func CreateUserWithHashedPassword( if username == "" { return nil, errors.New("username cannot be empty") } + if len(hashedPassword) == 0 { return nil, errors.New("password cannot be empty") } + u := &model.User{ Username: username, Role: model.RoleUser, @@ -68,12 +70,15 @@ func CreateUserWithHashedPassword( for _, c := range conf { c(u) } + if u.RegisteredByEmail && u.Email.String() == "" { return nil, errors.New("email cannot be empty") } + if u.Role == 0 { return nil, errors.New("role cannot be empty") } + err := db.Create(u).Error if err != nil { if errors.Is(err, gorm.ErrDuplicatedKey) { @@ -81,6 +86,7 @@ func CreateUserWithHashedPassword( } return nil, fmt.Errorf("failed to create user: %w", err) } + return u, nil } @@ -88,9 +94,11 @@ func CreateUser(username, password string, conf ...CreateUserConfig) (*model.Use if username == "" { return nil, errors.New("username cannot be empty") } + if password == "" { return nil, errors.New("password cannot be empty") } + hashedPassword, err := bcrypt.GenerateFromPassword( stream.StringToBytes(password), bcrypt.DefaultCost, @@ -98,6 +106,7 @@ func CreateUser(username, password string, conf ...CreateUserConfig) (*model.Use if err != nil { return nil, fmt.Errorf("failed to hash password: %w", err) } + return CreateUserWithHashedPassword(username, hashedPassword, conf...) } @@ -108,6 +117,7 @@ func CreateOrLoadUserWithProvider( if puid == "" { return nil, errors.New("provider user id cannot be empty") } + hashedPassword, err := bcrypt.GenerateFromPassword( stream.StringToBytes(password), bcrypt.DefaultCost, @@ -115,6 +125,7 @@ func CreateOrLoadUserWithProvider( if err != nil { return nil, fmt.Errorf("failed to hash password: %w", err) } + user := &model.User{ Username: username, HashedPassword: hashedPassword, @@ -128,16 +139,20 @@ func CreateOrLoadUserWithProvider( if user.Role == 0 { return nil, errors.New("role cannot be empty") } + for _, c := range conf { c(user) } + user.EnableAutoAddUsernameSuffix() + err = db.Joins("JOIN user_providers ON users.id = user_providers.user_id"). Where("user_providers.provider = ? AND user_providers.provider_user_id = ?", p, puid). FirstOrCreate(user).Error if err != nil { return nil, fmt.Errorf("failed to create or load user: %w", err) } + return user, nil } @@ -148,6 +163,7 @@ func CreateUserWithEmail( if email == "" { return nil, errors.New("email cannot be empty") } + return CreateUser(username, password, append(conf, WithRegisteredByEmail(email), WithEnableAutoAddUsernameSuffix(), @@ -156,24 +172,29 @@ func CreateUserWithEmail( func GetUserByProvider(p, puid string) (*model.User, error) { var user model.User + err := db.Joins("JOIN user_providers ON users.id = user_providers.user_id"). Where("user_providers.provider = ? AND user_providers.provider_user_id = ?", p, puid). First(&user).Error + return &user, HandleNotFound(err, ErrUserNotFound) } func GetUserByEmail(email string) (*model.User, error) { var user model.User + err := db.Where("email = ?", email).First(&user).Error return &user, HandleNotFound(err, ErrUserNotFound) } func GetProviderUserID(p, puid string) (string, error) { var userID string + err := db.Model(&model.UserProvider{}). Where("provider = ? AND provider_user_id = ?", p, puid). Select("user_id"). First(&userID).Error + return userID, HandleNotFound(err, ErrUserNotFound) } @@ -189,6 +210,7 @@ func BindProvider(uid, p, puid string) error { } return fmt.Errorf("failed to bind provider: %w", err) } + return nil } @@ -198,10 +220,13 @@ func UnBindProvider(uid, p string) error { if err := tx.Preload("UserProviders").Where("id = ?", uid).First(&user).Error; err != nil { return HandleNotFound(err, ErrUserNotFound) } + if user.RegisteredByProvider && len(user.UserProviders) <= 1 { return errors.New("user must have at least one provider") } + result := tx.Where("user_id = ? AND provider = ?", uid, p).Delete(&model.UserProvider{}) + return HandleUpdateResult(result, "provider") }) } @@ -219,30 +244,37 @@ func UnbindEmail(uid string) error { if err := tx.Select("email", "registered_by_email").Where("id = ?", uid).First(&user).Error; err != nil { return HandleNotFound(err, ErrUserNotFound) } + if user.RegisteredByEmail { return errors.New("user must have one email") } + if user.Email.String() == "" { return nil } + result := tx.Model(&model.User{}). Where("id = ?", uid). Update("email", model.EmptyNullString("")) + return HandleUpdateResult(result, ErrUserNotFound) }) } func GetBindProviders(uid string) ([]*model.UserProvider, error) { var providers []*model.UserProvider + err := db.Where("user_id = ?", uid).Find(&providers).Error if err != nil { return nil, fmt.Errorf("failed to get bind providers: %w", err) } + return providers, nil } func GetUserByUsername(username string) (*model.User, error) { var user model.User + err := db.Where("username = ?", username).First(&user).Error return &user, HandleNotFound(err, ErrUserNotFound) } @@ -252,6 +284,7 @@ func GetUserByUsernameLike( scopes ...func(*gorm.DB) *gorm.DB, ) ([]*model.User, error) { var users []*model.User + err := db.Where("username LIKE ?", fmt.Sprintf("%%%s%%", username)). Scopes(scopes...). Find(&users). @@ -259,6 +292,7 @@ func GetUserByUsernameLike( if err != nil { return nil, fmt.Errorf("failed to get users by username like: %w", err) } + return users, nil } @@ -267,6 +301,7 @@ func GerUsersIDByUsernameLike( scopes ...func(*gorm.DB) *gorm.DB, ) ([]string, error) { var ids []string + err := db.Model(&model.User{}). Where("username LIKE ?", fmt.Sprintf("%%%s%%", username)). Scopes(scopes...). @@ -275,11 +310,13 @@ func GerUsersIDByUsernameLike( if err != nil { return nil, fmt.Errorf("failed to get user IDs by username like: %w", err) } + return ids, nil } func GerUsersIDByIDLike(id string, scopes ...func(*gorm.DB) *gorm.DB) ([]string, error) { var ids []string + err := db.Model(&model.User{}). Where("id LIKE ?", utils.LIKE(id)). Scopes(scopes...). @@ -288,6 +325,7 @@ func GerUsersIDByIDLike(id string, scopes ...func(*gorm.DB) *gorm.DB) ([]string, if err != nil { return nil, fmt.Errorf("failed to get user IDs by ID like: %w", err) } + return ids, nil } @@ -296,6 +334,7 @@ func GetUserByIDOrUsernameLike( scopes ...func(*gorm.DB) *gorm.DB, ) ([]*model.User, error) { var users []*model.User + err := db.Where("id = ? OR username LIKE ?", idOrUsername, fmt.Sprintf("%%%s%%", idOrUsername)). Scopes(scopes...). Find(&users). @@ -303,6 +342,7 @@ func GetUserByIDOrUsernameLike( if err != nil { return nil, fmt.Errorf("failed to get users by ID or username like: %w", err) } + return users, nil } @@ -310,8 +350,11 @@ func GetUserByID(id string) (*model.User, error) { if len(id) != 32 { return nil, errors.New("user id is not 32 bit") } + var user model.User + err := db.Where("id = ?", id).First(&user).Error + return &user, HandleNotFound(err, ErrUserNotFound) } @@ -319,7 +362,9 @@ func BanUser(u *model.User) error { if u.Role == model.RoleBanned { return nil } + u.Role = model.RoleBanned + return SaveUser(u) } @@ -332,7 +377,9 @@ func UnbanUser(u *model.User) error { if u.Role != model.RoleBanned { return errors.New("user is not banned") } + u.Role = model.RoleUser + return SaveUser(u) } @@ -348,11 +395,13 @@ func DeleteUserByID(userID string) error { func LoadAndDeleteUserByID(userID string, columns ...clause.Column) (*model.User, error) { var user model.User + result := db.Unscoped(). Clauses(clause.Returning{Columns: columns}). Select(clause.Associations). Where("id = ?", userID). Delete(&user) + return &user, HandleUpdateResult(result, ErrUserNotFound) } @@ -365,7 +414,9 @@ func AddAdmin(u *model.User) error { if u.Role >= model.RoleAdmin { return nil } + u.Role = model.RoleAdmin + return SaveUser(u) } @@ -373,16 +424,20 @@ func RemoveAdmin(u *model.User) error { if u.Role < model.RoleAdmin { return nil } + u.Role = model.RoleUser + return SaveUser(u) } func GetAdmins() ([]*model.User, error) { var users []*model.User + err := db.Where("role = ?", model.RoleAdmin).Find(&users).Error if err != nil { return nil, fmt.Errorf("failed to get admins: %w", err) } + return users, nil } @@ -400,7 +455,9 @@ func AddRoot(u *model.User) error { if u.Role == model.RoleRoot { return nil } + u.Role = model.RoleRoot + return SaveUser(u) } @@ -408,7 +465,9 @@ func RemoveRoot(u *model.User) error { if u.Role != model.RoleRoot { return nil } + u.Role = model.RoleUser + return SaveUser(u) } @@ -450,19 +509,23 @@ func SetUsernameByID(userID, username string) error { func GetUserCount(scopes ...func(*gorm.DB) *gorm.DB) (int64, error) { var count int64 + err := db.Model(&model.User{}).Scopes(scopes...).Count(&count).Error if err != nil { return 0, fmt.Errorf("failed to get user count: %w", err) } + return count, nil } func GetUsers(scopes ...func(*gorm.DB) *gorm.DB) ([]*model.User, error) { var users []*model.User + err := db.Scopes(scopes...).Find(&users).Error if err != nil { return nil, fmt.Errorf("failed to get users: %w", err) } + return users, nil } diff --git a/internal/db/vendorBackend.go b/internal/db/vendorBackend.go index afb6c7e..cc4bf06 100644 --- a/internal/db/vendorBackend.go +++ b/internal/db/vendorBackend.go @@ -9,6 +9,7 @@ import ( func GetAllVendorBackend() ([]*model.VendorBackend, error) { var backends []*model.VendorBackend + err := db.Find(&backends).Error return backends, HandleNotFound(err, "backends") } @@ -58,6 +59,7 @@ func DeleteVendorBackends(endpoints []string) error { func GetVendorBackend(endpoint string) (*model.VendorBackend, error) { var backend model.VendorBackend + err := db.Where("backend_endpoint = ?", endpoint).First(&backend).Error return &backend, HandleNotFound(err, "backend") } @@ -65,6 +67,7 @@ func GetVendorBackend(endpoint string) (*model.VendorBackend, error) { func CreateOrSaveVendorBackend(backend *model.VendorBackend) (*model.VendorBackend, error) { return backend, Transactional(func(tx *gorm.DB) error { var existingBackend model.VendorBackend + err := tx.Where("backend_endpoint = ?", backend.Backend.Endpoint). First(&existingBackend). Error @@ -73,7 +76,9 @@ func CreateOrSaveVendorBackend(backend *model.VendorBackend) (*model.VendorBacke } else if err != nil { return err } + result := tx.Model(&existingBackend).Omit("created_at").Updates(backend) + return HandleUpdateResult(result, "vendor backend") }) } diff --git a/internal/db/vendorRecord.go b/internal/db/vendorRecord.go index e067004..1ae6287 100644 --- a/internal/db/vendorRecord.go +++ b/internal/db/vendorRecord.go @@ -13,6 +13,7 @@ const ( func GetBilibiliVendor(userID string) (*model.BilibiliVendor, error) { var vendor model.BilibiliVendor + err := db.Where("user_id = ?", userID).First(&vendor).Error return &vendor, HandleNotFound(err, ErrVendorNotFound) } @@ -21,13 +22,16 @@ func CreateOrSaveBilibiliVendor(vendorInfo *model.BilibiliVendor) (*model.Bilibi if vendorInfo.UserID == "" { return nil, errors.New("user_id must not be empty") } + return vendorInfo, Transactional(func(tx *gorm.DB) error { if errors.Is(tx.First(&model.BilibiliVendor{ UserID: vendorInfo.UserID, }).Error, gorm.ErrRecordNotFound) { return tx.Create(&vendorInfo).Error } + result := tx.Omit("created_at").Save(&vendorInfo) + return HandleUpdateResult(result, ErrVendorNotFound) }) } @@ -42,22 +46,26 @@ func GetAlistVendors( scopes ...func(*gorm.DB) *gorm.DB, ) ([]*model.AlistVendor, error) { var vendors []*model.AlistVendor + err := db.Scopes(scopes...).Where("user_id = ?", userID).Find(&vendors).Error return vendors, err } func GetAlistVendorsCount(userID string, scopes ...func(*gorm.DB) *gorm.DB) (int64, error) { var count int64 + err := db.Scopes(scopes...). Where("user_id = ?", userID). Model(&model.AlistVendor{}). Count(&count). Error + return count, err } func GetAlistVendor(userID, serverID string) (*model.AlistVendor, error) { var vendor model.AlistVendor + err := db.Where("user_id = ? AND server_id = ?", userID, serverID).First(&vendor).Error return &vendor, HandleNotFound(err, ErrVendorNotFound) } @@ -66,6 +74,7 @@ func CreateOrSaveAlistVendor(vendorInfo *model.AlistVendor) (*model.AlistVendor, if vendorInfo.UserID == "" || vendorInfo.ServerID == "" { return nil, errors.New("user_id and server_id must not be empty") } + return vendorInfo, Transactional(func(tx *gorm.DB) error { if errors.Is(tx.First(&model.AlistVendor{ UserID: vendorInfo.UserID, @@ -73,7 +82,9 @@ func CreateOrSaveAlistVendor(vendorInfo *model.AlistVendor) (*model.AlistVendor, }).Error, gorm.ErrRecordNotFound) { return tx.Create(&vendorInfo).Error } + result := tx.Omit("created_at").Save(&vendorInfo) + return HandleUpdateResult(result, ErrVendorNotFound) }) } @@ -86,28 +97,33 @@ func DeleteAlistVendor(userID, serverID string) error { func GetEmbyVendors(userID string, scopes ...func(*gorm.DB) *gorm.DB) ([]*model.EmbyVendor, error) { var vendors []*model.EmbyVendor + err := db.Scopes(scopes...).Where("user_id = ?", userID).Find(&vendors).Error return vendors, err } func GetEmbyVendorsCount(userID string, scopes ...func(*gorm.DB) *gorm.DB) (int64, error) { var count int64 + err := db.Scopes(scopes...). Where("user_id = ?", userID). Model(&model.EmbyVendor{}). Count(&count). Error + return count, err } func GetEmbyVendor(userID, serverID string) (*model.EmbyVendor, error) { var vendor model.EmbyVendor + err := db.Where("user_id = ? AND server_id = ?", userID, serverID).First(&vendor).Error return &vendor, HandleNotFound(err, ErrVendorNotFound) } func GetEmbyFirstVendor(userID string) (*model.EmbyVendor, error) { var vendor model.EmbyVendor + err := db.Where("user_id = ?", userID).First(&vendor).Error return &vendor, HandleNotFound(err, ErrVendorNotFound) } @@ -116,6 +132,7 @@ func CreateOrSaveEmbyVendor(vendorInfo *model.EmbyVendor) (*model.EmbyVendor, er if vendorInfo.UserID == "" || vendorInfo.ServerID == "" { return nil, errors.New("user_id and server_id must not be empty") } + return vendorInfo, Transactional(func(tx *gorm.DB) error { if errors.Is(tx.First(&model.EmbyVendor{ UserID: vendorInfo.UserID, @@ -123,7 +140,9 @@ func CreateOrSaveEmbyVendor(vendorInfo *model.EmbyVendor) (*model.EmbyVendor, er }).Error, gorm.ErrRecordNotFound) { return tx.Create(&vendorInfo).Error } + result := tx.Omit("created_at").Save(&vendorInfo) + return HandleUpdateResult(result, ErrVendorNotFound) }) } diff --git a/internal/email/email.go b/internal/email/email.go index cafe2d5..05a96f8 100644 --- a/internal/email/email.go +++ b/internal/email/email.go @@ -75,10 +75,12 @@ func init() { if err != nil { log.Fatalf("mjml test template error: %v", err) } + t, err := template.New("").Parse(body) if err != nil { log.Fatalf("parse test template error: %v", err) } + testTemplate = t body, err = mjml.ToHTML( @@ -89,10 +91,12 @@ func init() { if err != nil { log.Fatalf("mjml captcha template error: %v", err) } + t, err = template.New("").Parse(body) if err != nil { log.Fatalf("parse captcha template error: %v", err) } + captchaTemplate = t body, err = mjml.ToHTML( @@ -103,10 +107,12 @@ func init() { if err != nil { log.Fatalf("mjml retrieve password template error: %v", err) } + t, err = template.New("").Parse(body) if err != nil { log.Fatalf("parse retrieve password template error: %v", err) } + retrievePasswordTemplate = t } @@ -157,6 +163,7 @@ func SendBindCaptchaEmail(userID, userEmail string) error { } out := bytes.NewBuffer(nil) + err = captchaTemplate.Execute(out, captchaPayload{ Captcha: entry.Value(), Year: time.Now().Year(), @@ -212,6 +219,7 @@ func SendTestEmail(username, email string) error { } out := bytes.NewBuffer(nil) + err = testTemplate.Execute(out, testPayload{ Username: username, Year: time.Now().Year(), @@ -251,6 +259,7 @@ func SendSignupCaptchaEmail(email string) error { } out := bytes.NewBuffer(nil) + err = captchaTemplate.Execute(out, captchaPayload{ Captcha: entry.Value(), Year: time.Now().Year(), @@ -297,12 +306,15 @@ func SendRetrievePasswordCaptchaEmail(userID, email, host string) error { if userID == "" { return errors.New("user id is empty") } + if email == "" { return errors.New("email is empty") } + if host == "" { return errors.New("host is empty") } + if !strings.HasPrefix(host, "http://") && !strings.HasPrefix(host, "https://") { log.Errorf("host: %s must start with http:// or https://", host) return errors.New("get host error") @@ -312,6 +324,7 @@ func SendRetrievePasswordCaptchaEmail(userID, email, host string) error { if err != nil { return err } + u.Path = `web/auth/reset` pool, err := getSMTPPool() @@ -334,6 +347,7 @@ func SendRetrievePasswordCaptchaEmail(userID, email, host string) error { u.RawQuery = q.Encode() out := bytes.NewBuffer(nil) + err = retrievePasswordTemplate.Execute(out, retrievePasswordPayload{ Captcha: entry.Value(), Host: host, diff --git a/internal/email/smtp.go b/internal/email/smtp.go index 356180f..43f380f 100644 --- a/internal/email/smtp.go +++ b/internal/email/smtp.go @@ -142,6 +142,7 @@ func getSMTPPool() (*smtp.Pool, error) { if configChanged { configChanged = false + if smtpPool != nil { smtpPool.Close() smtpPool = nil @@ -153,6 +154,7 @@ func getSMTPPool() (*smtp.Pool, error) { if err != nil { return nil, err } + smtpPool = pool } diff --git a/internal/model/current.go b/internal/model/current.go index b3bd532..3c45fbe 100644 --- a/internal/model/current.go +++ b/internal/model/current.go @@ -33,10 +33,13 @@ func (c *Current) UpdateStatus() Status { c.Status.LastUpdate = time.Now() return c.Status } + if c.Status.IsPlaying { c.Status.CurrentTime += time.Since(c.Status.LastUpdate).Seconds() * c.Status.PlaybackRate } + c.Status.LastUpdate = time.Now() + return c.Status } @@ -45,6 +48,7 @@ func (c *Current) setLiveStatus() Status { c.Status.PlaybackRate = 1.0 c.Status.CurrentTime = 0 c.Status.LastUpdate = time.Now() + return c.Status } @@ -52,14 +56,18 @@ func (c *Current) SetStatus(playing bool, seek, rate, timeDiff float64) Status { if c.Movie.IsLive { return c.setLiveStatus() } + c.Status.IsPlaying = playing + c.Status.PlaybackRate = rate if playing { c.Status.CurrentTime = seek + (timeDiff * rate) } else { c.Status.CurrentTime = seek } + c.Status.LastUpdate = time.Now() + return c.Status } @@ -67,13 +75,16 @@ func (c *Current) SetSeekRate(seek, rate, timeDiff float64) Status { if c.Movie.IsLive { return c.setLiveStatus() } + if c.Status.IsPlaying { c.Status.CurrentTime = seek + (timeDiff * rate) } else { c.Status.CurrentTime = seek } + c.Status.PlaybackRate = rate c.Status.LastUpdate = time.Now() + return c.Status } @@ -81,11 +92,14 @@ func (c *Current) SetSeek(seek, timeDiff float64) Status { if c.Movie.IsLive { return c.setLiveStatus() } + if c.Status.IsPlaying { c.Status.CurrentTime = seek + (timeDiff * c.Status.PlaybackRate) } else { c.Status.CurrentTime = seek } + c.Status.LastUpdate = time.Now() + return c.Status } diff --git a/internal/model/member.go b/internal/model/member.go index 26c8893..f56f68a 100644 --- a/internal/model/member.go +++ b/internal/model/member.go @@ -157,12 +157,15 @@ func (r *RoomMember) HasPermission(permission RoomMemberPermission) bool { if r.Role.IsAdmin() { return true } + if !r.Role.IsMember() { return false } + if r.Status != RoomMemberStatusActive { return false } + return r.Permissions.Has(permission) } @@ -170,11 +173,14 @@ func (r *RoomMember) HasAdminPermission(permission RoomAdminPermission) bool { if r.Role.IsCreator() { return true } + if !r.Role.IsAdmin() { return false } + if r.Status != RoomMemberStatusActive { return false } + return r.AdminPermissions.Has(permission) } diff --git a/internal/model/movie.go b/internal/model/movie.go index 4b43181..eb1defc 100644 --- a/internal/model/movie.go +++ b/internal/model/movie.go @@ -45,17 +45,21 @@ func (m *Movie) BeforeCreate(_ *gorm.DB) error { func (m *Movie) BeforeSave(tx *gorm.DB) error { if m.ParentID != "" { mv := &Movie{} + err := tx.Where("id = ?", m.ParentID).First(mv).Error if err != nil { return fmt.Errorf("load parent movie failed: %w", err) } + if !mv.IsFolder { return errors.New("parent is not a folder") } + if mv.IsDynamicFolder() { return errors.New("parent is a dynamic folder, cannot add child") } } + return nil } @@ -95,10 +99,12 @@ func (m *MovieBase) Clone() *MovieBase { URL: ms.URL, } } + hds := make(map[string]string, len(m.Headers)) for k, v := range m.Headers { hds[k] = v } + sbs := make(map[string]*Subtitle, len(m.Subtitles)) for k, v := range m.Subtitles { sbs[k] = &Subtitle{ @@ -106,6 +112,7 @@ func (m *MovieBase) Clone() *MovieBase { Type: v.Type, } } + return &MovieBase{ URL: m.URL, MoreSources: mss, @@ -138,6 +145,7 @@ func (ns *EmptyNullString) Scan(value any) error { *ns = "" return nil } + switch v := value.(type) { case []byte: *ns = EmptyNullString(v) @@ -146,6 +154,7 @@ func (ns *EmptyNullString) Scan(value any) error { default: return fmt.Errorf("unsupported type: %T", v) } + return nil } @@ -217,6 +226,7 @@ func GetAlistServerIDFromPath(path string) (serverID, filePath string, err error if !found { return "", path, errors.New("path is invalid") } + return before, after, nil } @@ -255,8 +265,10 @@ func (a *AlistStreamingInfo) BeforeSave(_ *gorm.DB) error { if err != nil { return err } + a.Password = s } + return nil } @@ -266,8 +278,10 @@ func (a *AlistStreamingInfo) AfterSave(_ *gorm.DB) error { if err != nil { return err } + a.Password = string(b) } + return nil } diff --git a/internal/model/user.go b/internal/model/user.go index 632521f..92a3e3e 100644 --- a/internal/model/user.go +++ b/internal/model/user.go @@ -73,14 +73,17 @@ func (u *User) CheckPassword(password string) bool { func (u *User) BeforeCreate(tx *gorm.DB) error { if u.autoAddUsernameSuffix { var existingUser User + err := tx.Select("username").Where("username = ?", u.Username).First(&existingUser).Error if err == nil { u.Username = fmt.Sprintf("%s#%d", u.Username, rand.IntN(9999)) } } + if u.ID == "" { u.ID = utils.SortUUID() } + return nil } diff --git a/internal/model/vendorBackend.go b/internal/model/vendorBackend.go index af7f661..b9389aa 100644 --- a/internal/model/vendorBackend.go +++ b/internal/model/vendorBackend.go @@ -38,14 +38,17 @@ func (b *Backend) Validate() error { if b.Endpoint == "" { return errors.New("new http client failed, endpoint is empty") } + if b.Consul.ServiceName != "" && b.Etcd.ServiceName != "" { return errors.New("new grpc client failed, consul and etcd can't be used at the same time") } + if b.TimeOut != "" { if _, err := time.ParseDuration(b.TimeOut); err != nil { return err } } + return nil } @@ -68,27 +71,32 @@ type BackendUsedBy struct { func (v *VendorBackend) BeforeSave(_ *gorm.DB) error { key := utils.GenCryptoKey(v.Backend.Endpoint) + var err error if v.Backend.JwtSecret != "" { if v.Backend.JwtSecret, err = utils.CryptoToBase64([]byte(v.Backend.JwtSecret), key); err != nil { return err } } + if v.Backend.Consul.Token != "" { if v.Backend.Consul.Token, err = utils.CryptoToBase64([]byte(v.Backend.Consul.Token), key); err != nil { return err } } + if v.Backend.Etcd.Password != "" { if v.Backend.Etcd.Password, err = utils.CryptoToBase64([]byte(v.Backend.Etcd.Password), key); err != nil { return err } } + if v.Backend.CustomCa != "" { if v.Backend.CustomCa, err = utils.CryptoToBase64([]byte(v.Backend.CustomCa), key); err != nil { return err } } + return nil } @@ -99,29 +107,37 @@ func (v *VendorBackend) AfterSave(_ *gorm.DB) error { if err != nil { return err } + v.Backend.JwtSecret = stream.BytesToString(jwtSecret) } + if v.Backend.Consul.Token != "" { token, err := utils.DecryptoFromBase64(v.Backend.Consul.Token, key) if err != nil { return err } + v.Backend.Consul.Token = stream.BytesToString(token) } + if v.Backend.Etcd.Password != "" { password, err := utils.DecryptoFromBase64(v.Backend.Etcd.Password, key) if err != nil { return err } + v.Backend.Etcd.Password = stream.BytesToString(password) } + if v.Backend.CustomCa != "" { customCa, err := utils.DecryptoFromBase64(v.Backend.CustomCa, key) if err != nil { return err } + v.Backend.CustomCa = stream.BytesToString(customCa) } + return nil } diff --git a/internal/model/vendorRecord.go b/internal/model/vendorRecord.go index 022aff5..6ca6100 100644 --- a/internal/model/vendorRecord.go +++ b/internal/model/vendorRecord.go @@ -24,8 +24,10 @@ func (b *BilibiliVendor) BeforeSave(_ *gorm.DB) error { if err != nil { return err } + b.Cookies[k] = value } + return nil } @@ -36,8 +38,10 @@ func (b *BilibiliVendor) AfterSave(_ *gorm.DB) error { if err != nil { return err } + b.Cookies[k] = stream.BytesToString(value) } + return nil } @@ -64,36 +68,47 @@ func GenAlistServerID(a *AlistVendor) { func (a *AlistVendor) BeforeSave(_ *gorm.DB) error { key := utils.GenCryptoKey(a.UserID) + var err error if a.Host, err = utils.CryptoToBase64([]byte(a.Host), key); err != nil { return err } + if a.Username, err = utils.CryptoToBase64([]byte(a.Username), key); err != nil { return err } + if a.HashedPassword, err = utils.Crypto(a.HashedPassword, key); err != nil { return err } + return nil } func (a *AlistVendor) AfterSave(_ *gorm.DB) error { key := utils.GenCryptoKey(a.UserID) + host, err := utils.DecryptoFromBase64(a.Host, key) if err != nil { return err } + a.Host = stream.BytesToString(host) + username, err := utils.DecryptoFromBase64(a.Username, key) if err != nil { return err } + a.Username = stream.BytesToString(username) + hashedPassword, err := utils.Decrypto(a.HashedPassword, key) if err != nil { return err } + a.HashedPassword = hashedPassword + return nil } @@ -114,28 +129,36 @@ type EmbyVendor struct { func (e *EmbyVendor) BeforeSave(_ *gorm.DB) error { key := utils.GenCryptoKey(e.ServerID) + var err error if e.Host, err = utils.CryptoToBase64(stream.StringToBytes(e.Host), key); err != nil { return err } + if e.APIKey, err = utils.CryptoToBase64(stream.StringToBytes(e.APIKey), key); err != nil { return err } + return nil } func (e *EmbyVendor) AfterSave(_ *gorm.DB) error { key := utils.GenCryptoKey(e.ServerID) + host, err := utils.DecryptoFromBase64(e.Host, key) if err != nil { return err } + e.Host = stream.BytesToString(host) + apiKey, err := utils.DecryptoFromBase64(e.APIKey, key) if err != nil { return err } + e.APIKey = stream.BytesToString(apiKey) + return nil } diff --git a/internal/op/client.go b/internal/op/client.go index 318f707..8a4b9f3 100644 --- a/internal/op/client.go +++ b/internal/op/client.go @@ -65,6 +65,7 @@ func (c *Client) SendChatMessage(message string) error { if !c.u.HasRoomPermission(c.r, model.PermissionSendChatMessage) { return model.ErrNoPermission } + return c.Broadcast(&pb.Message{ Type: pb.MessageType_CHAT, Timestamp: time.Now().UnixMilli(), @@ -81,10 +82,13 @@ func (c *Client) SendChatMessage(message string) error { func (c *Client) Send(msg Message) error { c.wg.Add(1) defer c.wg.Done() + if c.Closed() { return ErrAlreadyClosed } + c.c <- msg + return nil } @@ -92,8 +96,10 @@ func (c *Client) Close() error { if !atomic.CompareAndSwapUint32(&c.closed, 0, 1) { return ErrAlreadyClosed } + c.wg.Wait() close(c.c) + return nil } @@ -118,6 +124,7 @@ func (c *Client) SetStatus(playing bool, seek, rate, timeDiff float64) error { if err != nil { return err } + return c.Broadcast(&pb.Message{ Type: pb.MessageType_STATUS, Sender: &pb.Sender{ diff --git a/internal/op/current.go b/internal/op/current.go index eccdff2..19c2182 100644 --- a/internal/op/current.go +++ b/internal/op/current.go @@ -23,6 +23,7 @@ func newCurrent(roomID string, c *model.Current) *current { }, } } + return ¤t{ roomID: roomID, current: *c, @@ -32,7 +33,9 @@ func newCurrent(roomID string, c *model.Current) *current { func (c *current) Current() model.Current { c.lock.RLock() defer c.lock.RUnlock() + c.current.UpdateStatus() + return c.current } @@ -59,7 +62,9 @@ func (c *current) SetMovie(movie model.CurrentMovie, play bool) { func (c *current) Status() model.Status { c.lock.RLock() defer c.lock.RUnlock() + c.current.UpdateStatus() + return c.current.Status } @@ -73,6 +78,7 @@ func (c *current) SetStatus(playing bool, seek, rate, timeDiff float64) *model.S }() s := c.current.SetStatus(playing, seek, rate, timeDiff) + return &s } @@ -86,5 +92,6 @@ func (c *current) SetSeekRate(seek, rate, timeDiff float64) *model.Status { }() s := c.current.SetSeekRate(seek, rate, timeDiff) + return &s } diff --git a/internal/op/hub.go b/internal/op/hub.go index 926bdfd..ccacf62 100644 --- a/internal/op/hub.go +++ b/internal/op/hub.go @@ -68,6 +68,7 @@ func (h *Hub) Start() error { go h.serve() go h.ping() }) + return nil } @@ -79,14 +80,17 @@ func (h *Hub) serve() { h.clients.Range(func(_ string, clients *clients) bool { clients.lock.RLock() defer clients.lock.RUnlock() + for _, c := range clients.m { if utils.In(message.ignoreUserID, c.u.ID) || utils.In(message.ignoreConnID, c.ConnID()) { continue } + if message.rtcJoined && !c.RTCJoined() { continue } + if err := c.Send(message.data); err != nil { c.Close() } @@ -104,6 +108,7 @@ func (h *Hub) serve() { func (h *Hub) ping() { ticker := time.NewTicker(time.Second * 5) defer ticker.Stop() + var ( pre int64 current int64 @@ -121,6 +126,7 @@ func (h *Hub) ping() { }); err != nil { continue } + pre = current } else { if err := h.Broadcast(&PingMessage{}); err != nil { @@ -151,33 +157,42 @@ func (h *Hub) Close() error { if !atomic.CompareAndSwapUint32(&h.closed, 0, 1) { return ErrAlreadyClosed } + close(h.exit) h.clients.Range(func(id string, clients *clients) bool { h.clients.CompareAndDelete(id, clients) + clients.lock.Lock() defer clients.lock.Unlock() + for id, c := range clients.m { delete(clients.m, id) c.Close() } + return true }) h.wg.Wait() close(h.broadcast) + return nil } func (h *Hub) Broadcast(data Message, conf ...BroadcastConf) error { h.wg.Add(1) defer h.wg.Done() + if h.Closed() { return ErrAlreadyClosed } + h.once.Done() + msg := &broadcastMessage{data: data} for _, c := range conf { c(msg) } + select { case h.broadcast <- msg: return nil @@ -190,23 +205,30 @@ func (h *Hub) RegClient(cli *Client) error { if h.Closed() { return ErrAlreadyClosed } + err := h.Start() if err != nil { return err } + c, _ := h.clients.LoadOrStore(cli.u.ID, &clients{}) + c.lock.Lock() defer c.lock.Unlock() + newC, loaded := h.clients.Load(cli.u.ID) if !loaded || c != newC { return h.RegClient(cli) } + if c.m == nil { c.m = make(map[string]*Client) } else if _, ok := c.m[cli.ConnID()]; ok { return errors.New("client already exists") } + c.m[cli.ConnID()] = cli + return nil } @@ -214,22 +236,29 @@ func (h *Hub) UnRegClient(cli *Client) error { if h.Closed() { return ErrAlreadyClosed } + if cli == nil { return errors.New("user is nil") } + c, loaded := h.clients.Load(cli.u.ID) if !loaded { return errors.New("client not found") } + c.lock.Lock() defer c.lock.Unlock() + if _, ok := c.m[cli.ConnID()]; !ok { return errors.New("client not found") } + delete(c.m, cli.ConnID()) + if len(c.m) == 0 { h.clients.CompareAndDelete(cli.u.ID, c) } + return nil } @@ -241,18 +270,22 @@ func (h *Hub) SendToUser(userID string, data Message) (err error) { if h.Closed() { return ErrAlreadyClosed } + cli, ok := h.clients.Load(userID) if !ok { return nil } + cli.lock.RLock() defer cli.lock.RUnlock() + for _, c := range cli.m { if err = c.Send(data); err != nil { c.Close() } } - return + + return err } func (h *Hub) SendToConnID(userID, connID string, data Message) error { @@ -260,6 +293,7 @@ func (h *Hub) SendToConnID(userID, connID string, data Message) error { if !ok { return nil } + return cli.Send(data) } @@ -268,7 +302,9 @@ func (h *Hub) GetClientByConnID(userID, connID string) (*Client, bool) { if !ok { return nil, false } + client, ok := c.m[connID] + return client, ok } @@ -282,11 +318,14 @@ func (h *Hub) OnlineCount(userID string) int { if !ok { return 0 } + c.lock.RLock() defer c.lock.RUnlock() + if len(c.m) == 0 { h.clients.CompareAndDelete(userID, c) } + return len(c.m) } @@ -294,14 +333,18 @@ func (h *Hub) KickUser(userID string) error { if h.Closed() { return ErrAlreadyClosed } + cli, ok := h.clients.Load(userID) if !ok { return nil } + cli.lock.RLock() defer cli.lock.RUnlock() + for _, c := range cli.m { c.Close() } + return nil } diff --git a/internal/op/movie.go b/internal/op/movie.go index 7f07478..7b092a8 100644 --- a/internal/op/movie.go +++ b/internal/op/movie.go @@ -48,12 +48,15 @@ func (m *Movie) ExpireID(ctx context.Context) (uint64, error) { } case m.Live && m.VendorInfo.Vendor == model.VendorBilibili: liveCache := m.BilibiliCache().Live + _, err := liveCache.Get(ctx) if err != nil { return 0, err } + return uint64(liveCache.Last()), nil } + return uint64(crc32.ChecksumIEEE([]byte(m.ID))), nil } @@ -68,10 +71,12 @@ func (m *Movie) CheckExpired(ctx context.Context, expireID uint64) (bool, error) case m.Live && m.VendorInfo.Vendor == model.VendorBilibili: return time.Now().UnixNano()-int64(expireID) > m.BilibiliCache().Live.MaxAge(), nil } + id, err := m.ExpireID(ctx) if err != nil { return false, err } + return expireID != id, nil } @@ -89,6 +94,7 @@ func (m *Movie) ClearCache() error { if err != nil { return err } + err = emc.Clear(context.Background(), u.Value().EmbyCache()) if err != nil { return err @@ -106,6 +112,7 @@ func (m *Movie) AlistCache() *cache.AlistMovieCache { return m.AlistCache() } } + return c } @@ -117,6 +124,7 @@ func (m *Movie) BilibiliCache() *cache.BilibiliMovieCache { return m.BilibiliCache() } } + return c } @@ -128,6 +136,7 @@ func (m *Movie) EmbyCache() *cache.EmbyMovieCache { return m.EmbyCache() } } + return c } @@ -135,10 +144,12 @@ func (m *Movie) Channel() (*rtmps.Channel, error) { if m.IsFolder { return nil, errors.New("this is a folder") } + c, err := m.initChannel() if err != nil { return nil, err } + return c, nil } @@ -153,8 +164,10 @@ func (m *Movie) compareAndSwapInitChannel() (*rtmps.Channel, bool) { if !m.channel.CompareAndSwap(nil, c) { return m.compareAndSwapInitChannel() } + return c, true } + return c, false } @@ -188,10 +201,12 @@ func (m *Movie) initRtmpSourceChannel() (*rtmps.Channel, error) { if !init { return c, nil } + err := c.InitHlsPlayer(hls.WithGenTsNameFunc(genTSName)) if err != nil { return nil, fmt.Errorf("init rtmp hls player error: %w", err) } + return c, nil } @@ -200,12 +215,14 @@ func (m *Movie) initRtmpProxyChannel() (*rtmps.Channel, error) { if !init { return c, nil } + err := c.InitHlsPlayer(hls.WithGenTsNameFunc(genTSName)) if err != nil { return nil, fmt.Errorf("init rtmp hls player error: %w", err) } go m.handleRtmpProxy(c) + return c, nil } @@ -214,6 +231,7 @@ func (m *Movie) handleRtmpProxy(c *rtmps.Channel) { if c.Closed() { return } + cli := core.NewConnClient() if err := cli.Start(m.URL, av.PLAY); err != nil { log.Errorf("push live error: %v", err) @@ -221,6 +239,7 @@ func (m *Movie) handleRtmpProxy(c *rtmps.Channel) { time.Sleep(time.Second) continue } + if err := c.PushStart(rtmpProto.NewReader(cli)); err != nil { log.Errorf("push live error: %v", err) cli.Close() @@ -238,12 +257,14 @@ func (m *Movie) initHTTPProxyChannel() (*rtmps.Channel, error) { if !init { return c, nil } + err := c.InitHlsPlayer(hls.WithGenTsNameFunc(genTSName)) if err != nil { return nil, fmt.Errorf("init http hls player error: %w", err) } go m.handleHTTPProxy(c) + return c, nil } @@ -252,18 +273,22 @@ func (m *Movie) handleHTTPProxy(c *rtmps.Channel) { if c.Closed() { return } + req, err := http.NewRequestWithContext(context.Background(), http.MethodGet, m.URL, nil) if err != nil { log.Errorf("get live error: %v", err) time.Sleep(time.Second) continue } + for k, v := range m.Headers { req.Header.Set(k, v) } + if req.Header.Get("User-Agent") == "" { req.Header.Set("User-Agent", utils.UA) } + resp, err := uhc.Do(req) if err != nil { log.Errorf("get live error: %v", err) @@ -271,6 +296,7 @@ func (m *Movie) handleHTTPProxy(c *rtmps.Channel) { time.Sleep(time.Second) continue } + if err := c.PushStart(flv.NewReader(resp.Body)); err != nil { log.Errorf("push live error: %v", err) resp.Body.Close() @@ -308,6 +334,7 @@ func (m *Movie) validateRTMPSource() error { case !m.Live && m.RtmpSource: return errors.New("rtmp source can't be true when movie is not live") } + return nil } @@ -335,9 +362,11 @@ func (m *Movie) validateLiveProxy(u *url.URL) error { if !settings.LiveProxy.Get() { return errors.New("live proxy is not enabled") } + if !settings.AllowProxyToLocal.Get() && utils.IsLocalIP(u.Host) { return errors.New("local ip is not allowed") } + switch u.Scheme { case "rtmp", "http", "https": return nil @@ -350,12 +379,15 @@ func (m *Movie) validateMovieProxy(u *url.URL) error { if !settings.MovieProxy.Get() { return errors.New("movie proxy is not enabled") } + if !settings.AllowProxyToLocal.Get() && utils.IsLocalIP(u.Host) { return errors.New("local ip is not allowed") } + if u.Scheme != "http" && u.Scheme != "https" { return fmt.Errorf("unsupported scheme: %s", u.Scheme) } + return nil } @@ -389,6 +421,7 @@ func (m *Movie) Terminate() error { if m.IsFolder { return nil } + c := m.channel.Swap(nil) if c != nil { err := c.Close() @@ -396,6 +429,7 @@ func (m *Movie) Terminate() error { return err } } + return nil } @@ -404,9 +438,11 @@ func (m *Movie) Close() error { if err != nil { return err } + err = m.ClearCache() if err != nil { return err } + return nil } diff --git a/internal/op/movies.go b/internal/op/movies.go index f5d4ab1..a6869fa 100644 --- a/internal/op/movies.go +++ b/internal/op/movies.go @@ -40,6 +40,7 @@ func (m *movies) AddMovie(mo *model.Movie) error { if ok { _ = old.Close() } + return nil } @@ -80,10 +81,12 @@ func (m *movies) GetChannel(id string) (*rtmps.Channel, error) { if id == "" { return nil, errors.New("channel name is nil") } + movie, err := m.GetMovieByID(id) if err != nil { return nil, err } + return movie.Channel() } @@ -92,15 +95,19 @@ func (m *movies) Update(movieID string, movie *model.MovieBase) error { if err != nil { return err } + mv.MovieBase = *movie + err = db.SaveMovie(mv) if err != nil { return err } + mm, loaded := m.cache.LoadAndDelete(mv.ID) if loaded { _ = mm.Close() } + return nil } @@ -126,7 +133,9 @@ func (m *movies) DeleteMovieByParentID(parentID string) error { if err != nil { return err } + m.DeleteMovieAndChiledCache(parentID) + return nil } @@ -135,7 +144,9 @@ func (m *movies) DeleteMovieByID(id string) error { if err != nil { return err } + m.DeleteMovieAndChiledCache(id) + return nil } @@ -144,10 +155,12 @@ func (m *movies) DeleteMovieAndChiledCache(id ...string) { for _, id := range id { idm[model.EmptyNullString(id)] = struct{}{} } + if _, ok := idm[model.EmptyNullString("")]; ok { m.ClearCache() return } + m.deleteMovieAndChiledCache(idm) } @@ -158,11 +171,14 @@ func (m *movies) deleteMovieAndChiledCache(ids map[model.EmptyNullString]struct{ if value.IsFolder { next[model.EmptyNullString(value.ID)] = struct{}{} } + m.cache.CompareAndDelete(key, value) value.Close() } + return true }) + if len(next) > 0 { m.deleteMovieAndChiledCache(next) } @@ -173,7 +189,9 @@ func (m *movies) DeleteMoviesByID(ids []string) error { if err != nil { return err } + m.DeleteMovieAndChiledCache(ids...) + return nil } @@ -181,18 +199,22 @@ func (m *movies) GetMovieByID(id string) (*Movie, error) { if id == "" { return nil, errors.New("movie id is nil") } + mm, ok := m.cache.Load(id) if ok { return mm, nil } + mv, err := db.GetMovieByID(m.roomID, id) if err != nil { return nil, err } + mm, _ = m.cache.LoadOrStore(mv.ID, &Movie{ room: m.room, Movie: mv, }) + return mm, nil } @@ -211,16 +233,19 @@ func (m *movies) GetMoviesWithPage( if keyword != "" { scopes = append(scopes, db.WhereMovieNameLikeOrURLLike(keyword, keyword)) } + count, err := db.GetMoviesCountByRoomID( m.roomID, append(scopes, db.Paginate(page, pageSize))...) if err != nil { return nil, 0, err } + movies, err := db.GetMoviesByRoomID(m.roomID, scopes...) if err != nil { return nil, 0, err } + return movies, count, nil } @@ -229,13 +254,16 @@ func (m *movies) IsParentOf(id, parentID string) (bool, error) { if parentID == "" { return id != "", nil } + mv, err := m.GetMovieByID(parentID) if err != nil { return false, fmt.Errorf("get parent movie failed: %w", err) } + if !mv.IsFolder { return false, nil } + return m.isParentOf(id, parentID, true) } @@ -243,16 +271,19 @@ func (m *movies) IsParentFolder(id, parentID string) (bool, error) { if parentID == "" { return id != "", nil } + mv, err := m.GetMovieByID(parentID) if err != nil { return false, fmt.Errorf("get parent movie failed: %w", err) } + firstCheck := true if mv.IsFolder { firstCheck = false } else { parentID = mv.ParentID.String() } + return m.isParentOf(id, parentID, firstCheck) } @@ -261,11 +292,14 @@ func (m *movies) isParentOf(id, parentID string, firstCheck bool) (bool, error) if err != nil { return false, err } + if mv.ParentID == "" { return false, nil } + if mv.ParentID == model.EmptyNullString(parentID) { return !firstCheck, nil } + return m.isParentOf(string(mv.ParentID), parentID, false) } diff --git a/internal/op/room.go b/internal/op/room.go index ed39fcb..23d35af 100644 --- a/internal/op/room.go +++ b/internal/op/room.go @@ -34,6 +34,7 @@ func (r *Room) lazyInitHub() *Hub { return r.lazyInitHub() } } + return h } @@ -90,6 +91,7 @@ func (r *Room) close() { h.Close() } } + r.movies.Close() r.members.Clear() } @@ -99,6 +101,7 @@ func (r *Room) UpdateMovie(movieID string, movie *model.MovieBase) error { if err != nil { return err } + return r.movies.Update(movieID, movie) } @@ -118,10 +121,12 @@ func (r *Room) UserRole(userID string) (model.RoomMemberRole, error) { if r.IsCreator(userID) { return model.RoomMemberRoleCreator, nil } + rur, err := r.LoadMember(userID) if err != nil { return model.RoomMemberRoleUnknown, err } + return rur.Role, nil } @@ -131,6 +136,7 @@ func (r *Room) IsAdmin(userID string) bool { if err != nil { return false } + return role.IsAdmin() } @@ -187,10 +193,12 @@ func (r *Room) LoadOrCreateMemberStatus(userID string) (model.RoomMemberStatus, if r.IsCreator(userID) { return model.RoomMemberStatusActive, nil } + rur, err := r.LoadOrCreateMember(userID) if err != nil { return model.RoomMemberStatusNotJoined, err } + return rur.Status, nil } @@ -198,6 +206,7 @@ func (r *Room) LoadMemberStatus(userID string) (model.RoomMemberStatus, error) { if r.IsCreator(userID) { return model.RoomMemberStatusActive, nil } + rur, err := r.LoadMember(userID) if err != nil { if errors.Is(err, db.NotFoundError(db.ErrRoomMemberNotFound)) { @@ -205,6 +214,7 @@ func (r *Room) LoadMemberStatus(userID string) (model.RoomMemberStatus, error) { } return model.RoomMemberStatusNotJoined, err } + return rur.Status, nil } @@ -212,13 +222,16 @@ func (r *Room) LoadOrCreateMember(userID string) (*model.RoomMember, error) { if r.Settings.DisableJoinNewUser { return r.LoadMember(userID) } + if r.IsGuest(userID) && (r.Settings.DisableGuest || !settings.EnableGuest.Get()) { return nil, errors.New("guest is disabled") } + member, ok := r.members.Load(userID) if ok { return member, nil } + var conf []db.CreateRoomMemberRelationConfig if r.IsCreator(userID) { conf = append( @@ -244,16 +257,19 @@ func (r *Room) LoadOrCreateMember(userID string) (*model.RoomMember, error) { db.WithRoomMemberAdminPermissions(model.NoAdminPermission), ) } + if r.Settings.JoinNeedReview { conf = append(conf, db.WithRoomMemberStatus(model.RoomMemberStatusPending)) } else { conf = append(conf, db.WithRoomMemberStatus(model.RoomMemberStatusActive)) } } + member, err := db.FirstOrCreateRoomMemberRelation(r.ID, userID, conf...) if err != nil { return nil, err } + return r.storeMember(userID, member), nil } @@ -261,14 +277,17 @@ func (r *Room) LoadMember(userID string) (*model.RoomMember, error) { if r.IsGuest(userID) && (r.Settings.DisableGuest || !settings.EnableGuest.Get()) { return nil, errors.New("guest is disabled") } + member, ok := r.members.Load(userID) if ok { return member, nil } + member, err := db.GetRoomMember(r.ID, userID) if err != nil { return nil, fmt.Errorf("get room member failed: %w", err) } + return r.storeMember(userID, member), nil } @@ -282,6 +301,7 @@ func (r *Room) storeMember(userID string, member *model.RoomMember) *model.RoomM case r.IsGuest(userID): member.Role = model.RoomMemberRoleMember member.Permissions = r.Settings.GuestPermissions + member.AdminPermissions = model.NoAdminPermission if member.Status.IsBanned() { member.Status = model.RoomMemberStatusActive @@ -289,7 +309,9 @@ func (r *Room) storeMember(userID string, member *model.RoomMember) *model.RoomM case member.Role.IsAdmin(): member.Permissions = model.AllPermissions } + member, _ = r.members.LoadOrStore(userID, member) + return member } @@ -297,10 +319,12 @@ func (r *Room) LoadRoomMemberPermission(userID string) (model.RoomMemberPermissi if r.IsCreator(userID) { return model.AllPermissions, nil } + member, err := r.LoadMember(userID) if err != nil { return model.NoPermission, err } + return member.Permissions, nil } @@ -308,10 +332,12 @@ func (r *Room) LoadRoomAdminPermission(userID string) (model.RoomAdminPermission if r.IsCreator(userID) { return model.AllAdminPermissions, nil } + member, err := r.LoadMember(userID) if err != nil { return model.NoAdminPermission, err } + return member.AdminPermissions, nil } @@ -323,9 +349,11 @@ func (r *Room) SetPassword(password string) error { if r.CheckPassword(password) && r.NeedPassword() { return errors.New("password is the same") } + var hashedPassword []byte if password != "" { var err error + hashedPassword, err = bcrypt.GenerateFromPassword( stream.StringToBytes(password), bcrypt.DefaultCost, @@ -334,7 +362,9 @@ func (r *Room) SetPassword(password string) error { return err } } + r.HashedPassword = hashedPassword + return db.SetRoomHashedPassword(r.ID, hashedPassword) } @@ -345,19 +375,23 @@ func (r *Room) checkCanModifyMovie(id string) error { } return nil } + cid := r.current.current.Movie.ID if cid != "" { if cid == id { return errors.New("cannot modify current movie") } + ok, err := r.movies.IsParentFolder(cid, id) if err != nil { return fmt.Errorf("check parent failed: %w", err) } + if ok { return errors.New("cannot modify current movie's parent") } } + return nil } @@ -365,6 +399,7 @@ func (r *Room) checkCanModifyMovies(ids []string) error { if len(ids) == 0 { return errors.New("ids is nil") } + cid := r.current.current.Movie.ID for _, id := range ids { if id == "" { @@ -372,19 +407,23 @@ func (r *Room) checkCanModifyMovies(ids []string) error { return errors.New("cannot modify current movie") } } + if cid != "" { if id == cid { return errors.New("cannot modify current movie") } + ok, err := r.movies.IsParentFolder(cid, id) if err != nil { return fmt.Errorf("check parent failed: %w", err) } + if ok { return errors.New("cannot modify current movie's parent") } } } + return nil } @@ -393,6 +432,7 @@ func (r *Room) DeleteMovieByID(id string) error { if err != nil { return err } + return r.movies.DeleteMovieByID(id) } @@ -401,6 +441,7 @@ func (r *Room) DeleteMoviesByID(ids []string) error { if err != nil { return err } + return r.movies.DeleteMoviesByID(ids) } @@ -413,6 +454,7 @@ func (r *Room) ClearMoviesByParentID(parentID string) error { if err != nil { return err } + return r.movies.DeleteMovieByParentID(parentID) } @@ -436,6 +478,7 @@ func (r *Room) LoadCurrentMovie() (*Movie, error) { if id == "" { return nil, ErrNoCurrentMovie } + return r.GetMovieByID(id) } @@ -444,6 +487,7 @@ func (r *Room) CheckCurrentExpired(ctx context.Context, expireID uint64) (bool, if err != nil { return false, err } + return m.CheckExpired(ctx, expireID) } @@ -459,26 +503,32 @@ func (r *Room) SetCurrentMovie(movieID, subPath string, play bool) error { } else { err = currentMovie.ClearCache() } + if err != nil { logrus.Errorf("clear current movie cache failed: %v", err) } } + if movieID == "" { r.current.SetMovie(model.CurrentMovie{}, false) return nil } + m, err := r.GetMovieByID(movieID) if err != nil { return err } + if m.IsFolder && !m.IsDynamicFolder() { return errors.New("cannot set static folder as current movie") } + r.current.SetMovie(model.CurrentMovie{ ID: m.ID, IsLive: m.Live, SubPath: subPath, }, play) + return m.ClearCache() } @@ -487,6 +537,7 @@ func (r *Room) SubPath(id string) string { if m.ID == id { return m.SubPath } + return "" } @@ -505,10 +556,12 @@ func (r *Room) GetMoviesWithPage( func (r *Room) NewClient(user *User, conn *websocket.Conn) (*Client, error) { h := r.lazyInitHub() cli := newClient(user, r, h, conn) + err := h.RegClient(cli) if err != nil { return nil, err } + return cli, nil } @@ -541,6 +594,7 @@ func (r *Room) SetSettings(settings *model.RoomSettings) error { if err != nil { return err } + return r.afterUpdateSettings(settings) } @@ -549,6 +603,7 @@ func (r *Room) UpdateSettings(settings map[string]any) error { if err != nil { return err } + return r.afterUpdateSettings(rs) } @@ -556,10 +611,12 @@ func (r *Room) afterUpdateSettings(rs *model.RoomSettings) error { if r.Settings.GuestPermissions != rs.GuestPermissions { r.members.Delete(db.GuestUserID) } + r.Settings = rs if rs.DisableGuest { return r.KickUser(db.GuestUserID) } + return nil } @@ -577,10 +634,12 @@ func (r *Room) SetMemberPermissions(userID string, permissions model.RoomMemberP if r.IsCreator(userID) { return errors.New("you are creator, cannot set permissions") } + if r.IsGuest(userID) { return r.SetGuestPermissions(permissions) } defer r.members.Delete(userID) + return db.SetMemberPermissions(r.ID, userID, permissions) } @@ -588,10 +647,12 @@ func (r *Room) AddMemberPermissions(userID string, permissions model.RoomMemberP if r.IsGuest(userID) { return r.SetGuestPermissions(r.Settings.GuestPermissions.Add(permissions)) } + if r.IsAdmin(userID) { return errors.New("cannot add permissions to admin") } defer r.members.Delete(userID) + return db.AddMemberPermissions(r.ID, userID, permissions) } @@ -602,10 +663,12 @@ func (r *Room) RemoveMemberPermissions( if r.IsGuest(userID) { return r.SetGuestPermissions(r.Settings.GuestPermissions.Remove(permissions)) } + if r.IsAdmin(userID) { return errors.New("cannot remove permissions from admin") } defer r.members.Delete(userID) + return db.RemoveMemberPermissions(r.ID, userID, permissions) } @@ -614,6 +677,7 @@ func (r *Room) ApprovePendingMember(userID string) error { return errors.New("creator cannot be approved as a pending member") } defer r.members.Delete(userID) + return db.RoomApprovePendingMember(r.ID, userID) } @@ -621,6 +685,7 @@ func (r *Room) BanMember(userID string) error { if r.IsCreator(userID) { return errors.New("creator cannot be banned") } + if r.IsGuest(userID) { return errors.New("please set whether to disable guest users in the room settings") } @@ -628,6 +693,7 @@ func (r *Room) BanMember(userID string) error { r.members.Delete(userID) _ = r.KickUser(userID) }() + return db.RoomBanMember(r.ID, userID) } @@ -635,10 +701,12 @@ func (r *Room) UnbanMember(userID string) error { if r.IsCreator(userID) { return errors.New("creator cannot be unbanned") } + if r.IsGuest(userID) { return errors.New("please set whether to enable guest users in the room settings") } defer r.members.Delete(userID) + return db.RoomUnbanMember(r.ID, userID) } @@ -650,6 +718,7 @@ func (r *Room) DeleteMember(userID string) error { r.members.Delete(userID) _ = r.KickUser(userID) }() + return db.DeleteRoomMember(r.ID, userID) } @@ -661,15 +730,18 @@ func (r *Room) SetAdminPermissions(userID string, permissions model.RoomAdminPer if r.IsCreator(userID) { return errors.New("creator cannot set admin permissions") } + if r.IsGuest(userID) { return errors.New("cannot set admin permissions to guest") } + if member, err := r.LoadMember(userID); err != nil { return err } else if !member.Role.IsAdmin() { return errors.New("not admin") } defer r.members.Delete(userID) + return db.RoomSetAdminPermissions(r.ID, userID, permissions) } @@ -677,15 +749,18 @@ func (r *Room) AddAdminPermissions(userID string, permissions model.RoomAdminPer if r.IsCreator(userID) { return errors.New("creator cannot add admin permissions") } + if r.IsGuest(userID) { return errors.New("cannot add admin permissions to guest") } + if member, err := r.LoadMember(userID); err != nil { return err } else if !member.Role.IsAdmin() { return errors.New("not admin") } defer r.members.Delete(userID) + return db.RoomAddAdminPermissions(r.ID, userID, permissions) } @@ -693,15 +768,18 @@ func (r *Room) RemoveAdminPermissions(userID string, permissions model.RoomAdmin if r.IsCreator(userID) { return errors.New("creator cannot remove admin permissions") } + if r.IsGuest(userID) { return errors.New("cannot remove admin permissions from guest") } + if member, err := r.LoadMember(userID); err != nil { return err } else if !member.Role.IsAdmin() { return errors.New("not admin") } defer r.members.Delete(userID) + return db.RoomRemoveAdminPermissions(r.ID, userID, permissions) } @@ -709,10 +787,12 @@ func (r *Room) SetAdmin(userID string, permissions model.RoomAdminPermission) er if r.IsCreator(userID) { return errors.New("creator cannot set admin") } + if r.IsGuest(userID) { return errors.New("cannot set guest as admin") } defer r.members.Delete(userID) + return db.RoomSetAdmin(r.ID, userID, permissions) } @@ -721,6 +801,7 @@ func (r *Room) SetMember(userID string, permissions model.RoomMemberPermission) return errors.New("creator cannot set member") } defer r.members.Delete(userID) + return db.RoomSetMember(r.ID, userID, permissions) } @@ -732,9 +813,11 @@ func (r *Room) SetStatus(status model.RoomStatus) error { if err := db.SetRoomStatus(r.ID, status); err != nil { return err } + r.Status = status if status == model.RoomStatusBanned || status == model.RoomStatusPending { r.close() } + return nil } diff --git a/internal/op/rooms.go b/internal/op/rooms.go index cea15f6..6014e6e 100644 --- a/internal/op/rooms.go +++ b/internal/op/rooms.go @@ -36,6 +36,7 @@ func CreateRoom( if err != nil { return nil, err } + return LoadOrInitRoom(r) } @@ -49,9 +50,11 @@ func checkRoomCreatorStatus(creatorID string) error { if user.IsBanned() { return ErrRoomCreatorBanned } + if user.IsPending() { return ErrRoomCreatorPending } + return nil } @@ -68,6 +71,7 @@ func LoadOrInitRoom(room *model.Room) (*RoomEntry, error) { r.movies.room = r i, _ := roomCache.LoadOrStore(room.ID, r, time.Duration(settings.RoomTTL.Get())*time.Hour) + return i, nil } @@ -96,7 +100,9 @@ func CompareAndDeleteRoom(room *RoomEntry) error { if err := db.DeleteRoomByID(room.Value().ID); err != nil { return err } + CompareAndCloseRoom(room) + return nil } @@ -124,6 +130,7 @@ func CompareAndCloseRoom(room *RoomEntry) bool { room.Value().close() return true } + return false } @@ -141,6 +148,7 @@ func LoadRoomByID(id string) (*RoomEntry, error) { } r2.SetExpiration(time.Now().Add(time.Duration(settings.RoomTTL.Get()) * time.Hour)) + return r2, nil } @@ -156,7 +164,9 @@ func LoadOrInitRoomByID(id string) (*RoomEntry, error) { } return nil, err } + i.SetExpiration(time.Now().Add(time.Duration(settings.RoomTTL.Get()) * time.Hour)) + return i, nil } @@ -164,11 +174,14 @@ func LoadOrInitRoomByID(id string) (*RoomEntry, error) { if err != nil { return nil, err } + settings, err := db.CreateOrLoadRoomSettings(room.ID) if err != nil { return nil, err } + room.Settings = settings + return LoadOrInitRoom(room) } @@ -184,5 +197,6 @@ func SetRoomStatusByID(roomID string, status model.RoomStatus) error { if err != nil { return err } + return room.Value().SetStatus(status) } diff --git a/internal/op/user.go b/internal/op/user.go index 431b0dc..342f80c 100644 --- a/internal/op/user.go +++ b/internal/op/user.go @@ -32,6 +32,7 @@ func (u *User) AlistCache() *cache.AlistUserCache { return u.AlistCache() } } + return c } @@ -43,6 +44,7 @@ func (u *User) BilibiliCache() *cache.BilibiliUserCache { return u.BilibiliCache() } } + return c } @@ -54,6 +56,7 @@ func (u *User) EmbyCache() *cache.EmbyUserCache { return u.EmbyCache() } } + return c } @@ -69,9 +72,11 @@ func (u *User) SetPassword(password string) error { if u.IsGuest() { return errors.New("guest cannot set password") } + if u.CheckPassword(password) { return errors.New("password is the same") } + hashedPassword, err := bcrypt.GenerateFromPassword( stream.StringToBytes(password), bcrypt.DefaultCost, @@ -79,8 +84,10 @@ func (u *User) SetPassword(password string) error { if err != nil { return err } + atomic.StoreUint32(&u.version, crc32.ChecksumIEEE(hashedPassword)) u.HashedPassword = hashedPassword + return db.SetUserHashedPassword(u.ID, hashedPassword) } @@ -91,9 +98,11 @@ func (u *User) CreateRoom(name, password string, conf ...db.CreateRoomConfig) (* if password == "" && settings.RoomMustNeedPwd.Get() { return nil, errors.New("room must need password") } + if password != "" && settings.RoomMustNoNeedPwd.Get() { return nil, errors.New("room must no need password") } + if settings.CreateRoomNeedReview.Get() { conf = append(conf, db.WithStatus(model.RoomStatusPending)) } else { @@ -113,6 +122,7 @@ func (u *User) NewMovie(movie *model.MovieBase) (*model.Movie, error) { if movie == nil { return nil, errors.New("movie is nil") } + switch movie.VendorInfo.Vendor { case model.VendorBilibili: if movie.VendorInfo.Bilibili == nil { @@ -123,6 +133,7 @@ func (u *User) NewMovie(movie *model.MovieBase) (*model.Movie, error) { return nil, errors.New("alist payload is nil") } } + return &model.Movie{ MovieBase: *movie, CreatorID: u.ID, @@ -133,14 +144,17 @@ func (u *User) AddRoomMovie(room *Room, movie *model.MovieBase) (*model.Movie, e if !u.HasRoomPermission(room, model.PermissionAddMovie) { return nil, model.ErrNoPermission } + m, err := u.NewMovie(movie) if err != nil { return nil, err } + err = room.AddMovie(m) if err != nil { return nil, err } + return m, room.Broadcast(&pb.Message{ Type: pb.MessageType_MOVIES, Sender: &pb.Sender{ @@ -157,8 +171,10 @@ func (u *User) NewMovies(movies []*model.MovieBase) ([]*model.Movie, error) { if err != nil { return nil, err } + ms[i] = movie } + return ms, nil } @@ -166,14 +182,17 @@ func (u *User) AddRoomMovies(room *Room, movies []*model.MovieBase) ([]*model.Mo if !u.HasRoomPermission(room, model.PermissionAddMovie) { return nil, model.ErrNoPermission } + m, err := u.NewMovies(movies) if err != nil { return nil, err } + err = room.AddMovies(m) if err != nil { return nil, err } + return m, room.Broadcast(&pb.Message{ Type: pb.MessageType_MOVIES, Sender: &pb.Sender{ @@ -214,9 +233,11 @@ func (u *User) HasRoomAdminPermission(room *Room, permission model.RoomAdminPerm if u.IsAdmin() { return true } + if u.IsGuest() { return false } + return room.HasAdminPermission(u.ID, permission) } @@ -243,14 +264,17 @@ func (u *User) SetRoomPassword(room *Room, password string) error { if !u.HasRoomAdminPermission(room, model.PermissionSetRoomPassword) { return model.ErrNoPermission } + if !u.IsAdmin() { if password == "" && settings.RoomMustNeedPwd.Get() { return errors.New("room must need password") } + if password != "" && settings.RoomMustNoNeedPwd.Get() { return errors.New("room must no need password") } } + return room.SetPassword(password) } @@ -258,10 +282,13 @@ func (u *User) SetUserRole() error { if u.IsGuest() { return errors.New("cannot set guest role") } + if err := db.SetUserRoleByID(u.ID); err != nil { return err } + u.Role = model.RoleUser + return nil } @@ -269,10 +296,13 @@ func (u *User) SetAdminRole() error { if u.IsGuest() { return errors.New("guest cannot be admin") } + if err := db.SetAdminRoleByID(u.ID); err != nil { return err } + u.Role = model.RoleAdmin + return nil } @@ -280,10 +310,13 @@ func (u *User) SetRootRole() error { if u.IsGuest() { return errors.New("guest cannot be root") } + if err := db.SetRootRoleByID(u.ID); err != nil { return err } + u.Role = model.RoleRoot + return nil } @@ -291,10 +324,13 @@ func (u *User) Ban() error { if u.IsGuest() { return errors.New("guest cannot be banned") } + if err := db.BanUserByID(u.ID); err != nil { return err } + u.Role = model.RoleBanned + return nil } @@ -302,7 +338,9 @@ func (u *User) Unban() error { if err := db.UnbanUserByID(u.ID); err != nil { return err } + u.Role = model.RoleUser + return nil } @@ -310,7 +348,9 @@ func (u *User) SetUsername(username string) error { if err := db.SetUsernameByID(u.ID, username); err != nil { return err } + u.Username = username + return nil } @@ -318,10 +358,12 @@ func (u *User) UpdateRoomMovie(room *Room, movieID string, movie *model.MovieBas if !u.HasRoomPermission(room, model.PermissionEditMovie) { return model.ErrNoPermission } + err := room.UpdateMovie(movieID, movie) if err != nil { return err } + return room.Broadcast(&pb.Message{ Type: pb.MessageType_MOVIES, Sender: &pb.Sender{ @@ -350,9 +392,11 @@ func (u *User) DeleteRoomMovieByID(room *Room, movieID string) error { if err != nil { return err } + if m.CreatorID != u.ID && !u.HasRoomPermission(room, model.PermissionDeleteMovie) { return model.ErrNoPermission } + return room.DeleteMovieByID(movieID) } @@ -362,13 +406,16 @@ func (u *User) DeleteRoomMoviesByID(room *Room, movieIDs []string) error { if err != nil { return err } + if m.CreatorID != u.ID && !u.HasRoomPermission(room, model.PermissionDeleteMovie) { return model.ErrNoPermission } } + if err := room.DeleteMoviesByID(movieIDs); err != nil { return err } + return room.Broadcast(&pb.Message{ Type: pb.MessageType_MOVIES, Sender: &pb.Sender{ @@ -382,10 +429,12 @@ func (u *User) ClearRoomMovies(room *Room) error { if !u.HasRoomPermission(room, model.PermissionDeleteMovie) { return model.ErrNoPermission } + err := room.ClearMovies() if err != nil { return err } + return room.Broadcast(&pb.Message{ Type: pb.MessageType_MOVIES, Sender: &pb.Sender{ @@ -399,10 +448,12 @@ func (u *User) ClearRoomMoviesByParentID(room *Room, parentID string) error { if !u.HasRoomPermission(room, model.PermissionDeleteMovie) { return model.ErrNoPermission } + err := room.ClearMoviesByParentID(parentID) if err != nil { return err } + return room.Broadcast(&pb.Message{ Type: pb.MessageType_MOVIES, Sender: &pb.Sender{ @@ -416,10 +467,12 @@ func (u *User) SwapRoomMoviePositions(room *Room, id1, id2 string) error { if !u.HasRoomPermission(room, model.PermissionEditMovie) { return model.ErrNoPermission } + err := room.SwapMoviePositions(id1, id2) if err != nil { return err } + return room.Broadcast(&pb.Message{ Type: pb.MessageType_MOVIES, Sender: &pb.Sender{ @@ -433,10 +486,12 @@ func (u *User) SetRoomCurrentMovie(room *Room, movieID, subPath string, play boo if !u.HasRoomPermission(room, model.PermissionSetCurrentMovie) { return model.ErrNoPermission } + err := room.SetCurrentMovie(movieID, subPath, play) if err != nil { return err } + return room.Broadcast(&pb.Message{ Type: pb.MessageType_CURRENT, Sender: &pb.Sender{ @@ -451,6 +506,7 @@ func (u *User) BindProvider(p provider.OAuth2Provider, pid string) error { if err != nil { return err } + return nil } @@ -467,7 +523,9 @@ func (u *User) BindEmail(e string) error { if err != nil { return err } + u.Email = model.EmptyNullString(e) + return nil } @@ -476,7 +534,9 @@ func (u *User) UnbindEmail() error { if err != nil { return err } + u.Email = "" + return nil } @@ -532,12 +592,15 @@ func (u *User) BanRoomMember(room *Room, userID string) error { if !u.HasRoomAdminPermission(room, model.PermissionBanRoomMember) { return model.ErrNoPermission } + if u.ID == userID { return errors.New("cannot ban yourself") } + if room.IsAdmin(userID) && !u.IsRoomCreator(room) { return errors.New("cannot ban admin") } + return room.BanMember(userID) } @@ -545,9 +608,11 @@ func (u *User) UnbanRoomMember(room *Room, userID string) error { if !u.HasRoomAdminPermission(room, model.PermissionBanRoomMember) { return model.ErrNoPermission } + if u.ID == userID { return errors.New("cannot unban yourself") } + return room.UnbanMember(userID) } @@ -566,13 +631,16 @@ func (u *User) SetMemberPermissions( if !u.HasRoomAdminPermission(room, model.PermissionSetUserPermission) { return model.ErrNoPermission } + if room.IsAdmin(userID) && !u.IsRoomCreator(room) { return errors.New("cannot set admin permissions") } + err := room.SetMemberPermissions(userID, permissions) if err != nil { return err } + return room.SendToUserWithID(userID, &pb.Message{ Type: pb.MessageType_MY_STATUS, Sender: &pb.Sender{ @@ -590,13 +658,16 @@ func (u *User) AddMemberPermissions( if !u.HasRoomAdminPermission(room, model.PermissionSetUserPermission) { return model.ErrNoPermission } + if room.IsAdmin(userID) && !u.IsRoomCreator(room) { return errors.New("cannot add admin permissions") } + err := room.AddMemberPermissions(userID, permissions) if err != nil { return err } + return room.SendToUserWithID(userID, &pb.Message{ Type: pb.MessageType_MY_STATUS, Sender: &pb.Sender{ @@ -614,13 +685,16 @@ func (u *User) RemoveMemberPermissions( if !u.HasRoomAdminPermission(room, model.PermissionSetUserPermission) { return model.ErrNoPermission } + if room.IsAdmin(userID) && !u.IsRoomCreator(room) { return errors.New("cannot remove admin permissions") } + err := room.RemoveMemberPermissions(userID, permissions) if err != nil { return err } + return room.SendToUserWithID(userID, &pb.Message{ Type: pb.MessageType_MY_STATUS, Sender: &pb.Sender{ @@ -634,13 +708,16 @@ func (u *User) ResetMemberPermissions(room *Room, userID string) error { if !u.HasRoomAdminPermission(room, model.PermissionSetUserPermission) { return model.ErrNoPermission } + if room.IsAdmin(userID) && !u.IsRoomCreator(room) { return errors.New("cannot reset admin permissions") } + err := room.ResetMemberPermissions(userID) if err != nil { return err } + return room.SendToUserWithID(userID, &pb.Message{ Type: pb.MessageType_MY_STATUS, Sender: &pb.Sender{ @@ -665,10 +742,12 @@ func (u *User) SetRoomAdmin( if !u.IsRoomCreator(room) { return model.ErrNoPermission } + err := room.SetAdmin(userID, permissions) if err != nil { return err } + return room.SendToUserWithID(userID, &pb.Message{ Type: pb.MessageType_MY_STATUS, Sender: &pb.Sender{ @@ -686,10 +765,12 @@ func (u *User) SetRoomMember( if !u.IsRoomCreator(room) { return model.ErrNoPermission } + err := room.SetMember(userID, permissions) if err != nil { return err } + return room.SendToUserWithID(userID, &pb.Message{ Type: pb.MessageType_MY_STATUS, Sender: &pb.Sender{ @@ -707,10 +788,12 @@ func (u *User) SetRoomAdminPermissions( if !u.IsRoomCreator(room) { return model.ErrNoPermission } + err := room.SetAdminPermissions(userID, permissions) if err != nil { return err } + return room.SendToUserWithID(userID, &pb.Message{ Type: pb.MessageType_MY_STATUS, Sender: &pb.Sender{ @@ -728,10 +811,12 @@ func (u *User) AddRoomAdminPermissions( if !u.IsRoomCreator(room) { return model.ErrNoPermission } + err := room.AddAdminPermissions(userID, permissions) if err != nil { return err } + return room.SendToUserWithID(userID, &pb.Message{ Type: pb.MessageType_MY_STATUS, Sender: &pb.Sender{ @@ -749,10 +834,12 @@ func (u *User) RemoveRoomAdminPermissions( if !u.IsRoomCreator(room) { return model.ErrNoPermission } + err := room.RemoveAdminPermissions(userID, permissions) if err != nil { return err } + return room.SendToUserWithID(userID, &pb.Message{ Type: pb.MessageType_MY_STATUS, Sender: &pb.Sender{ @@ -766,10 +853,12 @@ func (u *User) ResetRoomAdminPermissions(room *Room, userID string) error { if !u.IsRoomCreator(room) { return model.ErrNoPermission } + err := room.ResetAdminPermissions(userID) if err != nil { return err } + return room.SendToUserWithID(userID, &pb.Message{ Type: pb.MessageType_MY_STATUS, Sender: &pb.Sender{ diff --git a/internal/op/users.go b/internal/op/users.go index 9bffefd..2849340 100644 --- a/internal/op/users.go +++ b/internal/op/users.go @@ -27,6 +27,7 @@ func LoadOrInitUser(u *model.User) (*UserEntry, error) { User: *u, version: crc32.ChecksumIEEE(u.HashedPassword), }, time.Hour) + return i, nil } @@ -67,6 +68,7 @@ func CreateUser(username, password string, conf ...db.CreateUserConfig) (*UserEn if username == "" { return nil, errors.New("username cannot be empty") } + u, err := db.CreateUser(username, password, conf...) if err != nil { return nil, err @@ -115,10 +117,12 @@ func CompareAndDeleteUser(user *UserEntry) error { if id == db.GuestUserID { return errors.New("cannot delete guest user") } + err := db.DeleteUserByID(id) if err != nil { return err } + return CompareAndCloseUser(user) } @@ -126,10 +130,12 @@ func DeleteUserByID(id string) error { if id == db.GuestUserID { return errors.New("cannot delete guest user") } + err := db.DeleteUserByID(id) if err != nil { return err } + return CloseUserByID(id) } @@ -141,6 +147,7 @@ func CloseUserByID(id string) error { } return true }) + return nil } @@ -148,12 +155,14 @@ func CompareAndCloseUser(user *UserEntry) error { if !userCache.CompareAndDelete(user.Value().ID, user) { return nil } + roomCache.Range(func(_ string, value *synccache.Entry[*Room]) bool { if value.Value().CreatorID == user.Value().ID { CompareAndCloseRoom(value) } return true }) + return nil } @@ -162,6 +171,7 @@ func GetUserName(userID string) string { if err != nil { return "" } + return u.Value().Username } diff --git a/internal/provider/aggregation.go b/internal/provider/aggregation.go index 81f45e0..2377cf4 100644 --- a/internal/provider/aggregation.go +++ b/internal/provider/aggregation.go @@ -13,13 +13,16 @@ func ExtractProviders( if len(providers) == 0 { providers = p.Providers() } + pi := make([]Interface, len(providers)) for i, provider := range providers { pi2, err := p.ExtractProvider(provider) if err != nil { return nil, err } + pi[i] = pi2 } + return pi, nil } diff --git a/internal/provider/aggregations/rainbow.go b/internal/provider/aggregations/rainbow.go index 8fd19ac..37a11b8 100644 --- a/internal/provider/aggregations/rainbow.go +++ b/internal/provider/aggregations/rainbow.go @@ -68,10 +68,12 @@ func (p *rainbowGenericProvider) NewAuthURL(ctx context.Context, state string) ( if err != nil { return "", err } + u, err := url.Parse(result) if err != nil { return "", err } + query := url.Values{} query.Set("act", "login") query.Set("appid", p.conf.ClientID) @@ -85,19 +87,24 @@ func (p *rainbowGenericProvider) NewAuthURL(ctx context.Context, state string) ( if err != nil { return "", err } + resp, err := uhc.Do(req) if err != nil { return "", err } defer resp.Body.Close() + data := rainbowNewAuthURLResp{} + err = json.NewDecoder(resp.Body).Decode(&data) if err != nil { return "", err } + if data.Code != 0 { return "", fmt.Errorf("error code: %d, msg: %s", data.ErrCode, data.Msg) } + return data.URL, nil } @@ -117,10 +124,12 @@ func (p *rainbowGenericProvider) GetUserInfo( if err != nil { return nil, err } + u, err := url.Parse(result) if err != nil { return nil, err } + query := url.Values{} query.Set("act", "callback") query.Set("appid", p.conf.ClientID) @@ -133,19 +142,24 @@ func (p *rainbowGenericProvider) GetUserInfo( if err != nil { return nil, err } + resp, err := uhc.Do(req) if err != nil { return nil, err } defer resp.Body.Close() + data := rainbowUserInfo{} + err = json.NewDecoder(resp.Body).Decode(&data) if err != nil { return nil, err } + if data.Code != 0 { return nil, fmt.Errorf("error code: %d, msg: %s", data.ErrCode, data.Msg) } + return &provider.UserInfo{ Username: data.Nickname, ProviderUserID: data.SocialUID, diff --git a/internal/provider/plugins/client.go b/internal/provider/plugins/client.go index 2814afe..e992867 100644 --- a/internal/provider/plugins/client.go +++ b/internal/provider/plugins/client.go @@ -25,6 +25,7 @@ func (c *GRPCClient) Provider() provider.OAuth2Provider { if err != nil { return "" } + return resp.GetName() } @@ -33,6 +34,7 @@ func (c *GRPCClient) NewAuthURL(ctx context.Context, state string) (string, erro if err != nil { return "", err } + return resp.GetUrl(), nil } @@ -43,6 +45,7 @@ func (c *GRPCClient) GetUserInfo(ctx context.Context, code string) (*provider.Us if err != nil { return nil, err } + return &provider.UserInfo{ Username: resp.GetUsername(), ProviderUserID: resp.GetProviderUserId(), diff --git a/internal/provider/plugins/example/example_authing/example_authing.go b/internal/provider/plugins/example/example_authing/example_authing.go index 17da4d9..814b18b 100644 --- a/internal/provider/plugins/example/example_authing/example_authing.go +++ b/internal/provider/plugins/example/example_authing/example_authing.go @@ -70,7 +70,9 @@ func (p *AuthingProvider) GetUserInfo( if err != nil { return nil, err } + client := p.config.Client(ctx, tk) + req, err := http.NewRequestWithContext( ctx, http.MethodGet, @@ -80,16 +82,20 @@ func (p *AuthingProvider) GetUserInfo( if err != nil { return nil, err } + resp, err := client.Do(req) if err != nil { return nil, err } defer resp.Body.Close() + ui := AuthingUserInfo{} + err = json.NewDecoder(resp.Body).Decode(&ui) if err != nil { return nil, err } + return &provider.UserInfo{ Username: ui.Name, ProviderUserID: ui.UnionID, diff --git a/internal/provider/plugins/example/example_feishu-sso/example_feishu-sso.go b/internal/provider/plugins/example/example_feishu-sso/example_feishu-sso.go index b4cdea7..04d8028 100644 --- a/internal/provider/plugins/example/example_feishu-sso/example_feishu-sso.go +++ b/internal/provider/plugins/example/example_feishu-sso/example_feishu-sso.go @@ -84,7 +84,9 @@ func (p *FeishuSSOProvider) GetUserInfo( if err != nil { return nil, err } + client := p.config.Client(ctx, tk) + req, err := http.NewRequestWithContext( ctx, http.MethodGet, @@ -94,16 +96,20 @@ func (p *FeishuSSOProvider) GetUserInfo( if err != nil { return nil, err } + resp, err := client.Do(req) if err != nil { return nil, err } defer resp.Body.Close() + ui := FeishuSSOUserInfo{} + err = json.NewDecoder(resp.Body).Decode(&ui) if err != nil { return nil, err } + return &provider.UserInfo{ Username: ui.Name, ProviderUserID: ui.ID, diff --git a/internal/provider/plugins/example/example_gitee/example_gitee.go b/internal/provider/plugins/example/example_gitee/example_gitee.go index c52292b..999f605 100644 --- a/internal/provider/plugins/example/example_gitee/example_gitee.go +++ b/internal/provider/plugins/example/example_gitee/example_gitee.go @@ -63,7 +63,9 @@ func (p *GiteeProvider) GetUserInfo(ctx context.Context, code string) (*provider if err != nil { return nil, err } + client := p.config.Client(ctx, tk) + req, err := http.NewRequestWithContext( ctx, http.MethodGet, @@ -73,16 +75,20 @@ func (p *GiteeProvider) GetUserInfo(ctx context.Context, code string) (*provider if err != nil { return nil, err } + resp, err := client.Do(req) if err != nil { return nil, err } defer resp.Body.Close() + ui := giteeUserInfo{} + err = json.NewDecoder(resp.Body).Decode(&ui) if err != nil { return nil, err } + return &provider.UserInfo{ Username: ui.Login, ProviderUserID: strconv.FormatUint(ui.ID, 10), diff --git a/internal/provider/plugins/plugin.go b/internal/provider/plugins/plugin.go index a026a6f..ff84be4 100644 --- a/internal/provider/plugins/plugin.go +++ b/internal/provider/plugins/plugin.go @@ -16,6 +16,7 @@ import ( func InitProviderPlugins(name string, arg []string, logger hclog.Logger) error { client := NewProviderPlugin(name, arg, logger) + err := sysnotify.RegisterSysNotifyTask( 0, sysnotify.NewSysNotifyTask("plugin", sysnotify.NotifyTypeEXIT, func() error { @@ -26,19 +27,24 @@ func InitProviderPlugins(name string, arg []string, logger hclog.Logger) error { if err != nil { return err } + c, err := client.Client() if err != nil { return err } + i, err := c.Dispense("Provider") if err != nil { return err } + provider, ok := i.(provider.Interface) if !ok { return fmt.Errorf("%s not implement ProviderInterface", name) } + providers.RegisterProvider(provider) + return nil } @@ -74,7 +80,7 @@ func NewProviderPlugin(name string, arg []string, logger hclog.Logger) *plugin.C return plugin.NewClient(&plugin.ClientConfig{ HandshakeConfig: HandshakeConfig, Plugins: pluginMap, - Cmd: exec.Command(name, arg...), + Cmd: exec.CommandContext(context.Background(), name, arg...), AllowedProtocols: []plugin.Protocol{ plugin.ProtocolGRPC, }, diff --git a/internal/provider/plugins/server.go b/internal/provider/plugins/server.go index 0c4d8d2..afc8a10 100644 --- a/internal/provider/plugins/server.go +++ b/internal/provider/plugins/server.go @@ -19,6 +19,7 @@ func (s *GRPCServer) Init(_ context.Context, req *providerpb.InitReq) (*provider RedirectURL: req.GetRedirectUrl(), } s.Impl.Init(opt) + return &providerpb.Enpty{}, nil } @@ -37,6 +38,7 @@ func (s *GRPCServer) NewAuthURL( if err != nil { return nil, err } + return &providerpb.NewAuthURLResp{Url: s2}, nil } @@ -48,9 +50,11 @@ func (s *GRPCServer) GetUserInfo( if err != nil { return nil, err } + resp := &providerpb.GetUserInfoResp{ Username: userInfo.Username, ProviderUserId: userInfo.ProviderUserID, } + return resp, nil } diff --git a/internal/provider/providers/baidu-netdisk.go b/internal/provider/providers/baidu-netdisk.go index 5d3df26..f495498 100644 --- a/internal/provider/providers/baidu-netdisk.go +++ b/internal/provider/providers/baidu-netdisk.go @@ -58,7 +58,9 @@ func (p *BaiduNetDiskProvider) GetUserInfo( if err != nil { return nil, err } + client := p.config.Client(ctx, tk) + req, err := http.NewRequestWithContext( ctx, http.MethodGet, @@ -68,19 +70,24 @@ func (p *BaiduNetDiskProvider) GetUserInfo( if err != nil { return nil, err } + resp, err := client.Do(req) if err != nil { return nil, err } defer resp.Body.Close() + ui := baiduNetDiskProviderUserInfo{} + err = json.NewDecoder(resp.Body).Decode(&ui) if err != nil { return nil, err } + if ui.Errno != 0 { return nil, fmt.Errorf("baidu oauth2 get user info error: %s", ui.Errmsg) } + return &provider.UserInfo{ Username: ui.BaiduName, ProviderUserID: strconv.FormatUint(ui.Uk, 10), diff --git a/internal/provider/providers/baidu.go b/internal/provider/providers/baidu.go index e34408a..cd56e4d 100644 --- a/internal/provider/providers/baidu.go +++ b/internal/provider/providers/baidu.go @@ -53,7 +53,9 @@ func (p *BaiduProvider) GetUserInfo(ctx context.Context, code string) (*provider if err != nil { return nil, err } + client := p.config.Client(ctx, tk) + req, err := http.NewRequestWithContext( ctx, http.MethodGet, @@ -63,16 +65,20 @@ func (p *BaiduProvider) GetUserInfo(ctx context.Context, code string) (*provider if err != nil { return nil, err } + resp, err := client.Do(req) if err != nil { return nil, err } defer resp.Body.Close() + ui := baiduProviderUserInfo{} + err = json.NewDecoder(resp.Body).Decode(&ui) if err != nil { return nil, err } + return &provider.UserInfo{ Username: ui.Uname, ProviderUserID: ui.Openid, diff --git a/internal/provider/providers/casdoor.go b/internal/provider/providers/casdoor.go index 66d98b9..cc4668d 100644 --- a/internal/provider/providers/casdoor.go +++ b/internal/provider/providers/casdoor.go @@ -52,25 +52,32 @@ func (p *casdoorProvider) GetUserInfo( if err != nil { return nil, err } + client := p.config.Client(ctx, tk) + req, err := http.NewRequestWithContext(ctx, http.MethodGet, p.endpoint+"/api/userinfo", nil) if err != nil { return nil, err } + resp, err := client.Do(req) if err != nil { return nil, err } defer resp.Body.Close() + var ui casdoorUserInfo + err = json.NewDecoder(resp.Body).Decode(&ui) if err != nil { return nil, err } + un := ui.PreferredUsername if un == "" { un = ui.Name } + return &provider.UserInfo{ ProviderUserID: ui.Sub, Username: un, @@ -100,6 +107,7 @@ func (p *casdoorProvider) RegistSetting(group string) { if err != nil { return "", err } + return fmt.Sprintf("%s://%s", u.Scheme, u.Host), nil }), settings.WithAfterSetString(func(_ settings.StringSetting, s string) { diff --git a/internal/provider/providers/discord.go b/internal/provider/providers/discord.go index a84fb75..d55fd46 100644 --- a/internal/provider/providers/discord.go +++ b/internal/provider/providers/discord.go @@ -55,7 +55,9 @@ func (p *DiscordProvider) GetUserInfo( if err != nil { return nil, err } + client := p.config.Client(ctx, tk) + req, err := http.NewRequestWithContext( ctx, http.MethodGet, @@ -65,16 +67,20 @@ func (p *DiscordProvider) GetUserInfo( if err != nil { return nil, err } + resp, err := client.Do(req) if err != nil { return nil, err } defer resp.Body.Close() + ui := discordUserInfo{} + err = json.NewDecoder(resp.Body).Decode(&ui) if err != nil { return nil, err } + return &provider.UserInfo{ Username: ui.Data.Name, ProviderUserID: ui.Data.ID, diff --git a/internal/provider/providers/gitee.go b/internal/provider/providers/gitee.go index 8d6dbdd..9469fac 100644 --- a/internal/provider/providers/gitee.go +++ b/internal/provider/providers/gitee.go @@ -53,7 +53,9 @@ func (p *GiteeProvider) GetUserInfo(ctx context.Context, code string) (*provider if err != nil { return nil, err } + client := p.config.Client(ctx, tk) + req, err := http.NewRequestWithContext( ctx, http.MethodGet, @@ -63,16 +65,20 @@ func (p *GiteeProvider) GetUserInfo(ctx context.Context, code string) (*provider if err != nil { return nil, err } + resp, err := client.Do(req) if err != nil { return nil, err } defer resp.Body.Close() + ui := giteeUserInfo{} + err = json.NewDecoder(resp.Body).Decode(&ui) if err != nil { return nil, err } + return &provider.UserInfo{ Username: ui.Login, ProviderUserID: strconv.FormatUint(ui.ID, 10), diff --git a/internal/provider/providers/github.go b/internal/provider/providers/github.go index e5fa10a..2a87813 100644 --- a/internal/provider/providers/github.go +++ b/internal/provider/providers/github.go @@ -51,21 +51,27 @@ func (p *GithubProvider) GetUserInfo(ctx context.Context, code string) (*provide if err != nil { return nil, err } + client := p.config.Client(ctx, tk) + req, err := http.NewRequestWithContext(ctx, http.MethodGet, "https://api.github.com/user", nil) if err != nil { return nil, err } + resp, err := client.Do(req) if err != nil { return nil, err } defer resp.Body.Close() + ui := githubUserInfo{} + err = json.NewDecoder(resp.Body).Decode(&ui) if err != nil { return nil, err } + return &provider.UserInfo{ Username: ui.Login, ProviderUserID: strconv.FormatUint(ui.ID, 10), diff --git a/internal/provider/providers/gitlab.go b/internal/provider/providers/gitlab.go index 7b281eb..1a53cdb 100644 --- a/internal/provider/providers/gitlab.go +++ b/internal/provider/providers/gitlab.go @@ -49,7 +49,9 @@ func (g *GitlabProvider) GetUserInfo(ctx context.Context, code string) (*provide if err != nil { return nil, err } + client := g.config.Client(ctx, tk) + req, err := http.NewRequestWithContext( ctx, http.MethodGet, @@ -59,11 +61,13 @@ func (g *GitlabProvider) GetUserInfo(ctx context.Context, code string) (*provide if err != nil { return nil, err } + resp, err := client.Do(req) if err != nil { return nil, err } defer resp.Body.Close() + return nil, FormatNotImplementedError("gitlab") } diff --git a/internal/provider/providers/google.go b/internal/provider/providers/google.go index e494c89..e1d8c8c 100644 --- a/internal/provider/providers/google.go +++ b/internal/provider/providers/google.go @@ -50,7 +50,9 @@ func (g *GoogleProvider) GetUserInfo(ctx context.Context, code string) (*provide if err != nil { return nil, err } + client := g.config.Client(ctx, tk) + req, err := http.NewRequestWithContext( ctx, http.MethodGet, @@ -60,16 +62,20 @@ func (g *GoogleProvider) GetUserInfo(ctx context.Context, code string) (*provide if err != nil { return nil, err } + resp, err := client.Do(req) if err != nil { return nil, err } defer resp.Body.Close() + ui := googleUserInfo{} + err = json.NewDecoder(resp.Body).Decode(&ui) if err != nil { return nil, err } + return &provider.UserInfo{ Username: ui.Name, ProviderUserID: ui.ID, diff --git a/internal/provider/providers/logto.go b/internal/provider/providers/logto.go index 5734138..0bb5c3a 100644 --- a/internal/provider/providers/logto.go +++ b/internal/provider/providers/logto.go @@ -49,25 +49,32 @@ func (p *logtoProvider) GetUserInfo(ctx context.Context, code string) (*provider if err != nil { return nil, err } + client := p.config.Client(ctx, tk) + req, err := http.NewRequestWithContext(ctx, http.MethodGet, p.endpoint+"/oidc/me", nil) if err != nil { return nil, err } + resp, err := client.Do(req) if err != nil { return nil, err } defer resp.Body.Close() + var ui logtoUserInfo + err = json.NewDecoder(resp.Body).Decode(&ui) if err != nil { return nil, err } + un := ui.Username if un == "" { un = ui.Name } + return &provider.UserInfo{ ProviderUserID: ui.Sub, Username: un, @@ -98,6 +105,7 @@ func (p *logtoProvider) RegistSetting(group string) { if err != nil { return "", err } + return fmt.Sprintf("%s://%s", u.Scheme, u.Host), nil }), settings.WithAfterSetString(func(_ settings.StringSetting, s string) { diff --git a/internal/provider/providers/microsoft.go b/internal/provider/providers/microsoft.go index f3a9932..4e48cad 100644 --- a/internal/provider/providers/microsoft.go +++ b/internal/provider/providers/microsoft.go @@ -53,7 +53,9 @@ func (p *MicrosoftProvider) GetUserInfo( if err != nil { return nil, err } + client := p.config.Client(ctx, tk) + req, err := http.NewRequestWithContext( ctx, http.MethodGet, @@ -63,16 +65,20 @@ func (p *MicrosoftProvider) GetUserInfo( if err != nil { return nil, err } + resp, err := client.Do(req) if err != nil { return nil, err } defer resp.Body.Close() + ui := microsoftUserInfo{} + err = json.NewDecoder(resp.Body).Decode(&ui) if err != nil { return nil, err } + return &provider.UserInfo{ Username: ui.DisplayName, ProviderUserID: ui.ID, diff --git a/internal/provider/providers/providers.go b/internal/provider/providers/providers.go index d41350c..32caa49 100644 --- a/internal/provider/providers/providers.go +++ b/internal/provider/providers/providers.go @@ -15,7 +15,9 @@ func InitProvider(p provider.OAuth2Provider, c provider.Oauth2Option) (provider. if !ok { return nil, FormatNotImplementedError(p) } + pi.Init(c) + return pi, nil } @@ -30,10 +32,12 @@ func GetProvider(p provider.OAuth2Provider) (provider.Interface, error) { if !ok { return nil, FormatNotImplementedError(p) } + pi, ok := allProviders.Load(p) if !ok { return nil, FormatNotImplementedError(p) } + return pi, nil } @@ -43,6 +47,7 @@ func AllProvider() map[provider.OAuth2Provider]provider.Interface { m[key] = value return true }) + return m } @@ -55,7 +60,9 @@ func EnableProvider(p provider.OAuth2Provider) error { if !ok { return FormatNotImplementedError(p) } + enabledProviders.Store(p, struct{}{}) + return nil } @@ -64,7 +71,9 @@ func DisableProvider(p provider.OAuth2Provider) error { if !ok { return FormatNotImplementedError(p) } + enabledProviders.Delete(p) + return nil } diff --git a/internal/provider/providers/qq.go b/internal/provider/providers/qq.go index d5b9806..5109469 100644 --- a/internal/provider/providers/qq.go +++ b/internal/provider/providers/qq.go @@ -52,6 +52,7 @@ func (p *QQProvider) GetToken(ctx context.Context, code string) (*oauth2.Token, params.Set("client_id", p.config.ClientID) params.Set("client_secret", p.config.ClientSecret) params.Set("fmt", "json") + req, err := http.NewRequestWithContext( ctx, http.MethodGet, @@ -61,18 +62,19 @@ func (p *QQProvider) GetToken(ctx context.Context, code string) (*oauth2.Token, if err != nil { return nil, err } + resp, err := uhc.Do(req) if err != nil { return nil, err } defer resp.Body.Close() - + // 使用自定义的qqToken结构体解析QQ的响应 qqTk := &qqToken{} if err := json.NewDecoder(resp.Body).Decode(qqTk); err != nil { return nil, err } - + // 转换为标准的oauth2.Token return qqTk.toOAuth2Token() } @@ -84,6 +86,7 @@ func (p *QQProvider) RefreshToken(ctx context.Context, tk string) (*oauth2.Token params.Set("client_id", p.config.ClientID) params.Set("client_secret", p.config.ClientSecret) params.Set("fmt", "json") + req, err := http.NewRequestWithContext( ctx, http.MethodGet, @@ -93,18 +96,19 @@ func (p *QQProvider) RefreshToken(ctx context.Context, tk string) (*oauth2.Token if err != nil { return nil, err } + resp, err := uhc.Do(req) if err != nil { return nil, err } defer resp.Body.Close() - + // 使用自定义的qqToken结构体解析QQ的响应 qqTk := &qqToken{} if err := json.NewDecoder(resp.Body).Decode(qqTk); err != nil { return nil, err } - + // 转换为标准的oauth2.Token return qqTk.toOAuth2Token() } @@ -114,6 +118,7 @@ func (p *QQProvider) GetUserInfo(ctx context.Context, code string) (*provider.Us if err != nil { return nil, err } + req, err := http.NewRequestWithContext( ctx, http.MethodGet, @@ -123,16 +128,20 @@ func (p *QQProvider) GetUserInfo(ctx context.Context, code string) (*provider.Us if err != nil { return nil, err } + resp, err := uhc.Do(req) if err != nil { return nil, err } defer resp.Body.Close() + ume := qqProviderMe{} + err = json.NewDecoder(resp.Body).Decode(&ume) if err != nil { return nil, err } + req, err = http.NewRequestWithContext( ctx, http.MethodGet, @@ -147,16 +156,20 @@ func (p *QQProvider) GetUserInfo(ctx context.Context, code string) (*provider.Us if err != nil { return nil, err } + resp2, err := uhc.Do(req) if err != nil { return nil, err } defer resp2.Body.Close() + ui := qqUserInfo{} + err = json.NewDecoder(resp2.Body).Decode(&ui) if err != nil { return nil, err } + return &provider.UserInfo{ Username: ui.Nickname, ProviderUserID: ume.Openid, @@ -166,7 +179,7 @@ func (p *QQProvider) GetUserInfo(ctx context.Context, code string) (*provider.Us //nolint:tagliatelle type qqToken struct { AccessToken string `json:"access_token"` - ExpiresIn string `json:"expires_in"` // QQ返回字符串格式 + ExpiresIn string `json:"expires_in"` // QQ返回字符串格式 RefreshToken string `json:"refresh_token"` } @@ -176,7 +189,7 @@ func (qt *qqToken) toOAuth2Token() (*oauth2.Token, error) { if err != nil { return nil, fmt.Errorf("failed to parse expires_in: %w", err) } - + return &oauth2.Token{ AccessToken: qt.AccessToken, RefreshToken: qt.RefreshToken, diff --git a/internal/provider/providers/xiaomi.go b/internal/provider/providers/xiaomi.go index 7e2ae7a..ced03bb 100644 --- a/internal/provider/providers/xiaomi.go +++ b/internal/provider/providers/xiaomi.go @@ -53,7 +53,9 @@ func (p *XiaomiProvider) GetUserInfo(ctx context.Context, code string) (*provide if err != nil { return nil, err } + client := p.config.Client(ctx, tk) + req, err := http.NewRequestWithContext( ctx, http.MethodGet, @@ -67,16 +69,20 @@ func (p *XiaomiProvider) GetUserInfo(ctx context.Context, code string) (*provide if err != nil { return nil, err } + resp, err := client.Do(req) if err != nil { return nil, err } defer resp.Body.Close() + ui := xiaomiUserInfo{} + err = json.NewDecoder(resp.Body).Decode(&ui) if err != nil { return nil, err } + return &provider.UserInfo{ Username: ui.Data.Name, ProviderUserID: ui.Data.UnionID, diff --git a/internal/rtmp/rtmp.go b/internal/rtmp/rtmp.go index 84122c4..71331bb 100644 --- a/internal/rtmp/rtmp.go +++ b/internal/rtmp/rtmp.go @@ -29,10 +29,12 @@ func AuthRtmpPublish(authorization string) (movieID string, err error) { if err != nil { return "", errors.New("auth failed") } + claims, ok := t.Claims.(*Claims) if !ok { return "", errors.New("auth failed") } + return claims.MovieID, nil } @@ -43,6 +45,7 @@ func NewRtmpAuthorization(movieID string) (string, error) { NotBefore: jwt.NewNumericDate(time.Now()), }, } + return jwt.NewWithClaims(jwt.SigningMethodHS256, claims). SignedString(stream.StringToBytes(conf.Conf.Jwt.Secret)) } diff --git a/internal/settings/bool.go b/internal/settings/bool.go index 74c79dc..21e90b2 100644 --- a/internal/settings/bool.go +++ b/internal/settings/bool.go @@ -89,7 +89,9 @@ func newBool( for _, option := range options { option(b) } + b.set(value) + return b } @@ -122,6 +124,7 @@ func (b *Bool) Get() bool { if b.afterGet != nil { v = b.afterGet(b, v) } + return v } @@ -223,7 +226,7 @@ func (b *Bool) Set(v bool) (err error) { b.afterSet(b, v) } - return + return err } func (b *Bool) Interface() any { @@ -240,6 +243,7 @@ func NewBoolSetting( if loaded { panic(fmt.Sprintf("setting %s already exists", k)) } + return CoverBoolSetting(k, v, g, options...) } @@ -251,11 +255,14 @@ func CoverBoolSetting( ) BoolSetting { b := newBool(k, v, g, options...) Settings[k] = b + if GroupSettings[g] == nil { GroupSettings[g] = make(map[model.SettingGroup]Setting) } + GroupSettings[g][k] = b pushNeedInit(b) + return b } @@ -264,7 +271,9 @@ func LoadBoolSetting(k string) (BoolSetting, bool) { if !ok { return nil, false } + b, ok := s.(BoolSetting) + return b, ok } diff --git a/internal/settings/floate64.go b/internal/settings/floate64.go index 15f036c..5168fc0 100644 --- a/internal/settings/floate64.go +++ b/internal/settings/floate64.go @@ -101,7 +101,9 @@ func newFloat64( for _, option := range options { option(f) } + f.set(value) + return f } @@ -130,9 +132,11 @@ func (f *Float64) Parse(value string) (float64, error) { if err != nil { return 0, err } + if f.validator != nil { return v, f.validator(v) } + return v, nil } @@ -241,7 +245,7 @@ func (f *Float64) Set(v float64) (err error) { f.afterSet(f, v) } - return + return err } func (f *Float64) Get() float64 { @@ -249,6 +253,7 @@ func (f *Float64) Get() float64 { if f.afterGet != nil { v = f.afterGet(f, v) } + return v } @@ -266,6 +271,7 @@ func NewFloat64Setting( if loaded { panic(fmt.Sprintf("setting %s already exists", k)) } + return CoverFloat64Setting(k, v, g, options...) } @@ -277,11 +283,14 @@ func CoverFloat64Setting( ) Float64Setting { f := newFloat64(k, v, g, options...) Settings[k] = f + if GroupSettings[g] == nil { GroupSettings[g] = make(map[model.SettingGroup]Setting) } + GroupSettings[g][k] = f pushNeedInit(f) + return f } @@ -290,7 +299,9 @@ func LoadFloat64Setting(k string) (Float64Setting, bool) { if !ok { return nil, false } + f, ok := s.(Float64Setting) + return f, ok } @@ -304,5 +315,6 @@ func LoadOrNewFloat64Setting( if ok { return s } + return CoverFloat64Setting(k, v, g, options...) } diff --git a/internal/settings/int64.go b/internal/settings/int64.go index 2c129a4..8b8e61c 100644 --- a/internal/settings/int64.go +++ b/internal/settings/int64.go @@ -97,6 +97,7 @@ func newInt64( for _, option := range options { option(i) } + return i } @@ -125,9 +126,11 @@ func (i *Int64) Parse(value string) (int64, error) { if err != nil { return 0, err } + if i.validator != nil { return v, i.validator(v) } + return v, nil } @@ -236,7 +239,7 @@ func (i *Int64) Set(v int64) (err error) { i.afterSet(i, v) } - return + return err } func (i *Int64) Get() int64 { @@ -244,6 +247,7 @@ func (i *Int64) Get() int64 { if i.afterGet != nil { v = i.afterGet(i, v) } + return v } @@ -261,6 +265,7 @@ func NewInt64Setting( if loaded { panic(fmt.Sprintf("setting %s already exists", k)) } + return CoverInt64Setting(k, v, g, options...) } @@ -272,11 +277,14 @@ func CoverInt64Setting( ) Int64Setting { i := newInt64(k, v, g, options...) Settings[k] = i + if GroupSettings[g] == nil { GroupSettings[g] = make(map[model.SettingGroup]Setting) } + GroupSettings[g][k] = i pushNeedInit(i) + return i } @@ -285,7 +293,9 @@ func LoadInt64Setting(k string) (Int64Setting, bool) { if !ok { return nil, false } + i, ok := s.(Int64Setting) + return i, ok } @@ -299,5 +309,6 @@ func LoadOrNewInt64Setting( if ok { return s } + return CoverInt64Setting(k, v, g, options...) } diff --git a/internal/settings/setting.go b/internal/settings/setting.go index dd4bbc8..54e8fae 100644 --- a/internal/settings/setting.go +++ b/internal/settings/setting.go @@ -55,12 +55,14 @@ func pushNeedInit(s Setting) { if s == nil { panic("push need init failed, setting is nil") } + for i, item := range needInit.items { if item.Name() == s.Name() { heap.Remove(needInit, i) break } } + heap.Push(needInit, maxHeapItem{ priority: s.InitPriority(), Setting: s, @@ -74,12 +76,15 @@ func hasNeedInit() bool { func PopNeedInit() (Setting, bool) { for hasNeedInit() { item := heap.Pop(needInit) + s := item.Setting if s.Inited() { continue } + return s, true } + return nil, false } @@ -104,6 +109,7 @@ func SetValue(name string, value any) error { if !ok { return fmt.Errorf("setting %s not found", name) } + switch s.Type() { case model.SettingTypeBool: return s.(BoolSetting).Set(json.Wrap(value).ToBool()) @@ -114,6 +120,7 @@ func SetValue(name string, value any) error { case model.SettingTypeString: return s.(StringSetting).Set(json.Wrap(value).ToString()) } + return s.SetString(json.Wrap(value).ToString()) } diff --git a/internal/settings/string.go b/internal/settings/string.go index 7015956..60fe45d 100644 --- a/internal/settings/string.go +++ b/internal/settings/string.go @@ -100,6 +100,7 @@ func newString( for _, option := range options { option(s) } + return s } @@ -208,6 +209,7 @@ func (s *String) SetString(value string) error { func (s *String) set(value string) { s.lock.Lock() defer s.lock.Unlock() + s.value = value } @@ -237,16 +239,18 @@ func (s *String) Set(v string) (err error) { s.afterSet(s, v) } - return + return err } func (s *String) Get() string { s.lock.RLock() defer s.lock.RUnlock() + v := s.value if s.afterGet != nil { v = s.afterGet(s, v) } + return v } @@ -263,6 +267,7 @@ func NewStringSetting( if loaded { panic(fmt.Sprintf("setting %s already exists", k)) } + return CoverStringSetting(k, v, g, options...) } @@ -273,11 +278,14 @@ func CoverStringSetting( ) StringSetting { s := newString(k, v, g, options...) Settings[k] = s + if GroupSettings[g] == nil { GroupSettings[g] = make(map[model.SettingGroup]Setting) } + GroupSettings[g][k] = s pushNeedInit(s) + return s } @@ -286,7 +294,9 @@ func LoadStringSetting(k string) (StringSetting, bool) { if !ok { return nil, false } + ss, ok := s.(StringSetting) + return ss, ok } @@ -299,5 +309,6 @@ func LoadOrNewStringSetting( if ok { return s } + return NewStringSetting(k, v, g, options...) } diff --git a/internal/settings/var.go b/internal/settings/var.go index e4d1fea..83bdb1d 100644 --- a/internal/settings/var.go +++ b/internal/settings/var.go @@ -39,6 +39,7 @@ func init() { "room_must_need_pwd and room_must_no_need_pwd can't be true at the same time", ) } + return b, nil }), ) @@ -52,6 +53,7 @@ func init() { "room_must_need_pwd and room_must_no_need_pwd can't be true at the same time", ) } + return b, nil }), ) diff --git a/internal/sysNotify/sysNotify.go b/internal/sysNotify/sysNotify.go index cf6e5e2..2e67821 100644 --- a/internal/sysNotify/sysNotify.go +++ b/internal/sysNotify/sysNotify.go @@ -59,6 +59,7 @@ func NewSysNotifyTask(name string, notifyType NotifyType, task func() error) *Ta func runTask(tq *taskQueue) { tq.notifyTaskLock.Lock() defer tq.notifyTaskLock.Unlock() + for tq.notifyTaskQueue.Len() > 0 { _, task := tq.notifyTaskQueue.Pop() func() { @@ -67,10 +68,13 @@ func runTask(tq *taskQueue) { log.Errorf("task: %s panic has returned: %v", task.Name, err) } }() + log.Infof("task: %s running", task.Name) + if err := task.Task(); err != nil { log.Errorf("task: %s an error occurred: %v", task.Name, err) } + log.Infof("task: %s done", task.Name) }() } @@ -80,22 +84,29 @@ func (sn *SysNotify) RegisterSysNotifyTask(priority int, task *Task) error { if task == nil || task.Task == nil { return errors.New("task is nil") } + if task.NotifyType == 0 { panic("task notify type is 0") } + tasks, _ := sn.taskGroup.LoadOrStore(task.NotifyType, &taskQueue{ notifyTaskQueue: pqueue.NewMinPriorityQueue[*Task](), }) + tasks.notifyTaskLock.Lock() defer tasks.notifyTaskLock.Unlock() + tasks.notifyTaskQueue.Push(priority, task) + return nil } func (sn *SysNotify) waitCbk() { log.Info("wait sys notify") + for s := range sn.c { log.Infof("receive sys notify: %v", s) + switch parseSysNotifyType(s) { case NotifyTypeEXIT: tq, ok := sn.taskGroup.Load(NotifyTypeEXIT) @@ -103,6 +114,7 @@ func (sn *SysNotify) waitCbk() { log.Info("task: NotifyTypeEXIT running...") runTask(tq) } + return case NotifyTypeRELOAD: tq, ok := sn.taskGroup.Load(NotifyTypeRELOAD) @@ -112,6 +124,7 @@ func (sn *SysNotify) waitCbk() { } } } + log.Info("task: all done") } diff --git a/internal/sysnotify/sysnotify.go b/internal/sysnotify/sysnotify.go index cf6e5e2..2e67821 100644 --- a/internal/sysnotify/sysnotify.go +++ b/internal/sysnotify/sysnotify.go @@ -59,6 +59,7 @@ func NewSysNotifyTask(name string, notifyType NotifyType, task func() error) *Ta func runTask(tq *taskQueue) { tq.notifyTaskLock.Lock() defer tq.notifyTaskLock.Unlock() + for tq.notifyTaskQueue.Len() > 0 { _, task := tq.notifyTaskQueue.Pop() func() { @@ -67,10 +68,13 @@ func runTask(tq *taskQueue) { log.Errorf("task: %s panic has returned: %v", task.Name, err) } }() + log.Infof("task: %s running", task.Name) + if err := task.Task(); err != nil { log.Errorf("task: %s an error occurred: %v", task.Name, err) } + log.Infof("task: %s done", task.Name) }() } @@ -80,22 +84,29 @@ func (sn *SysNotify) RegisterSysNotifyTask(priority int, task *Task) error { if task == nil || task.Task == nil { return errors.New("task is nil") } + if task.NotifyType == 0 { panic("task notify type is 0") } + tasks, _ := sn.taskGroup.LoadOrStore(task.NotifyType, &taskQueue{ notifyTaskQueue: pqueue.NewMinPriorityQueue[*Task](), }) + tasks.notifyTaskLock.Lock() defer tasks.notifyTaskLock.Unlock() + tasks.notifyTaskQueue.Push(priority, task) + return nil } func (sn *SysNotify) waitCbk() { log.Info("wait sys notify") + for s := range sn.c { log.Infof("receive sys notify: %v", s) + switch parseSysNotifyType(s) { case NotifyTypeEXIT: tq, ok := sn.taskGroup.Load(NotifyTypeEXIT) @@ -103,6 +114,7 @@ func (sn *SysNotify) waitCbk() { log.Info("task: NotifyTypeEXIT running...") runTask(tq) } + return case NotifyTypeRELOAD: tq, ok := sn.taskGroup.Load(NotifyTypeRELOAD) @@ -112,6 +124,7 @@ func (sn *SysNotify) waitCbk() { } } } + log.Info("task: all done") } diff --git a/internal/vendor/alist.go b/internal/vendor/alist.go index bf5b692..5a9d52b 100644 --- a/internal/vendor/alist.go +++ b/internal/vendor/alist.go @@ -32,7 +32,9 @@ func NewAlistGrpcClient(conn *grpc.ClientConn) (AlistInterface, error) { if conn == nil { return nil, errors.New("grpc client conn is nil") } + conn.GetState() + return newGrpcAlist(alist.NewAlistClient(conn)), nil } diff --git a/internal/vendor/emby.go b/internal/vendor/emby.go index 8ff0513..a953502 100644 --- a/internal/vendor/emby.go +++ b/internal/vendor/emby.go @@ -32,7 +32,9 @@ func NewEmbyGrpcClient(conn *grpc.ClientConn) (EmbyInterface, error) { if conn == nil { return nil, errors.New("grpc client conn is nil") } + conn.GetState() + return newGrpcEmby(emby.NewEmbyClient(conn)), nil } diff --git a/internal/vendor/vendor.go b/internal/vendor/vendor.go index 1176c9e..b84d1ef 100644 --- a/internal/vendor/vendor.go +++ b/internal/vendor/vendor.go @@ -70,15 +70,19 @@ func Init(ctx context.Context) error { if err != nil { return err } + bc, err := newBackendConns(ctx, vb) if err != nil { return err } + vc, err := newVendorClients(bc) if err != nil { return err } + storeBackends(bc, vc) + return nil } @@ -124,6 +128,7 @@ func EnableVendorBackends(_ context.Context, endpoints []string) (err error) { defer lock.Unlock() raw := LoadConns() + needChangeEndpoints := make([]string, 0, len(endpoints)) for _, endpoint := range endpoints { if v, ok := raw[endpoint]; !ok { @@ -202,6 +207,7 @@ func DisableVendorBackends(_ context.Context, endpoints []string) (err error) { defer lock.Unlock() raw := LoadConns() + needChangeEndpoints := make([]string, 0, len(endpoints)) for _, endpoint := range endpoints { if v, ok := raw[endpoint]; !ok { @@ -319,7 +325,9 @@ func DeleteVendorBackends(_ context.Context, endpoints []string) error { if !ok { return fmt.Errorf("endpoint not found: %s", endpoint) } + beforeConn[i] = conn.Conn + delete(m, endpoint) } @@ -334,6 +342,7 @@ func DeleteVendorBackends(_ context.Context, endpoints []string) error { } storeBackends(m, vc) + for _, conn := range beforeConn { conn.Close() } @@ -410,6 +419,7 @@ func newBackendConn( if err != nil { return conns, err } + return &BackendConn{ Conn: cc, Info: conf, @@ -429,14 +439,17 @@ func newBackendConns( } } }() + for _, vb := range conf { if _, ok := conns[vb.Backend.Endpoint]; ok { return conns, fmt.Errorf("duplicate endpoint: %s", vb.Backend.Endpoint) } + cc, err := newBackendConn(ctx, vb) if err != nil { return conns, err } + conns[vb.Backend.Endpoint] = cc } @@ -453,6 +466,7 @@ func newVendorClients(conns map[string]*BackendConn) (*Clients, error) { if !conn.Info.UsedBy.Enabled { continue } + if conn.Info.UsedBy.Bilibili { if _, ok := clients.bilibili[conn.Info.UsedBy.BilibiliBackendName]; ok { return nil, fmt.Errorf( @@ -460,12 +474,15 @@ func newVendorClients(conns map[string]*BackendConn) (*Clients, error) { conn.Info.UsedBy.BilibiliBackendName, ) } + cli, err := NewBilibiliGrpcClient(conn.Conn) if err != nil { return nil, err } + clients.bilibili[conn.Info.UsedBy.BilibiliBackendName] = cli } + if conn.Info.UsedBy.Alist { if _, ok := clients.alist[conn.Info.UsedBy.AlistBackendName]; ok { return nil, fmt.Errorf( @@ -473,12 +490,15 @@ func newVendorClients(conns map[string]*BackendConn) (*Clients, error) { conn.Info.UsedBy.AlistBackendName, ) } + cli, err := NewAlistGrpcClient(conn.Conn) if err != nil { return nil, err } + clients.alist[conn.Info.UsedBy.AlistBackendName] = cli } + if conn.Info.UsedBy.Emby { if _, ok := clients.emby[conn.Info.UsedBy.EmbyBackendName]; ok { return nil, fmt.Errorf( @@ -486,10 +506,12 @@ func newVendorClients(conns map[string]*BackendConn) (*Clients, error) { conn.Info.UsedBy.EmbyBackendName, ) } + cli, err := NewEmbyGrpcClient(conn.Conn) if err != nil { return nil, err } + clients.emby[conn.Info.UsedBy.EmbyBackendName] = cli } } @@ -501,17 +523,20 @@ func NewGrpcConn(ctx context.Context, conf *model.Backend) (*grpc.ClientConn, er if err := conf.Validate(); err != nil { return nil, err } + _, _, err := net.SplitHostPort(conf.Endpoint) if err != nil { if !strings.Contains(err.Error(), "missing port in address") { return nil, err } + if conf.TLS { conf.Endpoint += ":443" } else { conf.Endpoint += ":80" } } + middlewares := []middleware.Middleware{ kcircuitbreaker.Client( kcircuitbreaker.WithCircuitBreaker(func() circuitbreaker.CircuitBreaker { @@ -525,6 +550,7 @@ func NewGrpcConn(ctx context.Context, conf *model.Backend) (*grpc.ClientConn, er if conf.JwtSecret != "" { key := []byte(conf.JwtSecret) + middlewares = append(middlewares, jwt.Client(func(_ *jwtv5.Token) (any, error) { return key, nil }, jwt.WithSigningMethod(jwtv5.SigningMethodHS256))) @@ -540,6 +566,7 @@ func NewGrpcConn(ctx context.Context, conf *model.Backend) (*grpc.ClientConn, er if err != nil { return nil, err } + opts = append(opts, ggrpc.WithTimeout(timeout)) } @@ -551,16 +578,20 @@ func NewGrpcConn(ctx context.Context, conf *model.Backend) (*grpc.ClientConn, er c.PathPrefix = conf.Consul.PathPrefix c.Namespace = conf.Consul.Namespace c.Partition = conf.Consul.Partition + client, err := api.NewClient(c) if err != nil { return nil, err } + endpoint := "discovery:///" + conf.Consul.ServiceName dis := consul.New(client) opts = append(opts, ggrpc.WithEndpoint(endpoint), ggrpc.WithDiscovery(dis)) + log.Infof("new grpc client with consul: %s", conf.Endpoint) case conf.Etcd.ServiceName != "": endpoint := "discovery:///" + conf.Etcd.ServiceName + cli, err := clientv3.New(clientv3.Config{ Endpoints: []string{conf.Endpoint}, Username: conf.Etcd.Username, @@ -569,8 +600,10 @@ func NewGrpcConn(ctx context.Context, conf *model.Backend) (*grpc.ClientConn, er if err != nil { return nil, err } + dis := etcd.New(cli) opts = append(opts, ggrpc.WithEndpoint(endpoint), ggrpc.WithDiscovery(dis)) + log.Infof("new grpc client with etcd: %v", conf.Endpoint) default: opts = append(opts, ggrpc.WithEndpoint(conf.Endpoint)) @@ -580,13 +613,16 @@ func NewGrpcConn(ctx context.Context, conf *model.Backend) (*grpc.ClientConn, er var con *grpc.ClientConn if conf.TLS { var rootCAs *x509.CertPool + rootCAs, err = x509.SystemCertPool() if err != nil { return nil, err } + if conf.CustomCa != "" { rootCAs.AppendCertsFromPEM([]byte(conf.CustomCa)) } + opts = append(opts, ggrpc.WithTLSConfig(&tls.Config{ RootCAs: rootCAs, MinVersion: tls.VersionTLS12, @@ -602,9 +638,11 @@ func NewGrpcConn(ctx context.Context, conf *model.Backend) (*grpc.ClientConn, er opts..., ) } + if err != nil { return nil, err } + return con, nil } @@ -612,17 +650,20 @@ func NewHTTPClientConn(ctx context.Context, conf *model.Backend) (*http.Client, if err := conf.Validate(); err != nil { return nil, err } + _, _, err := net.SplitHostPort(conf.Endpoint) if err != nil { if !strings.Contains(err.Error(), "missing port in address") { return nil, err } + if conf.TLS { conf.Endpoint += ":443" } else { conf.Endpoint += ":80" } } + middlewares := []middleware.Middleware{ kcircuitbreaker.Client( kcircuitbreaker.WithCircuitBreaker(func() circuitbreaker.CircuitBreaker { @@ -636,6 +677,7 @@ func NewHTTPClientConn(ctx context.Context, conf *model.Backend) (*http.Client, if conf.JwtSecret != "" { key := []byte(conf.JwtSecret) + middlewares = append(middlewares, jwt.Client(func(_ *jwtv5.Token) (any, error) { return key, nil }, jwt.WithSigningMethod(jwtv5.SigningMethodHS256))) @@ -650,6 +692,7 @@ func NewHTTPClientConn(ctx context.Context, conf *model.Backend) (*http.Client, if err != nil { return nil, err } + opts = append(opts, http.WithTimeout(timeout)) } else { opts = append(opts, http.WithTimeout(time.Second*10)) @@ -660,13 +703,16 @@ func NewHTTPClientConn(ctx context.Context, conf *model.Backend) (*http.Client, if err != nil { return nil, err } + if conf.CustomCa != "" { b, err := os.ReadFile(conf.CustomCa) if err != nil { return nil, err } + rootCAs.AppendCertsFromPEM(b) } + opts = append(opts, http.WithTLSConfig(&tls.Config{ RootCAs: rootCAs, MinVersion: tls.VersionTLS12, @@ -681,16 +727,20 @@ func NewHTTPClientConn(ctx context.Context, conf *model.Backend) (*http.Client, c.PathPrefix = conf.Consul.PathPrefix c.Namespace = conf.Consul.Namespace c.Partition = conf.Consul.Partition + client, err := api.NewClient(c) if err != nil { return nil, err } + endpoint := "discovery:///" + conf.Consul.ServiceName dis := consul.New(client) opts = append(opts, http.WithEndpoint(endpoint), http.WithDiscovery(dis)) + log.Infof("new http client with consul: %s", conf.Endpoint) case conf.Etcd.ServiceName != "": endpoint := "discovery:///" + conf.Etcd.ServiceName + cli, err := clientv3.New(clientv3.Config{ Endpoints: []string{conf.Endpoint}, Username: conf.Etcd.Username, @@ -699,8 +749,10 @@ func NewHTTPClientConn(ctx context.Context, conf *model.Backend) (*http.Client, if err != nil { return nil, err } + dis := etcd.New(cli) opts = append(opts, http.WithEndpoint(endpoint), http.WithDiscovery(dis)) + log.Infof("new http client with etcd: %v", conf.Endpoint) default: opts = append(opts, http.WithEndpoint(conf.Endpoint)) @@ -714,5 +766,6 @@ func NewHTTPClientConn(ctx context.Context, conf *model.Backend) (*http.Client, if err != nil { return nil, err } + return con, nil } diff --git a/internal/version/update.go b/internal/version/update.go index ef35253..3674f08 100644 --- a/internal/version/update.go +++ b/internal/version/update.go @@ -13,11 +13,13 @@ import ( func SelfUpdate(ctx context.Context, url string) error { now := time.Now().UnixNano() + currentExecFile, err := ExecutableFile() if err != nil { log.Errorf("self update: get current executable file error: %v", err) return err } + log.Debugf("self update: current executable file: %s", currentExecFile) tmp := filepath.Join(os.TempDir(), "synctv-server", fmt.Sprintf("self-update-%d", now)) @@ -25,24 +27,29 @@ func SelfUpdate(ctx context.Context, url string) error { log.Errorf("self update: mkdir %s error: %v", tmp, err) return err } + log.Infof("self update: temp path: %s", tmp) defer func() { log.Infof("self update: remove temp path: %s", tmp) + if err := os.RemoveAll(tmp); err != nil { log.Warnf("self update: remove temp path error: %v", err) } }() + file, err := DownloadWithProgress(ctx, url, tmp) if err != nil { log.Errorf("self update: download %s error: %v", url, err) return err } + log.Infof("self update: download success: %s", file) if err := os.Chmod(file, 0o755); err != nil { log.Errorf("self update: chmod %s error: %v", file, err) return err } + log.Debugf("self update: chmod success: %s", file) oldName := fmt.Sprintf("%s-%d.old", currentExecFile, now) @@ -50,11 +57,13 @@ func SelfUpdate(ctx context.Context, url string) error { log.Errorf("self update: rename %s -> %s error: %v", currentExecFile, oldName, err) return err } + log.Debugf("self update: rename success: %s -> %s", currentExecFile, oldName) defer func() { if err != nil { log.Warnf("self update: rollback: %s -> %s", oldName, currentExecFile) + if err := os.Rename(oldName, currentExecFile); err != nil { log.Errorf( "self update: rollback: rename %s -> %s error: %v", @@ -65,6 +74,7 @@ func SelfUpdate(ctx context.Context, url string) error { } } else { log.Debugf("self update: remove old executable file: %s", oldName) + if err := os.Remove(oldName); err != nil { log.Warnf("self update: remove old executable file error: %v", err) } @@ -87,8 +97,10 @@ func DownloadWithProgress(ctx context.Context, url, path string) (string, error) if err != nil { return "", err } + req = req.WithContext(ctx) resp := grab.NewClient().Do(req) + t := time.NewTicker(250 * time.Millisecond) defer t.Stop() @@ -112,5 +124,6 @@ func ExecutableFile() (string, error) { if err != nil { return "", err } + return filepath.EvalSymlinks(p) } diff --git a/internal/version/version.go b/internal/version/version.go index dc65a5a..2216c9e 100644 --- a/internal/version/version.go +++ b/internal/version/version.go @@ -60,6 +60,7 @@ func NewVersionInfo(conf ...InfoConf) (*Info, error) { for _, c := range conf { c(v) } + return v, v.fix() } @@ -67,7 +68,9 @@ func (v *Info) fix() (err error) { if v.baseURL == "" { v.baseURL = "https://api.github.com/" } + v.c, err = github.NewClient(nil).WithEnterpriseURLs(v.baseURL, "") + return err } @@ -75,16 +78,20 @@ func (v *Info) initLatest(ctx context.Context) (err error) { if v.latest != nil { return nil } + v.latest, _, err = v.c.Repositories.GetLatestRelease(ctx, owner, repo) - return + + return err } func (v *Info) initDev(ctx context.Context) (err error) { if v.dev != nil { return nil } + v.dev, _, err = v.c.Repositories.GetReleaseByTag(ctx, owner, repo, "dev") - return + + return err } func (v *Info) Current() string { @@ -103,7 +110,9 @@ func (v *Info) CheckLatest(ctx context.Context) (string, error) { if err != nil { return "", err } + v.latest = release + return release.GetTagName(), nil } @@ -128,6 +137,7 @@ func getBinaryURL(repo *github.RepositoryRelease) (string, error) { return a.GetBrowserDownloadURL(), nil } } + return "", errors.New("no binary found") } @@ -169,10 +179,12 @@ func (v *Info) SelfUpdate(ctx context.Context) (err error) { if err != nil { return err } + comp, err := utils.CompVersion(v.Current(), latest) if err != nil { return err } + switch comp { case utils.VersionEqual: log.Infof("self update: current version is latest: %s", v.Current()) @@ -189,6 +201,7 @@ func (v *Info) SelfUpdate(ctx context.Context) (err error) { v.Current(), latest, ) + return nil } default: @@ -201,6 +214,7 @@ func (v *Info) SelfUpdate(ctx context.Context) (err error) { } else { url, err = v.LatestBinaryURL(ctx) } + if err != nil { return err } diff --git a/internal/version/version_test.go b/internal/version/version_test.go index 0a3722e..e4a1db2 100644 --- a/internal/version/version_test.go +++ b/internal/version/version_test.go @@ -11,10 +11,12 @@ func TestCheckLatest(t *testing.T) { if err != nil { t.Fatal(err) } + s, err := v.CheckLatest(t.Context()) if err != nil { t.Fatal(err) } + t.Log(s) } @@ -23,9 +25,11 @@ func TestLatestBinaryURL(t *testing.T) { if err != nil { t.Fatal(err) } + s, err := v.LatestBinaryURL(t.Context()) if err != nil { t.Fatal(err) } + t.Log(s) } diff --git a/proto/message/message.go b/proto/message/message.go index 2dd478d..591a614 100644 --- a/proto/message/message.go +++ b/proto/message/message.go @@ -16,6 +16,8 @@ func (em *Message) Encode(w io.Writer) error { if err != nil { return err } + _, err = w.Write(b) + return err } diff --git a/server/handlers/admin.go b/server/handlers/admin.go index 4dd96c5..c49ca41 100644 --- a/server/handlers/admin.go +++ b/server/handlers/admin.go @@ -55,17 +55,20 @@ func AdminSettings(ctx *gin.Context) { switch group { case "oauth2": const groupPrefix = dbModel.SettingGroupOauth2 + settingGroups := make(map[string]map[string]settings.Setting) for sg, v := range settings.GroupSettings { if strings.HasPrefix(sg, groupPrefix) { settingGroups[sg] = v } } + resp := make(model.AdminSettingsResp, len(settingGroups)) for k, v := range settingGroups { if resp[k] == nil { resp[k] = make(gin.H, len(v)) } + for k2, s := range v { resp[k][k2] = s.Interface() } @@ -78,6 +81,7 @@ func AdminSettings(ctx *gin.Context) { if resp[sg] == nil { resp[sg] = make(gin.H, len(v)) } + for _, s2 := range v { resp[sg][s2.Name()] = s2.Interface() } @@ -92,8 +96,10 @@ func AdminSettings(ctx *gin.Context) { http.StatusBadRequest, model.NewAPIErrorStringResp("group not found"), ) + return } + data := make(map[string]any, len(s)) for _, v := range s { data[v.Name()] = v.Interface() @@ -141,6 +147,7 @@ func AdminGetUsers(ctx *gin.Context) { ctx.AbortWithStatusJSON(http.StatusBadRequest, model.NewAPIErrorResp(err)) return } + scopes = append(scopes, db.WhereUsernameLikeOrIDIn(keyword, ids)) case "name": scopes = append(scopes, db.WhereUsernameLike(keyword)) @@ -151,6 +158,7 @@ func AdminGetUsers(ctx *gin.Context) { ctx.AbortWithStatusJSON(http.StatusBadRequest, model.NewAPIErrorResp(err)) return } + scopes = append(scopes, db.WhereIDIn(ids)) } } @@ -182,6 +190,7 @@ func AdminGetUsers(ctx *gin.Context) { http.StatusBadRequest, model.NewAPIErrorStringResp("not support sort"), ) + return } @@ -208,6 +217,7 @@ func genUserListResp(us []*dbModel.User) []*model.UserInfoResp { CreatedAt: v.CreatedAt.UnixMilli(), } } + return resp } @@ -252,6 +262,7 @@ func AdminGetRoomMembers(ctx *gin.Context) { ctx.AbortWithStatusJSON(http.StatusInternalServerError, model.NewAPIErrorResp(err)) return } + scopes = append(scopes, db.WhereUsernameLikeOrIDIn(keyword, ids)) case "name": scopes = append(scopes, db.WhereUsernameLike(keyword)) @@ -262,9 +273,11 @@ func AdminGetRoomMembers(ctx *gin.Context) { ctx.AbortWithStatusJSON(http.StatusInternalServerError, model.NewAPIErrorResp(err)) return } + scopes = append(scopes, db.WhereIDIn(ids)) } } + scopes = append(scopes, func(db *gorm.DB) *gorm.DB { return db. InnerJoins("JOIN room_members ON users.id = room_members.user_id"). @@ -300,6 +313,7 @@ func AdminGetRoomMembers(ctx *gin.Context) { http.StatusBadRequest, model.NewAPIErrorStringResp("not support sort"), ) + return } @@ -323,6 +337,7 @@ func genRoomMemberListResp(us []*dbModel.User, room *op.Room) []*model.RoomMembe if room.IsGuest(v.ID) { permissions = room.Settings.GuestPermissions } + resp[i] = &model.RoomMembersResp{ UserID: v.ID, Username: v.Username, @@ -335,6 +350,7 @@ func genRoomMemberListResp(us []*dbModel.User, room *op.Room) []*model.RoomMembe AdminPermissions: v.RoomMembers[0].AdminPermissions, } } + return resp } @@ -353,6 +369,7 @@ func AdminApprovePendingUser(ctx *gin.Context) { ctx.AbortWithStatusJSON(http.StatusBadRequest, model.NewAPIErrorResp(err)) return } + user := userE.Value() if !user.IsPending() { @@ -361,6 +378,7 @@ func AdminApprovePendingUser(ctx *gin.Context) { http.StatusBadRequest, model.NewAPIErrorStringResp("user is not pending"), ) + return } @@ -390,6 +408,7 @@ func AdminBanUser(ctx *gin.Context) { http.StatusBadRequest, model.NewAPIErrorStringResp("cannot ban self"), ) + return } @@ -406,6 +425,7 @@ func AdminBanUser(ctx *gin.Context) { http.StatusBadRequest, model.NewAPIErrorStringResp("cannot ban root"), ) + return } @@ -415,6 +435,7 @@ func AdminBanUser(ctx *gin.Context) { http.StatusBadRequest, model.NewAPIErrorStringResp("cannot ban admin"), ) + return } @@ -451,6 +472,7 @@ func AdminUnBanUser(ctx *gin.Context) { http.StatusBadRequest, model.NewAPIErrorStringResp("user is not banned"), ) + return } @@ -496,6 +518,7 @@ func AdminGetRooms(ctx *gin.Context) { ctx.AbortWithStatusJSON(http.StatusBadRequest, model.NewAPIErrorResp(err)) return } + scopes = append(scopes, db.WhereRoomNameLikeOrCreatorInOrIDLike(keyword, ids, keyword)) case "name": scopes = append(scopes, db.WhereRoomNameLike(keyword)) @@ -506,6 +529,7 @@ func AdminGetRooms(ctx *gin.Context) { ctx.AbortWithStatusJSON(http.StatusBadRequest, model.NewAPIErrorResp(err)) return } + scopes = append(scopes, db.WhereCreatorIDIn(ids)) case "creatorId": scopes = append(scopes, db.WhereCreatorID(keyword)) @@ -541,6 +565,7 @@ func AdminGetRooms(ctx *gin.Context) { http.StatusBadRequest, model.NewAPIErrorStringResp("not support sort"), ) + return } @@ -566,6 +591,7 @@ func AdminGetUserRooms(ctx *gin.Context) { ctx.AbortWithStatusJSON(http.StatusBadRequest, model.NewAPIErrorStringResp("user id error")) return } + page, pageSize, err := utils.GetPageAndMax(ctx) if err != nil { log.Errorf("get page and max error: %v", err) @@ -625,6 +651,7 @@ func AdminGetUserRooms(ctx *gin.Context) { http.StatusBadRequest, model.NewAPIErrorStringResp("not support sort"), ) + return } @@ -721,6 +748,7 @@ func AdminGetUserJoinedRooms(ctx *gin.Context) { http.StatusBadRequest, model.NewAPIErrorStringResp("not support sort"), ) + return } @@ -761,6 +789,7 @@ func AdminApprovePendingRoom(ctx *gin.Context) { http.StatusBadRequest, model.NewAPIErrorStringResp("room is not pending"), ) + return } @@ -800,6 +829,7 @@ func AdminBanRoom(ctx *gin.Context) { ctx.AbortWithStatusJSON(http.StatusBadRequest, model.NewAPIErrorResp(err)) return } + creator := creatorE.Value() if creator.IsRoot() { @@ -808,6 +838,7 @@ func AdminBanRoom(ctx *gin.Context) { http.StatusBadRequest, model.NewAPIErrorStringResp("cannot ban root"), ) + return } @@ -817,6 +848,7 @@ func AdminBanRoom(ctx *gin.Context) { http.StatusForbidden, model.NewAPIErrorStringResp("cannot ban admin"), ) + return } } @@ -856,6 +888,7 @@ func AdminUnBanRoom(ctx *gin.Context) { http.StatusBadRequest, model.NewAPIErrorStringResp("room is not banned"), ) + return } @@ -885,6 +918,7 @@ func AdminAddUser(ctx *gin.Context) { http.StatusForbidden, model.NewAPIErrorStringResp("you cannot add root user"), ) + return } @@ -921,6 +955,7 @@ func AdminDeleteUser(ctx *gin.Context) { http.StatusBadRequest, model.NewAPIErrorStringResp("cannot delete yourself"), ) + return } @@ -930,6 +965,7 @@ func AdminDeleteUser(ctx *gin.Context) { http.StatusBadRequest, model.NewAPIErrorStringResp("cannot delete root"), ) + return } @@ -939,6 +975,7 @@ func AdminDeleteUser(ctx *gin.Context) { http.StatusForbidden, model.NewAPIErrorStringResp("cannot delete admin"), ) + return } @@ -977,6 +1014,7 @@ func AdminDeleteRoom(ctx *gin.Context) { ctx.AbortWithStatusJSON(http.StatusBadRequest, model.NewAPIErrorResp(err)) return } + creator := u.Value() if creator.IsRoot() { @@ -985,6 +1023,7 @@ func AdminDeleteRoom(ctx *gin.Context) { http.StatusBadRequest, model.NewAPIErrorStringResp("cannot delete root's room"), ) + return } @@ -994,6 +1033,7 @@ func AdminDeleteRoom(ctx *gin.Context) { http.StatusForbidden, model.NewAPIErrorStringResp("cannot delete admin's room"), ) + return } } @@ -1024,6 +1064,7 @@ func AdminUserPassword(ctx *gin.Context) { http.StatusBadRequest, model.NewAPIErrorStringResp("user not found"), ) + return } @@ -1034,6 +1075,7 @@ func AdminUserPassword(ctx *gin.Context) { http.StatusBadRequest, model.NewAPIErrorStringResp("cannot change root password"), ) + return } @@ -1043,6 +1085,7 @@ func AdminUserPassword(ctx *gin.Context) { http.StatusForbidden, model.NewAPIErrorStringResp("cannot change admin password"), ) + return } } @@ -1053,6 +1096,7 @@ func AdminUserPassword(ctx *gin.Context) { http.StatusInternalServerError, model.NewAPIErrorStringResp(err.Error()), ) + return } @@ -1076,6 +1120,7 @@ func AdminUsername(ctx *gin.Context) { http.StatusBadRequest, model.NewAPIErrorStringResp("user not found"), ) + return } @@ -1086,6 +1131,7 @@ func AdminUsername(ctx *gin.Context) { http.StatusBadRequest, model.NewAPIErrorStringResp("cannot change root username"), ) + return } @@ -1095,6 +1141,7 @@ func AdminUsername(ctx *gin.Context) { http.StatusForbidden, model.NewAPIErrorStringResp("cannot change admin username"), ) + return } } @@ -1105,6 +1152,7 @@ func AdminUsername(ctx *gin.Context) { http.StatusInternalServerError, model.NewAPIErrorStringResp(err.Error()), ) + return } @@ -1128,6 +1176,7 @@ func AdminRoomPassword(ctx *gin.Context) { http.StatusBadRequest, model.NewAPIErrorStringResp("room not found"), ) + return } @@ -1141,6 +1190,7 @@ func AdminRoomPassword(ctx *gin.Context) { http.StatusBadRequest, model.NewAPIErrorStringResp("room creator not found"), ) + return } @@ -1150,6 +1200,7 @@ func AdminRoomPassword(ctx *gin.Context) { http.StatusBadRequest, model.NewAPIErrorStringResp("cannot change root room password"), ) + return } @@ -1159,6 +1210,7 @@ func AdminRoomPassword(ctx *gin.Context) { http.StatusForbidden, model.NewAPIErrorStringResp("cannot change admin room password"), ) + return } } @@ -1169,6 +1221,7 @@ func AdminRoomPassword(ctx *gin.Context) { http.StatusInternalServerError, model.NewAPIErrorStringResp(err.Error()), ) + return } @@ -1180,29 +1233,35 @@ func AdminGetVendorBackends(ctx *gin.Context) { log := middlewares.GetLogger(ctx) conns := vendor.LoadConns() + page, size, err := utils.GetPageAndMax(ctx) if err != nil { log.Errorf("get page and max error: %v", err) ctx.AbortWithStatusJSON(http.StatusBadRequest, model.NewAPIErrorResp(err)) return } + s := slices.Collect(maps.Keys(conns)) l := len(s) + var resp []*model.GetVendorBackendResp if (page-1)*size <= l { slices.SortStableFunc(s, func(a, b string) int { if a == b { return 0 } + if natural.Less(a, b) { return -1 } + return 1 }) if l > size { l = size } + resp = make([]*model.GetVendorBackendResp, 0, l) for _, v := range s[(page-1)*size : (page-1)*size+l] { resp = append(resp, &model.GetVendorBackendResp{ @@ -1291,9 +1350,11 @@ func AdminReconnectVendorBackends(ctx *gin.Context) { if s := c.Conn.GetState(); s != connectivity.Ready { c.Conn.Connect() c.Conn.ResetConnectBackoff() + if len(req.Endpoints) == 1 { ctx2, cf := context.WithTimeout(ctx, time.Second*5) defer cf() + c.Conn.WaitForStateChange(ctx2, s) } } @@ -1358,11 +1419,13 @@ func AdminSendTestEmail(ctx *gin.Context) { if req.Email == "" { if err := user.SendTestEmail(); err != nil { log.Errorf("failed to send test email: %v", err) + if errors.Is(err, op.ErrEmailUnbound) { ctx.AbortWithStatusJSON(http.StatusBadRequest, model.NewAPIErrorResp(err)) } else { ctx.AbortWithStatusJSON(http.StatusInternalServerError, model.NewAPIErrorResp(err)) } + return } } else { diff --git a/server/handlers/danmu.go b/server/handlers/danmu.go index a4702da..d572d54 100644 --- a/server/handlers/danmu.go +++ b/server/handlers/danmu.go @@ -37,6 +37,7 @@ func StreamDanmu(ctx *gin.Context) { http.StatusBadRequest, model.NewAPIErrorStringResp("vendor not support danmu"), ) + return } @@ -45,10 +46,13 @@ func StreamDanmu(ctx *gin.Context) { err = danmu.StreamDanmu(c, func(danmu string) error { ctx.SSEvent("danmu", danmu) + if err := ctx.Err(); err != nil { return err } + ctx.Writer.Flush() + return nil }) if err != nil { diff --git a/server/handlers/init.go b/server/handlers/init.go index f307e9f..db31c7b 100644 --- a/server/handlers/init.go +++ b/server/handlers/init.go @@ -24,6 +24,7 @@ func Init(e *gin.Engine) { { admin := api.Group("/admin") root := api.Group("/admin") + admin.Use(middlewares.AuthAdminMiddleware) root.Use(middlewares.AuthRootMiddleware) @@ -194,7 +195,6 @@ func initRoom(room, needAuthUser, needAuthRoom, needAuthWithoutGuestRoom *gin.Ro func initMovie(movie, needAuthMovie *gin.RouterGroup) { // needAuthMovie.GET("/list", MovieList) - needAuthMovie.GET("/current", CurrentMovie) needAuthMovie.GET("/movies", Movies) diff --git a/server/handlers/member.go b/server/handlers/member.go index 3f425eb..825548e 100644 --- a/server/handlers/member.go +++ b/server/handlers/member.go @@ -44,6 +44,7 @@ func RoomMembers(ctx *gin.Context) { ctx.AbortWithStatusJSON(http.StatusInternalServerError, model.NewAPIErrorResp(err)) return } + scopes = append(scopes, db.WhereUsernameLikeOrIDIn(keyword, ids)) case "name": scopes = append(scopes, db.WhereUsernameLike(keyword)) @@ -54,9 +55,11 @@ func RoomMembers(ctx *gin.Context) { ctx.AbortWithStatusJSON(http.StatusInternalServerError, model.NewAPIErrorResp(err)) return } + scopes = append(scopes, db.WhereIDIn(ids)) } } + scopes = append(scopes, func(db *gorm.DB) *gorm.DB { return db. InnerJoins("JOIN room_members ON users.id = room_members.user_id"). @@ -92,6 +95,7 @@ func RoomMembers(ctx *gin.Context) { http.StatusBadRequest, model.NewAPIErrorStringResp("not support sort"), ) + return } @@ -149,6 +153,7 @@ func RoomAdminMembers(ctx *gin.Context) { ctx.AbortWithStatusJSON(http.StatusInternalServerError, model.NewAPIErrorResp(err)) return } + scopes = append(scopes, db.WhereUsernameLikeOrIDIn(keyword, ids)) case "name": scopes = append(scopes, db.WhereUsernameLike(keyword)) @@ -159,9 +164,11 @@ func RoomAdminMembers(ctx *gin.Context) { ctx.AbortWithStatusJSON(http.StatusInternalServerError, model.NewAPIErrorResp(err)) return } + scopes = append(scopes, db.WhereIDIn(ids)) } } + scopes = append(scopes, func(db *gorm.DB) *gorm.DB { return db. Joins("JOIN room_members ON users.id = room_members.user_id"). @@ -197,6 +204,7 @@ func RoomAdminMembers(ctx *gin.Context) { http.StatusBadRequest, model.NewAPIErrorStringResp("not support sort"), ) + return } diff --git a/server/handlers/movie.go b/server/handlers/movie.go index d655a15..0954d70 100644 --- a/server/handlers/movie.go +++ b/server/handlers/movie.go @@ -50,21 +50,25 @@ func genMovieInfo( if opMovie == nil || opMovie.ID == "" { return &model.Movie{}, nil } + if opMovie.IsFolder { if !opMovie.IsDynamicFolder() { return nil, errors.New("movie is static folder, can't get movie info") } } + movie := opMovie.Clone() if movie.Type == "" && movie.URL != "" { movie.Type = utils.GetURLExtension(movie.URL) } + switch { case movie.VendorInfo.Vendor != "": vendor, err := vendors.NewVendorService(room, opMovie) if err != nil { return nil, err } + movie, err = vendor.GenMovieInfo(ctx, user, userAgent, userToken) if err != nil { return nil, err @@ -101,6 +105,7 @@ func genMovieInfo( Type: "flv", }) } + movie.URL = fmt.Sprintf( "/api/room/movie/live/hls/list/%s.m3u8?token=%s&roomId=%s", movie.ID, @@ -118,19 +123,23 @@ func genMovieInfo( ) movie.Headers = nil } + if movie.Type == "" && movie.URL != "" { movie.Type = utils.GetURLExtension(movie.URL) } + for _, v := range movie.MoreSources { if v.Type == "" { v.Type = utils.GetURLExtension(v.URL) } } + for _, v := range movie.Subtitles { if v.Type == "" { v.Type = utils.GetURLExtension(v.URL) } } + resp := &model.Movie{ ID: movie.ID, CreatedAt: movie.CreatedAt.UnixMilli(), @@ -139,6 +148,7 @@ func genMovieInfo( CreatorID: movie.CreatorID, SubPath: opMovie.SubPath(), } + return resp, nil } @@ -154,23 +164,28 @@ func genCurrentRespWithCurrent( Movie: &model.Movie{}, }, nil } + opMovie, err := room.GetMovieByID(current.Movie.ID) if err != nil { return nil, fmt.Errorf("get current movie error: %w", err) } + mr, err := genMovieInfo(ctx, room, user, opMovie, userAgent, userToken) if err != nil { return nil, fmt.Errorf("gen current movie info error: %w", err) } + expireID, err := opMovie.ExpireID(ctx) if err != nil { return nil, fmt.Errorf("get expire id error: %w", err) } + resp := &model.CurrentMovieResp{ Status: current.UpdateStatus(), Movie: mr, ExpireID: expireID, } + return resp, nil } @@ -205,6 +220,7 @@ func Movies(ctx *gin.Context) { http.StatusForbidden, model.NewAPIErrorResp(dbModel.ErrNoPermission), ) + return } @@ -214,6 +230,7 @@ func Movies(ctx *gin.Context) { http.StatusBadRequest, model.NewAPIErrorStringResp("id length must be 0 or 32"), ) + return } @@ -231,13 +248,16 @@ func Movies(ctx *gin.Context) { ctx.AbortWithStatusJSON(http.StatusBadRequest, model.NewAPIErrorResp(err)) return } + if !mv.IsFolder { ctx.AbortWithStatusJSON( http.StatusBadRequest, model.NewAPIErrorStringResp("parent id is not folder"), ) + return } + if mv.IsDynamicFolder() { resp, err := listVendorDynamicMovie( ctx, @@ -254,7 +274,9 @@ func Movies(ctx *gin.Context) { ctx.AbortWithStatusJSON(http.StatusInternalServerError, model.NewAPIErrorResp(err)) return } + ctx.JSON(http.StatusOK, model.NewAPIDataResp(resp)) + return } } @@ -306,17 +328,20 @@ func getParentMoviePath(room *op.Room, id string) ([]*model.MoviePath, error) { if err != nil { return nil, fmt.Errorf("get movie by id error: %w", err) } + paths = append(paths, &model.MoviePath{ Name: p.Name, ID: p.ID, }) id = p.ParentID.String() } + paths = append(paths, &model.MoviePath{ Name: "Home", ID: "", }) slices.Reverse(paths) + return paths, nil } @@ -336,19 +361,23 @@ func listVendorDynamicMovie( if err != nil { return nil, fmt.Errorf("get parent movie path error: %w", err) } + vendor, err := vendors.NewVendorService(room, movie) if err != nil { return nil, err } + dynamic, err := vendor.ListDynamicMovie(ctx, reqUser, subPath, keyword, page, _max) if err != nil { return nil, err } + dynamic.Paths = append(paths, dynamic.Paths...) resp := &model.MoviesResp{ MovieList: dynamic, Dynamic: true, } + return resp, nil } @@ -367,6 +396,7 @@ func PushMovie(ctx *gin.Context) { m, err := user.AddRoomMovie(room, (*dbModel.MovieBase)(&req)) if err != nil { log.Errorf("push movie error: %v", err) + if errors.Is(err, dbModel.ErrNoPermission) { ctx.AbortWithStatusJSON( http.StatusForbidden, @@ -374,9 +404,12 @@ func PushMovie(ctx *gin.Context) { fmt.Errorf("push movie error: %w", err), ), ) + return } + ctx.AbortWithStatusJSON(http.StatusBadRequest, model.NewAPIErrorResp(err)) + return } @@ -403,6 +436,7 @@ func PushMovies(ctx *gin.Context) { m, err := user.AddRoomMovies(room, ms) if err != nil { log.Errorf("push movies error: %v", err) + if errors.Is(err, dbModel.ErrNoPermission) { ctx.AbortWithStatusJSON( http.StatusForbidden, @@ -410,9 +444,12 @@ func PushMovies(ctx *gin.Context) { fmt.Errorf("push movies error: %w", err), ), ) + return } + ctx.AbortWithStatusJSON(http.StatusBadRequest, model.NewAPIErrorResp(err)) + return } @@ -428,6 +465,7 @@ func NewPublishKey(ctx *gin.Context) { http.StatusBadRequest, model.NewAPIErrorStringResp("rtmp is not enabled"), ) + return } @@ -440,6 +478,7 @@ func NewPublishKey(ctx *gin.Context) { ctx.AbortWithStatusJSON(http.StatusBadRequest, model.NewAPIErrorResp(err)) return } + movie, err := room.GetMovieByID(req.ID) if err != nil { log.Errorf("new publish key error: %v", err) @@ -455,6 +494,7 @@ func NewPublishKey(ctx *gin.Context) { fmt.Errorf("new publish key error: %w", dbModel.ErrNoPermission), ), ) + return } @@ -464,6 +504,7 @@ func NewPublishKey(ctx *gin.Context) { http.StatusBadRequest, model.NewAPIErrorStringResp("only live movie can get publish key"), ) + return } @@ -482,8 +523,10 @@ func NewPublishKey(ctx *gin.Context) { ctx.AbortWithStatusJSON(http.StatusInternalServerError, model.NewAPIErrorResp(err)) return } + host = u.Host } + if host == "" { host = ctx.Request.Host } @@ -509,6 +552,7 @@ func EditMovie(ctx *gin.Context) { if err := user.UpdateRoomMovie(room, req.ID, (*dbModel.MovieBase)(&req.PushMovieReq)); err != nil { log.Errorf("edit movie error: %v", err) + if errors.Is(err, dbModel.ErrNoPermission) { ctx.AbortWithStatusJSON( http.StatusForbidden, @@ -516,9 +560,12 @@ func EditMovie(ctx *gin.Context) { fmt.Errorf("edit movie error: %w", err), ), ) + return } + ctx.AbortWithStatusJSON(http.StatusBadRequest, model.NewAPIErrorResp(err)) + return } @@ -540,6 +587,7 @@ func DelMovie(ctx *gin.Context) { err := user.DeleteRoomMoviesByID(room, req.IDs) if err != nil { log.Errorf("del movie error: %v", err) + if errors.Is(err, dbModel.ErrNoPermission) { ctx.AbortWithStatusJSON( http.StatusForbidden, @@ -547,9 +595,12 @@ func DelMovie(ctx *gin.Context) { fmt.Errorf("del movie error: %w", err), ), ) + return } + ctx.AbortWithStatusJSON(http.StatusBadRequest, model.NewAPIErrorResp(err)) + return } @@ -574,9 +625,12 @@ func ClearMovies(ctx *gin.Context) { fmt.Errorf("clear movies error: %w", err), ), ) + return } + ctx.AbortWithStatusJSON(http.StatusBadRequest, model.NewAPIErrorResp(err)) + return } @@ -607,6 +661,7 @@ func ChangeCurrentMovie(ctx *gin.Context) { log := middlewares.GetLogger(ctx) req := model.SetRoomCurrentMovieReq{} + err := model.Decode(ctx, &req) if err != nil { ctx.AbortWithStatusJSON(http.StatusBadRequest, model.NewAPIErrorResp(err)) @@ -616,6 +671,7 @@ func ChangeCurrentMovie(ctx *gin.Context) { err = user.SetRoomCurrentMovie(room, req.ID, req.SubPath, true) if err != nil { log.Errorf("change current movie error: %v", err) + if errors.Is(err, dbModel.ErrNoPermission) { ctx.AbortWithStatusJSON( http.StatusForbidden, @@ -623,9 +679,12 @@ func ChangeCurrentMovie(ctx *gin.Context) { fmt.Errorf("change current movie error: %w", err), ), ) + return } + ctx.AbortWithStatusJSON(http.StatusBadRequest, model.NewAPIErrorResp(err)) + return } @@ -652,7 +711,9 @@ func ProxyMovie(ctx *gin.Context) { ctx.AbortWithStatusJSON(http.StatusBadRequest, model.NewAPIErrorResp(err)) return } + vendor.ProxyMovie(ctx) + return } @@ -662,6 +723,7 @@ func ProxyMovie(ctx *gin.Context) { http.StatusBadRequest, model.NewAPIErrorStringResp("proxy is not enabled"), ) + return } @@ -670,6 +732,7 @@ func ProxyMovie(ctx *gin.Context) { http.StatusBadRequest, model.NewAPIErrorStringResp("movie is not proxy"), ) + return } @@ -680,6 +743,7 @@ func ProxyMovie(ctx *gin.Context) { "this movie is live or rtmp source, not support use this method proxy", ), ) + return } @@ -713,6 +777,7 @@ func ServeM3u8(ctx *gin.Context) { http.StatusBadRequest, model.NewAPIErrorStringResp("movie proxy is not enabled"), ) + return } @@ -732,6 +797,7 @@ func ServeM3u8(ctx *gin.Context) { "this movie is rtmp source, not support use this method proxy", ), ) + return } @@ -740,20 +806,24 @@ func ServeM3u8(ctx *gin.Context) { http.StatusBadRequest, model.NewAPIErrorStringResp("movie is not proxy"), ) + return } targetToken := ctx.Param("targetToken") + claims, err := proxy.GetM3u8Target(targetToken) if err != nil { log.Errorf("auth m3u8 error: %v", err) ctx.AbortWithStatusJSON(http.StatusBadRequest, model.NewAPIErrorResp(err)) return } + if claims.RoomID != room.ID || claims.MovieID != m.ID { ctx.AbortWithStatusJSON(http.StatusBadRequest, model.NewAPIErrorStringResp("invalid token")) return } + err = proxy.M3u8(ctx, claims.TargetURL, m.Headers, @@ -819,20 +889,24 @@ func JoinFlvLive(ctx *gin.Context) { ctx.Header("Cache-Control", "no-store") room := middlewares.GetRoomEntry(ctx).Value() movieID := strings.TrimSuffix(strings.Trim(ctx.Param("movieId"), "/"), ".flv") + m, err := room.GetMovieByID(movieID) if err != nil { log.Errorf("join flv live error: %v", err) ctx.AbortWithStatusJSON(http.StatusNotFound, model.NewAPIErrorResp(err)) return } + if !m.Live { log.Error("join hls live error: live is not enabled") ctx.AbortWithStatusJSON( http.StatusBadRequest, model.NewAPIErrorStringResp("live is not enabled"), ) + return } + if m.RtmpSource { if !conf.Conf.Server.RTMP.Enable { log.Error("join hls live error: rtmp is not enabled") @@ -840,6 +914,7 @@ func JoinFlvLive(ctx *gin.Context) { http.StatusBadRequest, model.NewAPIErrorStringResp("rtmp is not enabled"), ) + return } } else if !settings.LiveProxy.Get() { @@ -847,6 +922,7 @@ func JoinFlvLive(ctx *gin.Context) { ctx.AbortWithStatusJSON(http.StatusBadRequest, model.NewAPIErrorStringResp("live proxy is not enabled")) return } + channel, err := m.Channel() if err != nil { log.Errorf("join flv live error: %v", err) @@ -856,12 +932,14 @@ func JoinFlvLive(ctx *gin.Context) { w := httpflv.NewHttpFLVWriter(ctx.Writer) defer w.Close() + err = channel.AddPlayer(w) if err != nil { log.Errorf("join flv live error: %v", err) ctx.AbortWithStatusJSON(http.StatusNotFound, model.NewAPIErrorResp(err)) return } + err = w.SendPacket(ctx.Request.Context()) if err != nil { log.Errorf("join flv live error: %v", err) @@ -875,20 +953,24 @@ func JoinHlsLive(ctx *gin.Context) { ctx.Header("Cache-Control", "no-store") room := middlewares.GetRoomEntry(ctx).Value() movieID := strings.TrimSuffix(strings.Trim(ctx.Param("movieId"), "/"), ".m3u8") + m, err := room.GetMovieByID(movieID) if err != nil { log.Errorf("join hls live error: %v", err) ctx.AbortWithStatusJSON(http.StatusNotFound, model.NewAPIErrorResp(err)) return } + if !m.Live { log.Error("join hls live error: live is not enabled") ctx.AbortWithStatusJSON( http.StatusBadRequest, model.NewAPIErrorStringResp("live is not enabled"), ) + return } + if m.RtmpSource { if !conf.Conf.Server.RTMP.Enable { log.Error("join hls live error: rtmp is not enabled") @@ -896,6 +978,7 @@ func JoinHlsLive(ctx *gin.Context) { http.StatusBadRequest, model.NewAPIErrorStringResp("rtmp is not enabled"), ) + return } } else if !settings.LiveProxy.Get() { @@ -917,8 +1000,10 @@ func JoinHlsLive(ctx *gin.Context) { if err != nil { log.Errorf("proxy m3u8 hls live error: %v", err) } + return } + channel, err := m.Channel() if err != nil { log.Errorf("join hls live error: %v", err) @@ -931,6 +1016,7 @@ func JoinHlsLive(ctx *gin.Context) { if settings.TSDisguisedAsPng.Get() { ext = "png" } + return fmt.Sprintf( "/api/room/movie/live/hls/data/%s/%s/%s.%s", room.ID, @@ -944,6 +1030,7 @@ func JoinHlsLive(ctx *gin.Context) { ctx.AbortWithStatusJSON(http.StatusNotFound, model.NewAPIErrorResp(err)) return } + ctx.Data(http.StatusOK, hls.M3U8ContentType, b) } @@ -951,31 +1038,37 @@ func JoinHlsLive(ctx *gin.Context) { func ServeHlsLive(ctx *gin.Context) { log := middlewares.GetLogger(ctx) roomID := ctx.Param("roomId") + roomE, err := op.LoadRoomByID(roomID) if err != nil { log.Errorf("serve hls live error: %v", err) ctx.AbortWithStatusJSON(http.StatusNotFound, model.NewAPIErrorResp(err)) return } + room := roomE.Value() ctx.Header("Cache-Control", "public, max-age=30, s-maxage=90") movieID := ctx.Param("movieId") + m, err := room.GetMovieByID(movieID) if err != nil { log.Errorf("serve hls live error: %v", err) ctx.AbortWithStatusJSON(http.StatusNotFound, model.NewAPIErrorResp(err)) return } + if !m.Live { log.Error("join hls live error: live is not enabled") ctx.AbortWithStatusJSON( http.StatusBadRequest, model.NewAPIErrorStringResp("live is not enabled"), ) + return } + if m.RtmpSource { if !conf.Conf.Server.RTMP.Enable { log.Error("join hls live error: rtmp is not enabled") @@ -983,6 +1076,7 @@ func ServeHlsLive(ctx *gin.Context) { http.StatusBadRequest, model.NewAPIErrorStringResp("rtmp is not enabled"), ) + return } } else if !settings.LiveProxy.Get() { @@ -990,6 +1084,7 @@ func ServeHlsLive(ctx *gin.Context) { ctx.AbortWithStatusJSON(http.StatusBadRequest, model.NewAPIErrorStringResp("live proxy is not enabled")) return } + channel, err := m.Channel() if err != nil { log.Errorf("serve hls live error: %v", err) @@ -1006,14 +1101,17 @@ func ServeHlsLive(ctx *gin.Context) { http.StatusNotFound, model.NewAPIErrorResp(FormatNotSupportFileTypeError(fileExt)), ) + return } + b, err := channel.GetTsFile(strings.TrimSuffix(dataID, fileExt)) if err != nil { log.Errorf("serve hls live error: %v", err) ctx.AbortWithStatusJSON(http.StatusNotFound, model.NewAPIErrorResp(err)) return } + ctx.Header("Cache-Control", "public, max-age=90") ctx.Data(http.StatusOK, hls.TSContentType, b) case ".png": @@ -1023,24 +1121,31 @@ func ServeHlsLive(ctx *gin.Context) { http.StatusNotFound, model.NewAPIErrorResp(FormatNotSupportFileTypeError(fileExt)), ) + return } + b, err := channel.GetTsFile(strings.TrimSuffix(dataID, fileExt)) if err != nil { log.Errorf("serve hls live error: %v", err) ctx.AbortWithStatusJSON(http.StatusNotFound, model.NewAPIErrorResp(err)) return } + ctx.Header("Cache-Control", "public, max-age=90") + img := image.NewGray(image.Rect(0, 0, 1, 1)) img.Set(1, 1, color.Gray{uint8(rand.IntN(255))}) + cache := bytes.NewBuffer(make([]byte, 0, 71)) + err = png.Encode(cache, img) if err != nil { log.Errorf("serve hls live error: %v", err) ctx.AbortWithStatusJSON(http.StatusInternalServerError, model.NewAPIErrorResp(err)) return } + ctx.Data(http.StatusOK, "image/png", append(cache.Bytes(), b...)) default: ctx.Header("Cache-Control", "no-store") diff --git a/server/handlers/proxy/buffer.go b/server/handlers/proxy/buffer.go index 1b6dd88..1273746 100644 --- a/server/handlers/proxy/buffer.go +++ b/server/handlers/proxy/buffer.go @@ -22,6 +22,7 @@ func getBuffer() *[]byte { if !ok { panic("sharedBufferPool.Get() returned a non-[]byte value") } + return buf } @@ -32,26 +33,32 @@ func putBuffer(buffer *[]byte) { func copyBuffer(dst io.Writer, src io.Reader) (written int64, err error) { buf := getBuffer() defer putBuffer(buf) + for { nr, er := src.Read(*buf) if nr > 0 { nw, ew := dst.Write((*buf)[0:nr]) if nw < 0 || nr < nw { nw = 0 + if ew == nil { ew = errors.New("invalid write result") } } + written += int64(nw) + if ew != nil { err = ew break } + if nr != nw { err = io.ErrShortWrite break } } + if er != nil { if er != io.EOF { err = er @@ -59,5 +66,6 @@ func copyBuffer(dst io.Writer, src io.Reader) (written int64, err error) { break } } + return written, err } diff --git a/server/handlers/proxy/cache.go b/server/handlers/proxy/cache.go index ac2092e..c62388a 100644 --- a/server/handlers/proxy/cache.go +++ b/server/handlers/proxy/cache.go @@ -63,9 +63,11 @@ func (i *CacheItem) WriteTo(w io.Writer) (int64, error) { if err := binary.Write(w, binary.BigEndian, int64(len(metadata))); err != nil { return written, fmt.Errorf("failed to write metadata length: %w", err) } + written += 8 n, err := w.Write(metadata) + written += int64(n) if err != nil { return written, fmt.Errorf("failed to write metadata bytes: %w", err) @@ -75,9 +77,11 @@ func (i *CacheItem) WriteTo(w io.Writer) (int64, error) { if err := binary.Write(w, binary.BigEndian, int64(len(i.Data))); err != nil { return written, fmt.Errorf("failed to write data length: %w", err) } + written += 8 n, err = w.Write(i.Data) + written += int64(n) if err != nil { return written, fmt.Errorf("failed to write data bytes: %w", err) @@ -99,6 +103,7 @@ func (i *CacheItem) ReadFrom(r io.Reader) (int64, error) { if err := binary.Read(r, binary.BigEndian, &metadataLen); err != nil { return read, fmt.Errorf("failed to read metadata length: %w", err) } + read += 8 if metadataLen <= 0 { @@ -107,6 +112,7 @@ func (i *CacheItem) ReadFrom(r io.Reader) (int64, error) { metadata := make([]byte, metadataLen) n, err := io.ReadFull(r, metadata) + read += int64(n) if err != nil { return read, fmt.Errorf("failed to read metadata bytes: %w", err) @@ -122,6 +128,7 @@ func (i *CacheItem) ReadFrom(r io.Reader) (int64, error) { if err := binary.Read(r, binary.BigEndian, &dataLen); err != nil { return read, fmt.Errorf("failed to read data length: %w", err) } + read += 8 if dataLen < 0 { @@ -130,6 +137,7 @@ func (i *CacheItem) ReadFrom(r io.Reader) (int64, error) { i.Data = make([]byte, dataLen) n, err = io.ReadFull(r, i.Data) + read += int64(n) if err != nil { return read, fmt.Errorf("failed to read data bytes: %w", err) @@ -186,6 +194,7 @@ func NewMemoryCache(capacity int, opts ...MemoryCacheOption) *MemoryCache { for _, opt := range opts { opt(mc) } + return mc } @@ -195,6 +204,7 @@ func (c *MemoryCache) Get(key string) (*CacheItem, bool, error) { } c.mu.RLock() + element, exists := c.m[key] if !exists { c.mu.RUnlock() @@ -206,6 +216,7 @@ func (c *MemoryCache) Get(key string) (*CacheItem, bool, error) { c.mu.Lock() c.lruList.MoveToFront(element) item := element.Value.item + c.mu.Unlock() return item, true, nil @@ -231,15 +242,18 @@ func (c *MemoryCache) GetAnyWithPrefix(prefix string) (*CacheItem, bool, error) // DFS to find first complete key var findKey func(*TrieNode) string + findKey = func(n *TrieNode) string { if n.isEnd { return n.key } + for _, child := range n.children { if key := findKey(child); key != "" { return key } } + return "" } @@ -256,6 +270,7 @@ func (c *MemoryCache) Set(key string, data *CacheItem) error { if key == "" { return errors.New("cache key cannot be empty") } + if data == nil { return errors.New("cannot cache nil CacheItem") } @@ -279,6 +294,7 @@ func (c *MemoryCache) Set(key string, data *CacheItem) error { c.lruList.MoveToFront(element) element.Value.item = data element.Value.size = newSize + return nil } @@ -297,6 +313,7 @@ func (c *MemoryCache) Set(key string, data *CacheItem) error { for _, ch := range entry.key { node = node.children[ch] } + node.isEnd = false node.key = "" } @@ -318,6 +335,7 @@ func (c *MemoryCache) Set(key string, data *CacheItem) error { node = node.children[ch] } } + node.isEnd = true node.key = key @@ -365,6 +383,7 @@ func NewFileCache(filePath string, opts ...FileCacheOption) *FileCache { } go fc.periodicCleanup() + return fc } @@ -408,8 +427,11 @@ func (c *FileCache) cleanup() { size int64 } - var files []fileInfo - var totalSize int64 + var ( + files []fileInfo + totalSize int64 + ) + cutoffTime := time.Now().Add(-c.maxAge) // Collect file information and remove expired files @@ -419,6 +441,7 @@ func (c *FileCache) cleanup() { } subdir := filepath.Join(c.filePath, entry.Name()) + subEntries, err := os.ReadDir(subdir) if err != nil { continue @@ -463,6 +486,7 @@ func (c *FileCache) cleanup() { if totalSize <= maxSize { break } + if err := os.Remove(file.path); err == nil { totalSize -= file.size } @@ -561,6 +585,7 @@ func (c *FileCache) Set(key string, data *CacheItem) error { if key == "" { return errors.New("cache key cannot be empty") } + if data == nil { return errors.New("cannot cache nil CacheItem") } @@ -579,6 +604,7 @@ func (c *FileCache) Set(key string, data *CacheItem) error { newSize += int64(len(metadataBytes)) } } + if c.currentSize.Load()+newSize > maxSize { c.cleanup() } diff --git a/server/handlers/proxy/m3u8.go b/server/handlers/proxy/m3u8.go index 3f6942c..3fa2ccd 100644 --- a/server/handlers/proxy/m3u8.go +++ b/server/handlers/proxy/m3u8.go @@ -34,10 +34,12 @@ func GetM3u8Target(token string) (*M3u8TargetClaims, error) { if err != nil || !t.Valid { return nil, errors.New("auth failed") } + claims, ok := t.Claims.(*M3u8TargetClaims) if !ok { return nil, errors.New("auth failed") } + return claims, nil } @@ -51,6 +53,7 @@ func NewM3u8TargetToken(targetURL, roomID, movieID string, isM3u8File bool) (str NotBefore: jwt.NewNumericDate(time.Now()), }, } + return jwt.NewWithClaims(jwt.SigningMethodHS256, claims). SignedString(stream.StringToBytes(conf.Conf.Jwt.Secret)) } @@ -59,6 +62,7 @@ const maxM3u8FileSize = 3 * 1024 * 1024 // func M3u8Data(ctx *gin.Context, data []byte, baseURL, token, roomID, movieID string) error { hasM3u8File := false + err := m3u8.RangeM3u8SegmentsWithBaseURL( stream.BytesToString(data), baseURL, @@ -67,6 +71,7 @@ func M3u8Data(ctx *gin.Context, data []byte, baseURL, token, roomID, movieID str hasM3u8File = true return false, nil } + return true, nil }, ) @@ -76,8 +81,10 @@ func M3u8Data(ctx *gin.Context, data []byte, baseURL, token, roomID, movieID str fmt.Sprintf("range m3u8 segments with base url error: %v", err), ), ) + return fmt.Errorf("range m3u8 segments with base url error: %w", err) } + m3u8Str, err := m3u8.ReplaceM3u8SegmentsWithBaseURL( stream.BytesToString(data), baseURL, @@ -86,6 +93,7 @@ func M3u8Data(ctx *gin.Context, data []byte, baseURL, token, roomID, movieID str if err != nil { return "", err } + return fmt.Sprintf( "/api/room/movie/proxy/%s/m3u8/%s?token=%s&roomId=%s", movieID, @@ -101,9 +109,12 @@ func M3u8Data(ctx *gin.Context, data []byte, baseURL, token, roomID, movieID str fmt.Sprintf("replace m3u8 segments with base url error: %v", err), ), ) + return fmt.Errorf("replace m3u8 segments with base url error: %w", err) } + ctx.Data(http.StatusOK, hls.M3U8ContentType, stream.StringToBytes(m3u8Str)) + return nil } @@ -119,6 +130,7 @@ func M3u8( if !isM3u8File { return URL(ctx, u, headers, opts...) } + if flags.Global.Dev { ctx.Header(proxyURLHeader, u) } @@ -130,14 +142,18 @@ func M3u8( fmt.Sprintf("new request error: %v", err), ), ) + return fmt.Errorf("new request error: %w", err) } + for k, v := range headers { req.Header.Set(k, v) } + if req.Header.Get("User-Agent") == "" { req.Header.Set("User-Agent", utils.UA) } + resp, err := uhc.Do(req) if err != nil { ctx.AbortWithStatusJSON(http.StatusBadRequest, @@ -145,6 +161,7 @@ func M3u8( fmt.Sprintf("do request error: %v", err), ), ) + return fmt.Errorf("do request error: %w", err) } defer resp.Body.Close() @@ -162,12 +179,14 @@ func M3u8( ), ), ) + return fmt.Errorf( "m3u8 file is too large: %d, max: %d (3MB)", resp.ContentLength, maxM3u8FileSize, ) } + b, err := io.ReadAll(io.LimitReader(resp.Body, maxM3u8FileSize)) if err != nil { ctx.AbortWithStatusJSON(http.StatusBadRequest, @@ -175,7 +194,9 @@ func M3u8( fmt.Sprintf("read response body error: %v", err), ), ) + return fmt.Errorf("read response body error: %w", err) } + return M3u8Data(ctx, b, u, token, roomID, movieID) } diff --git a/server/handlers/proxy/proxy.go b/server/handlers/proxy/proxy.go index 08c3f31..c94d1b1 100644 --- a/server/handlers/proxy/proxy.go +++ b/server/handlers/proxy/proxy.go @@ -35,6 +35,7 @@ func parseProxyCacheSize(sizeStr string) (int64, error) { if sizeStr == "" { return defaultCacheSize, nil } + sizeStr = strings.ToLower(sizeStr) sizeStr = strings.TrimSpace(sizeStr) @@ -66,20 +67,25 @@ func getCache() Cache { if err != nil { log.Fatalf("parse proxy cache size error: %v", err) } + if size == 0 { size = defaultCacheSize } + if conf.Conf.Server.ProxyCachePath == "" { log.Infof("proxy cache path is empty, use memory cache, size: %d", size) defaultCache = NewMemoryCache(0, WithMaxSizeBytes(size)) return } + log.Infof("proxy cache path: %s, size: %d", conf.Conf.Server.ProxyCachePath, size) fileCache = NewFileCache(conf.Conf.Server.ProxyCachePath, WithFileCacheMaxSizeBytes(size)) }) + if fileCache != nil { return fileCache } + return defaultCache } @@ -107,6 +113,7 @@ func NewProxyURLOptions(opts ...Option) *Options { for _, opt := range opts { opt(o) } + return o } @@ -119,7 +126,9 @@ func URL(ctx *gin.Context, u string, headers map[string]string, opts ...Option) if flags.Global.Dev { ctx.Header(proxyURLHeader, u) } + o := NewProxyURLOptions(opts...) + if !settings.AllowProxyToLocal.Get() { if l, err := utils.ParseURLIsLocalIP(u); err != nil { ctx.AbortWithStatusJSON(http.StatusBadRequest, @@ -127,6 +136,7 @@ func URL(ctx *gin.Context, u string, headers map[string]string, opts ...Option) fmt.Sprintf("check url is local ip error: %v", err), ), ) + return fmt.Errorf("check url is local ip error: %w", err) } else if l { ctx.AbortWithStatusJSON(http.StatusBadRequest, @@ -134,6 +144,7 @@ func URL(ctx *gin.Context, u string, headers map[string]string, opts ...Option) "not allow proxy to local", ), ) + return errors.New("not allow proxy to local") } } @@ -141,21 +152,25 @@ func URL(ctx *gin.Context, u string, headers map[string]string, opts ...Option) if o.Cache && settings.ProxyCacheEnable.Get() { c, cancel := context.WithCancel(ctx) defer cancel() + rsc := NewHTTPReadSeekCloser(u, WithContext(c), WithHeadersMap(headers), WithPerLength(sliceSize*3), ) defer rsc.Close() + if o.CacheKey == "" { o.CacheKey = u } + return NewSliceCacheProxy(o.CacheKey, sliceSize, rsc, getCache()). Proxy(ctx.Writer, ctx.Request) } ctx2, cf := context.WithCancel(ctx) defer cf() + req, err := http.NewRequestWithContext(ctx2, http.MethodGet, u, nil) if err != nil { ctx.AbortWithStatusJSON(http.StatusBadRequest, @@ -163,32 +178,41 @@ func URL(ctx *gin.Context, u string, headers map[string]string, opts ...Option) fmt.Sprintf("new request error: %v", err), ), ) + return fmt.Errorf("new request error: %w", err) } + for k, v := range headers { req.Header.Set(k, v) } + if r := ctx.GetHeader("Range"); r != "" { req.Header.Set("Range", r) } + if r := ctx.GetHeader("Accept-Encoding"); r != "" { req.Header.Set("Accept-Encoding", r) } + if req.Header.Get("User-Agent") == "" { req.Header.Set("User-Agent", utils.UA) } + cli := http.Client{ Transport: uhc.DefaultTransport, CheckRedirect: func(req *http.Request, _ []*http.Request) error { for k, v := range headers { req.Header.Set(k, v) } + if req.Header.Get("User-Agent") == "" { req.Header.Set("User-Agent", utils.UA) } + return nil }, } + resp, err := cli.Do(req) if err != nil { ctx.AbortWithStatusJSON(http.StatusBadRequest, @@ -196,15 +220,18 @@ func URL(ctx *gin.Context, u string, headers map[string]string, opts ...Option) fmt.Sprintf("request url error: %v", err), ), ) + return fmt.Errorf("request url error: %w", err) } defer resp.Body.Close() + ctx.Status(resp.StatusCode) ctx.Header("Accept-Ranges", resp.Header.Get("Accept-Ranges")) ctx.Header("Cache-Control", resp.Header.Get("Cache-Control")) ctx.Header("Content-Length", resp.Header.Get("Content-Length")) ctx.Header("Content-Range", resp.Header.Get("Content-Range")) ctx.Header("Content-Type", resp.Header.Get("Content-Type")) + _, err = copyBuffer(ctx.Writer, resp.Body) if err != nil && !errors.Is(err, io.EOF) { ctx.AbortWithStatusJSON(http.StatusBadRequest, @@ -212,8 +239,10 @@ func URL(ctx *gin.Context, u string, headers map[string]string, opts ...Option) fmt.Sprintf("copy response body error: %v", err), ), ) + return fmt.Errorf("copy response body error: %w", err) } + return nil } diff --git a/server/handlers/proxy/readseeker.go b/server/handlers/proxy/readseeker.go index 1a284ea..4187cf5 100644 --- a/server/handlers/proxy/readseeker.go +++ b/server/handlers/proxy/readseeker.go @@ -153,18 +153,23 @@ func (h *HTTPReadSeekCloser) fix() *HTTPReadSeekCloser { if h.method == "" { h.method = http.MethodGet } + if h.headMethod == "" { h.headMethod = http.MethodHead } + if h.ctx == nil { h.ctx = context.Background() } + if h.perLength <= 0 { h.perLength = 1024 * 1024 } + if h.headers == nil { h.headers = make(http.Header) } + if h.client == nil { h.client = &http.Client{ Transport: uhc.DefaultTransport, @@ -172,13 +177,16 @@ func (h *HTTPReadSeekCloser) fix() *HTTPReadSeekCloser { for k, v := range h.headers { req.Header[k] = v } + if req.Header.Get("User-Agent") == "" { req.Header.Set("User-Agent", utils.UA) } + return nil }, } } + return h } @@ -201,11 +209,14 @@ func (h *HTTPReadSeekCloser) Read(p []byte) (n int, err error) { if err == io.EOF { h.closeCurrentResp() + if n < len(p) { continue } + break } + if err != nil { if n > 0 { return n, nil @@ -246,7 +257,9 @@ func (h *HTTPReadSeekCloser) FetchNextChunk() error { if err == nil && contentTotalLength > 0 { h.contentTotalLength = contentTotalLength } + resp.Body.Close() + return fmt.Errorf( "requested range not satisfiable, content total length: %d, offset: %d", h.contentTotalLength, @@ -268,6 +281,7 @@ func (h *HTTPReadSeekCloser) FetchNextChunk() error { if h.notSupportSeekWhenNotSupportRange { return errors.New("not support seek when not support range") } + if _, err := io.CopyN(io.Discard, resp.Body, h.offset); err != nil { resp.Body.Close() return fmt.Errorf("failed to discard bytes: %w", err) @@ -278,6 +292,7 @@ func (h *HTTPReadSeekCloser) FetchNextChunk() error { h.contentTotalLength = resp.ContentLength h.currentRespMaxOffset = h.contentTotalLength - 1 h.currentResp = resp + return nil } @@ -285,17 +300,20 @@ func (h *HTTPReadSeekCloser) FetchNextChunk() error { if err == nil && contentTotalLength > 0 { h.contentTotalLength = contentTotalLength } + start, end, err := ParseContentRangeStartAndEnd(resp.Header.Get("Content-Range")) if err == nil { if end != -1 { h.currentRespMaxOffset = end } + if h.offset != start { return fmt.Errorf("offset mismatch, expected: %d, got: %d", start, h.offset) } } h.currentResp = resp + return nil } @@ -320,6 +338,7 @@ func (h *HTTPReadSeekCloser) createRequest() (*http.Request, error) { h.currentRespMaxOffset = end req.Header.Set("Range", fmt.Sprintf("bytes=%d-%d", h.offset, end)) + return req, nil } @@ -328,11 +347,14 @@ func (h *HTTPReadSeekCloser) createRequestWithoutRange() (*http.Request, error) if err != nil { return nil, fmt.Errorf("failed to create request: %w", err) } + req.Header = h.headers.Clone() req.Header.Del("Range") + if req.Header.Get("User-Agent") == "" { req.Header.Set("User-Agent", utils.UA) } + return req, nil } @@ -353,6 +375,7 @@ func (h *HTTPReadSeekCloser) checkContentType(ct string) error { ) } } + return nil } @@ -370,9 +393,11 @@ func (h *HTTPReadSeekCloser) Seek(offset int64, whence int) (int64, error) { if newOffset != h.offset { h.closeCurrentResp() + if h.notSupportRange && h.notSupportSeekWhenNotSupportRange { return 0, errors.New("seek is not supported when not support range") } + h.offset = newOffset } @@ -391,6 +416,7 @@ func (h *HTTPReadSeekCloser) calculateNewOffset(offset int64, whence int) (int64 return 0, fmt.Errorf("failed to fetch content length: %w", err) } } + return h.contentTotalLength - offset, nil default: return 0, fmt.Errorf("invalid seek whence value: %d", whence) @@ -402,6 +428,7 @@ func (h *HTTPReadSeekCloser) fetchContentLength() error { if err != nil { return err } + req.Method = h.headMethod resp, err := h.client.Do(req) @@ -426,6 +453,7 @@ func (h *HTTPReadSeekCloser) fetchContentLength() error { h.contentTotalLength = resp.ContentLength h.headHeaders = resp.Header.Clone() + return nil } @@ -455,6 +483,7 @@ func (h *HTTPReadSeekCloser) ContentTotalLength() (int64, error) { if h.contentTotalLength > 0 { return h.contentTotalLength, nil } + return 0, errors.New( "content total length is not available - no successful response received yet", ) @@ -494,6 +523,7 @@ func ParseContentRangeStartAndEnd(contentRange string) (int64, int64, error) { } rangeParts[1] = strings.TrimSpace(rangeParts[1]) + var end int64 if rangeParts[1] == "" || rangeParts[1] == "*" { end = -1 diff --git a/server/handlers/proxy/slice.go b/server/handlers/proxy/slice.go index 1b5d859..3cbb4e8 100644 --- a/server/handlers/proxy/slice.go +++ b/server/handlers/proxy/slice.go @@ -59,16 +59,19 @@ func fmtContentRange(start, end, total int64) string { if total == -1 && end == -1 { return "bytes */*" } + totalStr := "*" if total >= 0 { totalStr = strconv.FormatInt(total, 10) } + if end == -1 { if total >= 0 { end = total - 1 } return fmt.Sprintf("bytes %d-%d/%s", start, end, totalStr) } + return fmt.Sprintf("bytes %d-%d/%s", start, end, totalStr) } @@ -76,12 +79,15 @@ func contentLength(start, end, total int64) int64 { if total == -1 && end == -1 { return -1 } + if end == -1 { return total - start } + if end >= total && total != -1 { return total - start } + return end - start + 1 } @@ -90,6 +96,7 @@ func fmtContentLength(start, end, total int64) string { if length == -1 { return "" } + return strconv.FormatInt(length, 10) } @@ -106,6 +113,7 @@ func (c *SliceCacheProxy) Proxy(w http.ResponseWriter, r *http.Request) error { } alignedOffset := alignedOffset(byteRange.Start, c.sliceSize) + cacheItem, cached, err := c.getCacheItem(alignedOffset) if err != nil { http.Error( @@ -113,13 +121,16 @@ func (c *SliceCacheProxy) Proxy(w http.ResponseWriter, r *http.Request) error { fmt.Sprintf("Failed to get cache item: %v", err), http.StatusInternalServerError, ) + return fmt.Errorf("failed to get cache item: %w", err) } c.setResponseHeaders(w, byteRange, cacheItem, cached, r.Header.Get("Range") != "") + if err := c.writeResponse(w, byteRange, alignedOffset, cacheItem); err != nil { return fmt.Errorf("failed to write response: %w", err) } + return nil } @@ -146,10 +157,12 @@ func (c *SliceCacheProxy) setResponseHeaders( } else { w.Header().Set(cacheStatusHeader, "MISS") } + w.Header().Set("Accept-Ranges", "bytes") w.Header(). Set("Content-Length", fmtContentLength(byteRange.Start, byteRange.End, cacheItem.Metadata.ContentTotalLength)) w.Header().Set("Content-Type", cacheItem.Metadata.ContentType) + if isRangeRequest { w.Header(). Set("Content-Range", fmtContentRange(byteRange.Start, byteRange.End, cacheItem.Metadata.ContentTotalLength)) @@ -185,10 +198,12 @@ func (c *SliceCacheProxy) writeResponse( if n > remainingLength { n = remainingLength } + if n > 0 { if _, err := w.Write(cacheItem.Data[sliceOffset : sliceOffset+n]); err != nil { return fmt.Errorf("failed to write initial data slice: %w", err) } + remainingLength -= n } } @@ -205,12 +220,15 @@ func (c *SliceCacheProxy) writeResponse( if n > remainingLength { n = remainingLength } + if n > 0 { if _, err := w.Write(cacheItem.Data[:n]); err != nil { return fmt.Errorf("failed to write data slice at offset %d: %w", currentOffset, err) } + remainingLength -= n } + currentOffset += c.sliceSize } @@ -226,6 +244,7 @@ func (c *SliceCacheProxy) getCacheItem(alignedOffset int64) (*CacheItem, bool, e } cacheKey := cacheKey(c.key, alignedOffset, c.sliceSize) + mu.Lock(cacheKey) defer mu.Unlock(cacheKey) @@ -234,6 +253,7 @@ func (c *SliceCacheProxy) getCacheItem(alignedOffset int64) (*CacheItem, bool, e if err != nil { return nil, false, fmt.Errorf("failed to get item from cache: %w", err) } + if ok { return slice, true, nil } @@ -257,9 +277,11 @@ func (c *SliceCacheProxy) contentTotalLength() (int64, error) { if err != nil { return -1, fmt.Errorf("failed to get content total length from source: %w", err) } + if total == -1 { return -1, errors.New("source does not support range requests") } + return total, nil } @@ -267,6 +289,7 @@ func (c *SliceCacheProxy) fetchFromSource(offset int64) (*CacheItem, error) { if offset < 0 { return nil, fmt.Errorf("source offset cannot be negative, got: %d", offset) } + if _, err := c.r.Seek(offset, io.SeekStart); err != nil { return nil, fmt.Errorf("failed to seek to offset %d in source: %w", offset, err) } @@ -274,6 +297,7 @@ func (c *SliceCacheProxy) fetchFromSource(offset int64) (*CacheItem, error) { var total int64 = -1 buf := make([]byte, c.sliceSize) + n, err := io.ReadFull(c.r, buf) if err != nil { if !errors.Is(err, io.ErrUnexpectedEOF) { @@ -284,10 +308,12 @@ func (c *SliceCacheProxy) fetchFromSource(offset int64) (*CacheItem, error) { err, ) } + total, err = c.contentTotalLength() if err != nil { return nil, fmt.Errorf("failed to get content total length from source: %w", err) } + if total != offset+int64(n) { return nil, fmt.Errorf( "source content total length mismatch, got: %d, expected: %d, %w", @@ -350,6 +376,7 @@ func ParseByteRange(r string) (*ByteRange, error) { } r = strings.TrimPrefix(r, "bytes=") + parts := strings.Split(r, "-") if len(parts) != 2 { return nil, fmt.Errorf( @@ -365,14 +392,17 @@ func ParseByteRange(r string) (*ByteRange, error) { return nil, fmt.Errorf("range header cannot have empty start and end values: %s", r) } - var start, end int64 = 0, -1 - var err error + var ( + start, end int64 = 0, -1 + err error + ) if parts[0] != "" { start, err = strconv.ParseInt(parts[0], 10, 64) if err != nil { return nil, fmt.Errorf("failed to parse range start value '%s': %w", parts[0], err) } + if start < 0 { return nil, fmt.Errorf("range start value must be non-negative, got: %d", start) } @@ -383,9 +413,11 @@ func ParseByteRange(r string) (*ByteRange, error) { if err != nil { return nil, fmt.Errorf("failed to parse range end value '%s': %w", parts[1], err) } + if end < 0 { return nil, fmt.Errorf("range end value must be non-negative, got: %d", end) } + if start > end { return nil, fmt.Errorf( "range start value (%d) cannot be greater than end value (%d)", diff --git a/server/handlers/public.go b/server/handlers/public.go index be30f0b..c35dd07 100644 --- a/server/handlers/public.go +++ b/server/handlers/public.go @@ -32,6 +32,7 @@ func Settings(ctx *gin.Context) { ctx.AbortWithStatusJSON(http.StatusInternalServerError, model.NewAPIErrorResp(err)) return } + ctx.JSON(200, model.NewAPIDataResp( &publicSettings{ PasswordDisableSignup: settings.DisableUserSignup.Get() || diff --git a/server/handlers/room.go b/server/handlers/room.go index 4847efd..759e514 100644 --- a/server/handlers/room.go +++ b/server/handlers/room.go @@ -100,6 +100,7 @@ func CreateRoom(ctx *gin.Context) { http.StatusBadRequest, model.NewAPIErrorStringResp("create room is disabled"), ) + return } @@ -217,6 +218,7 @@ func RoomList(ctx *gin.Context) { ctx.AbortWithStatusJSON(http.StatusInternalServerError, model.NewAPIErrorResp(err)) return } + scopes = append( scopes, db.WhereRoomNameLikeOrCreatorInOrRoomsIDLike(keyword, ids, keyword), @@ -230,6 +232,7 @@ func RoomList(ctx *gin.Context) { ctx.AbortWithStatusJSON(http.StatusInternalServerError, model.NewAPIErrorResp(err)) return } + scopes = append(scopes, db.WhereCreatorIDIn(ids)) case "id": scopes = append(scopes, db.WhereRoomsIDLike(keyword)) @@ -263,6 +266,7 @@ func RoomList(ctx *gin.Context) { http.StatusBadRequest, model.NewAPIErrorStringResp("not support sort"), ) + return } @@ -284,6 +288,7 @@ func genRoomListResp(scopes ...func(db *gorm.DB) *gorm.DB) ([]*model.RoomListRes if err != nil { return nil, err } + resp := make([]*model.RoomListResp, len(rs)) for i, r := range rs { resp[i] = &model.RoomListResp{ @@ -297,6 +302,7 @@ func genRoomListResp(scopes ...func(db *gorm.DB) *gorm.DB) ([]*model.RoomListRes Status: r.Status, } } + return resp, nil } @@ -305,11 +311,13 @@ func genJoinedRoomListResp(scopes ...func(db *gorm.DB) *gorm.DB) ([]*model.Joine if err != nil { return nil, err } + resp := make([]*model.JoinedRoomResp, len(rs)) for i, r := range rs { if len(r.RoomMembers) == 0 { return nil, fmt.Errorf("room %s load member failed", r.ID) } + resp[i] = &model.JoinedRoomResp{ RoomListResp: model.RoomListResp{ RoomID: r.ID, @@ -325,11 +333,13 @@ func genJoinedRoomListResp(scopes ...func(db *gorm.DB) *gorm.DB) ([]*model.Joine MemberRole: r.RoomMembers[0].Role, } } + return resp, nil } func CheckRoom(ctx *gin.Context) { log := middlewares.GetLogger(ctx) + roomID, err := middlewares.GetRoomIDFromContext(ctx) if err != nil { log.Errorf("check room failed: %v", err) @@ -343,6 +353,7 @@ func CheckRoom(ctx *gin.Context) { ctx.AbortWithStatusJSON(http.StatusNotFound, model.NewAPIErrorResp(err)) return } + room := roomE.Value() ctx.JSON(http.StatusOK, model.NewAPIDataResp(&model.CheckRoomResp{ @@ -373,6 +384,7 @@ func LoginRoom(ctx *gin.Context) { ctx.AbortWithStatusJSON(http.StatusBadRequest, model.NewAPIErrorResp(err)) return } + room := roomE.Value() if room.IsBanned() { @@ -387,6 +399,7 @@ func LoginRoom(ctx *gin.Context) { http.StatusForbidden, model.NewAPIErrorStringResp("room is pending, please wait for admin to approve"), ) + return } @@ -397,6 +410,7 @@ func LoginRoom(ctx *gin.Context) { "permissions": member.Permissions, "adminPermissions": member.AdminPermissions, })) + return } @@ -416,10 +430,13 @@ func LoginRoom(ctx *gin.Context) { errors.New("this room was disabled join new user"), ), ) + return } + log.Errorf("login room failed: %v", err) ctx.AbortWithStatusJSON(http.StatusInternalServerError, model.NewAPIErrorResp(err)) + return } @@ -454,6 +471,7 @@ func DeleteRoom(ctx *gin.Context) { if err := user.DeleteRoom(room); err != nil { log.Errorf("delete room failed: %v", err) + if errors.Is(err, dbModel.ErrNoPermission) { ctx.AbortWithStatusJSON( http.StatusForbidden, @@ -461,9 +479,12 @@ func DeleteRoom(ctx *gin.Context) { fmt.Errorf("delete room failed: %w", err), ), ) + return } + ctx.AbortWithStatusJSON(http.StatusBadRequest, model.NewAPIErrorResp(err)) + return } @@ -484,6 +505,7 @@ func SetRoomPassword(ctx *gin.Context) { if err := user.SetRoomPassword(room, req.Password); err != nil { log.Errorf("set room password failed: %v", err) + if errors.Is(err, dbModel.ErrNoPermission) { ctx.AbortWithStatusJSON( http.StatusForbidden, @@ -491,9 +513,12 @@ func SetRoomPassword(ctx *gin.Context) { fmt.Errorf("set room password failed: %w", err), ), ) + return } + ctx.AbortWithStatusJSON(http.StatusBadRequest, model.NewAPIErrorResp(err)) + return } @@ -521,6 +546,7 @@ func SetRoomSetting(ctx *gin.Context) { if err := user.UpdateRoomSettings(room, req); err != nil { log.Errorf("set room setting failed: %v", err) + if errors.Is(err, dbModel.ErrNoPermission) { ctx.AbortWithStatusJSON( http.StatusForbidden, @@ -528,9 +554,12 @@ func SetRoomSetting(ctx *gin.Context) { fmt.Errorf("set room setting failed: %w", err), ), ) + return } + ctx.AbortWithStatusJSON(http.StatusBadRequest, model.NewAPIErrorResp(err)) + return } diff --git a/server/handlers/root.go b/server/handlers/root.go index 782e865..87e966b 100644 --- a/server/handlers/root.go +++ b/server/handlers/root.go @@ -26,8 +26,10 @@ func RootAddAdmin(ctx *gin.Context) { http.StatusBadRequest, model.NewAPIErrorStringResp("cannot add yourself"), ) + return } + u, err := op.LoadOrInitUserByID(req.ID) if err != nil { log.Errorf("failed to load user: %v", err) @@ -35,14 +37,17 @@ func RootAddAdmin(ctx *gin.Context) { http.StatusInternalServerError, model.NewAPIErrorStringResp("user not found"), ) + return } + if u.Value().IsAdmin() { log.Errorf("user is already admin") ctx.AbortWithStatusJSON( http.StatusBadRequest, model.NewAPIErrorStringResp("user is already admin"), ) + return } @@ -72,8 +77,10 @@ func RootDeleteAdmin(ctx *gin.Context) { http.StatusBadRequest, model.NewAPIErrorStringResp("cannot remove yourself"), ) + return } + u, err := op.LoadOrInitUserByID(req.ID) if err != nil { log.Errorf("failed to load user: %v", err) @@ -81,14 +88,17 @@ func RootDeleteAdmin(ctx *gin.Context) { http.StatusInternalServerError, model.NewAPIErrorStringResp("user not found"), ) + return } + if u.Value().IsRoot() { log.Errorf("cannot remove root") ctx.AbortWithStatusJSON( http.StatusBadRequest, model.NewAPIErrorStringResp("cannot remove root"), ) + return } diff --git a/server/handlers/user.go b/server/handlers/user.go index 1ed2d2c..c931439 100644 --- a/server/handlers/user.go +++ b/server/handlers/user.go @@ -47,8 +47,11 @@ func LoginUser(ctx *gin.Context) { return } - var user *synccache.Entry[*op.User] - var err error + var ( + user *synccache.Entry[*op.User] + err error + ) + switch { case req.Username != "": user, err = op.LoadOrInitUserByUsername(req.Username) @@ -59,20 +62,25 @@ func LoginUser(ctx *gin.Context) { http.StatusBadRequest, model.NewAPIErrorStringResp("username or email is required"), ) + return } if err != nil { log.Errorf("failed to load user: %v", err) + if errors.Is(err, db.NotFoundError(db.ErrUserNotFound)) { ctx.AbortWithStatusJSON(http.StatusNotFound, model.NewAPIErrorResp(err)) return } + if errors.Is(err, op.ErrUserBanned) || errors.Is(err, op.ErrUserPending) { ctx.AbortWithStatusJSON(http.StatusForbidden, model.NewAPIErrorResp(err)) return } + ctx.AbortWithStatusJSON(http.StatusInternalServerError, model.NewAPIErrorResp(err)) + return } @@ -82,6 +90,7 @@ func LoginUser(ctx *gin.Context) { http.StatusForbidden, model.NewAPIErrorStringResp("password incorrect"), ) + return } @@ -99,10 +108,13 @@ func handleUserToken(ctx *gin.Context, user *op.User) { "message": err.Error(), "role": user.Role, })) + return } + log.Errorf("failed to generate token: %v", err) ctx.AbortWithStatusJSON(http.StatusInternalServerError, model.NewAPIErrorResp(err)) + return } @@ -160,6 +172,7 @@ func UserRooms(ctx *gin.Context) { ctx.AbortWithStatusJSON(http.StatusInternalServerError, model.NewAPIErrorResp(err)) return } + scopes = append(scopes, db.WhereRoomNameLikeOrCreatorInOrIDLike(keyword, ids, keyword)) case "name": scopes = append(scopes, db.WhereRoomNameLike(keyword)) @@ -195,6 +208,7 @@ func UserRooms(ctx *gin.Context) { http.StatusBadRequest, model.NewAPIErrorStringResp("not support sort"), ) + return } @@ -276,6 +290,7 @@ func UserJoinedRooms(ctx *gin.Context) { http.StatusBadRequest, model.NewAPIErrorStringResp("not support sort"), ) + return } @@ -309,6 +324,7 @@ func UserCheckJoinedRoom(ctx *gin.Context) { ctx.AbortWithStatusJSON(http.StatusNotFound, model.NewAPIErrorResp(err)) return } + room := roomE.Value() status, err := room.LoadMemberStatus(user.ID) @@ -412,6 +428,7 @@ func UserBindProviders(ctx *gin.Context) { CreatedAt: 0, } } + return true }) @@ -456,6 +473,7 @@ func SendUserBindEmailCaptcha(ctx *gin.Context) { http.StatusBadRequest, model.NewAPIErrorStringResp("captcha verify failed"), ) + return } @@ -464,6 +482,7 @@ func SendUserBindEmailCaptcha(ctx *gin.Context) { http.StatusBadRequest, model.NewAPIErrorStringResp("this email same as current email"), ) + return } @@ -474,6 +493,7 @@ func SendUserBindEmailCaptcha(ctx *gin.Context) { http.StatusBadRequest, model.NewAPIErrorStringResp("email already bind"), ) + return } @@ -503,6 +523,7 @@ func UserBindEmail(ctx *gin.Context) { http.StatusBadRequest, model.NewAPIErrorStringResp("email captcha verify failed"), ) + return } @@ -555,6 +576,7 @@ func SendUserSignupEmailCaptcha(ctx *gin.Context) { http.StatusBadRequest, model.NewAPIErrorStringResp("user signup disabled"), ) + return } else if email.DisableUserSignup.Get() { log.Errorf("email signup disabled") @@ -579,6 +601,7 @@ func SendUserSignupEmailCaptcha(ctx *gin.Context) { http.StatusBadRequest, model.NewAPIErrorStringResp("captcha verify failed"), ) + return } @@ -590,8 +613,10 @@ func SendUserSignupEmailCaptcha(ctx *gin.Context) { http.StatusBadRequest, model.NewAPIErrorStringResp("email format error"), ) + return } + if !slices.Contains( strings.Split(email.EmailSignupWhiteList.Get(), ","), after, @@ -601,6 +626,7 @@ func SendUserSignupEmailCaptcha(ctx *gin.Context) { http.StatusBadRequest, model.NewAPIErrorStringResp("email not in white list"), ) + return } } @@ -612,6 +638,7 @@ func SendUserSignupEmailCaptcha(ctx *gin.Context) { http.StatusBadRequest, model.NewAPIErrorStringResp("email already exists"), ) + return } @@ -633,6 +660,7 @@ func UserSignupEmail(ctx *gin.Context) { http.StatusBadRequest, model.NewAPIErrorStringResp("user signup disabled"), ) + return } else if email.DisableUserSignup.Get() { log.Errorf("email signup disabled") @@ -653,12 +681,14 @@ func UserSignupEmail(ctx *gin.Context) { ctx.AbortWithStatusJSON(http.StatusInternalServerError, model.NewAPIErrorResp(err)) return } + if !ok { log.Errorf("email captcha verify failed") ctx.AbortWithStatusJSON( http.StatusBadRequest, model.NewAPIErrorStringResp("email captcha verify failed"), ) + return } @@ -673,6 +703,7 @@ func UserSignupEmail(ctx *gin.Context) { } else { user, err = op.CreateUserWithEmail(req.Email, req.Password, req.Email) } + if err != nil { log.Errorf("failed to create user: %v", err) ctx.AbortWithStatusJSON(http.StatusBadRequest, model.NewAPIErrorResp(err)) @@ -718,6 +749,7 @@ func SendUserRetrievePasswordEmailCaptcha(ctx *gin.Context) { http.StatusBadRequest, model.NewAPIErrorStringResp("captcha verify failed"), ) + return } @@ -736,12 +768,14 @@ func SendUserRetrievePasswordEmailCaptcha(ctx *gin.Context) { Host: ctx.Request.Host, }).String() } + if host == "" { log.Error("failed to get host on send retrieve password email") ctx.AbortWithStatusJSON( http.StatusInternalServerError, model.NewAPIErrorStringResp("failed to get host"), ) + return } @@ -770,6 +804,7 @@ func UserRetrievePasswordEmail(ctx *gin.Context) { ctx.AbortWithStatusJSON(http.StatusBadRequest, model.NewAPIErrorResp(err)) return } + user := userE.Value() if ok, err := user.VerifyRetrievePasswordCaptchaEmail(req.Email, req.Captcha); err != nil || @@ -779,6 +814,7 @@ func UserRetrievePasswordEmail(ctx *gin.Context) { http.StatusBadRequest, model.NewAPIErrorStringResp("email captcha verify failed"), ) + return } @@ -837,6 +873,7 @@ func UserSignupPassword(ctx *gin.Context) { http.StatusBadRequest, model.NewAPIErrorStringResp("user signup disabled"), ) + return } else if !settings.EnablePasswordSignup.Get() { log.Errorf("password signup disabled") @@ -851,13 +888,17 @@ func UserSignupPassword(ctx *gin.Context) { return } - var user *op.UserEntry - var err error + var ( + user *op.UserEntry + err error + ) + if settings.SignupNeedReview.Get() || settings.PasswordSignupNeedReview.Get() { user, err = op.CreateUser(req.Username, req.Password, db.WithRole(dbModel.RolePending)) } else { user, err = op.CreateUser(req.Username, req.Password) } + if err != nil { log.Errorf("failed to create user: %v", err) ctx.AbortWithStatusJSON(http.StatusBadRequest, model.NewAPIErrorResp(err)) diff --git a/server/handlers/vendors/vendorAlist/alist.go b/server/handlers/vendors/vendorAlist/alist.go index f313f6f..427a953 100644 --- a/server/handlers/vendors/vendorAlist/alist.go +++ b/server/handlers/vendors/vendorAlist/alist.go @@ -34,6 +34,7 @@ func NewAlistVendorService(room *op.Room, movie *op.Movie) (*AlistVendorService, if movie.VendorInfo.Vendor != dbModel.VendorAlist { return nil, fmt.Errorf("alist vendor not support vendor %s", movie.VendorInfo.Vendor) } + return &AlistVendorService{ room: room, movie: movie, @@ -54,6 +55,7 @@ func (s *AlistVendorService) ListDynamicMovie( if reqUser.ID != s.movie.CreatorID { return nil, fmt.Errorf("list vendor dynamic folder error: %w", dbModel.ErrNoPermission) } + user := reqUser resp := &model.MovieList{ @@ -64,11 +66,13 @@ func (s *AlistVendorService) ListDynamicMovie( if err != nil { return nil, fmt.Errorf("load alist server id error: %w", err) } + newPath := path.Join(truePath, subPath) // check new path is in parent path if !strings.HasPrefix(newPath, truePath) { return nil, errors.New("sub path is not in parent path") } + aucd, err := user.AlistCache().LoadOrStore(ctx, serverID) if err != nil { if errors.Is(err, db.NotFoundError(db.ErrVendorNotFound)) { @@ -76,6 +80,7 @@ func (s *AlistVendorService) ListDynamicMovie( } return nil, err } + cli := s.Client() if keyword != "" { data, err := cli.FsSearch(ctx, &alist.FsSearchReq{ @@ -90,7 +95,9 @@ func (s *AlistVendorService) ListDynamicMovie( if err != nil { return nil, err } + resp.Total = int64(data.GetTotal()) + resp.Movies = make([]*model.Movie, len(data.GetContent())) for i, flr := range data.GetContent() { fileSubPath := strings.TrimPrefix(strings.Trim(flr.GetParent(), "/"), truePath) @@ -123,7 +130,9 @@ func (s *AlistVendorService) ListDynamicMovie( }, } } + resp.Paths = model.GenDefaultSubPaths(s.movie.ID, subPath, true) + return resp, nil } @@ -139,7 +148,9 @@ func (s *AlistVendorService) ListDynamicMovie( if err != nil { return nil, err } + resp.Total = int64(data.GetTotal()) + resp.Movies = make([]*model.Movie, len(data.GetContent())) for i, flr := range data.GetContent() { resp.Movies[i] = &model.Movie{ @@ -164,7 +175,9 @@ func (s *AlistVendorService) ListDynamicMovie( }, } } + resp.Paths = model.GenDefaultSubPaths(s.movie.ID, subPath, true) + return resp, nil } @@ -221,6 +234,7 @@ func (s *AlistVendorService) handleAliProvider( ctx.AbortWithStatusJSON(http.StatusInternalServerError, model.NewAPIErrorResp(err)) return } + if s.movie.Proxy { err := proxy.M3u8Data( ctx, @@ -243,6 +257,7 @@ func (s *AlistVendorService) handleAliProvider( ctx.AbortWithStatusJSON(http.StatusInternalServerError, model.NewAPIErrorResp(err)) return } + if s.movie.Proxy { s.proxyURL(ctx, log, b.URL) } else { @@ -269,6 +284,7 @@ func (s *AlistVendorService) handleDefaultProvider( http.StatusBadRequest, model.NewAPIErrorStringResp("id is empty"), ) + return } @@ -285,10 +301,12 @@ func (s *AlistVendorService) handleDefaultProvider( http.StatusBadRequest, model.NewAPIErrorStringResp("id out of range"), ) + return } subtitle := data.Subtitles[id] + b, err := subtitle.Cache.Get(ctx) if err != nil { log.Errorf("proxy vendor movie error: %v", err) @@ -304,8 +322,10 @@ func (s *AlistVendorService) handleDefaultProvider( http.StatusBadRequest, model.NewAPIErrorStringResp("proxy is not enabled"), ) + return } + s.proxyURL(ctx, log, data.URL) } } @@ -363,6 +383,7 @@ func (s *AlistVendorService) handleAliSubtitle( http.StatusBadRequest, model.NewAPIErrorStringResp("id out of range"), ) + return } @@ -386,13 +407,16 @@ func (s *AlistVendorService) GenMovieInfo( } movie := s.movie.Clone() + var err error creator, err := op.LoadOrInitUserByID(movie.CreatorID) if err != nil { return nil, err } + alistCache := s.movie.AlistCache() + data, err := alistCache.Get(ctx, &cache.AlistMovieCacheFuncArgs{ UserCache: creator.Value().AlistCache(), UserAgent: utils.UA, @@ -405,6 +429,7 @@ func (s *AlistVendorService) GenMovieInfo( if movie.Subtitles == nil { movie.Subtitles = make(map[string]*dbModel.Subtitle, len(data.Subtitles)) } + movie.Subtitles[subt.Name] = &dbModel.Subtitle{ URL: fmt.Sprintf( "/api/room/movie/proxy/%s?t=subtitle&id=%d&token=%s&roomId=%s", @@ -423,6 +448,7 @@ func (s *AlistVendorService) GenMovieInfo( if err != nil { return nil, err } + movie.URL = fmt.Sprintf( "/api/room/movie/proxy/%s?token=%s&roomId=%s", movie.ID, @@ -434,12 +460,14 @@ func (s *AlistVendorService) GenMovieInfo( rawStreamURL := data.URL subPath := s.movie.SubPath() + var rawType string if subPath == "" { rawType = utils.GetURLExtension(movie.VendorInfo.Alist.Path) } else { rawType = utils.GetURLExtension(subPath) } + movie.MoreSources = []*dbModel.MoreSource{ { Name: "raw", @@ -452,6 +480,7 @@ func (s *AlistVendorService) GenMovieInfo( if movie.Subtitles == nil { movie.Subtitles = make(map[string]*dbModel.Subtitle, len(data.Subtitles)) } + movie.Subtitles[subt.Name] = &dbModel.Subtitle{ URL: fmt.Sprintf( "/api/room/movie/proxy/%s?t=subtitle&id=%d&token=%s&roomId=%s", @@ -472,7 +501,9 @@ func (s *AlistVendorService) GenMovieInfo( if err != nil { return nil, fmt.Errorf("refresh 115 movie cache error: %w", err) } + movie.URL = data.URL + movie.Subtitles = make(map[string]*dbModel.Subtitle, len(data.Subtitles)) for _, subt := range data.Subtitles { movie.Subtitles[subt.Name] = &dbModel.Subtitle{ @@ -486,6 +517,7 @@ func (s *AlistVendorService) GenMovieInfo( } movie.VendorInfo.Alist.Password = "" + return movie, nil } @@ -495,13 +527,16 @@ func (s *AlistVendorService) GenProxyMovieInfo( _, userToken string, ) (*dbModel.Movie, error) { movie := s.movie.Clone() + var err error creator, err := op.LoadOrInitUserByID(movie.CreatorID) if err != nil { return nil, err } + alistCache := s.movie.AlistCache() + data, err := alistCache.Get(ctx, &cache.AlistMovieCacheFuncArgs{ UserCache: creator.Value().AlistCache(), UserAgent: utils.UA, @@ -514,6 +549,7 @@ func (s *AlistVendorService) GenProxyMovieInfo( if movie.Subtitles == nil { movie.Subtitles = make(map[string]*dbModel.Subtitle, len(data.Subtitles)) } + movie.Subtitles[subt.Name] = &dbModel.Subtitle{ URL: fmt.Sprintf( "/api/room/movie/proxy/%s?t=subtitle&id=%d&token=%s&roomId=%s", @@ -532,6 +568,7 @@ func (s *AlistVendorService) GenProxyMovieInfo( if err != nil { return nil, err } + movie.URL = fmt.Sprintf( "/api/room/movie/proxy/%s?token=%s&roomId=%s", movie.ID, @@ -558,6 +595,7 @@ func (s *AlistVendorService) GenProxyMovieInfo( if movie.Subtitles == nil { movie.Subtitles = make(map[string]*dbModel.Subtitle, len(data.Subtitles)) } + movie.Subtitles[subt.Name] = &dbModel.Subtitle{ URL: fmt.Sprintf( "/api/room/movie/proxy/%s?t=subtitle&id=%d&token=%s&roomId=%s", @@ -592,5 +630,6 @@ func (s *AlistVendorService) GenProxyMovieInfo( } movie.VendorInfo.Alist.Password = "" + return movie, nil } diff --git a/server/handlers/vendors/vendorAlist/list.go b/server/handlers/vendors/vendorAlist/list.go index cfafc34..34b3612 100644 --- a/server/handlers/vendors/vendorAlist/list.go +++ b/server/handlers/vendors/vendorAlist/list.go @@ -66,6 +66,7 @@ func List(ctx *gin.Context) { ctx.AbortWithStatusJSON(http.StatusInternalServerError, model.NewAPIErrorResp(err)) return } + if total == 0 { ctx.JSON(http.StatusBadRequest, model.NewAPIErrorStringResp("alist server not found")) return @@ -78,9 +79,12 @@ func List(ctx *gin.Context) { http.StatusBadRequest, model.NewAPIErrorStringResp("alist server not found"), ) + return } + ctx.AbortWithStatusJSON(http.StatusInternalServerError, model.NewAPIErrorResp(err)) + return } @@ -117,6 +121,7 @@ func List(ctx *gin.Context) { AlistFSListResp: var serverID string + serverID, req.Path, err = dbModel.GetAlistServerIDFromPath(req.Path) if err != nil { ctx.AbortWithStatusJSON(http.StatusBadRequest, model.NewAPIErrorResp(err)) @@ -135,6 +140,7 @@ AlistFSListResp: } ctx.AbortWithStatusJSON(http.StatusInternalServerError, model.NewAPIErrorResp(err)) + return } @@ -155,6 +161,7 @@ AlistFSListResp: } req.Path = strings.Trim(req.Path, "/") + resp := AlistFSListResp{ Total: data.GetTotal(), Paths: model.GenDefaultPaths(req.Path, true, @@ -183,6 +190,7 @@ AlistFSListResp: } ctx.JSON(http.StatusOK, model.NewAPIDataResp(&resp)) + return } @@ -201,6 +209,7 @@ AlistFSListResp: } req.Path = strings.Trim(req.Path, "/") + resp := AlistFSListResp{ Total: data.GetTotal(), Paths: model.GenDefaultPaths(req.Path, true, diff --git a/server/handlers/vendors/vendorAlist/login.go b/server/handlers/vendors/vendorAlist/login.go index 7225a10..4db64f6 100644 --- a/server/handlers/vendors/vendorAlist/login.go +++ b/server/handlers/vendors/vendorAlist/login.go @@ -29,17 +29,21 @@ func (r *LoginReq) Validate() error { if r.Host == "" { return errors.New("host is required") } + url, err := url.Parse(r.Host) if err != nil { return err } + if url.Scheme != "http" && url.Scheme != "https" { return errors.New("host is invalid") } + r.Host = strings.TrimRight(url.String(), "/") if r.Password != "" && r.HashedPassword != "" { return errors.New("password and hashedPassword can't be both set") } + return nil } diff --git a/server/handlers/vendors/vendorAlist/me.go b/server/handlers/vendors/vendorAlist/me.go index ca927d5..71d10c2 100644 --- a/server/handlers/vendors/vendorAlist/me.go +++ b/server/handlers/vendors/vendorAlist/me.go @@ -23,6 +23,7 @@ func Me(ctx *gin.Context) { http.StatusBadRequest, model.NewAPIErrorResp(errors.New("serverID is required")), ) + return } @@ -32,7 +33,9 @@ func Me(ctx *gin.Context) { ctx.JSON(http.StatusBadRequest, model.NewAPIErrorStringResp("alist server not found")) return } + ctx.AbortWithStatusJSON(http.StatusInternalServerError, model.NewAPIErrorResp(err)) + return } @@ -67,7 +70,9 @@ func Binds(ctx *gin.Context) { })) return } + ctx.AbortWithStatusJSON(http.StatusInternalServerError, model.NewAPIErrorResp(err)) + return } diff --git a/server/handlers/vendors/vendorBilibili/bilibili.go b/server/handlers/vendors/vendorBilibili/bilibili.go index 3856709..6b7f340 100644 --- a/server/handlers/vendors/vendorBilibili/bilibili.go +++ b/server/handlers/vendors/vendorBilibili/bilibili.go @@ -33,6 +33,7 @@ func NewBilibiliVendorService(room *op.Room, movie *op.Movie) (*BilibiliVendorSe if movie.VendorInfo.Vendor != dbModel.VendorBilibili { return nil, fmt.Errorf("bilibili vendor not support vendor %s", movie.VendorInfo.Vendor) } + return &BilibiliVendorService{ room: room, movie: movie, @@ -78,6 +79,7 @@ func (s *BilibiliVendorService) handleDanmuProxy(ctx *gin.Context, log *logrus.E ctx.AbortWithStatusJSON(http.StatusInternalServerError, model.NewAPIErrorResp(err)) return } + ctx.Data(http.StatusOK, "application/xml", danmu) } @@ -88,14 +90,17 @@ func (s *BilibiliVendorService) handleLiveProxy(ctx *gin.Context, log *logrus.En ctx.AbortWithStatusJSON(http.StatusInternalServerError, model.NewAPIErrorResp(err)) return } + if len(data) == 0 { log.Error("proxy vendor movie error: live data is empty") ctx.AbortWithStatusJSON( http.StatusNotFound, model.NewAPIErrorStringResp("live data is empty"), ) + return } + ctx.Data(http.StatusOK, "application/vnd.apple.mpegurl", data) } @@ -106,6 +111,7 @@ func (s *BilibiliVendorService) handleVideoProxy(ctx *gin.Context, log *logrus.E http.StatusBadRequest, model.NewAPIErrorStringResp("proxy is not enabled"), ) + return } @@ -138,18 +144,23 @@ func (s *BilibiliVendorService) handleMpdProxy( t string, mpdC *cache.BilibiliMpdCache, ) { - var mpd string - var err error + var ( + mpd string + err error + ) + if t == "hevc" { mpd, err = cache.BilibiliMpdToString(mpdC.HevcMpd, middlewares.GetToken(ctx)) } else { mpd, err = cache.BilibiliMpdToString(mpdC.Mpd, middlewares.GetToken(ctx)) } + if err != nil { log.Errorf("proxy vendor movie error: %v", err) ctx.AbortWithStatusJSON(http.StatusInternalServerError, model.NewAPIErrorResp(err)) return } + ctx.Data(http.StatusOK, "application/dash+xml", stream.StringToBytes(mpd)) } @@ -165,16 +176,19 @@ func (s *BilibiliVendorService) handleStreamProxy( ctx.AbortWithStatusJSON(http.StatusBadRequest, model.NewAPIErrorResp(err)) return } + if streamID >= len(mpdC.URLs) { log.Errorf("proxy vendor movie error: %v", "stream id out of range") ctx.AbortWithStatusJSON( http.StatusBadRequest, model.NewAPIErrorStringResp("stream id out of range"), ) + return } headers := s.getProxyHeaders() + err = proxy.URL(ctx, mpdC.URLs[streamID], headers, @@ -196,6 +210,7 @@ func (s *BilibiliVendorService) getProxyHeaders() map[string]string { headers["Referer"] = "https://www.bilibili.com" headers["User-Agent"] = utils.UA } + return headers } @@ -228,7 +243,9 @@ func (s *BilibiliVendorService) handleSubtitleProxy(ctx *gin.Context, log *logru ctx.AbortWithStatusJSON(http.StatusInternalServerError, model.NewAPIErrorResp(err)) return } + http.ServeContent(ctx.Writer, ctx.Request, id, time.Now(), bytes.NewReader(srtData)) + return } @@ -246,7 +263,9 @@ func (s *BilibiliVendorService) GenMovieInfo( } movie := s.movie.Clone() + var err error + if movie.IsFolder { return nil, errors.New("bilibili folder not support") } @@ -267,6 +286,7 @@ func (s *BilibiliVendorService) GenMovieInfo( userToken, movie.RoomID, ) + return movie, nil } @@ -280,10 +300,12 @@ func (s *BilibiliVendorService) GenMovieInfo( var str string if movie.VendorInfo.Bilibili.Shared { var u *op.UserEntry + u, err = op.LoadOrInitUserByID(movie.CreatorID) if err != nil { return nil, err } + str, err = s.movie.BilibiliCache().NoSharedMovie.LoadOrStore( ctx, movie.CreatorID, @@ -292,19 +314,23 @@ func (s *BilibiliVendorService) GenMovieInfo( } else { str, err = s.movie.BilibiliCache().NoSharedMovie.LoadOrStore(ctx, user.ID, user.BilibiliCache()) } + if err != nil { return nil, err } + movie.URL = str srt, err := bmc.Subtitle.Get(ctx, user.BilibiliCache()) if err != nil { return nil, err } + for k := range srt { if movie.Subtitles == nil { movie.Subtitles = make(map[string]*dbModel.Subtitle, len(srt)) } + movie.Subtitles[k] = &dbModel.Subtitle{ URL: fmt.Sprintf( "/api/room/movie/proxy/%s?t=subtitle&n=%s&token=%s&roomId=%s", @@ -316,6 +342,7 @@ func (s *BilibiliVendorService) GenMovieInfo( Type: "srt", } } + return movie, nil } @@ -325,7 +352,9 @@ func (s *BilibiliVendorService) GenProxyMovieInfo( _, userToken string, ) (*dbModel.Movie, error) { movie := s.movie.Clone() + var err error + if movie.IsFolder { return nil, errors.New("bilibili folder not support") } @@ -346,6 +375,7 @@ func (s *BilibiliVendorService) GenProxyMovieInfo( userToken, movie.RoomID, ) + return movie, nil } @@ -375,14 +405,17 @@ func (s *BilibiliVendorService) GenProxyMovieInfo( ), }, } + srt, err := bmc.Subtitle.Get(ctx, user.BilibiliCache()) if err != nil { return nil, err } + for k := range srt { if movie.Subtitles == nil { movie.Subtitles = make(map[string]*dbModel.Subtitle, len(srt)) } + movie.Subtitles[k] = &dbModel.Subtitle{ URL: fmt.Sprintf( "/api/room/movie/proxy/%s?t=subtitle&n=%s&token=%s&roomId=%s", @@ -394,5 +427,6 @@ func (s *BilibiliVendorService) GenProxyMovieInfo( Type: "srt", } } + return movie, nil } diff --git a/server/handlers/vendors/vendorBilibili/login.go b/server/handlers/vendors/vendorBilibili/login.go index 4548197..0785659 100644 --- a/server/handlers/vendors/vendorBilibili/login.go +++ b/server/handlers/vendors/vendorBilibili/login.go @@ -23,6 +23,7 @@ func NewQRCode(ctx *gin.Context) { ctx.AbortWithStatusJSON(http.StatusInternalServerError, model.NewAPIErrorResp(err)) return } + ctx.JSON(http.StatusOK, model.NewAPIDataResp(r)) } @@ -51,6 +52,7 @@ func LoginWithQR(ctx *gin.Context) { } backend := ctx.Query("backend") + resp, err := vendor.LoadBilibiliClient(backend). LoginWithQRCode(ctx, &bilibili.LoginWithQRCodeReq{ Key: req.Key, @@ -86,6 +88,7 @@ func LoginWithQR(ctx *gin.Context) { ctx.AbortWithStatusJSON(http.StatusInternalServerError, model.NewAPIErrorResp(err)) return } + _, err = user.BilibiliCache(). Data(). Refresh(ctx, func(_ context.Context, _ ...struct{}) (*cache.BilibiliUserCacheData, error) { @@ -98,6 +101,7 @@ func LoginWithQR(ctx *gin.Context) { ctx.AbortWithStatusJSON(http.StatusInternalServerError, model.NewAPIErrorResp(err)) return } + ctx.JSON(http.StatusOK, model.NewAPIDataResp(gin.H{ "status": "success", })) @@ -106,6 +110,7 @@ func LoginWithQR(ctx *gin.Context) { http.StatusInternalServerError, model.NewAPIErrorStringResp("unknown status"), ) + return } } @@ -116,6 +121,7 @@ func NewCaptcha(ctx *gin.Context) { ctx.AbortWithStatusJSON(http.StatusInternalServerError, model.NewAPIErrorResp(err)) return } + ctx.JSON(http.StatusOK, model.NewAPIDataResp(r)) } @@ -130,15 +136,19 @@ func (r *SMSReq) Validate() error { if r.Token == "" { return errors.New("token is empty") } + if r.Challenge == "" { return errors.New("challenge is empty") } + if r.V == "" { return errors.New("validate is empty") } + if r.Telephone == "" { return errors.New("telephone is empty") } + return nil } @@ -163,6 +173,7 @@ func NewSMS(ctx *gin.Context) { ctx.AbortWithStatusJSON(http.StatusInternalServerError, model.NewAPIErrorResp(err)) return } + ctx.JSON(http.StatusOK, model.NewAPIDataResp(gin.H{ "captchaKey": r.GetCaptchaKey(), })) @@ -178,12 +189,15 @@ func (r *SMSLoginReq) Validate() error { if r.Telephone == "" { return errors.New("telephone is empty") } + if r.CaptchaKey == "" { return errors.New("captchaKey is empty") } + if r.Code == "" { return errors.New("code is empty") } + return nil } @@ -201,6 +215,7 @@ func LoginWithSMS(ctx *gin.Context) { } backend := ctx.Query("backend") + c, err := vendor.LoadBilibiliClient(backend).LoginWithSMS(ctx, &bilibili.LoginWithSMSReq{ Phone: req.Telephone, CaptchaKey: req.CaptchaKey, @@ -210,6 +225,7 @@ func LoginWithSMS(ctx *gin.Context) { ctx.AbortWithStatusJSON(http.StatusInternalServerError, model.NewAPIErrorResp(err)) return } + _, err = db.CreateOrSaveBilibiliVendor(&dbModel.BilibiliVendor{ UserID: user.ID, Backend: backend, @@ -219,6 +235,7 @@ func LoginWithSMS(ctx *gin.Context) { ctx.AbortWithStatusJSON(http.StatusInternalServerError, model.NewAPIErrorResp(err)) return } + _, err = user.BilibiliCache(). Data(). Refresh(ctx, func(_ context.Context, _ ...struct{}) (*cache.BilibiliUserCacheData, error) { @@ -231,20 +248,24 @@ func LoginWithSMS(ctx *gin.Context) { ctx.AbortWithStatusJSON(http.StatusInternalServerError, model.NewAPIErrorResp(err)) return } + ctx.Status(http.StatusNoContent) } func Logout(ctx *gin.Context) { log := middlewares.GetLogger(ctx) user := middlewares.GetUserEntry(ctx).Value() + err := db.DeleteBilibiliVendor(user.ID) if err != nil { ctx.AbortWithStatusJSON(http.StatusInternalServerError, model.NewAPIErrorResp(err)) return } + err = user.BilibiliCache().Clear(ctx) if err != nil { log.Errorf("clear bilibili cache: %v", err) } + ctx.Status(http.StatusNoContent) } diff --git a/server/handlers/vendors/vendorBilibili/me.go b/server/handlers/vendors/vendorBilibili/me.go index 1f8f7c4..a583b89 100644 --- a/server/handlers/vendors/vendorBilibili/me.go +++ b/server/handlers/vendors/vendorBilibili/me.go @@ -26,15 +26,19 @@ func Me(ctx *gin.Context) { })) return } + ctx.AbortWithStatusJSON(http.StatusInternalServerError, model.NewAPIErrorResp(err)) + return } + if len(bucd.Cookies) == 0 { ctx.JSON(http.StatusOK, model.NewAPIDataResp(&BilibiliMeResp{ IsLogin: false, })) return } + resp, err := vendor.LoadBilibiliClient(bucd.Backend).UserInfo(ctx, &bilibili.UserInfoReq{ Cookies: utils.HTTPCookieToMap(bucd.Cookies), }) diff --git a/server/handlers/vendors/vendorBilibili/parse.go b/server/handlers/vendors/vendorBilibili/parse.go index 4e768ac..7abd5bb 100644 --- a/server/handlers/vendors/vendorBilibili/parse.go +++ b/server/handlers/vendors/vendorBilibili/parse.go @@ -51,6 +51,7 @@ func Parse(ctx *gin.Context) { // can be no login var cookies []*http.Cookie + bucd, err := user.BilibiliCache().Get(ctx) if err != nil { if !errors.Is(err, db.NotFoundError(db.ErrVendorNotFound)) { @@ -72,6 +73,7 @@ func Parse(ctx *gin.Context) { ctx.AbortWithStatusJSON(http.StatusBadRequest, model.NewAPIErrorResp(err)) return } + ctx.JSON(http.StatusOK, model.NewAPIDataResp(resp)) case "av": aid, err := strconv.ParseUint(resp.GetId(), 10, 64) @@ -79,6 +81,7 @@ func Parse(ctx *gin.Context) { ctx.AbortWithStatusJSON(http.StatusInternalServerError, model.NewAPIErrorResp(err)) return } + resp, err := cli.ParseVideoPage(ctx, &bilibili.ParseVideoPageReq{ Cookies: utils.HTTPCookieToMap(cookies), Aid: aid, @@ -88,6 +91,7 @@ func Parse(ctx *gin.Context) { ctx.AbortWithStatusJSON(http.StatusBadRequest, model.NewAPIErrorResp(err)) return } + ctx.JSON(http.StatusOK, model.NewAPIDataResp(resp)) case "ep": epid, err := strconv.ParseUint(resp.GetId(), 10, 64) @@ -95,6 +99,7 @@ func Parse(ctx *gin.Context) { ctx.AbortWithStatusJSON(http.StatusInternalServerError, model.NewAPIErrorResp(err)) return } + resp, err := cli.ParsePGCPage(ctx, &bilibili.ParsePGCPageReq{ Cookies: utils.HTTPCookieToMap(cookies), Epid: epid, @@ -103,6 +108,7 @@ func Parse(ctx *gin.Context) { ctx.AbortWithStatusJSON(http.StatusBadRequest, model.NewAPIErrorResp(err)) return } + ctx.JSON(http.StatusOK, model.NewAPIDataResp(resp)) case "ss": ssid, err := strconv.ParseUint(resp.GetId(), 10, 64) @@ -110,6 +116,7 @@ func Parse(ctx *gin.Context) { ctx.AbortWithStatusJSON(http.StatusInternalServerError, model.NewAPIErrorResp(err)) return } + resp, err := cli.ParsePGCPage(ctx, &bilibili.ParsePGCPageReq{ Cookies: utils.HTTPCookieToMap(cookies), Ssid: ssid, @@ -118,6 +125,7 @@ func Parse(ctx *gin.Context) { ctx.AbortWithStatusJSON(http.StatusBadRequest, model.NewAPIErrorResp(err)) return } + ctx.JSON(http.StatusOK, model.NewAPIDataResp(resp)) case "live": roomid, err := strconv.ParseUint(resp.GetId(), 10, 64) @@ -125,6 +133,7 @@ func Parse(ctx *gin.Context) { ctx.AbortWithStatusJSON(http.StatusInternalServerError, model.NewAPIErrorResp(err)) return } + resp, err := cli.ParseLivePage(ctx, &bilibili.ParseLivePageReq{ Cookies: utils.HTTPCookieToMap(cookies), RoomID: roomid, @@ -133,12 +142,14 @@ func Parse(ctx *gin.Context) { ctx.AbortWithStatusJSON(http.StatusBadRequest, model.NewAPIErrorResp(err)) return } + ctx.JSON(http.StatusOK, model.NewAPIDataResp(resp)) default: ctx.AbortWithStatusJSON( http.StatusInternalServerError, model.NewAPIErrorStringResp("unknown match type "+resp.GetType()), ) + return } } diff --git a/server/handlers/vendors/vendorEmby/emby.go b/server/handlers/vendors/vendorEmby/emby.go index 93cc112..3d69195 100644 --- a/server/handlers/vendors/vendorEmby/emby.go +++ b/server/handlers/vendors/vendorEmby/emby.go @@ -31,6 +31,7 @@ func NewEmbyVendorService(room *op.Room, movie *op.Movie) (*EmbyVendorService, e if movie.VendorInfo.Vendor != dbModel.VendorEmby { return nil, fmt.Errorf("emby vendor not support vendor %s", movie.VendorInfo.Vendor) } + return &EmbyVendorService{ room: room, movie: movie, @@ -51,6 +52,7 @@ func (s *EmbyVendorService) ListDynamicMovie( if reqUser.ID != s.movie.CreatorID { return nil, fmt.Errorf("list vendor dynamic folder error: %w", dbModel.ErrNoPermission) } + user := reqUser resp := &model.MovieList{ @@ -61,9 +63,11 @@ func (s *EmbyVendorService) ListDynamicMovie( if err != nil { return nil, fmt.Errorf("load emby server id error: %w", err) } + if subPath != "" { truePath = subPath } + aucd, err := user.EmbyCache().LoadOrStore(ctx, serverID) if err != nil { if errors.Is(err, db.NotFoundError(db.ErrVendorNotFound)) { @@ -71,6 +75,7 @@ func (s *EmbyVendorService) ListDynamicMovie( } return nil, err } + data, err := s.Client().FsList(ctx, &emby.FsListReq{ Host: aucd.Host, Path: truePath, @@ -83,7 +88,9 @@ func (s *EmbyVendorService) ListDynamicMovie( if err != nil { return nil, fmt.Errorf("emby fs list error: %w", err) } + resp.Total = int64(data.GetTotal()) + resp.Movies = make([]*model.Movie, len(data.GetItems())) for i, flr := range data.GetItems() { resp.Movies[i] = &model.Movie{ @@ -106,6 +113,7 @@ func (s *EmbyVendorService) ListDynamicMovie( }, } } + return resp, nil } @@ -118,6 +126,7 @@ func (s *EmbyVendorService) handleProxyMovie(ctx *gin.Context) { http.StatusBadRequest, model.NewAPIErrorStringResp("proxy is not enabled"), ) + return } @@ -154,6 +163,7 @@ func (s *EmbyVendorService) handleProxyMovie(ctx *gin.Context) { http.StatusBadRequest, model.NewAPIErrorStringResp("source out of range"), ) + return } @@ -231,6 +241,7 @@ func (s *EmbyVendorService) handleSubtitle(ctx *gin.Context) error { time.Now(), bytes.NewReader(data), ) + return nil } @@ -258,12 +269,14 @@ func (s *EmbyVendorService) GenMovieInfo( } movie := s.movie.Clone() + var err error u, err := op.LoadOrInitUserByID(movie.CreatorID) if err != nil { return nil, err } + data, err := s.movie.EmbyCache().Get(ctx, u.Value().EmbyCache()) if err != nil { return nil, err @@ -272,16 +285,19 @@ func (s *EmbyVendorService) GenMovieInfo( if len(data.Sources) == 0 { return nil, errors.New("no source") } + movie.URL = data.Sources[0].URL for _, s := range data.Sources[0].Subtitles { if movie.Subtitles == nil { movie.Subtitles = make(map[string]*dbModel.Subtitle, len(data.Sources[0].Subtitles)) } + movie.Subtitles[s.Name] = &dbModel.Subtitle{ URL: s.URL, Type: s.Type, } } + for _, s := range data.Sources[1:] { movie.MoreSources = append(movie.MoreSources, &dbModel.MoreSource{ @@ -294,6 +310,7 @@ func (s *EmbyVendorService) GenMovieInfo( if movie.Subtitles == nil { movie.Subtitles = make(map[string]*dbModel.Subtitle, len(s.Subtitles)) } + movie.Subtitles[subt.Name] = &dbModel.Subtitle{ URL: subt.URL, Type: subt.Type, @@ -310,12 +327,14 @@ func (s *EmbyVendorService) GenProxyMovieInfo( _, userToken string, ) (*dbModel.Movie, error) { movie := s.movie.Clone() + var err error u, err := op.LoadOrInitUserByID(movie.CreatorID) if err != nil { return nil, err } + data, err := s.movie.EmbyCache().Get(ctx, u.Value().EmbyCache()) if err != nil { return nil, err @@ -326,6 +345,7 @@ func (s *EmbyVendorService) GenProxyMovieInfo( if si != len(data.Sources)-1 { continue } + if movie.URL == "" { return nil, errors.New("no source") } @@ -335,6 +355,7 @@ func (s *EmbyVendorService) GenProxyMovieInfo( if err != nil { return nil, err } + rawQuery := url.Values{} rawQuery.Set("source", strconv.Itoa(si)) rawQuery.Set("token", userToken) @@ -360,10 +381,12 @@ func (s *EmbyVendorService) GenProxyMovieInfo( if len(es.Subtitles) == 0 { continue } + for sbi, s := range es.Subtitles { if movie.Subtitles == nil { movie.Subtitles = make(map[string]*dbModel.Subtitle, len(es.Subtitles)) } + rawQuery := url.Values{} rawQuery.Set("t", "subtitle") rawQuery.Set("source", strconv.Itoa(si)) diff --git a/server/handlers/vendors/vendorEmby/list.go b/server/handlers/vendors/vendorEmby/list.go index 960ad20..f475935 100644 --- a/server/handlers/vendors/vendorEmby/list.go +++ b/server/handlers/vendors/vendorEmby/list.go @@ -61,8 +61,10 @@ func List(ctx *gin.Context) { "keywords is not supported when not choose server (server id is empty)", ), ) + return } + socpes := [](func(*gorm.DB) *gorm.DB){ db.OrderByCreatedAtAsc, } @@ -72,6 +74,7 @@ func List(ctx *gin.Context) { ctx.AbortWithStatusJSON(http.StatusInternalServerError, model.NewAPIErrorResp(err)) return } + if total == 0 { ctx.JSON(http.StatusBadRequest, model.NewAPIErrorStringResp("emby server not found")) return @@ -84,9 +87,12 @@ func List(ctx *gin.Context) { http.StatusBadRequest, model.NewAPIErrorStringResp("emby server not found"), ) + return } + ctx.AbortWithStatusJSON(http.StatusInternalServerError, model.NewAPIErrorResp(err)) + return } @@ -124,6 +130,7 @@ func List(ctx *gin.Context) { EmbyFSListResp: var serverID string + serverID, req.Path, err = dbModel.GetEmbyServerIDFromPath(req.Path) if err != nil { ctx.AbortWithStatusJSON(http.StatusBadRequest, model.NewAPIErrorResp(err)) @@ -136,11 +143,14 @@ EmbyFSListResp: ctx.JSON(http.StatusBadRequest, model.NewAPIErrorStringResp("emby server not found")) return } + ctx.AbortWithStatusJSON(http.StatusInternalServerError, model.NewAPIErrorResp(err)) + return } cli := vendor.LoadEmbyClient(ctx.Query("backend")) + data, err := cli.FsList(ctx, &emby.FsListReq{ Host: aucd.Host, Path: req.Path, @@ -155,6 +165,7 @@ EmbyFSListResp: http.StatusInternalServerError, model.NewAPIErrorResp(fmt.Errorf("emby fs list error: %w", err)), ) + return } @@ -168,11 +179,13 @@ EmbyFSListResp: if p.GetPath() == "1" { n = aucd.Host } + resp.Paths = append(resp.Paths, &model.Path{ Name: n, Path: fmt.Sprintf("%s/%s", aucd.ServerID, p.GetPath()), }) } + for _, i := range data.GetItems() { resp.Items = append(resp.Items, &EmbyFileItem{ Item: &model.Item{ diff --git a/server/handlers/vendors/vendorEmby/login.go b/server/handlers/vendors/vendorEmby/login.go index 1ad44de..4ce0559 100644 --- a/server/handlers/vendors/vendorEmby/login.go +++ b/server/handlers/vendors/vendorEmby/login.go @@ -28,17 +28,21 @@ func (r *LoginReq) Validate() error { if r.Host == "" { return errors.New("host is required") } + url, err := url.Parse(r.Host) if err != nil { return err } + if url.Scheme != "http" && url.Scheme != "https" { return errors.New("host is invalid") } + r.Host = strings.TrimRight(url.String(), "/") if r.Username == "" { return errors.New("username is required") } + return nil } @@ -73,6 +77,7 @@ func Login(ctx *gin.Context) { http.StatusInternalServerError, model.NewAPIErrorStringResp("serverID is empty"), ) + return } @@ -135,6 +140,7 @@ func logoutEmby(eucd *cache.EmbyUserCacheData) { if eucd == nil || eucd.APIKey == "" { return } + _, _ = vendor.LoadEmbyClient(eucd.Backend).Logout(context.Background(), &emby.LogoutReq{ Host: eucd.Host, Token: eucd.APIKey, diff --git a/server/handlers/vendors/vendorEmby/me.go b/server/handlers/vendors/vendorEmby/me.go index 79a5faf..3d338f1 100644 --- a/server/handlers/vendors/vendorEmby/me.go +++ b/server/handlers/vendors/vendorEmby/me.go @@ -23,6 +23,7 @@ func Me(ctx *gin.Context) { http.StatusBadRequest, model.NewAPIErrorResp(errors.New("serverID is required")), ) + return } @@ -32,7 +33,9 @@ func Me(ctx *gin.Context) { ctx.JSON(http.StatusBadRequest, model.NewAPIErrorStringResp("emby server not found")) return } + ctx.AbortWithStatusJSON(http.StatusInternalServerError, model.NewAPIErrorResp(err)) + return } @@ -67,7 +70,9 @@ func Binds(ctx *gin.Context) { })) return } + ctx.AbortWithStatusJSON(http.StatusInternalServerError, model.NewAPIErrorResp(err)) + return } diff --git a/server/handlers/vendors/vendoralist/alist.go b/server/handlers/vendors/vendoralist/alist.go index f313f6f..427a953 100644 --- a/server/handlers/vendors/vendoralist/alist.go +++ b/server/handlers/vendors/vendoralist/alist.go @@ -34,6 +34,7 @@ func NewAlistVendorService(room *op.Room, movie *op.Movie) (*AlistVendorService, if movie.VendorInfo.Vendor != dbModel.VendorAlist { return nil, fmt.Errorf("alist vendor not support vendor %s", movie.VendorInfo.Vendor) } + return &AlistVendorService{ room: room, movie: movie, @@ -54,6 +55,7 @@ func (s *AlistVendorService) ListDynamicMovie( if reqUser.ID != s.movie.CreatorID { return nil, fmt.Errorf("list vendor dynamic folder error: %w", dbModel.ErrNoPermission) } + user := reqUser resp := &model.MovieList{ @@ -64,11 +66,13 @@ func (s *AlistVendorService) ListDynamicMovie( if err != nil { return nil, fmt.Errorf("load alist server id error: %w", err) } + newPath := path.Join(truePath, subPath) // check new path is in parent path if !strings.HasPrefix(newPath, truePath) { return nil, errors.New("sub path is not in parent path") } + aucd, err := user.AlistCache().LoadOrStore(ctx, serverID) if err != nil { if errors.Is(err, db.NotFoundError(db.ErrVendorNotFound)) { @@ -76,6 +80,7 @@ func (s *AlistVendorService) ListDynamicMovie( } return nil, err } + cli := s.Client() if keyword != "" { data, err := cli.FsSearch(ctx, &alist.FsSearchReq{ @@ -90,7 +95,9 @@ func (s *AlistVendorService) ListDynamicMovie( if err != nil { return nil, err } + resp.Total = int64(data.GetTotal()) + resp.Movies = make([]*model.Movie, len(data.GetContent())) for i, flr := range data.GetContent() { fileSubPath := strings.TrimPrefix(strings.Trim(flr.GetParent(), "/"), truePath) @@ -123,7 +130,9 @@ func (s *AlistVendorService) ListDynamicMovie( }, } } + resp.Paths = model.GenDefaultSubPaths(s.movie.ID, subPath, true) + return resp, nil } @@ -139,7 +148,9 @@ func (s *AlistVendorService) ListDynamicMovie( if err != nil { return nil, err } + resp.Total = int64(data.GetTotal()) + resp.Movies = make([]*model.Movie, len(data.GetContent())) for i, flr := range data.GetContent() { resp.Movies[i] = &model.Movie{ @@ -164,7 +175,9 @@ func (s *AlistVendorService) ListDynamicMovie( }, } } + resp.Paths = model.GenDefaultSubPaths(s.movie.ID, subPath, true) + return resp, nil } @@ -221,6 +234,7 @@ func (s *AlistVendorService) handleAliProvider( ctx.AbortWithStatusJSON(http.StatusInternalServerError, model.NewAPIErrorResp(err)) return } + if s.movie.Proxy { err := proxy.M3u8Data( ctx, @@ -243,6 +257,7 @@ func (s *AlistVendorService) handleAliProvider( ctx.AbortWithStatusJSON(http.StatusInternalServerError, model.NewAPIErrorResp(err)) return } + if s.movie.Proxy { s.proxyURL(ctx, log, b.URL) } else { @@ -269,6 +284,7 @@ func (s *AlistVendorService) handleDefaultProvider( http.StatusBadRequest, model.NewAPIErrorStringResp("id is empty"), ) + return } @@ -285,10 +301,12 @@ func (s *AlistVendorService) handleDefaultProvider( http.StatusBadRequest, model.NewAPIErrorStringResp("id out of range"), ) + return } subtitle := data.Subtitles[id] + b, err := subtitle.Cache.Get(ctx) if err != nil { log.Errorf("proxy vendor movie error: %v", err) @@ -304,8 +322,10 @@ func (s *AlistVendorService) handleDefaultProvider( http.StatusBadRequest, model.NewAPIErrorStringResp("proxy is not enabled"), ) + return } + s.proxyURL(ctx, log, data.URL) } } @@ -363,6 +383,7 @@ func (s *AlistVendorService) handleAliSubtitle( http.StatusBadRequest, model.NewAPIErrorStringResp("id out of range"), ) + return } @@ -386,13 +407,16 @@ func (s *AlistVendorService) GenMovieInfo( } movie := s.movie.Clone() + var err error creator, err := op.LoadOrInitUserByID(movie.CreatorID) if err != nil { return nil, err } + alistCache := s.movie.AlistCache() + data, err := alistCache.Get(ctx, &cache.AlistMovieCacheFuncArgs{ UserCache: creator.Value().AlistCache(), UserAgent: utils.UA, @@ -405,6 +429,7 @@ func (s *AlistVendorService) GenMovieInfo( if movie.Subtitles == nil { movie.Subtitles = make(map[string]*dbModel.Subtitle, len(data.Subtitles)) } + movie.Subtitles[subt.Name] = &dbModel.Subtitle{ URL: fmt.Sprintf( "/api/room/movie/proxy/%s?t=subtitle&id=%d&token=%s&roomId=%s", @@ -423,6 +448,7 @@ func (s *AlistVendorService) GenMovieInfo( if err != nil { return nil, err } + movie.URL = fmt.Sprintf( "/api/room/movie/proxy/%s?token=%s&roomId=%s", movie.ID, @@ -434,12 +460,14 @@ func (s *AlistVendorService) GenMovieInfo( rawStreamURL := data.URL subPath := s.movie.SubPath() + var rawType string if subPath == "" { rawType = utils.GetURLExtension(movie.VendorInfo.Alist.Path) } else { rawType = utils.GetURLExtension(subPath) } + movie.MoreSources = []*dbModel.MoreSource{ { Name: "raw", @@ -452,6 +480,7 @@ func (s *AlistVendorService) GenMovieInfo( if movie.Subtitles == nil { movie.Subtitles = make(map[string]*dbModel.Subtitle, len(data.Subtitles)) } + movie.Subtitles[subt.Name] = &dbModel.Subtitle{ URL: fmt.Sprintf( "/api/room/movie/proxy/%s?t=subtitle&id=%d&token=%s&roomId=%s", @@ -472,7 +501,9 @@ func (s *AlistVendorService) GenMovieInfo( if err != nil { return nil, fmt.Errorf("refresh 115 movie cache error: %w", err) } + movie.URL = data.URL + movie.Subtitles = make(map[string]*dbModel.Subtitle, len(data.Subtitles)) for _, subt := range data.Subtitles { movie.Subtitles[subt.Name] = &dbModel.Subtitle{ @@ -486,6 +517,7 @@ func (s *AlistVendorService) GenMovieInfo( } movie.VendorInfo.Alist.Password = "" + return movie, nil } @@ -495,13 +527,16 @@ func (s *AlistVendorService) GenProxyMovieInfo( _, userToken string, ) (*dbModel.Movie, error) { movie := s.movie.Clone() + var err error creator, err := op.LoadOrInitUserByID(movie.CreatorID) if err != nil { return nil, err } + alistCache := s.movie.AlistCache() + data, err := alistCache.Get(ctx, &cache.AlistMovieCacheFuncArgs{ UserCache: creator.Value().AlistCache(), UserAgent: utils.UA, @@ -514,6 +549,7 @@ func (s *AlistVendorService) GenProxyMovieInfo( if movie.Subtitles == nil { movie.Subtitles = make(map[string]*dbModel.Subtitle, len(data.Subtitles)) } + movie.Subtitles[subt.Name] = &dbModel.Subtitle{ URL: fmt.Sprintf( "/api/room/movie/proxy/%s?t=subtitle&id=%d&token=%s&roomId=%s", @@ -532,6 +568,7 @@ func (s *AlistVendorService) GenProxyMovieInfo( if err != nil { return nil, err } + movie.URL = fmt.Sprintf( "/api/room/movie/proxy/%s?token=%s&roomId=%s", movie.ID, @@ -558,6 +595,7 @@ func (s *AlistVendorService) GenProxyMovieInfo( if movie.Subtitles == nil { movie.Subtitles = make(map[string]*dbModel.Subtitle, len(data.Subtitles)) } + movie.Subtitles[subt.Name] = &dbModel.Subtitle{ URL: fmt.Sprintf( "/api/room/movie/proxy/%s?t=subtitle&id=%d&token=%s&roomId=%s", @@ -592,5 +630,6 @@ func (s *AlistVendorService) GenProxyMovieInfo( } movie.VendorInfo.Alist.Password = "" + return movie, nil } diff --git a/server/handlers/vendors/vendoralist/list.go b/server/handlers/vendors/vendoralist/list.go index cfafc34..34b3612 100644 --- a/server/handlers/vendors/vendoralist/list.go +++ b/server/handlers/vendors/vendoralist/list.go @@ -66,6 +66,7 @@ func List(ctx *gin.Context) { ctx.AbortWithStatusJSON(http.StatusInternalServerError, model.NewAPIErrorResp(err)) return } + if total == 0 { ctx.JSON(http.StatusBadRequest, model.NewAPIErrorStringResp("alist server not found")) return @@ -78,9 +79,12 @@ func List(ctx *gin.Context) { http.StatusBadRequest, model.NewAPIErrorStringResp("alist server not found"), ) + return } + ctx.AbortWithStatusJSON(http.StatusInternalServerError, model.NewAPIErrorResp(err)) + return } @@ -117,6 +121,7 @@ func List(ctx *gin.Context) { AlistFSListResp: var serverID string + serverID, req.Path, err = dbModel.GetAlistServerIDFromPath(req.Path) if err != nil { ctx.AbortWithStatusJSON(http.StatusBadRequest, model.NewAPIErrorResp(err)) @@ -135,6 +140,7 @@ AlistFSListResp: } ctx.AbortWithStatusJSON(http.StatusInternalServerError, model.NewAPIErrorResp(err)) + return } @@ -155,6 +161,7 @@ AlistFSListResp: } req.Path = strings.Trim(req.Path, "/") + resp := AlistFSListResp{ Total: data.GetTotal(), Paths: model.GenDefaultPaths(req.Path, true, @@ -183,6 +190,7 @@ AlistFSListResp: } ctx.JSON(http.StatusOK, model.NewAPIDataResp(&resp)) + return } @@ -201,6 +209,7 @@ AlistFSListResp: } req.Path = strings.Trim(req.Path, "/") + resp := AlistFSListResp{ Total: data.GetTotal(), Paths: model.GenDefaultPaths(req.Path, true, diff --git a/server/handlers/vendors/vendoralist/login.go b/server/handlers/vendors/vendoralist/login.go index 7225a10..4db64f6 100644 --- a/server/handlers/vendors/vendoralist/login.go +++ b/server/handlers/vendors/vendoralist/login.go @@ -29,17 +29,21 @@ func (r *LoginReq) Validate() error { if r.Host == "" { return errors.New("host is required") } + url, err := url.Parse(r.Host) if err != nil { return err } + if url.Scheme != "http" && url.Scheme != "https" { return errors.New("host is invalid") } + r.Host = strings.TrimRight(url.String(), "/") if r.Password != "" && r.HashedPassword != "" { return errors.New("password and hashedPassword can't be both set") } + return nil } diff --git a/server/handlers/vendors/vendoralist/me.go b/server/handlers/vendors/vendoralist/me.go index ca927d5..71d10c2 100644 --- a/server/handlers/vendors/vendoralist/me.go +++ b/server/handlers/vendors/vendoralist/me.go @@ -23,6 +23,7 @@ func Me(ctx *gin.Context) { http.StatusBadRequest, model.NewAPIErrorResp(errors.New("serverID is required")), ) + return } @@ -32,7 +33,9 @@ func Me(ctx *gin.Context) { ctx.JSON(http.StatusBadRequest, model.NewAPIErrorStringResp("alist server not found")) return } + ctx.AbortWithStatusJSON(http.StatusInternalServerError, model.NewAPIErrorResp(err)) + return } @@ -67,7 +70,9 @@ func Binds(ctx *gin.Context) { })) return } + ctx.AbortWithStatusJSON(http.StatusInternalServerError, model.NewAPIErrorResp(err)) + return } diff --git a/server/handlers/vendors/vendorbilibili/bilibili.go b/server/handlers/vendors/vendorbilibili/bilibili.go index 3856709..6b7f340 100644 --- a/server/handlers/vendors/vendorbilibili/bilibili.go +++ b/server/handlers/vendors/vendorbilibili/bilibili.go @@ -33,6 +33,7 @@ func NewBilibiliVendorService(room *op.Room, movie *op.Movie) (*BilibiliVendorSe if movie.VendorInfo.Vendor != dbModel.VendorBilibili { return nil, fmt.Errorf("bilibili vendor not support vendor %s", movie.VendorInfo.Vendor) } + return &BilibiliVendorService{ room: room, movie: movie, @@ -78,6 +79,7 @@ func (s *BilibiliVendorService) handleDanmuProxy(ctx *gin.Context, log *logrus.E ctx.AbortWithStatusJSON(http.StatusInternalServerError, model.NewAPIErrorResp(err)) return } + ctx.Data(http.StatusOK, "application/xml", danmu) } @@ -88,14 +90,17 @@ func (s *BilibiliVendorService) handleLiveProxy(ctx *gin.Context, log *logrus.En ctx.AbortWithStatusJSON(http.StatusInternalServerError, model.NewAPIErrorResp(err)) return } + if len(data) == 0 { log.Error("proxy vendor movie error: live data is empty") ctx.AbortWithStatusJSON( http.StatusNotFound, model.NewAPIErrorStringResp("live data is empty"), ) + return } + ctx.Data(http.StatusOK, "application/vnd.apple.mpegurl", data) } @@ -106,6 +111,7 @@ func (s *BilibiliVendorService) handleVideoProxy(ctx *gin.Context, log *logrus.E http.StatusBadRequest, model.NewAPIErrorStringResp("proxy is not enabled"), ) + return } @@ -138,18 +144,23 @@ func (s *BilibiliVendorService) handleMpdProxy( t string, mpdC *cache.BilibiliMpdCache, ) { - var mpd string - var err error + var ( + mpd string + err error + ) + if t == "hevc" { mpd, err = cache.BilibiliMpdToString(mpdC.HevcMpd, middlewares.GetToken(ctx)) } else { mpd, err = cache.BilibiliMpdToString(mpdC.Mpd, middlewares.GetToken(ctx)) } + if err != nil { log.Errorf("proxy vendor movie error: %v", err) ctx.AbortWithStatusJSON(http.StatusInternalServerError, model.NewAPIErrorResp(err)) return } + ctx.Data(http.StatusOK, "application/dash+xml", stream.StringToBytes(mpd)) } @@ -165,16 +176,19 @@ func (s *BilibiliVendorService) handleStreamProxy( ctx.AbortWithStatusJSON(http.StatusBadRequest, model.NewAPIErrorResp(err)) return } + if streamID >= len(mpdC.URLs) { log.Errorf("proxy vendor movie error: %v", "stream id out of range") ctx.AbortWithStatusJSON( http.StatusBadRequest, model.NewAPIErrorStringResp("stream id out of range"), ) + return } headers := s.getProxyHeaders() + err = proxy.URL(ctx, mpdC.URLs[streamID], headers, @@ -196,6 +210,7 @@ func (s *BilibiliVendorService) getProxyHeaders() map[string]string { headers["Referer"] = "https://www.bilibili.com" headers["User-Agent"] = utils.UA } + return headers } @@ -228,7 +243,9 @@ func (s *BilibiliVendorService) handleSubtitleProxy(ctx *gin.Context, log *logru ctx.AbortWithStatusJSON(http.StatusInternalServerError, model.NewAPIErrorResp(err)) return } + http.ServeContent(ctx.Writer, ctx.Request, id, time.Now(), bytes.NewReader(srtData)) + return } @@ -246,7 +263,9 @@ func (s *BilibiliVendorService) GenMovieInfo( } movie := s.movie.Clone() + var err error + if movie.IsFolder { return nil, errors.New("bilibili folder not support") } @@ -267,6 +286,7 @@ func (s *BilibiliVendorService) GenMovieInfo( userToken, movie.RoomID, ) + return movie, nil } @@ -280,10 +300,12 @@ func (s *BilibiliVendorService) GenMovieInfo( var str string if movie.VendorInfo.Bilibili.Shared { var u *op.UserEntry + u, err = op.LoadOrInitUserByID(movie.CreatorID) if err != nil { return nil, err } + str, err = s.movie.BilibiliCache().NoSharedMovie.LoadOrStore( ctx, movie.CreatorID, @@ -292,19 +314,23 @@ func (s *BilibiliVendorService) GenMovieInfo( } else { str, err = s.movie.BilibiliCache().NoSharedMovie.LoadOrStore(ctx, user.ID, user.BilibiliCache()) } + if err != nil { return nil, err } + movie.URL = str srt, err := bmc.Subtitle.Get(ctx, user.BilibiliCache()) if err != nil { return nil, err } + for k := range srt { if movie.Subtitles == nil { movie.Subtitles = make(map[string]*dbModel.Subtitle, len(srt)) } + movie.Subtitles[k] = &dbModel.Subtitle{ URL: fmt.Sprintf( "/api/room/movie/proxy/%s?t=subtitle&n=%s&token=%s&roomId=%s", @@ -316,6 +342,7 @@ func (s *BilibiliVendorService) GenMovieInfo( Type: "srt", } } + return movie, nil } @@ -325,7 +352,9 @@ func (s *BilibiliVendorService) GenProxyMovieInfo( _, userToken string, ) (*dbModel.Movie, error) { movie := s.movie.Clone() + var err error + if movie.IsFolder { return nil, errors.New("bilibili folder not support") } @@ -346,6 +375,7 @@ func (s *BilibiliVendorService) GenProxyMovieInfo( userToken, movie.RoomID, ) + return movie, nil } @@ -375,14 +405,17 @@ func (s *BilibiliVendorService) GenProxyMovieInfo( ), }, } + srt, err := bmc.Subtitle.Get(ctx, user.BilibiliCache()) if err != nil { return nil, err } + for k := range srt { if movie.Subtitles == nil { movie.Subtitles = make(map[string]*dbModel.Subtitle, len(srt)) } + movie.Subtitles[k] = &dbModel.Subtitle{ URL: fmt.Sprintf( "/api/room/movie/proxy/%s?t=subtitle&n=%s&token=%s&roomId=%s", @@ -394,5 +427,6 @@ func (s *BilibiliVendorService) GenProxyMovieInfo( Type: "srt", } } + return movie, nil } diff --git a/server/handlers/vendors/vendorbilibili/danmu.go b/server/handlers/vendors/vendorbilibili/danmu.go index 1144525..c192462 100644 --- a/server/handlers/vendors/vendorbilibili/danmu.go +++ b/server/handlers/vendors/vendorbilibili/danmu.go @@ -44,10 +44,12 @@ var headerLen = binary.Size(header{}) func (h *header) Marshal() ([]byte, error) { buf := bytes.NewBuffer(make([]byte, 0, headerLen)) + err := binary.Write(buf, binary.BigEndian, h) if err != nil { return nil, err } + return buf.Bytes(), nil } @@ -67,6 +69,7 @@ func newHeader(size uint32, command command, sequence uint32) header { case CmdHeartbeat, CmdAuth: h.Version = 1 } + return h } @@ -95,20 +98,25 @@ func writeVerifyHello(conn *websocket.Conn, hello *verifyHello) error { if err != nil { return err } + header := newHeader(uint32(len(msg)), CmdAuth, 1) + headerBytes, err := header.Marshal() if err != nil { return err } + return conn.WriteMessage(websocket.BinaryMessage, append(headerBytes, msg...)) } func writeHeartbeat(conn *websocket.Conn, sequence uint32) error { header := newHeader(0, CmdHeartbeat, sequence) + headerBytes, err := header.Marshal() if err != nil { return err } + return conn.WriteMessage(websocket.BinaryMessage, headerBytes) } @@ -126,9 +134,11 @@ func (v *BilibiliVendorService) StreamDanmu( if err != nil { return err } + if len(resp.GetHostList()) == 0 { return errors.New("no host list") } + wssHost := resp.GetHostList()[0].GetHost() wssPort := resp.GetHostList()[0].GetWssPort() @@ -167,6 +177,7 @@ func (v *BilibiliVendorService) StreamDanmu( go func() { ticker := time.NewTicker(time.Second * 20) defer ticker.Stop() + sequence := uint32(1) for { select { @@ -174,6 +185,7 @@ func (v *BilibiliVendorService) StreamDanmu( return case <-ticker.C: sequence++ + err = writeHeartbeat(conn, sequence) if err != nil { log.Errorf("write heartbeat error: %v", err) @@ -191,16 +203,20 @@ func (v *BilibiliVendorService) StreamDanmu( if err != nil { return err } + header := header{} + err = header.Unmarshal(message[:headerLen]) if err != nil { return err } + switch header.Command { case CmdHeartbeatReply: continue default: } + data := message[headerLen:] switch header.Version { case 2: @@ -210,6 +226,7 @@ func (v *BilibiliVendorService) StreamDanmu( return err } defer zlibReader.Close() + data, err = io.ReadAll(zlibReader) if err != nil { return err @@ -217,28 +234,36 @@ func (v *BilibiliVendorService) StreamDanmu( case 3: // brotli brotliReader := brotli.NewReader(bytes.NewReader(data)) + data, err = io.ReadAll(brotliReader) if err != nil { return err } + data = data[headerLen:] } + reply := replyCmd{} + err = json.Unmarshal(data, &reply) if err != nil { return err } + switch reply.Cmd { case "DANMU_MSG": danmu := danmuMsg{} + err = json.Unmarshal(data, &danmu) if err != nil { return err } + content, ok := danmu.Info[1].(string) if !ok { return errors.New("content is not string") } + _ = handler(content) case "DM_INTERACTION": } diff --git a/server/handlers/vendors/vendorbilibili/login.go b/server/handlers/vendors/vendorbilibili/login.go index 4548197..0785659 100644 --- a/server/handlers/vendors/vendorbilibili/login.go +++ b/server/handlers/vendors/vendorbilibili/login.go @@ -23,6 +23,7 @@ func NewQRCode(ctx *gin.Context) { ctx.AbortWithStatusJSON(http.StatusInternalServerError, model.NewAPIErrorResp(err)) return } + ctx.JSON(http.StatusOK, model.NewAPIDataResp(r)) } @@ -51,6 +52,7 @@ func LoginWithQR(ctx *gin.Context) { } backend := ctx.Query("backend") + resp, err := vendor.LoadBilibiliClient(backend). LoginWithQRCode(ctx, &bilibili.LoginWithQRCodeReq{ Key: req.Key, @@ -86,6 +88,7 @@ func LoginWithQR(ctx *gin.Context) { ctx.AbortWithStatusJSON(http.StatusInternalServerError, model.NewAPIErrorResp(err)) return } + _, err = user.BilibiliCache(). Data(). Refresh(ctx, func(_ context.Context, _ ...struct{}) (*cache.BilibiliUserCacheData, error) { @@ -98,6 +101,7 @@ func LoginWithQR(ctx *gin.Context) { ctx.AbortWithStatusJSON(http.StatusInternalServerError, model.NewAPIErrorResp(err)) return } + ctx.JSON(http.StatusOK, model.NewAPIDataResp(gin.H{ "status": "success", })) @@ -106,6 +110,7 @@ func LoginWithQR(ctx *gin.Context) { http.StatusInternalServerError, model.NewAPIErrorStringResp("unknown status"), ) + return } } @@ -116,6 +121,7 @@ func NewCaptcha(ctx *gin.Context) { ctx.AbortWithStatusJSON(http.StatusInternalServerError, model.NewAPIErrorResp(err)) return } + ctx.JSON(http.StatusOK, model.NewAPIDataResp(r)) } @@ -130,15 +136,19 @@ func (r *SMSReq) Validate() error { if r.Token == "" { return errors.New("token is empty") } + if r.Challenge == "" { return errors.New("challenge is empty") } + if r.V == "" { return errors.New("validate is empty") } + if r.Telephone == "" { return errors.New("telephone is empty") } + return nil } @@ -163,6 +173,7 @@ func NewSMS(ctx *gin.Context) { ctx.AbortWithStatusJSON(http.StatusInternalServerError, model.NewAPIErrorResp(err)) return } + ctx.JSON(http.StatusOK, model.NewAPIDataResp(gin.H{ "captchaKey": r.GetCaptchaKey(), })) @@ -178,12 +189,15 @@ func (r *SMSLoginReq) Validate() error { if r.Telephone == "" { return errors.New("telephone is empty") } + if r.CaptchaKey == "" { return errors.New("captchaKey is empty") } + if r.Code == "" { return errors.New("code is empty") } + return nil } @@ -201,6 +215,7 @@ func LoginWithSMS(ctx *gin.Context) { } backend := ctx.Query("backend") + c, err := vendor.LoadBilibiliClient(backend).LoginWithSMS(ctx, &bilibili.LoginWithSMSReq{ Phone: req.Telephone, CaptchaKey: req.CaptchaKey, @@ -210,6 +225,7 @@ func LoginWithSMS(ctx *gin.Context) { ctx.AbortWithStatusJSON(http.StatusInternalServerError, model.NewAPIErrorResp(err)) return } + _, err = db.CreateOrSaveBilibiliVendor(&dbModel.BilibiliVendor{ UserID: user.ID, Backend: backend, @@ -219,6 +235,7 @@ func LoginWithSMS(ctx *gin.Context) { ctx.AbortWithStatusJSON(http.StatusInternalServerError, model.NewAPIErrorResp(err)) return } + _, err = user.BilibiliCache(). Data(). Refresh(ctx, func(_ context.Context, _ ...struct{}) (*cache.BilibiliUserCacheData, error) { @@ -231,20 +248,24 @@ func LoginWithSMS(ctx *gin.Context) { ctx.AbortWithStatusJSON(http.StatusInternalServerError, model.NewAPIErrorResp(err)) return } + ctx.Status(http.StatusNoContent) } func Logout(ctx *gin.Context) { log := middlewares.GetLogger(ctx) user := middlewares.GetUserEntry(ctx).Value() + err := db.DeleteBilibiliVendor(user.ID) if err != nil { ctx.AbortWithStatusJSON(http.StatusInternalServerError, model.NewAPIErrorResp(err)) return } + err = user.BilibiliCache().Clear(ctx) if err != nil { log.Errorf("clear bilibili cache: %v", err) } + ctx.Status(http.StatusNoContent) } diff --git a/server/handlers/vendors/vendorbilibili/me.go b/server/handlers/vendors/vendorbilibili/me.go index 1f8f7c4..a583b89 100644 --- a/server/handlers/vendors/vendorbilibili/me.go +++ b/server/handlers/vendors/vendorbilibili/me.go @@ -26,15 +26,19 @@ func Me(ctx *gin.Context) { })) return } + ctx.AbortWithStatusJSON(http.StatusInternalServerError, model.NewAPIErrorResp(err)) + return } + if len(bucd.Cookies) == 0 { ctx.JSON(http.StatusOK, model.NewAPIDataResp(&BilibiliMeResp{ IsLogin: false, })) return } + resp, err := vendor.LoadBilibiliClient(bucd.Backend).UserInfo(ctx, &bilibili.UserInfoReq{ Cookies: utils.HTTPCookieToMap(bucd.Cookies), }) diff --git a/server/handlers/vendors/vendorbilibili/parse.go b/server/handlers/vendors/vendorbilibili/parse.go index 4e768ac..7abd5bb 100644 --- a/server/handlers/vendors/vendorbilibili/parse.go +++ b/server/handlers/vendors/vendorbilibili/parse.go @@ -51,6 +51,7 @@ func Parse(ctx *gin.Context) { // can be no login var cookies []*http.Cookie + bucd, err := user.BilibiliCache().Get(ctx) if err != nil { if !errors.Is(err, db.NotFoundError(db.ErrVendorNotFound)) { @@ -72,6 +73,7 @@ func Parse(ctx *gin.Context) { ctx.AbortWithStatusJSON(http.StatusBadRequest, model.NewAPIErrorResp(err)) return } + ctx.JSON(http.StatusOK, model.NewAPIDataResp(resp)) case "av": aid, err := strconv.ParseUint(resp.GetId(), 10, 64) @@ -79,6 +81,7 @@ func Parse(ctx *gin.Context) { ctx.AbortWithStatusJSON(http.StatusInternalServerError, model.NewAPIErrorResp(err)) return } + resp, err := cli.ParseVideoPage(ctx, &bilibili.ParseVideoPageReq{ Cookies: utils.HTTPCookieToMap(cookies), Aid: aid, @@ -88,6 +91,7 @@ func Parse(ctx *gin.Context) { ctx.AbortWithStatusJSON(http.StatusBadRequest, model.NewAPIErrorResp(err)) return } + ctx.JSON(http.StatusOK, model.NewAPIDataResp(resp)) case "ep": epid, err := strconv.ParseUint(resp.GetId(), 10, 64) @@ -95,6 +99,7 @@ func Parse(ctx *gin.Context) { ctx.AbortWithStatusJSON(http.StatusInternalServerError, model.NewAPIErrorResp(err)) return } + resp, err := cli.ParsePGCPage(ctx, &bilibili.ParsePGCPageReq{ Cookies: utils.HTTPCookieToMap(cookies), Epid: epid, @@ -103,6 +108,7 @@ func Parse(ctx *gin.Context) { ctx.AbortWithStatusJSON(http.StatusBadRequest, model.NewAPIErrorResp(err)) return } + ctx.JSON(http.StatusOK, model.NewAPIDataResp(resp)) case "ss": ssid, err := strconv.ParseUint(resp.GetId(), 10, 64) @@ -110,6 +116,7 @@ func Parse(ctx *gin.Context) { ctx.AbortWithStatusJSON(http.StatusInternalServerError, model.NewAPIErrorResp(err)) return } + resp, err := cli.ParsePGCPage(ctx, &bilibili.ParsePGCPageReq{ Cookies: utils.HTTPCookieToMap(cookies), Ssid: ssid, @@ -118,6 +125,7 @@ func Parse(ctx *gin.Context) { ctx.AbortWithStatusJSON(http.StatusBadRequest, model.NewAPIErrorResp(err)) return } + ctx.JSON(http.StatusOK, model.NewAPIDataResp(resp)) case "live": roomid, err := strconv.ParseUint(resp.GetId(), 10, 64) @@ -125,6 +133,7 @@ func Parse(ctx *gin.Context) { ctx.AbortWithStatusJSON(http.StatusInternalServerError, model.NewAPIErrorResp(err)) return } + resp, err := cli.ParseLivePage(ctx, &bilibili.ParseLivePageReq{ Cookies: utils.HTTPCookieToMap(cookies), RoomID: roomid, @@ -133,12 +142,14 @@ func Parse(ctx *gin.Context) { ctx.AbortWithStatusJSON(http.StatusBadRequest, model.NewAPIErrorResp(err)) return } + ctx.JSON(http.StatusOK, model.NewAPIDataResp(resp)) default: ctx.AbortWithStatusJSON( http.StatusInternalServerError, model.NewAPIErrorStringResp("unknown match type "+resp.GetType()), ) + return } } diff --git a/server/handlers/vendors/vendoremby/emby.go b/server/handlers/vendors/vendoremby/emby.go index 93cc112..3d69195 100644 --- a/server/handlers/vendors/vendoremby/emby.go +++ b/server/handlers/vendors/vendoremby/emby.go @@ -31,6 +31,7 @@ func NewEmbyVendorService(room *op.Room, movie *op.Movie) (*EmbyVendorService, e if movie.VendorInfo.Vendor != dbModel.VendorEmby { return nil, fmt.Errorf("emby vendor not support vendor %s", movie.VendorInfo.Vendor) } + return &EmbyVendorService{ room: room, movie: movie, @@ -51,6 +52,7 @@ func (s *EmbyVendorService) ListDynamicMovie( if reqUser.ID != s.movie.CreatorID { return nil, fmt.Errorf("list vendor dynamic folder error: %w", dbModel.ErrNoPermission) } + user := reqUser resp := &model.MovieList{ @@ -61,9 +63,11 @@ func (s *EmbyVendorService) ListDynamicMovie( if err != nil { return nil, fmt.Errorf("load emby server id error: %w", err) } + if subPath != "" { truePath = subPath } + aucd, err := user.EmbyCache().LoadOrStore(ctx, serverID) if err != nil { if errors.Is(err, db.NotFoundError(db.ErrVendorNotFound)) { @@ -71,6 +75,7 @@ func (s *EmbyVendorService) ListDynamicMovie( } return nil, err } + data, err := s.Client().FsList(ctx, &emby.FsListReq{ Host: aucd.Host, Path: truePath, @@ -83,7 +88,9 @@ func (s *EmbyVendorService) ListDynamicMovie( if err != nil { return nil, fmt.Errorf("emby fs list error: %w", err) } + resp.Total = int64(data.GetTotal()) + resp.Movies = make([]*model.Movie, len(data.GetItems())) for i, flr := range data.GetItems() { resp.Movies[i] = &model.Movie{ @@ -106,6 +113,7 @@ func (s *EmbyVendorService) ListDynamicMovie( }, } } + return resp, nil } @@ -118,6 +126,7 @@ func (s *EmbyVendorService) handleProxyMovie(ctx *gin.Context) { http.StatusBadRequest, model.NewAPIErrorStringResp("proxy is not enabled"), ) + return } @@ -154,6 +163,7 @@ func (s *EmbyVendorService) handleProxyMovie(ctx *gin.Context) { http.StatusBadRequest, model.NewAPIErrorStringResp("source out of range"), ) + return } @@ -231,6 +241,7 @@ func (s *EmbyVendorService) handleSubtitle(ctx *gin.Context) error { time.Now(), bytes.NewReader(data), ) + return nil } @@ -258,12 +269,14 @@ func (s *EmbyVendorService) GenMovieInfo( } movie := s.movie.Clone() + var err error u, err := op.LoadOrInitUserByID(movie.CreatorID) if err != nil { return nil, err } + data, err := s.movie.EmbyCache().Get(ctx, u.Value().EmbyCache()) if err != nil { return nil, err @@ -272,16 +285,19 @@ func (s *EmbyVendorService) GenMovieInfo( if len(data.Sources) == 0 { return nil, errors.New("no source") } + movie.URL = data.Sources[0].URL for _, s := range data.Sources[0].Subtitles { if movie.Subtitles == nil { movie.Subtitles = make(map[string]*dbModel.Subtitle, len(data.Sources[0].Subtitles)) } + movie.Subtitles[s.Name] = &dbModel.Subtitle{ URL: s.URL, Type: s.Type, } } + for _, s := range data.Sources[1:] { movie.MoreSources = append(movie.MoreSources, &dbModel.MoreSource{ @@ -294,6 +310,7 @@ func (s *EmbyVendorService) GenMovieInfo( if movie.Subtitles == nil { movie.Subtitles = make(map[string]*dbModel.Subtitle, len(s.Subtitles)) } + movie.Subtitles[subt.Name] = &dbModel.Subtitle{ URL: subt.URL, Type: subt.Type, @@ -310,12 +327,14 @@ func (s *EmbyVendorService) GenProxyMovieInfo( _, userToken string, ) (*dbModel.Movie, error) { movie := s.movie.Clone() + var err error u, err := op.LoadOrInitUserByID(movie.CreatorID) if err != nil { return nil, err } + data, err := s.movie.EmbyCache().Get(ctx, u.Value().EmbyCache()) if err != nil { return nil, err @@ -326,6 +345,7 @@ func (s *EmbyVendorService) GenProxyMovieInfo( if si != len(data.Sources)-1 { continue } + if movie.URL == "" { return nil, errors.New("no source") } @@ -335,6 +355,7 @@ func (s *EmbyVendorService) GenProxyMovieInfo( if err != nil { return nil, err } + rawQuery := url.Values{} rawQuery.Set("source", strconv.Itoa(si)) rawQuery.Set("token", userToken) @@ -360,10 +381,12 @@ func (s *EmbyVendorService) GenProxyMovieInfo( if len(es.Subtitles) == 0 { continue } + for sbi, s := range es.Subtitles { if movie.Subtitles == nil { movie.Subtitles = make(map[string]*dbModel.Subtitle, len(es.Subtitles)) } + rawQuery := url.Values{} rawQuery.Set("t", "subtitle") rawQuery.Set("source", strconv.Itoa(si)) diff --git a/server/handlers/vendors/vendoremby/list.go b/server/handlers/vendors/vendoremby/list.go index 960ad20..f475935 100644 --- a/server/handlers/vendors/vendoremby/list.go +++ b/server/handlers/vendors/vendoremby/list.go @@ -61,8 +61,10 @@ func List(ctx *gin.Context) { "keywords is not supported when not choose server (server id is empty)", ), ) + return } + socpes := [](func(*gorm.DB) *gorm.DB){ db.OrderByCreatedAtAsc, } @@ -72,6 +74,7 @@ func List(ctx *gin.Context) { ctx.AbortWithStatusJSON(http.StatusInternalServerError, model.NewAPIErrorResp(err)) return } + if total == 0 { ctx.JSON(http.StatusBadRequest, model.NewAPIErrorStringResp("emby server not found")) return @@ -84,9 +87,12 @@ func List(ctx *gin.Context) { http.StatusBadRequest, model.NewAPIErrorStringResp("emby server not found"), ) + return } + ctx.AbortWithStatusJSON(http.StatusInternalServerError, model.NewAPIErrorResp(err)) + return } @@ -124,6 +130,7 @@ func List(ctx *gin.Context) { EmbyFSListResp: var serverID string + serverID, req.Path, err = dbModel.GetEmbyServerIDFromPath(req.Path) if err != nil { ctx.AbortWithStatusJSON(http.StatusBadRequest, model.NewAPIErrorResp(err)) @@ -136,11 +143,14 @@ EmbyFSListResp: ctx.JSON(http.StatusBadRequest, model.NewAPIErrorStringResp("emby server not found")) return } + ctx.AbortWithStatusJSON(http.StatusInternalServerError, model.NewAPIErrorResp(err)) + return } cli := vendor.LoadEmbyClient(ctx.Query("backend")) + data, err := cli.FsList(ctx, &emby.FsListReq{ Host: aucd.Host, Path: req.Path, @@ -155,6 +165,7 @@ EmbyFSListResp: http.StatusInternalServerError, model.NewAPIErrorResp(fmt.Errorf("emby fs list error: %w", err)), ) + return } @@ -168,11 +179,13 @@ EmbyFSListResp: if p.GetPath() == "1" { n = aucd.Host } + resp.Paths = append(resp.Paths, &model.Path{ Name: n, Path: fmt.Sprintf("%s/%s", aucd.ServerID, p.GetPath()), }) } + for _, i := range data.GetItems() { resp.Items = append(resp.Items, &EmbyFileItem{ Item: &model.Item{ diff --git a/server/handlers/vendors/vendoremby/login.go b/server/handlers/vendors/vendoremby/login.go index 1ad44de..4ce0559 100644 --- a/server/handlers/vendors/vendoremby/login.go +++ b/server/handlers/vendors/vendoremby/login.go @@ -28,17 +28,21 @@ func (r *LoginReq) Validate() error { if r.Host == "" { return errors.New("host is required") } + url, err := url.Parse(r.Host) if err != nil { return err } + if url.Scheme != "http" && url.Scheme != "https" { return errors.New("host is invalid") } + r.Host = strings.TrimRight(url.String(), "/") if r.Username == "" { return errors.New("username is required") } + return nil } @@ -73,6 +77,7 @@ func Login(ctx *gin.Context) { http.StatusInternalServerError, model.NewAPIErrorStringResp("serverID is empty"), ) + return } @@ -135,6 +140,7 @@ func logoutEmby(eucd *cache.EmbyUserCacheData) { if eucd == nil || eucd.APIKey == "" { return } + _, _ = vendor.LoadEmbyClient(eucd.Backend).Logout(context.Background(), &emby.LogoutReq{ Host: eucd.Host, Token: eucd.APIKey, diff --git a/server/handlers/vendors/vendoremby/me.go b/server/handlers/vendors/vendoremby/me.go index 79a5faf..3d338f1 100644 --- a/server/handlers/vendors/vendoremby/me.go +++ b/server/handlers/vendors/vendoremby/me.go @@ -23,6 +23,7 @@ func Me(ctx *gin.Context) { http.StatusBadRequest, model.NewAPIErrorResp(errors.New("serverID is required")), ) + return } @@ -32,7 +33,9 @@ func Me(ctx *gin.Context) { ctx.JSON(http.StatusBadRequest, model.NewAPIErrorStringResp("emby server not found")) return } + ctx.AbortWithStatusJSON(http.StatusInternalServerError, model.NewAPIErrorResp(err)) + return } @@ -67,7 +70,9 @@ func Binds(ctx *gin.Context) { })) return } + ctx.AbortWithStatusJSON(http.StatusInternalServerError, model.NewAPIErrorResp(err)) + return } diff --git a/server/handlers/vendors/vendors.go b/server/handlers/vendors/vendors.go index 84eec28..517b479 100644 --- a/server/handlers/vendors/vendors.go +++ b/server/handlers/vendors/vendors.go @@ -31,8 +31,10 @@ func Backends(ctx *gin.Context) { http.StatusBadRequest, model.NewAPIErrorStringResp("invalid vendor name"), ) + return } + ctx.JSON(http.StatusOK, model.NewAPIDataResp(backends)) } diff --git a/server/handlers/websocket.go b/server/handlers/websocket.go index 232096b..e34a3d6 100644 --- a/server/handlers/websocket.go +++ b/server/handlers/websocket.go @@ -46,6 +46,7 @@ func isNormalCloseError(err error) bool { if !errors.As(err, &we) { return false } + return we.Code == websocket.CloseNormalClosure } @@ -54,17 +55,20 @@ func NewWSMessageHandler(u *op.User, r *op.Room, l *log.Entry) func(c *websocket client, err := r.NewClient(u, c) if err != nil { l.Errorf("ws: register client error: %v", err) + wc, err2 := c.NextWriter(websocket.BinaryMessage) if err2 != nil { return err2 } defer wc.Close() + em := pb.Message{ Type: pb.MessageType_ERROR, Payload: &pb.Message_ErrorMessage{ ErrorMessage: fmt.Sprintf("register client error: %v", err), }, } + return em.Encode(wc) } @@ -81,9 +85,11 @@ func NewWSMessageHandler(u *op.User, r *op.Room, l *log.Entry) func(c *websocket if isNormalCloseError(err) { return } + l.Errorf("ws: handle reader message error: %v", err) } }() + return handleWriterMessage(client, l) } } @@ -92,6 +98,7 @@ func handleClientDisconnection(r *op.Room, client *op.Client, l *log.Entry) { if err := r.UnregisterClient(client); err != nil { l.Errorf("ws: unregister client error: %v", err) } + client.Close() l.Info("ws: disconnected") } @@ -112,6 +119,7 @@ func handleWriterMessage(c *op.Client, l *log.Entry) error { return err } } + return nil } @@ -152,6 +160,7 @@ func handleReaderMessage(c *op.Client, l *log.Entry) error { defer func() { leaveWebRTC(c) c.Close() + if r := recover(); r != nil { l.Errorf("ws: panic: %v", r) } @@ -163,11 +172,14 @@ func handleReaderMessage(c *op.Client, l *log.Entry) error { if isNormalCloseError(err) { return nil } + l.Errorf("ws: read message error: %v", err) + return err } l.Debugf("ws: receive message: %v", msg.String()) + if err = handleElementMsg(c, msg); err != nil { l.Errorf("ws: handle message error: %v", err) return err @@ -321,6 +333,7 @@ func handleWebRTCJoin(cli *op.Client) error { } cli.SetRTCJoined(true) + return cli.Broadcast(&pb.Message{ Type: pb.MessageType_WEBRTC_JOIN, Sender: &pb.Sender{ @@ -342,6 +355,7 @@ func handleWebRTCLeave(cli *op.Client) error { } cli.SetRTCJoined(false) + return cli.Broadcast(&pb.Message{ Type: pb.MessageType_WEBRTC_LEAVE, Sender: &pb.Sender{ @@ -360,13 +374,16 @@ func calculateTimeDiff(timestamp int64) float64 { if timestamp == 0 { return 0.0 } + timeDiff := time.Since(time.UnixMilli(timestamp)).Seconds() if timeDiff < 0 { return 0 } + if timeDiff > 1.5 { return 1.5 } + return timeDiff } @@ -374,14 +391,17 @@ func handleChatMessage(cli *op.Client, message string) error { if message == "" { return sendErrorMessage(cli, "message is empty") } + sanitizedMessage := template.HTMLEscapeString(message) if len(sanitizedMessage) > MaxChatMessageLength { return sendErrorMessage(cli, "message too long") } + err := cli.SendChatMessage(sanitizedMessage) if err != nil && errors.Is(err, model.ErrNoPermission) { return sendErrorMessage(cli, "failed to send message due to permission issue") } + return err } @@ -390,6 +410,7 @@ func handleStatusMessage(cli *op.Client, msg *pb.Message, timeDiff float64) erro if playbackStatus == nil { return sendErrorMessage(cli, "playback status is nil") } + err := cli.SetStatus( playbackStatus.GetIsPlaying(), playbackStatus.GetCurrentTime(), @@ -399,11 +420,13 @@ func handleStatusMessage(cli *op.Client, msg *pb.Message, timeDiff float64) erro if err != nil { return sendErrorMessage(cli, fmt.Sprintf("set status error: %v", err)) } + return nil } func handleSyncMessage(cli *op.Client) error { status := cli.Room().Current().Status + return cli.Send(&pb.Message{ Type: pb.MessageType_SYNC, Timestamp: time.Now().UnixMilli(), @@ -424,31 +447,38 @@ func handleExpiredMessage(cli *op.Client, expirationID uint64) error { if err != nil { return sendErrorMessage(cli, fmt.Sprintf("get movie by id error: %v", err)) } + ctx, cancel := context.WithTimeout(context.Background(), time.Second*10) defer cancel() + expired, err := currentMovie.CheckExpired(ctx, expirationID) if err != nil { return sendErrorMessage(cli, fmt.Sprintf("check expired error: %v", err)) } + if expired { return cli.Send(&pb.Message{ Type: pb.MessageType_EXPIRED, }) } } + return nil } func handleCheckStatusMessage(cli *op.Client, msg *pb.Message, timeDiff float64) error { current := cli.Room().Current() status := current.Status + cliStatus := msg.GetPlaybackStatus() if cliStatus == nil { return sendErrorMessage(cli, "playback status is nil") } + if needsSync(cliStatus, status, timeDiff) { return sendSyncStatus(cli, &status) } + return nil } @@ -459,6 +489,7 @@ func needsSync(clientStatus *pb.Status, serverStatus model.Status, timeDiff floa serverStatus.CurrentTime-maxInterval > clientStatus.GetCurrentTime()+timeDiff { return true } + return false } diff --git a/server/middlewares/auth.go b/server/middlewares/auth.go index d71c7af..e99129c 100644 --- a/server/middlewares/auth.go +++ b/server/middlewares/auth.go @@ -51,10 +51,12 @@ func authUser(authorization string) (*AuthClaims, error) { if err != nil || !t.Valid { return nil, ErrAuthFailed } + claims, ok := t.Claims.(*AuthClaims) if !ok { return nil, ErrAuthFailed } + return claims, nil } @@ -69,6 +71,7 @@ func AuthRoom(authorization, roomID string) (*op.UserEntry, *op.RoomEntry, error } user := userE.Value() + roomE, err := authenticateRoomAccess(roomID, user) if err != nil { return nil, nil, err @@ -98,6 +101,7 @@ func authenticateUser(authorization string) (*op.UserEntry, error) { if err != nil { return nil, err } + user := userE.Value() if err := validateUser(user, claims.UserVersion); err != nil { @@ -118,15 +122,19 @@ func validateUser(user *op.User, userVersion uint32) error { if user.IsGuest() { return ErrUserGuest } + if !user.CheckVersion(userVersion) { return ErrAuthExpired } + if user.IsBanned() { return ErrUserBanned } + if user.IsPending() { return ErrUserPending } + return nil } @@ -135,6 +143,7 @@ func authenticateRoomAccess(roomID string, user *op.User) (*op.RoomEntry, error) if err != nil { return nil, err } + room := roomE.Value() if err := validateRoomAccess(room, user); err != nil { @@ -149,6 +158,7 @@ func validateRoomAccess(room *op.Room, user *op.User) error { if room.Settings.DisableGuest { return ErrUserGuest } + if room.NeedPassword() { return ErrUserGuest } @@ -157,17 +167,22 @@ func validateRoomAccess(room *op.Room, user *op.User) error { if room.IsBanned() { return ErrRoomBanned } + if room.IsPending() { return ErrRoomPending } - var status dbModel.RoomMemberStatus - var err error + var ( + status dbModel.RoomMemberStatus + err error + ) + if room.NeedPassword() { status, err = room.LoadMemberStatus(user.ID) } else { status, err = room.LoadOrCreateMemberStatus(user.ID) } + if err != nil { return err } @@ -175,6 +190,7 @@ func validateRoomAccess(room *op.Room, user *op.User) error { if status.IsBanned() { return ErrUserBannedFromRoom } + if status.IsPending() { return ErrUserPending } @@ -196,6 +212,7 @@ func AuthUser(authorization string) (*op.UserEntry, error) { if err != nil { return nil, err } + user := userE.Value() if err := validateAuthUser(user, claims.UserVersion); err != nil { @@ -209,15 +226,19 @@ func validateAuthUser(user *op.User, userVersion uint32) error { if user.IsGuest() { return ErrUserGuest } + if !user.CheckVersion(userVersion) { return ErrAuthExpired } + if user.IsBanned() { return ErrUserBanned } + if user.IsPending() { return ErrUserPending } + return nil } @@ -239,6 +260,7 @@ func NewAuthUserToken(user *op.User) (string, error) { ExpiresAt: jwt.NewNumericDate(time.Now().Add(t)), }, } + return jwt.NewWithClaims(jwt.SigningMethodHS256, claims). SignedString(stream.StringToBytes(conf.Conf.Jwt.Secret)) } @@ -247,12 +269,15 @@ func validateNewAuthUserToken(user *op.User) error { if user.IsBanned() { return ErrUserBanned } + if user.IsPending() { return ErrUserPending } + if user.IsGuest() { return ErrUserGuest } + return nil } @@ -262,11 +287,13 @@ func AuthUserMiddleware(ctx *gin.Context) { ctx.AbortWithStatusJSON(http.StatusUnauthorized, model.NewAPIErrorResp(ErrEmptyToken)) return } + userE, err := AuthUser(token) if err != nil { ctx.AbortWithStatusJSON(http.StatusUnauthorized, model.NewAPIErrorResp(err)) return } + user := userE.Value() ctx.Set("user", userE) @@ -279,11 +306,13 @@ func AuthRoomMiddleware(ctx *gin.Context) { ctx.AbortWithStatusJSON(http.StatusUnauthorized, model.NewAPIErrorResp(err)) return } + userE, roomE, err := AuthRoom(GetAuthorizationTokenFromContext(ctx), roomID) if err != nil { ctx.AbortWithStatusJSON(http.StatusUnauthorized, model.NewAPIErrorResp(err)) return } + user := userE.Value() room := roomE.Value() @@ -294,6 +323,7 @@ func AuthRoomMiddleware(ctx *gin.Context) { func AuthRoomWithoutGuestMiddleware(ctx *gin.Context) { AuthRoomMiddleware(ctx) + if ctx.IsAborted() { return } @@ -304,8 +334,10 @@ func AuthRoomWithoutGuestMiddleware(ctx *gin.Context) { http.StatusInternalServerError, model.NewAPIErrorResp(errors.New("invalid user type")), ) + return } + user := userEntry.Value() if user.IsGuest() { ctx.AbortWithStatusJSON(http.StatusForbidden, model.NewAPIErrorResp(ErrUserGuest)) @@ -315,6 +347,7 @@ func AuthRoomWithoutGuestMiddleware(ctx *gin.Context) { func AuthRoomAdminMiddleware(ctx *gin.Context) { AuthRoomMiddleware(ctx) + if ctx.IsAborted() { return } @@ -325,17 +358,22 @@ func AuthRoomAdminMiddleware(ctx *gin.Context) { http.StatusInternalServerError, model.NewAPIErrorResp(errors.New("invalid room type")), ) + return } + room := roomEntry.Value() + userEntry, ok := ctx.MustGet("user").(*synccache.Entry[*op.User]) if !ok { ctx.JSON( http.StatusInternalServerError, model.NewAPIErrorResp(errors.New("invalid user type")), ) + return } + user := userEntry.Value() if !user.IsRoomAdmin(room) { @@ -346,6 +384,7 @@ func AuthRoomAdminMiddleware(ctx *gin.Context) { func AuthRoomCreatorMiddleware(ctx *gin.Context) { AuthRoomMiddleware(ctx) + if ctx.IsAborted() { return } @@ -356,17 +395,22 @@ func AuthRoomCreatorMiddleware(ctx *gin.Context) { http.StatusInternalServerError, model.NewAPIErrorResp(errors.New("invalid room type")), ) + return } + room := roomEntry.Value() + userEntry, ok := ctx.MustGet("user").(*synccache.Entry[*op.User]) if !ok { ctx.JSON( http.StatusInternalServerError, model.NewAPIErrorResp(errors.New("invalid user type")), ) + return } + user := userEntry.Value() if room.CreatorID != user.ID { @@ -377,6 +421,7 @@ func AuthRoomCreatorMiddleware(ctx *gin.Context) { func AuthAdminMiddleware(ctx *gin.Context) { AuthUserMiddleware(ctx) + if ctx.IsAborted() { return } @@ -387,8 +432,10 @@ func AuthAdminMiddleware(ctx *gin.Context) { http.StatusInternalServerError, model.NewAPIErrorResp(errors.New("invalid user type")), ) + return } + user := userEntry.Value() if !user.IsAdmin() { ctx.AbortWithStatusJSON(http.StatusForbidden, model.NewAPIErrorResp(ErrNotAdmin)) @@ -398,6 +445,7 @@ func AuthAdminMiddleware(ctx *gin.Context) { func AuthRootMiddleware(ctx *gin.Context) { AuthUserMiddleware(ctx) + if ctx.IsAborted() { return } @@ -408,8 +456,10 @@ func AuthRootMiddleware(ctx *gin.Context) { http.StatusInternalServerError, model.NewAPIErrorResp(errors.New("invalid user type")), ) + return } + user := userEntry.Value() if !user.IsRoot() { ctx.AbortWithStatusJSON(http.StatusForbidden, model.NewAPIErrorResp(ErrNotRoot)) @@ -437,6 +487,7 @@ func GetAuthorizationTokenFromContext(ctx *gin.Context) string { } ctx.Set("token", "") + return "" } @@ -452,15 +503,19 @@ func GetRoomIDFromContext(ctx *gin.Context) (string, error) { if roomID == "" { continue } + if len(roomID) == 32 { ctx.Set("roomId", roomID) return roomID, nil } + ctx.Set("roomId", "") + return "", ErrInvalidRoomID } ctx.Set("roomId", "") + return "", ErrInvalidRoomID } @@ -471,6 +526,7 @@ func setLogFields(ctx *gin.Context, user *op.User, room *op.Room) { log.Data["unm"] = user.Username log.Data["uro"] = user.Role.String() } + if room != nil { log.Data["rid"] = room.ID log.Data["rnm"] = room.Name @@ -482,6 +538,7 @@ func GetUserEntry(ctx *gin.Context) *op.UserEntry { if !ok { panic("invalid user type") } + return userEntry } @@ -490,6 +547,7 @@ func GetRoomEntry(ctx *gin.Context) *op.RoomEntry { if !ok { panic("invalid room type") } + return roomEntry } @@ -498,9 +556,11 @@ func GetToken(ctx *gin.Context) string { if !ok { return "" } + t, ok := token.(string) if !ok { panic("invalid token type") } + return t } diff --git a/server/middlewares/cors.go b/server/middlewares/cors.go index 2b804a0..aee60b7 100644 --- a/server/middlewares/cors.go +++ b/server/middlewares/cors.go @@ -10,5 +10,6 @@ func NewCors() gin.HandlerFunc { config.AllowAllOrigins = true config.AllowHeaders = []string{"*"} config.AllowMethods = []string{"*"} + return cors.New(config) } diff --git a/server/middlewares/init.go b/server/middlewares/init.go index 733f961..7c3a209 100644 --- a/server/middlewares/init.go +++ b/server/middlewares/init.go @@ -15,11 +15,13 @@ func Init(e *gin.Engine) { Use(NewLog(log.StandardLogger())). Use(gin.RecoveryWithWriter(w)). Use(NewCors()) + if conf.Conf.RateLimit.Enable { d, err := time.ParseDuration(conf.Conf.RateLimit.Period) if err != nil { log.Fatal(err) } + options := []limiter.Option{ limiter.WithTrustForwardHeader(conf.Conf.RateLimit.TrustForwardHeader), } @@ -29,6 +31,7 @@ func Init(e *gin.Engine) { limiter.WithClientIPHeader(conf.Conf.RateLimit.TrustedClientIPHeader), ) } + e.Use(NewLimiter(d, conf.Conf.RateLimit.Limit, options...)) } } diff --git a/server/middlewares/log.go b/server/middlewares/log.go index 2d16e51..bb68297 100644 --- a/server/middlewares/log.go +++ b/server/middlewares/log.go @@ -26,8 +26,10 @@ func NewLog(l *logrus.Logger) gin.HandlerFunc { http.StatusInternalServerError, model.NewAPIErrorResp(errors.New("invalid fields type")), ) + return } + defer func() { clear(fields) fieldsPool.Put(fields) @@ -72,6 +74,7 @@ func NewLog(l *logrus.Logger) gin.HandlerFunc { func logColor(logger *logrus.Entry, p gin.LogFormatterParams) { str := formatter(p) + code := p.StatusCode switch { case code >= http.StatusBadRequest && code < http.StatusInternalServerError: @@ -92,6 +95,7 @@ func formatter(param gin.LogFormatterParams) string { if param.Latency > time.Minute { param.Latency = param.Latency.Truncate(time.Second) } + return fmt.Sprintf("[GIN] |%s %3d %s| %13v | %15s |%s %-7s %s %#v\n%s", statusColor, param.StatusCode, resetColor, param.Latency, @@ -108,16 +112,20 @@ func GetLogger(c *gin.Context) *logrus.Entry { if !ok { panic("invalid log type") } + return entry } + fields, ok := fieldsPool.Get().(logrus.Fields) if !ok { panic("invalid fields type") } + entry := &logrus.Entry{ Logger: logrus.StandardLogger(), Data: fields, } c.Set("log", entry) + return entry } diff --git a/server/middlewares/rateLimit.go b/server/middlewares/rateLimit.go index 0be7bbd..ce7fbf4 100644 --- a/server/middlewares/rateLimit.go +++ b/server/middlewares/rateLimit.go @@ -16,6 +16,7 @@ func NewLimiter(period time.Duration, limit int64, options ...limiter.Option) gi Period: period, Limit: limit, }, options...) + return mgin.NewMiddleware(limiter, mgin.WithLimitReachedHandler(func(c *gin.Context) { c.JSON(http.StatusTooManyRequests, model.NewAPIErrorStringResp("too many requests")) })) diff --git a/server/model/admin.go b/server/model/admin.go index 17f79b2..a0311b4 100644 --- a/server/model/admin.go +++ b/server/model/admin.go @@ -146,16 +146,19 @@ func (avbr *AddVendorBackendReq) Validate() error { return errors.New("alist backend name has invalid char") } } + if avbr.UsedBy.BilibiliBackendName != "" { if !alnumPrintHanReg.MatchString(avbr.UsedBy.BilibiliBackendName) { return errors.New("bilibili backend name has invalid char") } } + if avbr.UsedBy.EmbyBackendName != "" { if !alnumPrintHanReg.MatchString(avbr.UsedBy.EmbyBackendName) { return errors.New("emby backend name has invalid char") } } + return avbr.Backend.Validate() } diff --git a/server/model/auth.go b/server/model/auth.go index 2d148c6..0500680 100644 --- a/server/model/auth.go +++ b/server/model/auth.go @@ -21,9 +21,11 @@ func (o *OAuth2CallbackReq) Validate() error { if o.Code == "" { return ErrInvalidOAuth2Code } + if o.State == "" { return ErrInvalidOAuth2State } + return nil } diff --git a/server/model/decode.go b/server/model/decode.go index b2fa5d2..6bbe744 100644 --- a/server/model/decode.go +++ b/server/model/decode.go @@ -11,8 +11,10 @@ func Decode(ctx *gin.Context, decoder Decoder) error { if err := decoder.Decode(ctx); err != nil { return err } + if err := decoder.Validate(); err != nil { return err } + return nil } diff --git a/server/model/movie.go b/server/model/movie.go index 521ed8d..359c4ca 100644 --- a/server/model/movie.go +++ b/server/model/movie.go @@ -58,6 +58,7 @@ func (p *PushMoviesReq) Validate() error { return err } } + return nil } @@ -113,9 +114,11 @@ func (e *EditMovieReq) Validate() error { if err := e.IDReq.Validate(); err != nil { return err } + if err := e.PushMovieReq.Validate(); err != nil { return err } + return nil } @@ -131,11 +134,13 @@ func (i *IDsReq) Validate() error { if len(i.IDs) == 0 { return ErrEmptyIDs } + for _, v := range i.IDs { if len(v) != 32 { return ErrID } } + return nil } @@ -161,6 +166,7 @@ func GenDefaultSubPaths(id, path string, skipEmpty bool, paths ...*MoviePath) [] if v == "" && skipEmpty { continue } + if l := len(paths); l != 0 { paths = append(paths, &MoviePath{ Name: v, @@ -175,6 +181,7 @@ func GenDefaultSubPaths(id, path string, skipEmpty bool, paths ...*MoviePath) [] }) } } + return paths } diff --git a/server/model/room.go b/server/model/room.go index 5a5653d..c9ae5a7 100644 --- a/server/model/room.go +++ b/server/model/room.go @@ -109,6 +109,7 @@ func (s *SetRoomPasswordReq) Validate() error { return ErrPasswordHasInvalidChar } } + return nil } diff --git a/server/model/user.go b/server/model/user.go index 3056cef..74c986d 100644 --- a/server/model/user.go +++ b/server/model/user.go @@ -33,6 +33,7 @@ func (s *SetUserPasswordReq) Validate() error { case !alnumPrintReg.MatchString(s.Password): return ErrPasswordHasInvalidChar } + return nil } @@ -59,6 +60,7 @@ func (l *LoginUserReq) Validate() error { if len(l.Username) > 32 { return ErrUsernameTooLong } + if !alnumPrintHanReg.MatchString(l.Username) { return ErrUsernameHasInvalidChar } @@ -78,6 +80,7 @@ func (l *LoginUserReq) Validate() error { case !alnumPrintReg.MatchString(l.Password): return ErrPasswordHasInvalidChar } + return nil } @@ -94,12 +97,15 @@ func (u *UserSignupPasswordReq) Validate() error { if u.Username == "" { return errors.New("username is empty") } + if len(u.Username) > 32 { return ErrUsernameTooLong } + if !alnumPrintHanReg.MatchString(u.Username) { return ErrUsernameHasInvalidChar } + switch { case u.Password == "": return FormatEmptyPasswordError("user") @@ -108,6 +114,7 @@ func (u *UserSignupPasswordReq) Validate() error { case !alnumPrintReg.MatchString(u.Password): return ErrPasswordHasInvalidChar } + return nil } @@ -132,6 +139,7 @@ func (s *SetUsernameReq) Validate() error { case !alnumPrintHanReg.MatchString(s.Username): return ErrUsernameHasInvalidChar } + return nil } @@ -188,12 +196,15 @@ func (u *UserSendBindEmailCaptchaReq) Validate() error { case !emailReg.MatchString(u.Email): return ErrInvalidEmail } + if u.CaptchaID == "" { return errors.New("captcha id is empty") } + if u.Answer == "" { return errors.New("answer is empty") } + return nil } @@ -215,9 +226,11 @@ func (u *UserBindEmailReq) Validate() error { case !emailReg.MatchString(u.Email): return ErrInvalidEmail } + if u.Captcha == "" { return errors.New("captcha is empty") } + return nil } @@ -236,6 +249,7 @@ func (u *UserSignupEmailReq) Validate() error { if err := u.UserBindEmailReq.Validate(); err != nil { return err } + switch { case u.Password == "": return FormatEmptyPasswordError("user") @@ -244,6 +258,7 @@ func (u *UserSignupEmailReq) Validate() error { case !alnumPrintReg.MatchString(u.Password): return ErrPasswordHasInvalidChar } + return nil } diff --git a/server/model/vendor.go b/server/model/vendor.go index ac4fa49..0b8798f 100644 --- a/server/model/vendor.go +++ b/server/model/vendor.go @@ -26,6 +26,7 @@ func GenDefaultPaths(path string, skipEmpty bool, paths ...*Path) []*Path { if v == "" && skipEmpty { continue } + if l := len(paths); l != 0 { paths = append(paths, &Path{ Name: v, @@ -38,6 +39,7 @@ func GenDefaultPaths(path string, skipEmpty bool, paths ...*Path) []*Path { }) } } + return paths } diff --git a/server/oauth2/auth.go b/server/oauth2/auth.go index 4eddc24..1ec6de5 100644 --- a/server/oauth2/auth.go +++ b/server/oauth2/auth.go @@ -32,12 +32,14 @@ func OAuth2(ctx *gin.Context) { } state := utils.RandString(16) + url, err := pi.NewAuthURL(ctx, state) if err != nil { log.Errorf("failed to get auth url: %v", err) ctx.AbortWithStatusJSON(http.StatusBadRequest, model.NewAPIErrorResp(err)) return } + states.Store(state, newAuthFunc(ctx.Query("redirect")), time.Minute*5) err = RenderRedirect(ctx, url) @@ -64,12 +66,14 @@ func OAuth2Api(ctx *gin.Context) { } state := utils.RandString(16) + url, err := pi.NewAuthURL(ctx, state) if err != nil { log.Errorf("failed to get auth url: %v", err) ctx.AbortWithStatusJSON(http.StatusBadRequest, model.NewAPIErrorResp(err)) return } + states.Store(state, newAuthFunc(meta.Redirect), time.Minute*5) ctx.JSON(http.StatusOK, model.NewAPIDataResp(gin.H{ "url": url, @@ -88,6 +92,7 @@ func OAuth2Callback(ctx *gin.Context) { http.StatusBadRequest, model.NewAPIErrorStringResp("invalid oauth2 code"), ) + return } @@ -105,6 +110,7 @@ func OAuth2Callback(ctx *gin.Context) { http.StatusBadRequest, model.NewAPIErrorStringResp("invalid oauth2 state"), ) + return } @@ -141,6 +147,7 @@ func OAuth2CallbackAPI(ctx *gin.Context) { http.StatusBadRequest, model.NewAPIErrorStringResp("invalid oauth2 state"), ) + return } @@ -171,14 +178,17 @@ func newAuthFunc(redirect string) stateHandler { http.StatusBadRequest, model.NewAPIErrorStringResp("invalid oauth2 provider user id"), ) + return } + if ui.Username == "" { log.Errorf("invalid oauth2 username") ctx.AbortWithStatusJSON( http.StatusBadRequest, model.NewAPIErrorStringResp("invalid oauth2 username"), ) + return } @@ -189,6 +199,7 @@ func newAuthFunc(redirect string) stateHandler { http.StatusBadRequest, model.NewAPIErrorStringResp("invalid oauth2 provider"), ) + return } @@ -202,11 +213,13 @@ func newAuthFunc(redirect string) stateHandler { userE, err = op.CreateOrLoadUserWithProvider(ui.Username, utils.RandString(16), pi.Provider(), ui.ProviderUserID) } } + if err != nil { log.Errorf("failed to create or load user: %v", err) ctx.AbortWithStatusJSON(http.StatusBadRequest, model.NewAPIErrorResp(err)) return } + user := userE.Value() token, err := middlewares.NewAuthUserToken(user) @@ -218,10 +231,13 @@ func newAuthFunc(redirect string) stateHandler { "message": err.Error(), "role": user.Role, })) + return } + log.Errorf("failed to generate token: %v", err) ctx.AbortWithStatusJSON(http.StatusBadRequest, model.NewAPIErrorResp(err)) + return } diff --git a/server/oauth2/bind.go b/server/oauth2/bind.go index ef96f9f..9127f82 100644 --- a/server/oauth2/bind.go +++ b/server/oauth2/bind.go @@ -33,12 +33,14 @@ func BindAPI(ctx *gin.Context) { } state := utils.RandString(16) + url, err := pi.NewAuthURL(ctx, state) if err != nil { log.Errorf("failed to get auth url: %v", err) ctx.AbortWithStatusJSON(http.StatusBadRequest, model.NewAPIErrorResp(err)) return } + states.Store(state, newBindFunc(user.ID, meta.Redirect), time.Minute*5) ctx.JSON(http.StatusOK, model.NewAPIDataResp(gin.H{ "url": url, @@ -85,14 +87,17 @@ func newBindFunc(userID, redirect string) stateHandler { http.StatusBadRequest, model.NewAPIErrorStringResp("invalid oauth2 provider user id"), ) + return } + if ui.Username == "" { log.Errorf("invalid oauth2 username") ctx.AbortWithStatusJSON( http.StatusBadRequest, model.NewAPIErrorStringResp("invalid oauth2 username"), ) + return } diff --git a/server/router.go b/server/router.go index 282c91f..2e889bf 100644 --- a/server/router.go +++ b/server/router.go @@ -13,6 +13,7 @@ func Init(e *gin.Engine) { middlewares.Init(e) auth.Init(e) handlers.Init(e) + if !flags.Server.DisableWeb { static.Init(e) } @@ -21,5 +22,5 @@ func Init(e *gin.Engine) { func NewAndInit() (e *gin.Engine) { e = gin.New() Init(e) - return + return e } diff --git a/server/static/static.go b/server/static/static.go index bd6cb2c..1698c29 100644 --- a/server/static/static.go +++ b/server/static/static.go @@ -56,18 +56,21 @@ func Init(e *gin.Engine) { func newFSHandler(fileSys fs.FS) func(ctx *gin.Context) { return func(ctx *gin.Context) { fp := strings.Trim(ctx.Param("filepath"), "/") + f, err := fileSys.Open(fp) if err != nil { fp = "" } else { f.Close() } + ctx.FileFromFS(fp, http.FS(fileSys)) } } func newStatCachedFSHandler(fileSys fs.FS) (func(ctx *gin.Context), error) { cache := make(map[string]struct{}) + err := fs.WalkDir(fileSys, ".", func(path string, _ fs.DirEntry, _ error) error { cache[`/`+path] = struct{}{} return nil @@ -75,11 +78,13 @@ func newStatCachedFSHandler(fileSys fs.FS) (func(ctx *gin.Context), error) { if err != nil { return nil, err } + return func(ctx *gin.Context) { fp := ctx.Param("filepath") if _, ok := cache[fp]; !ok { fp = "" } + ctx.FileFromFS(fp, http.FS(fileSys)) }, nil } @@ -88,6 +93,7 @@ func SiglePageAppFS(r *gin.RouterGroup, fileSys fs.FS, cacheStat bool) error { var h func(ctx *gin.Context) if cacheStat { var err error + h, err = newStatCachedFSHandler(fileSys) if err != nil { return err @@ -95,8 +101,10 @@ func SiglePageAppFS(r *gin.RouterGroup, fileSys fs.FS, cacheStat bool) error { } else { h = newFSHandler(fileSys) } + r.GET("/*filepath", h) r.HEAD("/*filepath", h) + return nil } diff --git a/utils/crypto.go b/utils/crypto.go index 40347c1..8381104 100644 --- a/utils/crypto.go +++ b/utils/crypto.go @@ -29,6 +29,7 @@ func Crypto(v, key []byte) ([]byte, error) { // Encrypt and authenticate the plaintext ciphertext := aead.Seal(nonce, nonce, v, nil) + return ciphertext, nil } @@ -67,6 +68,7 @@ func CryptoToBase64(v, key []byte) (string, error) { if err != nil { return "", err } + return base64.StdEncoding.EncodeToString(ciphertext), nil } @@ -75,6 +77,7 @@ func DecryptoFromBase64(v string, key []byte) ([]byte, error) { if err != nil { return nil, err } + return Decrypto(ciphertext, key) } @@ -83,6 +86,7 @@ func GenCryptoKey(base string) []byte { for i := range len(base) { key[i%32] ^= base[i] } + return key } @@ -91,5 +95,6 @@ func GenCryptoKeyWithBytes(base []byte) []byte { for i := range base { key[i%32] ^= base[i] } + return key } diff --git a/utils/crypto_test.go b/utils/crypto_test.go index ccf265b..fddf30e 100644 --- a/utils/crypto_test.go +++ b/utils/crypto_test.go @@ -9,14 +9,18 @@ import ( func TestCrypto(t *testing.T) { m := []byte("hello world") key := []byte(utils.RandString(32)) + m, err := utils.Crypto(m, key) if err != nil { t.Fatal(err) } + t.Log(string(m)) + m, err = utils.Decrypto(m, key) if err != nil { t.Fatal(err) } + t.Log(string(m)) } diff --git a/utils/fastJSONSerializer/fastJSONSerializer.go b/utils/fastJSONSerializer/fastJSONSerializer.go index 5c328a1..c956554 100644 --- a/utils/fastJSONSerializer/fastJSONSerializer.go +++ b/utils/fastJSONSerializer/fastJSONSerializer.go @@ -42,7 +42,8 @@ func (*JSONSerializer) Scan( } field.ReflectValueOf(ctx, dst).Set(fieldValue.Elem()) - return + + return err } func (*JSONSerializer) Value( diff --git a/utils/m3u8/m3u8.go b/utils/m3u8/m3u8.go index 9d2032f..824e18b 100644 --- a/utils/m3u8/m3u8.go +++ b/utils/m3u8/m3u8.go @@ -9,6 +9,7 @@ import ( func GetM3u8AllSegments(m3u8Str, baseURL string) ([]string, error) { var segments []string + err := RangeM3u8SegmentsWithBaseURL(m3u8Str, baseURL, func(segmentUrl string) (bool, error) { segments = append(segments, segmentUrl) return true, nil @@ -16,6 +17,7 @@ func GetM3u8AllSegments(m3u8Str, baseURL string) ([]string, error) { if err != nil { return nil, err } + return segments, nil } @@ -31,9 +33,11 @@ func RangeM3u8Segments(m3u8Str string, callback func(segmentUrl string) (bool, e } } } + if err := scanner.Err(); err != nil { return fmt.Errorf("scan m3u8 error: %w", err) } + return nil } @@ -45,14 +49,17 @@ func RangeM3u8SegmentsWithBaseURL( if err != nil { return fmt.Errorf("parse base url error: %w", err) } + return RangeM3u8Segments(m3u8Str, func(segmentURL string) (bool, error) { if !strings.HasPrefix(segmentURL, "http://") && !strings.HasPrefix(segmentURL, "https://") { segmentURLParsed, err := url.Parse(segmentURL) if err != nil { return false, fmt.Errorf("parse segment url error: %w", err) } + segmentURL = baseURLParsed.ResolveReference(segmentURLParsed).String() } + return callback(segmentURL) }) } @@ -62,6 +69,7 @@ func ReplaceM3u8Segments( callback func(segmentURL string) (string, error), ) (string, error) { var result strings.Builder + scanner := bufio.NewScanner(strings.NewReader(m3u8Str)) for scanner.Scan() { line := strings.TrimSpace(scanner.Text()) @@ -70,15 +78,19 @@ func ReplaceM3u8Segments( if err != nil { return "", fmt.Errorf("callback error: %w", err) } + result.WriteString(newSegment) } else { result.WriteString(line) } + result.WriteString("\n") } + if err := scanner.Err(); err != nil { return "", fmt.Errorf("scan m3u8 error: %w", err) } + return result.String(), nil } @@ -90,14 +102,17 @@ func ReplaceM3u8SegmentsWithBaseURL( if err != nil { return "", fmt.Errorf("parse base url error: %w", err) } + return ReplaceM3u8Segments(m3u8Str, func(segmentURL string) (string, error) { if !strings.HasPrefix(segmentURL, "http://") && !strings.HasPrefix(segmentURL, "https://") { segmentURLParsed, err := url.Parse(segmentURL) if err != nil { return "", fmt.Errorf("parse segment url error: %w", err) } + segmentURL = baseURLParsed.ResolveReference(segmentURLParsed).String() } + return callback(segmentURL) }) } diff --git a/utils/smtp/format.go b/utils/smtp/format.go index 6ac28bf..7162139 100644 --- a/utils/smtp/format.go +++ b/utils/smtp/format.go @@ -55,6 +55,7 @@ func FormatMail(from string, to []string, subject, body string, opts ...FormatMa for _, opt := range opts { opt(c) } + buf := bytes.NewBuffer(nil) fmt.Fprintf(buf, "From: %s\r\n", from) @@ -63,6 +64,7 @@ func FormatMail(from string, to []string, subject, body string, opts ...FormatMa fmt.Fprintf(buf, "Date: %s\r\n", c.date) fmt.Fprintf(buf, "MIME-Version: %s\r\n", c.mimeVersion) fmt.Fprintf(buf, "Content-Type: %s\r\n", c.contentType) + if c.contentTransferEncoding != "" { fmt.Fprintf(buf, "Content-Transfer-Encoding: %s\r\n", c.contentTransferEncoding) } @@ -77,6 +79,7 @@ func FormatMail(from string, to []string, subject, body string, opts ...FormatMa if end > len(encodedBody) { end = len(encodedBody) } + buf.WriteString(encodedBody[i:end] + "\r\n") } case "": diff --git a/utils/smtp/smtpool.go b/utils/smtp/smtpool.go index b4d6958..890d7de 100644 --- a/utils/smtp/smtpool.go +++ b/utils/smtp/smtpool.go @@ -24,21 +24,27 @@ func validateSMTPConfig(c *Config) error { if c == nil { return errors.New("smtp config is nil") } + if c.Host == "" { return errors.New("smtp host is empty") } + if c.Port == 0 { return errors.New("smtp port is empty") } + if c.Username == "" { return errors.New("smtp username is empty") } + if c.Password == "" { return errors.New("smtp password is empty") } + if c.From == "" { return errors.New("smtp from is empty") } + return nil } @@ -56,6 +62,7 @@ func newSMTPClient(c *Config) (*smtp.Client, error) { default: cli, err = smtp.Dial(fmt.Sprintf("%s:%d", c.Host, c.Port)) } + if err != nil { return nil, fmt.Errorf("dial smtp server failed: %w", err) } @@ -85,6 +92,7 @@ func NewSMTPPool(c *Config, poolCap int) (*Pool, error) { if err != nil { return nil, err } + return &Pool{ clients: make([]*smtp.Client, 0, poolCap), c: c, @@ -94,6 +102,7 @@ func NewSMTPPool(c *Config, poolCap int) (*Pool, error) { func (p *Pool) Get() (*smtp.Client, error) { p.mu.Lock() + if p.closed { p.mu.Unlock() return nil, ErrSMTPPoolClosed @@ -104,13 +113,16 @@ func (p *Pool) Get() (*smtp.Client, error) { p.clients = p.clients[:len(p.clients)-1] p.active++ p.mu.Unlock() + if cli.Noop() != nil { cli.Close() p.mu.Lock() p.active-- p.mu.Unlock() + return p.Get() } + return cli, nil } @@ -128,6 +140,7 @@ func (p *Pool) Get() (*smtp.Client, error) { p.active++ p.mu.Unlock() + return cli, nil } @@ -160,6 +173,7 @@ func (p *Pool) Close() { for _, cli := range p.clients { cli.Close() } + p.clients = nil } @@ -169,6 +183,7 @@ func (p *Pool) SendEmail(to []string, subject, body string, opts ...FormatMailOp return err } defer p.Put(cli) + return SendEmail(cli, p.c.From, to, subject, body, opts...) } diff --git a/utils/utils.go b/utils/utils.go index 5f2301c..1dfd479 100644 --- a/utils/utils.go +++ b/utils/utils.go @@ -51,6 +51,7 @@ func RandString(n int) string { for i := range b { b[i] = letters[rand.IntN(len(letters))] } + return string(b) } @@ -59,6 +60,7 @@ func RandBytes(n int) []byte { for i := range b { b[i] = byte(rand.IntN(256)) } + return b } @@ -71,15 +73,18 @@ func GetPageItemsRange(total, page, pageSize int) (start, end int) { if pageSize <= 0 || page <= 0 { return 0, 0 } + start = (page - 1) * pageSize if start > total { start = total } + end = page * pageSize if end > total { end = total } - return + + return start, end } func Index[T comparable](items []T, item T) int { @@ -88,6 +93,7 @@ func Index[T comparable](items []T, item T) int { return i } } + return -1 } @@ -105,11 +111,13 @@ func WriteYaml(file string, module any) error { if err != nil { return err } + f, err := os.Create(file) if err != nil { return err } defer f.Close() + return yamlcomment.NewEncoder(yaml.NewEncoder(f)).Encode(module) } @@ -119,6 +127,7 @@ func ReadYaml(file string, module any) error { return err } defer f.Close() + return yaml.NewDecoder(f).Decode(module) } @@ -142,6 +151,7 @@ func CompVersion(v1, v2 string) (int, error) { if err != nil { return VersionEqual, err } + v2Base, err := SplitVersion(strings.TrimLeft(v2Parts[0], "v")) if err != nil { return VersionEqual, err @@ -157,6 +167,7 @@ func CompVersion(v1, v2 string) (int, error) { if v1Base[i] > v2Base[i] { return VersionGreater, nil } + if v1Base[i] < v2Base[i] { return VersionLess, nil } @@ -170,9 +181,11 @@ func CompVersion(v1, v2 string) (int, error) { if len(v1PreRelease) == 0 && len(v2PreRelease) != 0 { return VersionGreater, nil } + if len(v1PreRelease) != 0 && len(v2PreRelease) == 0 { return VersionLess, nil } + if len(v1PreRelease) == 0 && len(v2PreRelease) == 0 { return VersionEqual, nil } @@ -217,14 +230,17 @@ func getPreReleaseType(s string) string { func SplitVersion(v string) ([]int, error) { split := strings.Split(v, ".") + vs := make([]int, 0, len(split)) for _, s := range split { i, err := strconv.Atoi(s) if err != nil { return nil, err } + vs = append(vs, i) } + return vs, nil } @@ -244,16 +260,19 @@ func (o *Once) Done() (doned bool) { o.m.Lock() defer o.m.Unlock() + switch o.done { case 0: doned = false + atomic.StoreUint32(&o.done, 2) case 1: doned = true default: doned = false } - return + + return doned } func (o *Once) Do(f func()) { @@ -265,8 +284,10 @@ func (o *Once) Do(f func()) { func (o *Once) doSlow(f func()) { o.m.Lock() defer o.m.Unlock() + if o.done == 0 { defer atomic.StoreUint32(&o.done, 1) + f() } } @@ -280,6 +301,7 @@ func ParseURLIsLocalIP(u string) (bool, error) { if err != nil { return false, err } + return IsLocalIP(url.Host), nil } @@ -326,9 +348,11 @@ func OptFilePath(filePath string) (string, error) { if filePath == "" { return "", nil } + if !filepath.IsAbs(filePath) { return filepath.Abs(filepath.Join(flags.Global.DataDir, filePath)) } + return filePath, nil } @@ -351,6 +375,7 @@ func HTTPCookieToMap(c []*http.Cookie) map[string]string { for _, v := range c { m[v.Name] = v.Value } + return m } @@ -362,6 +387,7 @@ func MapToHTTPCookie(m map[string]string) []*http.Cookie { Value: v, }) } + return c } @@ -373,14 +399,17 @@ func GetURLExtension(u string) string { if u == "" { return "" } + p, err := url.Parse(u) if err != nil { return "" } + ext := GetFileExtension(p.Path) if ext != "" { return ext } + return GetFileExtension(p.RawQuery) } @@ -399,8 +428,10 @@ func ForceColor() bool { needColor = false return } + needColor = colorable.IsTerminal(os.Stdout.Fd()) }) + return needColor } @@ -409,19 +440,23 @@ func GetPageAndMax(ctx *gin.Context) (page, _max int, err error) { if err != nil { return 0, 0, errors.New("max must be a number") } + page, err = strconv.Atoi(ctx.DefaultQuery("page", "1")) if err != nil { return 0, 0, errors.New("page must be a number") } + if page <= 0 { page = 1 } + if _max <= 0 { _max = 10 } else if _max > 100 { _max = 100 } - return + + return page, _max, err } func TruncateByRune(s string, length int) string { @@ -431,8 +466,10 @@ func TruncateByRune(s string, length int) string { if runeLen == -1 || total+runeLen > length { return s[:total] } + total += runeLen } + return s[:total] } diff --git a/utils/utils_test.go b/utils/utils_test.go index 3a2d58c..92e4665 100644 --- a/utils/utils_test.go +++ b/utils/utils_test.go @@ -14,6 +14,7 @@ func TestGetPageItems(t *testing.T) { page int pageSize int } + tests := []struct { name string args args @@ -69,6 +70,7 @@ func FuzzCompVersion(f *testing.F) { f.Add("v0.3.1", "v0.3.1-alpha.2") f.Fuzz(func(t *testing.T, a, b string) { t.Logf("a: %s, b: %s", a, b) + _, err := utils.CompVersion(a, b) if err != nil { t.Errorf("CompVersion error = %v", err) @@ -118,15 +120,19 @@ func TestTruncateByRune(t *testing.T) { if !strings.EqualFold(utils.TruncateByRune(name, 6), "abcd") { t.Errorf("TruncateByRune() = %v, want %v", utils.TruncateByRune(name, 6), "abcd") } + if !strings.EqualFold(utils.TruncateByRune(name, 7), "abcd测") { t.Errorf("TruncateByRune() = %v, want %v", utils.TruncateByRune(name, 7), "abcd测") } + if !strings.EqualFold(utils.TruncateByRune(name, 8), "abcd测") { t.Errorf("TruncateByRune() = %v, want %v", utils.TruncateByRune(name, 8), "abcd测") } + if !strings.EqualFold(utils.TruncateByRune(name, 9), "abcd测") { t.Errorf("TruncateByRune() = %v, want %v", utils.TruncateByRune(name, 9), "abcd测") } + if !strings.EqualFold(utils.TruncateByRune(name, 10), "abcd测试") { t.Errorf("TruncateByRune() = %v, want %v", utils.TruncateByRune(name, 10), "abcd测试") } diff --git a/utils/websocket.go b/utils/websocket.go index 934a132..f5a9edf 100644 --- a/utils/websocket.go +++ b/utils/websocket.go @@ -28,6 +28,7 @@ func NewWebSocketServer(conf ...WebSocketConfig) *WebSocket { for _, wsc := range conf { wsc(ws) } + return ws } @@ -41,11 +42,13 @@ func (ws *WebSocket) Server( if len(subprotocols) > 0 { conf = append(conf, WithSubprotocols(subprotocols)) } + wsc, err := ws.NewWebSocketClient(w, r, nil, conf...) if err != nil { return err } defer wsc.Close() + return handler(wsc) } @@ -69,6 +72,7 @@ func (ws *WebSocket) newUpgrader(conf ...UpgraderConf) *websocket.Upgrader { for _, uc := range conf { uc(ug) } + return ug } @@ -82,5 +86,6 @@ func (ws *WebSocket) NewWebSocketClient( if err != nil { return nil, err } + return conn, nil }