mirror of https://github.com/usememos/memos
chore: implement session sliding expiration and JWT authentication
- Added UpdateSessionLastAccessed method to update session access time. - Enhanced Authenticate method to support both session cookie and JWT token authentication. - Introduced AuthResult struct to encapsulate authentication results. - Added SetUserInContext function to simplify context management for authenticated users. refactor(auth): streamline gRPC and HTTP authentication - Removed gRPC authentication interceptor and replaced it with a unified approach using GatewayAuthMiddleware for HTTP requests. - Updated Connect interceptors to utilize the new authentication logic. - Consolidated public and admin-only method checks into service layer for better maintainability. chore(api): clean up unused code and improve documentation - Removed deprecated logger interceptor and unused gRPC server code. - Updated ACL configuration documentation for clarity on public and admin-only methods. - Enhanced metadata handling in Connect RPC to ensure consistent header access. fix(server): simplify server startup and shutdown process - Eliminated cmux dependency for handling HTTP and gRPC traffic. - Streamlined server initialization and shutdown logic for better performance and readability.pull/5349/head
parent
65a19df4be
commit
09afa579e4
@ -1,134 +0,0 @@
|
||||
package v1
|
||||
|
||||
// gRPC Authentication Interceptor
|
||||
//
|
||||
// This file implements the authentication interceptor for gRPC requests.
|
||||
// It extracts credentials from gRPC metadata and delegates to the shared Authenticator.
|
||||
//
|
||||
// Authentication flow:
|
||||
// 1. Extract session cookie or bearer token from metadata
|
||||
// 2. Validate credentials using Authenticator
|
||||
// 3. Check authorization (admin-only methods)
|
||||
// 4. Set user context and proceed with request
|
||||
//
|
||||
// For public methods (defined in acl_config.go), authentication is skipped.
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/grpc/codes"
|
||||
"google.golang.org/grpc/metadata"
|
||||
"google.golang.org/grpc/status"
|
||||
|
||||
"github.com/usememos/memos/server/auth"
|
||||
"github.com/usememos/memos/store"
|
||||
)
|
||||
|
||||
// GRPCAuthInterceptor is the authentication interceptor for gRPC server.
|
||||
// It validates incoming requests and sets user context for authenticated requests.
|
||||
type GRPCAuthInterceptor struct {
|
||||
authenticator *auth.Authenticator
|
||||
}
|
||||
|
||||
// NewGRPCAuthInterceptor creates a new gRPC authentication interceptor.
|
||||
func NewGRPCAuthInterceptor(store *store.Store, secret string) *GRPCAuthInterceptor {
|
||||
return &GRPCAuthInterceptor{
|
||||
authenticator: auth.NewAuthenticator(store, secret),
|
||||
}
|
||||
}
|
||||
|
||||
// AuthenticationInterceptor is the unary interceptor for gRPC API.
|
||||
//
|
||||
// Authentication strategy (in priority order):
|
||||
// 1. Session Cookie: "user_session" cookie with format "{userID}-{sessionID}"
|
||||
// 2. Bearer Token: "Authorization: Bearer {jwt_token}" header
|
||||
// 3. Public Methods: Allow without auth if method is in public allowlist
|
||||
// 4. Reject: Return Unauthenticated error
|
||||
//
|
||||
// On successful authentication, context values are set:
|
||||
// - auth.UserIDContextKey: The authenticated user's ID
|
||||
// - auth.SessionIDContextKey: Session ID (cookie auth only)
|
||||
// - auth.AccessTokenContextKey: JWT token (bearer auth only).
|
||||
func (in *GRPCAuthInterceptor) AuthenticationInterceptor(ctx context.Context, request any, serverInfo *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (any, error) {
|
||||
md, ok := metadata.FromIncomingContext(ctx)
|
||||
if !ok {
|
||||
// If metadata is missing, only allow public methods
|
||||
if IsPublicMethod(serverInfo.FullMethod) {
|
||||
return handler(ctx, request)
|
||||
}
|
||||
return nil, status.Errorf(codes.Unauthenticated, "failed to parse metadata from incoming context")
|
||||
}
|
||||
|
||||
// Try session cookie authentication
|
||||
if sessionCookie := extractSessionCookieFromMetadata(md); sessionCookie != "" {
|
||||
user, err := in.authenticator.AuthenticateBySession(ctx, sessionCookie)
|
||||
if err == nil && user != nil {
|
||||
_, sessionID, err := auth.ParseSessionCookieValue(sessionCookie)
|
||||
if err != nil {
|
||||
// This should not happen since AuthenticateBySession already validated the cookie
|
||||
// but handle it gracefully anyway
|
||||
sessionID = ""
|
||||
}
|
||||
ctx, err = in.authenticator.AuthorizeAndSetContext(ctx, serverInfo.FullMethod, user, sessionID, "", IsAdminOnlyMethod)
|
||||
if err != nil {
|
||||
return nil, toGRPCError(err, codes.PermissionDenied)
|
||||
}
|
||||
return handler(ctx, request)
|
||||
}
|
||||
}
|
||||
|
||||
// Try bearer token authentication
|
||||
if token := extractBearerTokenFromMetadata(md); token != "" {
|
||||
user, err := in.authenticator.AuthenticateByJWT(ctx, token)
|
||||
if err == nil && user != nil {
|
||||
ctx, err = in.authenticator.AuthorizeAndSetContext(ctx, serverInfo.FullMethod, user, "", token, IsAdminOnlyMethod)
|
||||
if err != nil {
|
||||
return nil, toGRPCError(err, codes.PermissionDenied)
|
||||
}
|
||||
return handler(ctx, request)
|
||||
}
|
||||
}
|
||||
|
||||
// Allow public methods without authentication
|
||||
if IsPublicMethod(serverInfo.FullMethod) {
|
||||
return handler(ctx, request)
|
||||
}
|
||||
|
||||
return nil, status.Errorf(codes.Unauthenticated, "authentication required")
|
||||
}
|
||||
|
||||
// toGRPCError converts an error to a gRPC status error with the given code.
|
||||
// If the error is already a gRPC status error, it is returned as-is.
|
||||
func toGRPCError(err error, code codes.Code) error {
|
||||
if err == nil {
|
||||
return nil
|
||||
}
|
||||
if _, ok := status.FromError(err); ok {
|
||||
return err
|
||||
}
|
||||
return status.Errorf(code, "%v", err)
|
||||
}
|
||||
|
||||
// extractSessionCookieFromMetadata extracts the session cookie value from gRPC metadata.
|
||||
// Checks both "grpcgateway-cookie" (from gRPC-Gateway) and "cookie" (native gRPC).
|
||||
// Returns empty string if no session cookie is found.
|
||||
func extractSessionCookieFromMetadata(md metadata.MD) string {
|
||||
// gRPC-Gateway puts cookies in "grpcgateway-cookie", native gRPC uses "cookie"
|
||||
for _, cookieHeader := range append(md.Get("grpcgateway-cookie"), md.Get("cookie")...) {
|
||||
if cookie := auth.ExtractSessionCookieFromHeader(cookieHeader); cookie != "" {
|
||||
return cookie
|
||||
}
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
// extractBearerTokenFromMetadata extracts JWT token from Authorization header in gRPC metadata.
|
||||
// Returns empty string if no valid bearer token is found.
|
||||
func extractBearerTokenFromMetadata(md metadata.MD) string {
|
||||
authHeaders := md.Get("Authorization")
|
||||
if len(authHeaders) == 0 {
|
||||
return ""
|
||||
}
|
||||
return auth.ExtractBearerToken(authHeaders[0])
|
||||
}
|
||||
@ -1,56 +1,40 @@
|
||||
package v1
|
||||
|
||||
// Access Control List (ACL) Configuration
|
||||
// PublicMethods defines API endpoints that don't require authentication.
|
||||
// All other endpoints require a valid session or access token.
|
||||
//
|
||||
// This file defines which API methods require authentication and which require admin privileges.
|
||||
// Used by both gRPC and Connect interceptors to enforce access control.
|
||||
// This is the SINGLE SOURCE OF TRUTH for public endpoints.
|
||||
// Both Connect interceptor and gRPC-Gateway interceptor use this map.
|
||||
//
|
||||
// Method names follow the gRPC full method format: "/{package}.{service}/{method}"
|
||||
// Example: "/memos.api.v1.MemoService/CreateMemo"
|
||||
|
||||
// publicMethods lists methods that can be called without authentication.
|
||||
// These are typically read-only endpoints for public content or login-related endpoints.
|
||||
var publicMethods = map[string]bool{
|
||||
// Instance info - needed before login
|
||||
"/memos.api.v1.InstanceService/GetInstanceProfile": true,
|
||||
"/memos.api.v1.InstanceService/GetInstanceSetting": true,
|
||||
|
||||
// Auth - login/session endpoints
|
||||
"/memos.api.v1.AuthService/CreateSession": true,
|
||||
"/memos.api.v1.AuthService/GetCurrentSession": true,
|
||||
|
||||
// User - public user info and registration
|
||||
"/memos.api.v1.UserService/CreateUser": true, // Registration (also admin-only when not first user)
|
||||
"/memos.api.v1.UserService/GetUser": true,
|
||||
"/memos.api.v1.UserService/GetUserAvatar": true,
|
||||
"/memos.api.v1.UserService/GetUserStats": true,
|
||||
"/memos.api.v1.UserService/ListAllUserStats": true,
|
||||
"/memos.api.v1.UserService/SearchUsers": true,
|
||||
|
||||
// Identity providers - needed for SSO login
|
||||
"/memos.api.v1.IdentityProviderService/ListIdentityProviders": true,
|
||||
|
||||
// Memo - public memo access
|
||||
"/memos.api.v1.MemoService/GetMemo": true,
|
||||
"/memos.api.v1.MemoService/ListMemos": true,
|
||||
|
||||
// Attachment - public attachment access
|
||||
"/memos.api.v1.AttachmentService/GetAttachmentBinary": true,
|
||||
}
|
||||
|
||||
// adminOnlyMethods lists methods that require admin (Host or Admin role) privileges.
|
||||
// Regular users cannot call these methods even if authenticated.
|
||||
var adminOnlyMethods = map[string]bool{
|
||||
"/memos.api.v1.UserService/CreateUser": true, // Admin creates users (except first user registration)
|
||||
"/memos.api.v1.InstanceService/UpdateInstanceSetting": true,
|
||||
}
|
||||
|
||||
// IsPublicMethod returns true if the method can be called without authentication.
|
||||
func IsPublicMethod(fullMethodName string) bool {
|
||||
return publicMethods[fullMethodName]
|
||||
// Format: Full gRPC procedure path as returned by req.Spec().Procedure (Connect)
|
||||
// or info.FullMethod (gRPC interceptor).
|
||||
var PublicMethods = map[string]struct{}{
|
||||
// Auth Service - login flow must be accessible without auth
|
||||
"/memos.api.v1.AuthService/CreateSession": {},
|
||||
"/memos.api.v1.AuthService/GetCurrentSession": {},
|
||||
|
||||
// Instance Service - needed before login to show instance info
|
||||
"/memos.api.v1.InstanceService/GetInstanceProfile": {},
|
||||
"/memos.api.v1.InstanceService/GetInstanceSetting": {},
|
||||
|
||||
// User Service - public user profiles and stats
|
||||
"/memos.api.v1.UserService/GetUser": {},
|
||||
"/memos.api.v1.UserService/GetUserAvatar": {},
|
||||
"/memos.api.v1.UserService/GetUserStats": {},
|
||||
"/memos.api.v1.UserService/ListAllUserStats": {},
|
||||
"/memos.api.v1.UserService/SearchUsers": {},
|
||||
|
||||
// Identity Provider Service - SSO buttons on login page
|
||||
"/memos.api.v1.IdentityProviderService/ListIdentityProviders": {},
|
||||
|
||||
// Memo Service - public memos (visibility filtering done in service layer)
|
||||
"/memos.api.v1.MemoService/GetMemo": {},
|
||||
"/memos.api.v1.MemoService/ListMemos": {},
|
||||
}
|
||||
|
||||
// IsAdminOnlyMethod returns true if the method requires admin privileges.
|
||||
func IsAdminOnlyMethod(fullMethodName string) bool {
|
||||
return adminOnlyMethods[fullMethodName]
|
||||
// IsPublicMethod checks if a procedure path is public (no authentication required).
|
||||
// Returns true for public methods, false for protected methods.
|
||||
func IsPublicMethod(procedure string) bool {
|
||||
_, ok := PublicMethods[procedure]
|
||||
return ok
|
||||
}
|
||||
|
||||
@ -0,0 +1,96 @@
|
||||
package v1
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
// TestPublicMethodsArePublic verifies that methods in PublicMethods are recognized as public.
|
||||
func TestPublicMethodsArePublic(t *testing.T) {
|
||||
publicMethods := []string{
|
||||
// Auth Service
|
||||
"/memos.api.v1.AuthService/CreateSession",
|
||||
"/memos.api.v1.AuthService/GetCurrentSession",
|
||||
// Instance Service
|
||||
"/memos.api.v1.InstanceService/GetInstanceProfile",
|
||||
"/memos.api.v1.InstanceService/GetInstanceSetting",
|
||||
// User Service
|
||||
"/memos.api.v1.UserService/GetUser",
|
||||
"/memos.api.v1.UserService/GetUserAvatar",
|
||||
"/memos.api.v1.UserService/GetUserStats",
|
||||
"/memos.api.v1.UserService/ListAllUserStats",
|
||||
"/memos.api.v1.UserService/SearchUsers",
|
||||
// Identity Provider Service
|
||||
"/memos.api.v1.IdentityProviderService/ListIdentityProviders",
|
||||
// Memo Service
|
||||
"/memos.api.v1.MemoService/GetMemo",
|
||||
"/memos.api.v1.MemoService/ListMemos",
|
||||
}
|
||||
|
||||
for _, method := range publicMethods {
|
||||
t.Run(method, func(t *testing.T) {
|
||||
assert.True(t, IsPublicMethod(method), "Expected %s to be public", method)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestProtectedMethodsRequireAuth verifies that non-public methods are recognized as protected.
|
||||
func TestProtectedMethodsRequireAuth(t *testing.T) {
|
||||
protectedMethods := []string{
|
||||
// Auth Service - logout requires auth
|
||||
"/memos.api.v1.AuthService/DeleteSession",
|
||||
// Instance Service - admin operations
|
||||
"/memos.api.v1.InstanceService/UpdateInstanceSetting",
|
||||
// User Service - modification operations
|
||||
"/memos.api.v1.UserService/ListUsers",
|
||||
"/memos.api.v1.UserService/UpdateUser",
|
||||
"/memos.api.v1.UserService/DeleteUser",
|
||||
// Memo Service - write operations
|
||||
"/memos.api.v1.MemoService/CreateMemo",
|
||||
"/memos.api.v1.MemoService/UpdateMemo",
|
||||
"/memos.api.v1.MemoService/DeleteMemo",
|
||||
// Attachment Service - write operations
|
||||
"/memos.api.v1.AttachmentService/CreateAttachment",
|
||||
"/memos.api.v1.AttachmentService/DeleteAttachment",
|
||||
// Shortcut Service
|
||||
"/memos.api.v1.ShortcutService/CreateShortcut",
|
||||
"/memos.api.v1.ShortcutService/ListShortcuts",
|
||||
"/memos.api.v1.ShortcutService/UpdateShortcut",
|
||||
"/memos.api.v1.ShortcutService/DeleteShortcut",
|
||||
// Activity Service
|
||||
"/memos.api.v1.ActivityService/GetActivity",
|
||||
}
|
||||
|
||||
for _, method := range protectedMethods {
|
||||
t.Run(method, func(t *testing.T) {
|
||||
assert.False(t, IsPublicMethod(method), "Expected %s to require auth", method)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestUnknownMethodsRequireAuth verifies that unknown methods default to requiring auth.
|
||||
func TestUnknownMethodsRequireAuth(t *testing.T) {
|
||||
unknownMethods := []string{
|
||||
"/unknown.Service/Method",
|
||||
"/memos.api.v1.UnknownService/Method",
|
||||
"",
|
||||
"invalid",
|
||||
}
|
||||
|
||||
for _, method := range unknownMethods {
|
||||
t.Run(method, func(t *testing.T) {
|
||||
assert.False(t, IsPublicMethod(method), "Unknown method %q should require auth", method)
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
// TestPublicMethodsMapConsistency verifies that PublicMethods map matches test expectations.
|
||||
func TestPublicMethodsMapConsistency(t *testing.T) {
|
||||
// Ensure the PublicMethods map has the expected number of entries
|
||||
expectedCount := 13
|
||||
actualCount := len(PublicMethods)
|
||||
assert.Equal(t, expectedCount, actualCount,
|
||||
"PublicMethods map has %d entries, expected %d. Update this test if public methods changed intentionally.",
|
||||
actualCount, expectedCount)
|
||||
}
|
||||
@ -1,53 +0,0 @@
|
||||
package v1
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"log/slog"
|
||||
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/grpc/codes"
|
||||
"google.golang.org/grpc/status"
|
||||
)
|
||||
|
||||
type LoggerInterceptor struct {
|
||||
logStacktrace bool
|
||||
}
|
||||
|
||||
func NewLoggerInterceptor(logStacktrace bool) *LoggerInterceptor {
|
||||
return &LoggerInterceptor{logStacktrace: logStacktrace}
|
||||
}
|
||||
|
||||
func (in *LoggerInterceptor) LoggerInterceptor(ctx context.Context, request any, serverInfo *grpc.UnaryServerInfo, handler grpc.UnaryHandler) (any, error) {
|
||||
resp, err := handler(ctx, request)
|
||||
in.loggerInterceptorDo(ctx, serverInfo.FullMethod, err)
|
||||
return resp, err
|
||||
}
|
||||
|
||||
func (in *LoggerInterceptor) loggerInterceptorDo(ctx context.Context, fullMethod string, err error) {
|
||||
st := status.Convert(err)
|
||||
var logLevel slog.Level
|
||||
var logMsg string
|
||||
switch st.Code() {
|
||||
case codes.OK:
|
||||
logLevel = slog.LevelInfo
|
||||
logMsg = "OK"
|
||||
case codes.Unauthenticated, codes.OutOfRange, codes.PermissionDenied, codes.NotFound:
|
||||
logLevel = slog.LevelInfo
|
||||
logMsg = "client error"
|
||||
case codes.Internal, codes.Unknown, codes.DataLoss, codes.Unavailable, codes.DeadlineExceeded:
|
||||
logLevel = slog.LevelError
|
||||
logMsg = "server error"
|
||||
default:
|
||||
logLevel = slog.LevelError
|
||||
logMsg = "unknown error"
|
||||
}
|
||||
logAttrs := []slog.Attr{slog.String("method", fullMethod)}
|
||||
if err != nil {
|
||||
logAttrs = append(logAttrs, slog.String("error", err.Error()))
|
||||
if in.logStacktrace {
|
||||
logAttrs = append(logAttrs, slog.String("stacktrace", fmt.Sprintf("%v", err)))
|
||||
}
|
||||
}
|
||||
slog.LogAttrs(ctx, logLevel, logMsg, logAttrs...)
|
||||
}
|
||||
Loading…
Reference in New Issue