diff --git a/server/route/api/v2/auth.go b/server/route/api/auth/auth.go similarity index 99% rename from server/route/api/v2/auth.go rename to server/route/api/auth/auth.go index 672b260e4..5a46d0105 100644 --- a/server/route/api/v2/auth.go +++ b/server/route/api/auth/auth.go @@ -1,4 +1,4 @@ -package v2 +package auth import ( "fmt" diff --git a/server/route/api/auth/jwt.go b/server/route/api/auth/jwt.go new file mode 100644 index 000000000..8d5bd1e91 --- /dev/null +++ b/server/route/api/auth/jwt.go @@ -0,0 +1,170 @@ +package auth + +import ( + "fmt" + "log/slog" + "net/http" + "strings" + "time" + + "github.com/golang-jwt/jwt/v5" + "github.com/labstack/echo/v4" + "github.com/pkg/errors" + + "github.com/usememos/memos/internal/util" + storepb "github.com/usememos/memos/proto/gen/store" + "github.com/usememos/memos/store" +) + +const ( + // UserIDContextKey is the key name used to store user id in the context. + UserIDContextKey = "user-id" +) + +func extractTokenFromHeader(c echo.Context) (string, error) { + authHeader := c.Request().Header.Get("Authorization") + if authHeader == "" { + return "", nil + } + + authHeaderParts := strings.Fields(authHeader) + if len(authHeaderParts) != 2 || strings.ToLower(authHeaderParts[0]) != "bearer" { + return "", errors.New("Authorization header format must be Bearer {token}") + } + + return authHeaderParts[1], nil +} + +func findAccessToken(c echo.Context) string { + // Check the HTTP request header first. + accessToken, _ := extractTokenFromHeader(c) + if accessToken == "" { + // Check the cookie. + cookie, _ := c.Cookie(AccessTokenCookieName) + if cookie != nil { + accessToken = cookie.Value + } + } + return accessToken +} + +// JWTMiddleware validates the access token. +func JWTMiddleware(storeInstance *store.Store, next echo.HandlerFunc, secret string) echo.HandlerFunc { + return func(c echo.Context) error { + ctx := c.Request().Context() + path := c.Request().URL.Path + method := c.Request().Method + + // Skip validation for server status endpoints. + if util.HasPrefixes(path, "/api/v1/ping", "/api/v1/status") && method == http.MethodGet { + return next(c) + } + + accessToken := findAccessToken(c) + if accessToken == "" { + // Allow the user to access the public endpoints. + if util.HasPrefixes(path, "/o") { + return next(c) + } + // When the request is not authenticated, we allow the user to access the memo endpoints for those public memos. + if util.HasPrefixes(path, "/api/v1/idp", "/api/v1/memo", "/api/v1/user") && path != "/api/v1/user" && method == http.MethodGet { + return next(c) + } + return echo.NewHTTPError(http.StatusUnauthorized, "Missing access token") + } + + userID, err := getUserIDFromAccessToken(accessToken, secret) + if err != nil { + err = removeAccessTokenAndCookies(c, storeInstance, userID, accessToken) + if err != nil { + slog.Warn("fail to remove AccessToken and Cookies", err) + } + return echo.NewHTTPError(http.StatusUnauthorized, "Invalid or expired access token") + } + + accessTokens, err := storeInstance.GetUserAccessTokens(ctx, userID) + if err != nil { + return echo.NewHTTPError(http.StatusInternalServerError, "Failed to get user access tokens.").WithInternal(err) + } + if !validateAccessToken(accessToken, accessTokens) { + err = removeAccessTokenAndCookies(c, storeInstance, userID, accessToken) + if err != nil { + slog.Warn("fail to remove AccessToken and Cookies", err) + } + return echo.NewHTTPError(http.StatusUnauthorized, "Invalid access token.") + } + + // Even if there is no error, we still need to make sure the user still exists. + user, err := storeInstance.GetUser(ctx, &store.FindUser{ + ID: &userID, + }) + if err != nil { + return echo.NewHTTPError(http.StatusInternalServerError, fmt.Sprintf("Server error to find user ID: %d", userID)).SetInternal(err) + } + if user == nil { + return echo.NewHTTPError(http.StatusUnauthorized, fmt.Sprintf("Failed to find user ID: %d", userID)) + } + + // Stores userID into context. + c.Set(UserIDContextKey, userID) + return next(c) + } +} + +func getUserIDFromAccessToken(accessToken, secret string) (int32, error) { + claims := &ClaimsMessage{} + _, err := jwt.ParseWithClaims(accessToken, claims, func(t *jwt.Token) (any, error) { + if t.Method.Alg() != jwt.SigningMethodHS256.Name { + return nil, errors.Errorf("unexpected access token signing method=%v, expect %v", t.Header["alg"], jwt.SigningMethodHS256) + } + if kid, ok := t.Header["kid"].(string); ok { + if kid == "v1" { + return []byte(secret), nil + } + } + return nil, errors.Errorf("unexpected access token kid=%v", t.Header["kid"]) + }) + if err != nil { + return 0, errors.Wrap(err, "Invalid or expired access token") + } + // We either have a valid access token or we will attempt to generate new access token. + userID, err := util.ConvertStringToInt32(claims.Subject) + if err != nil { + return 0, errors.Wrap(err, "Malformed ID in the token") + } + return userID, nil +} + +func validateAccessToken(accessTokenString string, userAccessTokens []*storepb.AccessTokensUserSetting_AccessToken) bool { + for _, userAccessToken := range userAccessTokens { + if accessTokenString == userAccessToken.AccessToken { + return true + } + } + return false +} + +// removeAccessTokenAndCookies removes the jwt token from the cookies. +func removeAccessTokenAndCookies(c echo.Context, s *store.Store, userID int32, token string) error { + err := s.RemoveUserAccessToken(c.Request().Context(), userID, token) + if err != nil { + return err + } + + cookieExp := time.Now().Add(-1 * time.Hour) + setTokenCookie(c, AccessTokenCookieName, "", cookieExp) + return nil +} + +// setTokenCookie sets the token to the cookie. +func setTokenCookie(c echo.Context, name, token string, expiration time.Time) { + cookie := new(http.Cookie) + cookie.Name = name + cookie.Value = token + cookie.Expires = expiration + cookie.Path = "/" + // Http-only helps mitigate the risk of client side script accessing the protected cookie. + cookie.HttpOnly = true + cookie.SameSite = http.SameSiteStrictMode + c.SetCookie(cookie) +} diff --git a/server/route/api/v2/acl.go b/server/route/api/v2/acl.go index 1ff6c9f13..23b0a5976 100644 --- a/server/route/api/v2/acl.go +++ b/server/route/api/v2/acl.go @@ -14,6 +14,7 @@ import ( "github.com/usememos/memos/internal/util" storepb "github.com/usememos/memos/proto/gen/store" + "github.com/usememos/memos/server/route/api/auth" "github.com/usememos/memos/store" ) @@ -83,7 +84,7 @@ func (in *GRPCAuthInterceptor) authenticate(ctx context.Context, accessToken str if accessToken == "" { return "", status.Errorf(codes.Unauthenticated, "access token not found") } - claims := &ClaimsMessage{} + claims := &auth.ClaimsMessage{} _, err := jwt.ParseWithClaims(accessToken, claims, func(t *jwt.Token) (any, error) { if t.Method.Alg() != jwt.SigningMethodHS256.Name { return nil, status.Errorf(codes.Unauthenticated, "unexpected access token signing method=%v, expect %v", t.Header["alg"], jwt.SigningMethodHS256) @@ -144,7 +145,7 @@ func getTokenFromMetadata(md metadata.MD) (string, error) { header := http.Header{} header.Add("Cookie", t) request := http.Request{Header: header} - if v, _ := request.Cookie(AccessTokenCookieName); v != nil { + if v, _ := request.Cookie(auth.AccessTokenCookieName); v != nil { accessToken = v.Value } } diff --git a/server/route/api/v2/auth_service.go b/server/route/api/v2/auth_service.go index 167ea643d..f1dec6430 100644 --- a/server/route/api/v2/auth_service.go +++ b/server/route/api/v2/auth_service.go @@ -19,6 +19,7 @@ import ( "github.com/usememos/memos/plugin/idp/oauth2" apiv2pb "github.com/usememos/memos/proto/gen/api/v2" storepb "github.com/usememos/memos/proto/gen/store" + "github.com/usememos/memos/server/route/api/auth" "github.com/usememos/memos/store" ) @@ -57,7 +58,7 @@ func (s *APIV2Service) SignIn(ctx context.Context, request *apiv2pb.SignInReques return nil, status.Errorf(codes.InvalidArgument, "unmatched email and password") } - expireTime := time.Now().Add(AccessTokenDuration) + expireTime := time.Now().Add(auth.AccessTokenDuration) if request.NeverExpire { // Set the expire time to 100 years. expireTime = time.Now().Add(100 * 365 * 24 * time.Hour) @@ -140,7 +141,7 @@ func (s *APIV2Service) SignInWithSSO(ctx context.Context, request *apiv2pb.SignI return nil, status.Errorf(codes.PermissionDenied, fmt.Sprintf("user has been archived with username %s", userInfo.Identifier)) } - if err := s.doSignIn(ctx, user, time.Now().Add(AccessTokenDuration)); err != nil { + if err := s.doSignIn(ctx, user, time.Now().Add(auth.AccessTokenDuration)); err != nil { return nil, status.Errorf(codes.Internal, fmt.Sprintf("failed to sign in, err: %s", err)) } return &apiv2pb.SignInWithSSOResponse{ @@ -149,7 +150,7 @@ func (s *APIV2Service) SignInWithSSO(ctx context.Context, request *apiv2pb.SignI } func (s *APIV2Service) doSignIn(ctx context.Context, user *store.User, expireTime time.Time) error { - accessToken, err := GenerateAccessToken(user.Email, user.ID, expireTime, []byte(s.Secret)) + accessToken, err := auth.GenerateAccessToken(user.Email, user.ID, expireTime, []byte(s.Secret)) if err != nil { return status.Errorf(codes.Internal, fmt.Sprintf("failed to generate tokens, err: %s", err)) } @@ -212,7 +213,7 @@ func (s *APIV2Service) SignUp(ctx context.Context, request *apiv2pb.SignUpReques return nil, status.Errorf(codes.Internal, fmt.Sprintf("failed to create user, err: %s", err)) } - if err := s.doSignIn(ctx, user, time.Now().Add(AccessTokenDuration)); err != nil { + if err := s.doSignIn(ctx, user, time.Now().Add(auth.AccessTokenDuration)); err != nil { return nil, status.Errorf(codes.Internal, fmt.Sprintf("failed to sign in, err: %s", err)) } return &apiv2pb.SignUpResponse{ @@ -242,7 +243,7 @@ func (s *APIV2Service) clearAccessTokenCookie(ctx context.Context) error { func (*APIV2Service) buildAccessTokenCookie(ctx context.Context, accessToken string, expireTime time.Time) (string, error) { attrs := []string{ - fmt.Sprintf("%s=%s", AccessTokenCookieName, accessToken), + fmt.Sprintf("%s=%s", auth.AccessTokenCookieName, accessToken), "Path=/", "HttpOnly", } diff --git a/server/route/api/v2/reaction_service.go b/server/route/api/v2/reaction_service.go index 94f27d50d..616e27b6c 100644 --- a/server/route/api/v2/reaction_service.go +++ b/server/route/api/v2/reaction_service.go @@ -38,10 +38,10 @@ func (s *APIV2Service) UpsertMemoReaction(ctx context.Context, request *apiv2pb. if err != nil { return nil, status.Errorf(codes.Internal, "failed to get current user") } - reaction, err := s.Store.UpsertReaction(ctx, &storepb.Reaction{ - CreatorId: user.ID, - ContentId: request.Reaction.ContentId, - ReactionType: storepb.Reaction_Type(request.Reaction.ReactionType), + reaction, err := s.Store.UpsertReaction(ctx, &store.Reaction{ + CreatorID: user.ID, + ContentID: request.Reaction.ContentId, + ReactionType: storepb.ReactionType(request.Reaction.ReactionType), }) if err != nil { return nil, status.Errorf(codes.Internal, "failed to upsert reaction") @@ -66,17 +66,17 @@ func (s *APIV2Service) DeleteMemoReaction(ctx context.Context, request *apiv2pb. return &apiv2pb.DeleteMemoReactionResponse{}, nil } -func (s *APIV2Service) convertReactionFromStore(ctx context.Context, reaction *storepb.Reaction) (*apiv2pb.Reaction, error) { +func (s *APIV2Service) convertReactionFromStore(ctx context.Context, reaction *store.Reaction) (*apiv2pb.Reaction, error) { creator, err := s.Store.GetUser(ctx, &store.FindUser{ - ID: &reaction.CreatorId, + ID: &reaction.CreatorID, }) if err != nil { return nil, err } return &apiv2pb.Reaction{ - Id: reaction.Id, + Id: reaction.ID, Creator: fmt.Sprintf("%s%d", UserNamePrefix, creator.ID), - ContentId: reaction.ContentId, + ContentId: reaction.ContentID, ReactionType: apiv2pb.Reaction_Type(reaction.ReactionType), }, nil } diff --git a/server/route/api/v2/user_service.go b/server/route/api/v2/user_service.go index 0c47101b5..547899939 100644 --- a/server/route/api/v2/user_service.go +++ b/server/route/api/v2/user_service.go @@ -21,6 +21,7 @@ import ( "github.com/usememos/memos/internal/util" apiv2pb "github.com/usememos/memos/proto/gen/api/v2" storepb "github.com/usememos/memos/proto/gen/store" + "github.com/usememos/memos/server/route/api/auth" "github.com/usememos/memos/store" ) @@ -355,7 +356,7 @@ func (s *APIV2Service) ListUserAccessTokens(ctx context.Context, _ *apiv2pb.List accessTokens := []*apiv2pb.UserAccessToken{} for _, userAccessToken := range userAccessTokens { - claims := &ClaimsMessage{} + claims := &auth.ClaimsMessage{} _, err := jwt.ParseWithClaims(userAccessToken.AccessToken, claims, func(t *jwt.Token) (any, error) { if t.Method.Alg() != jwt.SigningMethodHS256.Name { return nil, errors.Errorf("unexpected access token signing method=%v, expect %v", t.Header["alg"], jwt.SigningMethodHS256) @@ -404,12 +405,12 @@ func (s *APIV2Service) CreateUserAccessToken(ctx context.Context, request *apiv2 expiresAt = request.ExpiresAt.AsTime() } - accessToken, err := GenerateAccessToken(user.Username, user.ID, expiresAt, []byte(s.Secret)) + accessToken, err := auth.GenerateAccessToken(user.Username, user.ID, expiresAt, []byte(s.Secret)) if err != nil { return nil, status.Errorf(codes.Internal, "failed to generate access token: %v", err) } - claims := &ClaimsMessage{} + claims := &auth.ClaimsMessage{} _, err = jwt.ParseWithClaims(accessToken, claims, func(t *jwt.Token) (any, error) { if t.Method.Alg() != jwt.SigningMethodHS256.Name { return nil, errors.Errorf("unexpected access token signing method=%v, expect %v", t.Header["alg"], jwt.SigningMethodHS256) diff --git a/server/route/resource/resource.go b/server/route/resource/resource.go index 9c83a4538..c9e365b09 100644 --- a/server/route/resource/resource.go +++ b/server/route/resource/resource.go @@ -18,13 +18,11 @@ import ( "github.com/usememos/memos/internal/util" "github.com/usememos/memos/server/profile" + "github.com/usememos/memos/server/route/api/auth" "github.com/usememos/memos/store" ) const ( - // The key name used to store user id in the context - // user id is extracted from the jwt token subject field. - userIDContextKey = "user-id" // thumbnailImagePath is the directory to store image thumbnails. thumbnailImagePath = ".thumbnail_cache" ) @@ -68,7 +66,7 @@ func (s *ResourceService) streamResource(c echo.Context) error { return echo.NewHTTPError(http.StatusInternalServerError, fmt.Sprintf("Failed to find memo by ID: %v", resource.MemoID)).SetInternal(err) } if memo != nil && memo.Visibility != store.Public { - userID, ok := c.Get(userIDContextKey).(int32) + userID, ok := c.Get(auth.UserIDContextKey).(int32) if !ok || (memo.Visibility == store.Private && userID != resource.CreatorID) { return echo.NewHTTPError(http.StatusUnauthorized, "Resource visibility not match") } diff --git a/server/server.go b/server/server.go index 3b0b941cd..5c3692531 100644 --- a/server/server.go +++ b/server/server.go @@ -15,8 +15,11 @@ import ( storepb "github.com/usememos/memos/proto/gen/store" "github.com/usememos/memos/server/integration" "github.com/usememos/memos/server/profile" + "github.com/usememos/memos/server/route/api/auth" apiv2 "github.com/usememos/memos/server/route/api/v2" "github.com/usememos/memos/server/route/frontend" + "github.com/usememos/memos/server/route/resource" + "github.com/usememos/memos/server/route/rss" versionchecker "github.com/usememos/memos/server/service/version_checker" "github.com/usememos/memos/store" ) @@ -74,6 +77,20 @@ func NewServer(ctx context.Context, profile *profile.Profile, store *store.Store frontendService.Serve(ctx, e) } + rootGroup := e.Group("") + + // Register public routes. + publicGroup := rootGroup.Group("/o") + publicGroup.Use(func(next echo.HandlerFunc) echo.HandlerFunc { + return auth.JWTMiddleware(s.Store, next, s.Secret) + }) + + // Create and register resource public routes. + resource.NewResourceService(s.Profile, s.Store).RegisterRoutes(publicGroup) + + // Create and register rss public routes. + rss.NewRSSService(s.Profile, s.Store).RegisterRoutes(rootGroup) + apiV2Service := apiv2.NewAPIV2Service(s.Secret, profile, store, s.Profile.Port+1) // Register gRPC gateway as api v2. if err := apiV2Service.RegisterGateway(ctx, e); err != nil {