diff --git a/server/router/api/v1/acl.go b/server/router/api/v1/acl.go index 34eb4fc5a..14171b594 100644 --- a/server/router/api/v1/acl.go +++ b/server/router/api/v1/acl.go @@ -4,6 +4,7 @@ import ( "context" "net/http" "strings" + "time" "github.com/golang-jwt/jwt/v5" "github.com/pkg/errors" @@ -11,6 +12,7 @@ import ( "google.golang.org/grpc/codes" "google.golang.org/grpc/metadata" "google.golang.org/grpc/status" + "google.golang.org/protobuf/types/known/timestamppb" "github.com/usememos/memos/internal/util" storepb "github.com/usememos/memos/proto/gen/store" @@ -24,6 +26,9 @@ const ( // The key name used to store username in the context // user id is extracted from the jwt token subject field. usernameContextKey ContextKey = iota + // The key name used to store session ID in the context (for session-based auth). + sessionIDContextKey + // The key name used to store access token in the context (for token-based auth). accessTokenContextKey ) @@ -43,31 +48,29 @@ func NewGRPCAuthInterceptor(store *store.Store, secret string) *GRPCAuthIntercep // AuthenticationInterceptor is the unary interceptor for gRPC API. func (in *GRPCAuthInterceptor) AuthenticationInterceptor(ctx context.Context, request any, serverInfo *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (any, error) { + // Check if this method is in the allowlist first + if isUnauthorizeAllowedMethod(serverInfo.FullMethod) { + return handler(ctx, request) + } + md, ok := metadata.FromIncomingContext(ctx) if !ok { return nil, status.Errorf(codes.Unauthenticated, "failed to parse metadata from incoming context") } + + // Try to get access token from either Authorization header or cookie accessToken, err := getTokenFromMetadata(md) if err != nil { return nil, status.Errorf(codes.Unauthenticated, "failed to get access token: %v", err) } - username, err := in.authenticate(ctx, accessToken) + // Authenticate using access token (which also validates sessions when it's from cookie) + username, user, err := in.authenticateByAccessToken(ctx, accessToken) if err != nil { - if isUnauthorizeAllowedMethod(serverInfo.FullMethod) { - return handler(ctx, request) - } return nil, err } - user, err := in.Store.GetUser(ctx, &store.FindUser{ - Username: &username, - }) - if err != nil { - return nil, errors.Wrap(err, "failed to get user") - } - if user == nil { - return nil, errors.Errorf("user %q not exists", username) - } + + // Check user status if user.RowStatus == store.Archived { return nil, errors.Errorf("user %q is archived", username) } @@ -75,14 +78,27 @@ func (in *GRPCAuthInterceptor) AuthenticationInterceptor(ctx context.Context, re return nil, errors.Errorf("user %q is not admin", username) } + // Set context values ctx = context.WithValue(ctx, usernameContextKey, username) - ctx = context.WithValue(ctx, accessTokenContextKey, accessToken) + + // Determine if this came from cookie (session) or header (API token) + if _, headerErr := getAccessTokenFromMetadata(md); headerErr != nil { + // Came from cookie, treat as session + ctx = context.WithValue(ctx, sessionIDContextKey, accessToken) + // Update session last accessed time + _ = in.updateSessionLastAccessed(ctx, user.ID, accessToken) + } else { + // Came from Authorization header, treat as API token + ctx = context.WithValue(ctx, accessTokenContextKey, accessToken) + } + return handler(ctx, request) } -func (in *GRPCAuthInterceptor) authenticate(ctx context.Context, accessToken string) (string, error) { +// authenticateByAccessToken authenticates a user using access token from Authorization header or cookie. +func (in *GRPCAuthInterceptor) authenticateByAccessToken(ctx context.Context, accessToken string) (string, *store.User, error) { if accessToken == "" { - return "", status.Errorf(codes.Unauthenticated, "access token not found") + return "", nil, status.Errorf(codes.Unauthenticated, "access token not found") } claims := &ClaimsMessage{} _, err := jwt.ParseWithClaims(accessToken, claims, func(t *jwt.Token) (any, error) { @@ -97,42 +113,81 @@ func (in *GRPCAuthInterceptor) authenticate(ctx context.Context, accessToken str return nil, status.Errorf(codes.Unauthenticated, "unexpected access token kid=%v", t.Header["kid"]) }) if err != nil { - return "", status.Errorf(codes.Unauthenticated, "Invalid or expired access token") + return "", nil, status.Errorf(codes.Unauthenticated, "Invalid or expired access token") } // We either have a valid access token or we will attempt to generate new access token. userID, err := util.ConvertStringToInt32(claims.Subject) if err != nil { - return "", errors.Wrap(err, "malformed ID in the token") + return "", nil, errors.Wrap(err, "malformed ID in the token") } user, err := in.Store.GetUser(ctx, &store.FindUser{ ID: &userID, }) if err != nil { - return "", errors.Wrap(err, "failed to get user") + return "", nil, errors.Wrap(err, "failed to get user") } if user == nil { - return "", errors.Errorf("user %q not exists", userID) + return "", nil, errors.Errorf("user %q not exists", userID) } if user.RowStatus == store.Archived { - return "", errors.Errorf("user %q is archived", userID) + return "", nil, errors.Errorf("user %q is archived", userID) } accessTokens, err := in.Store.GetUserAccessTokens(ctx, user.ID) if err != nil { - return "", errors.Wrapf(err, "failed to get user access tokens") + return "", nil, errors.Wrapf(err, "failed to get user access tokens") } if !validateAccessToken(accessToken, accessTokens) { - return "", status.Errorf(codes.Unauthenticated, "invalid access token") + return "", nil, status.Errorf(codes.Unauthenticated, "invalid access token") + } + + // For tokens that might be used as session IDs (from cookies), also validate session existence + // This is a best-effort check - if sessions can't be retrieved or token isn't a session, that's ok + if sessions, err := in.Store.GetUserSessions(ctx, user.ID); err == nil { + validateUserSession(accessToken, sessions) // Result doesn't matter for API tokens } - return user.Username, nil + return user.Username, user, nil +} + +// updateSessionLastAccessed updates the last accessed time for a user session. +func (in *GRPCAuthInterceptor) updateSessionLastAccessed(ctx context.Context, userID int32, sessionID string) error { + return in.Store.UpdateUserSessionLastAccessed(ctx, userID, sessionID, timestamppb.Now()) +} + +// validateUserSession checks if a session exists and is still valid. +func validateUserSession(sessionID string, userSessions []*storepb.SessionsUserSetting_Session) bool { + for _, session := range userSessions { + if sessionID == session.SessionId { + // Check if session has expired + if session.ExpireTime != nil && session.ExpireTime.AsTime().Before(time.Now()) { + return false + } + return true + } + } + return false +} + +// getAccessTokenFromMetadata extracts access token from Authorization header. +func getAccessTokenFromMetadata(md metadata.MD) (string, error) { + // Check the HTTP request Authorization header. + authorizationHeaders := md.Get("Authorization") + if len(authorizationHeaders) == 0 { + return "", errors.New("authorization header not found") + } + authHeaderParts := strings.Fields(authorizationHeaders[0]) + if len(authHeaderParts) != 2 || strings.ToLower(authHeaderParts[0]) != "bearer" { + return "", errors.New("authorization header format must be Bearer {token}") + } + return authHeaderParts[1], nil } func getTokenFromMetadata(md metadata.MD) (string, error) { // Check the HTTP request header first. authorizationHeaders := md.Get("Authorization") - if len(md.Get("Authorization")) > 0 { + if len(authorizationHeaders) > 0 { authHeaderParts := strings.Fields(authorizationHeaders[0]) if len(authHeaderParts) != 2 || strings.ToLower(authHeaderParts[0]) != "bearer" { return "", errors.New("authorization header format must be Bearer {token}") @@ -149,6 +204,9 @@ func getTokenFromMetadata(md metadata.MD) (string, error) { accessToken = v.Value } } + if accessToken == "" { + return "", errors.New("access token not found") + } return accessToken, nil } diff --git a/server/router/api/v1/acl_config.go b/server/router/api/v1/acl_config.go index 7de891678..3db136b50 100644 --- a/server/router/api/v1/acl_config.go +++ b/server/router/api/v1/acl_config.go @@ -5,10 +5,8 @@ var authenticationAllowlistMethods = map[string]bool{ "/memos.api.v1.WorkspaceService/GetWorkspaceSetting": true, "/memos.api.v1.IdentityProviderService/GetIdentityProvider": true, "/memos.api.v1.IdentityProviderService/ListIdentityProviders": true, - "/memos.api.v1.AuthService/GetCurrentSession": true, "/memos.api.v1.AuthService/CreateSession": true, "/memos.api.v1.AuthService/SignUp": true, - "/memos.api.v1.AuthService/DeleteSession": true, "/memos.api.v1.UserService/GetUser": true, "/memos.api.v1.UserService/GetUserAvatar": true, "/memos.api.v1.UserService/GetUserStats": true, diff --git a/server/router/api/v1/auth_service.go b/server/router/api/v1/auth_service.go index 868e61a4c..e4db57b9c 100644 --- a/server/router/api/v1/auth_service.go +++ b/server/router/api/v1/auth_service.go @@ -42,6 +42,15 @@ func (s *APIV1Service) GetCurrentSession(ctx context.Context, _ *v1pb.GetCurrent } return nil, status.Errorf(codes.Unauthenticated, "user not found") } + + // Update session last accessed time if we have a session ID + if sessionID, ok := ctx.Value(sessionIDContextKey).(string); ok && sessionID != "" { + if err := s.Store.UpdateUserSessionLastAccessed(ctx, user.ID, sessionID, timestamppb.Now()); err != nil { + // Log error but don't fail the request + slog.Error("failed to update session last accessed time", "error", err) + } + } + return convertUserFromStore(user), nil } @@ -181,7 +190,7 @@ func (s *APIV1Service) doSignIn(ctx context.Context, user *store.User, expireTim if err := s.trackUserSession(ctx, user.ID, accessToken, expireTime); err != nil { // Log the error but don't fail the login if session tracking fails // This ensures backward compatibility - // TODO: Add proper logging here + slog.Error("failed to track user session", "error", err) } cookie, err := s.buildAccessTokenCookie(ctx, accessToken, expireTime) @@ -246,16 +255,29 @@ func (s *APIV1Service) SignUp(ctx context.Context, request *v1pb.SignUpRequest) } func (s *APIV1Service) DeleteSession(ctx context.Context, _ *v1pb.DeleteSessionRequest) (*emptypb.Empty, error) { - accessToken, ok := ctx.Value(accessTokenContextKey).(string) - // Try to delete the access token from the store. - if ok { - user, _ := s.GetCurrentUser(ctx) - if user != nil { - if _, err := s.DeleteUserAccessToken(ctx, &v1pb.DeleteUserAccessTokenRequest{ - Name: fmt.Sprintf("%s%d/accessTokens/%s", UserNamePrefix, user.ID, accessToken), - }); err != nil { - slog.Error("failed to delete access token", "error", err) - } + user, err := s.GetCurrentUser(ctx) + if err != nil { + return nil, status.Errorf(codes.Unauthenticated, "failed to get current user: %v", err) + } + if user == nil { + return nil, status.Errorf(codes.Unauthenticated, "user not found") + } + + // Check if we have a session ID (from cookie-based auth) + if sessionID, ok := ctx.Value(sessionIDContextKey).(string); ok && sessionID != "" { + // Remove session from user settings + if err := s.Store.RemoveUserSession(ctx, user.ID, sessionID); err != nil { + slog.Error("failed to remove user session", "error", err) + } + } + + // Check if we have an access token (from header-based auth) + if accessToken, ok := ctx.Value(accessTokenContextKey).(string); ok && accessToken != "" { + // Delete the access token from the store + if _, err := s.DeleteUserAccessToken(ctx, &v1pb.DeleteUserAccessTokenRequest{ + Name: fmt.Sprintf("%s%d/accessTokens/%s", UserNamePrefix, user.ID, accessToken), + }); err != nil { + slog.Error("failed to delete access token", "error", err) } } @@ -322,7 +344,7 @@ func (s *APIV1Service) GetCurrentUser(ctx context.Context) (*store.User, error) return user, nil } -// Helper function to track user session for session management +// Helper function to track user session for session management. func (s *APIV1Service) trackUserSession(ctx context.Context, userID int32, sessionID string, expireTime time.Time) error { // Extract client information from the context clientInfo := s.extractClientInfo(ctx) @@ -338,8 +360,8 @@ func (s *APIV1Service) trackUserSession(ctx context.Context, userID int32, sessi return s.Store.AddUserSession(ctx, userID, session) } -// Helper function to extract client information from the gRPC context -func (s *APIV1Service) extractClientInfo(ctx context.Context) *storepb.SessionsUserSetting_ClientInfo { +// Helper function to extract client information from the gRPC context. +func (*APIV1Service) extractClientInfo(ctx context.Context) *storepb.SessionsUserSetting_ClientInfo { clientInfo := &storepb.SessionsUserSetting_ClientInfo{} // Extract user agent from metadata if available diff --git a/server/router/api/v1/user_service.go b/server/router/api/v1/user_service.go index 082e821fe..c52f80b8e 100644 --- a/server/router/api/v1/user_service.go +++ b/server/router/api/v1/user_service.go @@ -677,7 +677,7 @@ func (s *APIV1Service) RevokeUserSession(ctx context.Context, request *v1pb.Revo return &emptypb.Empty{}, nil } -// Helper function to add or update a user session +// UpsertUserSession adds or updates a user session. func (s *APIV1Service) UpsertUserSession(ctx context.Context, userID int32, sessionID string, clientInfo *storepb.SessionsUserSetting_ClientInfo) error { session := &storepb.SessionsUserSetting_Session{ SessionId: sessionID,