mirror of https://github.com/usememos/memos
refactor: user auth improvements (#5360)
parent
2c2ef53737
commit
7932f6d0d0
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,306 @@
|
||||
package auth
|
||||
|
||||
import (
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestGenerateAccessTokenV2(t *testing.T) {
|
||||
secret := []byte("test-secret")
|
||||
|
||||
t.Run("generates valid access token", func(t *testing.T) {
|
||||
token, expiresAt, err := GenerateAccessTokenV2(1, "testuser", "USER", "ACTIVE", secret)
|
||||
require.NoError(t, err)
|
||||
assert.NotEmpty(t, token)
|
||||
assert.True(t, expiresAt.After(time.Now()))
|
||||
assert.True(t, expiresAt.Before(time.Now().Add(AccessTokenDuration+time.Minute)))
|
||||
})
|
||||
|
||||
t.Run("generates different tokens for same user", func(t *testing.T) {
|
||||
token1, _, err := GenerateAccessTokenV2(1, "testuser", "USER", "ACTIVE", secret)
|
||||
require.NoError(t, err)
|
||||
|
||||
time.Sleep(2 * time.Second) // Ensure different timestamps (tokens have 1s precision)
|
||||
|
||||
token2, _, err := GenerateAccessTokenV2(1, "testuser", "USER", "ACTIVE", secret)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.NotEqual(t, token1, token2, "tokens should be different due to different timestamps")
|
||||
})
|
||||
}
|
||||
|
||||
func TestParseAccessTokenV2(t *testing.T) {
|
||||
secret := []byte("test-secret")
|
||||
|
||||
t.Run("parses valid access token", func(t *testing.T) {
|
||||
token, _, err := GenerateAccessTokenV2(1, "testuser", "USER", "ACTIVE", secret)
|
||||
require.NoError(t, err)
|
||||
|
||||
claims, err := ParseAccessTokenV2(token, secret)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "1", claims.Subject)
|
||||
assert.Equal(t, "testuser", claims.Username)
|
||||
assert.Equal(t, "USER", claims.Role)
|
||||
assert.Equal(t, "ACTIVE", claims.Status)
|
||||
assert.Equal(t, "access", claims.Type)
|
||||
})
|
||||
|
||||
t.Run("fails with wrong secret", func(t *testing.T) {
|
||||
token, _, err := GenerateAccessTokenV2(1, "testuser", "USER", "ACTIVE", secret)
|
||||
require.NoError(t, err)
|
||||
|
||||
wrongSecret := []byte("wrong-secret")
|
||||
_, err = ParseAccessTokenV2(token, wrongSecret)
|
||||
assert.Error(t, err)
|
||||
})
|
||||
|
||||
t.Run("fails with invalid token", func(t *testing.T) {
|
||||
_, err := ParseAccessTokenV2("invalid-token", secret)
|
||||
assert.Error(t, err)
|
||||
})
|
||||
|
||||
t.Run("fails with refresh token", func(t *testing.T) {
|
||||
// Generate a refresh token and try to parse it as access token
|
||||
// Should fail because audience mismatch is caught before type check
|
||||
refreshToken, _, err := GenerateRefreshToken(1, "token-id", secret)
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = ParseAccessTokenV2(refreshToken, secret)
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "invalid audience")
|
||||
})
|
||||
|
||||
t.Run("parses token with different roles", func(t *testing.T) {
|
||||
roles := []string{"USER", "ADMIN", "HOST"}
|
||||
for _, role := range roles {
|
||||
token, _, err := GenerateAccessTokenV2(1, "testuser", role, "ACTIVE", secret)
|
||||
require.NoError(t, err)
|
||||
|
||||
claims, err := ParseAccessTokenV2(token, secret)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, role, claims.Role)
|
||||
}
|
||||
})
|
||||
}
|
||||
|
||||
func TestGenerateRefreshToken(t *testing.T) {
|
||||
secret := []byte("test-secret")
|
||||
|
||||
t.Run("generates valid refresh token", func(t *testing.T) {
|
||||
token, expiresAt, err := GenerateRefreshToken(1, "token-id-123", secret)
|
||||
require.NoError(t, err)
|
||||
assert.NotEmpty(t, token)
|
||||
assert.True(t, expiresAt.After(time.Now().Add(29*24*time.Hour)))
|
||||
})
|
||||
|
||||
t.Run("generates different tokens for different token IDs", func(t *testing.T) {
|
||||
token1, _, err := GenerateRefreshToken(1, "token-id-1", secret)
|
||||
require.NoError(t, err)
|
||||
|
||||
token2, _, err := GenerateRefreshToken(1, "token-id-2", secret)
|
||||
require.NoError(t, err)
|
||||
|
||||
assert.NotEqual(t, token1, token2)
|
||||
})
|
||||
}
|
||||
|
||||
func TestParseRefreshToken(t *testing.T) {
|
||||
secret := []byte("test-secret")
|
||||
|
||||
t.Run("parses valid refresh token", func(t *testing.T) {
|
||||
token, _, err := GenerateRefreshToken(1, "token-id-123", secret)
|
||||
require.NoError(t, err)
|
||||
|
||||
claims, err := ParseRefreshToken(token, secret)
|
||||
require.NoError(t, err)
|
||||
assert.Equal(t, "1", claims.Subject)
|
||||
assert.Equal(t, "token-id-123", claims.TokenID)
|
||||
assert.Equal(t, "refresh", claims.Type)
|
||||
})
|
||||
|
||||
t.Run("fails with wrong secret", func(t *testing.T) {
|
||||
token, _, err := GenerateRefreshToken(1, "token-id-123", secret)
|
||||
require.NoError(t, err)
|
||||
|
||||
wrongSecret := []byte("wrong-secret")
|
||||
_, err = ParseRefreshToken(token, wrongSecret)
|
||||
assert.Error(t, err)
|
||||
})
|
||||
|
||||
t.Run("fails with invalid token", func(t *testing.T) {
|
||||
_, err := ParseRefreshToken("invalid-token", secret)
|
||||
assert.Error(t, err)
|
||||
})
|
||||
|
||||
t.Run("fails with access token", func(t *testing.T) {
|
||||
// Generate an access token and try to parse it as refresh token
|
||||
// Should fail because audience mismatch is caught before type check
|
||||
accessToken, _, err := GenerateAccessTokenV2(1, "testuser", "USER", "ACTIVE", secret)
|
||||
require.NoError(t, err)
|
||||
|
||||
_, err = ParseRefreshToken(accessToken, secret)
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "invalid audience")
|
||||
})
|
||||
}
|
||||
|
||||
func TestGeneratePersonalAccessToken(t *testing.T) {
|
||||
t.Run("generates token with correct prefix", func(t *testing.T) {
|
||||
token := GeneratePersonalAccessToken()
|
||||
assert.NotEmpty(t, token)
|
||||
assert.True(t, len(token) > len(PersonalAccessTokenPrefix))
|
||||
assert.Equal(t, PersonalAccessTokenPrefix, token[:len(PersonalAccessTokenPrefix)])
|
||||
})
|
||||
|
||||
t.Run("generates unique tokens", func(t *testing.T) {
|
||||
token1 := GeneratePersonalAccessToken()
|
||||
token2 := GeneratePersonalAccessToken()
|
||||
assert.NotEqual(t, token1, token2)
|
||||
})
|
||||
|
||||
t.Run("generates token of sufficient length", func(t *testing.T) {
|
||||
token := GeneratePersonalAccessToken()
|
||||
// Prefix is "memos_pat_" (10 chars) + 32 random chars = at least 42 chars
|
||||
assert.True(t, len(token) >= 42, "token should be at least 42 characters")
|
||||
})
|
||||
}
|
||||
|
||||
func TestHashPersonalAccessToken(t *testing.T) {
|
||||
t.Run("generates SHA-256 hash", func(t *testing.T) {
|
||||
token := "memos_pat_abc123"
|
||||
hash := HashPersonalAccessToken(token)
|
||||
assert.NotEmpty(t, hash)
|
||||
assert.Len(t, hash, 64, "SHA-256 hex should be 64 characters")
|
||||
})
|
||||
|
||||
t.Run("same input produces same hash", func(t *testing.T) {
|
||||
token := "memos_pat_abc123"
|
||||
hash1 := HashPersonalAccessToken(token)
|
||||
hash2 := HashPersonalAccessToken(token)
|
||||
assert.Equal(t, hash1, hash2)
|
||||
})
|
||||
|
||||
t.Run("different inputs produce different hashes", func(t *testing.T) {
|
||||
token1 := "memos_pat_abc123"
|
||||
token2 := "memos_pat_xyz789"
|
||||
hash1 := HashPersonalAccessToken(token1)
|
||||
hash2 := HashPersonalAccessToken(token2)
|
||||
assert.NotEqual(t, hash1, hash2)
|
||||
})
|
||||
|
||||
t.Run("hash is deterministic", func(t *testing.T) {
|
||||
token := GeneratePersonalAccessToken()
|
||||
hash1 := HashPersonalAccessToken(token)
|
||||
hash2 := HashPersonalAccessToken(token)
|
||||
assert.Equal(t, hash1, hash2)
|
||||
})
|
||||
}
|
||||
|
||||
func TestAccessTokenV2Integration(t *testing.T) {
|
||||
secret := []byte("test-secret")
|
||||
|
||||
t.Run("full lifecycle: generate, parse, validate", func(t *testing.T) {
|
||||
userID := int32(42)
|
||||
username := "john_doe"
|
||||
role := "ADMIN"
|
||||
status := "ACTIVE"
|
||||
|
||||
// Generate token
|
||||
token, expiresAt, err := GenerateAccessTokenV2(userID, username, role, status, secret)
|
||||
require.NoError(t, err)
|
||||
assert.NotEmpty(t, token)
|
||||
|
||||
// Parse token
|
||||
claims, err := ParseAccessTokenV2(token, secret)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Validate claims
|
||||
assert.Equal(t, "42", claims.Subject)
|
||||
assert.Equal(t, username, claims.Username)
|
||||
assert.Equal(t, role, claims.Role)
|
||||
assert.Equal(t, status, claims.Status)
|
||||
assert.Equal(t, "access", claims.Type)
|
||||
assert.Equal(t, Issuer, claims.Issuer)
|
||||
assert.NotNil(t, claims.IssuedAt)
|
||||
assert.NotNil(t, claims.ExpiresAt)
|
||||
|
||||
// Validate expiration
|
||||
assert.True(t, claims.ExpiresAt.Equal(expiresAt) || claims.ExpiresAt.Before(expiresAt))
|
||||
})
|
||||
}
|
||||
|
||||
func TestRefreshTokenIntegration(t *testing.T) {
|
||||
secret := []byte("test-secret")
|
||||
|
||||
t.Run("full lifecycle: generate, parse, validate", func(t *testing.T) {
|
||||
userID := int32(42)
|
||||
tokenID := "unique-token-id-456"
|
||||
|
||||
// Generate token
|
||||
token, expiresAt, err := GenerateRefreshToken(userID, tokenID, secret)
|
||||
require.NoError(t, err)
|
||||
assert.NotEmpty(t, token)
|
||||
|
||||
// Parse token
|
||||
claims, err := ParseRefreshToken(token, secret)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Validate claims
|
||||
assert.Equal(t, "42", claims.Subject)
|
||||
assert.Equal(t, tokenID, claims.TokenID)
|
||||
assert.Equal(t, "refresh", claims.Type)
|
||||
assert.Equal(t, Issuer, claims.Issuer)
|
||||
assert.NotNil(t, claims.IssuedAt)
|
||||
assert.NotNil(t, claims.ExpiresAt)
|
||||
|
||||
// Validate expiration
|
||||
assert.True(t, claims.ExpiresAt.Equal(expiresAt) || claims.ExpiresAt.Before(expiresAt))
|
||||
})
|
||||
}
|
||||
|
||||
func TestPersonalAccessTokenIntegration(t *testing.T) {
|
||||
t.Run("full lifecycle: generate, hash, verify", func(t *testing.T) {
|
||||
// Generate token
|
||||
token := GeneratePersonalAccessToken()
|
||||
assert.NotEmpty(t, token)
|
||||
assert.True(t, len(token) > len(PersonalAccessTokenPrefix))
|
||||
|
||||
// Hash token
|
||||
hash := HashPersonalAccessToken(token)
|
||||
assert.Len(t, hash, 64)
|
||||
|
||||
// Verify same token produces same hash
|
||||
hashAgain := HashPersonalAccessToken(token)
|
||||
assert.Equal(t, hash, hashAgain)
|
||||
|
||||
// Verify different token produces different hash
|
||||
token2 := GeneratePersonalAccessToken()
|
||||
hash2 := HashPersonalAccessToken(token2)
|
||||
assert.NotEqual(t, hash, hash2)
|
||||
})
|
||||
}
|
||||
|
||||
func TestTokenExpiration(t *testing.T) {
|
||||
secret := []byte("test-secret")
|
||||
|
||||
t.Run("access token expires after AccessTokenDuration", func(t *testing.T) {
|
||||
_, expiresAt, err := GenerateAccessTokenV2(1, "testuser", "USER", "ACTIVE", secret)
|
||||
require.NoError(t, err)
|
||||
|
||||
expectedExpiry := time.Now().Add(AccessTokenDuration)
|
||||
delta := expiresAt.Sub(expectedExpiry)
|
||||
assert.True(t, delta < time.Second, "expiration should be within 1 second of expected")
|
||||
})
|
||||
|
||||
t.Run("refresh token expires after RefreshTokenDuration", func(t *testing.T) {
|
||||
_, expiresAt, err := GenerateRefreshToken(1, "token-id", secret)
|
||||
require.NoError(t, err)
|
||||
|
||||
expectedExpiry := time.Now().Add(RefreshTokenDuration)
|
||||
delta := expiresAt.Sub(expectedExpiry)
|
||||
assert.True(t, delta < time.Second, "expiration should be within 1 second of expected")
|
||||
})
|
||||
}
|
||||
@ -0,0 +1,655 @@
|
||||
package test
|
||||
|
||||
import (
|
||||
"context"
|
||||
"testing"
|
||||
"time"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
"github.com/stretchr/testify/require"
|
||||
"google.golang.org/protobuf/types/known/timestamppb"
|
||||
|
||||
"github.com/usememos/memos/internal/util"
|
||||
storepb "github.com/usememos/memos/proto/gen/store"
|
||||
"github.com/usememos/memos/server/auth"
|
||||
"github.com/usememos/memos/store"
|
||||
)
|
||||
|
||||
func TestAuthenticatorAccessTokenV2(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
t.Run("authenticates valid access token v2", func(t *testing.T) {
|
||||
ts := NewTestService(t)
|
||||
defer ts.Cleanup()
|
||||
|
||||
// Create a test user
|
||||
user, err := ts.CreateRegularUser(ctx, "testuser")
|
||||
require.NoError(t, err)
|
||||
|
||||
// Generate access token v2
|
||||
token, _, err := auth.GenerateAccessTokenV2(
|
||||
user.ID,
|
||||
user.Username,
|
||||
string(user.Role),
|
||||
string(user.RowStatus),
|
||||
[]byte(ts.Secret),
|
||||
)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Authenticate
|
||||
authenticator := auth.NewAuthenticator(ts.Store, ts.Secret)
|
||||
claims, err := authenticator.AuthenticateByAccessTokenV2(token)
|
||||
require.NoError(t, err)
|
||||
assert.NotNil(t, claims)
|
||||
assert.Equal(t, user.ID, claims.UserID)
|
||||
assert.Equal(t, user.Username, claims.Username)
|
||||
assert.Equal(t, string(user.Role), claims.Role)
|
||||
assert.Equal(t, string(user.RowStatus), claims.Status)
|
||||
})
|
||||
|
||||
t.Run("fails with invalid token", func(t *testing.T) {
|
||||
ts := NewTestService(t)
|
||||
defer ts.Cleanup()
|
||||
|
||||
authenticator := auth.NewAuthenticator(ts.Store, ts.Secret)
|
||||
_, err := authenticator.AuthenticateByAccessTokenV2("invalid-token")
|
||||
assert.Error(t, err)
|
||||
})
|
||||
|
||||
t.Run("fails with wrong secret", func(t *testing.T) {
|
||||
ts := NewTestService(t)
|
||||
defer ts.Cleanup()
|
||||
|
||||
user, err := ts.CreateRegularUser(ctx, "testuser")
|
||||
require.NoError(t, err)
|
||||
|
||||
// Generate token with one secret
|
||||
token, _, err := auth.GenerateAccessTokenV2(
|
||||
user.ID,
|
||||
user.Username,
|
||||
string(user.Role),
|
||||
string(user.RowStatus),
|
||||
[]byte("secret-1"),
|
||||
)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Try to authenticate with different secret
|
||||
authenticator := auth.NewAuthenticator(ts.Store, "secret-2")
|
||||
_, err = authenticator.AuthenticateByAccessTokenV2(token)
|
||||
assert.Error(t, err)
|
||||
})
|
||||
}
|
||||
|
||||
func TestAuthenticatorRefreshToken(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
t.Run("authenticates valid refresh token", func(t *testing.T) {
|
||||
ts := NewTestService(t)
|
||||
defer ts.Cleanup()
|
||||
|
||||
// Create a test user
|
||||
user, err := ts.CreateRegularUser(ctx, "testuser")
|
||||
require.NoError(t, err)
|
||||
|
||||
// Create refresh token record in store
|
||||
tokenID := util.GenUUID()
|
||||
refreshTokenRecord := &storepb.RefreshTokensUserSetting_RefreshToken{
|
||||
TokenId: tokenID,
|
||||
ExpiresAt: timestamppb.New(time.Now().Add(auth.RefreshTokenDuration)),
|
||||
CreatedAt: timestamppb.Now(),
|
||||
}
|
||||
err = ts.Store.AddUserRefreshToken(ctx, user.ID, refreshTokenRecord)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Generate refresh token JWT
|
||||
token, _, err := auth.GenerateRefreshToken(user.ID, tokenID, []byte(ts.Secret))
|
||||
require.NoError(t, err)
|
||||
|
||||
// Authenticate
|
||||
authenticator := auth.NewAuthenticator(ts.Store, ts.Secret)
|
||||
authenticatedUser, returnedTokenID, err := authenticator.AuthenticateByRefreshToken(ctx, token)
|
||||
require.NoError(t, err)
|
||||
assert.NotNil(t, authenticatedUser)
|
||||
assert.Equal(t, user.ID, authenticatedUser.ID)
|
||||
assert.Equal(t, tokenID, returnedTokenID)
|
||||
})
|
||||
|
||||
t.Run("fails with revoked token", func(t *testing.T) {
|
||||
ts := NewTestService(t)
|
||||
defer ts.Cleanup()
|
||||
|
||||
user, err := ts.CreateRegularUser(ctx, "testuser")
|
||||
require.NoError(t, err)
|
||||
|
||||
tokenID := util.GenUUID()
|
||||
|
||||
// Generate refresh token JWT but don't store it in database (simulates revocation)
|
||||
token, _, err := auth.GenerateRefreshToken(user.ID, tokenID, []byte(ts.Secret))
|
||||
require.NoError(t, err)
|
||||
|
||||
// Try to authenticate
|
||||
authenticator := auth.NewAuthenticator(ts.Store, ts.Secret)
|
||||
_, _, err = authenticator.AuthenticateByRefreshToken(ctx, token)
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "revoked")
|
||||
})
|
||||
|
||||
t.Run("fails with expired token", func(t *testing.T) {
|
||||
ts := NewTestService(t)
|
||||
defer ts.Cleanup()
|
||||
|
||||
user, err := ts.CreateRegularUser(ctx, "testuser")
|
||||
require.NoError(t, err)
|
||||
|
||||
// Create expired refresh token record in store
|
||||
tokenID := util.GenUUID()
|
||||
expiredToken := &storepb.RefreshTokensUserSetting_RefreshToken{
|
||||
TokenId: tokenID,
|
||||
ExpiresAt: timestamppb.New(time.Now().Add(-1 * time.Hour)), // Expired
|
||||
CreatedAt: timestamppb.Now(),
|
||||
}
|
||||
err = ts.Store.AddUserRefreshToken(ctx, user.ID, expiredToken)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Generate refresh token JWT (JWT itself isn't expired yet)
|
||||
token, _, err := auth.GenerateRefreshToken(user.ID, tokenID, []byte(ts.Secret))
|
||||
require.NoError(t, err)
|
||||
|
||||
// Try to authenticate
|
||||
authenticator := auth.NewAuthenticator(ts.Store, ts.Secret)
|
||||
_, _, err = authenticator.AuthenticateByRefreshToken(ctx, token)
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "expired")
|
||||
})
|
||||
|
||||
t.Run("fails with archived user", func(t *testing.T) {
|
||||
ts := NewTestService(t)
|
||||
defer ts.Cleanup()
|
||||
|
||||
user, err := ts.CreateRegularUser(ctx, "testuser")
|
||||
require.NoError(t, err)
|
||||
|
||||
// Create valid refresh token
|
||||
tokenID := util.GenUUID()
|
||||
refreshTokenRecord := &storepb.RefreshTokensUserSetting_RefreshToken{
|
||||
TokenId: tokenID,
|
||||
ExpiresAt: timestamppb.New(time.Now().Add(auth.RefreshTokenDuration)),
|
||||
CreatedAt: timestamppb.Now(),
|
||||
}
|
||||
err = ts.Store.AddUserRefreshToken(ctx, user.ID, refreshTokenRecord)
|
||||
require.NoError(t, err)
|
||||
|
||||
token, _, err := auth.GenerateRefreshToken(user.ID, tokenID, []byte(ts.Secret))
|
||||
require.NoError(t, err)
|
||||
|
||||
// Archive the user
|
||||
archivedStatus := store.Archived
|
||||
_, err = ts.Store.UpdateUser(ctx, &store.UpdateUser{
|
||||
ID: user.ID,
|
||||
RowStatus: &archivedStatus,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Try to authenticate
|
||||
authenticator := auth.NewAuthenticator(ts.Store, ts.Secret)
|
||||
_, _, err = authenticator.AuthenticateByRefreshToken(ctx, token)
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "archived")
|
||||
})
|
||||
}
|
||||
|
||||
func TestAuthenticatorPAT(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
t.Run("authenticates valid PAT", func(t *testing.T) {
|
||||
ts := NewTestService(t)
|
||||
defer ts.Cleanup()
|
||||
|
||||
// Create a test user
|
||||
user, err := ts.CreateRegularUser(ctx, "testuser")
|
||||
require.NoError(t, err)
|
||||
|
||||
// Generate PAT
|
||||
token := auth.GeneratePersonalAccessToken()
|
||||
tokenHash := auth.HashPersonalAccessToken(token)
|
||||
tokenID := util.GenUUID()
|
||||
|
||||
// Store PAT in database
|
||||
patRecord := &storepb.PersonalAccessTokensUserSetting_PersonalAccessToken{
|
||||
TokenId: tokenID,
|
||||
TokenHash: tokenHash,
|
||||
Description: "Test PAT",
|
||||
CreatedAt: timestamppb.Now(),
|
||||
}
|
||||
err = ts.Store.AddUserPersonalAccessToken(ctx, user.ID, patRecord)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Authenticate
|
||||
authenticator := auth.NewAuthenticator(ts.Store, ts.Secret)
|
||||
authenticatedUser, pat, err := authenticator.AuthenticateByPAT(ctx, token)
|
||||
require.NoError(t, err)
|
||||
assert.NotNil(t, authenticatedUser)
|
||||
assert.NotNil(t, pat)
|
||||
assert.Equal(t, user.ID, authenticatedUser.ID)
|
||||
assert.Equal(t, tokenID, pat.TokenId)
|
||||
})
|
||||
|
||||
t.Run("fails with invalid PAT format", func(t *testing.T) {
|
||||
ts := NewTestService(t)
|
||||
defer ts.Cleanup()
|
||||
|
||||
authenticator := auth.NewAuthenticator(ts.Store, ts.Secret)
|
||||
_, _, err := authenticator.AuthenticateByPAT(ctx, "invalid-token-without-prefix")
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "invalid PAT format")
|
||||
})
|
||||
|
||||
t.Run("fails with non-existent PAT", func(t *testing.T) {
|
||||
ts := NewTestService(t)
|
||||
defer ts.Cleanup()
|
||||
|
||||
// Generate a PAT but don't store it
|
||||
token := auth.GeneratePersonalAccessToken()
|
||||
|
||||
authenticator := auth.NewAuthenticator(ts.Store, ts.Secret)
|
||||
_, _, err := authenticator.AuthenticateByPAT(ctx, token)
|
||||
assert.Error(t, err)
|
||||
})
|
||||
|
||||
t.Run("fails with expired PAT", func(t *testing.T) {
|
||||
ts := NewTestService(t)
|
||||
defer ts.Cleanup()
|
||||
|
||||
user, err := ts.CreateRegularUser(ctx, "testuser")
|
||||
require.NoError(t, err)
|
||||
|
||||
// Generate and store expired PAT
|
||||
token := auth.GeneratePersonalAccessToken()
|
||||
tokenHash := auth.HashPersonalAccessToken(token)
|
||||
tokenID := util.GenUUID()
|
||||
|
||||
expiredPAT := &storepb.PersonalAccessTokensUserSetting_PersonalAccessToken{
|
||||
TokenId: tokenID,
|
||||
TokenHash: tokenHash,
|
||||
Description: "Expired PAT",
|
||||
ExpiresAt: timestamppb.New(time.Now().Add(-1 * time.Hour)), // Expired
|
||||
CreatedAt: timestamppb.Now(),
|
||||
}
|
||||
err = ts.Store.AddUserPersonalAccessToken(ctx, user.ID, expiredPAT)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Try to authenticate
|
||||
authenticator := auth.NewAuthenticator(ts.Store, ts.Secret)
|
||||
_, _, err = authenticator.AuthenticateByPAT(ctx, token)
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "expired")
|
||||
})
|
||||
|
||||
t.Run("succeeds with non-expiring PAT", func(t *testing.T) {
|
||||
ts := NewTestService(t)
|
||||
defer ts.Cleanup()
|
||||
|
||||
user, err := ts.CreateRegularUser(ctx, "testuser")
|
||||
require.NoError(t, err)
|
||||
|
||||
// Generate and store PAT without expiration
|
||||
token := auth.GeneratePersonalAccessToken()
|
||||
tokenHash := auth.HashPersonalAccessToken(token)
|
||||
tokenID := util.GenUUID()
|
||||
|
||||
patRecord := &storepb.PersonalAccessTokensUserSetting_PersonalAccessToken{
|
||||
TokenId: tokenID,
|
||||
TokenHash: tokenHash,
|
||||
Description: "Never-expiring PAT",
|
||||
ExpiresAt: nil, // No expiration
|
||||
CreatedAt: timestamppb.Now(),
|
||||
}
|
||||
err = ts.Store.AddUserPersonalAccessToken(ctx, user.ID, patRecord)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Authenticate
|
||||
authenticator := auth.NewAuthenticator(ts.Store, ts.Secret)
|
||||
authenticatedUser, pat, err := authenticator.AuthenticateByPAT(ctx, token)
|
||||
require.NoError(t, err)
|
||||
assert.NotNil(t, authenticatedUser)
|
||||
assert.NotNil(t, pat)
|
||||
assert.Nil(t, pat.ExpiresAt)
|
||||
})
|
||||
|
||||
t.Run("fails with archived user", func(t *testing.T) {
|
||||
ts := NewTestService(t)
|
||||
defer ts.Cleanup()
|
||||
|
||||
user, err := ts.CreateRegularUser(ctx, "testuser")
|
||||
require.NoError(t, err)
|
||||
|
||||
// Generate and store PAT
|
||||
token := auth.GeneratePersonalAccessToken()
|
||||
tokenHash := auth.HashPersonalAccessToken(token)
|
||||
tokenID := util.GenUUID()
|
||||
|
||||
patRecord := &storepb.PersonalAccessTokensUserSetting_PersonalAccessToken{
|
||||
TokenId: tokenID,
|
||||
TokenHash: tokenHash,
|
||||
Description: "Test PAT",
|
||||
CreatedAt: timestamppb.Now(),
|
||||
}
|
||||
err = ts.Store.AddUserPersonalAccessToken(ctx, user.ID, patRecord)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Archive the user
|
||||
archivedStatus := store.Archived
|
||||
_, err = ts.Store.UpdateUser(ctx, &store.UpdateUser{
|
||||
ID: user.ID,
|
||||
RowStatus: &archivedStatus,
|
||||
})
|
||||
require.NoError(t, err)
|
||||
|
||||
// Try to authenticate
|
||||
authenticator := auth.NewAuthenticator(ts.Store, ts.Secret)
|
||||
_, _, err = authenticator.AuthenticateByPAT(ctx, token)
|
||||
assert.Error(t, err)
|
||||
assert.Contains(t, err.Error(), "archived")
|
||||
})
|
||||
}
|
||||
|
||||
func TestStoreRefreshTokenMethods(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
t.Run("adds and retrieves refresh token", func(t *testing.T) {
|
||||
ts := NewTestService(t)
|
||||
defer ts.Cleanup()
|
||||
|
||||
user, err := ts.CreateRegularUser(ctx, "testuser")
|
||||
require.NoError(t, err)
|
||||
|
||||
tokenID := util.GenUUID()
|
||||
token := &storepb.RefreshTokensUserSetting_RefreshToken{
|
||||
TokenId: tokenID,
|
||||
ExpiresAt: timestamppb.New(time.Now().Add(30 * 24 * time.Hour)),
|
||||
CreatedAt: timestamppb.Now(),
|
||||
}
|
||||
|
||||
err = ts.Store.AddUserRefreshToken(ctx, user.ID, token)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Retrieve tokens
|
||||
tokens, err := ts.Store.GetUserRefreshTokens(ctx, user.ID)
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, tokens, 1)
|
||||
assert.Equal(t, tokenID, tokens[0].TokenId)
|
||||
})
|
||||
|
||||
t.Run("retrieves specific refresh token by ID", func(t *testing.T) {
|
||||
ts := NewTestService(t)
|
||||
defer ts.Cleanup()
|
||||
|
||||
user, err := ts.CreateRegularUser(ctx, "testuser")
|
||||
require.NoError(t, err)
|
||||
|
||||
tokenID := util.GenUUID()
|
||||
token := &storepb.RefreshTokensUserSetting_RefreshToken{
|
||||
TokenId: tokenID,
|
||||
ExpiresAt: timestamppb.New(time.Now().Add(30 * 24 * time.Hour)),
|
||||
CreatedAt: timestamppb.Now(),
|
||||
}
|
||||
|
||||
err = ts.Store.AddUserRefreshToken(ctx, user.ID, token)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Retrieve specific token
|
||||
retrievedToken, err := ts.Store.GetUserRefreshTokenByID(ctx, user.ID, tokenID)
|
||||
require.NoError(t, err)
|
||||
assert.NotNil(t, retrievedToken)
|
||||
assert.Equal(t, tokenID, retrievedToken.TokenId)
|
||||
})
|
||||
|
||||
t.Run("removes refresh token", func(t *testing.T) {
|
||||
ts := NewTestService(t)
|
||||
defer ts.Cleanup()
|
||||
|
||||
user, err := ts.CreateRegularUser(ctx, "testuser")
|
||||
require.NoError(t, err)
|
||||
|
||||
tokenID := util.GenUUID()
|
||||
token := &storepb.RefreshTokensUserSetting_RefreshToken{
|
||||
TokenId: tokenID,
|
||||
ExpiresAt: timestamppb.New(time.Now().Add(30 * 24 * time.Hour)),
|
||||
CreatedAt: timestamppb.Now(),
|
||||
}
|
||||
|
||||
err = ts.Store.AddUserRefreshToken(ctx, user.ID, token)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Remove token
|
||||
err = ts.Store.RemoveUserRefreshToken(ctx, user.ID, tokenID)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify removal
|
||||
tokens, err := ts.Store.GetUserRefreshTokens(ctx, user.ID)
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, tokens, 0)
|
||||
})
|
||||
|
||||
t.Run("handles multiple refresh tokens", func(t *testing.T) {
|
||||
ts := NewTestService(t)
|
||||
defer ts.Cleanup()
|
||||
|
||||
user, err := ts.CreateRegularUser(ctx, "testuser")
|
||||
require.NoError(t, err)
|
||||
|
||||
// Add multiple tokens
|
||||
tokenID1 := util.GenUUID()
|
||||
tokenID2 := util.GenUUID()
|
||||
|
||||
token1 := &storepb.RefreshTokensUserSetting_RefreshToken{
|
||||
TokenId: tokenID1,
|
||||
ExpiresAt: timestamppb.New(time.Now().Add(30 * 24 * time.Hour)),
|
||||
CreatedAt: timestamppb.Now(),
|
||||
}
|
||||
token2 := &storepb.RefreshTokensUserSetting_RefreshToken{
|
||||
TokenId: tokenID2,
|
||||
ExpiresAt: timestamppb.New(time.Now().Add(30 * 24 * time.Hour)),
|
||||
CreatedAt: timestamppb.Now(),
|
||||
}
|
||||
|
||||
err = ts.Store.AddUserRefreshToken(ctx, user.ID, token1)
|
||||
require.NoError(t, err)
|
||||
err = ts.Store.AddUserRefreshToken(ctx, user.ID, token2)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Retrieve all tokens
|
||||
tokens, err := ts.Store.GetUserRefreshTokens(ctx, user.ID)
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, tokens, 2)
|
||||
|
||||
// Remove one token
|
||||
err = ts.Store.RemoveUserRefreshToken(ctx, user.ID, tokenID1)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify only one token remains
|
||||
tokens, err = ts.Store.GetUserRefreshTokens(ctx, user.ID)
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, tokens, 1)
|
||||
assert.Equal(t, tokenID2, tokens[0].TokenId)
|
||||
})
|
||||
}
|
||||
|
||||
func TestStorePersonalAccessTokenMethods(t *testing.T) {
|
||||
ctx := context.Background()
|
||||
|
||||
t.Run("adds and retrieves PAT", func(t *testing.T) {
|
||||
ts := NewTestService(t)
|
||||
defer ts.Cleanup()
|
||||
|
||||
user, err := ts.CreateRegularUser(ctx, "testuser")
|
||||
require.NoError(t, err)
|
||||
|
||||
token := auth.GeneratePersonalAccessToken()
|
||||
tokenHash := auth.HashPersonalAccessToken(token)
|
||||
tokenID := util.GenUUID()
|
||||
|
||||
pat := &storepb.PersonalAccessTokensUserSetting_PersonalAccessToken{
|
||||
TokenId: tokenID,
|
||||
TokenHash: tokenHash,
|
||||
Description: "Test PAT",
|
||||
CreatedAt: timestamppb.Now(),
|
||||
}
|
||||
|
||||
err = ts.Store.AddUserPersonalAccessToken(ctx, user.ID, pat)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Retrieve PATs
|
||||
pats, err := ts.Store.GetUserPersonalAccessTokens(ctx, user.ID)
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, pats, 1)
|
||||
assert.Equal(t, tokenID, pats[0].TokenId)
|
||||
assert.Equal(t, tokenHash, pats[0].TokenHash)
|
||||
})
|
||||
|
||||
t.Run("removes PAT", func(t *testing.T) {
|
||||
ts := NewTestService(t)
|
||||
defer ts.Cleanup()
|
||||
|
||||
user, err := ts.CreateRegularUser(ctx, "testuser")
|
||||
require.NoError(t, err)
|
||||
|
||||
token := auth.GeneratePersonalAccessToken()
|
||||
tokenHash := auth.HashPersonalAccessToken(token)
|
||||
tokenID := util.GenUUID()
|
||||
|
||||
pat := &storepb.PersonalAccessTokensUserSetting_PersonalAccessToken{
|
||||
TokenId: tokenID,
|
||||
TokenHash: tokenHash,
|
||||
Description: "Test PAT",
|
||||
CreatedAt: timestamppb.Now(),
|
||||
}
|
||||
|
||||
err = ts.Store.AddUserPersonalAccessToken(ctx, user.ID, pat)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Remove PAT
|
||||
err = ts.Store.RemoveUserPersonalAccessToken(ctx, user.ID, tokenID)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify removal
|
||||
pats, err := ts.Store.GetUserPersonalAccessTokens(ctx, user.ID)
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, pats, 0)
|
||||
})
|
||||
|
||||
t.Run("updates PAT last used time", func(t *testing.T) {
|
||||
ts := NewTestService(t)
|
||||
defer ts.Cleanup()
|
||||
|
||||
user, err := ts.CreateRegularUser(ctx, "testuser")
|
||||
require.NoError(t, err)
|
||||
|
||||
token := auth.GeneratePersonalAccessToken()
|
||||
tokenHash := auth.HashPersonalAccessToken(token)
|
||||
tokenID := util.GenUUID()
|
||||
|
||||
pat := &storepb.PersonalAccessTokensUserSetting_PersonalAccessToken{
|
||||
TokenId: tokenID,
|
||||
TokenHash: tokenHash,
|
||||
Description: "Test PAT",
|
||||
CreatedAt: timestamppb.Now(),
|
||||
}
|
||||
|
||||
err = ts.Store.AddUserPersonalAccessToken(ctx, user.ID, pat)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Update last used time
|
||||
lastUsed := timestamppb.Now()
|
||||
err = ts.Store.UpdatePATLastUsed(ctx, user.ID, tokenID, lastUsed)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify update
|
||||
pats, err := ts.Store.GetUserPersonalAccessTokens(ctx, user.ID)
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, pats, 1)
|
||||
assert.NotNil(t, pats[0].LastUsedAt)
|
||||
})
|
||||
|
||||
t.Run("handles multiple PATs", func(t *testing.T) {
|
||||
ts := NewTestService(t)
|
||||
defer ts.Cleanup()
|
||||
|
||||
user, err := ts.CreateRegularUser(ctx, "testuser")
|
||||
require.NoError(t, err)
|
||||
|
||||
// Add multiple PATs
|
||||
token1 := auth.GeneratePersonalAccessToken()
|
||||
tokenHash1 := auth.HashPersonalAccessToken(token1)
|
||||
tokenID1 := util.GenUUID()
|
||||
|
||||
token2 := auth.GeneratePersonalAccessToken()
|
||||
tokenHash2 := auth.HashPersonalAccessToken(token2)
|
||||
tokenID2 := util.GenUUID()
|
||||
|
||||
pat1 := &storepb.PersonalAccessTokensUserSetting_PersonalAccessToken{
|
||||
TokenId: tokenID1,
|
||||
TokenHash: tokenHash1,
|
||||
Description: "PAT 1",
|
||||
CreatedAt: timestamppb.Now(),
|
||||
}
|
||||
pat2 := &storepb.PersonalAccessTokensUserSetting_PersonalAccessToken{
|
||||
TokenId: tokenID2,
|
||||
TokenHash: tokenHash2,
|
||||
Description: "PAT 2",
|
||||
CreatedAt: timestamppb.Now(),
|
||||
}
|
||||
|
||||
err = ts.Store.AddUserPersonalAccessToken(ctx, user.ID, pat1)
|
||||
require.NoError(t, err)
|
||||
err = ts.Store.AddUserPersonalAccessToken(ctx, user.ID, pat2)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Retrieve all PATs
|
||||
pats, err := ts.Store.GetUserPersonalAccessTokens(ctx, user.ID)
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, pats, 2)
|
||||
|
||||
// Remove one PAT
|
||||
err = ts.Store.RemoveUserPersonalAccessToken(ctx, user.ID, tokenID1)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Verify only one PAT remains
|
||||
pats, err = ts.Store.GetUserPersonalAccessTokens(ctx, user.ID)
|
||||
require.NoError(t, err)
|
||||
assert.Len(t, pats, 1)
|
||||
assert.Equal(t, tokenID2, pats[0].TokenId)
|
||||
})
|
||||
|
||||
t.Run("finds user by PAT hash", func(t *testing.T) {
|
||||
ts := NewTestService(t)
|
||||
defer ts.Cleanup()
|
||||
|
||||
user, err := ts.CreateRegularUser(ctx, "testuser")
|
||||
require.NoError(t, err)
|
||||
|
||||
token := auth.GeneratePersonalAccessToken()
|
||||
tokenHash := auth.HashPersonalAccessToken(token)
|
||||
tokenID := util.GenUUID()
|
||||
|
||||
pat := &storepb.PersonalAccessTokensUserSetting_PersonalAccessToken{
|
||||
TokenId: tokenID,
|
||||
TokenHash: tokenHash,
|
||||
Description: "Test PAT",
|
||||
CreatedAt: timestamppb.Now(),
|
||||
}
|
||||
|
||||
err = ts.Store.AddUserPersonalAccessToken(ctx, user.ID, pat)
|
||||
require.NoError(t, err)
|
||||
|
||||
// Find user by PAT hash
|
||||
result, err := ts.Store.GetUserByPATHash(ctx, tokenHash)
|
||||
require.NoError(t, err)
|
||||
assert.NotNil(t, result)
|
||||
assert.Equal(t, user.ID, result.UserID)
|
||||
assert.NotNil(t, result.User)
|
||||
assert.Equal(t, user.Username, result.User.Username)
|
||||
assert.NotNil(t, result.PAT)
|
||||
assert.Equal(t, tokenID, result.PAT.TokenId)
|
||||
})
|
||||
}
|
||||
@ -0,0 +1,21 @@
|
||||
// In-memory storage for access token (not persisted for security)
|
||||
let accessToken: string | null = null;
|
||||
let tokenExpiresAt: Date | null = null;
|
||||
|
||||
export const getAccessToken = (): string | null => accessToken;
|
||||
|
||||
export const setAccessToken = (token: string | null, expiresAt?: Date): void => {
|
||||
accessToken = token;
|
||||
tokenExpiresAt = expiresAt || null;
|
||||
};
|
||||
|
||||
export const isTokenExpired = (): boolean => {
|
||||
if (!tokenExpiresAt) return true;
|
||||
// Consider expired 30 seconds before actual expiry for safety
|
||||
return new Date() >= new Date(tokenExpiresAt.getTime() - 30000);
|
||||
};
|
||||
|
||||
export const clearAccessToken = (): void => {
|
||||
accessToken = null;
|
||||
tokenExpiresAt = null;
|
||||
};
|
||||
File diff suppressed because one or more lines are too long
Loading…
Reference in New Issue