fix(security): implement security review recommendations (#5228)

Co-authored-by: Claude <noreply@anthropic.com>
pull/5231/head
boojack 1 week ago committed by GitHub
parent bb8fa90496
commit 21d31e3609
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

@ -29,6 +29,9 @@ func (s *APIV1Service) SetMemoAttachments(ctx context.Context, request *v1pb.Set
if err != nil { if err != nil {
return nil, status.Errorf(codes.Internal, "failed to get memo") return nil, status.Errorf(codes.Internal, "failed to get memo")
} }
if memo == nil {
return nil, status.Errorf(codes.NotFound, "memo not found")
}
if memo.CreatorID != user.ID && !isSuperUser(user) { if memo.CreatorID != user.ID && !isSuperUser(user) {
return nil, status.Errorf(codes.PermissionDenied, "permission denied") return nil, status.Errorf(codes.PermissionDenied, "permission denied")
} }

@ -29,6 +29,9 @@ func (s *APIV1Service) SetMemoRelations(ctx context.Context, request *v1pb.SetMe
if err != nil { if err != nil {
return nil, status.Errorf(codes.Internal, "failed to get memo") return nil, status.Errorf(codes.Internal, "failed to get memo")
} }
if memo == nil {
return nil, status.Errorf(codes.NotFound, "memo not found")
}
if memo.CreatorID != user.ID && !isSuperUser(user) { if memo.CreatorID != user.ID && !isSuperUser(user) {
return nil, status.Errorf(codes.PermissionDenied, "permission denied") return nil, status.Errorf(codes.PermissionDenied, "permission denied")
} }

@ -68,18 +68,18 @@ func (s *APIV1Service) DeleteMemoReaction(ctx context.Context, request *v1pb.Del
return nil, status.Errorf(codes.InvalidArgument, "invalid reaction name: %v", err) return nil, status.Errorf(codes.InvalidArgument, "invalid reaction name: %v", err)
} }
// Get reaction and check ownership // Get reaction and check ownership.
reactions, err := s.Store.ListReactions(ctx, &store.FindReaction{ reaction, err := s.Store.GetReaction(ctx, &store.FindReaction{
ID: &reactionID, ID: &reactionID,
}) })
if err != nil { if err != nil {
return nil, status.Errorf(codes.Internal, "failed to list reactions") return nil, status.Errorf(codes.Internal, "failed to get reaction")
} }
if len(reactions) == 0 { if reaction == nil {
return nil, status.Errorf(codes.NotFound, "reaction not found") // Return permission denied to avoid revealing if reaction exists.
return nil, status.Errorf(codes.PermissionDenied, "permission denied")
} }
reaction := reactions[0]
if reaction.CreatorID != user.ID && !isSuperUser(user) { if reaction.CreatorID != user.ID && !isSuperUser(user) {
return nil, status.Errorf(codes.PermissionDenied, "permission denied") return nil, status.Errorf(codes.PermissionDenied, "permission denied")
} }

@ -0,0 +1,166 @@
package test
import (
"context"
"testing"
"github.com/stretchr/testify/require"
apiv1 "github.com/usememos/memos/proto/gen/api/v1"
)
func TestSetMemoAttachments(t *testing.T) {
ctx := context.Background()
t.Run("SetMemoAttachments success by memo owner", func(t *testing.T) {
ts := NewTestService(t)
defer ts.Cleanup()
// Create user
user, err := ts.CreateRegularUser(ctx, "user")
require.NoError(t, err)
userCtx := ts.CreateUserContext(ctx, user.ID)
// Create memo
memo, err := ts.Service.CreateMemo(userCtx, &apiv1.CreateMemoRequest{
Memo: &apiv1.Memo{
Content: "Test memo",
Visibility: apiv1.Visibility_PRIVATE,
},
})
require.NoError(t, err)
require.NotNil(t, memo)
// Create attachment
attachment, err := ts.Service.CreateAttachment(userCtx, &apiv1.CreateAttachmentRequest{
Attachment: &apiv1.Attachment{
Filename: "test.txt",
Size: 5,
Type: "text/plain",
Content: []byte("hello"),
},
})
require.NoError(t, err)
require.NotNil(t, attachment)
// Set memo attachments - should succeed
_, err = ts.Service.SetMemoAttachments(userCtx, &apiv1.SetMemoAttachmentsRequest{
Name: memo.Name,
Attachments: []*apiv1.Attachment{
{Name: attachment.Name},
},
})
require.NoError(t, err)
})
t.Run("SetMemoAttachments success by host user", func(t *testing.T) {
ts := NewTestService(t)
defer ts.Cleanup()
// Create regular user
regularUser, err := ts.CreateRegularUser(ctx, "user")
require.NoError(t, err)
regularUserCtx := ts.CreateUserContext(ctx, regularUser.ID)
// Create host user
hostUser, err := ts.CreateHostUser(ctx, "admin")
require.NoError(t, err)
hostCtx := ts.CreateUserContext(ctx, hostUser.ID)
// Create memo by regular user
memo, err := ts.Service.CreateMemo(regularUserCtx, &apiv1.CreateMemoRequest{
Memo: &apiv1.Memo{
Content: "Test memo",
Visibility: apiv1.Visibility_PRIVATE,
},
})
require.NoError(t, err)
require.NotNil(t, memo)
// Host user can modify attachments - should succeed
_, err = ts.Service.SetMemoAttachments(hostCtx, &apiv1.SetMemoAttachmentsRequest{
Name: memo.Name,
Attachments: []*apiv1.Attachment{},
})
require.NoError(t, err)
})
t.Run("SetMemoAttachments permission denied for non-owner", func(t *testing.T) {
ts := NewTestService(t)
defer ts.Cleanup()
// Create user1
user1, err := ts.CreateRegularUser(ctx, "user1")
require.NoError(t, err)
user1Ctx := ts.CreateUserContext(ctx, user1.ID)
// Create user2
user2, err := ts.CreateRegularUser(ctx, "user2")
require.NoError(t, err)
user2Ctx := ts.CreateUserContext(ctx, user2.ID)
// Create memo by user1
memo, err := ts.Service.CreateMemo(user1Ctx, &apiv1.CreateMemoRequest{
Memo: &apiv1.Memo{
Content: "Test memo",
Visibility: apiv1.Visibility_PRIVATE,
},
})
require.NoError(t, err)
require.NotNil(t, memo)
// User2 tries to modify attachments - should fail
_, err = ts.Service.SetMemoAttachments(user2Ctx, &apiv1.SetMemoAttachmentsRequest{
Name: memo.Name,
Attachments: []*apiv1.Attachment{},
})
require.Error(t, err)
require.Contains(t, err.Error(), "permission denied")
})
t.Run("SetMemoAttachments unauthenticated", func(t *testing.T) {
ts := NewTestService(t)
defer ts.Cleanup()
// Create user
user, err := ts.CreateRegularUser(ctx, "user")
require.NoError(t, err)
userCtx := ts.CreateUserContext(ctx, user.ID)
// Create memo
memo, err := ts.Service.CreateMemo(userCtx, &apiv1.CreateMemoRequest{
Memo: &apiv1.Memo{
Content: "Test memo",
Visibility: apiv1.Visibility_PRIVATE,
},
})
require.NoError(t, err)
require.NotNil(t, memo)
// Unauthenticated user tries to modify attachments - should fail
_, err = ts.Service.SetMemoAttachments(ctx, &apiv1.SetMemoAttachmentsRequest{
Name: memo.Name,
Attachments: []*apiv1.Attachment{},
})
require.Error(t, err)
require.Contains(t, err.Error(), "not authenticated")
})
t.Run("SetMemoAttachments memo not found", func(t *testing.T) {
ts := NewTestService(t)
defer ts.Cleanup()
// Create user
user, err := ts.CreateRegularUser(ctx, "user")
require.NoError(t, err)
userCtx := ts.CreateUserContext(ctx, user.ID)
// Try to set attachments on non-existent memo - should fail
_, err = ts.Service.SetMemoAttachments(userCtx, &apiv1.SetMemoAttachmentsRequest{
Name: "memos/nonexistent-uid-12345",
Attachments: []*apiv1.Attachment{},
})
require.Error(t, err)
require.Contains(t, err.Error(), "not found")
})
}

@ -0,0 +1,169 @@
package test
import (
"context"
"testing"
"github.com/stretchr/testify/require"
apiv1 "github.com/usememos/memos/proto/gen/api/v1"
)
func TestSetMemoRelations(t *testing.T) {
ctx := context.Background()
t.Run("SetMemoRelations success by memo owner", func(t *testing.T) {
ts := NewTestService(t)
defer ts.Cleanup()
// Create user
user, err := ts.CreateRegularUser(ctx, "user")
require.NoError(t, err)
userCtx := ts.CreateUserContext(ctx, user.ID)
// Create memo1
memo1, err := ts.Service.CreateMemo(userCtx, &apiv1.CreateMemoRequest{
Memo: &apiv1.Memo{
Content: "Test memo 1",
Visibility: apiv1.Visibility_PRIVATE,
},
})
require.NoError(t, err)
require.NotNil(t, memo1)
// Create memo2
memo2, err := ts.Service.CreateMemo(userCtx, &apiv1.CreateMemoRequest{
Memo: &apiv1.Memo{
Content: "Test memo 2",
Visibility: apiv1.Visibility_PRIVATE,
},
})
require.NoError(t, err)
require.NotNil(t, memo2)
// Set memo relations - should succeed
_, err = ts.Service.SetMemoRelations(userCtx, &apiv1.SetMemoRelationsRequest{
Name: memo1.Name,
Relations: []*apiv1.MemoRelation{
{
RelatedMemo: &apiv1.MemoRelation_Memo{
Name: memo2.Name,
},
Type: apiv1.MemoRelation_REFERENCE,
},
},
})
require.NoError(t, err)
})
t.Run("SetMemoRelations success by host user", func(t *testing.T) {
ts := NewTestService(t)
defer ts.Cleanup()
// Create regular user
regularUser, err := ts.CreateRegularUser(ctx, "user")
require.NoError(t, err)
regularUserCtx := ts.CreateUserContext(ctx, regularUser.ID)
// Create host user
hostUser, err := ts.CreateHostUser(ctx, "admin")
require.NoError(t, err)
hostCtx := ts.CreateUserContext(ctx, hostUser.ID)
// Create memo by regular user
memo, err := ts.Service.CreateMemo(regularUserCtx, &apiv1.CreateMemoRequest{
Memo: &apiv1.Memo{
Content: "Test memo",
Visibility: apiv1.Visibility_PRIVATE,
},
})
require.NoError(t, err)
require.NotNil(t, memo)
// Host user can modify relations - should succeed
_, err = ts.Service.SetMemoRelations(hostCtx, &apiv1.SetMemoRelationsRequest{
Name: memo.Name,
Relations: []*apiv1.MemoRelation{},
})
require.NoError(t, err)
})
t.Run("SetMemoRelations permission denied for non-owner", func(t *testing.T) {
ts := NewTestService(t)
defer ts.Cleanup()
// Create user1
user1, err := ts.CreateRegularUser(ctx, "user1")
require.NoError(t, err)
user1Ctx := ts.CreateUserContext(ctx, user1.ID)
// Create user2
user2, err := ts.CreateRegularUser(ctx, "user2")
require.NoError(t, err)
user2Ctx := ts.CreateUserContext(ctx, user2.ID)
// Create memo by user1
memo, err := ts.Service.CreateMemo(user1Ctx, &apiv1.CreateMemoRequest{
Memo: &apiv1.Memo{
Content: "Test memo",
Visibility: apiv1.Visibility_PRIVATE,
},
})
require.NoError(t, err)
require.NotNil(t, memo)
// User2 tries to modify relations - should fail
_, err = ts.Service.SetMemoRelations(user2Ctx, &apiv1.SetMemoRelationsRequest{
Name: memo.Name,
Relations: []*apiv1.MemoRelation{},
})
require.Error(t, err)
require.Contains(t, err.Error(), "permission denied")
})
t.Run("SetMemoRelations unauthenticated", func(t *testing.T) {
ts := NewTestService(t)
defer ts.Cleanup()
// Create user
user, err := ts.CreateRegularUser(ctx, "user")
require.NoError(t, err)
userCtx := ts.CreateUserContext(ctx, user.ID)
// Create memo
memo, err := ts.Service.CreateMemo(userCtx, &apiv1.CreateMemoRequest{
Memo: &apiv1.Memo{
Content: "Test memo",
Visibility: apiv1.Visibility_PRIVATE,
},
})
require.NoError(t, err)
require.NotNil(t, memo)
// Unauthenticated user tries to modify relations - should fail
_, err = ts.Service.SetMemoRelations(ctx, &apiv1.SetMemoRelationsRequest{
Name: memo.Name,
Relations: []*apiv1.MemoRelation{},
})
require.Error(t, err)
require.Contains(t, err.Error(), "not authenticated")
})
t.Run("SetMemoRelations memo not found", func(t *testing.T) {
ts := NewTestService(t)
defer ts.Cleanup()
// Create user
user, err := ts.CreateRegularUser(ctx, "user")
require.NoError(t, err)
userCtx := ts.CreateUserContext(ctx, user.ID)
// Try to set relations on non-existent memo - should fail
_, err = ts.Service.SetMemoRelations(userCtx, &apiv1.SetMemoRelationsRequest{
Name: "memos/nonexistent-uid-12345",
Relations: []*apiv1.MemoRelation{},
})
require.Error(t, err)
require.Contains(t, err.Error(), "not found")
})
}

@ -0,0 +1,193 @@
package test
import (
"context"
"testing"
"github.com/stretchr/testify/require"
apiv1 "github.com/usememos/memos/proto/gen/api/v1"
)
func TestDeleteMemoReaction(t *testing.T) {
ctx := context.Background()
t.Run("DeleteMemoReaction success by reaction owner", func(t *testing.T) {
ts := NewTestService(t)
defer ts.Cleanup()
// Create user
user, err := ts.CreateRegularUser(ctx, "user")
require.NoError(t, err)
userCtx := ts.CreateUserContext(ctx, user.ID)
// Create memo
memo, err := ts.Service.CreateMemo(userCtx, &apiv1.CreateMemoRequest{
Memo: &apiv1.Memo{
Content: "Test memo",
Visibility: apiv1.Visibility_PUBLIC,
},
})
require.NoError(t, err)
require.NotNil(t, memo)
// Create reaction
reaction, err := ts.Service.UpsertMemoReaction(userCtx, &apiv1.UpsertMemoReactionRequest{
Name: memo.Name,
Reaction: &apiv1.Reaction{
ContentId: memo.Name,
ReactionType: "👍",
},
})
require.NoError(t, err)
require.NotNil(t, reaction)
// Delete reaction - should succeed
_, err = ts.Service.DeleteMemoReaction(userCtx, &apiv1.DeleteMemoReactionRequest{
Name: reaction.Name,
})
require.NoError(t, err)
})
t.Run("DeleteMemoReaction success by host user", func(t *testing.T) {
ts := NewTestService(t)
defer ts.Cleanup()
// Create regular user
regularUser, err := ts.CreateRegularUser(ctx, "user")
require.NoError(t, err)
regularUserCtx := ts.CreateUserContext(ctx, regularUser.ID)
// Create host user
hostUser, err := ts.CreateHostUser(ctx, "admin")
require.NoError(t, err)
hostCtx := ts.CreateUserContext(ctx, hostUser.ID)
// Create memo by regular user
memo, err := ts.Service.CreateMemo(regularUserCtx, &apiv1.CreateMemoRequest{
Memo: &apiv1.Memo{
Content: "Test memo",
Visibility: apiv1.Visibility_PUBLIC,
},
})
require.NoError(t, err)
require.NotNil(t, memo)
// Create reaction by regular user
reaction, err := ts.Service.UpsertMemoReaction(regularUserCtx, &apiv1.UpsertMemoReactionRequest{
Name: memo.Name,
Reaction: &apiv1.Reaction{
ContentId: memo.Name,
ReactionType: "👍",
},
})
require.NoError(t, err)
require.NotNil(t, reaction)
// Host user can delete reaction - should succeed
_, err = ts.Service.DeleteMemoReaction(hostCtx, &apiv1.DeleteMemoReactionRequest{
Name: reaction.Name,
})
require.NoError(t, err)
})
t.Run("DeleteMemoReaction permission denied for non-owner", func(t *testing.T) {
ts := NewTestService(t)
defer ts.Cleanup()
// Create user1
user1, err := ts.CreateRegularUser(ctx, "user1")
require.NoError(t, err)
user1Ctx := ts.CreateUserContext(ctx, user1.ID)
// Create user2
user2, err := ts.CreateRegularUser(ctx, "user2")
require.NoError(t, err)
user2Ctx := ts.CreateUserContext(ctx, user2.ID)
// Create memo by user1
memo, err := ts.Service.CreateMemo(user1Ctx, &apiv1.CreateMemoRequest{
Memo: &apiv1.Memo{
Content: "Test memo",
Visibility: apiv1.Visibility_PUBLIC,
},
})
require.NoError(t, err)
require.NotNil(t, memo)
// Create reaction by user1
reaction, err := ts.Service.UpsertMemoReaction(user1Ctx, &apiv1.UpsertMemoReactionRequest{
Name: memo.Name,
Reaction: &apiv1.Reaction{
ContentId: memo.Name,
ReactionType: "👍",
},
})
require.NoError(t, err)
require.NotNil(t, reaction)
// User2 tries to delete reaction - should fail with permission denied
_, err = ts.Service.DeleteMemoReaction(user2Ctx, &apiv1.DeleteMemoReactionRequest{
Name: reaction.Name,
})
require.Error(t, err)
require.Contains(t, err.Error(), "permission denied")
})
t.Run("DeleteMemoReaction unauthenticated", func(t *testing.T) {
ts := NewTestService(t)
defer ts.Cleanup()
// Create user
user, err := ts.CreateRegularUser(ctx, "user")
require.NoError(t, err)
userCtx := ts.CreateUserContext(ctx, user.ID)
// Create memo
memo, err := ts.Service.CreateMemo(userCtx, &apiv1.CreateMemoRequest{
Memo: &apiv1.Memo{
Content: "Test memo",
Visibility: apiv1.Visibility_PUBLIC,
},
})
require.NoError(t, err)
require.NotNil(t, memo)
// Create reaction
reaction, err := ts.Service.UpsertMemoReaction(userCtx, &apiv1.UpsertMemoReactionRequest{
Name: memo.Name,
Reaction: &apiv1.Reaction{
ContentId: memo.Name,
ReactionType: "👍",
},
})
require.NoError(t, err)
require.NotNil(t, reaction)
// Unauthenticated user tries to delete reaction - should fail
_, err = ts.Service.DeleteMemoReaction(ctx, &apiv1.DeleteMemoReactionRequest{
Name: reaction.Name,
})
require.Error(t, err)
require.Contains(t, err.Error(), "not authenticated")
})
t.Run("DeleteMemoReaction not found returns permission denied", func(t *testing.T) {
ts := NewTestService(t)
defer ts.Cleanup()
// Create user
user, err := ts.CreateRegularUser(ctx, "user")
require.NoError(t, err)
userCtx := ts.CreateUserContext(ctx, user.ID)
// Try to delete non-existent reaction - should fail with permission denied
// (not "not found" to avoid information disclosure)
_, err = ts.Service.DeleteMemoReaction(userCtx, &apiv1.DeleteMemoReactionRequest{
Name: "reactions/99999",
})
require.Error(t, err)
require.Contains(t, err.Error(), "permission denied")
require.NotContains(t, err.Error(), "not found")
})
}

@ -0,0 +1,165 @@
package test
import (
"context"
"testing"
"github.com/stretchr/testify/require"
apiv1 "github.com/usememos/memos/proto/gen/api/v1"
storepb "github.com/usememos/memos/proto/gen/store"
)
func TestCreateUserRegistration(t *testing.T) {
ctx := context.Background()
t.Run("CreateUser success when registration enabled", func(t *testing.T) {
ts := NewTestService(t)
defer ts.Cleanup()
// User registration is enabled by default, no need to set it explicitly
// Create user without authentication - should succeed
_, err := ts.Service.CreateUser(ctx, &apiv1.CreateUserRequest{
User: &apiv1.User{
Username: "newuser",
Email: "newuser@example.com",
Password: "password123",
},
})
require.NoError(t, err)
})
t.Run("CreateUser blocked when registration disabled", func(t *testing.T) {
ts := NewTestService(t)
defer ts.Cleanup()
// Disable user registration
_, err := ts.Store.UpsertInstanceSetting(ctx, &storepb.InstanceSetting{
Key: storepb.InstanceSettingKey_GENERAL,
Value: &storepb.InstanceSetting_GeneralSetting{
GeneralSetting: &storepb.InstanceGeneralSetting{
DisallowUserRegistration: true,
},
},
})
require.NoError(t, err)
// Try to create user without authentication - should fail
_, err = ts.Service.CreateUser(ctx, &apiv1.CreateUserRequest{
User: &apiv1.User{
Username: "newuser",
Email: "newuser@example.com",
Password: "password123",
},
})
require.Error(t, err)
require.Contains(t, err.Error(), "not allowed")
})
t.Run("CreateUser succeeds for superuser even when registration disabled", func(t *testing.T) {
ts := NewTestService(t)
defer ts.Cleanup()
// Create host user
hostUser, err := ts.CreateHostUser(ctx, "admin")
require.NoError(t, err)
hostCtx := ts.CreateUserContext(ctx, hostUser.ID)
// Disable user registration
_, err = ts.Store.UpsertInstanceSetting(ctx, &storepb.InstanceSetting{
Key: storepb.InstanceSettingKey_GENERAL,
Value: &storepb.InstanceSetting_GeneralSetting{
GeneralSetting: &storepb.InstanceGeneralSetting{
DisallowUserRegistration: true,
},
},
})
require.NoError(t, err)
// Host user can create users even when registration is disabled - should succeed
_, err = ts.Service.CreateUser(hostCtx, &apiv1.CreateUserRequest{
User: &apiv1.User{
Username: "newuser",
Email: "newuser@example.com",
Password: "password123",
},
})
require.NoError(t, err)
})
t.Run("CreateUser regular user cannot create users when registration disabled", func(t *testing.T) {
ts := NewTestService(t)
defer ts.Cleanup()
// Create regular user
regularUser, err := ts.CreateRegularUser(ctx, "regularuser")
require.NoError(t, err)
regularUserCtx := ts.CreateUserContext(ctx, regularUser.ID)
// Disable user registration
_, err = ts.Store.UpsertInstanceSetting(ctx, &storepb.InstanceSetting{
Key: storepb.InstanceSettingKey_GENERAL,
Value: &storepb.InstanceSetting_GeneralSetting{
GeneralSetting: &storepb.InstanceGeneralSetting{
DisallowUserRegistration: true,
},
},
})
require.NoError(t, err)
// Regular user tries to create user when registration is disabled - should fail
_, err = ts.Service.CreateUser(regularUserCtx, &apiv1.CreateUserRequest{
User: &apiv1.User{
Username: "newuser",
Email: "newuser@example.com",
Password: "password123",
},
})
require.Error(t, err)
require.Contains(t, err.Error(), "not allowed")
})
t.Run("CreateUser host can assign roles", func(t *testing.T) {
ts := NewTestService(t)
defer ts.Cleanup()
// Create host user
hostUser, err := ts.CreateHostUser(ctx, "admin")
require.NoError(t, err)
hostCtx := ts.CreateUserContext(ctx, hostUser.ID)
// Host user can create user with specific role - should succeed
createdUser, err := ts.Service.CreateUser(hostCtx, &apiv1.CreateUserRequest{
User: &apiv1.User{
Username: "newadmin",
Email: "newadmin@example.com",
Password: "password123",
Role: apiv1.User_ADMIN,
},
})
require.NoError(t, err)
require.NotNil(t, createdUser)
require.Equal(t, apiv1.User_ADMIN, createdUser.Role)
})
t.Run("CreateUser unauthenticated user can only create regular user", func(t *testing.T) {
ts := NewTestService(t)
defer ts.Cleanup()
// User registration is enabled by default
// Unauthenticated user tries to create admin user - role should be ignored
createdUser, err := ts.Service.CreateUser(ctx, &apiv1.CreateUserRequest{
User: &apiv1.User{
Username: "wannabeadmin",
Email: "wannabeadmin@example.com",
Password: "password123",
Role: apiv1.User_ADMIN, // This should be ignored
},
})
require.NoError(t, err)
require.NotNil(t, createdUser)
require.Equal(t, apiv1.User_USER, createdUser.Role, "Unauthenticated users can only create USER role")
})
}

@ -82,6 +82,19 @@ func (d *DB) ListReactions(ctx context.Context, find *store.FindReaction) ([]*st
return list, nil return list, nil
} }
func (d *DB) GetReaction(ctx context.Context, find *store.FindReaction) (*store.Reaction, error) {
list, err := d.ListReactions(ctx, find)
if err != nil {
return nil, err
}
if len(list) == 0 {
return nil, nil
}
reaction := list[0]
return reaction, nil
}
func (d *DB) DeleteReaction(ctx context.Context, delete *store.DeleteReaction) error { func (d *DB) DeleteReaction(ctx context.Context, delete *store.DeleteReaction) error {
_, err := d.db.ExecContext(ctx, "DELETE FROM reaction WHERE id = $1", delete.ID) _, err := d.db.ExecContext(ctx, "DELETE FROM reaction WHERE id = $1", delete.ID)
return err return err

@ -2,6 +2,8 @@ package sqlite
import ( import (
"context" "context"
"database/sql"
"errors"
"strings" "strings"
"github.com/usememos/memos/store" "github.com/usememos/memos/store"
@ -87,6 +89,47 @@ func (d *DB) ListReactions(ctx context.Context, find *store.FindReaction) ([]*st
return list, nil return list, nil
} }
func (d *DB) GetReaction(ctx context.Context, find *store.FindReaction) (*store.Reaction, error) {
where, args := []string{"1 = 1"}, []any{}
if find.ID != nil {
where, args = append(where, "id = ?"), append(args, *find.ID)
}
if find.CreatorID != nil {
where, args = append(where, "creator_id = ?"), append(args, *find.CreatorID)
}
if find.ContentID != nil {
where, args = append(where, "content_id = ?"), append(args, *find.ContentID)
}
reaction := &store.Reaction{}
if err := d.db.QueryRowContext(ctx, `
SELECT
id,
created_ts,
creator_id,
content_id,
reaction_type
FROM reaction
WHERE `+strings.Join(where, " AND ")+`
LIMIT 1`,
args...,
).Scan(
&reaction.ID,
&reaction.CreatedTs,
&reaction.CreatorID,
&reaction.ContentID,
&reaction.ReactionType,
); err != nil {
if errors.Is(err, sql.ErrNoRows) {
return nil, nil
}
return nil, err
}
return reaction, nil
}
func (d *DB) DeleteReaction(ctx context.Context, delete *store.DeleteReaction) error { func (d *DB) DeleteReaction(ctx context.Context, delete *store.DeleteReaction) error {
_, err := d.db.ExecContext(ctx, "DELETE FROM `reaction` WHERE `id` = ?", delete.ID) _, err := d.db.ExecContext(ctx, "DELETE FROM `reaction` WHERE `id` = ?", delete.ID)
return err return err

@ -71,5 +71,6 @@ type Driver interface {
// Reaction model related methods. // Reaction model related methods.
UpsertReaction(ctx context.Context, create *Reaction) (*Reaction, error) UpsertReaction(ctx context.Context, create *Reaction) (*Reaction, error)
ListReactions(ctx context.Context, find *FindReaction) ([]*Reaction, error) ListReactions(ctx context.Context, find *FindReaction) ([]*Reaction, error)
GetReaction(ctx context.Context, find *FindReaction) (*Reaction, error)
DeleteReaction(ctx context.Context, delete *DeleteReaction) error DeleteReaction(ctx context.Context, delete *DeleteReaction) error
} }

@ -32,6 +32,10 @@ func (s *Store) ListReactions(ctx context.Context, find *FindReaction) ([]*React
return s.driver.ListReactions(ctx, find) return s.driver.ListReactions(ctx, find)
} }
func (s *Store) GetReaction(ctx context.Context, find *FindReaction) (*Reaction, error) {
return s.driver.GetReaction(ctx, find)
}
func (s *Store) DeleteReaction(ctx context.Context, delete *DeleteReaction) error { func (s *Store) DeleteReaction(ctx context.Context, delete *DeleteReaction) error {
return s.driver.DeleteReaction(ctx, delete) return s.driver.DeleteReaction(ctx, delete)
} }

@ -33,6 +33,25 @@ func TestReactionStore(t *testing.T) {
require.Len(t, reactions, 1) require.Len(t, reactions, 1)
require.Equal(t, reaction, reactions[0]) require.Equal(t, reaction, reactions[0])
// Test GetReaction.
gotReaction, err := ts.GetReaction(ctx, &store.FindReaction{
ID: &reaction.ID,
})
require.NoError(t, err)
require.NotNil(t, gotReaction)
require.Equal(t, reaction.ID, gotReaction.ID)
require.Equal(t, reaction.CreatorID, gotReaction.CreatorID)
require.Equal(t, reaction.ContentID, gotReaction.ContentID)
require.Equal(t, reaction.ReactionType, gotReaction.ReactionType)
// Test GetReaction with non-existent ID.
nonExistentID := int32(99999)
notFoundReaction, err := ts.GetReaction(ctx, &store.FindReaction{
ID: &nonExistentID,
})
require.NoError(t, err)
require.Nil(t, notFoundReaction)
err = ts.DeleteReaction(ctx, &store.DeleteReaction{ err = ts.DeleteReaction(ctx, &store.DeleteReaction{
ID: reaction.ID, ID: reaction.ID,
}) })

Loading…
Cancel
Save