diff --git a/internal/op/hub.go b/internal/op/hub.go index a583da3..7fd9ac3 100644 --- a/internal/op/hub.go +++ b/internal/op/hub.go @@ -14,9 +14,14 @@ import ( "github.com/zijiren233/gencontainer/rwmap" ) +type clients struct { + lock sync.RWMutex + m map[*Client]struct{} +} + type Hub struct { id string - clients rwmap.RWMap[string, *rwmap.RWMap[*Client, struct{}]] + clients rwmap.RWMap[string, *clients] broadcast chan *broadcastMessage exit chan struct{} closed uint32 @@ -66,8 +71,11 @@ func (h *Hub) serve() error { select { case message := <-h.broadcast: h.devMessage(message.data) - h.clients.Range(func(id string, cli *rwmap.RWMap[*Client, struct{}]) bool { - cli.Range(func(c *Client, value struct{}) bool { + h.clients.Range(func(id string, clients *clients) bool { + clients.lock.RLock() + defer clients.lock.RUnlock() + + for c := range clients.m { if utils.In(message.ignoreId, c.u.ID) { return true } @@ -75,11 +83,9 @@ func (h *Hub) serve() error { return true } if err := c.Send(message.data); err != nil { - log.Debugf("hub: %s, write to client err: %s\nmessage: %+v", h.id, err, message) c.Close() } - return true - }) + } return true }) @@ -122,7 +128,7 @@ func (h *Hub) ping() { func (h *Hub) devMessage(msg Message) { switch msg.MessageType() { - case websocket.TextMessage: + case websocket.BinaryMessage: log.Debugf("hub: %s, broadcast:\nmessage: %+v", h.id, msg.String()) } } @@ -140,13 +146,12 @@ func (h *Hub) Close() error { return ErrAlreadyClosed } close(h.exit) - h.clients.Range(func(id string, client *rwmap.RWMap[*Client, struct{}]) bool { + h.clients.Range(func(id string, clients *clients) bool { h.clients.Delete(id) - client.Range(func(key *Client, value struct{}) bool { - client.Delete(key) - key.Close() - return true - }) + for c := range clients.m { + delete(clients.m, c) + c.Close() + } return true }) h.wg.Wait() @@ -181,11 +186,15 @@ func (h *Hub) RegClient(cli *Client) error { if err != nil { return err } - c, _ := h.clients.LoadOrStore(cli.u.ID, &rwmap.RWMap[*Client, struct{}]{}) - _, loaded := c.LoadOrStore(cli, struct{}{}) - if loaded { - return errors.New("client already exist") + c, _ := h.clients.LoadOrStore(cli.u.ID, &clients{}) + c.lock.Lock() + defer c.lock.Unlock() + if c.m == nil { + c.m = make(map[*Client]struct{}) + } else if _, ok := c.m[cli]; ok { + return errors.New("client already exists") } + c.m[cli] = struct{}{} return nil } @@ -200,18 +209,14 @@ func (h *Hub) UnRegClient(cli *Client) error { if !loaded { return errors.New("client not found") } - _, loaded2 := c.LoadAndDelete(cli) - if !loaded2 { + c.lock.Lock() + defer c.lock.Unlock() + if _, ok := c.m[cli]; !ok { return errors.New("client not found") } - if c.Len() == 0 { - if h.clients.CompareAndDelete(cli.u.ID, c) { - c.Range(func(key *Client, value struct{}) bool { - c.Delete(key) - h.RegClient(key) - return true - }) - } + delete(c.m, cli) + if len(c.m) == 0 { + h.clients.CompareAndDelete(cli.u.ID, c) } return nil } @@ -228,17 +233,12 @@ func (h *Hub) SendToUser(userID string, data Message) (err error) { if !ok { return nil } - cli.Range(func(key *Client, value struct{}) bool { - if err = key.Send(data); err != nil { - cli.CompareAndDelete(key, value) - log.Debugf("hub: %s, write to client err: %s\nmessage: %+v", h.id, err, data) - key.Close() + cli.lock.RLock() + defer cli.lock.RUnlock() + for c := range cli.m { + if err = c.Send(data); err != nil { + c.Close() } - return true - }) + } return } - -func (h *Hub) LoadClient(userID string) (*rwmap.RWMap[*Client, struct{}], bool) { - return h.clients.Load(userID) -} diff --git a/internal/op/room.go b/internal/op/room.go index 5466c93..d40b146 100644 --- a/internal/op/room.go +++ b/internal/op/room.go @@ -9,7 +9,6 @@ import ( "github.com/synctv-org/synctv/internal/db" "github.com/synctv-org/synctv/internal/model" "github.com/synctv-org/synctv/utils" - "github.com/zijiren233/gencontainer/rwmap" rtmps "github.com/zijiren233/livelib/server" "github.com/zijiren233/stream" "golang.org/x/crypto/bcrypt" @@ -44,13 +43,6 @@ func (r *Room) Broadcast(data Message, conf ...BroadcastConf) error { return r.hub.Broadcast(data, conf...) } -func (r *Room) LoadClient(userID string) (*rwmap.RWMap[*Client, struct{}], bool) { - if r.hub == nil { - return nil, false - } - return r.hub.LoadClient(userID) -} - func (r *Room) SendToUser(user *User, data Message) error { if r.hub == nil { return nil