fix(server): close SSE clients during shutdown

Close long-lived SSE streams before HTTP shutdown so graceful shutdown is not held until the deadline. Also wait for background runners before closing the store to make shutdown ordering explicit.
pull/5864/head
boojack 1 month ago
parent a7fd1dacc9
commit a5ddd5adaf

@ -56,6 +56,7 @@ type SSEClient struct {
type SSEHub struct {
mu sync.RWMutex
clients map[*SSEClient]struct{}
closed bool
}
// NewSSEHub creates a new SSE hub.
@ -75,7 +76,11 @@ func (h *SSEHub) Subscribe(userID int32, role store.Role) *SSEClient {
role: role,
}
h.mu.Lock()
h.clients[c] = struct{}{}
if h.closed {
close(c.events)
} else {
h.clients[c] = struct{}{}
}
h.mu.Unlock()
return c
}
@ -90,6 +95,20 @@ func (h *SSEHub) Unsubscribe(c *SSEClient) {
h.mu.Unlock()
}
// Close disconnects all subscribed SSE clients.
func (h *SSEHub) Close() {
h.mu.Lock()
defer h.mu.Unlock()
if h.closed {
return
}
h.closed = true
for c := range h.clients {
delete(h.clients, c)
close(c.events)
}
}
// Broadcast sends an event to all connected clients.
// Slow clients that have a full buffer will have the event dropped
// to avoid blocking the broadcaster.

@ -47,6 +47,28 @@ func TestSSEHub_SubscribeUnsubscribe(t *testing.T) {
assert.False(t, ok, "channel should be closed after Unsubscribe")
}
func TestSSEHub_Close(t *testing.T) {
hub := NewSSEHub()
c1 := hub.Subscribe(1, store.RoleUser)
c2 := hub.Subscribe(2, store.RoleAdmin)
hub.Close()
hub.Close()
for _, ch := range []chan []byte{c1.events, c2.events} {
_, ok := <-ch
assert.False(t, ok, "channel should be closed after hub close")
}
late := hub.Subscribe(3, store.RoleUser)
_, ok := <-late.events
assert.False(t, ok, "late subscriber should be closed immediately")
hub.Broadcast(&SSEEvent{Type: SSEEventMemoCreated, Name: "memos/123"})
hub.Unsubscribe(c1)
hub.Unsubscribe(late)
}
func TestSSEHub_Broadcast(t *testing.T) {
hub := NewSSEHub()
client := hub.Subscribe(1, store.RoleUser)

@ -2,9 +2,11 @@ package test
import (
"context"
"io"
"net/http"
"net/http/httptest"
"testing"
"time"
"github.com/labstack/echo/v5"
"github.com/stretchr/testify/require"
@ -75,4 +77,34 @@ func TestSSEHandler_Authentication(t *testing.T) {
e.ServeHTTP(rec, req)
require.Equal(t, http.StatusUnauthorized, rec.Code)
})
t.Run("hub close disconnects stream", func(t *testing.T) {
server := httptest.NewServer(e)
defer server.Close()
req, err := http.NewRequestWithContext(ctx, http.MethodGet, server.URL+"/api/v1/sse", nil)
require.NoError(t, err)
req.Header.Set("Authorization", "Bearer "+token)
resp, err := server.Client().Do(req)
require.NoError(t, err)
defer resp.Body.Close()
require.Equal(t, http.StatusOK, resp.StatusCode)
require.Equal(t, "text/event-stream", resp.Header.Get("Content-Type"))
ts.Service.SSEHub.Close()
done := make(chan error, 1)
go func() {
_, err := io.ReadAll(resp.Body)
done <- err
}()
select {
case err := <-done:
require.NoError(t, err)
case <-time.After(time.Second):
t.Fatal("SSE stream did not close after hub close")
}
})
}

