From ae1e22931f7729fa2adb8a3ae52823f124ff6432 Mon Sep 17 00:00:00 2001 From: Steven Date: Wed, 20 Sep 2023 19:24:26 +0800 Subject: [PATCH] chore: auto remove current access token when sign out --- api/v1/auth.go | 34 +++++++++++++++++++++++---- api/v1/jwt.go | 62 +++++++++++++++++++++++--------------------------- api/v2/acl.go | 16 ------------- 3 files changed, 58 insertions(+), 54 deletions(-) diff --git a/api/v1/auth.go b/api/v1/auth.go index 0225b8d0..a5425346 100644 --- a/api/v1/auth.go +++ b/api/v1/auth.go @@ -251,8 +251,34 @@ func (s *APIV1Service) SignInSSO(c echo.Context) error { // @Produce json // @Success 200 {boolean} true "Sign-out success" // @Router /api/v1/auth/signout [POST] -func (*APIV1Service) SignOut(c echo.Context) error { - RemoveTokensAndCookies(c) +func (s *APIV1Service) SignOut(c echo.Context) error { + ctx := c.Request().Context() + accessToken := findAccessToken(c) + userID, _ := getUserIDFromAccessToken(accessToken, s.Secret) + userAccessTokens, err := s.Store.GetUserAccessTokens(ctx, userID) + // Auto remove the current access token from the user access tokens. + if err == nil && len(userAccessTokens) != 0 { + accessTokens := []*storepb.AccessTokensUserSetting_AccessToken{} + for _, userAccessToken := range userAccessTokens { + if accessToken != userAccessToken.AccessToken { + accessTokens = append(accessTokens, userAccessToken) + } + } + + if _, err := s.Store.UpsertUserSettingV1(ctx, &storepb.UserSetting{ + UserId: userID, + Key: storepb.UserSettingKey_USER_SETTING_ACCESS_TOKENS, + Value: &storepb.UserSetting_AccessTokens{ + AccessTokens: &storepb.AccessTokensUserSetting{ + AccessTokens: accessTokens, + }, + }, + }); err != nil { + return echo.NewHTTPError(http.StatusInternalServerError, fmt.Sprintf("failed to upsert user setting, err: %s", err)).SetInternal(err) + } + } + + removeAccessTokenAndCookies(c) return c.JSON(http.StatusOK, true) } @@ -411,8 +437,8 @@ func (s *APIV1Service) createAuthSignUpActivity(c echo.Context, user *store.User return err } -// RemoveTokensAndCookies removes the jwt token from the cookies. -func RemoveTokensAndCookies(c echo.Context) { +// removeAccessTokenAndCookies removes the jwt token from the cookies. +func removeAccessTokenAndCookies(c echo.Context) { cookieExp := time.Now().Add(-1 * time.Hour) setTokenCookie(c, auth.AccessTokenCookieName, "", cookieExp) } diff --git a/api/v1/jwt.go b/api/v1/jwt.go index 6921908c..bc641fa5 100644 --- a/api/v1/jwt.go +++ b/api/v1/jwt.go @@ -48,15 +48,6 @@ func findAccessToken(c echo.Context) string { return accessToken } -func audienceContains(audience jwt.ClaimStrings, token string) bool { - for _, v := range audience { - if v == token { - return true - } - } - return false -} - // JWTMiddleware validates the access token. func JWTMiddleware(server *APIV1Service, next echo.HandlerFunc, secret string) echo.HandlerFunc { return func(c echo.Context) error { @@ -86,31 +77,10 @@ func JWTMiddleware(server *APIV1Service, next echo.HandlerFunc, secret string) e return echo.NewHTTPError(http.StatusUnauthorized, "Missing access token") } - claims := &auth.ClaimsMessage{} - _, err := jwt.ParseWithClaims(accessToken, claims, func(t *jwt.Token) (any, error) { - if t.Method.Alg() != jwt.SigningMethodHS256.Name { - return nil, errors.Errorf("unexpected access token signing method=%v, expect %v", t.Header["alg"], jwt.SigningMethodHS256) - } - if kid, ok := t.Header["kid"].(string); ok { - if kid == "v1" { - return []byte(secret), nil - } - } - return nil, errors.Errorf("unexpected access token kid=%v", t.Header["kid"]) - }) - + userID, err := getUserIDFromAccessToken(accessToken, secret) if err != nil { - RemoveTokensAndCookies(c) - return echo.NewHTTPError(http.StatusUnauthorized, errors.Wrap(err, "Invalid or expired access token")) - } - if !audienceContains(claims.Audience, auth.AccessTokenAudienceName) { - return echo.NewHTTPError(http.StatusUnauthorized, fmt.Sprintf("Invalid access token, audience mismatch, got %q, expected %q.", claims.Audience, auth.AccessTokenAudienceName)) - } - - // We either have a valid access token or we will attempt to generate new access token. - userID, err := util.ConvertStringToInt32(claims.Subject) - if err != nil { - return echo.NewHTTPError(http.StatusUnauthorized, "Malformed ID in the token.") + removeAccessTokenAndCookies(c) + return echo.NewHTTPError(http.StatusUnauthorized, "Invalid or expired access token") } accessTokens, err := server.Store.GetUserAccessTokens(ctx, userID) @@ -118,7 +88,7 @@ func JWTMiddleware(server *APIV1Service, next echo.HandlerFunc, secret string) e return echo.NewHTTPError(http.StatusInternalServerError, "Failed to get user access tokens.").WithInternal(err) } if !validateAccessToken(accessToken, accessTokens) { - RemoveTokensAndCookies(c) + removeAccessTokenAndCookies(c) return echo.NewHTTPError(http.StatusUnauthorized, "Invalid access token.") } @@ -139,6 +109,30 @@ func JWTMiddleware(server *APIV1Service, next echo.HandlerFunc, secret string) e } } +func getUserIDFromAccessToken(accessToken, secret string) (int32, error) { + claims := &auth.ClaimsMessage{} + _, err := jwt.ParseWithClaims(accessToken, claims, func(t *jwt.Token) (any, error) { + if t.Method.Alg() != jwt.SigningMethodHS256.Name { + return nil, errors.Errorf("unexpected access token signing method=%v, expect %v", t.Header["alg"], jwt.SigningMethodHS256) + } + if kid, ok := t.Header["kid"].(string); ok { + if kid == "v1" { + return []byte(secret), nil + } + } + return nil, errors.Errorf("unexpected access token kid=%v", t.Header["kid"]) + }) + if err != nil { + return 0, errors.Wrap(err, "Invalid or expired access token") + } + // We either have a valid access token or we will attempt to generate new access token. + userID, err := util.ConvertStringToInt32(claims.Subject) + if err != nil { + return 0, errors.Wrap(err, "Malformed ID in the token") + } + return userID, nil +} + func (*APIV1Service) defaultAuthSkipper(c echo.Context) bool { path := c.Path() return util.HasPrefixes(path, "/api/v1/auth") diff --git a/api/v2/acl.go b/api/v2/acl.go index 86f45e77..804604ba 100644 --- a/api/v2/acl.go +++ b/api/v2/acl.go @@ -96,13 +96,6 @@ func (in *GRPCAuthInterceptor) authenticate(ctx context.Context, accessToken str if err != nil { return "", status.Errorf(codes.Unauthenticated, "Invalid or expired access token") } - if !audienceContains(claims.Audience, auth.AccessTokenAudienceName) { - return "", status.Errorf(codes.Unauthenticated, - "invalid access token, audience mismatch, got %q, expected %q. you may send request to the wrong environment", - claims.Audience, - auth.AccessTokenAudienceName, - ) - } // We either have a valid access token or we will attempt to generate new access token. userID, err := util.ConvertStringToInt32(claims.Subject) @@ -155,15 +148,6 @@ func getTokenFromMetadata(md metadata.MD) (string, error) { return accessToken, nil } -func audienceContains(audience jwt.ClaimStrings, token string) bool { - for _, v := range audience { - if v == token { - return true - } - } - return false -} - func validateAccessToken(accessTokenString string, userAccessTokens []*storepb.AccessTokensUserSetting_AccessToken) bool { for _, userAccessToken := range userAccessTokens { if accessTokenString == userAccessToken.AccessToken {