From 760c164328472e0a360924ea562540a60b339a05 Mon Sep 17 00:00:00 2001 From: Steven Date: Tue, 17 Jun 2025 20:56:10 +0800 Subject: [PATCH] chore: add server tests --- .github/prompts/claude.prompt.md | 5 + CLAUDE.md | 21 +- server/router/api/v1/idp_service.go | 12 +- ...ortcuts_service.go => shortcut_service.go} | 8 +- server/router/api/v1/test/idp_service_test.go | 520 +++++++++++ .../api/v1/test/shortcut_service_test.go | 819 ++++++++++++++++++ .../router/api/v1/{ => test}/test_helper.go | 12 +- .../api/v1/test/webhook_service_test.go | 408 +++++++++ .../v1/{ => test}/workspace_service_test.go | 80 -- server/router/api/v1/test_auth.go | 19 + server/router/api/v1/webhook_service.go | 29 +- 11 files changed, 1824 insertions(+), 109 deletions(-) create mode 100644 .github/prompts/claude.prompt.md rename server/router/api/v1/{shortcuts_service.go => shortcut_service.go} (98%) create mode 100644 server/router/api/v1/test/idp_service_test.go create mode 100644 server/router/api/v1/test/shortcut_service_test.go rename server/router/api/v1/{ => test}/test_helper.go (85%) create mode 100644 server/router/api/v1/test/webhook_service_test.go rename server/router/api/v1/{ => test}/workspace_service_test.go (70%) create mode 100644 server/router/api/v1/test_auth.go diff --git a/.github/prompts/claude.prompt.md b/.github/prompts/claude.prompt.md new file mode 100644 index 000000000..341f911c3 --- /dev/null +++ b/.github/prompts/claude.prompt.md @@ -0,0 +1,5 @@ +--- +mode: agent +--- + +Please follow `./CLAUDE.md` for the basic structure and development guidelines of the Memos project. diff --git a/CLAUDE.md b/CLAUDE.md index cb24c1995..bd32a86f4 100644 --- a/CLAUDE.md +++ b/CLAUDE.md @@ -215,23 +215,4 @@ FROM alpine:latest AS production 1. **Lint Checking**: All linters must pass 2. **Test Coverage**: New code should include tests 3. **Documentation**: Update relevant documentation -4. **AIP Compliance**: New APIs should follow AIP standards - -## Future Considerations - -### Planned Improvements - -- **Additional Service Tests**: Expand test coverage to all services -- **API Versioning**: Support for multiple API versions -- **Enhanced Metrics**: Better observability and monitoring -- **Plugin System**: Extensible architecture for custom features - -### Technical Debt - -- **Legacy API Cleanup**: Remove deprecated endpoints -- **Performance Optimization**: Database query optimization -- **Security Hardening**: Enhanced security measures - ---- - -_This documentation reflects the current state of the Memos project as of June 2025, including recent AIP compliance refactoring and comprehensive testing infrastructure._ +4. **AIP Compliance**: New APIs should follow [AIP](https://google.aip.dev/) standards diff --git a/server/router/api/v1/idp_service.go b/server/router/api/v1/idp_service.go index 22024cc1c..68dea473d 100644 --- a/server/router/api/v1/idp_service.go +++ b/server/router/api/v1/idp_service.go @@ -18,7 +18,7 @@ func (s *APIV1Service) CreateIdentityProvider(ctx context.Context, request *v1pb if err != nil { return nil, status.Errorf(codes.Internal, "failed to get user: %v", err) } - if currentUser.Role != store.RoleHost { + if currentUser == nil || currentUser.Role != store.RoleHost { return nil, status.Errorf(codes.PermissionDenied, "permission denied") } @@ -97,6 +97,16 @@ func (s *APIV1Service) DeleteIdentityProvider(ctx context.Context, request *v1pb if err != nil { return nil, status.Errorf(codes.InvalidArgument, "invalid identity provider name: %v", err) } + + // Check if the identity provider exists before trying to delete it + identityProvider, err := s.Store.GetIdentityProvider(ctx, &store.FindIdentityProvider{ID: &id}) + if err != nil { + return nil, status.Errorf(codes.Internal, "failed to check identity provider existence: %v", err) + } + if identityProvider == nil { + return nil, status.Errorf(codes.NotFound, "identity provider not found") + } + if err := s.Store.DeleteIdentityProvider(ctx, &store.DeleteIdentityProvider{ID: id}); err != nil { return nil, status.Errorf(codes.Internal, "failed to delete identity provider, error: %+v", err) } diff --git a/server/router/api/v1/shortcuts_service.go b/server/router/api/v1/shortcut_service.go similarity index 98% rename from server/router/api/v1/shortcuts_service.go rename to server/router/api/v1/shortcut_service.go index 8f3d35dac..9fbb8c762 100644 --- a/server/router/api/v1/shortcuts_service.go +++ b/server/router/api/v1/shortcut_service.go @@ -290,17 +290,23 @@ func (s *APIV1Service) DeleteShortcut(ctx context.Context, request *v1pb.DeleteS return nil, err } if userSetting == nil { - return &emptypb.Empty{}, nil + return nil, status.Errorf(codes.NotFound, "shortcut not found") } shortcutsUserSetting := userSetting.GetShortcuts() shortcuts := shortcutsUserSetting.GetShortcuts() newShortcuts := make([]*storepb.ShortcutsUserSetting_Shortcut, 0, len(shortcuts)) + found := false for _, shortcut := range shortcuts { if shortcut.GetId() != shortcutID { newShortcuts = append(newShortcuts, shortcut) + } else { + found = true } } + if !found { + return nil, status.Errorf(codes.NotFound, "shortcut not found") + } shortcutsUserSetting.Shortcuts = newShortcuts userSetting.Value = &storepb.UserSetting_Shortcuts{ Shortcuts: shortcutsUserSetting, diff --git a/server/router/api/v1/test/idp_service_test.go b/server/router/api/v1/test/idp_service_test.go new file mode 100644 index 000000000..5b0b05a93 --- /dev/null +++ b/server/router/api/v1/test/idp_service_test.go @@ -0,0 +1,520 @@ +package v1 + +import ( + "context" + "testing" + + "github.com/stretchr/testify/require" + "google.golang.org/protobuf/types/known/fieldmaskpb" + + v1pb "github.com/usememos/memos/proto/gen/api/v1" +) + +func TestCreateIdentityProvider(t *testing.T) { + ctx := context.Background() + + t.Run("CreateIdentityProvider success", func(t *testing.T) { + ts := NewTestService(t) + defer ts.Cleanup() + + // Create host user + hostUser, err := ts.CreateHostUser(ctx, "admin") + require.NoError(t, err) + + // Set user context + ctx := ts.CreateUserContext(ctx, hostUser.Username) + + // Create OAuth2 identity provider + req := &v1pb.CreateIdentityProviderRequest{ + IdentityProvider: &v1pb.IdentityProvider{ + Title: "Test OAuth2 Provider", + IdentifierFilter: "", + Type: v1pb.IdentityProvider_OAUTH2, + Config: &v1pb.IdentityProviderConfig{ + Config: &v1pb.IdentityProviderConfig_Oauth2Config{ + Oauth2Config: &v1pb.OAuth2Config{ + ClientId: "test-client-id", + ClientSecret: "test-client-secret", + AuthUrl: "https://example.com/oauth/authorize", + TokenUrl: "https://example.com/oauth/token", + UserInfoUrl: "https://example.com/oauth/userinfo", + Scopes: []string{"openid", "profile", "email"}, + FieldMapping: &v1pb.FieldMapping{ + Identifier: "id", + DisplayName: "name", + Email: "email", + AvatarUrl: "avatar_url", + }, + }, + }, + }, + }, + } + + resp, err := ts.Service.CreateIdentityProvider(ctx, req) + require.NoError(t, err) + require.NotNil(t, resp) + require.Equal(t, "Test OAuth2 Provider", resp.Title) + require.Equal(t, v1pb.IdentityProvider_OAUTH2, resp.Type) + require.Contains(t, resp.Name, "identityProviders/") + require.NotEmpty(t, resp.Uid) + require.NotNil(t, resp.Config.GetOauth2Config()) + require.Equal(t, "test-client-id", resp.Config.GetOauth2Config().ClientId) + }) + + t.Run("CreateIdentityProvider permission denied for non-host user", func(t *testing.T) { + ts := NewTestService(t) + defer ts.Cleanup() + + // Create regular user + regularUser, err := ts.CreateRegularUser(ctx, "user") + require.NoError(t, err) + + // Set user context + ctx := ts.CreateUserContext(ctx, regularUser.Username) + + req := &v1pb.CreateIdentityProviderRequest{ + IdentityProvider: &v1pb.IdentityProvider{ + Title: "Test Provider", + Type: v1pb.IdentityProvider_OAUTH2, + }, + } + + _, err = ts.Service.CreateIdentityProvider(ctx, req) + require.Error(t, err) + require.Contains(t, err.Error(), "permission denied") + }) + + t.Run("CreateIdentityProvider unauthenticated", func(t *testing.T) { + ts := NewTestService(t) + defer ts.Cleanup() + + req := &v1pb.CreateIdentityProviderRequest{ + IdentityProvider: &v1pb.IdentityProvider{ + Title: "Test Provider", + Type: v1pb.IdentityProvider_OAUTH2, + }, + } + + _, err := ts.Service.CreateIdentityProvider(ctx, req) + require.Error(t, err) + require.Contains(t, err.Error(), "permission denied") + }) +} + +func TestListIdentityProviders(t *testing.T) { + ctx := context.Background() + + t.Run("ListIdentityProviders empty", func(t *testing.T) { + ts := NewTestService(t) + defer ts.Cleanup() + + req := &v1pb.ListIdentityProvidersRequest{} + resp, err := ts.Service.ListIdentityProviders(ctx, req) + require.NoError(t, err) + require.NotNil(t, resp) + require.Empty(t, resp.IdentityProviders) + }) + + t.Run("ListIdentityProviders with providers", func(t *testing.T) { + ts := NewTestService(t) + defer ts.Cleanup() + + // Create host user + hostUser, err := ts.CreateHostUser(ctx, "admin") + require.NoError(t, err) + + // Set user context + userCtx := ts.CreateUserContext(ctx, hostUser.Username) + + // Create a couple of identity providers + createReq1 := &v1pb.CreateIdentityProviderRequest{ + IdentityProvider: &v1pb.IdentityProvider{ + Title: "Provider 1", + Type: v1pb.IdentityProvider_OAUTH2, + Config: &v1pb.IdentityProviderConfig{ + Config: &v1pb.IdentityProviderConfig_Oauth2Config{ + Oauth2Config: &v1pb.OAuth2Config{ + ClientId: "client1", + AuthUrl: "https://example1.com/auth", + TokenUrl: "https://example1.com/token", + UserInfoUrl: "https://example1.com/user", + FieldMapping: &v1pb.FieldMapping{ + Identifier: "id", + }, + }, + }, + }, + }, + } + + createReq2 := &v1pb.CreateIdentityProviderRequest{ + IdentityProvider: &v1pb.IdentityProvider{ + Title: "Provider 2", + Type: v1pb.IdentityProvider_OAUTH2, + Config: &v1pb.IdentityProviderConfig{ + Config: &v1pb.IdentityProviderConfig_Oauth2Config{ + Oauth2Config: &v1pb.OAuth2Config{ + ClientId: "client2", + AuthUrl: "https://example2.com/auth", + TokenUrl: "https://example2.com/token", + UserInfoUrl: "https://example2.com/user", + FieldMapping: &v1pb.FieldMapping{ + Identifier: "id", + }, + }, + }, + }, + }, + } + + _, err = ts.Service.CreateIdentityProvider(userCtx, createReq1) + require.NoError(t, err) + _, err = ts.Service.CreateIdentityProvider(userCtx, createReq2) + require.NoError(t, err) + + // List providers + listReq := &v1pb.ListIdentityProvidersRequest{} + resp, err := ts.Service.ListIdentityProviders(ctx, listReq) + require.NoError(t, err) + require.NotNil(t, resp) + require.Len(t, resp.IdentityProviders, 2) + + // Verify response contains expected providers + titles := []string{resp.IdentityProviders[0].Title, resp.IdentityProviders[1].Title} + require.Contains(t, titles, "Provider 1") + require.Contains(t, titles, "Provider 2") + }) +} + +func TestGetIdentityProvider(t *testing.T) { + ctx := context.Background() + + t.Run("GetIdentityProvider success", func(t *testing.T) { + ts := NewTestService(t) + defer ts.Cleanup() + + // Create host user + hostUser, err := ts.CreateHostUser(ctx, "admin") + require.NoError(t, err) + + // Set user context + userCtx := ts.CreateUserContext(ctx, hostUser.Username) + + // Create identity provider + createReq := &v1pb.CreateIdentityProviderRequest{ + IdentityProvider: &v1pb.IdentityProvider{ + Title: "Test Provider", + Type: v1pb.IdentityProvider_OAUTH2, + Config: &v1pb.IdentityProviderConfig{ + Config: &v1pb.IdentityProviderConfig_Oauth2Config{ + Oauth2Config: &v1pb.OAuth2Config{ + ClientId: "test-client", + ClientSecret: "test-secret", + AuthUrl: "https://example.com/auth", + TokenUrl: "https://example.com/token", + UserInfoUrl: "https://example.com/user", + Scopes: []string{"openid", "profile"}, + FieldMapping: &v1pb.FieldMapping{ + Identifier: "id", + DisplayName: "name", + Email: "email", + }, + }, + }, + }, + }, + } + + created, err := ts.Service.CreateIdentityProvider(userCtx, createReq) + require.NoError(t, err) + + // Get identity provider + getReq := &v1pb.GetIdentityProviderRequest{ + Name: created.Name, + } + + resp, err := ts.Service.GetIdentityProvider(ctx, getReq) + require.NoError(t, err) + require.NotNil(t, resp) + require.Equal(t, created.Name, resp.Name) + require.Equal(t, "Test Provider", resp.Title) + require.Equal(t, v1pb.IdentityProvider_OAUTH2, resp.Type) + require.NotNil(t, resp.Config.GetOauth2Config()) + require.Equal(t, "test-client", resp.Config.GetOauth2Config().ClientId) + require.Equal(t, "test-secret", resp.Config.GetOauth2Config().ClientSecret) + }) + + t.Run("GetIdentityProvider not found", func(t *testing.T) { + ts := NewTestService(t) + defer ts.Cleanup() + + req := &v1pb.GetIdentityProviderRequest{ + Name: "identityProviders/999", + } + + _, err := ts.Service.GetIdentityProvider(ctx, req) + require.Error(t, err) + require.Contains(t, err.Error(), "not found") + }) + + t.Run("GetIdentityProvider invalid name", func(t *testing.T) { + ts := NewTestService(t) + defer ts.Cleanup() + + req := &v1pb.GetIdentityProviderRequest{ + Name: "invalid-name", + } + + _, err := ts.Service.GetIdentityProvider(ctx, req) + require.Error(t, err) + require.Contains(t, err.Error(), "invalid identity provider name") + }) +} + +func TestUpdateIdentityProvider(t *testing.T) { + ctx := context.Background() + + t.Run("UpdateIdentityProvider success", func(t *testing.T) { + ts := NewTestService(t) + defer ts.Cleanup() + + // Create host user + hostUser, err := ts.CreateHostUser(ctx, "admin") + require.NoError(t, err) + + // Set user context + userCtx := ts.CreateUserContext(ctx, hostUser.Username) + + // Create identity provider + createReq := &v1pb.CreateIdentityProviderRequest{ + IdentityProvider: &v1pb.IdentityProvider{ + Title: "Original Provider", + IdentifierFilter: "", + Type: v1pb.IdentityProvider_OAUTH2, + Config: &v1pb.IdentityProviderConfig{ + Config: &v1pb.IdentityProviderConfig_Oauth2Config{ + Oauth2Config: &v1pb.OAuth2Config{ + ClientId: "original-client", + AuthUrl: "https://original.com/auth", + TokenUrl: "https://original.com/token", + UserInfoUrl: "https://original.com/user", + FieldMapping: &v1pb.FieldMapping{ + Identifier: "id", + }, + }, + }, + }, + }, + } + + created, err := ts.Service.CreateIdentityProvider(userCtx, createReq) + require.NoError(t, err) + + // Update identity provider + updateReq := &v1pb.UpdateIdentityProviderRequest{ + IdentityProvider: &v1pb.IdentityProvider{ + Name: created.Name, + Title: "Updated Provider", + IdentifierFilter: "test@example.com", + Type: v1pb.IdentityProvider_OAUTH2, + Config: &v1pb.IdentityProviderConfig{ + Config: &v1pb.IdentityProviderConfig_Oauth2Config{ + Oauth2Config: &v1pb.OAuth2Config{ + ClientId: "updated-client", + ClientSecret: "updated-secret", + AuthUrl: "https://updated.com/auth", + TokenUrl: "https://updated.com/token", + UserInfoUrl: "https://updated.com/user", + Scopes: []string{"openid", "profile", "email"}, + FieldMapping: &v1pb.FieldMapping{ + Identifier: "sub", + DisplayName: "given_name", + Email: "email", + AvatarUrl: "picture", + }, + }, + }, + }, + }, + UpdateMask: &fieldmaskpb.FieldMask{ + Paths: []string{"title", "identifier_filter", "config"}, + }, + } + + updated, err := ts.Service.UpdateIdentityProvider(userCtx, updateReq) + require.NoError(t, err) + require.NotNil(t, updated) + require.Equal(t, "Updated Provider", updated.Title) + require.Equal(t, "test@example.com", updated.IdentifierFilter) + require.Equal(t, "updated-client", updated.Config.GetOauth2Config().ClientId) + }) + + t.Run("UpdateIdentityProvider missing update mask", func(t *testing.T) { + ts := NewTestService(t) + defer ts.Cleanup() + + req := &v1pb.UpdateIdentityProviderRequest{ + IdentityProvider: &v1pb.IdentityProvider{ + Name: "identityProviders/1", + Title: "Updated Provider", + }, + } + + _, err := ts.Service.UpdateIdentityProvider(ctx, req) + require.Error(t, err) + require.Contains(t, err.Error(), "update_mask is required") + }) + + t.Run("UpdateIdentityProvider invalid name", func(t *testing.T) { + ts := NewTestService(t) + defer ts.Cleanup() + + req := &v1pb.UpdateIdentityProviderRequest{ + IdentityProvider: &v1pb.IdentityProvider{ + Name: "invalid-name", + Title: "Updated Provider", + }, + UpdateMask: &fieldmaskpb.FieldMask{ + Paths: []string{"title"}, + }, + } + + _, err := ts.Service.UpdateIdentityProvider(ctx, req) + require.Error(t, err) + require.Contains(t, err.Error(), "invalid identity provider name") + }) +} + +func TestDeleteIdentityProvider(t *testing.T) { + ctx := context.Background() + + t.Run("DeleteIdentityProvider success", func(t *testing.T) { + ts := NewTestService(t) + defer ts.Cleanup() + + // Create host user + hostUser, err := ts.CreateHostUser(ctx, "admin") + require.NoError(t, err) + + // Set user context + userCtx := ts.CreateUserContext(ctx, hostUser.Username) + + // Create identity provider + createReq := &v1pb.CreateIdentityProviderRequest{ + IdentityProvider: &v1pb.IdentityProvider{ + Title: "Provider to Delete", + Type: v1pb.IdentityProvider_OAUTH2, + Config: &v1pb.IdentityProviderConfig{ + Config: &v1pb.IdentityProviderConfig_Oauth2Config{ + Oauth2Config: &v1pb.OAuth2Config{ + ClientId: "client-to-delete", + AuthUrl: "https://example.com/auth", + TokenUrl: "https://example.com/token", + UserInfoUrl: "https://example.com/user", + FieldMapping: &v1pb.FieldMapping{ + Identifier: "id", + }, + }, + }, + }, + }, + } + + created, err := ts.Service.CreateIdentityProvider(userCtx, createReq) + require.NoError(t, err) + + // Delete identity provider + deleteReq := &v1pb.DeleteIdentityProviderRequest{ + Name: created.Name, + } + + _, err = ts.Service.DeleteIdentityProvider(userCtx, deleteReq) + require.NoError(t, err) + + // Verify deletion + getReq := &v1pb.GetIdentityProviderRequest{ + Name: created.Name, + } + + _, err = ts.Service.GetIdentityProvider(ctx, getReq) + require.Error(t, err) + require.Contains(t, err.Error(), "not found") + }) + + t.Run("DeleteIdentityProvider invalid name", func(t *testing.T) { + ts := NewTestService(t) + defer ts.Cleanup() + + req := &v1pb.DeleteIdentityProviderRequest{ + Name: "invalid-name", + } + + _, err := ts.Service.DeleteIdentityProvider(ctx, req) + require.Error(t, err) + require.Contains(t, err.Error(), "invalid identity provider name") + }) + + t.Run("DeleteIdentityProvider not found", func(t *testing.T) { + ts := NewTestService(t) + defer ts.Cleanup() + + // Create host user + hostUser, err := ts.CreateHostUser(ctx, "admin") + require.NoError(t, err) + + // Set user context + userCtx := ts.CreateUserContext(ctx, hostUser.Username) + + req := &v1pb.DeleteIdentityProviderRequest{ + Name: "identityProviders/999", + } + + _, err = ts.Service.DeleteIdentityProvider(userCtx, req) + require.Error(t, err) + // Note: Delete might succeed even if item doesn't exist, depending on store implementation + }) +} + +func TestIdentityProviderPermissions(t *testing.T) { + ctx := context.Background() + + t.Run("Only host users can create identity providers", func(t *testing.T) { + ts := NewTestService(t) + defer ts.Cleanup() + + // Create regular user + regularUser, err := ts.CreateRegularUser(ctx, "regularuser") + require.NoError(t, err) + + // Set user context + userCtx := ts.CreateUserContext(ctx, regularUser.Username) + + req := &v1pb.CreateIdentityProviderRequest{ + IdentityProvider: &v1pb.IdentityProvider{ + Title: "Test Provider", + Type: v1pb.IdentityProvider_OAUTH2, + }, + } + + _, err = ts.Service.CreateIdentityProvider(userCtx, req) + require.Error(t, err) + require.Contains(t, err.Error(), "permission denied") + }) + + t.Run("Authentication required", func(t *testing.T) { + ts := NewTestService(t) + defer ts.Cleanup() + + req := &v1pb.CreateIdentityProviderRequest{ + IdentityProvider: &v1pb.IdentityProvider{ + Title: "Test Provider", + Type: v1pb.IdentityProvider_OAUTH2, + }, + } + + _, err := ts.Service.CreateIdentityProvider(ctx, req) + require.Error(t, err) + require.Contains(t, err.Error(), "permission denied") + }) +} diff --git a/server/router/api/v1/test/shortcut_service_test.go b/server/router/api/v1/test/shortcut_service_test.go new file mode 100644 index 000000000..6f210789f --- /dev/null +++ b/server/router/api/v1/test/shortcut_service_test.go @@ -0,0 +1,819 @@ +package v1 + +import ( + "context" + "fmt" + "testing" + + "github.com/stretchr/testify/require" + "google.golang.org/protobuf/types/known/fieldmaskpb" + + v1pb "github.com/usememos/memos/proto/gen/api/v1" +) + +func TestListShortcuts(t *testing.T) { + ctx := context.Background() + + t.Run("ListShortcuts success", func(t *testing.T) { + ts := NewTestService(t) + defer ts.Cleanup() + + // Create a user + user, err := ts.CreateRegularUser(ctx, "testuser") + require.NoError(t, err) + + // Set user context + userCtx := ts.CreateUserContext(ctx, user.Username) + + // List shortcuts (should be empty initially) + req := &v1pb.ListShortcutsRequest{ + Parent: fmt.Sprintf("users/%d", user.ID), + } + + resp, err := ts.Service.ListShortcuts(userCtx, req) + require.NoError(t, err) + require.NotNil(t, resp) + require.Empty(t, resp.Shortcuts) + }) + + t.Run("ListShortcuts permission denied for different user", func(t *testing.T) { + ts := NewTestService(t) + defer ts.Cleanup() + + // Create two users + user1, err := ts.CreateRegularUser(ctx, "user1") + require.NoError(t, err) + user2, err := ts.CreateRegularUser(ctx, "user2") + require.NoError(t, err) + + // Set user1 context but try to list user2's shortcuts + userCtx := ts.CreateUserContext(ctx, user1.Username) + + req := &v1pb.ListShortcutsRequest{ + Parent: fmt.Sprintf("users/%d", user2.ID), + } + + _, err = ts.Service.ListShortcuts(userCtx, req) + require.Error(t, err) + require.Contains(t, err.Error(), "permission denied") + }) + + t.Run("ListShortcuts invalid parent format", func(t *testing.T) { + ts := NewTestService(t) + defer ts.Cleanup() + + // Create user + user, err := ts.CreateRegularUser(ctx, "testuser") + require.NoError(t, err) + + // Set user context + userCtx := ts.CreateUserContext(ctx, user.Username) + + req := &v1pb.ListShortcutsRequest{ + Parent: "invalid-parent-format", + } + + _, err = ts.Service.ListShortcuts(userCtx, req) + require.Error(t, err) + require.Contains(t, err.Error(), "invalid user name") + }) + + t.Run("ListShortcuts unauthenticated", func(t *testing.T) { + ts := NewTestService(t) + defer ts.Cleanup() + + req := &v1pb.ListShortcutsRequest{ + Parent: "users/1", + } + + _, err := ts.Service.ListShortcuts(ctx, req) + require.Error(t, err) + require.Contains(t, err.Error(), "permission denied") + }) +} + +func TestGetShortcut(t *testing.T) { + ctx := context.Background() + + t.Run("GetShortcut success", func(t *testing.T) { + ts := NewTestService(t) + defer ts.Cleanup() + + // Create a user + user, err := ts.CreateRegularUser(ctx, "testuser") + require.NoError(t, err) + + // Set user context + userCtx := ts.CreateUserContext(ctx, user.Username) + + // First create a shortcut + createReq := &v1pb.CreateShortcutRequest{ + Parent: fmt.Sprintf("users/%d", user.ID), + Shortcut: &v1pb.Shortcut{ + Title: "Test Shortcut", + Filter: "tag in [\"test\"]", + }, + } + + created, err := ts.Service.CreateShortcut(userCtx, createReq) + require.NoError(t, err) + + // Now get the shortcut + getReq := &v1pb.GetShortcutRequest{ + Name: created.Name, + } + + resp, err := ts.Service.GetShortcut(userCtx, getReq) + require.NoError(t, err) + require.NotNil(t, resp) + require.Equal(t, created.Name, resp.Name) + require.Equal(t, "Test Shortcut", resp.Title) + require.Equal(t, "tag in [\"test\"]", resp.Filter) + }) + + t.Run("GetShortcut permission denied for different user", func(t *testing.T) { + ts := NewTestService(t) + defer ts.Cleanup() + + // Create two users + user1, err := ts.CreateRegularUser(ctx, "user1") + require.NoError(t, err) + user2, err := ts.CreateRegularUser(ctx, "user2") + require.NoError(t, err) + + // Create shortcut as user1 + user1Ctx := ts.CreateUserContext(ctx, user1.Username) + createReq := &v1pb.CreateShortcutRequest{ + Parent: fmt.Sprintf("users/%d", user1.ID), + Shortcut: &v1pb.Shortcut{ + Title: "User1 Shortcut", + Filter: "tag in [\"user1\"]", + }, + } + + created, err := ts.Service.CreateShortcut(user1Ctx, createReq) + require.NoError(t, err) + + // Try to get shortcut as user2 + user2Ctx := ts.CreateUserContext(ctx, user2.Username) + getReq := &v1pb.GetShortcutRequest{ + Name: created.Name, + } + + _, err = ts.Service.GetShortcut(user2Ctx, getReq) + require.Error(t, err) + require.Contains(t, err.Error(), "permission denied") + }) + + t.Run("GetShortcut invalid name format", func(t *testing.T) { + ts := NewTestService(t) + defer ts.Cleanup() + + // Create user + user, err := ts.CreateRegularUser(ctx, "testuser") + require.NoError(t, err) + + // Set user context + userCtx := ts.CreateUserContext(ctx, user.Username) + + req := &v1pb.GetShortcutRequest{ + Name: "invalid-shortcut-name", + } + + _, err = ts.Service.GetShortcut(userCtx, req) + require.Error(t, err) + require.Contains(t, err.Error(), "invalid shortcut name") + }) + + t.Run("GetShortcut not found", func(t *testing.T) { + ts := NewTestService(t) + defer ts.Cleanup() + + // Create user + user, err := ts.CreateRegularUser(ctx, "testuser") + require.NoError(t, err) + + // Set user context + userCtx := ts.CreateUserContext(ctx, user.Username) + + req := &v1pb.GetShortcutRequest{ + Name: fmt.Sprintf("users/%d", user.ID) + "/shortcuts/nonexistent", + } + + _, err = ts.Service.GetShortcut(userCtx, req) + require.Error(t, err) + require.Contains(t, err.Error(), "not found") + }) +} + +func TestCreateShortcut(t *testing.T) { + ctx := context.Background() + + t.Run("CreateShortcut success", func(t *testing.T) { + ts := NewTestService(t) + defer ts.Cleanup() + + // Create a user + user, err := ts.CreateRegularUser(ctx, "testuser") + require.NoError(t, err) + + // Set user context + userCtx := ts.CreateUserContext(ctx, user.Username) + + req := &v1pb.CreateShortcutRequest{ + Parent: fmt.Sprintf("users/%d", user.ID), + Shortcut: &v1pb.Shortcut{ + Title: "My Shortcut", + Filter: "tag in [\"important\"]", + }, + } + + resp, err := ts.Service.CreateShortcut(userCtx, req) + require.NoError(t, err) + require.NotNil(t, resp) + require.Equal(t, "My Shortcut", resp.Title) + require.Equal(t, "tag in [\"important\"]", resp.Filter) + require.Contains(t, resp.Name, fmt.Sprintf("users/%d/shortcuts/", user.ID)) + + // Verify the shortcut was created by listing + listReq := &v1pb.ListShortcutsRequest{ + Parent: fmt.Sprintf("users/%d", user.ID), + } + + listResp, err := ts.Service.ListShortcuts(userCtx, listReq) + require.NoError(t, err) + require.Len(t, listResp.Shortcuts, 1) + require.Equal(t, "My Shortcut", listResp.Shortcuts[0].Title) + }) + + t.Run("CreateShortcut permission denied for different user", func(t *testing.T) { + ts := NewTestService(t) + defer ts.Cleanup() + + // Create two users + user1, err := ts.CreateRegularUser(ctx, "user1") + require.NoError(t, err) + user2, err := ts.CreateRegularUser(ctx, "user2") + require.NoError(t, err) + + // Set user1 context but try to create shortcut for user2 + userCtx := ts.CreateUserContext(ctx, user1.Username) + + req := &v1pb.CreateShortcutRequest{ + Parent: fmt.Sprintf("users/%d", user2.ID), + Shortcut: &v1pb.Shortcut{ + Title: "Forbidden Shortcut", + Filter: "tag in [\"forbidden\"]", + }, + } + + _, err = ts.Service.CreateShortcut(userCtx, req) + require.Error(t, err) + require.Contains(t, err.Error(), "permission denied") + }) + + t.Run("CreateShortcut invalid parent format", func(t *testing.T) { + ts := NewTestService(t) + defer ts.Cleanup() + + // Create user + user, err := ts.CreateRegularUser(ctx, "testuser") + require.NoError(t, err) + + // Set user context + userCtx := ts.CreateUserContext(ctx, user.Username) + + req := &v1pb.CreateShortcutRequest{ + Parent: "invalid-parent", + Shortcut: &v1pb.Shortcut{ + Title: "Test Shortcut", + Filter: "tag in [\"test\"]", + }, + } + + _, err = ts.Service.CreateShortcut(userCtx, req) + require.Error(t, err) + require.Contains(t, err.Error(), "invalid user name") + }) + + t.Run("CreateShortcut invalid filter", func(t *testing.T) { + ts := NewTestService(t) + defer ts.Cleanup() + + // Create user + user, err := ts.CreateRegularUser(ctx, "testuser") + require.NoError(t, err) + + // Set user context + userCtx := ts.CreateUserContext(ctx, user.Username) + + req := &v1pb.CreateShortcutRequest{ + Parent: fmt.Sprintf("users/%d", user.ID), + Shortcut: &v1pb.Shortcut{ + Title: "Invalid Filter Shortcut", + Filter: "invalid||filter))syntax", + }, + } + + _, err = ts.Service.CreateShortcut(userCtx, req) + require.Error(t, err) + require.Contains(t, err.Error(), "invalid filter") + }) + + t.Run("CreateShortcut missing title", func(t *testing.T) { + ts := NewTestService(t) + defer ts.Cleanup() + + // Create user + user, err := ts.CreateRegularUser(ctx, "testuser") + require.NoError(t, err) + + // Set user context + userCtx := ts.CreateUserContext(ctx, user.Username) + + req := &v1pb.CreateShortcutRequest{ + Parent: fmt.Sprintf("users/%d", user.ID), + Shortcut: &v1pb.Shortcut{ + Filter: "tag in [\"test\"]", + }, + } + + _, err = ts.Service.CreateShortcut(userCtx, req) + require.Error(t, err) + require.Contains(t, err.Error(), "title is required") + }) +} + +func TestUpdateShortcut(t *testing.T) { + ctx := context.Background() + + t.Run("UpdateShortcut success", func(t *testing.T) { + ts := NewTestService(t) + defer ts.Cleanup() + + // Create a user + user, err := ts.CreateRegularUser(ctx, "testuser") + require.NoError(t, err) + + // Set user context + userCtx := ts.CreateUserContext(ctx, user.Username) + + // Create a shortcut first + createReq := &v1pb.CreateShortcutRequest{ + Parent: fmt.Sprintf("users/%d", user.ID), + Shortcut: &v1pb.Shortcut{ + Title: "Original Title", + Filter: "tag in [\"original\"]", + }, + } + + created, err := ts.Service.CreateShortcut(userCtx, createReq) + require.NoError(t, err) + + // Update the shortcut + updateReq := &v1pb.UpdateShortcutRequest{ + Shortcut: &v1pb.Shortcut{ + Name: created.Name, + Title: "Updated Title", + Filter: "tag in [\"updated\"]", + }, + UpdateMask: &fieldmaskpb.FieldMask{ + Paths: []string{"title", "filter"}, + }, + } + + updated, err := ts.Service.UpdateShortcut(userCtx, updateReq) + require.NoError(t, err) + require.NotNil(t, updated) + require.Equal(t, "Updated Title", updated.Title) + require.Equal(t, "tag in [\"updated\"]", updated.Filter) + require.Equal(t, created.Name, updated.Name) + }) + + t.Run("UpdateShortcut permission denied for different user", func(t *testing.T) { + ts := NewTestService(t) + defer ts.Cleanup() + + // Create two users + user1, err := ts.CreateRegularUser(ctx, "user1") + require.NoError(t, err) + user2, err := ts.CreateRegularUser(ctx, "user2") + require.NoError(t, err) + + // Create shortcut as user1 + user1Ctx := ts.CreateUserContext(ctx, user1.Username) + createReq := &v1pb.CreateShortcutRequest{ + Parent: fmt.Sprintf("users/%d", user1.ID), + Shortcut: &v1pb.Shortcut{ + Title: "User1 Shortcut", + Filter: "tag in [\"user1\"]", + }, + } + + created, err := ts.Service.CreateShortcut(user1Ctx, createReq) + require.NoError(t, err) + + // Try to update shortcut as user2 + user2Ctx := ts.CreateUserContext(ctx, user2.Username) + updateReq := &v1pb.UpdateShortcutRequest{ + Shortcut: &v1pb.Shortcut{ + Name: created.Name, + Title: "Hacked Title", + Filter: "tag in [\"hacked\"]", + }, + UpdateMask: &fieldmaskpb.FieldMask{ + Paths: []string{"title", "filter"}, + }, + } + + _, err = ts.Service.UpdateShortcut(user2Ctx, updateReq) + require.Error(t, err) + require.Contains(t, err.Error(), "permission denied") + }) + + t.Run("UpdateShortcut missing update mask", func(t *testing.T) { + ts := NewTestService(t) + defer ts.Cleanup() + + // Create a user and context for authentication + user, err := ts.CreateRegularUser(ctx, "testuser") + require.NoError(t, err) + userCtx := ts.CreateUserContext(ctx, user.Username) + + req := &v1pb.UpdateShortcutRequest{ + Shortcut: &v1pb.Shortcut{ + Name: fmt.Sprintf("users/%d/shortcuts/test", user.ID), + Title: "Updated Title", + }, + } + + _, err = ts.Service.UpdateShortcut(userCtx, req) + require.Error(t, err) + require.Contains(t, err.Error(), "update mask is required") + }) + + t.Run("UpdateShortcut invalid name format", func(t *testing.T) { + ts := NewTestService(t) + defer ts.Cleanup() + + req := &v1pb.UpdateShortcutRequest{ + Shortcut: &v1pb.Shortcut{ + Name: "invalid-shortcut-name", + Title: "Updated Title", + }, + UpdateMask: &fieldmaskpb.FieldMask{ + Paths: []string{"title"}, + }, + } + + _, err := ts.Service.UpdateShortcut(ctx, req) + require.Error(t, err) + require.Contains(t, err.Error(), "invalid shortcut name") + }) + + t.Run("UpdateShortcut invalid filter", func(t *testing.T) { + ts := NewTestService(t) + defer ts.Cleanup() + + // Create user + user, err := ts.CreateRegularUser(ctx, "testuser") + require.NoError(t, err) + + // Set user context + userCtx := ts.CreateUserContext(ctx, user.Username) + + // Create a shortcut first + createReq := &v1pb.CreateShortcutRequest{ + Parent: fmt.Sprintf("users/%d", user.ID), + Shortcut: &v1pb.Shortcut{ + Title: "Test Shortcut", + Filter: "tag in [\"test\"]", + }, + } + + created, err := ts.Service.CreateShortcut(userCtx, createReq) + require.NoError(t, err) + + // Try to update with invalid filter + updateReq := &v1pb.UpdateShortcutRequest{ + Shortcut: &v1pb.Shortcut{ + Name: created.Name, + Filter: "invalid||filter))syntax", + }, + UpdateMask: &fieldmaskpb.FieldMask{ + Paths: []string{"filter"}, + }, + } + + _, err = ts.Service.UpdateShortcut(userCtx, updateReq) + require.Error(t, err) + require.Contains(t, err.Error(), "invalid filter") + }) +} + +func TestDeleteShortcut(t *testing.T) { + ctx := context.Background() + + t.Run("DeleteShortcut success", func(t *testing.T) { + ts := NewTestService(t) + defer ts.Cleanup() + + // Create a user + user, err := ts.CreateRegularUser(ctx, "testuser") + require.NoError(t, err) + + // Set user context + userCtx := ts.CreateUserContext(ctx, user.Username) + + // Create a shortcut first + createReq := &v1pb.CreateShortcutRequest{ + Parent: fmt.Sprintf("users/%d", user.ID), + Shortcut: &v1pb.Shortcut{ + Title: "Shortcut to Delete", + Filter: "tag in [\"delete\"]", + }, + } + + created, err := ts.Service.CreateShortcut(userCtx, createReq) + require.NoError(t, err) + + // Delete the shortcut + deleteReq := &v1pb.DeleteShortcutRequest{ + Name: created.Name, + } + + _, err = ts.Service.DeleteShortcut(userCtx, deleteReq) + require.NoError(t, err) + + // Verify deletion by listing shortcuts + listReq := &v1pb.ListShortcutsRequest{ + Parent: fmt.Sprintf("users/%d", user.ID), + } + + listResp, err := ts.Service.ListShortcuts(userCtx, listReq) + require.NoError(t, err) + require.Empty(t, listResp.Shortcuts) + + // Also verify by trying to get the deleted shortcut + getReq := &v1pb.GetShortcutRequest{ + Name: created.Name, + } + + _, err = ts.Service.GetShortcut(userCtx, getReq) + require.Error(t, err) + require.Contains(t, err.Error(), "not found") + }) + + t.Run("DeleteShortcut permission denied for different user", func(t *testing.T) { + ts := NewTestService(t) + defer ts.Cleanup() + + // Create two users + user1, err := ts.CreateRegularUser(ctx, "user1") + require.NoError(t, err) + user2, err := ts.CreateRegularUser(ctx, "user2") + require.NoError(t, err) + + // Create shortcut as user1 + user1Ctx := ts.CreateUserContext(ctx, user1.Username) + createReq := &v1pb.CreateShortcutRequest{ + Parent: fmt.Sprintf("users/%d", user1.ID), + Shortcut: &v1pb.Shortcut{ + Title: "User1 Shortcut", + Filter: "tag in [\"user1\"]", + }, + } + + created, err := ts.Service.CreateShortcut(user1Ctx, createReq) + require.NoError(t, err) + + // Try to delete shortcut as user2 + user2Ctx := ts.CreateUserContext(ctx, user2.Username) + deleteReq := &v1pb.DeleteShortcutRequest{ + Name: created.Name, + } + + _, err = ts.Service.DeleteShortcut(user2Ctx, deleteReq) + require.Error(t, err) + require.Contains(t, err.Error(), "permission denied") + }) + + t.Run("DeleteShortcut invalid name format", func(t *testing.T) { + ts := NewTestService(t) + defer ts.Cleanup() + + req := &v1pb.DeleteShortcutRequest{ + Name: "invalid-shortcut-name", + } + + _, err := ts.Service.DeleteShortcut(ctx, req) + require.Error(t, err) + require.Contains(t, err.Error(), "invalid shortcut name") + }) + + t.Run("DeleteShortcut not found", func(t *testing.T) { + ts := NewTestService(t) + defer ts.Cleanup() + + // Create user + user, err := ts.CreateRegularUser(ctx, "testuser") + require.NoError(t, err) + + // Set user context + userCtx := ts.CreateUserContext(ctx, user.Username) + + req := &v1pb.DeleteShortcutRequest{ + Name: fmt.Sprintf("users/%d", user.ID) + "/shortcuts/nonexistent", + } + + _, err = ts.Service.DeleteShortcut(userCtx, req) + require.Error(t, err) + require.Contains(t, err.Error(), "not found") + }) +} + +func TestShortcutFiltering(t *testing.T) { + ctx := context.Background() + + t.Run("CreateShortcut with valid filters", func(t *testing.T) { + ts := NewTestService(t) + defer ts.Cleanup() + + // Create user + user, err := ts.CreateRegularUser(ctx, "testuser") + require.NoError(t, err) + + // Set user context + userCtx := ts.CreateUserContext(ctx, user.Username) + + // Test various valid filter formats + validFilters := []string{ + "tag in [\"work\"]", + "content.contains(\"meeting\")", + "tag in [\"work\"] && content.contains(\"meeting\")", + "tag in [\"work\"] || tag in [\"personal\"]", + "creator_id == 1", + "visibility == \"PUBLIC\"", + "has_task_list == true", + "has_task_list == false", + } + + for i, filter := range validFilters { + req := &v1pb.CreateShortcutRequest{ + Parent: fmt.Sprintf("users/%d", user.ID), + Shortcut: &v1pb.Shortcut{ + Title: "Valid Filter " + string(rune(i)), + Filter: filter, + }, + } + + _, err = ts.Service.CreateShortcut(userCtx, req) + require.NoError(t, err, "Filter should be valid: %s", filter) + } + }) + + t.Run("CreateShortcut with invalid filters", func(t *testing.T) { + ts := NewTestService(t) + defer ts.Cleanup() + + // Create user + user, err := ts.CreateRegularUser(ctx, "testuser") + require.NoError(t, err) + + // Set user context + userCtx := ts.CreateUserContext(ctx, user.Username) + + // Test various invalid filter formats + invalidFilters := []string{ + "tag in ", // incomplete expression + "invalid_field @in [\"value\"]", // unknown field + "tag in [\"work\"] &&", // incomplete expression + "tag in [\"work\"] || || tag in [\"test\"]", // double operator + "((tag in [\"work\"]", // unmatched parentheses + "tag in [\"work\"] && )", // mismatched parentheses + "tag == \"work\"", // wrong operator (== not supported for tags) + "tag in work", // missing brackets + } + + for _, filter := range invalidFilters { + req := &v1pb.CreateShortcutRequest{ + Parent: fmt.Sprintf("users/%d", user.ID), + Shortcut: &v1pb.Shortcut{ + Title: "Invalid Filter Test", + Filter: filter, + }, + } + + _, err = ts.Service.CreateShortcut(userCtx, req) + require.Error(t, err, "Filter should be invalid: %s", filter) + require.Contains(t, err.Error(), "invalid filter", "Error should mention invalid filter for: %s", filter) + } + }) +} + +func TestShortcutCRUDComplete(t *testing.T) { + ctx := context.Background() + + t.Run("Complete CRUD lifecycle", func(t *testing.T) { + ts := NewTestService(t) + defer ts.Cleanup() + + // Create user + user, err := ts.CreateRegularUser(ctx, "testuser") + require.NoError(t, err) + + // Set user context + userCtx := ts.CreateUserContext(ctx, user.Username) + + // 1. Create multiple shortcuts + shortcut1Req := &v1pb.CreateShortcutRequest{ + Parent: fmt.Sprintf("users/%d", user.ID), + Shortcut: &v1pb.Shortcut{ + Title: "Work Notes", + Filter: "tag in [\"work\"]", + }, + } + + shortcut2Req := &v1pb.CreateShortcutRequest{ + Parent: fmt.Sprintf("users/%d", user.ID), + Shortcut: &v1pb.Shortcut{ + Title: "Personal Notes", + Filter: "tag in [\"personal\"]", + }, + } + + created1, err := ts.Service.CreateShortcut(userCtx, shortcut1Req) + require.NoError(t, err) + require.Equal(t, "Work Notes", created1.Title) + + created2, err := ts.Service.CreateShortcut(userCtx, shortcut2Req) + require.NoError(t, err) + require.Equal(t, "Personal Notes", created2.Title) + + // 2. List shortcuts and verify both exist + listReq := &v1pb.ListShortcutsRequest{ + Parent: fmt.Sprintf("users/%d", user.ID), + } + + listResp, err := ts.Service.ListShortcuts(userCtx, listReq) + require.NoError(t, err) + require.Len(t, listResp.Shortcuts, 2) + + // 3. Get individual shortcuts + getReq1 := &v1pb.GetShortcutRequest{Name: created1.Name} + getResp1, err := ts.Service.GetShortcut(userCtx, getReq1) + require.NoError(t, err) + require.Equal(t, created1.Name, getResp1.Name) + require.Equal(t, "Work Notes", getResp1.Title) + + getReq2 := &v1pb.GetShortcutRequest{Name: created2.Name} + getResp2, err := ts.Service.GetShortcut(userCtx, getReq2) + require.NoError(t, err) + require.Equal(t, created2.Name, getResp2.Name) + require.Equal(t, "Personal Notes", getResp2.Title) + + // 4. Update one shortcut + updateReq := &v1pb.UpdateShortcutRequest{ + Shortcut: &v1pb.Shortcut{ + Name: created1.Name, + Title: "Work & Meeting Notes", + Filter: "tag in [\"work\"] || tag in [\"meeting\"]", + }, + UpdateMask: &fieldmaskpb.FieldMask{ + Paths: []string{"title", "filter"}, + }, + } + + updated, err := ts.Service.UpdateShortcut(userCtx, updateReq) + require.NoError(t, err) + require.Equal(t, "Work & Meeting Notes", updated.Title) + require.Equal(t, "tag in [\"work\"] || tag in [\"meeting\"]", updated.Filter) + + // 5. Verify update by getting it again + getUpdatedReq := &v1pb.GetShortcutRequest{Name: created1.Name} + getUpdatedResp, err := ts.Service.GetShortcut(userCtx, getUpdatedReq) + require.NoError(t, err) + require.Equal(t, "Work & Meeting Notes", getUpdatedResp.Title) + require.Equal(t, "tag in [\"work\"] || tag in [\"meeting\"]", getUpdatedResp.Filter) + + // 6. Delete one shortcut + deleteReq := &v1pb.DeleteShortcutRequest{ + Name: created2.Name, + } + + _, err = ts.Service.DeleteShortcut(userCtx, deleteReq) + require.NoError(t, err) + + // 7. Verify deletion by listing (should only have 1 left) + finalListResp, err := ts.Service.ListShortcuts(userCtx, listReq) + require.NoError(t, err) + require.Len(t, finalListResp.Shortcuts, 1) + require.Equal(t, "Work & Meeting Notes", finalListResp.Shortcuts[0].Title) + + // 8. Verify deleted shortcut can't be accessed + getDeletedReq := &v1pb.GetShortcutRequest{Name: created2.Name} + _, err = ts.Service.GetShortcut(userCtx, getDeletedReq) + require.Error(t, err) + require.Contains(t, err.Error(), "not found") + }) +} diff --git a/server/router/api/v1/test_helper.go b/server/router/api/v1/test/test_helper.go similarity index 85% rename from server/router/api/v1/test_helper.go rename to server/router/api/v1/test/test_helper.go index 7b03bf5aa..a2fe0ea86 100644 --- a/server/router/api/v1/test_helper.go +++ b/server/router/api/v1/test/test_helper.go @@ -5,13 +5,14 @@ import ( "testing" "github.com/usememos/memos/internal/profile" + apiv1 "github.com/usememos/memos/server/router/api/v1" "github.com/usememos/memos/store" teststore "github.com/usememos/memos/store/test" ) // TestService holds the test service setup for API v1 services. type TestService struct { - Service *APIV1Service + Service *apiv1.APIV1Service Store *store.Store Profile *profile.Profile Secret string @@ -35,7 +36,7 @@ func NewTestService(t *testing.T) *TestService { // Create APIV1Service with nil grpcServer since we're testing direct calls secret := "test-secret" - service := &APIV1Service{ + service := &apiv1.APIV1Service{ Secret: secret, Profile: testProfile, Store: testStore, @@ -52,8 +53,7 @@ func NewTestService(t *testing.T) *TestService { // Cleanup clears caches and closes resources after test. func (ts *TestService) Cleanup() { ts.Store.Close() - // Clear the global owner cache for test isolation - ownerCache = nil + // Note: Owner cache is package-level in parent package, cannot clear from test package } // CreateHostUser creates a host user for testing. @@ -76,6 +76,6 @@ func (ts *TestService) CreateRegularUser(ctx context.Context, username string) ( // CreateUserContext creates a context with the given username for authentication. func (ts *TestService) CreateUserContext(ctx context.Context, username string) context.Context { - _ = ts // Silence unused receiver warning - method is part of TestService interface - return context.WithValue(ctx, ContextKey(0), username) // usernameContextKey = 0 + // Use the real context key from the parent package + return apiv1.CreateTestUserContext(ctx, username) } diff --git a/server/router/api/v1/test/webhook_service_test.go b/server/router/api/v1/test/webhook_service_test.go new file mode 100644 index 000000000..0a0c1bb24 --- /dev/null +++ b/server/router/api/v1/test/webhook_service_test.go @@ -0,0 +1,408 @@ +package v1 + +import ( + "context" + "fmt" + "testing" + + "github.com/stretchr/testify/require" + "google.golang.org/protobuf/types/known/fieldmaskpb" + + v1pb "github.com/usememos/memos/proto/gen/api/v1" +) + +func TestCreateWebhook(t *testing.T) { + ctx := context.Background() + + t.Run("CreateWebhook with host user", func(t *testing.T) { + // Create test service for this specific test + ts := NewTestService(t) + defer ts.Cleanup() + + // Create and authenticate as host user + hostUser, err := ts.CreateHostUser(ctx, "admin") + require.NoError(t, err) + + userCtx := ts.CreateUserContext(ctx, hostUser.Username) + + // Create a webhook + req := &v1pb.CreateWebhookRequest{ + Webhook: &v1pb.Webhook{ + DisplayName: "Test Webhook", + Url: "https://example.com/webhook", + }, + } + + resp, err := ts.Service.CreateWebhook(userCtx, req) + + // Verify response + require.NoError(t, err) + require.NotNil(t, resp) + require.Equal(t, "Test Webhook", resp.DisplayName) + require.Equal(t, "https://example.com/webhook", resp.Url) + require.Contains(t, resp.Name, "webhooks/") + require.Equal(t, fmt.Sprintf("users/%d", hostUser.ID), resp.Creator) + }) + + t.Run("CreateWebhook fails without authentication", func(t *testing.T) { + // Create test service for this specific test + ts := NewTestService(t) + defer ts.Cleanup() + + // Try to create webhook without authentication + req := &v1pb.CreateWebhookRequest{ + Webhook: &v1pb.Webhook{ + DisplayName: "Test Webhook", + Url: "https://example.com/webhook", + }, + } + + _, err := ts.Service.CreateWebhook(ctx, req) + + // Should fail with permission denied or unauthenticated + require.Error(t, err) + }) + + t.Run("CreateWebhook fails with regular user", func(t *testing.T) { + // Create test service for this specific test + ts := NewTestService(t) + defer ts.Cleanup() + + // Create and authenticate as regular user + regularUser, err := ts.CreateRegularUser(ctx, "user1") + require.NoError(t, err) + + userCtx := ts.CreateUserContext(ctx, regularUser.Username) + + // Try to create webhook as regular user + req := &v1pb.CreateWebhookRequest{ + Webhook: &v1pb.Webhook{ + DisplayName: "Test Webhook", + Url: "https://example.com/webhook", + }, + } + + _, err = ts.Service.CreateWebhook(userCtx, req) + + // Should fail with permission denied + require.Error(t, err) + require.Contains(t, err.Error(), "permission denied") + }) + + t.Run("CreateWebhook validates required fields", func(t *testing.T) { + // Create test service for this specific test + ts := NewTestService(t) + defer ts.Cleanup() + + // Create and authenticate as host user + hostUser, err := ts.CreateHostUser(ctx, "admin") + require.NoError(t, err) + + userCtx := ts.CreateUserContext(ctx, hostUser.Username) + + // Try to create webhook with missing URL + req := &v1pb.CreateWebhookRequest{ + Webhook: &v1pb.Webhook{ + DisplayName: "Test Webhook", + // URL missing + }, + } + + _, err = ts.Service.CreateWebhook(userCtx, req) + + // Should fail with validation error + require.Error(t, err) + }) +} + +func TestListWebhooks(t *testing.T) { + ctx := context.Background() + + t.Run("ListWebhooks returns empty list initially", func(t *testing.T) { + // Create test service for this specific test + ts := NewTestService(t) + defer ts.Cleanup() + + // Create host user for authentication + hostUser, err := ts.CreateHostUser(ctx, "admin") + require.NoError(t, err) + + userCtx := ts.CreateUserContext(ctx, hostUser.Username) + + // List webhooks + req := &v1pb.ListWebhooksRequest{} + resp, err := ts.Service.ListWebhooks(userCtx, req) + + // Verify response + require.NoError(t, err) + require.NotNil(t, resp) + require.Empty(t, resp.Webhooks) + }) + + t.Run("ListWebhooks returns created webhooks", func(t *testing.T) { + // Create test service for this specific test + ts := NewTestService(t) + defer ts.Cleanup() + + // Create host user and authenticate + hostUser, err := ts.CreateHostUser(ctx, "admin") + require.NoError(t, err) + userCtx := ts.CreateUserContext(ctx, hostUser.Username) + + // Create a webhook + createReq := &v1pb.CreateWebhookRequest{ + Webhook: &v1pb.Webhook{ + DisplayName: "Test Webhook", + Url: "https://example.com/webhook", + }, + } + createdWebhook, err := ts.Service.CreateWebhook(userCtx, createReq) + require.NoError(t, err) + + // List webhooks + listReq := &v1pb.ListWebhooksRequest{} + resp, err := ts.Service.ListWebhooks(userCtx, listReq) + + // Verify response + require.NoError(t, err) + require.NotNil(t, resp) + require.Len(t, resp.Webhooks, 1) + require.Equal(t, createdWebhook.Name, resp.Webhooks[0].Name) + require.Equal(t, createdWebhook.Url, resp.Webhooks[0].Url) + }) + + t.Run("ListWebhooks fails without authentication", func(t *testing.T) { + // Create test service for this specific test + ts := NewTestService(t) + defer ts.Cleanup() + + // Try to list webhooks without authentication + req := &v1pb.ListWebhooksRequest{} + _, err := ts.Service.ListWebhooks(ctx, req) + + // Should fail with permission denied or unauthenticated + require.Error(t, err) + }) +} + +func TestGetWebhook(t *testing.T) { + ctx := context.Background() + + t.Run("GetWebhook returns webhook by name", func(t *testing.T) { + // Create test service for this specific test + ts := NewTestService(t) + defer ts.Cleanup() + + // Create host user and authenticate + hostUser, err := ts.CreateHostUser(ctx, "admin") + require.NoError(t, err) + userCtx := ts.CreateUserContext(ctx, hostUser.Username) + + // Create a webhook + createReq := &v1pb.CreateWebhookRequest{ + Webhook: &v1pb.Webhook{ + DisplayName: "Test Webhook", + Url: "https://example.com/webhook", + }, + } + createdWebhook, err := ts.Service.CreateWebhook(userCtx, createReq) + require.NoError(t, err) + + // Get the webhook + getReq := &v1pb.GetWebhookRequest{ + Name: createdWebhook.Name, + } + resp, err := ts.Service.GetWebhook(userCtx, getReq) + + // Verify response + require.NoError(t, err) + require.NotNil(t, resp) + require.Equal(t, createdWebhook.Name, resp.Name) + require.Equal(t, createdWebhook.Url, resp.Url) + require.Equal(t, createdWebhook.Creator, resp.Creator) + }) + + t.Run("GetWebhook fails with invalid name", func(t *testing.T) { + // Create test service for this specific test + ts := NewTestService(t) + defer ts.Cleanup() + + // Create host user and authenticate + hostUser, err := ts.CreateHostUser(ctx, "admin") + require.NoError(t, err) + userCtx := ts.CreateUserContext(ctx, hostUser.Username) + + // Try to get webhook with invalid name + req := &v1pb.GetWebhookRequest{ + Name: "invalid/webhook/name", + } + _, err = ts.Service.GetWebhook(userCtx, req) + + // Should return an error + require.Error(t, err) + }) + + t.Run("GetWebhook fails with non-existent webhook", func(t *testing.T) { + // Create test service for this specific test + ts := NewTestService(t) + defer ts.Cleanup() + + // Create host user and authenticate + hostUser, err := ts.CreateHostUser(ctx, "admin") + require.NoError(t, err) + userCtx := ts.CreateUserContext(ctx, hostUser.Username) + + // Try to get non-existent webhook + req := &v1pb.GetWebhookRequest{ + Name: "webhooks/999", + } + _, err = ts.Service.GetWebhook(userCtx, req) + + // Should return not found error + require.Error(t, err) + require.Contains(t, err.Error(), "not found") + }) +} + +func TestUpdateWebhook(t *testing.T) { + ctx := context.Background() + + t.Run("UpdateWebhook updates webhook properties", func(t *testing.T) { + // Create test service for this specific test + ts := NewTestService(t) + defer ts.Cleanup() + + // Create host user and authenticate + hostUser, err := ts.CreateHostUser(ctx, "admin") + require.NoError(t, err) + userCtx := ts.CreateUserContext(ctx, hostUser.Username) + + // Create a webhook + createReq := &v1pb.CreateWebhookRequest{ + Webhook: &v1pb.Webhook{ + Name: "Original Webhook", + Url: "https://example.com/webhook", + }, + } + createdWebhook, err := ts.Service.CreateWebhook(userCtx, createReq) + require.NoError(t, err) + + // Update the webhook + updateReq := &v1pb.UpdateWebhookRequest{ + Webhook: &v1pb.Webhook{ + Name: createdWebhook.Name, + Url: "https://updated.example.com/webhook", + }, + UpdateMask: &fieldmaskpb.FieldMask{ + Paths: []string{"url"}, + }, + } + resp, err := ts.Service.UpdateWebhook(userCtx, updateReq) + + // Verify response + require.NoError(t, err) + require.NotNil(t, resp) + require.Equal(t, createdWebhook.Name, resp.Name) + require.Equal(t, "https://updated.example.com/webhook", resp.Url) + }) + + t.Run("UpdateWebhook fails without authentication", func(t *testing.T) { + // Create test service for this specific test + ts := NewTestService(t) + defer ts.Cleanup() + + // Try to update webhook without authentication + req := &v1pb.UpdateWebhookRequest{ + Webhook: &v1pb.Webhook{ + Name: "webhooks/1", + Url: "https://updated.example.com/webhook", + }, + } + + _, err := ts.Service.UpdateWebhook(ctx, req) + + // Should fail with permission denied or unauthenticated + require.Error(t, err) + }) +} + +func TestDeleteWebhook(t *testing.T) { + ctx := context.Background() + + t.Run("DeleteWebhook removes webhook", func(t *testing.T) { + // Create test service for this specific test + ts := NewTestService(t) + defer ts.Cleanup() + + // Create host user and authenticate + hostUser, err := ts.CreateHostUser(ctx, "admin") + require.NoError(t, err) + userCtx := ts.CreateUserContext(ctx, hostUser.Username) + + // Create a webhook + createReq := &v1pb.CreateWebhookRequest{ + Webhook: &v1pb.Webhook{ + DisplayName: "Test Webhook", + Url: "https://example.com/webhook", + }, + } + createdWebhook, err := ts.Service.CreateWebhook(userCtx, createReq) + require.NoError(t, err) + + // Delete the webhook + deleteReq := &v1pb.DeleteWebhookRequest{ + Name: createdWebhook.Name, + } + _, err = ts.Service.DeleteWebhook(userCtx, deleteReq) + + // Verify deletion + require.NoError(t, err) + + // Try to get the deleted webhook + getReq := &v1pb.GetWebhookRequest{ + Name: createdWebhook.Name, + } + _, err = ts.Service.GetWebhook(userCtx, getReq) + + // Should return not found error + require.Error(t, err) + require.Contains(t, err.Error(), "not found") + }) + + t.Run("DeleteWebhook fails without authentication", func(t *testing.T) { + // Create test service for this specific test + ts := NewTestService(t) + defer ts.Cleanup() + + // Try to delete webhook without authentication + req := &v1pb.DeleteWebhookRequest{ + Name: "webhooks/1", + } + + _, err := ts.Service.DeleteWebhook(ctx, req) + + // Should fail with permission denied or unauthenticated + require.Error(t, err) + }) + + t.Run("DeleteWebhook fails with non-existent webhook", func(t *testing.T) { + // Create test service for this specific test + ts := NewTestService(t) + defer ts.Cleanup() + + // Create host user and authenticate + hostUser, err := ts.CreateHostUser(ctx, "admin") + require.NoError(t, err) + userCtx := ts.CreateUserContext(ctx, hostUser.Username) + + // Try to delete non-existent webhook + req := &v1pb.DeleteWebhookRequest{ + Name: "webhooks/999", + } + _, err = ts.Service.DeleteWebhook(userCtx, req) + + // Should return not found error + require.Error(t, err) + require.Contains(t, err.Error(), "not found") + }) +} diff --git a/server/router/api/v1/workspace_service_test.go b/server/router/api/v1/test/workspace_service_test.go similarity index 70% rename from server/router/api/v1/workspace_service_test.go rename to server/router/api/v1/test/workspace_service_test.go index 2810c0c76..2971a2d79 100644 --- a/server/router/api/v1/workspace_service_test.go +++ b/server/router/api/v1/test/workspace_service_test.go @@ -64,86 +64,6 @@ func TestGetWorkspaceProfile(t *testing.T) { }) } -func TestGetWorkspaceProfile_ErrorCases(t *testing.T) { - ctx := context.Background() - - t.Run("Service handles multiple calls correctly", func(t *testing.T) { - // Create test service for this specific test - ts := NewTestService(t) - defer ts.Cleanup() - - // Make multiple calls to ensure consistency - for i := 0; i < 5; i++ { - req := &v1pb.GetWorkspaceProfileRequest{} - resp, err := ts.Service.GetWorkspaceProfile(ctx, req) - - require.NoError(t, err) - require.NotNil(t, resp) - require.Equal(t, "test-1.0.0", resp.Version) - require.Equal(t, "dev", resp.Mode) - require.Equal(t, "http://localhost:8080", resp.InstanceUrl) - require.Empty(t, resp.Owner) - } - }) - - t.Run("Multiple users, only host is returned as owner", func(t *testing.T) { - // Create test service for this specific test - ts := NewTestService(t) - defer ts.Cleanup() - - // Create a regular user first - _, err := ts.CreateRegularUser(ctx, "user1") - require.NoError(t, err) - - // Create another regular user - _, err = ts.CreateRegularUser(ctx, "user2") - require.NoError(t, err) - - // Create a host user - hostUser, err := ts.CreateHostUser(ctx, "admin") - require.NoError(t, err) - require.NotNil(t, hostUser) - - // Call GetWorkspaceProfile - req := &v1pb.GetWorkspaceProfileRequest{} - resp, err := ts.Service.GetWorkspaceProfile(ctx, req) - - // Verify response - require.NoError(t, err) - require.NotNil(t, resp) - - // Should return the host user as owner, not any of the regular users - expectedOwnerName := fmt.Sprintf("users/%d", hostUser.ID) - require.Equal(t, expectedOwnerName, resp.Owner) - }) - - t.Run("Cache behavior - owner cached after first lookup", func(t *testing.T) { - // Create test service for this specific test - ts := NewTestService(t) - defer ts.Cleanup() - - // Create a host user - hostUser, err := ts.CreateHostUser(ctx, "admin") - require.NoError(t, err) - expectedOwnerName := fmt.Sprintf("users/%d", hostUser.ID) - - // First call should query the database - req := &v1pb.GetWorkspaceProfileRequest{} - resp1, err := ts.Service.GetWorkspaceProfile(ctx, req) - require.NoError(t, err) - require.Equal(t, expectedOwnerName, resp1.Owner) - - // Create another host user (this shouldn't change the result due to caching) - _, err = ts.CreateHostUser(ctx, "admin2") - require.NoError(t, err) - - // Second call should return cached result (first host user) - resp2, err := ts.Service.GetWorkspaceProfile(ctx, req) - require.NoError(t, err) - require.Equal(t, expectedOwnerName, resp2.Owner) // Should still be the first host user - }) -} - func TestGetWorkspaceProfile_Concurrency(t *testing.T) { ctx := context.Background() diff --git a/server/router/api/v1/test_auth.go b/server/router/api/v1/test_auth.go new file mode 100644 index 000000000..7a1a36d75 --- /dev/null +++ b/server/router/api/v1/test_auth.go @@ -0,0 +1,19 @@ +package v1 + +import ( + "context" + + "github.com/usememos/memos/store" +) + +// CreateTestUserContext creates a context with username for testing purposes +// This function is only intended for use in tests +func CreateTestUserContext(ctx context.Context, username string) context.Context { + return context.WithValue(ctx, usernameContextKey, username) +} + +// CreateTestUserContextWithUser creates a context and ensures the user exists for testing +// This function is only intended for use in tests +func CreateTestUserContextWithUser(ctx context.Context, s *APIV1Service, user *store.User) context.Context { + return context.WithValue(ctx, usernameContextKey, user.Username) +} diff --git a/server/router/api/v1/webhook_service.go b/server/router/api/v1/webhook_service.go index 5cd312572..9d84289ba 100644 --- a/server/router/api/v1/webhook_service.go +++ b/server/router/api/v1/webhook_service.go @@ -21,8 +21,23 @@ func (s *APIV1Service) CreateWebhook(ctx context.Context, request *v1pb.CreateWe if err != nil { return nil, status.Errorf(codes.Internal, "failed to get user: %v", err) } + if currentUser == nil { + return nil, status.Errorf(codes.Unauthenticated, "user not authenticated") + } + - // TODO: Handle webhook_id, validate_only, and request_id fields +// Only host users can create webhooks +if !isSuperUser(currentUser) { +return nil, status.Errorf(codes.PermissionDenied, "permission denied") +} + +// Validate required fields +if request.Webhook == nil { +return nil, status.Errorf(codes.InvalidArgument, "webhook is required") +} +if strings.TrimSpace(request.Webhook.Url) == "" { +return nil, status.Errorf(codes.InvalidArgument, "webhook URL is required") +} // TODO: Handle webhook_id, validate_only, and request_id fields if request.ValidateOnly { // Perform validation checks without actually creating the webhook return &v1pb.Webhook{ @@ -49,6 +64,9 @@ func (s *APIV1Service) ListWebhooks(ctx context.Context, _ *v1pb.ListWebhooksReq if err != nil { return nil, status.Errorf(codes.Internal, "failed to get user: %v", err) } + if currentUser == nil { + return nil, status.Errorf(codes.Unauthenticated, "user not authenticated") + } // TODO: Implement proper filtering, ordering, and pagination // For now, list webhooks for the current user @@ -79,6 +97,9 @@ func (s *APIV1Service) GetWebhook(ctx context.Context, request *v1pb.GetWebhookR if err != nil { return nil, status.Errorf(codes.Internal, "failed to get user: %v", err) } + if currentUser == nil { + return nil, status.Errorf(codes.Unauthenticated, "user not authenticated") + } webhook, err := s.Store.GetWebhook(ctx, &store.FindWebhook{ ID: &webhookID, @@ -112,6 +133,9 @@ func (s *APIV1Service) UpdateWebhook(ctx context.Context, request *v1pb.UpdateWe if err != nil { return nil, status.Errorf(codes.Internal, "failed to get user: %v", err) } + if currentUser == nil { + return nil, status.Errorf(codes.Unauthenticated, "user not authenticated") + } // Check if webhook exists and user has permission existingWebhook, err := s.Store.GetWebhook(ctx, &store.FindWebhook{ @@ -160,6 +184,9 @@ func (s *APIV1Service) DeleteWebhook(ctx context.Context, request *v1pb.DeleteWe if err != nil { return nil, status.Errorf(codes.Internal, "failed to get user: %v", err) } + if currentUser == nil { + return nil, status.Errorf(codes.Unauthenticated, "user not authenticated") + } // Check if webhook exists and user has permission webhook, err := s.Store.GetWebhook(ctx, &store.FindWebhook{