diff --git a/internal/bootstrap/vendor.go b/internal/bootstrap/vendor.go index 5983084..a4b38c5 100644 --- a/internal/bootstrap/vendor.go +++ b/internal/bootstrap/vendor.go @@ -3,18 +3,9 @@ package bootstrap import ( "context" - "github.com/synctv-org/synctv/internal/db" "github.com/synctv-org/synctv/internal/vendor" ) func InitVendorBackend(ctx context.Context) error { - vb, err := db.GetAllVendorBackend() - if err != nil { - return err - } - bc, err := vendor.NewBackendConns(ctx, vb) - if err != nil { - return err - } - return vendor.StoreConns(bc) + return vendor.Init(ctx) } diff --git a/internal/vendor/alist.go b/internal/vendor/alist.go index bab58c8..b8a49e2 100644 --- a/internal/vendor/alist.go +++ b/internal/vendor/alist.go @@ -13,7 +13,7 @@ import ( type AlistInterface = alist.AlistHTTPServer func LoadAlistClient(name string) AlistInterface { - if cli, ok := clients.Load().alist[name]; ok { + if cli, ok := LoadClients().alist[name]; ok { return cli } return alistLocalClient diff --git a/internal/vendor/bilibili.go b/internal/vendor/bilibili.go index f697f9e..68cafed 100644 --- a/internal/vendor/bilibili.go +++ b/internal/vendor/bilibili.go @@ -13,7 +13,7 @@ import ( type BilibiliInterface = bilibili.BilibiliHTTPServer func LoadBilibiliClient(name string) BilibiliInterface { - if cli, ok := clients.Load().bilibili[name]; ok { + if cli, ok := LoadClients().bilibili[name]; ok { return cli } return bilibiliLocalClient diff --git a/internal/vendor/emby.go b/internal/vendor/emby.go index a82a44b..e52b5de 100644 --- a/internal/vendor/emby.go +++ b/internal/vendor/emby.go @@ -13,7 +13,7 @@ import ( type EmbyInterface = emby.EmbyHTTPServer func LoadEmbyClient(name string) EmbyInterface { - if cli, ok := clients.Load().emby[name]; ok && cli != nil { + if cli, ok := LoadClients().emby[name]; ok && cli != nil { return cli } return embyLocalClient diff --git a/internal/vendor/vendor.go b/internal/vendor/vendor.go index df62227..e661596 100644 --- a/internal/vendor/vendor.go +++ b/internal/vendor/vendor.go @@ -6,7 +6,9 @@ import ( "crypto/x509" "errors" "fmt" + "maps" "os" + "sync" "sync/atomic" "time" @@ -25,6 +27,7 @@ import ( jwtv4 "github.com/golang-jwt/jwt/v4" "github.com/hashicorp/consul/api" log "github.com/sirupsen/logrus" + "github.com/synctv-org/synctv/internal/db" "github.com/synctv-org/synctv/internal/model" clientv3 "go.etcd.io/etcd/client/v3" "google.golang.org/grpc" @@ -35,30 +38,190 @@ func init() { selector.SetGlobalSelector(wrr.NewBuilder()) } +type Backends struct { + conns map[string]*BackendConn + clients *VendorClients +} + var ( - conns atomic.Value - clients atomic.Pointer[VendorClients] + backends atomic.Pointer[Backends] + lock sync.Mutex ) func LoadClients() *VendorClients { - return clients.Load() + return backends.Load().clients +} + +func storeBackends(conns map[string]*BackendConn, clients *VendorClients) { + backends.Store(&Backends{ + conns: conns, + clients: clients, + }) } -func storeClients(b *VendorClients) { - clients.Store(b) +func loadBackends() *Backends { + return backends.Load() } func LoadConns() map[string]*BackendConn { - return conns.Load().(map[string]*BackendConn) + return backends.Load().conns +} + +func Init(ctx context.Context) error { + vb, err := db.GetAllVendorBackend() + if err != nil { + return err + } + bc, err := newBackendConns(ctx, vb) + if err != nil { + return err + } + vc, err := newVendorClients(bc) + if err != nil { + return err + } + storeBackends(bc, vc) + return nil +} + +func AddVendorBackend(ctx context.Context, backend *model.VendorBackend) error { + if !lock.TryLock() { + return errors.New("vendor backend is updating") + } + defer lock.Unlock() + + raw := LoadConns() + if _, ok := raw[backend.Backend.Endpoint]; ok { + return fmt.Errorf("duplicate endpoint: %s", backend.Backend.Endpoint) + } + + bc, err := newBackendConn(ctx, backend) + if err != nil { + return err + } + + m := maps.Clone(raw) + m[backend.Backend.Endpoint] = bc + + vc, err := newVendorClients(m) + if err != nil { + bc.Conn.Close() + return err + } + + err = db.CreateVendorBackend(backend) + if err != nil { + bc.Conn.Close() + return err + } + + storeBackends(m, vc) + + return nil +} + +func DeleteVendorBackend(ctx context.Context, endpoint string) error { + if !lock.TryLock() { + return errors.New("vendor backend is updating") + } + defer lock.Unlock() + + raw := LoadConns() + if _, ok := raw[endpoint]; !ok { + return fmt.Errorf("endpoint not found: %s", endpoint) + } + + m := maps.Clone(raw) + beforeConn := m[endpoint].Conn + delete(m, endpoint) + + vc, err := newVendorClients(m) + if err != nil { + return err + } + + err = db.DeleteVendorBackend(endpoint) + if err != nil { + return err + } + + storeBackends(m, vc) + beforeConn.Close() + + return nil +} + +func DeleteVendorBackends(ctx context.Context, endpoints []string) error { + if !lock.TryLock() { + return errors.New("vendor backend is updating") + } + defer lock.Unlock() + + m := maps.Clone(LoadConns()) + + var beforeConn = make([]*grpc.ClientConn, len(endpoints)) + for i, endpoint := range endpoints { + if conn, ok := m[endpoint]; !ok { + return fmt.Errorf("endpoint not found: %s", endpoint) + } else { + beforeConn[i] = conn.Conn + } + delete(m, endpoint) + } + + vc, err := newVendorClients(m) + if err != nil { + return err + } + + err = db.DeleteVendorBackends(endpoints) + if err != nil { + return err + } + + storeBackends(m, vc) + for _, conn := range beforeConn { + conn.Close() + } + + return nil } -func StoreConns(c map[string]*BackendConn) error { - vc, err := newVendorClients(c) +func UpdateVendorBackend(ctx context.Context, backend *model.VendorBackend) error { + if !lock.TryLock() { + return errors.New("vendor backend is updating") + } + defer lock.Unlock() + + raw := LoadConns() + if _, ok := raw[backend.Backend.Endpoint]; !ok { + return fmt.Errorf("endpoint not found: %s", backend.Backend.Endpoint) + } + + bc, err := newBackendConn(ctx, backend) if err != nil { return err } - conns.Store(c) - storeClients(vc) + + m := maps.Clone(raw) + beforeConn := m[backend.Backend.Endpoint].Conn + m[backend.Backend.Endpoint] = bc + + vc, err := newVendorClients(m) + if err != nil { + bc.Conn.Close() + return err + } + + err = db.SaveVendorBackend(backend) + if err != nil { + bc.Conn.Close() + return err + } + + storeBackends(m, vc) + beforeConn.Close() + return nil } @@ -85,7 +248,7 @@ func (b *VendorClients) EmbyClients() map[string]EmbyInterface { return b.emby } -func NewBackendConn(ctx context.Context, conf *model.VendorBackend) (conns *BackendConn, err error) { +func newBackendConn(ctx context.Context, conf *model.VendorBackend) (conns *BackendConn, err error) { cc, err := NewGrpcClientConn(ctx, &conf.Backend) if err != nil { return conns, err @@ -96,7 +259,7 @@ func NewBackendConn(ctx context.Context, conf *model.VendorBackend) (conns *Back }, nil } -func NewBackendConns(ctx context.Context, conf []*model.VendorBackend) (conns map[string]*BackendConn, err error) { +func newBackendConns(ctx context.Context, conf []*model.VendorBackend) (conns map[string]*BackendConn, err error) { conns = make(map[string]*BackendConn, len(conf)) defer func() { if err != nil { @@ -110,7 +273,7 @@ func NewBackendConns(ctx context.Context, conf []*model.VendorBackend) (conns ma if _, ok := conns[vb.Backend.Endpoint]; ok { return conns, fmt.Errorf("duplicate endpoint: %s", vb.Backend.Endpoint) } - cc, err := NewBackendConn(ctx, vb) + cc, err := newBackendConn(ctx, vb) if err != nil { return conns, err } diff --git a/server/handlers/admin.go b/server/handlers/admin.go index de0c6d2..136504a 100644 --- a/server/handlers/admin.go +++ b/server/handlers/admin.go @@ -1,7 +1,6 @@ package handlers import ( - "fmt" "net/http" "reflect" @@ -13,7 +12,6 @@ import ( "github.com/synctv-org/synctv/internal/settings" "github.com/synctv-org/synctv/internal/vendor" "github.com/synctv-org/synctv/server/model" - "golang.org/x/exp/maps" "gorm.io/gorm" ) @@ -718,7 +716,7 @@ func AdminGetVendorBackends(ctx *gin.Context) { ctx.JSON(http.StatusOK, model.NewApiDataResp(resp)) } -func AdminAddVendorBackends(ctx *gin.Context) { +func AdminAddVendorBackend(ctx *gin.Context) { // user := ctx.MustGet("user").(*op.User) var req model.AddVendorBackendReq @@ -727,31 +725,7 @@ func AdminAddVendorBackends(ctx *gin.Context) { return } - raw := vendor.LoadConns() - if _, ok := raw[req.Backend.Endpoint]; ok { - ctx.AbortWithStatusJSON(http.StatusBadRequest, model.NewApiErrorStringResp("duplicate endpoint")) - return - } - - bc, err := vendor.NewBackendConn(ctx, (*dbModel.VendorBackend)(&req)) - if err != nil { - ctx.AbortWithStatusJSON(http.StatusBadRequest, model.NewApiErrorResp(err)) - return - } - - m := maps.Clone(raw) - m[req.Backend.Endpoint] = bc - - err = vendor.StoreConns(m) - if err != nil { - bc.Conn.Close() - ctx.AbortWithStatusJSON(http.StatusBadRequest, model.NewApiErrorResp(err)) - return - } - - err = db.CreateVendorBackend((*dbModel.VendorBackend)(&req)) - if err != nil { - bc.Conn.Close() + if err := vendor.AddVendorBackend(ctx, (*dbModel.VendorBackend)(&req)); err != nil { ctx.AbortWithStatusJSON(http.StatusBadRequest, model.NewApiErrorResp(err)) return } @@ -768,39 +742,11 @@ func AdminDeleteVendorBackends(ctx *gin.Context) { return } - raw := vendor.LoadConns() - for _, v := range req.Endpoints { - if _, ok := raw[v]; !ok { - ctx.AbortWithStatusJSON(http.StatusBadRequest, model.NewApiErrorStringResp(fmt.Sprintf("endpoint %s not found", v))) - return - } - } - - err := db.DeleteVendorBackends(req.Endpoints) - if err != nil { - ctx.AbortWithStatusJSON(http.StatusBadRequest, model.NewApiErrorResp(err)) - return - } - - m := maps.Clone(raw) - - var deletedConn = make([]*vendor.BackendConn, len(req.Endpoints)) - - for i, v := range req.Endpoints { - deletedConn[i] = m[v] - delete(m, v) - } - - err = vendor.StoreConns(m) - if err != nil { + if err := vendor.DeleteVendorBackends(ctx, req.Endpoints); err != nil { ctx.AbortWithStatusJSON(http.StatusBadRequest, model.NewApiErrorResp(err)) return } - for _, v := range deletedConn { - v.Conn.Close() - } - ctx.Status(http.StatusNoContent) } @@ -813,37 +759,7 @@ func AdminUpdateVendorBackends(ctx *gin.Context) { return } - var beforeConn *vendor.BackendConn - - raw := vendor.LoadConns() - if c, ok := raw[req.Backend.Endpoint]; !ok { - ctx.AbortWithStatusJSON(http.StatusBadRequest, model.NewApiErrorStringResp("endpoint not found")) - return - } else { - beforeConn = c - } - - bc, err := vendor.NewBackendConn(ctx, (*dbModel.VendorBackend)(&req)) - if err != nil { - ctx.AbortWithStatusJSON(http.StatusBadRequest, model.NewApiErrorResp(err)) - return - } - - m := maps.Clone(raw) - m[req.Backend.Endpoint] = bc - - err = vendor.StoreConns(m) - if err != nil { - bc.Conn.Close() - ctx.AbortWithStatusJSON(http.StatusBadRequest, model.NewApiErrorResp(err)) - return - } - - beforeConn.Conn.Close() - - err = db.SaveVendorBackend((*dbModel.VendorBackend)(&req)) - if err != nil { - bc.Conn.Close() + if err := vendor.UpdateVendorBackend(ctx, (*dbModel.VendorBackend)(&req)); err != nil { ctx.AbortWithStatusJSON(http.StatusBadRequest, model.NewApiErrorResp(err)) return } diff --git a/server/handlers/init.go b/server/handlers/init.go index 3d12165..8556571 100644 --- a/server/handlers/init.go +++ b/server/handlers/init.go @@ -40,7 +40,7 @@ func Init(e *gin.Engine) { admin.GET("/vendors", AdminGetVendorBackends) - admin.POST("/vendors", AdminAddVendorBackends) + admin.POST("/vendors", AdminAddVendorBackend) admin.PUT("/vendors", AdminUpdateVendorBackends)