mirror of https://github.com/usememos/memos
You cannot select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
606 lines
21 KiB
Go
606 lines
21 KiB
Go
package mcp
|
|
|
|
import (
|
|
"bytes"
|
|
"context"
|
|
"encoding/json"
|
|
"net/http"
|
|
"net/http/httptest"
|
|
"reflect"
|
|
"testing"
|
|
"time"
|
|
"unsafe"
|
|
|
|
"github.com/labstack/echo/v5"
|
|
"github.com/lithammer/shortuuid/v4"
|
|
"github.com/mark3labs/mcp-go/mcp"
|
|
"github.com/stretchr/testify/require"
|
|
|
|
"github.com/usememos/memos/internal/profile"
|
|
storepb "github.com/usememos/memos/proto/gen/store"
|
|
"github.com/usememos/memos/server/auth"
|
|
apiv1service "github.com/usememos/memos/server/router/api/v1"
|
|
"github.com/usememos/memos/store"
|
|
teststore "github.com/usememos/memos/store/test"
|
|
)
|
|
|
|
type testMCPService struct {
|
|
service *MCPService
|
|
store *store.Store
|
|
}
|
|
|
|
func newTestMCPService(t *testing.T) *testMCPService {
|
|
t.Helper()
|
|
|
|
ctx := context.Background()
|
|
stores := teststore.NewTestingStore(ctx, t)
|
|
t.Cleanup(func() {
|
|
require.NoError(t, stores.Close())
|
|
})
|
|
|
|
profile := &profile.Profile{
|
|
Driver: "sqlite",
|
|
InstanceURL: "https://notes.example.com",
|
|
}
|
|
apiV1Service := apiv1service.NewAPIV1Service("test-secret", profile, stores)
|
|
svc := NewMCPService(profile, stores, "test-secret", apiV1Service)
|
|
return &testMCPService{
|
|
service: svc,
|
|
store: stores,
|
|
}
|
|
}
|
|
|
|
func (s *testMCPService) createUser(t *testing.T, username string) *store.User {
|
|
t.Helper()
|
|
|
|
user, err := s.store.CreateUser(context.Background(), &store.User{
|
|
Username: username,
|
|
Role: store.RoleUser,
|
|
Email: username + "@example.com",
|
|
})
|
|
require.NoError(t, err)
|
|
return user
|
|
}
|
|
|
|
func (s *testMCPService) createMemo(t *testing.T, creatorID int32, visibility store.Visibility, content string) *store.Memo {
|
|
t.Helper()
|
|
|
|
memo, err := s.store.CreateMemo(context.Background(), &store.Memo{
|
|
UID: shortuuid.New(),
|
|
CreatorID: creatorID,
|
|
RowStatus: store.Normal,
|
|
Visibility: visibility,
|
|
Content: content,
|
|
})
|
|
require.NoError(t, err)
|
|
return memo
|
|
}
|
|
|
|
func (s *testMCPService) archiveMemo(t *testing.T, memoID int32) {
|
|
t.Helper()
|
|
|
|
rowStatus := store.Archived
|
|
require.NoError(t, s.store.UpdateMemo(context.Background(), &store.UpdateMemo{
|
|
ID: memoID,
|
|
RowStatus: &rowStatus,
|
|
}))
|
|
}
|
|
|
|
func (s *testMCPService) createAttachment(t *testing.T, creatorID int32, memoID *int32) *store.Attachment {
|
|
t.Helper()
|
|
|
|
attachment, err := s.store.CreateAttachment(context.Background(), &store.Attachment{
|
|
UID: shortuuid.New(),
|
|
CreatorID: creatorID,
|
|
Filename: "note.txt",
|
|
Type: "text/plain",
|
|
Size: 4,
|
|
StorageType: storepb.AttachmentStorageType_ATTACHMENT_STORAGE_TYPE_UNSPECIFIED,
|
|
Reference: "db://attachment/note.txt",
|
|
MemoID: memoID,
|
|
})
|
|
require.NoError(t, err)
|
|
return attachment
|
|
}
|
|
|
|
func withUser(ctx context.Context, userID int32) context.Context {
|
|
return context.WithValue(ctx, auth.UserIDContextKey, userID)
|
|
}
|
|
|
|
func toolRequest(name string, arguments map[string]any) mcp.CallToolRequest {
|
|
return mcp.CallToolRequest{
|
|
Params: mcp.CallToolParams{
|
|
Name: name,
|
|
Arguments: arguments,
|
|
},
|
|
}
|
|
}
|
|
|
|
func firstText(t *testing.T, result *mcp.CallToolResult) string {
|
|
t.Helper()
|
|
require.NotEmpty(t, result.Content)
|
|
text, ok := result.Content[0].(mcp.TextContent)
|
|
require.True(t, ok)
|
|
return text.Text
|
|
}
|
|
|
|
func initializeMCPHTTP(t *testing.T, e *echo.Echo, path string, headers map[string]string) string {
|
|
t.Helper()
|
|
payload := map[string]any{
|
|
"jsonrpc": "2.0",
|
|
"id": 1,
|
|
"method": "initialize",
|
|
"params": map[string]any{
|
|
"protocolVersion": "2025-06-18",
|
|
"capabilities": map[string]any{},
|
|
"clientInfo": map[string]any{
|
|
"name": "mcp-test",
|
|
"version": "1.0.0",
|
|
},
|
|
},
|
|
}
|
|
resp := postMCPHTTP(t, e, path, "", headers, payload)
|
|
require.Equal(t, http.StatusOK, resp.Code, resp.Body.String())
|
|
sessionID := resp.Header().Get("Mcp-Session-Id")
|
|
require.NotEmpty(t, sessionID)
|
|
return sessionID
|
|
}
|
|
|
|
func callMCPHTTP(t *testing.T, e *echo.Echo, path string, sessionID string, headers map[string]string, method string, params any) map[string]any {
|
|
t.Helper()
|
|
payload := map[string]any{
|
|
"jsonrpc": "2.0",
|
|
"id": 2,
|
|
"method": method,
|
|
}
|
|
if params != nil {
|
|
payload["params"] = params
|
|
}
|
|
resp := postMCPHTTP(t, e, path, sessionID, headers, payload)
|
|
require.Equal(t, http.StatusOK, resp.Code, resp.Body.String())
|
|
var decoded map[string]any
|
|
require.NoError(t, json.Unmarshal(resp.Body.Bytes(), &decoded))
|
|
return decoded
|
|
}
|
|
|
|
func postMCPHTTP(t *testing.T, e *echo.Echo, path string, sessionID string, headers map[string]string, payload map[string]any) *httptest.ResponseRecorder {
|
|
t.Helper()
|
|
body, err := json.Marshal(payload)
|
|
require.NoError(t, err)
|
|
req := httptest.NewRequest(http.MethodPost, path, bytes.NewReader(body))
|
|
req.Header.Set("Content-Type", "application/json")
|
|
if sessionID != "" {
|
|
req.Header.Set("Mcp-Session-Id", sessionID)
|
|
}
|
|
for key, value := range headers {
|
|
req.Header.Set(key, value)
|
|
}
|
|
resp := httptest.NewRecorder()
|
|
e.ServeHTTP(resp, req)
|
|
return resp
|
|
}
|
|
|
|
func toolNamesFromListResponse(t *testing.T, response map[string]any) map[string]struct{} {
|
|
t.Helper()
|
|
result, ok := response["result"].(map[string]any)
|
|
require.True(t, ok, "missing result: %#v", response)
|
|
rawTools, ok := result["tools"].([]any)
|
|
require.True(t, ok, "missing tools: %#v", result)
|
|
names := map[string]struct{}{}
|
|
for _, rawTool := range rawTools {
|
|
tool, ok := rawTool.(map[string]any)
|
|
require.True(t, ok)
|
|
name, ok := tool["name"].(string)
|
|
require.True(t, ok)
|
|
names[name] = struct{}{}
|
|
}
|
|
return names
|
|
}
|
|
|
|
func requireToolPresent(t *testing.T, names map[string]struct{}, name string) {
|
|
t.Helper()
|
|
_, ok := names[name]
|
|
require.True(t, ok, "expected tool %q to be present in %#v", name, names)
|
|
}
|
|
|
|
func requireToolAbsent(t *testing.T, names map[string]struct{}, name string) {
|
|
t.Helper()
|
|
_, ok := names[name]
|
|
require.False(t, ok, "expected tool %q to be absent in %#v", name, names)
|
|
}
|
|
|
|
func nextSSEEvent(t *testing.T, client *apiv1service.SSEClient) *apiv1service.SSEEvent {
|
|
t.Helper()
|
|
events := sseClientEvents(t, client)
|
|
var data []byte
|
|
select {
|
|
case eventData, ok := <-events:
|
|
require.True(t, ok, "SSE client channel closed")
|
|
data = eventData
|
|
case <-time.After(time.Second):
|
|
t.Fatal("timed out waiting for SSE event")
|
|
}
|
|
var event apiv1service.SSEEvent
|
|
require.NoError(t, json.Unmarshal(data, &event))
|
|
return &event
|
|
}
|
|
|
|
func requireNoSSEEvent(t *testing.T, client *apiv1service.SSEClient) {
|
|
t.Helper()
|
|
select {
|
|
case eventData, ok := <-sseClientEvents(t, client):
|
|
require.True(t, ok, "SSE client channel closed")
|
|
t.Fatalf("unexpected SSE event received: %s", string(eventData))
|
|
case <-time.After(150 * time.Millisecond):
|
|
}
|
|
}
|
|
|
|
func sseClientEvents(t *testing.T, client *apiv1service.SSEClient) <-chan []byte {
|
|
t.Helper()
|
|
field := reflect.ValueOf(client).Elem().FieldByName("events")
|
|
events, ok := reflect.NewAt(field.Type(), unsafe.Pointer(field.UnsafeAddr())).Elem().Interface().(chan []byte)
|
|
require.True(t, ok)
|
|
return events
|
|
}
|
|
|
|
func TestHandleGetMemoAndReadResourceDenyArchivedMemoToNonCreator(t *testing.T) {
|
|
ts := newTestMCPService(t)
|
|
owner := ts.createUser(t, "owner")
|
|
other := ts.createUser(t, "other")
|
|
|
|
memo := ts.createMemo(t, owner.ID, store.Public, "archived")
|
|
ts.archiveMemo(t, memo.ID)
|
|
|
|
ctx := withUser(context.Background(), other.ID)
|
|
result, err := ts.service.handleGetMemo(ctx, toolRequest("get_memo", map[string]any{
|
|
"name": "memos/" + memo.UID,
|
|
}))
|
|
require.NoError(t, err)
|
|
require.True(t, result.IsError)
|
|
require.Contains(t, firstText(t, result), "permission denied")
|
|
|
|
_, err = ts.service.handleReadMemoResource(ctx, mcp.ReadResourceRequest{
|
|
Params: mcp.ReadResourceParams{
|
|
URI: "memo://memos/" + memo.UID,
|
|
},
|
|
})
|
|
require.ErrorContains(t, err, "permission denied")
|
|
}
|
|
|
|
func TestHandleListMemosArchivedOnlyReturnsCreatorMemos(t *testing.T) {
|
|
ts := newTestMCPService(t)
|
|
owner := ts.createUser(t, "owner")
|
|
other := ts.createUser(t, "other")
|
|
|
|
ownerMemo := ts.createMemo(t, owner.ID, store.Public, "owner archived")
|
|
ts.archiveMemo(t, ownerMemo.ID)
|
|
otherMemo := ts.createMemo(t, other.ID, store.Public, "other archived")
|
|
ts.archiveMemo(t, otherMemo.ID)
|
|
|
|
result, err := ts.service.handleListMemos(withUser(context.Background(), owner.ID), toolRequest("list_memos", map[string]any{
|
|
"state": "ARCHIVED",
|
|
}))
|
|
require.NoError(t, err)
|
|
require.False(t, result.IsError)
|
|
|
|
var payload struct {
|
|
Memos []memoJSON `json:"memos"`
|
|
}
|
|
require.NoError(t, json.Unmarshal([]byte(firstText(t, result)), &payload))
|
|
require.Len(t, payload.Memos, 1)
|
|
require.Equal(t, "memos/"+ownerMemo.UID, payload.Memos[0].Name)
|
|
|
|
anonResult, err := ts.service.handleListMemos(context.Background(), toolRequest("list_memos", map[string]any{
|
|
"state": "ARCHIVED",
|
|
}))
|
|
require.NoError(t, err)
|
|
require.NoError(t, json.Unmarshal([]byte(firstText(t, anonResult)), &payload))
|
|
require.Empty(t, payload.Memos)
|
|
}
|
|
|
|
func TestHandleListMemoRelationsFiltersUnreadableTargets(t *testing.T) {
|
|
ts := newTestMCPService(t)
|
|
owner := ts.createUser(t, "owner")
|
|
privateUser := ts.createUser(t, "private-user")
|
|
publicUser := ts.createUser(t, "public-user")
|
|
|
|
source := ts.createMemo(t, owner.ID, store.Public, "source")
|
|
privateTarget := ts.createMemo(t, privateUser.ID, store.Private, "private")
|
|
publicTarget := ts.createMemo(t, publicUser.ID, store.Public, "public")
|
|
|
|
_, err := ts.store.UpsertMemoRelation(context.Background(), &store.MemoRelation{
|
|
MemoID: source.ID,
|
|
RelatedMemoID: privateTarget.ID,
|
|
Type: store.MemoRelationReference,
|
|
})
|
|
require.NoError(t, err)
|
|
_, err = ts.store.UpsertMemoRelation(context.Background(), &store.MemoRelation{
|
|
MemoID: source.ID,
|
|
RelatedMemoID: publicTarget.ID,
|
|
Type: store.MemoRelationReference,
|
|
})
|
|
require.NoError(t, err)
|
|
|
|
result, err := ts.service.handleListMemoRelations(context.Background(), toolRequest("list_memo_relations", map[string]any{
|
|
"name": "memos/" + source.UID,
|
|
}))
|
|
require.NoError(t, err)
|
|
require.False(t, result.IsError)
|
|
|
|
var relations []relationJSON
|
|
require.NoError(t, json.Unmarshal([]byte(firstText(t, result)), &relations))
|
|
require.Len(t, relations, 1)
|
|
require.Equal(t, "memos/"+publicTarget.UID, relations[0].RelatedMemo)
|
|
|
|
denied, err := ts.service.handleListMemoRelations(context.Background(), toolRequest("list_memo_relations", map[string]any{
|
|
"name": "memos/" + privateTarget.UID,
|
|
}))
|
|
require.NoError(t, err)
|
|
require.True(t, denied.IsError)
|
|
require.Contains(t, firstText(t, denied), "permission denied")
|
|
}
|
|
|
|
func TestHandleLinkAttachmentToMemoRequiresMemoOwnership(t *testing.T) {
|
|
ts := newTestMCPService(t)
|
|
attachmentOwner := ts.createUser(t, "attachment-owner")
|
|
memoOwner := ts.createUser(t, "memo-owner")
|
|
|
|
attachment := ts.createAttachment(t, attachmentOwner.ID, nil)
|
|
memo := ts.createMemo(t, memoOwner.ID, store.Public, "target")
|
|
|
|
result, err := ts.service.handleLinkAttachmentToMemo(withUser(context.Background(), attachmentOwner.ID), toolRequest("link_attachment_to_memo", map[string]any{
|
|
"name": "attachments/" + attachment.UID,
|
|
"memo": "memos/" + memo.UID,
|
|
}))
|
|
require.NoError(t, err)
|
|
require.True(t, result.IsError)
|
|
require.Contains(t, firstText(t, result), "permission denied")
|
|
}
|
|
|
|
func TestHandleGetAttachmentDeniesArchivedLinkedMemoToNonCreator(t *testing.T) {
|
|
ts := newTestMCPService(t)
|
|
owner := ts.createUser(t, "owner")
|
|
other := ts.createUser(t, "other")
|
|
|
|
memo := ts.createMemo(t, owner.ID, store.Public, "memo")
|
|
ts.archiveMemo(t, memo.ID)
|
|
attachment := ts.createAttachment(t, owner.ID, &memo.ID)
|
|
|
|
result, err := ts.service.handleGetAttachment(withUser(context.Background(), other.ID), toolRequest("get_attachment", map[string]any{
|
|
"name": "attachments/" + attachment.UID,
|
|
}))
|
|
require.NoError(t, err)
|
|
require.True(t, result.IsError)
|
|
require.Contains(t, firstText(t, result), "permission denied")
|
|
}
|
|
|
|
func TestIsAllowedOrigin(t *testing.T) {
|
|
ts := newTestMCPService(t)
|
|
|
|
t.Run("allow missing origin", func(t *testing.T) {
|
|
req := httptest.NewRequest("POST", "http://localhost:5230/mcp", nil)
|
|
require.True(t, ts.service.isAllowedOrigin(req))
|
|
})
|
|
|
|
t.Run("allow same origin as request host", func(t *testing.T) {
|
|
req := httptest.NewRequest("POST", "http://localhost:5230/mcp", nil)
|
|
req.Header.Set("Origin", "http://localhost:5230")
|
|
require.True(t, ts.service.isAllowedOrigin(req))
|
|
})
|
|
|
|
t.Run("allow configured instance origin", func(t *testing.T) {
|
|
req := httptest.NewRequest("POST", "http://127.0.0.1:5230/mcp", nil)
|
|
req.Host = "127.0.0.1:5230"
|
|
req.Header.Set("Origin", "https://notes.example.com")
|
|
require.True(t, ts.service.isAllowedOrigin(req))
|
|
})
|
|
|
|
t.Run("reject cross origin", func(t *testing.T) {
|
|
req := httptest.NewRequest("POST", "http://localhost:5230/mcp", nil)
|
|
req.Header.Set("Origin", "https://evil.example.com")
|
|
require.False(t, ts.service.isAllowedOrigin(req))
|
|
})
|
|
}
|
|
|
|
func TestMCPToolFilteringRoutesAndHeaders(t *testing.T) {
|
|
ts := newTestMCPService(t)
|
|
e := echo.New()
|
|
ts.service.RegisterRoutes(e)
|
|
|
|
t.Run("default endpoint lists all tools", func(t *testing.T) {
|
|
sessionID := initializeMCPHTTP(t, e, "/mcp", nil)
|
|
response := callMCPHTTP(t, e, "/mcp", sessionID, nil, "tools/list", map[string]any{})
|
|
names := toolNamesFromListResponse(t, response)
|
|
require.Len(t, names, len(allMCPToolNames))
|
|
requireToolPresent(t, names, "create_memo")
|
|
requireToolPresent(t, names, "list_tags")
|
|
requireToolPresent(t, names, "upsert_reaction")
|
|
})
|
|
|
|
t.Run("readonly header hides and blocks mutation tools", func(t *testing.T) {
|
|
headers := map[string]string{headerMCPReadonly: "true"}
|
|
sessionID := initializeMCPHTTP(t, e, "/mcp", nil)
|
|
response := callMCPHTTP(t, e, "/mcp", sessionID, headers, "tools/list", map[string]any{})
|
|
names := toolNamesFromListResponse(t, response)
|
|
requireToolPresent(t, names, "list_memos")
|
|
requireToolPresent(t, names, "list_tags")
|
|
requireToolAbsent(t, names, "create_memo")
|
|
requireToolAbsent(t, names, "delete_memo")
|
|
|
|
callResponse := callMCPHTTP(t, e, "/mcp", sessionID, headers, "tools/call", map[string]any{
|
|
"name": "create_memo",
|
|
"arguments": map[string]any{"content": "blocked"},
|
|
})
|
|
result, ok := callResponse["result"].(map[string]any)
|
|
require.True(t, ok)
|
|
require.Equal(t, true, result["isError"])
|
|
rawContent, ok := result["content"].([]any)
|
|
require.True(t, ok)
|
|
content, ok := rawContent[0].(map[string]any)
|
|
require.True(t, ok)
|
|
require.Contains(t, content["text"], "not enabled")
|
|
})
|
|
|
|
t.Run("readonly alias applies path config", func(t *testing.T) {
|
|
sessionID := initializeMCPHTTP(t, e, "/mcp/readonly", nil)
|
|
response := callMCPHTTP(t, e, "/mcp/readonly", sessionID, nil, "tools/list", map[string]any{})
|
|
names := toolNamesFromListResponse(t, response)
|
|
requireToolPresent(t, names, "get_memo")
|
|
requireToolAbsent(t, names, "create_memo")
|
|
requireToolAbsent(t, names, "upsert_reaction")
|
|
})
|
|
|
|
t.Run("toolsets include and exclude compose", func(t *testing.T) {
|
|
headers := map[string]string{
|
|
headerMCPToolsets: "memos",
|
|
headerMCPTools: "list_tags",
|
|
headerMCPExcludeTools: "get_memo",
|
|
}
|
|
sessionID := initializeMCPHTTP(t, e, "/mcp", nil)
|
|
response := callMCPHTTP(t, e, "/mcp", sessionID, headers, "tools/list", map[string]any{})
|
|
names := toolNamesFromListResponse(t, response)
|
|
requireToolPresent(t, names, "list_memos")
|
|
requireToolPresent(t, names, "list_tags")
|
|
requireToolAbsent(t, names, "get_memo")
|
|
requireToolAbsent(t, names, "list_attachments")
|
|
})
|
|
|
|
t.Run("path toolsets and readonly compose", func(t *testing.T) {
|
|
sessionID := initializeMCPHTTP(t, e, "/mcp/x/memos,tags/readonly", nil)
|
|
response := callMCPHTTP(t, e, "/mcp/x/memos,tags/readonly", sessionID, nil, "tools/list", map[string]any{})
|
|
names := toolNamesFromListResponse(t, response)
|
|
requireToolPresent(t, names, "list_memos")
|
|
requireToolPresent(t, names, "list_tags")
|
|
requireToolAbsent(t, names, "create_memo")
|
|
requireToolAbsent(t, names, "list_attachments")
|
|
})
|
|
|
|
t.Run("unknown toolset returns empty tool list", func(t *testing.T) {
|
|
sessionID := initializeMCPHTTP(t, e, "/mcp/x/notreal", nil)
|
|
response := callMCPHTTP(t, e, "/mcp/x/notreal", sessionID, nil, "tools/list", map[string]any{})
|
|
names := toolNamesFromListResponse(t, response)
|
|
require.Empty(t, names)
|
|
})
|
|
}
|
|
|
|
func TestMCPMemoAndReactionMutationsEmitSSEEvents(t *testing.T) {
|
|
ts := newTestMCPService(t)
|
|
user := ts.createUser(t, "author")
|
|
ctx := withUser(context.Background(), user.ID)
|
|
client := ts.service.apiV1Service.SSEHub.Subscribe(user.ID, store.RoleUser)
|
|
defer ts.service.apiV1Service.SSEHub.Unsubscribe(client)
|
|
|
|
createResult, err := ts.service.handleCreateMemo(ctx, toolRequest("create_memo", map[string]any{
|
|
"content": "created from MCP",
|
|
"visibility": "PRIVATE",
|
|
}))
|
|
require.NoError(t, err)
|
|
require.False(t, createResult.IsError)
|
|
createEvent := nextSSEEvent(t, client)
|
|
require.Equal(t, apiv1service.SSEEventMemoCreated, createEvent.Type)
|
|
|
|
var created memoJSON
|
|
require.NoError(t, json.Unmarshal([]byte(firstText(t, createResult)), &created))
|
|
|
|
updateResult, err := ts.service.handleUpdateMemo(ctx, toolRequest("update_memo", map[string]any{
|
|
"name": created.Name,
|
|
"content": "updated from MCP",
|
|
}))
|
|
require.NoError(t, err)
|
|
require.False(t, updateResult.IsError)
|
|
updateEvent := nextSSEEvent(t, client)
|
|
require.Equal(t, apiv1service.SSEEventMemoUpdated, updateEvent.Type)
|
|
|
|
commentResult, err := ts.service.handleCreateMemoComment(ctx, toolRequest("create_memo_comment", map[string]any{
|
|
"name": created.Name,
|
|
"content": "comment from MCP",
|
|
}))
|
|
require.NoError(t, err)
|
|
require.False(t, commentResult.IsError)
|
|
commentEvent := nextSSEEvent(t, client)
|
|
require.Equal(t, apiv1service.SSEEventMemoCommentCreated, commentEvent.Type)
|
|
require.Equal(t, created.Name, commentEvent.Name)
|
|
|
|
upsertReactionResult, err := ts.service.handleUpsertReaction(ctx, toolRequest("upsert_reaction", map[string]any{
|
|
"name": created.Name,
|
|
"reaction_type": "👍",
|
|
}))
|
|
require.NoError(t, err)
|
|
require.False(t, upsertReactionResult.IsError)
|
|
reactionEvent := nextSSEEvent(t, client)
|
|
require.Equal(t, apiv1service.SSEEventReactionUpserted, reactionEvent.Type)
|
|
|
|
var reaction reactionJSON
|
|
require.NoError(t, json.Unmarshal([]byte(firstText(t, upsertReactionResult)), &reaction))
|
|
deleteReactionResult, err := ts.service.handleDeleteReaction(ctx, toolRequest("delete_reaction", map[string]any{
|
|
"id": float64(reaction.ID),
|
|
}))
|
|
require.NoError(t, err)
|
|
require.False(t, deleteReactionResult.IsError)
|
|
deleteReactionEvent := nextSSEEvent(t, client)
|
|
require.Equal(t, apiv1service.SSEEventReactionDeleted, deleteReactionEvent.Type)
|
|
|
|
deleteResult, err := ts.service.handleDeleteMemo(ctx, toolRequest("delete_memo", map[string]any{
|
|
"name": created.Name,
|
|
}))
|
|
require.NoError(t, err)
|
|
require.False(t, deleteResult.IsError)
|
|
deleteEvent := nextSSEEvent(t, client)
|
|
require.Equal(t, apiv1service.SSEEventMemoDeleted, deleteEvent.Type)
|
|
}
|
|
|
|
func TestMCPRelationAndAttachmentMutationsEmitMemoUpdated(t *testing.T) {
|
|
ts := newTestMCPService(t)
|
|
user := ts.createUser(t, "owner")
|
|
ctx := withUser(context.Background(), user.ID)
|
|
source := ts.createMemo(t, user.ID, store.Private, "source")
|
|
target := ts.createMemo(t, user.ID, store.Private, "target")
|
|
attachment := ts.createAttachment(t, user.ID, nil)
|
|
client := ts.service.apiV1Service.SSEHub.Subscribe(user.ID, store.RoleUser)
|
|
defer ts.service.apiV1Service.SSEHub.Unsubscribe(client)
|
|
|
|
relationResult, err := ts.service.handleCreateMemoRelation(ctx, toolRequest("create_memo_relation", map[string]any{
|
|
"name": "memos/" + source.UID,
|
|
"related_memo": "memos/" + target.UID,
|
|
}))
|
|
require.NoError(t, err)
|
|
require.False(t, relationResult.IsError)
|
|
relationEvent := nextSSEEvent(t, client)
|
|
require.Equal(t, apiv1service.SSEEventMemoUpdated, relationEvent.Type)
|
|
require.Equal(t, "memos/"+source.UID, relationEvent.Name)
|
|
|
|
duplicateRelationResult, err := ts.service.handleCreateMemoRelation(ctx, toolRequest("create_memo_relation", map[string]any{
|
|
"name": "memos/" + source.UID,
|
|
"related_memo": "memos/" + target.UID,
|
|
}))
|
|
require.NoError(t, err)
|
|
require.False(t, duplicateRelationResult.IsError)
|
|
requireNoSSEEvent(t, client)
|
|
|
|
selfRelationResult, err := ts.service.handleCreateMemoRelation(ctx, toolRequest("create_memo_relation", map[string]any{
|
|
"name": "memos/" + source.UID,
|
|
"related_memo": "memos/" + source.UID,
|
|
}))
|
|
require.NoError(t, err)
|
|
require.True(t, selfRelationResult.IsError)
|
|
require.Contains(t, firstText(t, selfRelationResult), "itself")
|
|
|
|
linkResult, err := ts.service.handleLinkAttachmentToMemo(ctx, toolRequest("link_attachment_to_memo", map[string]any{
|
|
"name": "attachments/" + attachment.UID,
|
|
"memo": "memos/" + source.UID,
|
|
}))
|
|
require.NoError(t, err)
|
|
require.False(t, linkResult.IsError)
|
|
attachmentEvent := nextSSEEvent(t, client)
|
|
require.Equal(t, apiv1service.SSEEventMemoUpdated, attachmentEvent.Type)
|
|
require.Equal(t, "memos/"+source.UID, attachmentEvent.Name)
|
|
|
|
relinkResult, err := ts.service.handleLinkAttachmentToMemo(ctx, toolRequest("link_attachment_to_memo", map[string]any{
|
|
"name": "attachments/" + attachment.UID,
|
|
"memo": "memos/" + source.UID,
|
|
}))
|
|
require.NoError(t, err)
|
|
require.False(t, relinkResult.IsError)
|
|
requireNoSSEEvent(t, client)
|
|
}
|