From 21d31e3609be7bbc8204ce61a4e1e14f922a07e4 Mon Sep 17 00:00:00 2001 From: boojack Date: Thu, 6 Nov 2025 23:32:27 +0800 Subject: [PATCH] fix(security): implement security review recommendations (#5228) Co-authored-by: Claude --- .../router/api/v1/memo_attachment_service.go | 3 + server/router/api/v1/memo_relation_service.go | 3 + server/router/api/v1/reaction_service.go | 12 +- .../v1/test/memo_attachment_service_test.go | 166 +++++++++++++++ .../api/v1/test/memo_relation_service_test.go | 169 +++++++++++++++ .../api/v1/test/reaction_service_test.go | 193 ++++++++++++++++++ .../v1/test/user_service_registration_test.go | 165 +++++++++++++++ store/db/postgres/reaction.go | 13 ++ store/db/sqlite/reaction.go | 43 ++++ store/driver.go | 1 + store/reaction.go | 4 + store/test/reaction_test.go | 19 ++ 12 files changed, 785 insertions(+), 6 deletions(-) create mode 100644 server/router/api/v1/test/memo_attachment_service_test.go create mode 100644 server/router/api/v1/test/memo_relation_service_test.go create mode 100644 server/router/api/v1/test/reaction_service_test.go create mode 100644 server/router/api/v1/test/user_service_registration_test.go diff --git a/server/router/api/v1/memo_attachment_service.go b/server/router/api/v1/memo_attachment_service.go index 4084c9a8a..e396ac760 100644 --- a/server/router/api/v1/memo_attachment_service.go +++ b/server/router/api/v1/memo_attachment_service.go @@ -29,6 +29,9 @@ func (s *APIV1Service) SetMemoAttachments(ctx context.Context, request *v1pb.Set if err != nil { return nil, status.Errorf(codes.Internal, "failed to get memo") } + if memo == nil { + return nil, status.Errorf(codes.NotFound, "memo not found") + } if memo.CreatorID != user.ID && !isSuperUser(user) { return nil, status.Errorf(codes.PermissionDenied, "permission denied") } diff --git a/server/router/api/v1/memo_relation_service.go b/server/router/api/v1/memo_relation_service.go index 77cff1a38..4f49f7975 100644 --- a/server/router/api/v1/memo_relation_service.go +++ b/server/router/api/v1/memo_relation_service.go @@ -29,6 +29,9 @@ func (s *APIV1Service) SetMemoRelations(ctx context.Context, request *v1pb.SetMe if err != nil { return nil, status.Errorf(codes.Internal, "failed to get memo") } + if memo == nil { + return nil, status.Errorf(codes.NotFound, "memo not found") + } if memo.CreatorID != user.ID && !isSuperUser(user) { return nil, status.Errorf(codes.PermissionDenied, "permission denied") } diff --git a/server/router/api/v1/reaction_service.go b/server/router/api/v1/reaction_service.go index f5ec6d96d..d644ad14b 100644 --- a/server/router/api/v1/reaction_service.go +++ b/server/router/api/v1/reaction_service.go @@ -68,18 +68,18 @@ func (s *APIV1Service) DeleteMemoReaction(ctx context.Context, request *v1pb.Del return nil, status.Errorf(codes.InvalidArgument, "invalid reaction name: %v", err) } - // Get reaction and check ownership - reactions, err := s.Store.ListReactions(ctx, &store.FindReaction{ + // Get reaction and check ownership. + reaction, err := s.Store.GetReaction(ctx, &store.FindReaction{ ID: &reactionID, }) if err != nil { - return nil, status.Errorf(codes.Internal, "failed to list reactions") + return nil, status.Errorf(codes.Internal, "failed to get reaction") } - if len(reactions) == 0 { - return nil, status.Errorf(codes.NotFound, "reaction not found") + if reaction == nil { + // Return permission denied to avoid revealing if reaction exists. + return nil, status.Errorf(codes.PermissionDenied, "permission denied") } - reaction := reactions[0] if reaction.CreatorID != user.ID && !isSuperUser(user) { return nil, status.Errorf(codes.PermissionDenied, "permission denied") } diff --git a/server/router/api/v1/test/memo_attachment_service_test.go b/server/router/api/v1/test/memo_attachment_service_test.go new file mode 100644 index 000000000..41abb629e --- /dev/null +++ b/server/router/api/v1/test/memo_attachment_service_test.go @@ -0,0 +1,166 @@ +package test + +import ( + "context" + "testing" + + "github.com/stretchr/testify/require" + + apiv1 "github.com/usememos/memos/proto/gen/api/v1" +) + +func TestSetMemoAttachments(t *testing.T) { + ctx := context.Background() + + t.Run("SetMemoAttachments success by memo owner", func(t *testing.T) { + ts := NewTestService(t) + defer ts.Cleanup() + + // Create user + user, err := ts.CreateRegularUser(ctx, "user") + require.NoError(t, err) + userCtx := ts.CreateUserContext(ctx, user.ID) + + // Create memo + memo, err := ts.Service.CreateMemo(userCtx, &apiv1.CreateMemoRequest{ + Memo: &apiv1.Memo{ + Content: "Test memo", + Visibility: apiv1.Visibility_PRIVATE, + }, + }) + require.NoError(t, err) + require.NotNil(t, memo) + + // Create attachment + attachment, err := ts.Service.CreateAttachment(userCtx, &apiv1.CreateAttachmentRequest{ + Attachment: &apiv1.Attachment{ + Filename: "test.txt", + Size: 5, + Type: "text/plain", + Content: []byte("hello"), + }, + }) + require.NoError(t, err) + require.NotNil(t, attachment) + + // Set memo attachments - should succeed + _, err = ts.Service.SetMemoAttachments(userCtx, &apiv1.SetMemoAttachmentsRequest{ + Name: memo.Name, + Attachments: []*apiv1.Attachment{ + {Name: attachment.Name}, + }, + }) + require.NoError(t, err) + }) + + t.Run("SetMemoAttachments success by host user", func(t *testing.T) { + ts := NewTestService(t) + defer ts.Cleanup() + + // Create regular user + regularUser, err := ts.CreateRegularUser(ctx, "user") + require.NoError(t, err) + regularUserCtx := ts.CreateUserContext(ctx, regularUser.ID) + + // Create host user + hostUser, err := ts.CreateHostUser(ctx, "admin") + require.NoError(t, err) + hostCtx := ts.CreateUserContext(ctx, hostUser.ID) + + // Create memo by regular user + memo, err := ts.Service.CreateMemo(regularUserCtx, &apiv1.CreateMemoRequest{ + Memo: &apiv1.Memo{ + Content: "Test memo", + Visibility: apiv1.Visibility_PRIVATE, + }, + }) + require.NoError(t, err) + require.NotNil(t, memo) + + // Host user can modify attachments - should succeed + _, err = ts.Service.SetMemoAttachments(hostCtx, &apiv1.SetMemoAttachmentsRequest{ + Name: memo.Name, + Attachments: []*apiv1.Attachment{}, + }) + require.NoError(t, err) + }) + + t.Run("SetMemoAttachments permission denied for non-owner", func(t *testing.T) { + ts := NewTestService(t) + defer ts.Cleanup() + + // Create user1 + user1, err := ts.CreateRegularUser(ctx, "user1") + require.NoError(t, err) + user1Ctx := ts.CreateUserContext(ctx, user1.ID) + + // Create user2 + user2, err := ts.CreateRegularUser(ctx, "user2") + require.NoError(t, err) + user2Ctx := ts.CreateUserContext(ctx, user2.ID) + + // Create memo by user1 + memo, err := ts.Service.CreateMemo(user1Ctx, &apiv1.CreateMemoRequest{ + Memo: &apiv1.Memo{ + Content: "Test memo", + Visibility: apiv1.Visibility_PRIVATE, + }, + }) + require.NoError(t, err) + require.NotNil(t, memo) + + // User2 tries to modify attachments - should fail + _, err = ts.Service.SetMemoAttachments(user2Ctx, &apiv1.SetMemoAttachmentsRequest{ + Name: memo.Name, + Attachments: []*apiv1.Attachment{}, + }) + require.Error(t, err) + require.Contains(t, err.Error(), "permission denied") + }) + + t.Run("SetMemoAttachments unauthenticated", func(t *testing.T) { + ts := NewTestService(t) + defer ts.Cleanup() + + // Create user + user, err := ts.CreateRegularUser(ctx, "user") + require.NoError(t, err) + userCtx := ts.CreateUserContext(ctx, user.ID) + + // Create memo + memo, err := ts.Service.CreateMemo(userCtx, &apiv1.CreateMemoRequest{ + Memo: &apiv1.Memo{ + Content: "Test memo", + Visibility: apiv1.Visibility_PRIVATE, + }, + }) + require.NoError(t, err) + require.NotNil(t, memo) + + // Unauthenticated user tries to modify attachments - should fail + _, err = ts.Service.SetMemoAttachments(ctx, &apiv1.SetMemoAttachmentsRequest{ + Name: memo.Name, + Attachments: []*apiv1.Attachment{}, + }) + require.Error(t, err) + require.Contains(t, err.Error(), "not authenticated") + }) + + t.Run("SetMemoAttachments memo not found", func(t *testing.T) { + ts := NewTestService(t) + defer ts.Cleanup() + + // Create user + user, err := ts.CreateRegularUser(ctx, "user") + require.NoError(t, err) + userCtx := ts.CreateUserContext(ctx, user.ID) + + // Try to set attachments on non-existent memo - should fail + _, err = ts.Service.SetMemoAttachments(userCtx, &apiv1.SetMemoAttachmentsRequest{ + Name: "memos/nonexistent-uid-12345", + Attachments: []*apiv1.Attachment{}, + }) + require.Error(t, err) + require.Contains(t, err.Error(), "not found") + }) +} diff --git a/server/router/api/v1/test/memo_relation_service_test.go b/server/router/api/v1/test/memo_relation_service_test.go new file mode 100644 index 000000000..98c2c5659 --- /dev/null +++ b/server/router/api/v1/test/memo_relation_service_test.go @@ -0,0 +1,169 @@ +package test + +import ( + "context" + "testing" + + "github.com/stretchr/testify/require" + + apiv1 "github.com/usememos/memos/proto/gen/api/v1" +) + +func TestSetMemoRelations(t *testing.T) { + ctx := context.Background() + + t.Run("SetMemoRelations success by memo owner", func(t *testing.T) { + ts := NewTestService(t) + defer ts.Cleanup() + + // Create user + user, err := ts.CreateRegularUser(ctx, "user") + require.NoError(t, err) + userCtx := ts.CreateUserContext(ctx, user.ID) + + // Create memo1 + memo1, err := ts.Service.CreateMemo(userCtx, &apiv1.CreateMemoRequest{ + Memo: &apiv1.Memo{ + Content: "Test memo 1", + Visibility: apiv1.Visibility_PRIVATE, + }, + }) + require.NoError(t, err) + require.NotNil(t, memo1) + + // Create memo2 + memo2, err := ts.Service.CreateMemo(userCtx, &apiv1.CreateMemoRequest{ + Memo: &apiv1.Memo{ + Content: "Test memo 2", + Visibility: apiv1.Visibility_PRIVATE, + }, + }) + require.NoError(t, err) + require.NotNil(t, memo2) + + // Set memo relations - should succeed + _, err = ts.Service.SetMemoRelations(userCtx, &apiv1.SetMemoRelationsRequest{ + Name: memo1.Name, + Relations: []*apiv1.MemoRelation{ + { + RelatedMemo: &apiv1.MemoRelation_Memo{ + Name: memo2.Name, + }, + Type: apiv1.MemoRelation_REFERENCE, + }, + }, + }) + require.NoError(t, err) + }) + + t.Run("SetMemoRelations success by host user", func(t *testing.T) { + ts := NewTestService(t) + defer ts.Cleanup() + + // Create regular user + regularUser, err := ts.CreateRegularUser(ctx, "user") + require.NoError(t, err) + regularUserCtx := ts.CreateUserContext(ctx, regularUser.ID) + + // Create host user + hostUser, err := ts.CreateHostUser(ctx, "admin") + require.NoError(t, err) + hostCtx := ts.CreateUserContext(ctx, hostUser.ID) + + // Create memo by regular user + memo, err := ts.Service.CreateMemo(regularUserCtx, &apiv1.CreateMemoRequest{ + Memo: &apiv1.Memo{ + Content: "Test memo", + Visibility: apiv1.Visibility_PRIVATE, + }, + }) + require.NoError(t, err) + require.NotNil(t, memo) + + // Host user can modify relations - should succeed + _, err = ts.Service.SetMemoRelations(hostCtx, &apiv1.SetMemoRelationsRequest{ + Name: memo.Name, + Relations: []*apiv1.MemoRelation{}, + }) + require.NoError(t, err) + }) + + t.Run("SetMemoRelations permission denied for non-owner", func(t *testing.T) { + ts := NewTestService(t) + defer ts.Cleanup() + + // Create user1 + user1, err := ts.CreateRegularUser(ctx, "user1") + require.NoError(t, err) + user1Ctx := ts.CreateUserContext(ctx, user1.ID) + + // Create user2 + user2, err := ts.CreateRegularUser(ctx, "user2") + require.NoError(t, err) + user2Ctx := ts.CreateUserContext(ctx, user2.ID) + + // Create memo by user1 + memo, err := ts.Service.CreateMemo(user1Ctx, &apiv1.CreateMemoRequest{ + Memo: &apiv1.Memo{ + Content: "Test memo", + Visibility: apiv1.Visibility_PRIVATE, + }, + }) + require.NoError(t, err) + require.NotNil(t, memo) + + // User2 tries to modify relations - should fail + _, err = ts.Service.SetMemoRelations(user2Ctx, &apiv1.SetMemoRelationsRequest{ + Name: memo.Name, + Relations: []*apiv1.MemoRelation{}, + }) + require.Error(t, err) + require.Contains(t, err.Error(), "permission denied") + }) + + t.Run("SetMemoRelations unauthenticated", func(t *testing.T) { + ts := NewTestService(t) + defer ts.Cleanup() + + // Create user + user, err := ts.CreateRegularUser(ctx, "user") + require.NoError(t, err) + userCtx := ts.CreateUserContext(ctx, user.ID) + + // Create memo + memo, err := ts.Service.CreateMemo(userCtx, &apiv1.CreateMemoRequest{ + Memo: &apiv1.Memo{ + Content: "Test memo", + Visibility: apiv1.Visibility_PRIVATE, + }, + }) + require.NoError(t, err) + require.NotNil(t, memo) + + // Unauthenticated user tries to modify relations - should fail + _, err = ts.Service.SetMemoRelations(ctx, &apiv1.SetMemoRelationsRequest{ + Name: memo.Name, + Relations: []*apiv1.MemoRelation{}, + }) + require.Error(t, err) + require.Contains(t, err.Error(), "not authenticated") + }) + + t.Run("SetMemoRelations memo not found", func(t *testing.T) { + ts := NewTestService(t) + defer ts.Cleanup() + + // Create user + user, err := ts.CreateRegularUser(ctx, "user") + require.NoError(t, err) + userCtx := ts.CreateUserContext(ctx, user.ID) + + // Try to set relations on non-existent memo - should fail + _, err = ts.Service.SetMemoRelations(userCtx, &apiv1.SetMemoRelationsRequest{ + Name: "memos/nonexistent-uid-12345", + Relations: []*apiv1.MemoRelation{}, + }) + require.Error(t, err) + require.Contains(t, err.Error(), "not found") + }) +} diff --git a/server/router/api/v1/test/reaction_service_test.go b/server/router/api/v1/test/reaction_service_test.go new file mode 100644 index 000000000..f763da266 --- /dev/null +++ b/server/router/api/v1/test/reaction_service_test.go @@ -0,0 +1,193 @@ +package test + +import ( + "context" + "testing" + + "github.com/stretchr/testify/require" + + apiv1 "github.com/usememos/memos/proto/gen/api/v1" +) + +func TestDeleteMemoReaction(t *testing.T) { + ctx := context.Background() + + t.Run("DeleteMemoReaction success by reaction owner", func(t *testing.T) { + ts := NewTestService(t) + defer ts.Cleanup() + + // Create user + user, err := ts.CreateRegularUser(ctx, "user") + require.NoError(t, err) + userCtx := ts.CreateUserContext(ctx, user.ID) + + // Create memo + memo, err := ts.Service.CreateMemo(userCtx, &apiv1.CreateMemoRequest{ + Memo: &apiv1.Memo{ + Content: "Test memo", + Visibility: apiv1.Visibility_PUBLIC, + }, + }) + require.NoError(t, err) + require.NotNil(t, memo) + + // Create reaction + reaction, err := ts.Service.UpsertMemoReaction(userCtx, &apiv1.UpsertMemoReactionRequest{ + Name: memo.Name, + Reaction: &apiv1.Reaction{ + ContentId: memo.Name, + ReactionType: "👍", + }, + }) + require.NoError(t, err) + require.NotNil(t, reaction) + + // Delete reaction - should succeed + _, err = ts.Service.DeleteMemoReaction(userCtx, &apiv1.DeleteMemoReactionRequest{ + Name: reaction.Name, + }) + require.NoError(t, err) + }) + + t.Run("DeleteMemoReaction success by host user", func(t *testing.T) { + ts := NewTestService(t) + defer ts.Cleanup() + + // Create regular user + regularUser, err := ts.CreateRegularUser(ctx, "user") + require.NoError(t, err) + regularUserCtx := ts.CreateUserContext(ctx, regularUser.ID) + + // Create host user + hostUser, err := ts.CreateHostUser(ctx, "admin") + require.NoError(t, err) + hostCtx := ts.CreateUserContext(ctx, hostUser.ID) + + // Create memo by regular user + memo, err := ts.Service.CreateMemo(regularUserCtx, &apiv1.CreateMemoRequest{ + Memo: &apiv1.Memo{ + Content: "Test memo", + Visibility: apiv1.Visibility_PUBLIC, + }, + }) + require.NoError(t, err) + require.NotNil(t, memo) + + // Create reaction by regular user + reaction, err := ts.Service.UpsertMemoReaction(regularUserCtx, &apiv1.UpsertMemoReactionRequest{ + Name: memo.Name, + Reaction: &apiv1.Reaction{ + ContentId: memo.Name, + ReactionType: "👍", + }, + }) + require.NoError(t, err) + require.NotNil(t, reaction) + + // Host user can delete reaction - should succeed + _, err = ts.Service.DeleteMemoReaction(hostCtx, &apiv1.DeleteMemoReactionRequest{ + Name: reaction.Name, + }) + require.NoError(t, err) + }) + + t.Run("DeleteMemoReaction permission denied for non-owner", func(t *testing.T) { + ts := NewTestService(t) + defer ts.Cleanup() + + // Create user1 + user1, err := ts.CreateRegularUser(ctx, "user1") + require.NoError(t, err) + user1Ctx := ts.CreateUserContext(ctx, user1.ID) + + // Create user2 + user2, err := ts.CreateRegularUser(ctx, "user2") + require.NoError(t, err) + user2Ctx := ts.CreateUserContext(ctx, user2.ID) + + // Create memo by user1 + memo, err := ts.Service.CreateMemo(user1Ctx, &apiv1.CreateMemoRequest{ + Memo: &apiv1.Memo{ + Content: "Test memo", + Visibility: apiv1.Visibility_PUBLIC, + }, + }) + require.NoError(t, err) + require.NotNil(t, memo) + + // Create reaction by user1 + reaction, err := ts.Service.UpsertMemoReaction(user1Ctx, &apiv1.UpsertMemoReactionRequest{ + Name: memo.Name, + Reaction: &apiv1.Reaction{ + ContentId: memo.Name, + ReactionType: "👍", + }, + }) + require.NoError(t, err) + require.NotNil(t, reaction) + + // User2 tries to delete reaction - should fail with permission denied + _, err = ts.Service.DeleteMemoReaction(user2Ctx, &apiv1.DeleteMemoReactionRequest{ + Name: reaction.Name, + }) + require.Error(t, err) + require.Contains(t, err.Error(), "permission denied") + }) + + t.Run("DeleteMemoReaction unauthenticated", func(t *testing.T) { + ts := NewTestService(t) + defer ts.Cleanup() + + // Create user + user, err := ts.CreateRegularUser(ctx, "user") + require.NoError(t, err) + userCtx := ts.CreateUserContext(ctx, user.ID) + + // Create memo + memo, err := ts.Service.CreateMemo(userCtx, &apiv1.CreateMemoRequest{ + Memo: &apiv1.Memo{ + Content: "Test memo", + Visibility: apiv1.Visibility_PUBLIC, + }, + }) + require.NoError(t, err) + require.NotNil(t, memo) + + // Create reaction + reaction, err := ts.Service.UpsertMemoReaction(userCtx, &apiv1.UpsertMemoReactionRequest{ + Name: memo.Name, + Reaction: &apiv1.Reaction{ + ContentId: memo.Name, + ReactionType: "👍", + }, + }) + require.NoError(t, err) + require.NotNil(t, reaction) + + // Unauthenticated user tries to delete reaction - should fail + _, err = ts.Service.DeleteMemoReaction(ctx, &apiv1.DeleteMemoReactionRequest{ + Name: reaction.Name, + }) + require.Error(t, err) + require.Contains(t, err.Error(), "not authenticated") + }) + + t.Run("DeleteMemoReaction not found returns permission denied", func(t *testing.T) { + ts := NewTestService(t) + defer ts.Cleanup() + + // Create user + user, err := ts.CreateRegularUser(ctx, "user") + require.NoError(t, err) + userCtx := ts.CreateUserContext(ctx, user.ID) + + // Try to delete non-existent reaction - should fail with permission denied + // (not "not found" to avoid information disclosure) + _, err = ts.Service.DeleteMemoReaction(userCtx, &apiv1.DeleteMemoReactionRequest{ + Name: "reactions/99999", + }) + require.Error(t, err) + require.Contains(t, err.Error(), "permission denied") + require.NotContains(t, err.Error(), "not found") + }) +} diff --git a/server/router/api/v1/test/user_service_registration_test.go b/server/router/api/v1/test/user_service_registration_test.go new file mode 100644 index 000000000..a6c6ccad4 --- /dev/null +++ b/server/router/api/v1/test/user_service_registration_test.go @@ -0,0 +1,165 @@ +package test + +import ( + "context" + "testing" + + "github.com/stretchr/testify/require" + + apiv1 "github.com/usememos/memos/proto/gen/api/v1" + storepb "github.com/usememos/memos/proto/gen/store" +) + +func TestCreateUserRegistration(t *testing.T) { + ctx := context.Background() + + t.Run("CreateUser success when registration enabled", func(t *testing.T) { + ts := NewTestService(t) + defer ts.Cleanup() + + // User registration is enabled by default, no need to set it explicitly + + // Create user without authentication - should succeed + _, err := ts.Service.CreateUser(ctx, &apiv1.CreateUserRequest{ + User: &apiv1.User{ + Username: "newuser", + Email: "newuser@example.com", + Password: "password123", + }, + }) + require.NoError(t, err) + }) + + t.Run("CreateUser blocked when registration disabled", func(t *testing.T) { + ts := NewTestService(t) + defer ts.Cleanup() + + // Disable user registration + _, err := ts.Store.UpsertInstanceSetting(ctx, &storepb.InstanceSetting{ + Key: storepb.InstanceSettingKey_GENERAL, + Value: &storepb.InstanceSetting_GeneralSetting{ + GeneralSetting: &storepb.InstanceGeneralSetting{ + DisallowUserRegistration: true, + }, + }, + }) + require.NoError(t, err) + + // Try to create user without authentication - should fail + _, err = ts.Service.CreateUser(ctx, &apiv1.CreateUserRequest{ + User: &apiv1.User{ + Username: "newuser", + Email: "newuser@example.com", + Password: "password123", + }, + }) + require.Error(t, err) + require.Contains(t, err.Error(), "not allowed") + }) + + t.Run("CreateUser succeeds for superuser even when registration disabled", func(t *testing.T) { + ts := NewTestService(t) + defer ts.Cleanup() + + // Create host user + hostUser, err := ts.CreateHostUser(ctx, "admin") + require.NoError(t, err) + hostCtx := ts.CreateUserContext(ctx, hostUser.ID) + + // Disable user registration + _, err = ts.Store.UpsertInstanceSetting(ctx, &storepb.InstanceSetting{ + Key: storepb.InstanceSettingKey_GENERAL, + Value: &storepb.InstanceSetting_GeneralSetting{ + GeneralSetting: &storepb.InstanceGeneralSetting{ + DisallowUserRegistration: true, + }, + }, + }) + require.NoError(t, err) + + // Host user can create users even when registration is disabled - should succeed + _, err = ts.Service.CreateUser(hostCtx, &apiv1.CreateUserRequest{ + User: &apiv1.User{ + Username: "newuser", + Email: "newuser@example.com", + Password: "password123", + }, + }) + require.NoError(t, err) + }) + + t.Run("CreateUser regular user cannot create users when registration disabled", func(t *testing.T) { + ts := NewTestService(t) + defer ts.Cleanup() + + // Create regular user + regularUser, err := ts.CreateRegularUser(ctx, "regularuser") + require.NoError(t, err) + regularUserCtx := ts.CreateUserContext(ctx, regularUser.ID) + + // Disable user registration + _, err = ts.Store.UpsertInstanceSetting(ctx, &storepb.InstanceSetting{ + Key: storepb.InstanceSettingKey_GENERAL, + Value: &storepb.InstanceSetting_GeneralSetting{ + GeneralSetting: &storepb.InstanceGeneralSetting{ + DisallowUserRegistration: true, + }, + }, + }) + require.NoError(t, err) + + // Regular user tries to create user when registration is disabled - should fail + _, err = ts.Service.CreateUser(regularUserCtx, &apiv1.CreateUserRequest{ + User: &apiv1.User{ + Username: "newuser", + Email: "newuser@example.com", + Password: "password123", + }, + }) + require.Error(t, err) + require.Contains(t, err.Error(), "not allowed") + }) + + t.Run("CreateUser host can assign roles", func(t *testing.T) { + ts := NewTestService(t) + defer ts.Cleanup() + + // Create host user + hostUser, err := ts.CreateHostUser(ctx, "admin") + require.NoError(t, err) + hostCtx := ts.CreateUserContext(ctx, hostUser.ID) + + // Host user can create user with specific role - should succeed + createdUser, err := ts.Service.CreateUser(hostCtx, &apiv1.CreateUserRequest{ + User: &apiv1.User{ + Username: "newadmin", + Email: "newadmin@example.com", + Password: "password123", + Role: apiv1.User_ADMIN, + }, + }) + require.NoError(t, err) + require.NotNil(t, createdUser) + require.Equal(t, apiv1.User_ADMIN, createdUser.Role) + }) + + t.Run("CreateUser unauthenticated user can only create regular user", func(t *testing.T) { + ts := NewTestService(t) + defer ts.Cleanup() + + // User registration is enabled by default + + // Unauthenticated user tries to create admin user - role should be ignored + createdUser, err := ts.Service.CreateUser(ctx, &apiv1.CreateUserRequest{ + User: &apiv1.User{ + Username: "wannabeadmin", + Email: "wannabeadmin@example.com", + Password: "password123", + Role: apiv1.User_ADMIN, // This should be ignored + }, + }) + require.NoError(t, err) + require.NotNil(t, createdUser) + require.Equal(t, apiv1.User_USER, createdUser.Role, "Unauthenticated users can only create USER role") + }) +} diff --git a/store/db/postgres/reaction.go b/store/db/postgres/reaction.go index e2b64737c..b5af43e3e 100644 --- a/store/db/postgres/reaction.go +++ b/store/db/postgres/reaction.go @@ -82,6 +82,19 @@ func (d *DB) ListReactions(ctx context.Context, find *store.FindReaction) ([]*st return list, nil } +func (d *DB) GetReaction(ctx context.Context, find *store.FindReaction) (*store.Reaction, error) { + list, err := d.ListReactions(ctx, find) + if err != nil { + return nil, err + } + if len(list) == 0 { + return nil, nil + } + + reaction := list[0] + return reaction, nil +} + func (d *DB) DeleteReaction(ctx context.Context, delete *store.DeleteReaction) error { _, err := d.db.ExecContext(ctx, "DELETE FROM reaction WHERE id = $1", delete.ID) return err diff --git a/store/db/sqlite/reaction.go b/store/db/sqlite/reaction.go index a6f87cdc5..c4edfd613 100644 --- a/store/db/sqlite/reaction.go +++ b/store/db/sqlite/reaction.go @@ -2,6 +2,8 @@ package sqlite import ( "context" + "database/sql" + "errors" "strings" "github.com/usememos/memos/store" @@ -87,6 +89,47 @@ func (d *DB) ListReactions(ctx context.Context, find *store.FindReaction) ([]*st return list, nil } +func (d *DB) GetReaction(ctx context.Context, find *store.FindReaction) (*store.Reaction, error) { + where, args := []string{"1 = 1"}, []any{} + + if find.ID != nil { + where, args = append(where, "id = ?"), append(args, *find.ID) + } + if find.CreatorID != nil { + where, args = append(where, "creator_id = ?"), append(args, *find.CreatorID) + } + if find.ContentID != nil { + where, args = append(where, "content_id = ?"), append(args, *find.ContentID) + } + + reaction := &store.Reaction{} + if err := d.db.QueryRowContext(ctx, ` + SELECT + id, + created_ts, + creator_id, + content_id, + reaction_type + FROM reaction + WHERE `+strings.Join(where, " AND ")+` + LIMIT 1`, + args..., + ).Scan( + &reaction.ID, + &reaction.CreatedTs, + &reaction.CreatorID, + &reaction.ContentID, + &reaction.ReactionType, + ); err != nil { + if errors.Is(err, sql.ErrNoRows) { + return nil, nil + } + return nil, err + } + + return reaction, nil +} + func (d *DB) DeleteReaction(ctx context.Context, delete *store.DeleteReaction) error { _, err := d.db.ExecContext(ctx, "DELETE FROM `reaction` WHERE `id` = ?", delete.ID) return err diff --git a/store/driver.go b/store/driver.go index bee6ed508..029f522d0 100644 --- a/store/driver.go +++ b/store/driver.go @@ -71,5 +71,6 @@ type Driver interface { // Reaction model related methods. UpsertReaction(ctx context.Context, create *Reaction) (*Reaction, error) ListReactions(ctx context.Context, find *FindReaction) ([]*Reaction, error) + GetReaction(ctx context.Context, find *FindReaction) (*Reaction, error) DeleteReaction(ctx context.Context, delete *DeleteReaction) error } diff --git a/store/reaction.go b/store/reaction.go index a10093128..389da3d23 100644 --- a/store/reaction.go +++ b/store/reaction.go @@ -32,6 +32,10 @@ func (s *Store) ListReactions(ctx context.Context, find *FindReaction) ([]*React return s.driver.ListReactions(ctx, find) } +func (s *Store) GetReaction(ctx context.Context, find *FindReaction) (*Reaction, error) { + return s.driver.GetReaction(ctx, find) +} + func (s *Store) DeleteReaction(ctx context.Context, delete *DeleteReaction) error { return s.driver.DeleteReaction(ctx, delete) } diff --git a/store/test/reaction_test.go b/store/test/reaction_test.go index fc83861e4..dc0817547 100644 --- a/store/test/reaction_test.go +++ b/store/test/reaction_test.go @@ -33,6 +33,25 @@ func TestReactionStore(t *testing.T) { require.Len(t, reactions, 1) require.Equal(t, reaction, reactions[0]) + // Test GetReaction. + gotReaction, err := ts.GetReaction(ctx, &store.FindReaction{ + ID: &reaction.ID, + }) + require.NoError(t, err) + require.NotNil(t, gotReaction) + require.Equal(t, reaction.ID, gotReaction.ID) + require.Equal(t, reaction.CreatorID, gotReaction.CreatorID) + require.Equal(t, reaction.ContentID, gotReaction.ContentID) + require.Equal(t, reaction.ReactionType, gotReaction.ReactionType) + + // Test GetReaction with non-existent ID. + nonExistentID := int32(99999) + notFoundReaction, err := ts.GetReaction(ctx, &store.FindReaction{ + ID: &nonExistentID, + }) + require.NoError(t, err) + require.Nil(t, notFoundReaction) + err = ts.DeleteReaction(ctx, &store.DeleteReaction{ ID: reaction.ID, })