diff --git a/server/router/api/v1/memo_service.go b/server/router/api/v1/memo_service.go index 912caa222..4570188f6 100644 --- a/server/router/api/v1/memo_service.go +++ b/server/router/api/v1/memo_service.go @@ -38,6 +38,37 @@ func isSSESuppressed(ctx context.Context) bool { return ok && v } +func (s *APIV1Service) checkMemoReadAccess(ctx context.Context, memo *store.Memo) error { + if memo == nil { + return status.Errorf(codes.NotFound, "memo not found") + } + + // Archived memos are only visible to their creator. + if memo.RowStatus == store.Archived { + user, err := s.fetchCurrentUser(ctx) + if err != nil { + return status.Errorf(codes.Internal, "failed to get user") + } + if user == nil || memo.CreatorID != user.ID { + return status.Errorf(codes.NotFound, "memo not found") + } + } + + if memo.Visibility != store.Public { + user, err := s.fetchCurrentUser(ctx) + if err != nil { + return status.Errorf(codes.Internal, "failed to get user") + } + if user == nil { + return status.Errorf(codes.Unauthenticated, "user not authenticated") + } + if memo.Visibility == store.Private && memo.CreatorID != user.ID { + return status.Errorf(codes.PermissionDenied, "permission denied") + } + } + return nil +} + func (s *APIV1Service) CreateMemo(ctx context.Context, request *v1pb.CreateMemoRequest) (*v1pb.Memo, error) { user, err := s.fetchCurrentUser(ctx) if err != nil { @@ -335,27 +366,19 @@ func (s *APIV1Service) GetMemo(ctx context.Context, request *v1pb.GetMemoRequest return nil, status.Errorf(codes.NotFound, "memo not found") } - // Archived memos are only visible to their creator. - if memo.RowStatus == store.Archived { - user, err := s.fetchCurrentUser(ctx) - if err != nil { - return nil, status.Errorf(codes.Internal, "failed to get user") - } - if user == nil || memo.CreatorID != user.ID { - return nil, status.Errorf(codes.NotFound, "memo not found") - } + if err := s.checkMemoReadAccess(ctx, memo); err != nil { + return nil, err } - - if memo.Visibility != store.Public { - user, err := s.fetchCurrentUser(ctx) + if memo.ParentUID != nil { + parentMemo, err := s.Store.GetMemo(ctx, &store.FindMemo{UID: memo.ParentUID}) if err != nil { - return nil, status.Errorf(codes.Internal, "failed to get user") + return nil, status.Errorf(codes.Internal, "failed to get parent memo") } - if user == nil { - return nil, status.Errorf(codes.Unauthenticated, "user not authenticated") + if parentMemo == nil { + return nil, status.Errorf(codes.NotFound, "memo not found") } - if memo.Visibility == store.Private && memo.CreatorID != user.ID { - return nil, status.Errorf(codes.PermissionDenied, "permission denied") + if err := s.checkMemoReadAccess(ctx, parentMemo); err != nil { + return nil, err } } @@ -486,6 +509,16 @@ func (s *APIV1Service) UpdateMemo(ctx context.Context, request *v1pb.UpdateMemoR update.Payload = memo.Payload } else if path == "visibility" { visibility := convertVisibilityToStore(request.Memo.Visibility) + if memo.ParentUID != nil { + parentMemo, err := s.Store.GetMemo(ctx, &store.FindMemo{UID: memo.ParentUID}) + if err != nil { + return nil, status.Errorf(codes.Internal, "failed to get parent memo") + } + if parentMemo == nil { + return nil, status.Errorf(codes.NotFound, "memo not found") + } + visibility = parentMemo.Visibility + } update.Visibility = &visibility } else if path == "pinned" { update.Pinned = &request.Memo.Pinned @@ -641,11 +674,17 @@ func (s *APIV1Service) CreateMemoComment(ctx context.Context, request *v1pb.Crea if relatedMemo.Visibility == store.Private && relatedMemo.CreatorID != user.ID && !isSuperUser(user) { return nil, status.Errorf(codes.PermissionDenied, "permission denied") } + if request.Comment == nil { + return nil, status.Errorf(codes.InvalidArgument, "comment is required") + } + + comment := *request.Comment + comment.Visibility = convertVisibilityFromStore(relatedMemo.Visibility) // Create the memo comment first; suppress the generic memo.created SSE event // since CreateMemoComment broadcasts memo.comment.created for the parent instead. memoComment, err := s.CreateMemo(withSuppressMentionNotifications(withSuppressSSE(ctx)), &v1pb.CreateMemoRequest{ - Memo: request.Comment, + Memo: &comment, MemoId: request.CommentId, }) if err != nil { @@ -722,6 +761,12 @@ func (s *APIV1Service) ListMemoComments(ctx context.Context, request *v1pb.ListM if err != nil { return nil, status.Errorf(codes.Internal, "failed to get memo") } + if memo == nil { + return nil, status.Errorf(codes.NotFound, "memo not found") + } + if err := s.checkMemoReadAccess(ctx, memo); err != nil { + return nil, err + } currentUser, err := s.fetchCurrentUser(ctx) if err != nil { diff --git a/server/router/api/v1/test/memo_service_test.go b/server/router/api/v1/test/memo_service_test.go index dffff1ec5..c07ba269c 100644 --- a/server/router/api/v1/test/memo_service_test.go +++ b/server/router/api/v1/test/memo_service_test.go @@ -8,6 +8,9 @@ import ( "time" "github.com/stretchr/testify/require" + "google.golang.org/grpc/codes" + "google.golang.org/grpc/status" + "google.golang.org/protobuf/types/known/fieldmaskpb" "google.golang.org/protobuf/types/known/timestamppb" apiv1 "github.com/usememos/memos/proto/gen/api/v1" @@ -527,6 +530,113 @@ func TestListMemoCommentsPaginates(t *testing.T) { require.Empty(t, secondPage.NextPageToken) } +func TestCreateMemoCommentInheritsParentVisibility(t *testing.T) { + ctx := context.Background() + + ts := NewTestService(t) + defer ts.Cleanup() + + owner, err := ts.CreateRegularUser(ctx, "private-comment-owner") + require.NoError(t, err) + ownerCtx := ts.CreateUserContext(ctx, owner.ID) + + parent, err := ts.Service.CreateMemo(ownerCtx, &apiv1.CreateMemoRequest{ + Memo: &apiv1.Memo{ + Content: "private parent", + Visibility: apiv1.Visibility_PRIVATE, + }, + }) + require.NoError(t, err) + + comment, err := ts.Service.CreateMemoComment(ownerCtx, &apiv1.CreateMemoCommentRequest{ + Name: parent.Name, + Comment: &apiv1.Memo{ + Content: "client requested public comment", + Visibility: apiv1.Visibility_PUBLIC, + }, + }) + require.NoError(t, err) + require.Equal(t, apiv1.Visibility_PRIVATE, comment.Visibility) + + updatedComment, err := ts.Service.UpdateMemo(ownerCtx, &apiv1.UpdateMemoRequest{ + Memo: &apiv1.Memo{ + Name: comment.Name, + Visibility: apiv1.Visibility_PUBLIC, + }, + UpdateMask: &fieldmaskpb.FieldMask{Paths: []string{"visibility"}}, + }) + require.NoError(t, err) + require.Equal(t, apiv1.Visibility_PRIVATE, updatedComment.Visibility) + + _, err = ts.Service.GetMemo(ctx, &apiv1.GetMemoRequest{Name: comment.Name}) + require.Equal(t, codes.Unauthenticated, status.Code(err)) +} + +func TestGetMemoCommentRequiresParentReadAccess(t *testing.T) { + ctx := context.Background() + + ts := NewTestService(t) + defer ts.Cleanup() + + owner, err := ts.CreateRegularUser(ctx, "legacy-comment-owner") + require.NoError(t, err) + ownerCtx := ts.CreateUserContext(ctx, owner.ID) + + other, err := ts.CreateRegularUser(ctx, "legacy-comment-other") + require.NoError(t, err) + otherCtx := ts.CreateUserContext(ctx, other.ID) + + parent, err := ts.Service.CreateMemo(ownerCtx, &apiv1.CreateMemoRequest{ + Memo: &apiv1.Memo{ + Content: "private parent for legacy comment", + Visibility: apiv1.Visibility_PRIVATE, + }, + }) + require.NoError(t, err) + + legacyComment, err := ts.Store.CreateMemo(ctx, &store.Memo{ + UID: "legacy-public-comment", + CreatorID: owner.ID, + Content: "legacy public comment under private parent", + Visibility: store.Public, + }) + require.NoError(t, err) + + parentUID := parent.Name[len("memos/"):] + parentMemo, err := ts.Store.GetMemo(ctx, &store.FindMemo{UID: &parentUID}) + require.NoError(t, err) + require.NotNil(t, parentMemo) + + _, err = ts.Store.UpsertMemoRelation(ctx, &store.MemoRelation{ + MemoID: legacyComment.ID, + RelatedMemoID: parentMemo.ID, + Type: store.MemoRelationComment, + }) + require.NoError(t, err) + + commentName := "memos/" + legacyComment.UID + _, err = ts.Service.GetMemo(ctx, &apiv1.GetMemoRequest{Name: commentName}) + require.Equal(t, codes.Unauthenticated, status.Code(err)) + + _, err = ts.Service.GetMemo(otherCtx, &apiv1.GetMemoRequest{Name: commentName}) + require.Equal(t, codes.PermissionDenied, status.Code(err)) + + comment, err := ts.Service.GetMemo(ownerCtx, &apiv1.GetMemoRequest{Name: commentName}) + require.NoError(t, err) + require.Equal(t, parent.Name, comment.GetParent()) + + _, err = ts.Service.ListMemoComments(ctx, &apiv1.ListMemoCommentsRequest{Name: parent.Name}) + require.Equal(t, codes.Unauthenticated, status.Code(err)) + + _, err = ts.Service.ListMemoComments(otherCtx, &apiv1.ListMemoCommentsRequest{Name: parent.Name}) + require.Equal(t, codes.PermissionDenied, status.Code(err)) + + comments, err := ts.Service.ListMemoComments(ownerCtx, &apiv1.ListMemoCommentsRequest{Name: parent.Name}) + require.NoError(t, err) + require.Len(t, comments.Memos, 1) + require.Equal(t, commentName, comments.Memos[0].Name) +} + // TestCreateMemoWithCustomTimestamps tests that custom timestamps can be set when creating memos and comments. // This addresses issue #5483: https://github.com/usememos/memos/issues/5483 func TestCreateMemoWithCustomTimestamps(t *testing.T) { diff --git a/server/router/rss/rss.go b/server/router/rss/rss.go index 472ddfed5..f30c65c3d 100644 --- a/server/router/rss/rss.go +++ b/server/router/rss/rss.go @@ -86,9 +86,10 @@ func (s *RSSService) GetExploreRSS(c *echo.Context) error { normalStatus := store.Normal limit := maxRSSItemCount memoFind := store.FindMemo{ - RowStatus: &normalStatus, - VisibilityList: []store.Visibility{store.Public}, - Limit: &limit, + RowStatus: &normalStatus, + VisibilityList: []store.Visibility{store.Public}, + ExcludeComments: true, + Limit: &limit, } memoList, err := s.Store.ListMemos(ctx, &memoFind) if err != nil { @@ -135,10 +136,11 @@ func (s *RSSService) GetUserRSS(c *echo.Context) error { normalStatus := store.Normal limit := maxRSSItemCount memoFind := store.FindMemo{ - CreatorID: &user.ID, - RowStatus: &normalStatus, - VisibilityList: []store.Visibility{store.Public}, - Limit: &limit, + CreatorID: &user.ID, + RowStatus: &normalStatus, + VisibilityList: []store.Visibility{store.Public}, + ExcludeComments: true, + Limit: &limit, } memoList, err := s.Store.ListMemos(ctx, &memoFind) if err != nil { diff --git a/server/router/rss/rss_test.go b/server/router/rss/rss_test.go new file mode 100644 index 000000000..3324b8061 --- /dev/null +++ b/server/router/rss/rss_test.go @@ -0,0 +1,86 @@ +package rss + +import ( + "context" + "net/http" + "net/http/httptest" + "strings" + "testing" + + "github.com/labstack/echo/v5" + "github.com/stretchr/testify/require" + + "github.com/usememos/memos/internal/markdown" + "github.com/usememos/memos/internal/profile" + "github.com/usememos/memos/store" + teststore "github.com/usememos/memos/store/test" +) + +func TestPublicRSSExcludesComments(t *testing.T) { + ctx := context.Background() + stores := teststore.NewTestingStore(ctx, t) + defer stores.Close() + + user, err := stores.CreateUser(ctx, &store.User{ + Username: "rss-comment-owner", + Role: store.RoleUser, + Email: "rss-comment-owner@example.com", + }) + require.NoError(t, err) + + parent, err := stores.CreateMemo(ctx, &store.Memo{ + UID: "rss-public-parent", + CreatorID: user.ID, + Content: "public parent should stay in rss", + Visibility: store.Public, + }) + require.NoError(t, err) + + comment, err := stores.CreateMemo(ctx, &store.Memo{ + UID: "rss-public-comment", + CreatorID: user.ID, + Content: "public comment should not be in rss", + Visibility: store.Public, + }) + require.NoError(t, err) + + _, err = stores.UpsertMemoRelation(ctx, &store.MemoRelation{ + MemoID: comment.ID, + RelatedMemoID: parent.ID, + Type: store.MemoRelationComment, + }) + require.NoError(t, err) + + service := NewRSSService(&profile.Profile{}, stores, markdown.NewService()) + + exploreRSS := renderRSS(t, service, "/explore/rss.xml", "") + require.Contains(t, exploreRSS, "public parent should stay in rss") + require.NotContains(t, exploreRSS, "public comment should not be in rss") + + userRSS := renderRSS(t, service, "/u/rss-comment-owner/rss.xml", user.Username) + require.Contains(t, userRSS, "public parent should stay in rss") + require.NotContains(t, userRSS, "public comment should not be in rss") +} + +func renderRSS(t *testing.T, service *RSSService, target string, username string) string { + t.Helper() + + e := echo.New() + req := httptest.NewRequest(http.MethodGet, target, strings.NewReader("")) + req.Host = "example.com" + rec := httptest.NewRecorder() + c := e.NewContext(req, rec) + if username != "" { + c.SetPathValues(echo.PathValues{{Name: "username", Value: username}}) + } + + var err error + if username == "" { + err = service.GetExploreRSS(c) + } else { + err = service.GetUserRSS(c) + } + require.NoError(t, err) + require.Equal(t, http.StatusOK, rec.Code) + return rec.Body.String() +}