@ -6,7 +6,7 @@ import (
"log/slog"
"net"
"net/http"
"runtime"
"sync"
"time"
"github.com/google/uuid"
@ -25,14 +25,19 @@ import (
"github.com/usememos/memos/store"
)
const shutdownTimeout = 10 * time.Second
type Server struct {
Secret string
Profile *profile.Profile
Store *store.Store
echoServer *echo.Echo
httpServer *http.Server
runnerCancelFuncs []context.CancelFunc
echoServer *echo.Echo
httpServer *http.Server
sseHub *apiv1.SSEHub
backgroundRunnerCancels []context.CancelFunc
backgroundRunnerWG sync.WaitGroup
}
func NewServer(ctx context.Context, profile *profile.Profile, store *store.Store) (*Server, error) {
@ -67,6 +72,7 @@ func NewServer(ctx context.Context, profile *profile.Profile, store *store.Store
rootGroup := echoServer.Group("")
apiV1Service := apiv1.NewAPIV1Service(s.Secret, profile, store)
s.sseHub = apiV1Service.SSEHub
// Register HTTP file server routes BEFORE gRPC-Gateway to ensure proper range request handling for Safari.
// This uses native HTTP serving (http.ServeContent) instead of gRPC for video/audio files.
@ -109,30 +115,21 @@ func (s *Server) Start(ctx context.Context) error {
slog.Error("failed to start echo server", "error", err)
}
}()
s.StartBackgroundRunners(ctx)
s.startBackgroundRunners(ctx)
return nil
}
func (s *Server) Shutdown(ctx context.Context) {
ctx, cancel := context.WithTimeout(ctx, 10*time.Second)
ctx, cancel := context.WithTimeout(ctx, shutdownTimeout)
defer cancel()
slog.Info("server shutting down")
// Cancel all background runners
for _, cancelFunc := range s.runnerCancelFuncs {
if cancelFunc != nil {
cancelFunc()
}
}
// Shutdown HTTP server.
if s.httpServer != nil {
if err := s.httpServer.Shutdown(ctx); err != nil {
slog.Error("failed to shutdown server", slog.String("error", err.Error()))
}
}
s.stopBackgroundRunners()
s.closeLongLivedConnections()
s.shutdownHTTPServer(ctx)
s.waitBackgroundRunners(ctx)
// Close database connection.
if err := s.Store.Close(); err != nil {
@ -142,26 +139,73 @@ func (s *Server) Shutdown(ctx context.Context) {
slog.Info("memos stopped properly")
}
func (s *Server) StartBackgroundRunners(ctx context.Context) {
func (s *Server) startBackgroundRunners(ctx context.Context) {
// Create a separate context for each background runner
// This allows us to control cancellation for each runner independently
s3Context, s3Cancel := context.WithCancel(ctx)
// Store the cancel function so we can properly shut down runners
s.runnerCancelFuncs = append(s.runnerCancelFuncs, s3Cancel)
s.backgroundRunnerCancels = append(s.backgroundRunnerCancels, s3Cancel)
// Create and start S3 presign runner
s3presignRunner := s3presign.NewRunner(s.Store)
s3presignRunner.RunOnce(ctx)
// Start continuous S3 presign runner
s.backgroundRunnerWG.Add(1)
go func() {
defer s.backgroundRunnerWG.Done()
s3presignRunner.Run(s3Context)
slog.Info("s3presign runner stopped")
}()
// Log the number of goroutines running
slog.Info("background runners started", "goroutines", runtime.NumGoroutine())
slog.Info("background runners started")
}
func (s *Server) stopBackgroundRunners() {
for _, cancelFunc := range s.backgroundRunnerCancels {
if cancelFunc != nil {
cancelFunc()
}
}
}
func (s *Server) waitBackgroundRunners(ctx context.Context) {
done := make(chan struct{})
go func() {
s.backgroundRunnerWG.Wait()
close(done)
}()
select {
case <-done:
case <-ctx.Done():
select {
case <-done:
return
default:
}
slog.Error("failed to stop background runners", slog.String("error", ctx.Err().Error()))
}
}
func (s *Server) closeLongLivedConnections() {
// Long-lived SSE requests do not finish on their own during http.Server.Shutdown.
if s.sseHub != nil {
s.sseHub.Close()
}
}
func (s *Server) shutdownHTTPServer(ctx context.Context) {
if s.httpServer == nil {
return
}
if err := s.httpServer.Shutdown(ctx); err != nil {
slog.Error("failed to shutdown server", slog.String("error", err.Error()))
if closeErr := s.httpServer.Close(); closeErr != nil && closeErr != http.ErrServerClosed {
slog.Error("failed to close server", slog.String("error", closeErr.Error()))
}
}
}
func (s *Server) getOrUpsertInstanceBasicSetting(ctx context.Context) (*storepb.InstanceBasicSetting, error) {

Loading…
Cancel
Save