chore: fix resource routes

pull/3224/head
Steven 1 year ago
parent cebc46adc7
commit 75359854cc

@ -1,4 +1,4 @@
package v2
package auth
import (
"fmt"

@ -0,0 +1,170 @@
package auth
import (
"fmt"
"log/slog"
"net/http"
"strings"
"time"
"github.com/golang-jwt/jwt/v5"
"github.com/labstack/echo/v4"
"github.com/pkg/errors"
"github.com/usememos/memos/internal/util"
storepb "github.com/usememos/memos/proto/gen/store"
"github.com/usememos/memos/store"
)
const (
// UserIDContextKey is the key name used to store user id in the context.
UserIDContextKey = "user-id"
)
func extractTokenFromHeader(c echo.Context) (string, error) {
authHeader := c.Request().Header.Get("Authorization")
if authHeader == "" {
return "", nil
}
authHeaderParts := strings.Fields(authHeader)
if len(authHeaderParts) != 2 || strings.ToLower(authHeaderParts[0]) != "bearer" {
return "", errors.New("Authorization header format must be Bearer {token}")
}
return authHeaderParts[1], nil
}
func findAccessToken(c echo.Context) string {
// Check the HTTP request header first.
accessToken, _ := extractTokenFromHeader(c)
if accessToken == "" {
// Check the cookie.
cookie, _ := c.Cookie(AccessTokenCookieName)
if cookie != nil {
accessToken = cookie.Value
}
}
return accessToken
}
// JWTMiddleware validates the access token.
func JWTMiddleware(storeInstance *store.Store, next echo.HandlerFunc, secret string) echo.HandlerFunc {
return func(c echo.Context) error {
ctx := c.Request().Context()
path := c.Request().URL.Path
method := c.Request().Method
// Skip validation for server status endpoints.
if util.HasPrefixes(path, "/api/v1/ping", "/api/v1/status") && method == http.MethodGet {
return next(c)
}
accessToken := findAccessToken(c)
if accessToken == "" {
// Allow the user to access the public endpoints.
if util.HasPrefixes(path, "/o") {
return next(c)
}
// When the request is not authenticated, we allow the user to access the memo endpoints for those public memos.
if util.HasPrefixes(path, "/api/v1/idp", "/api/v1/memo", "/api/v1/user") && path != "/api/v1/user" && method == http.MethodGet {
return next(c)
}
return echo.NewHTTPError(http.StatusUnauthorized, "Missing access token")
}
userID, err := getUserIDFromAccessToken(accessToken, secret)
if err != nil {
err = removeAccessTokenAndCookies(c, storeInstance, userID, accessToken)
if err != nil {
slog.Warn("fail to remove AccessToken and Cookies", err)
}
return echo.NewHTTPError(http.StatusUnauthorized, "Invalid or expired access token")
}
accessTokens, err := storeInstance.GetUserAccessTokens(ctx, userID)
if err != nil {
return echo.NewHTTPError(http.StatusInternalServerError, "Failed to get user access tokens.").WithInternal(err)
}
if !validateAccessToken(accessToken, accessTokens) {
err = removeAccessTokenAndCookies(c, storeInstance, userID, accessToken)
if err != nil {
slog.Warn("fail to remove AccessToken and Cookies", err)
}
return echo.NewHTTPError(http.StatusUnauthorized, "Invalid access token.")
}
// Even if there is no error, we still need to make sure the user still exists.
user, err := storeInstance.GetUser(ctx, &store.FindUser{
ID: &userID,
})
if err != nil {
return echo.NewHTTPError(http.StatusInternalServerError, fmt.Sprintf("Server error to find user ID: %d", userID)).SetInternal(err)
}
if user == nil {
return echo.NewHTTPError(http.StatusUnauthorized, fmt.Sprintf("Failed to find user ID: %d", userID))
}
// Stores userID into context.
c.Set(UserIDContextKey, userID)
return next(c)
}
}
func getUserIDFromAccessToken(accessToken, secret string) (int32, error) {
claims := &ClaimsMessage{}
_, err := jwt.ParseWithClaims(accessToken, claims, func(t *jwt.Token) (any, error) {
if t.Method.Alg() != jwt.SigningMethodHS256.Name {
return nil, errors.Errorf("unexpected access token signing method=%v, expect %v", t.Header["alg"], jwt.SigningMethodHS256)
}
if kid, ok := t.Header["kid"].(string); ok {
if kid == "v1" {
return []byte(secret), nil
}
}
return nil, errors.Errorf("unexpected access token kid=%v", t.Header["kid"])
})
if err != nil {
return 0, errors.Wrap(err, "Invalid or expired access token")
}
// We either have a valid access token or we will attempt to generate new access token.
userID, err := util.ConvertStringToInt32(claims.Subject)
if err != nil {
return 0, errors.Wrap(err, "Malformed ID in the token")
}
return userID, nil
}
func validateAccessToken(accessTokenString string, userAccessTokens []*storepb.AccessTokensUserSetting_AccessToken) bool {
for _, userAccessToken := range userAccessTokens {
if accessTokenString == userAccessToken.AccessToken {
return true
}
}
return false
}
// removeAccessTokenAndCookies removes the jwt token from the cookies.
func removeAccessTokenAndCookies(c echo.Context, s *store.Store, userID int32, token string) error {
err := s.RemoveUserAccessToken(c.Request().Context(), userID, token)
if err != nil {
return err
}
cookieExp := time.Now().Add(-1 * time.Hour)
setTokenCookie(c, AccessTokenCookieName, "", cookieExp)
return nil
}
// setTokenCookie sets the token to the cookie.
func setTokenCookie(c echo.Context, name, token string, expiration time.Time) {
cookie := new(http.Cookie)
cookie.Name = name
cookie.Value = token
cookie.Expires = expiration
cookie.Path = "/"
// Http-only helps mitigate the risk of client side script accessing the protected cookie.
cookie.HttpOnly = true
cookie.SameSite = http.SameSiteStrictMode
c.SetCookie(cookie)
}

@ -14,6 +14,7 @@ import (
"github.com/usememos/memos/internal/util"
storepb "github.com/usememos/memos/proto/gen/store"
"github.com/usememos/memos/server/route/api/auth"
"github.com/usememos/memos/store"
)
@ -83,7 +84,7 @@ func (in *GRPCAuthInterceptor) authenticate(ctx context.Context, accessToken str
if accessToken == "" {
return "", status.Errorf(codes.Unauthenticated, "access token not found")
}
claims := &ClaimsMessage{}
claims := &auth.ClaimsMessage{}
_, err := jwt.ParseWithClaims(accessToken, claims, func(t *jwt.Token) (any, error) {
if t.Method.Alg() != jwt.SigningMethodHS256.Name {
return nil, status.Errorf(codes.Unauthenticated, "unexpected access token signing method=%v, expect %v", t.Header["alg"], jwt.SigningMethodHS256)
@ -144,7 +145,7 @@ func getTokenFromMetadata(md metadata.MD) (string, error) {
header := http.Header{}
header.Add("Cookie", t)
request := http.Request{Header: header}
if v, _ := request.Cookie(AccessTokenCookieName); v != nil {
if v, _ := request.Cookie(auth.AccessTokenCookieName); v != nil {
accessToken = v.Value
}
}

@ -19,6 +19,7 @@ import (
"github.com/usememos/memos/plugin/idp/oauth2"
apiv2pb "github.com/usememos/memos/proto/gen/api/v2"
storepb "github.com/usememos/memos/proto/gen/store"
"github.com/usememos/memos/server/route/api/auth"
"github.com/usememos/memos/store"
)
@ -57,7 +58,7 @@ func (s *APIV2Service) SignIn(ctx context.Context, request *apiv2pb.SignInReques
return nil, status.Errorf(codes.InvalidArgument, "unmatched email and password")
}
expireTime := time.Now().Add(AccessTokenDuration)
expireTime := time.Now().Add(auth.AccessTokenDuration)
if request.NeverExpire {
// Set the expire time to 100 years.
expireTime = time.Now().Add(100 * 365 * 24 * time.Hour)
@ -140,7 +141,7 @@ func (s *APIV2Service) SignInWithSSO(ctx context.Context, request *apiv2pb.SignI
return nil, status.Errorf(codes.PermissionDenied, fmt.Sprintf("user has been archived with username %s", userInfo.Identifier))
}
if err := s.doSignIn(ctx, user, time.Now().Add(AccessTokenDuration)); err != nil {
if err := s.doSignIn(ctx, user, time.Now().Add(auth.AccessTokenDuration)); err != nil {
return nil, status.Errorf(codes.Internal, fmt.Sprintf("failed to sign in, err: %s", err))
}
return &apiv2pb.SignInWithSSOResponse{
@ -149,7 +150,7 @@ func (s *APIV2Service) SignInWithSSO(ctx context.Context, request *apiv2pb.SignI
}
func (s *APIV2Service) doSignIn(ctx context.Context, user *store.User, expireTime time.Time) error {
accessToken, err := GenerateAccessToken(user.Email, user.ID, expireTime, []byte(s.Secret))
accessToken, err := auth.GenerateAccessToken(user.Email, user.ID, expireTime, []byte(s.Secret))
if err != nil {
return status.Errorf(codes.Internal, fmt.Sprintf("failed to generate tokens, err: %s", err))
}
@ -212,7 +213,7 @@ func (s *APIV2Service) SignUp(ctx context.Context, request *apiv2pb.SignUpReques
return nil, status.Errorf(codes.Internal, fmt.Sprintf("failed to create user, err: %s", err))
}
if err := s.doSignIn(ctx, user, time.Now().Add(AccessTokenDuration)); err != nil {
if err := s.doSignIn(ctx, user, time.Now().Add(auth.AccessTokenDuration)); err != nil {
return nil, status.Errorf(codes.Internal, fmt.Sprintf("failed to sign in, err: %s", err))
}
return &apiv2pb.SignUpResponse{
@ -242,7 +243,7 @@ func (s *APIV2Service) clearAccessTokenCookie(ctx context.Context) error {
func (*APIV2Service) buildAccessTokenCookie(ctx context.Context, accessToken string, expireTime time.Time) (string, error) {
attrs := []string{
fmt.Sprintf("%s=%s", AccessTokenCookieName, accessToken),
fmt.Sprintf("%s=%s", auth.AccessTokenCookieName, accessToken),
"Path=/",
"HttpOnly",
}

@ -38,10 +38,10 @@ func (s *APIV2Service) UpsertMemoReaction(ctx context.Context, request *apiv2pb.
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to get current user")
}
reaction, err := s.Store.UpsertReaction(ctx, &storepb.Reaction{
CreatorId: user.ID,
ContentId: request.Reaction.ContentId,
ReactionType: storepb.Reaction_Type(request.Reaction.ReactionType),
reaction, err := s.Store.UpsertReaction(ctx, &store.Reaction{
CreatorID: user.ID,
ContentID: request.Reaction.ContentId,
ReactionType: storepb.ReactionType(request.Reaction.ReactionType),
})
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to upsert reaction")
@ -66,17 +66,17 @@ func (s *APIV2Service) DeleteMemoReaction(ctx context.Context, request *apiv2pb.
return &apiv2pb.DeleteMemoReactionResponse{}, nil
}
func (s *APIV2Service) convertReactionFromStore(ctx context.Context, reaction *storepb.Reaction) (*apiv2pb.Reaction, error) {
func (s *APIV2Service) convertReactionFromStore(ctx context.Context, reaction *store.Reaction) (*apiv2pb.Reaction, error) {
creator, err := s.Store.GetUser(ctx, &store.FindUser{
ID: &reaction.CreatorId,
ID: &reaction.CreatorID,
})
if err != nil {
return nil, err
}
return &apiv2pb.Reaction{
Id: reaction.Id,
Id: reaction.ID,
Creator: fmt.Sprintf("%s%d", UserNamePrefix, creator.ID),
ContentId: reaction.ContentId,
ContentId: reaction.ContentID,
ReactionType: apiv2pb.Reaction_Type(reaction.ReactionType),
}, nil
}

@ -21,6 +21,7 @@ import (
"github.com/usememos/memos/internal/util"
apiv2pb "github.com/usememos/memos/proto/gen/api/v2"
storepb "github.com/usememos/memos/proto/gen/store"
"github.com/usememos/memos/server/route/api/auth"
"github.com/usememos/memos/store"
)
@ -355,7 +356,7 @@ func (s *APIV2Service) ListUserAccessTokens(ctx context.Context, _ *apiv2pb.List
accessTokens := []*apiv2pb.UserAccessToken{}
for _, userAccessToken := range userAccessTokens {
claims := &ClaimsMessage{}
claims := &auth.ClaimsMessage{}
_, err := jwt.ParseWithClaims(userAccessToken.AccessToken, claims, func(t *jwt.Token) (any, error) {
if t.Method.Alg() != jwt.SigningMethodHS256.Name {
return nil, errors.Errorf("unexpected access token signing method=%v, expect %v", t.Header["alg"], jwt.SigningMethodHS256)
@ -404,12 +405,12 @@ func (s *APIV2Service) CreateUserAccessToken(ctx context.Context, request *apiv2
expiresAt = request.ExpiresAt.AsTime()
}
accessToken, err := GenerateAccessToken(user.Username, user.ID, expiresAt, []byte(s.Secret))
accessToken, err := auth.GenerateAccessToken(user.Username, user.ID, expiresAt, []byte(s.Secret))
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to generate access token: %v", err)
}
claims := &ClaimsMessage{}
claims := &auth.ClaimsMessage{}
_, err = jwt.ParseWithClaims(accessToken, claims, func(t *jwt.Token) (any, error) {
if t.Method.Alg() != jwt.SigningMethodHS256.Name {
return nil, errors.Errorf("unexpected access token signing method=%v, expect %v", t.Header["alg"], jwt.SigningMethodHS256)

@ -18,13 +18,11 @@ import (
"github.com/usememos/memos/internal/util"
"github.com/usememos/memos/server/profile"
"github.com/usememos/memos/server/route/api/auth"
"github.com/usememos/memos/store"
)
const (
// The key name used to store user id in the context
// user id is extracted from the jwt token subject field.
userIDContextKey = "user-id"
// thumbnailImagePath is the directory to store image thumbnails.
thumbnailImagePath = ".thumbnail_cache"
)
@ -68,7 +66,7 @@ func (s *ResourceService) streamResource(c echo.Context) error {
return echo.NewHTTPError(http.StatusInternalServerError, fmt.Sprintf("Failed to find memo by ID: %v", resource.MemoID)).SetInternal(err)
}
if memo != nil && memo.Visibility != store.Public {
userID, ok := c.Get(userIDContextKey).(int32)
userID, ok := c.Get(auth.UserIDContextKey).(int32)
if !ok || (memo.Visibility == store.Private && userID != resource.CreatorID) {
return echo.NewHTTPError(http.StatusUnauthorized, "Resource visibility not match")
}

@ -15,8 +15,11 @@ import (
storepb "github.com/usememos/memos/proto/gen/store"
"github.com/usememos/memos/server/integration"
"github.com/usememos/memos/server/profile"
"github.com/usememos/memos/server/route/api/auth"
apiv2 "github.com/usememos/memos/server/route/api/v2"
"github.com/usememos/memos/server/route/frontend"
"github.com/usememos/memos/server/route/resource"
"github.com/usememos/memos/server/route/rss"
versionchecker "github.com/usememos/memos/server/service/version_checker"
"github.com/usememos/memos/store"
)
@ -74,6 +77,20 @@ func NewServer(ctx context.Context, profile *profile.Profile, store *store.Store
frontendService.Serve(ctx, e)
}
rootGroup := e.Group("")
// Register public routes.
publicGroup := rootGroup.Group("/o")
publicGroup.Use(func(next echo.HandlerFunc) echo.HandlerFunc {
return auth.JWTMiddleware(s.Store, next, s.Secret)
})
// Create and register resource public routes.
resource.NewResourceService(s.Profile, s.Store).RegisterRoutes(publicGroup)
// Create and register rss public routes.
rss.NewRSSService(s.Profile, s.Store).RegisterRoutes(rootGroup)
apiV2Service := apiv2.NewAPIV2Service(s.Secret, profile, store, s.Profile.Port+1)
// Register gRPC gateway as api v2.
if err := apiV2Service.RegisterGateway(ctx, e); err != nil {

Loading…
Cancel
Save