diff --git a/server/router/api/v1/sse_hub.go b/server/router/api/v1/sse_hub.go index a04c2474c..88887c4a3 100644 --- a/server/router/api/v1/sse_hub.go +++ b/server/router/api/v1/sse_hub.go @@ -56,6 +56,7 @@ type SSEClient struct { type SSEHub struct { mu sync.RWMutex clients map[*SSEClient]struct{} + closed bool } // NewSSEHub creates a new SSE hub. @@ -75,7 +76,11 @@ func (h *SSEHub) Subscribe(userID int32, role store.Role) *SSEClient { role: role, } h.mu.Lock() - h.clients[c] = struct{}{} + if h.closed { + close(c.events) + } else { + h.clients[c] = struct{}{} + } h.mu.Unlock() return c } @@ -90,6 +95,20 @@ func (h *SSEHub) Unsubscribe(c *SSEClient) { h.mu.Unlock() } +// Close disconnects all subscribed SSE clients. +func (h *SSEHub) Close() { + h.mu.Lock() + defer h.mu.Unlock() + if h.closed { + return + } + h.closed = true + for c := range h.clients { + delete(h.clients, c) + close(c.events) + } +} + // Broadcast sends an event to all connected clients. // Slow clients that have a full buffer will have the event dropped // to avoid blocking the broadcaster. diff --git a/server/router/api/v1/sse_hub_test.go b/server/router/api/v1/sse_hub_test.go index 42e01091a..2e2005ff1 100644 --- a/server/router/api/v1/sse_hub_test.go +++ b/server/router/api/v1/sse_hub_test.go @@ -47,6 +47,28 @@ func TestSSEHub_SubscribeUnsubscribe(t *testing.T) { assert.False(t, ok, "channel should be closed after Unsubscribe") } +func TestSSEHub_Close(t *testing.T) { + hub := NewSSEHub() + c1 := hub.Subscribe(1, store.RoleUser) + c2 := hub.Subscribe(2, store.RoleAdmin) + + hub.Close() + hub.Close() + + for _, ch := range []chan []byte{c1.events, c2.events} { + _, ok := <-ch + assert.False(t, ok, "channel should be closed after hub close") + } + + late := hub.Subscribe(3, store.RoleUser) + _, ok := <-late.events + assert.False(t, ok, "late subscriber should be closed immediately") + + hub.Broadcast(&SSEEvent{Type: SSEEventMemoCreated, Name: "memos/123"}) + hub.Unsubscribe(c1) + hub.Unsubscribe(late) +} + func TestSSEHub_Broadcast(t *testing.T) { hub := NewSSEHub() client := hub.Subscribe(1, store.RoleUser) diff --git a/server/router/api/v1/test/sse_handler_test.go b/server/router/api/v1/test/sse_handler_test.go index c06f9e3fd..2a755e1ed 100644 --- a/server/router/api/v1/test/sse_handler_test.go +++ b/server/router/api/v1/test/sse_handler_test.go @@ -2,9 +2,11 @@ package test import ( "context" + "io" "net/http" "net/http/httptest" "testing" + "time" "github.com/labstack/echo/v5" "github.com/stretchr/testify/require" @@ -75,4 +77,34 @@ func TestSSEHandler_Authentication(t *testing.T) { e.ServeHTTP(rec, req) require.Equal(t, http.StatusUnauthorized, rec.Code) }) + + t.Run("hub close disconnects stream", func(t *testing.T) { + server := httptest.NewServer(e) + defer server.Close() + + req, err := http.NewRequestWithContext(ctx, http.MethodGet, server.URL+"/api/v1/sse", nil) + require.NoError(t, err) + req.Header.Set("Authorization", "Bearer "+token) + + resp, err := server.Client().Do(req) + require.NoError(t, err) + defer resp.Body.Close() + require.Equal(t, http.StatusOK, resp.StatusCode) + require.Equal(t, "text/event-stream", resp.Header.Get("Content-Type")) + + ts.Service.SSEHub.Close() + + done := make(chan error, 1) + go func() { + _, err := io.ReadAll(resp.Body) + done <- err + }() + + select { + case err := <-done: + require.NoError(t, err) + case <-time.After(time.Second): + t.Fatal("SSE stream did not close after hub close") + } + }) } diff --git a/server/server.go b/server/server.go index dd3247def..6096228d7 100644 --- a/server/server.go +++ b/server/server.go @@ -6,7 +6,7 @@ import ( "log/slog" "net" "net/http" - "runtime" + "sync" "time" "github.com/google/uuid" @@ -25,14 +25,19 @@ import ( "github.com/usememos/memos/store" ) +const shutdownTimeout = 10 * time.Second + type Server struct { Secret string Profile *profile.Profile Store *store.Store - echoServer *echo.Echo - httpServer *http.Server - runnerCancelFuncs []context.CancelFunc + echoServer *echo.Echo + httpServer *http.Server + sseHub *apiv1.SSEHub + + backgroundRunnerCancels []context.CancelFunc + backgroundRunnerWG sync.WaitGroup } func NewServer(ctx context.Context, profile *profile.Profile, store *store.Store) (*Server, error) { @@ -67,6 +72,7 @@ func NewServer(ctx context.Context, profile *profile.Profile, store *store.Store rootGroup := echoServer.Group("") apiV1Service := apiv1.NewAPIV1Service(s.Secret, profile, store) + s.sseHub = apiV1Service.SSEHub // Register HTTP file server routes BEFORE gRPC-Gateway to ensure proper range request handling for Safari. // This uses native HTTP serving (http.ServeContent) instead of gRPC for video/audio files. @@ -109,30 +115,21 @@ func (s *Server) Start(ctx context.Context) error { slog.Error("failed to start echo server", "error", err) } }() - s.StartBackgroundRunners(ctx) + s.startBackgroundRunners(ctx) return nil } func (s *Server) Shutdown(ctx context.Context) { - ctx, cancel := context.WithTimeout(ctx, 10*time.Second) + ctx, cancel := context.WithTimeout(ctx, shutdownTimeout) defer cancel() slog.Info("server shutting down") - // Cancel all background runners - for _, cancelFunc := range s.runnerCancelFuncs { - if cancelFunc != nil { - cancelFunc() - } - } - - // Shutdown HTTP server. - if s.httpServer != nil { - if err := s.httpServer.Shutdown(ctx); err != nil { - slog.Error("failed to shutdown server", slog.String("error", err.Error())) - } - } + s.stopBackgroundRunners() + s.closeLongLivedConnections() + s.shutdownHTTPServer(ctx) + s.waitBackgroundRunners(ctx) // Close database connection. if err := s.Store.Close(); err != nil { @@ -142,26 +139,73 @@ func (s *Server) Shutdown(ctx context.Context) { slog.Info("memos stopped properly") } -func (s *Server) StartBackgroundRunners(ctx context.Context) { +func (s *Server) startBackgroundRunners(ctx context.Context) { // Create a separate context for each background runner // This allows us to control cancellation for each runner independently s3Context, s3Cancel := context.WithCancel(ctx) // Store the cancel function so we can properly shut down runners - s.runnerCancelFuncs = append(s.runnerCancelFuncs, s3Cancel) + s.backgroundRunnerCancels = append(s.backgroundRunnerCancels, s3Cancel) // Create and start S3 presign runner s3presignRunner := s3presign.NewRunner(s.Store) s3presignRunner.RunOnce(ctx) // Start continuous S3 presign runner + s.backgroundRunnerWG.Add(1) go func() { + defer s.backgroundRunnerWG.Done() s3presignRunner.Run(s3Context) slog.Info("s3presign runner stopped") }() - // Log the number of goroutines running - slog.Info("background runners started", "goroutines", runtime.NumGoroutine()) + slog.Info("background runners started") +} + +func (s *Server) stopBackgroundRunners() { + for _, cancelFunc := range s.backgroundRunnerCancels { + if cancelFunc != nil { + cancelFunc() + } + } +} + +func (s *Server) waitBackgroundRunners(ctx context.Context) { + done := make(chan struct{}) + go func() { + s.backgroundRunnerWG.Wait() + close(done) + }() + + select { + case <-done: + case <-ctx.Done(): + select { + case <-done: + return + default: + } + slog.Error("failed to stop background runners", slog.String("error", ctx.Err().Error())) + } +} + +func (s *Server) closeLongLivedConnections() { + // Long-lived SSE requests do not finish on their own during http.Server.Shutdown. + if s.sseHub != nil { + s.sseHub.Close() + } +} + +func (s *Server) shutdownHTTPServer(ctx context.Context) { + if s.httpServer == nil { + return + } + if err := s.httpServer.Shutdown(ctx); err != nil { + slog.Error("failed to shutdown server", slog.String("error", err.Error())) + if closeErr := s.httpServer.Close(); closeErr != nil && closeErr != http.ErrServerClosed { + slog.Error("failed to close server", slog.String("error", closeErr.Error())) + } + } } func (s *Server) getOrUpsertInstanceBasicSetting(ctx context.Context) (*storepb.InstanceBasicSetting, error) {