diff --git a/server/router/api/v1/memo_attachment_service.go b/server/router/api/v1/memo_attachment_service.go index 6f6e31a14..07344769f 100644 --- a/server/router/api/v1/memo_attachment_service.go +++ b/server/router/api/v1/memo_attachment_service.go @@ -35,7 +35,7 @@ func (s *APIV1Service) SetMemoAttachments(ctx context.Context, request *v1pb.Set if !canModifyMemo(user, memo) { return nil, status.Errorf(codes.PermissionDenied, "permission denied") } - if err := s.setMemoAttachmentsInternal(ctx, memo, request.Attachments); err != nil { + if err := s.setMemoAttachmentsInternal(ctx, user, memo, request.Attachments); err != nil { return nil, err } if err := s.touchMemoUpdatedTimestamp(ctx, memo.ID); err != nil { @@ -50,7 +50,7 @@ func (s *APIV1Service) SetMemoAttachments(ctx context.Context, request *v1pb.Set return &emptypb.Empty{}, nil } -func (s *APIV1Service) setMemoAttachmentsInternal(ctx context.Context, memo *store.Memo, requestAttachments []*v1pb.Attachment) error { +func (s *APIV1Service) setMemoAttachmentsInternal(ctx context.Context, user *store.User, memo *store.Memo, requestAttachments []*v1pb.Attachment) error { currentAttachments, err := s.Store.ListAttachments(ctx, &store.FindAttachment{ MemoID: &memo.ID, }) @@ -58,7 +58,7 @@ func (s *APIV1Service) setMemoAttachmentsInternal(ctx context.Context, memo *sto return status.Errorf(codes.Internal, "failed to list attachments") } - normalizedAttachments, err := s.normalizeMemoAttachmentRequest(ctx, currentAttachments, requestAttachments) + normalizedAttachments, err := s.normalizeMemoAttachmentRequest(ctx, user, currentAttachments, requestAttachments) if err != nil { return err } @@ -71,6 +71,9 @@ func (s *APIV1Service) setMemoAttachmentsInternal(ctx context.Context, memo *sto // Delete attachments that are not in the request. for _, attachment := range currentAttachments { if !requestedIDs[attachment.ID] { + if attachment.CreatorID != user.ID && !isSuperUser(user) { + return status.Errorf(codes.PermissionDenied, "cannot remove another user's attachment") + } if err = s.Store.DeleteAttachment(ctx, &store.DeleteAttachment{ ID: int32(attachment.ID), MemoID: &memo.ID, @@ -98,6 +101,7 @@ func (s *APIV1Service) setMemoAttachmentsInternal(ctx context.Context, memo *sto func (s *APIV1Service) normalizeMemoAttachmentRequest( ctx context.Context, + user *store.User, currentAttachments []*store.Attachment, requestAttachments []*v1pb.Attachment, ) ([]*store.Attachment, error) { @@ -114,6 +118,9 @@ func (s *APIV1Service) normalizeMemoAttachmentRequest( if attachment == nil { return nil, status.Errorf(codes.NotFound, "attachment not found: %s", attachmentUID) } + if attachment.CreatorID != user.ID && !isSuperUser(user) { + return nil, status.Errorf(codes.PermissionDenied, "cannot attach another user's attachment") + } requestedAttachments = append(requestedAttachments, attachment) } diff --git a/server/router/api/v1/memo_service.go b/server/router/api/v1/memo_service.go index 420a59ffc..3d949452b 100644 --- a/server/router/api/v1/memo_service.go +++ b/server/router/api/v1/memo_service.go @@ -508,7 +508,7 @@ func (s *APIV1Service) UpdateMemo(ctx context.Context, request *v1pb.UpdateMemoR payload.Location = convertLocationToStore(request.Memo.Location) update.Payload = payload } else if path == "attachments" { - if err := s.setMemoAttachmentsInternal(ctx, memo, request.Memo.Attachments); err != nil { + if err := s.setMemoAttachmentsInternal(ctx, user, memo, request.Memo.Attachments); err != nil { return nil, errors.Wrap(err, "failed to set memo attachments") } } else if path == "relations" { diff --git a/server/router/api/v1/test/memo_attachment_service_test.go b/server/router/api/v1/test/memo_attachment_service_test.go index f14437b03..a715a81b3 100644 --- a/server/router/api/v1/test/memo_attachment_service_test.go +++ b/server/router/api/v1/test/memo_attachment_service_test.go @@ -2,11 +2,13 @@ package test import ( "context" + "strings" "testing" "github.com/stretchr/testify/require" apiv1 "github.com/usememos/memos/proto/gen/api/v1" + "github.com/usememos/memos/store" ) func TestSetMemoAttachments(t *testing.T) { @@ -223,4 +225,122 @@ func TestSetMemoAttachments(t *testing.T) { require.NoError(t, err) require.Len(t, response.Attachments, 0) }) + + t.Run("SetMemoAttachments denies attaching another user's attachment", func(t *testing.T) { + ts := NewTestService(t) + defer ts.Cleanup() + + victim, err := ts.CreateRegularUser(ctx, "attachment_victim") + require.NoError(t, err) + attacker, err := ts.CreateRegularUser(ctx, "attachment_attacker") + require.NoError(t, err) + victimCtx := ts.CreateUserContext(ctx, victim.ID) + attackerCtx := ts.CreateUserContext(ctx, attacker.ID) + + victimAttachment, err := ts.Service.CreateAttachment(victimCtx, &apiv1.CreateAttachmentRequest{ + Attachment: &apiv1.Attachment{ + Filename: "secret.txt", + Size: 6, + Type: "text/plain", + Content: []byte("secret"), + }, + }) + require.NoError(t, err) + + victimMemo, err := ts.Service.CreateMemo(victimCtx, &apiv1.CreateMemoRequest{ + Memo: &apiv1.Memo{ + Content: "victim protected memo", + Visibility: apiv1.Visibility_PROTECTED, + Attachments: []*apiv1.Attachment{ + {Name: victimAttachment.Name}, + }, + }, + }) + require.NoError(t, err) + + attackerMemo, err := ts.Service.CreateMemo(attackerCtx, &apiv1.CreateMemoRequest{ + Memo: &apiv1.Memo{ + Content: "attacker public memo", + Visibility: apiv1.Visibility_PUBLIC, + }, + }) + require.NoError(t, err) + + _, err = ts.Service.SetMemoAttachments(attackerCtx, &apiv1.SetMemoAttachmentsRequest{ + Name: attackerMemo.Name, + Attachments: []*apiv1.Attachment{ + {Name: victimAttachment.Name}, + }, + }) + require.Error(t, err) + require.Contains(t, err.Error(), "cannot attach another user's attachment") + + victimAttachments, err := ts.Service.ListMemoAttachments(victimCtx, &apiv1.ListMemoAttachmentsRequest{Name: victimMemo.Name}) + require.NoError(t, err) + require.Len(t, victimAttachments.Attachments, 1) + require.Equal(t, victimAttachment.Name, victimAttachments.Attachments[0].Name) + + attackerAttachments, err := ts.Service.ListMemoAttachments(attackerCtx, &apiv1.ListMemoAttachmentsRequest{Name: attackerMemo.Name}) + require.NoError(t, err) + require.Empty(t, attackerAttachments.Attachments) + }) + + t.Run("SetMemoAttachments denies removing another user's attached attachment", func(t *testing.T) { + ts := NewTestService(t) + defer ts.Cleanup() + + victim, err := ts.CreateRegularUser(ctx, "remove_victim") + require.NoError(t, err) + attacker, err := ts.CreateRegularUser(ctx, "remove_attacker") + require.NoError(t, err) + victimCtx := ts.CreateUserContext(ctx, victim.ID) + attackerCtx := ts.CreateUserContext(ctx, attacker.ID) + + victimAttachment, err := ts.Service.CreateAttachment(victimCtx, &apiv1.CreateAttachmentRequest{ + Attachment: &apiv1.Attachment{ + Filename: "kept.txt", + Size: 4, + Type: "text/plain", + Content: []byte("kept"), + }, + }) + require.NoError(t, err) + + attackerMemo, err := ts.Service.CreateMemo(attackerCtx, &apiv1.CreateMemoRequest{ + Memo: &apiv1.Memo{ + Content: "contaminated memo", + Visibility: apiv1.Visibility_PUBLIC, + }, + }) + require.NoError(t, err) + + attachmentUID := strings.TrimPrefix(victimAttachment.Name, "attachments/") + attachment, err := ts.Store.GetAttachment(ctx, &store.FindAttachment{UID: &attachmentUID}) + require.NoError(t, err) + require.NotNil(t, attachment) + + memoUID := strings.TrimPrefix(attackerMemo.Name, "memos/") + memo, err := ts.Store.GetMemo(ctx, &store.FindMemo{UID: &memoUID}) + require.NoError(t, err) + require.NotNil(t, memo) + + err = ts.Store.UpdateAttachment(ctx, &store.UpdateAttachment{ + ID: attachment.ID, + MemoID: &memo.ID, + }) + require.NoError(t, err) + + _, err = ts.Service.SetMemoAttachments(attackerCtx, &apiv1.SetMemoAttachmentsRequest{ + Name: attackerMemo.Name, + Attachments: []*apiv1.Attachment{}, + }) + require.Error(t, err) + require.Contains(t, err.Error(), "cannot remove another user's attachment") + + attachmentAfter, err := ts.Store.GetAttachment(ctx, &store.FindAttachment{ID: &attachment.ID}) + require.NoError(t, err) + require.NotNil(t, attachmentAfter) + require.NotNil(t, attachmentAfter.MemoID) + require.Equal(t, memo.ID, *attachmentAfter.MemoID) + }) } diff --git a/server/router/api/v1/user_service.go b/server/router/api/v1/user_service.go index 4d59b0114..dec37ea44 100644 --- a/server/router/api/v1/user_service.go +++ b/server/router/api/v1/user_service.go @@ -381,8 +381,15 @@ func (s *APIV1Service) DeleteUser(ctx context.Context, request *v1pb.DeleteUserR } var attachmentCleanupErr error failedAttachmentIDs := make([]int32, 0) + attachmentStorageSetting, attachmentStorageSettingErr := getDeleteUserAttachmentStorageSetting(ctx, s.Store, attachments) for _, attachment := range attachments { - if err := s.Store.DeleteAttachmentStorage(ctx, attachment); err != nil { + var err error + if attachmentStorageSettingErr != nil && store.AttachmentNeedsInstanceStorageSetting(attachment) { + err = attachmentStorageSettingErr + } else { + err = s.Store.DeleteAttachmentStorageWithInstanceSetting(ctx, attachment, attachmentStorageSetting) + } + if err != nil { slog.Warn("failed to delete attachment storage after deleting user", "user_id", userID, "attachment_id", attachment.ID, "error", err) failedAttachmentIDs = append(failedAttachmentIDs, attachment.ID) if attachmentCleanupErr == nil { @@ -408,6 +415,19 @@ func (s *APIV1Service) DeleteUser(ctx context.Context, request *v1pb.DeleteUserR return &emptypb.Empty{}, nil } +func getDeleteUserAttachmentStorageSetting(ctx context.Context, stores *store.Store, attachments []*store.Attachment) (*storepb.InstanceStorageSetting, error) { + for _, attachment := range attachments { + if store.AttachmentNeedsInstanceStorageSetting(attachment) { + instanceStorageSetting, err := stores.GetInstanceStorageSetting(ctx) + if err != nil { + return nil, errors.Wrap(err, "failed to get instance storage setting") + } + return instanceStorageSetting, nil + } + } + return nil, nil +} + func getDefaultUserGeneralSetting() *v1pb.UserSetting_GeneralSetting { return &v1pb.UserSetting_GeneralSetting{ Locale: "en", diff --git a/store/attachment.go b/store/attachment.go index 38245855d..d0b2742f3 100644 --- a/store/attachment.go +++ b/store/attachment.go @@ -168,11 +168,18 @@ func (s *Store) DeleteAttachments(ctx context.Context, attachments []*Attachment return err } + instanceStorageSetting, instanceStorageSettingErr := s.getAttachmentStorageCleanupInstanceSetting(ctx, attachments) for _, attachment := range attachments { if attachment == nil { continue } - if err := s.DeleteAttachmentStorage(ctx, attachment); err != nil { + var err error + if instanceStorageSettingErr != nil && AttachmentNeedsInstanceStorageSetting(attachment) { + err = instanceStorageSettingErr + } else { + err = s.deleteAttachmentStorage(ctx, attachment, instanceStorageSetting) + } + if err != nil { if attachment.StorageType == storepb.AttachmentStorageType_LOCAL { return errors.Wrap(err, "failed to delete local file") } @@ -184,6 +191,15 @@ func (s *Store) DeleteAttachments(ctx context.Context, attachments []*Attachment } func (s *Store) DeleteAttachmentStorage(ctx context.Context, attachment *Attachment) error { + return s.deleteAttachmentStorage(ctx, attachment, nil) +} + +// DeleteAttachmentStorageWithInstanceSetting deletes attachment storage using a preloaded instance storage setting. +func (s *Store) DeleteAttachmentStorageWithInstanceSetting(ctx context.Context, attachment *Attachment, instanceStorageSetting *storepb.InstanceStorageSetting) error { + return s.deleteAttachmentStorage(ctx, attachment, instanceStorageSetting) +} + +func (s *Store) deleteAttachmentStorage(ctx context.Context, attachment *Attachment, instanceStorageSetting *storepb.InstanceStorageSetting) error { if attachment == nil { return nil } @@ -211,12 +227,15 @@ func (s *Store) DeleteAttachmentStorage(ctx context.Context, attachment *Attachm if s3ObjectPayload == nil { return errors.Errorf("No s3 object found") } - instanceStorageSetting, err := s.GetInstanceStorageSetting(ctx) - if err != nil { - return errors.Wrap(err, "failed to get instance storage setting") - } s3Config := s3ObjectPayload.S3Config if s3Config == nil { + if instanceStorageSetting == nil { + var err error + instanceStorageSetting, err = s.GetInstanceStorageSetting(ctx) + if err != nil { + return errors.Wrap(err, "failed to get instance storage setting") + } + } if instanceStorageSetting.S3Config == nil { return errors.Errorf("S3 config is not found") } @@ -240,6 +259,28 @@ func (s *Store) DeleteAttachmentStorage(ctx context.Context, attachment *Attachm return nil } +func (s *Store) getAttachmentStorageCleanupInstanceSetting(ctx context.Context, attachments []*Attachment) (*storepb.InstanceStorageSetting, error) { + for _, attachment := range attachments { + if AttachmentNeedsInstanceStorageSetting(attachment) { + instanceStorageSetting, err := s.GetInstanceStorageSetting(ctx) + if err != nil { + return nil, errors.Wrap(err, "failed to get instance storage setting") + } + return instanceStorageSetting, nil + } + } + return nil, nil +} + +// AttachmentNeedsInstanceStorageSetting reports whether S3 cleanup needs the instance fallback storage setting. +func AttachmentNeedsInstanceStorageSetting(attachment *Attachment) bool { + if attachment == nil || attachment.StorageType != storepb.AttachmentStorageType_S3 { + return false + } + s3ObjectPayload := attachment.Payload.GetS3Object() + return s3ObjectPayload != nil && s3ObjectPayload.S3Config == nil +} + func (s *Store) deleteAttachmentDerivedCaches(attachment *Attachment) { for _, cachePath := range []string{ filepath.Join(s.profile.Data, thumbnailCacheFolder, attachment.UID+".jpeg"), diff --git a/store/attachment_test.go b/store/attachment_test.go new file mode 100644 index 000000000..71db540a7 --- /dev/null +++ b/store/attachment_test.go @@ -0,0 +1,64 @@ +package store + +import ( + "testing" + + storepb "github.com/usememos/memos/proto/gen/store" +) + +func TestAttachmentNeedsInstanceStorageSetting(t *testing.T) { + tests := []struct { + name string + attachment *Attachment + want bool + }{ + { + name: "nil attachment", + }, + { + name: "local attachment", + attachment: &Attachment{ + StorageType: storepb.AttachmentStorageType_LOCAL, + }, + }, + { + name: "s3 attachment without payload", + attachment: &Attachment{ + StorageType: storepb.AttachmentStorageType_S3, + }, + }, + { + name: "s3 attachment with embedded config", + attachment: &Attachment{ + StorageType: storepb.AttachmentStorageType_S3, + Payload: &storepb.AttachmentPayload{ + Payload: &storepb.AttachmentPayload_S3Object_{ + S3Object: &storepb.AttachmentPayload_S3Object{ + S3Config: &storepb.StorageS3Config{}, + }, + }, + }, + }, + }, + { + name: "s3 attachment without embedded config", + attachment: &Attachment{ + StorageType: storepb.AttachmentStorageType_S3, + Payload: &storepb.AttachmentPayload{ + Payload: &storepb.AttachmentPayload_S3Object_{ + S3Object: &storepb.AttachmentPayload_S3Object{}, + }, + }, + }, + want: true, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + if got := AttachmentNeedsInstanceStorageSetting(test.attachment); got != test.want { + t.Fatalf("AttachmentNeedsInstanceStorageSetting() = %v, want %v", got, test.want) + } + }) + } +}