diff --git a/cmd/server.go b/cmd/server.go index dfd140f..8e596c7 100644 --- a/cmd/server.go +++ b/cmd/server.go @@ -15,7 +15,6 @@ import ( "github.com/synctv-org/synctv/internal/rtmp" sysnotify "github.com/synctv-org/synctv/internal/sysNotify" "github.com/synctv-org/synctv/server" - "github.com/synctv-org/synctv/utils" ) var ServerCmd = &cobra.Command{ @@ -80,14 +79,6 @@ func Server(cmd *cobra.Command, args []string) { e := server.NewAndInit() switch { case conf.Conf.Server.Http.CertPath != "" && conf.Conf.Server.Http.KeyPath != "": - conf.Conf.Server.Http.CertPath, err = utils.OptFilePath(conf.Conf.Server.Http.CertPath) - if err != nil { - log.Fatalf("cert path error: %s", err) - } - conf.Conf.Server.Http.KeyPath, err = utils.OptFilePath(conf.Conf.Server.Http.KeyPath) - if err != nil { - log.Fatalf("key path error: %s", err) - } httpl := muxer.Match(cmux.HTTP2(), cmux.TLS()) go http.ServeTLS(httpl, e.Handler(), conf.Conf.Server.Http.CertPath, conf.Conf.Server.Http.KeyPath) if conf.Conf.Server.Http.Quic { @@ -106,14 +97,6 @@ func Server(cmd *cobra.Command, args []string) { e := server.NewAndInit() switch { case conf.Conf.Server.Http.CertPath != "" && conf.Conf.Server.Http.KeyPath != "": - conf.Conf.Server.Http.CertPath, err = utils.OptFilePath(conf.Conf.Server.Http.CertPath) - if err != nil { - log.Fatalf("cert path error: %s", err) - } - conf.Conf.Server.Http.KeyPath, err = utils.OptFilePath(conf.Conf.Server.Http.KeyPath) - if err != nil { - log.Fatalf("key path error: %s", err) - } go http.ServeTLS(serverHttpListener, e.Handler(), conf.Conf.Server.Http.CertPath, conf.Conf.Server.Http.KeyPath) if conf.Conf.Server.Http.Quic { go http3.ListenAndServeQUIC(udpServerHttpAddr.String(), conf.Conf.Server.Http.CertPath, conf.Conf.Server.Http.KeyPath, e.Handler()) diff --git a/go.mod b/go.mod index 5b112a4..cd8e3ae 100644 --- a/go.mod +++ b/go.mod @@ -35,9 +35,10 @@ require ( github.com/synctv-org/vendors v0.3.3 github.com/ulule/limiter/v3 v3.11.2 github.com/zencoder/go-dash/v3 v3.0.3 - github.com/zijiren233/gencontainer v0.0.0-20241028165332-af5906dd24c9 + github.com/zijiren233/gencontainer v0.0.0-20241030052007-1f7025eb92f5 github.com/zijiren233/go-colorable v0.0.0-20230930131441-997304c961cb github.com/zijiren233/go-uhc v0.2.6 + github.com/zijiren233/ksync v0.2.0 github.com/zijiren233/livelib v0.3.3 github.com/zijiren233/stream v0.5.2 github.com/zijiren233/yaml-comment v0.2.2 diff --git a/go.sum b/go.sum index 852d74b..1bd2133 100644 --- a/go.sum +++ b/go.sum @@ -380,12 +380,14 @@ github.com/yuin/goldmark v1.2.1/go.mod h1:3hX8gzYuyVAZsxl0MRgGTJEmQBFcNTphYh9dec github.com/yuin/goldmark v1.4.13/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY= github.com/zencoder/go-dash/v3 v3.0.3 h1:xqwGJ2fJCSArwONGx6sY26Z1lxQ7zTURoxdRjCpuodM= github.com/zencoder/go-dash/v3 v3.0.3/go.mod h1:30R5bKy1aUYY45yesjtZ9l8trNc2TwNqbS17WVQmCzk= -github.com/zijiren233/gencontainer v0.0.0-20241028165332-af5906dd24c9 h1:CbBpXy3eFGY5+pmTOUp7EyE4Uq9LvyY9vbvROUyl2zs= -github.com/zijiren233/gencontainer v0.0.0-20241028165332-af5906dd24c9/go.mod h1:bt31/uEP7Eq5qEEW+I/9AzCsueDohCwXLwy/rnLpNPY= +github.com/zijiren233/gencontainer v0.0.0-20241030052007-1f7025eb92f5 h1:/3MxA7C04j+3npN+uQw8EV5lcH4vhDWcLHLG+EaBp2Y= +github.com/zijiren233/gencontainer v0.0.0-20241030052007-1f7025eb92f5/go.mod h1:bt31/uEP7Eq5qEEW+I/9AzCsueDohCwXLwy/rnLpNPY= github.com/zijiren233/go-colorable v0.0.0-20230930131441-997304c961cb h1:0DyOxf/TbbGodHhOVHNoPk+7v/YBJACs22gKpKlatWw= github.com/zijiren233/go-colorable v0.0.0-20230930131441-997304c961cb/go.mod h1:6TCzjDiQ8+5gWZiwsC3pnA5M0vUy2jV2Y7ciHJh729g= github.com/zijiren233/go-uhc v0.2.6 h1:7VG21KWa/o6jpOAqG08x9FUX4+Rj8GJRbhIpOuk4sxs= github.com/zijiren233/go-uhc v0.2.6/go.mod h1:4thRoKeIFjYBa1bB+/fVzL9IjxlXA8V8gnnQDVRfILI= +github.com/zijiren233/ksync v0.2.0 h1:OyXVXbVQYFEVfWM13NApt4LMHbLQ3HTs4oYcLmqL6NE= +github.com/zijiren233/ksync v0.2.0/go.mod h1:YNvvoErcbtF86Xn18J8QJ14jKOXinxFVOzyp4hn8FKw= github.com/zijiren233/livelib v0.3.3 h1:0hbOK9RJzdduOp2GE4MWi/pqHy7ifxGQeZh4x9zxXGo= github.com/zijiren233/livelib v0.3.3/go.mod h1:vEjPSCaZ9RAen2efc1lmjCoFV/3zpPtjVV5ngS2q/nE= github.com/zijiren233/stream v0.5.2 h1:K8xPvXtETH7qo9P99xdvi7q0MXALfxb1XBtzpz/Zn0A= diff --git a/internal/bootstrap/config.go b/internal/bootstrap/config.go index 561b11e..34aef66 100644 --- a/internal/bootstrap/config.go +++ b/internal/bootstrap/config.go @@ -3,6 +3,7 @@ package bootstrap import ( "context" "errors" + "fmt" "path/filepath" "github.com/caarlos0/env/v9" @@ -54,6 +55,33 @@ func InitConfig(ctx context.Context) (err error) { } log.Info("load config success from env") } + return optConfigPath(conf.Conf) +} + +func optConfigPath(conf *conf.Config) error { + var err error + conf.Server.ProxyCachePath, err = utils.OptFilePath(conf.Server.ProxyCachePath) + if err != nil { + return fmt.Errorf("get proxy cache path error: %w", err) + } + conf.Server.Http.CertPath, err = utils.OptFilePath(conf.Server.Http.CertPath) + if err != nil { + return fmt.Errorf("get http cert path error: %w", err) + } + conf.Server.Http.KeyPath, err = utils.OptFilePath(conf.Server.Http.KeyPath) + if err != nil { + return fmt.Errorf("get http key path error: %w", err) + } + conf.Log.FilePath, err = utils.OptFilePath(conf.Log.FilePath) + if err != nil { + return fmt.Errorf("get log file path error: %w", err) + } + for _, op := range conf.Oauth2Plugins { + op.PluginFile, err = utils.OptFilePath(op.PluginFile) + if err != nil { + return fmt.Errorf("get oauth2 plugin file path error: %w", err) + } + } return nil } diff --git a/internal/bootstrap/log.go b/internal/bootstrap/log.go index f8c2f05..33a44bb 100644 --- a/internal/bootstrap/log.go +++ b/internal/bootstrap/log.go @@ -35,10 +35,6 @@ func InitLog(ctx context.Context) (err error) { setLog(logrus.StandardLogger()) forceColor := utils.ForceColor() if conf.Conf.Log.Enable { - conf.Conf.Log.FilePath, err = utils.OptFilePath(conf.Conf.Log.FilePath) - if err != nil { - logrus.Fatalf("log: log file path error: %v", err) - } l := &lumberjack.Logger{ Filename: conf.Conf.Log.FilePath, MaxSize: conf.Conf.Log.MaxSize, diff --git a/internal/bootstrap/provider.go b/internal/bootstrap/provider.go index d6b54be..bcfc9d0 100644 --- a/internal/bootstrap/provider.go +++ b/internal/bootstrap/provider.go @@ -19,7 +19,6 @@ import ( "github.com/synctv-org/synctv/internal/provider/plugins" "github.com/synctv-org/synctv/internal/provider/providers" "github.com/synctv-org/synctv/internal/settings" - "github.com/synctv-org/synctv/utils" "github.com/zijiren233/gencontainer/refreshcache0" ) @@ -83,11 +82,6 @@ func InitProvider(ctx context.Context) (err error) { logLevle = hclog.Debug } for _, op := range conf.Conf.Oauth2Plugins { - op.PluginFile, err = utils.OptFilePath(op.PluginFile) - if err != nil { - log.Fatalf("oauth2 plugin file path error: %v", err) - return err - } log.Infof("load oauth2 plugin: %s", op.PluginFile) err := os.MkdirAll(filepath.Dir(op.PluginFile), 0o755) if err != nil { diff --git a/internal/cache/alist.go b/internal/cache/alist.go index 406d2dc..2bf0ea9 100644 --- a/internal/cache/alist.go +++ b/internal/cache/alist.go @@ -135,7 +135,7 @@ func newAliSubtitles(list []*alist.FsOtherResp_VideoPreviewPlayInfo_LiveTranscod return nil, fmt.Errorf("status code: %d", resp.StatusCode) } return io.ReadAll(resp.Body) - }, 0), + }, -1), Name: v.Language, URL: v.Url, Type: utils.GetFileExtension(v.Url), diff --git a/internal/cache/bilibili.go b/internal/cache/bilibili.go index 169189b..fcac8e4 100644 --- a/internal/cache/bilibili.go +++ b/internal/cache/bilibili.go @@ -370,7 +370,7 @@ func NewBilibiliMovieCache(movie *model.Movie) *BilibiliMovieCache { return &BilibiliMovieCache{ NoSharedMovie: newMapCache(NewBilibiliNoSharedMovieCacheInitFunc(movie), time.Minute*60), SharedMpd: refreshcache1.NewRefreshCache(NewBilibiliSharedMpdCacheInitFunc(movie), time.Minute*60), - Subtitle: refreshcache1.NewRefreshCache(NewBilibiliSubtitleCacheInitFunc(movie), 0), + Subtitle: refreshcache1.NewRefreshCache(NewBilibiliSubtitleCacheInitFunc(movie), -1), Live: refreshcache0.NewRefreshCache(NewBilibiliLiveCacheInitFunc(movie), time.Minute*55), } } @@ -386,7 +386,7 @@ func NewBilibiliUserCache(userID string) *BilibiliUserCache { f := BilibiliAuthorizationCacheWithUserIDInitFunc(userID) return refreshcache.NewRefreshCache(func(ctx context.Context, args ...struct{}) (*BilibiliUserCacheData, error) { return f(ctx) - }, 0) + }, -1) } func BilibiliAuthorizationCacheWithUserIDInitFunc(userID string) func(ctx context.Context, args ...struct{}) (*BilibiliUserCacheData, error) { diff --git a/internal/cache/emby.go b/internal/cache/emby.go index ae31e04..74322c6 100644 --- a/internal/cache/emby.go +++ b/internal/cache/emby.go @@ -78,7 +78,7 @@ type EmbyMovieCacheData struct { type EmbyMovieCache = refreshcache1.RefreshCache[*EmbyMovieCacheData, *EmbyUserCache] func NewEmbyMovieCache(movie *model.Movie, subPath string) *EmbyMovieCache { - cache := refreshcache1.NewRefreshCache(NewEmbyMovieCacheInitFunc(movie, subPath), 0) + cache := refreshcache1.NewRefreshCache(NewEmbyMovieCacheInitFunc(movie, subPath), -1) cache.SetClearFunc(NewEmbyMovieClearCacheFunc(movie, subPath)) return cache } @@ -217,7 +217,7 @@ func NewEmbyMovieCacheInitFunc(movie *model.Movie, subPath string) func(ctx cont URL: url, Type: subtutleType, Name: name, - Cache: refreshcache0.NewRefreshCache(newEmbySubtitleCacheInitFunc(url), 0), + Cache: refreshcache0.NewRefreshCache(newEmbySubtitleCacheInitFunc(url), -1), }) } } diff --git a/internal/conf/server.go b/internal/conf/server.go index a34fcbe..3d07dc6 100644 --- a/internal/conf/server.go +++ b/internal/conf/server.go @@ -1,8 +1,10 @@ package conf type ServerConfig struct { - Http HttpServerConfig `yaml:"http"` - Rtmp RtmpServerConfig `yaml:"rtmp"` + Http HttpServerConfig `yaml:"http"` + Rtmp RtmpServerConfig `yaml:"rtmp"` + ProxyCachePath string `yaml:"proxy_cache_path" env:"SERVER_PROXY_CACHE_PATH"` + ProxyCacheSize string `yaml:"proxy_cache_size" env:"SERVER_PROXY_CACHE_SIZE"` } type HttpServerConfig struct { @@ -33,5 +35,6 @@ func DefaultServerConfig() ServerConfig { Enable: true, Port: 0, }, + ProxyCachePath: "", } } diff --git a/internal/op/movie.go b/internal/op/movie.go index c50d95f..df10c2a 100644 --- a/internal/op/movie.go +++ b/internal/op/movie.go @@ -269,6 +269,8 @@ func (movie *Movie) Validate() error { return nil } switch { + case m.Live && m.RtmpSource: + return nil case m.Live && m.Proxy: if !settings.LiveProxy.Get() { return errors.New("live proxy is not enabled") @@ -309,7 +311,7 @@ func (movie *Movie) Validate() error { return fmt.Errorf("unsupported scheme: %s", u.Scheme) } default: - return errors.New("unknown error") + return errors.New("validate movie error: unknown error") } return nil } diff --git a/internal/settings/var.go b/internal/settings/var.go index 3da4636..75c1d84 100644 --- a/internal/settings/var.go +++ b/internal/settings/var.go @@ -59,6 +59,7 @@ var ( MovieProxy = NewBoolSetting("movie_proxy", true, model.SettingGroupProxy) LiveProxy = NewBoolSetting("live_proxy", true, model.SettingGroupProxy) AllowProxyToLocal = NewBoolSetting("allow_proxy_to_local", false, model.SettingGroupProxy) + ProxyCacheEnable = NewBoolSetting("proxy_cache_enable", false, model.SettingGroupProxy) ) var ( diff --git a/server/handlers/movie.go b/server/handlers/movie.go index 135c3ab..2fe344e 100644 --- a/server/handlers/movie.go +++ b/server/handlers/movie.go @@ -600,7 +600,7 @@ func ProxyMovie(ctx *gin.Context) { // TODO: cache mpd file fallthrough default: - err = proxy.AutoProxyURL(ctx, m.Movie.MovieBase.Url, m.Movie.MovieBase.Type, m.Movie.MovieBase.Headers, ctx.GetString("token"), room.ID, m.ID) + err = proxy.AutoProxyURL(ctx, m.Movie.MovieBase.Url, m.Movie.MovieBase.Type, m.Movie.MovieBase.Headers, true, ctx.GetString("token"), room.ID, m.ID) if err != nil { log.Errorf("proxy movie error: %v", err) return @@ -654,7 +654,7 @@ func ServeM3u8(ctx *gin.Context) { ctx.AbortWithStatusJSON(http.StatusBadRequest, model.NewApiErrorStringResp("invalid token")) return } - err = proxy.ProxyM3u8(ctx, claims.TargetUrl, m.Movie.MovieBase.Headers, claims.IsM3u8File, ctx.GetString("token"), room.ID, m.ID) + err = proxy.ProxyM3u8(ctx, claims.TargetUrl, m.Movie.MovieBase.Headers, true, claims.IsM3u8File, ctx.GetString("token"), room.ID, m.ID) if err != nil { log.Errorf("proxy m3u8 error: %v", err) } @@ -785,7 +785,7 @@ func JoinHlsLive(ctx *gin.Context) { } if utils.IsM3u8Url(m.Movie.MovieBase.Url) { - err = proxy.ProxyM3u8(ctx, m.Movie.MovieBase.Url, m.Movie.MovieBase.Headers, true, ctx.GetString("token"), room.ID, m.ID) + err = proxy.ProxyM3u8(ctx, m.Movie.MovieBase.Url, m.Movie.MovieBase.Headers, true, true, ctx.GetString("token"), room.ID, m.ID) if err != nil { log.Errorf("proxy m3u8 hls live error: %v", err) } diff --git a/server/handlers/proxy/cache.go b/server/handlers/proxy/cache.go new file mode 100644 index 0000000..8457f7c --- /dev/null +++ b/server/handlers/proxy/cache.go @@ -0,0 +1,477 @@ +package proxy + +import ( + "encoding/binary" + "fmt" + "io" + "net/http" + "os" + "path/filepath" + "sort" + "sync" + "sync/atomic" + "time" + + json "github.com/json-iterator/go" + "github.com/zijiren233/gencontainer/dllist" + "github.com/zijiren233/ksync" +) + +// Cache defines the interface for cache implementations +type Cache interface { + Get(key string) (*CacheItem, bool, error) + Set(key string, data *CacheItem) error +} + +// CacheMetadata stores metadata about a cached response +type CacheMetadata struct { + Headers http.Header `json:"headers"` + ContentType string `json:"content_type"` + ContentTotalLength int64 `json:"content_total_length"` +} + +func (m *CacheMetadata) MarshalBinary() ([]byte, error) { + return json.Marshal(m) +} + +// CacheItem represents a cached response with metadata and data +type CacheItem struct { + Metadata *CacheMetadata + Data []byte +} + +// WriteTo implements io.WriterTo to serialize the cache item +func (i *CacheItem) WriteTo(w io.Writer) (int64, error) { + if w == nil { + return 0, fmt.Errorf("cannot write to nil io.Writer") + } + + if i.Metadata == nil { + return 0, fmt.Errorf("CacheItem contains nil Metadata") + } + + metadata, err := i.Metadata.MarshalBinary() + if err != nil { + return 0, fmt.Errorf("failed to marshal metadata: %w", err) + } + + var written int64 + + // Write metadata length and content + if err := binary.Write(w, binary.BigEndian, int64(len(metadata))); err != nil { + return written, fmt.Errorf("failed to write metadata length: %w", err) + } + written += 8 + + n, err := w.Write(metadata) + written += int64(n) + if err != nil { + return written, fmt.Errorf("failed to write metadata bytes: %w", err) + } + + // Write data length and content + if err := binary.Write(w, binary.BigEndian, int64(len(i.Data))); err != nil { + return written, fmt.Errorf("failed to write data length: %w", err) + } + written += 8 + + n, err = w.Write(i.Data) + written += int64(n) + if err != nil { + return written, fmt.Errorf("failed to write data bytes: %w", err) + } + + return written, nil +} + +// ReadFrom implements io.ReaderFrom to deserialize the cache item +func (i *CacheItem) ReadFrom(r io.Reader) (int64, error) { + if r == nil { + return 0, fmt.Errorf("cannot read from nil io.Reader") + } + + var read int64 + + // Read metadata length and content + var metadataLen int64 + if err := binary.Read(r, binary.BigEndian, &metadataLen); err != nil { + return read, fmt.Errorf("failed to read metadata length: %w", err) + } + read += 8 + + if metadataLen <= 0 { + return read, fmt.Errorf("metadata length must be positive, got: %d", metadataLen) + } + + metadata := make([]byte, metadataLen) + n, err := io.ReadFull(r, metadata) + read += int64(n) + if err != nil { + return read, fmt.Errorf("failed to read metadata bytes: %w", err) + } + + i.Metadata = new(CacheMetadata) + if err := json.Unmarshal(metadata, i.Metadata); err != nil { + return read, fmt.Errorf("failed to unmarshal metadata: %w", err) + } + + // Read data length and content + var dataLen int64 + if err := binary.Read(r, binary.BigEndian, &dataLen); err != nil { + return read, fmt.Errorf("failed to read data length: %w", err) + } + read += 8 + + if dataLen < 0 { + return read, fmt.Errorf("data length cannot be negative, got: %d", dataLen) + } + + i.Data = make([]byte, dataLen) + n, err = io.ReadFull(r, i.Data) + read += int64(n) + if err != nil { + return read, fmt.Errorf("failed to read data bytes: %w", err) + } + + return read, nil +} + +// MemoryCache implements an in-memory Cache with LRU eviction +type MemoryCache struct { + m map[string]*dllist.Element[*cacheEntry] + lruList *dllist.Dllist[*cacheEntry] + capacity int + maxSizeBytes int64 + currentSize int64 + mu sync.RWMutex +} + +type MemoryCacheOption func(*MemoryCache) + +func WithMaxSizeBytes(size int64) MemoryCacheOption { + return func(c *MemoryCache) { + c.maxSizeBytes = size + } +} + +type cacheEntry struct { + item *CacheItem + key string + size int64 +} + +func NewMemoryCache(capacity int, opts ...MemoryCacheOption) *MemoryCache { + mc := &MemoryCache{ + m: make(map[string]*dllist.Element[*cacheEntry]), + lruList: dllist.New[*cacheEntry](), + capacity: capacity, + } + for _, opt := range opts { + opt(mc) + } + return mc +} + +func (c *MemoryCache) Get(key string) (*CacheItem, bool, error) { + if key == "" { + return nil, false, fmt.Errorf("cache key cannot be empty") + } + + c.mu.RLock() + element, exists := c.m[key] + if !exists { + c.mu.RUnlock() + return nil, false, nil + } + + // Upgrade to write lock for moving element + c.mu.RUnlock() + c.mu.Lock() + c.lruList.MoveToFront(element) + item := element.Value.item + c.mu.Unlock() + + return item, true, nil +} + +func (c *MemoryCache) Set(key string, data *CacheItem) error { + if key == "" { + return fmt.Errorf("cache key cannot be empty") + } + if data == nil { + return fmt.Errorf("cannot cache nil CacheItem") + } + + // Calculate size of new item + newSize := int64(len(data.Data)) + if data.Metadata != nil { + metadataBytes, err := data.Metadata.MarshalBinary() + if err == nil { + newSize += int64(len(metadataBytes)) + } + } + + c.mu.Lock() + defer c.mu.Unlock() + + // Update existing entry if present + if element, ok := c.m[key]; ok { + c.currentSize -= element.Value.size + c.currentSize += newSize + c.lruList.MoveToFront(element) + element.Value.item = data + element.Value.size = newSize + return nil + } + + // Evict entries if needed + for c.lruList.Len() > 0 && + ((c.capacity > 0 && c.lruList.Len() >= c.capacity) || + (c.maxSizeBytes > 0 && c.currentSize+newSize > c.maxSizeBytes)) { + + if back := c.lruList.Back(); back != nil { + entry := back.Value + c.currentSize -= entry.size + delete(c.m, entry.key) + c.lruList.Remove(back) + } + } + + // Add new entry + newEntry := &cacheEntry{key: key, item: data, size: newSize} + element := c.lruList.PushFront(newEntry) + c.m[key] = element + c.currentSize += newSize + return nil +} + +type FileCache struct { + mu *ksync.Krwmutex + filePath string + maxSizeBytes int64 + currentSize atomic.Int64 + lastCleanup atomic.Int64 + maxAge time.Duration + cleanMu sync.Mutex +} + +type FileCacheOption func(*FileCache) + +func WithFileCacheMaxSizeBytes(size int64) FileCacheOption { + return func(c *FileCache) { + c.maxSizeBytes = size + } +} + +func WithFileCacheMaxAge(age time.Duration) FileCacheOption { + return func(c *FileCache) { + if age > 0 { + c.maxAge = age + } + } +} + +func NewFileCache(filePath string, opts ...FileCacheOption) *FileCache { + fc := &FileCache{ + filePath: filePath, + mu: ksync.DefaultKrwmutex(), + maxAge: 24 * time.Hour, // Default 1 day + } + + for _, opt := range opts { + opt(fc) + } + + go fc.periodicCleanup() + return fc +} + +func (c *FileCache) periodicCleanup() { + ticker := time.NewTicker(5 * time.Minute) + defer ticker.Stop() + + for range ticker.C { + c.cleanup() + } +} + +func (c *FileCache) cleanup() { + maxSize := c.maxSizeBytes + if maxSize <= 0 { + return + } + + // Avoid frequent cleanups + now := time.Now().Unix() + if now-c.lastCleanup.Load() < 300 { + return + } + + c.cleanMu.Lock() + defer c.cleanMu.Unlock() + + // Double check after acquiring lock + if now-c.lastCleanup.Load() < 300 { + return + } + + entries, err := os.ReadDir(c.filePath) + if err != nil { + return + } + + type fileInfo struct { + modTime time.Time + path string + size int64 + } + + var files []fileInfo + var totalSize int64 + cutoffTime := time.Now().Add(-c.maxAge) + + // Collect file information and remove expired files + for _, entry := range entries { + if !entry.IsDir() { + continue + } + + subdir := filepath.Join(c.filePath, entry.Name()) + subEntries, err := os.ReadDir(subdir) + if err != nil { + continue + } + + for _, subEntry := range subEntries { + info, err := subEntry.Info() + if err != nil { + continue + } + + fullPath := filepath.Join(subdir, subEntry.Name()) + + // Remove expired files + if info.ModTime().Before(cutoffTime) { + os.Remove(fullPath) + continue + } + + files = append(files, fileInfo{ + path: fullPath, + size: info.Size(), + modTime: info.ModTime(), + }) + totalSize += info.Size() + } + } + + // If under size limit, just update size and return + if totalSize <= maxSize { + c.currentSize.Store(totalSize) + c.lastCleanup.Store(now) + return + } + + // Sort by modification time (oldest first) and remove until under limit + sort.Slice(files, func(i, j int) bool { + return files[i].modTime.Before(files[j].modTime) + }) + + for _, file := range files { + if totalSize <= maxSize { + break + } + if err := os.Remove(file.path); err == nil { + totalSize -= file.size + } + } + + c.currentSize.Store(totalSize) + c.lastCleanup.Store(now) +} + +func (c *FileCache) Get(key string) (*CacheItem, bool, error) { + if key == "" { + return nil, false, fmt.Errorf("cache key cannot be empty") + } + + prefix := string(key[0]) + filePath := filepath.Join(c.filePath, prefix, key) + + c.mu.RLock(key) + defer c.mu.RUnlock(key) + + file, err := os.OpenFile(filePath, os.O_RDONLY, 0o644) + if err != nil { + if os.IsNotExist(err) { + return nil, false, nil + } + return nil, false, fmt.Errorf("failed to open cache file: %w", err) + } + defer file.Close() + + // Check if file is expired + if info, err := file.Stat(); err == nil { + if time.Since(info.ModTime()) > c.maxAge { + os.Remove(filePath) + return nil, false, nil + } + } + + item := &CacheItem{} + if _, err := item.ReadFrom(file); err != nil { + return nil, false, fmt.Errorf("failed to read cache item: %w", err) + } + + return item, true, nil +} + +func (c *FileCache) Set(key string, data *CacheItem) error { + if key == "" { + return fmt.Errorf("cache key cannot be empty") + } + if data == nil { + return fmt.Errorf("cannot cache nil CacheItem") + } + + // Check and cleanup if needed + maxSize := c.maxSizeBytes + if maxSize > 0 { + newSize := int64(len(data.Data)) + if data.Metadata != nil { + if metadataBytes, err := data.Metadata.MarshalBinary(); err == nil { + newSize += int64(len(metadataBytes)) + } + } + if c.currentSize.Load()+newSize > maxSize { + c.cleanup() + } + } + + prefix := string(key[0]) + dirPath := filepath.Join(c.filePath, prefix) + filePath := filepath.Join(dirPath, key) + + c.mu.Lock(key) + defer c.mu.Unlock(key) + + if err := os.MkdirAll(dirPath, 0o755); err != nil { + return fmt.Errorf("failed to create cache directory: %w", err) + } + + file, err := os.OpenFile(filePath, os.O_WRONLY|os.O_CREATE|os.O_TRUNC, 0o644) + if err != nil { + return fmt.Errorf("failed to create cache file: %w", err) + } + defer file.Close() + + if _, err := data.WriteTo(file); err != nil { + return fmt.Errorf("failed to write cache item: %w", err) + } + + if info, err := file.Stat(); err == nil { + c.currentSize.Add(info.Size()) + } + + return nil +} diff --git a/server/handlers/proxy/m3u8.go b/server/handlers/proxy/m3u8.go index 3c9351e..b235e15 100644 --- a/server/handlers/proxy/m3u8.go +++ b/server/handlers/proxy/m3u8.go @@ -55,9 +55,10 @@ func NewM3u8TargetToken(targetUrl, roomId, movieId string, isM3u8File bool) (str const maxM3u8FileSize = 3 * 1024 * 1024 // -func ProxyM3u8(ctx *gin.Context, u string, headers map[string]string, isM3u8File bool, token, roomId, movieId string) error { +// only cache non-m3u8 files +func ProxyM3u8(ctx *gin.Context, u string, headers map[string]string, cache bool, isM3u8File bool, token, roomId, movieId string) error { if !isM3u8File { - return ProxyURL(ctx, u, headers) + return ProxyURL(ctx, u, headers, cache) } req, err := http.NewRequestWithContext(ctx, http.MethodGet, u, nil) diff --git a/server/handlers/proxy/proxy.go b/server/handlers/proxy/proxy.go index f57bb07..69ad458 100644 --- a/server/handlers/proxy/proxy.go +++ b/server/handlers/proxy/proxy.go @@ -6,16 +6,78 @@ import ( "fmt" "io" "net/http" + "strconv" "strings" + "sync" "github.com/gin-gonic/gin" + log "github.com/sirupsen/logrus" + "github.com/synctv-org/synctv/internal/conf" "github.com/synctv-org/synctv/internal/settings" "github.com/synctv-org/synctv/server/model" "github.com/synctv-org/synctv/utils" "github.com/zijiren233/go-uhc" ) -func ProxyURL(ctx *gin.Context, u string, headers map[string]string) error { +var ( + defaultCache *MemoryCache + fileCacheOnce sync.Once + fileCache Cache +) + +// MB GB KB +func parseProxyCacheSize(sizeStr string) (int64, error) { + if sizeStr == "" { + return 0, nil + } + sizeStr = strings.ToLower(sizeStr) + sizeStr = strings.TrimSpace(sizeStr) + + var multiplier int64 = 1024 * 1024 // Default MB + + if strings.HasSuffix(sizeStr, "gb") { + multiplier = 1024 * 1024 * 1024 + sizeStr = strings.TrimSuffix(sizeStr, "gb") + } else if strings.HasSuffix(sizeStr, "mb") { + multiplier = 1024 * 1024 + sizeStr = strings.TrimSuffix(sizeStr, "mb") + } else if strings.HasSuffix(sizeStr, "kb") { + multiplier = 1024 + sizeStr = strings.TrimSuffix(sizeStr, "kb") + } + + size, err := strconv.ParseInt(strings.TrimSpace(sizeStr), 10, 64) + if err != nil { + return 0, fmt.Errorf("invalid size format: %w", err) + } + + return size * multiplier, nil +} + +func getCache() Cache { + fileCacheOnce.Do(func() { + size, err := parseProxyCacheSize(conf.Conf.Server.ProxyCacheSize) + if err != nil { + log.Fatalf("parse proxy cache size error: %v", err) + } + if size == 0 { + size = 1024 * 1024 * 1024 + } + if conf.Conf.Server.ProxyCachePath == "" { + log.Infof("proxy cache path is empty, use memory cache, size: %d", size) + defaultCache = NewMemoryCache(0, WithMaxSizeBytes(size)) + return + } + log.Infof("proxy cache path: %s, size: %d", conf.Conf.Server.ProxyCachePath, size) + fileCache = NewFileCache(conf.Conf.Server.ProxyCachePath, WithFileCacheMaxSizeBytes(size)) + }) + if fileCache != nil { + return fileCache + } + return defaultCache +} + +func ProxyURL(ctx *gin.Context, u string, headers map[string]string, cache bool) error { if !settings.AllowProxyToLocal.Get() { if l, err := utils.ParseURLIsLocalIP(u); err != nil { ctx.AbortWithStatusJSON(http.StatusBadRequest, @@ -33,6 +95,17 @@ func ProxyURL(ctx *gin.Context, u string, headers map[string]string) error { return errors.New("not allow proxy to local") } } + + if cache && settings.ProxyCacheEnable.Get() { + rsc := NewHttpReadSeekCloser(u, + WithHeadersMap(headers), + WithNotSupportRange(ctx.GetHeader("Range") == ""), + ) + defer rsc.Close() + return NewSliceCacheProxy(u, 1024*512, rsc, getCache()). + Proxy(ctx.Writer, ctx.Request) + } + ctx2, cf := context.WithCancel(ctx) defer cf() req, err := http.NewRequestWithContext(ctx2, http.MethodGet, u, nil) @@ -103,9 +176,9 @@ func ProxyURL(ctx *gin.Context, u string, headers map[string]string) error { return nil } -func AutoProxyURL(ctx *gin.Context, u, t string, headers map[string]string, token, roomId, movieId string) error { +func AutoProxyURL(ctx *gin.Context, u, t string, headers map[string]string, cache bool, token, roomId, movieId string) error { if strings.HasPrefix(t, "m3u") || utils.IsM3u8Url(u) { - return ProxyM3u8(ctx, u, headers, true, token, roomId, movieId) + return ProxyM3u8(ctx, u, headers, cache, true, token, roomId, movieId) } - return ProxyURL(ctx, u, headers) + return ProxyURL(ctx, u, headers, cache) } diff --git a/server/handlers/proxy/readseeker.go b/server/handlers/proxy/readseeker.go new file mode 100644 index 0000000..5d4e1ec --- /dev/null +++ b/server/handlers/proxy/readseeker.go @@ -0,0 +1,541 @@ +package proxy + +import ( + "context" + "fmt" + "io" + "net/http" + "slices" + "strconv" + "strings" + + "github.com/synctv-org/synctv/utils" +) + +var ( + _ io.ReadSeekCloser = (*HttpReadSeekCloser)(nil) + _ Proxy = (*HttpReadSeekCloser)(nil) +) + +type HttpReadSeekCloser struct { + ctx context.Context + headHeaders http.Header + currentResp *http.Response + headers http.Header + client *http.Client + contentType string + method string + headMethod string + url string + allowedContentTypes []string + notAllowedStatusCodes []int + allowedStatusCodes []int + offset int64 + contentLength int64 + length int64 + currentRespMaxOffset int64 + notSupportRange bool +} + +type HttpReadSeekerConf func(h *HttpReadSeekCloser) + +func WithHeaders(headers http.Header) HttpReadSeekerConf { + return func(h *HttpReadSeekCloser) { + if headers != nil { + h.headers = headers.Clone() + } + } +} + +func WithHeadersMap(headers map[string]string) HttpReadSeekerConf { + return func(h *HttpReadSeekCloser) { + for k, v := range headers { + h.headers.Set(k, v) + } + } +} + +func WithClient(client *http.Client) HttpReadSeekerConf { + return func(h *HttpReadSeekCloser) { + if client != nil { + h.client = client + } + } +} + +func WithMethod(method string) HttpReadSeekerConf { + return func(h *HttpReadSeekCloser) { + if method != "" { + h.method = method + } + } +} + +func WithHeadMethod(method string) HttpReadSeekerConf { + return func(h *HttpReadSeekCloser) { + if method != "" { + h.headMethod = method + } + } +} + +func WithContext(ctx context.Context) HttpReadSeekerConf { + return func(h *HttpReadSeekCloser) { + if ctx != nil { + h.ctx = ctx + } + } +} + +func WithContentLength(contentLength int64) HttpReadSeekerConf { + return func(h *HttpReadSeekCloser) { + if contentLength >= 0 { + h.contentLength = contentLength + } + } +} + +func AllowedContentTypes(types ...string) HttpReadSeekerConf { + return func(h *HttpReadSeekCloser) { + if len(types) > 0 { + h.allowedContentTypes = slices.Clone(types) + } + } +} + +func AllowedStatusCodes(codes ...int) HttpReadSeekerConf { + return func(h *HttpReadSeekCloser) { + if len(codes) > 0 { + h.allowedStatusCodes = slices.Clone(codes) + } + } +} + +func NotAllowedStatusCodes(codes ...int) HttpReadSeekerConf { + return func(h *HttpReadSeekCloser) { + if len(codes) > 0 { + h.notAllowedStatusCodes = slices.Clone(codes) + } + } +} + +func WithLength(length int64) HttpReadSeekerConf { + return func(h *HttpReadSeekCloser) { + if length > 0 { + h.length = length + } + } +} + +func WithNotSupportRange(notSupportRange bool) HttpReadSeekerConf { + return func(h *HttpReadSeekCloser) { + h.notSupportRange = notSupportRange + } +} + +func NewHttpReadSeekCloser(url string, conf ...HttpReadSeekerConf) *HttpReadSeekCloser { + rs := &HttpReadSeekCloser{ + url: url, + contentLength: -1, + method: http.MethodGet, + headMethod: http.MethodHead, + length: 1024 * 1024 * 16, + headers: make(http.Header), + ctx: context.Background(), + client: http.DefaultClient, + } + + for _, c := range conf { + if c != nil { + c(rs) + } + } + + rs.fix() + + return rs +} + +func (h *HttpReadSeekCloser) fix() *HttpReadSeekCloser { + if h.method == "" { + h.method = http.MethodGet + } + if h.headMethod == "" { + h.headMethod = http.MethodHead + } + if h.ctx == nil { + h.ctx = context.Background() + } + if h.client == nil { + h.client = http.DefaultClient + } + if len(h.notAllowedStatusCodes) == 0 { + h.notAllowedStatusCodes = []int{http.StatusNotFound} + } + if h.length <= 0 { + h.length = 64 * 1024 + } + if h.headers == nil { + h.headers = make(http.Header) + } + return h +} + +func (h *HttpReadSeekCloser) Read(p []byte) (n int, err error) { + for n < len(p) { + if h.currentResp == nil || h.offset > h.currentRespMaxOffset { + if err := h.FetchNextChunk(); err != nil { + if err == io.EOF { + return n, io.EOF + } + return 0, fmt.Errorf("failed to fetch next chunk: %w", err) + } + } + + readN, err := h.currentResp.Body.Read(p[n:]) + if readN > 0 { + n += readN + h.offset += int64(readN) + } + + if err == io.EOF { + h.closeCurrentResp() + if n < len(p) { + continue + } + break + } + if err != nil { + if n > 0 { + return n, nil + } + return 0, fmt.Errorf("error reading response body: %w", err) + } + } + + return n, nil +} + +func (h *HttpReadSeekCloser) FetchNextChunk() error { + h.closeCurrentResp() + + if h.contentLength > 0 && h.offset >= h.contentLength { + return io.EOF + } + + req, err := h.createRequest() + if err != nil { + return fmt.Errorf("failed to create request: %w", err) + } + + resp, err := h.client.Do(req) + if err != nil { + return fmt.Errorf("failed to execute HTTP request: %w", err) + } + + h.contentType = resp.Header.Get("Content-Type") + + if resp.StatusCode == http.StatusOK { + // if the maximum offset of the current response is less than the content length minus one, it means that the server does not support range requests + if h.currentRespMaxOffset < resp.ContentLength-1 || h.notSupportRange { + // if the offset is not 0, it means that the seek method is incorrectly used + if h.offset != 0 { + resp.Body.Close() + return fmt.Errorf("server does not support range requests, cannot seek to non-zero offset") + } + h.notSupportRange = true + h.contentLength = resp.ContentLength + h.currentRespMaxOffset = h.contentLength - 1 + h.currentResp = resp + return nil + } + // if the content length is not known, it may be because the requested length is too long, and a new request is needed + if h.contentLength < 0 { + h.contentLength = resp.ContentLength + resp.Body.Close() + return h.FetchNextChunk() + } + // if the offset is greater than 0, it means that the seek method is incorrectly used + if h.offset > 0 { + resp.Body.Close() + return fmt.Errorf("server does not support range requests, cannot seek to offset %d", h.offset) + } + h.notSupportRange = true + h.currentRespMaxOffset = h.contentLength - 1 + h.currentResp = resp + return nil + } + + if resp.StatusCode != http.StatusPartialContent { + resp.Body.Close() + return fmt.Errorf("unexpected HTTP status code: %d (expected 206 Partial Content)", resp.StatusCode) + } + + if err := h.checkResponse(resp); err != nil { + resp.Body.Close() + return fmt.Errorf("response validation failed: %w", err) + } + + contentTotalLength, err := ParseContentRangeTotalLength(resp.Header.Get("Content-Range")) + if err == nil && contentTotalLength > 0 { + h.contentLength = contentTotalLength + } + _, end, err := ParseContentRangeStartAndEnd(resp.Header.Get("Content-Range")) + if err == nil && end != -1 { + h.currentRespMaxOffset = end + } + + h.currentResp = resp + return nil +} + +func (h *HttpReadSeekCloser) createRequest() (*http.Request, error) { + if h.notSupportRange { + if h.contentLength != -1 { + h.currentRespMaxOffset = h.contentLength - 1 + } + return h.createRequestWithoutRange() + } + + req, err := h.createRequestWithoutRange() + if err != nil { + return nil, err + } + + end := h.offset + h.length - 1 + if h.contentLength > 0 && end > h.contentLength-1 { + end = h.contentLength - 1 + } + + h.currentRespMaxOffset = end + + req.Header.Set("Range", fmt.Sprintf("bytes=%d-%d", h.offset, end)) + return req, nil +} + +func (h *HttpReadSeekCloser) createRequestWithoutRange() (*http.Request, error) { + req, err := http.NewRequestWithContext(h.ctx, h.method, h.url, nil) + if err != nil { + return nil, fmt.Errorf("failed to create HTTP request: %w", err) + } + req.Header = h.headers.Clone() + req.Header.Del("Range") + if req.Header.Get("User-Agent") == "" { + req.Header.Set("User-Agent", utils.UA) + } + return req, nil +} + +func (h *HttpReadSeekCloser) checkResponse(resp *http.Response) error { + if err := h.checkStatusCode(resp.StatusCode); err != nil { + return err + } + return h.checkContentType(resp.Header.Get("Content-Type")) +} + +func (h *HttpReadSeekCloser) closeCurrentResp() { + if h.currentResp != nil { + h.currentResp.Body.Close() + h.currentResp = nil + } +} + +func (h *HttpReadSeekCloser) checkContentType(ct string) error { + if len(h.allowedContentTypes) != 0 { + if ct == "" || slices.Index(h.allowedContentTypes, ct) == -1 { + return fmt.Errorf("content type '%s' is not in the list of allowed content types: %v", ct, h.allowedContentTypes) + } + } + return nil +} + +func (h *HttpReadSeekCloser) checkStatusCode(code int) error { + if len(h.allowedStatusCodes) != 0 { + if slices.Index(h.allowedStatusCodes, code) == -1 { + return fmt.Errorf("HTTP status code %d is not in the list of allowed status codes: %v", code, h.allowedStatusCodes) + } + return nil + } + if len(h.notAllowedStatusCodes) != 0 { + if slices.Index(h.notAllowedStatusCodes, code) != -1 { + return fmt.Errorf("HTTP status code %d is in the list of not allowed status codes: %v", code, h.notAllowedStatusCodes) + } + } + return nil +} + +func (h *HttpReadSeekCloser) Seek(offset int64, whence int) (int64, error) { + newOffset, err := h.calculateNewOffset(offset, whence) + if err != nil { + return 0, fmt.Errorf("failed to calculate new offset: %w", err) + } + + if newOffset < 0 { + return 0, fmt.Errorf("cannot seek to negative offset: %d", newOffset) + } + + if newOffset != h.offset { + h.closeCurrentResp() + h.offset = newOffset + } + + return h.offset, nil +} + +func (h *HttpReadSeekCloser) calculateNewOffset(offset int64, whence int) (int64, error) { + switch whence { + case io.SeekStart: + if h.notSupportRange && offset != 0 && offset != h.offset { + return 0, fmt.Errorf("server does not support range requests, cannot seek to non-zero offset") + } + return offset, nil + case io.SeekCurrent: + if h.notSupportRange && offset != 0 { + return 0, fmt.Errorf("server does not support range requests, cannot seek to non-zero offset") + } + return h.offset + offset, nil + case io.SeekEnd: + if h.contentLength < 0 { + if err := h.fetchContentLength(); err != nil { + return 0, fmt.Errorf("failed to fetch content length: %w", err) + } + } + newOffset := h.contentLength - offset + if h.notSupportRange && newOffset != h.offset { + return 0, fmt.Errorf("server does not support range requests, cannot seek to non-zero offset") + } + return newOffset, nil + default: + return 0, fmt.Errorf("invalid seek whence value: %d (must be 0, 1, or 2)", whence) + } +} + +func (h *HttpReadSeekCloser) fetchContentLength() error { + req, err := h.createRequestWithoutRange() + if err != nil { + return err + } + req.Method = h.headMethod + + resp, err := h.client.Do(req) + if err != nil { + return fmt.Errorf("failed to execute HEAD request: %w", err) + } + defer resp.Body.Close() + + if resp.StatusCode != http.StatusOK { + return fmt.Errorf("unexpected HTTP status code in HEAD request: %d (expected 200 OK)", resp.StatusCode) + } + + if err := h.checkResponse(resp); err != nil { + return fmt.Errorf("HEAD response validation failed: %w", err) + } + + if resp.ContentLength < 0 { + return fmt.Errorf("server returned invalid content length: %d", resp.ContentLength) + } + + h.contentType = resp.Header.Get("Content-Type") + + h.contentLength = resp.ContentLength + h.headHeaders = resp.Header.Clone() + return nil +} + +func (h *HttpReadSeekCloser) Close() error { + if h.currentResp != nil { + return h.currentResp.Body.Close() + } + return nil +} + +func (h *HttpReadSeekCloser) Offset() int64 { + return h.offset +} + +func (h *HttpReadSeekCloser) ContentLength() int64 { + return h.contentLength +} + +func (h *HttpReadSeekCloser) ContentType() (string, error) { + if h.contentType != "" { + return h.contentType, nil + } + return "", fmt.Errorf("content type is not available - no successful response received yet") +} + +func (h *HttpReadSeekCloser) ContentTotalLength() (int64, error) { + if h.contentLength > 0 { + return h.contentLength, nil + } + return 0, fmt.Errorf("content total length is not available - no successful response received yet") +} + +func ParseContentRangeStartAndEnd(contentRange string) (int64, int64, error) { + if contentRange == "" { + return 0, 0, fmt.Errorf("Content-Range header is empty") + } + + if !strings.HasPrefix(contentRange, "bytes ") { + return 0, 0, fmt.Errorf("invalid Content-Range header format (expected 'bytes ' prefix): %s", contentRange) + } + + parts := strings.Split(strings.TrimPrefix(contentRange, "bytes "), "/") + if len(parts) != 2 { + return 0, 0, fmt.Errorf("invalid Content-Range header format (expected 2 parts separated by '/'): %s", contentRange) + } + + rangeParts := strings.Split(strings.TrimSpace(parts[0]), "-") + if len(rangeParts) != 2 { + return 0, 0, fmt.Errorf("invalid Content-Range range format (expected start-end): %s", contentRange) + } + + start, err := strconv.ParseInt(strings.TrimSpace(rangeParts[0]), 10, 64) + if err != nil { + return 0, 0, fmt.Errorf("invalid Content-Range start value '%s': %w", rangeParts[0], err) + } + + rangeParts[1] = strings.TrimSpace(rangeParts[1]) + var end int64 + if rangeParts[1] == "" || rangeParts[1] == "*" { + end = -1 + } else { + end, err = strconv.ParseInt(rangeParts[1], 10, 64) + if err != nil { + return 0, 0, fmt.Errorf("invalid Content-Range end value '%s': %w", rangeParts[1], err) + } + } + + return start, end, nil +} + +// ParseContentRangeTotalLength parses a Content-Range header value and returns the total length +func ParseContentRangeTotalLength(contentRange string) (int64, error) { + if contentRange == "" { + return 0, fmt.Errorf("Content-Range header is empty") + } + + if !strings.HasPrefix(contentRange, "bytes ") { + return 0, fmt.Errorf("invalid Content-Range header format (expected 'bytes ' prefix): %s", contentRange) + } + + parts := strings.Split(strings.TrimPrefix(contentRange, "bytes "), "/") + if len(parts) != 2 { + return 0, fmt.Errorf("invalid Content-Range header format (expected 2 parts separated by '/'): %s", contentRange) + } + + if parts[1] == "" || parts[1] == "*" { + return -1, nil + } + + length, err := strconv.ParseInt(strings.TrimSpace(parts[1]), 10, 64) + if err != nil { + return 0, fmt.Errorf("invalid Content-Range total length value '%s': %w", parts[1], err) + } + + if length < 0 { + return 0, fmt.Errorf("Content-Range total length cannot be negative: %d", length) + } + + return length, nil +} diff --git a/server/handlers/proxy/slice.go b/server/handlers/proxy/slice.go new file mode 100644 index 0000000..fe71e12 --- /dev/null +++ b/server/handlers/proxy/slice.go @@ -0,0 +1,323 @@ +package proxy + +import ( + "crypto/sha256" + "encoding/hex" + "fmt" + "io" + "net/http" + "strconv" + "strings" + + "github.com/zijiren233/ksync" +) + +var mu = ksync.DefaultKmutex() + +// Proxy defines the interface for proxy implementations +type Proxy interface { + io.ReadSeeker + ContentTotalLength() (int64, error) + ContentType() (string, error) +} + +// Headers defines the interface for accessing response headers +type Headers interface { + Headers() http.Header +} + +// SliceCacheProxy implements caching of content slices +type SliceCacheProxy struct { + r Proxy + cache Cache + key string + sliceSize int64 +} + +// NewSliceCacheProxy creates a new SliceCacheProxy instance +func NewSliceCacheProxy(key string, sliceSize int64, r Proxy, cache Cache) *SliceCacheProxy { + return &SliceCacheProxy{ + key: key, + sliceSize: sliceSize, + r: r, + cache: cache, + } +} + +func cacheKey(key string, offset int64, sliceSize int64) string { + key = fmt.Sprintf("%s-%d-%d", key, offset, sliceSize) + hash := sha256.Sum256([]byte(key)) + return hex.EncodeToString(hash[:]) +} + +func (c *SliceCacheProxy) alignedOffset(offset int64) int64 { + return (offset / c.sliceSize) * c.sliceSize +} + +func (c *SliceCacheProxy) fmtContentRange(start, end, total int64) string { + totalStr := "*" + if total >= 0 { + totalStr = strconv.FormatInt(total, 10) + } + if end == -1 { + if total >= 0 { + end = total - 1 + } + return fmt.Sprintf("bytes %d-%d/%s", start, end, totalStr) + } + return fmt.Sprintf("bytes %d-%d/%s", start, end, totalStr) +} + +func (c *SliceCacheProxy) contentLength(start, end, total int64) int64 { + if total == -1 && end == -1 { + return -1 + } + if end == -1 { + if total == -1 { + return -1 + } + return total - start + } + if end >= total && total != -1 { + return total - start + } + return end - start + 1 +} + +func (c *SliceCacheProxy) fmtContentLength(start, end, total int64) string { + length := c.contentLength(start, end, total) + if length == -1 { + return "" + } + return strconv.FormatInt(length, 10) +} + +// ServeHTTP implements http.Handler interface +func (c *SliceCacheProxy) ServeHTTP(w http.ResponseWriter, r *http.Request) { + _ = c.Proxy(w, r) +} + +func (c *SliceCacheProxy) Proxy(w http.ResponseWriter, r *http.Request) error { + byteRange, err := ParseByteRange(r.Header.Get("Range")) + if err != nil { + http.Error(w, fmt.Sprintf("Failed to parse Range header: %v", err), http.StatusBadRequest) + return err + } + + alignedOffset := c.alignedOffset(byteRange.Start) + cacheItem, err := c.getCacheItem(alignedOffset) + if err != nil { + http.Error(w, fmt.Sprintf("Failed to get cache item: %v", err), http.StatusInternalServerError) + return err + } + + c.setResponseHeaders(w, byteRange, cacheItem, r.Header.Get("Range") != "") + if err := c.writeResponse(w, byteRange, alignedOffset, cacheItem); err != nil { + return fmt.Errorf("failed to write response: %w", err) + } + return nil +} + +func (c *SliceCacheProxy) setResponseHeaders(w http.ResponseWriter, byteRange *ByteRange, cacheItem *CacheItem, hasRange bool) { + // Copy headers excluding special ones + for k, v := range cacheItem.Metadata.Headers { + switch k { + case "Content-Type", "Content-Length", "Content-Range", "Accept-Ranges": + continue + default: + w.Header()[k] = v + } + } + + w.Header().Set("Content-Length", c.fmtContentLength(byteRange.Start, byteRange.End, cacheItem.Metadata.ContentTotalLength)) + w.Header().Set("Content-Type", cacheItem.Metadata.ContentType) + if hasRange { + w.Header().Set("Accept-Ranges", "bytes") + w.Header().Set("Content-Range", c.fmtContentRange(byteRange.Start, byteRange.End, cacheItem.Metadata.ContentTotalLength)) + w.WriteHeader(http.StatusPartialContent) + } else { + w.WriteHeader(http.StatusOK) + } +} + +func (c *SliceCacheProxy) writeResponse(w http.ResponseWriter, byteRange *ByteRange, alignedOffset int64, cacheItem *CacheItem) error { + sliceOffset := byteRange.Start - alignedOffset + if sliceOffset < 0 { + return fmt.Errorf("slice offset cannot be negative, got: %d", sliceOffset) + } + + remainingLength := c.contentLength(byteRange.Start, byteRange.End, cacheItem.Metadata.ContentTotalLength) + if remainingLength == 0 { + return nil + } + + // Write initial slice + if remainingLength > 0 { + n := int64(len(cacheItem.Data)) - sliceOffset + if n > remainingLength { + n = remainingLength + } + if n > 0 { + if _, err := w.Write(cacheItem.Data[sliceOffset : sliceOffset+n]); err != nil { + return fmt.Errorf("failed to write initial data slice: %w", err) + } + remainingLength -= n + } + } + + // Write subsequent slices + currentOffset := alignedOffset + c.sliceSize + for remainingLength > 0 { + cacheItem, err := c.getCacheItem(currentOffset) + if err != nil { + return fmt.Errorf("failed to get cache item at offset %d: %w", currentOffset, err) + } + + n := int64(len(cacheItem.Data)) + if n > remainingLength { + n = remainingLength + } + if n > 0 { + if _, err := w.Write(cacheItem.Data[:n]); err != nil { + return fmt.Errorf("failed to write data slice at offset %d: %w", currentOffset, err) + } + remainingLength -= n + } + currentOffset += c.sliceSize + } + + return nil +} + +func (c *SliceCacheProxy) getCacheItem(alignedOffset int64) (*CacheItem, error) { + if alignedOffset < 0 { + return nil, fmt.Errorf("cache item offset cannot be negative, got: %d", alignedOffset) + } + + cacheKey := cacheKey(c.key, alignedOffset, c.sliceSize) + mu.Lock(cacheKey) + defer mu.Unlock(cacheKey) + + // Try to get from cache first + slice, ok, err := c.cache.Get(cacheKey) + if err != nil { + return nil, fmt.Errorf("failed to get item from cache: %w", err) + } + if ok { + return slice, nil + } + + // Fetch from source if not in cache + slice, err = c.fetchFromSource(alignedOffset) + if err != nil { + return nil, fmt.Errorf("failed to fetch item from source: %w", err) + } + + // Store in cache + if err = c.cache.Set(cacheKey, slice); err != nil { + return nil, fmt.Errorf("failed to store item in cache: %w", err) + } + + return slice, nil +} + +func (c *SliceCacheProxy) fetchFromSource(offset int64) (*CacheItem, error) { + if offset < 0 { + return nil, fmt.Errorf("source offset cannot be negative, got: %d", offset) + } + if _, err := c.r.Seek(offset, io.SeekStart); err != nil { + return nil, fmt.Errorf("failed to seek to offset %d in source: %w", offset, err) + } + + buf := make([]byte, c.sliceSize) + n, err := io.ReadFull(c.r, buf) + if err != nil && err != io.ErrUnexpectedEOF { + return nil, fmt.Errorf("failed to read %d bytes from source at offset %d: %w", c.sliceSize, offset, err) + } + + var headers http.Header + if h, ok := c.r.(Headers); ok { + headers = h.Headers().Clone() + } else { + headers = make(http.Header) + } + + contentTotalLength, err := c.r.ContentTotalLength() + if err != nil { + return nil, fmt.Errorf("failed to get content total length from source: %w", err) + } + + contentType, err := c.r.ContentType() + if err != nil { + return nil, fmt.Errorf("failed to get content type from source: %w", err) + } + + return &CacheItem{ + Metadata: &CacheMetadata{ + Headers: headers, + ContentTotalLength: contentTotalLength, + ContentType: contentType, + }, + Data: buf[:n], + }, nil +} + +// ByteRange represents an HTTP Range header value +type ByteRange struct { + Start int64 + End int64 +} + +// ParseByteRange parses a Range header value in the format: +// bytes=- +// where end is optional +func ParseByteRange(r string) (*ByteRange, error) { + if r == "" { + return &ByteRange{Start: 0, End: -1}, nil + } + + if !strings.HasPrefix(r, "bytes=") { + return nil, fmt.Errorf("range header must start with 'bytes=', got: %s", r) + } + + r = strings.TrimPrefix(r, "bytes=") + parts := strings.Split(r, "-") + if len(parts) != 2 { + return nil, fmt.Errorf("range header must contain exactly one hyphen (-) separator, got: %s", r) + } + + parts[0] = strings.TrimSpace(parts[0]) + parts[1] = strings.TrimSpace(parts[1]) + + if parts[0] == "" && parts[1] == "" { + return nil, fmt.Errorf("range header cannot have empty start and end values: %s", r) + } + + var start, end int64 = 0, -1 + var err error + + if parts[0] != "" { + start, err = strconv.ParseInt(parts[0], 10, 64) + if err != nil { + return nil, fmt.Errorf("failed to parse range start value '%s': %v", parts[0], err) + } + if start < 0 { + return nil, fmt.Errorf("range start value must be non-negative, got: %d", start) + } + } + + if parts[1] != "" { + end, err = strconv.ParseInt(parts[1], 10, 64) + if err != nil { + return nil, fmt.Errorf("failed to parse range end value '%s': %v", parts[1], err) + } + if end < 0 { + return nil, fmt.Errorf("range end value must be non-negative, got: %d", end) + } + if start > end { + return nil, fmt.Errorf("range start value (%d) cannot be greater than end value (%d)", start, end) + } + } + + return &ByteRange{Start: start, End: end}, nil +} diff --git a/server/handlers/vendors/vendorAlist/alist.go b/server/handlers/vendors/vendorAlist/alist.go index 9e2bcf5..780f342 100644 --- a/server/handlers/vendors/vendorAlist/alist.go +++ b/server/handlers/vendors/vendorAlist/alist.go @@ -134,7 +134,7 @@ func (s *alistVendorService) ProxyMovie(ctx *gin.Context) { ctx.Data(http.StatusOK, "audio/mpegurl", data.Ali.M3U8ListFile) return case "raw": - err := proxy.AutoProxyURL(ctx, data.URL, s.movie.MovieBase.Type, nil, ctx.GetString("token"), s.movie.RoomID, s.movie.ID) + err := proxy.AutoProxyURL(ctx, data.URL, s.movie.MovieBase.Type, nil, true, ctx.GetString("token"), s.movie.RoomID, s.movie.ID) if err != nil { log.Errorf("proxy vendor movie error: %v", err) } @@ -173,7 +173,7 @@ func (s *alistVendorService) ProxyMovie(ctx *gin.Context) { ctx.AbortWithStatusJSON(http.StatusBadRequest, model.NewApiErrorStringResp("proxy is not enabled")) return } - err = proxy.AutoProxyURL(ctx, data.URL, s.movie.MovieBase.Type, nil, ctx.GetString("token"), s.movie.RoomID, s.movie.ID) + err = proxy.AutoProxyURL(ctx, data.URL, s.movie.MovieBase.Type, nil, true, ctx.GetString("token"), s.movie.RoomID, s.movie.ID) if err != nil { log.Errorf("proxy vendor movie error: %v", err) } diff --git a/server/handlers/vendors/vendorBilibili/bilibili.go b/server/handlers/vendors/vendorBilibili/bilibili.go index b4700be..f8e2b09 100644 --- a/server/handlers/vendors/vendorBilibili/bilibili.go +++ b/server/handlers/vendors/vendorBilibili/bilibili.go @@ -127,7 +127,7 @@ func (s *bilibiliVendorService) ProxyMovie(ctx *gin.Context) { headers["Referer"] = "https://www.bilibili.com" headers["User-Agent"] = utils.UA } - err = proxy.ProxyURL(ctx, mpdC.Urls[streamId], headers) + err = proxy.ProxyURL(ctx, mpdC.Urls[streamId], headers, true) if err != nil { log.Errorf("proxy vendor movie [%s] error: %v", mpdC.Urls[streamId], err) } diff --git a/server/handlers/vendors/vendorEmby/emby.go b/server/handlers/vendors/vendorEmby/emby.go index 9db698d..dd69cda 100644 --- a/server/handlers/vendors/vendorEmby/emby.go +++ b/server/handlers/vendors/vendorEmby/emby.go @@ -125,6 +125,11 @@ func (s *embyVendorService) ProxyMovie(ctx *gin.Context) { ctx.AbortWithStatusJSON(http.StatusInternalServerError, model.NewApiErrorResp(err)) return } + if len(embyC.Sources) == 0 { + log.Errorf("proxy vendor movie error: %v", "no source") + ctx.AbortWithStatusJSON(http.StatusBadRequest, model.NewApiErrorStringResp("no source")) + return + } source, err := strconv.Atoi(ctx.Query("source")) if err != nil { log.Errorf("proxy vendor movie error: %v", err) @@ -136,22 +141,11 @@ func (s *embyVendorService) ProxyMovie(ctx *gin.Context) { ctx.AbortWithStatusJSON(http.StatusBadRequest, model.NewApiErrorStringResp("source out of range")) return } - id, err := strconv.Atoi(ctx.Query("source")) - if err != nil { - log.Errorf("proxy vendor movie error: %v", err) - ctx.AbortWithStatusJSON(http.StatusBadRequest, model.NewApiErrorResp(err)) - return - } - if id >= len(embyC.Sources[source].URL) { - log.Errorf("proxy vendor movie error: %v", "id out of range") - ctx.AbortWithStatusJSON(http.StatusBadRequest, model.NewApiErrorStringResp("id out of range")) - return - } if embyC.Sources[source].IsTranscode { ctx.Redirect(http.StatusFound, embyC.Sources[source].URL) return } - err = proxy.AutoProxyURL(ctx, embyC.Sources[source].URL, "", nil, ctx.GetString("token"), s.movie.RoomID, s.movie.ID) + err = proxy.AutoProxyURL(ctx, embyC.Sources[source].URL, "", nil, true, ctx.GetString("token"), s.movie.RoomID, s.movie.ID) if err != nil { log.Errorf("proxy vendor movie error: %v", err) } diff --git a/utils/utils.go b/utils/utils.go index 7ef1b9a..8897079 100644 --- a/utils/utils.go +++ b/utils/utils.go @@ -298,6 +298,9 @@ func getLocalIPs() []net.IP { } func OptFilePath(filePath string) (string, error) { + if filePath == "" { + return "", nil + } if !filepath.IsAbs(filePath) { return filepath.Abs(filepath.Join(flags.Global.DataDir, filePath)) }