You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
memos/server/router/mcp/tools_relation.go

298 lines
9.7 KiB
Go

package mcp
import (
"context"
"fmt"
"github.com/mark3labs/mcp-go/mcp"
mcpserver "github.com/mark3labs/mcp-go/server"
v1pb "github.com/usememos/memos/proto/gen/api/v1"
"github.com/usememos/memos/server/auth"
"github.com/usememos/memos/store"
)
type relationJSON struct {
Memo string `json:"memo"`
RelatedMemo string `json:"related_memo"`
Type string `json:"type"`
}
func (s *MCPService) registerRelationTools(mcpSrv *mcpserver.MCPServer) {
mcpSrv.AddTool(mcp.NewTool("list_memo_relations",
readOnlyToolOptions("List memo relations", "List all relations (references and comments) for a memo. Requires read access to the memo.",
mcp.WithString("name", mcp.Required(), mcp.Description(`Memo resource name, e.g. "memos/abc123"`)),
mcp.WithString("type",
mcp.Enum("REFERENCE", "COMMENT"),
mcp.Description("Filter by relation type (optional)"),
),
)...,
), s.handleListMemoRelations)
mcpSrv.AddTool(mcp.NewTool("create_memo_relation",
createToolOptions("Create memo relation", "Create a reference relation between two memos. Requires authentication. For comments, use create_memo_comment instead.", true,
mcp.WithString("name", mcp.Required(), mcp.Description(`Source memo resource name, e.g. "memos/abc123"`)),
mcp.WithString("related_memo", mcp.Required(), mcp.Description(`Target memo resource name, e.g. "memos/def456"`)),
mcp.WithOutputSchema[relationJSON](),
)...,
), s.handleCreateMemoRelation)
mcpSrv.AddTool(mcp.NewTool("delete_memo_relation",
updateToolOptions("Delete memo relation", "Delete a reference relation between two memos. Requires authentication and ownership of the source memo.",
mcp.WithString("name", mcp.Required(), mcp.Description(`Source memo resource name, e.g. "memos/abc123"`)),
mcp.WithString("related_memo", mcp.Required(), mcp.Description(`Target memo resource name, e.g. "memos/def456"`)),
mcp.WithOutputSchema[deletedJSON](),
)...,
), s.handleDeleteMemoRelation)
}
func (s *MCPService) handleListMemoRelations(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) {
userID := auth.GetUserID(ctx)
uid, err := parseMemoUID(req.GetString("name", ""))
if err != nil {
return mcp.NewToolResultError(err.Error()), nil
}
memo, err := s.store.GetMemo(ctx, &store.FindMemo{UID: &uid})
if err != nil {
return mcp.NewToolResultError(fmt.Sprintf("failed to get memo: %v", err)), nil
}
if memo == nil {
return mcp.NewToolResultError("memo not found"), nil
}
if err := checkMemoAccess(memo, userID); err != nil {
return mcp.NewToolResultError(err.Error()), nil
}
find := &store.FindMemoRelation{
MemoIDList: []int32{memo.ID},
}
if typeStr := req.GetString("type", ""); typeStr != "" {
switch store.MemoRelationType(typeStr) {
case store.MemoRelationReference, store.MemoRelationComment:
t := store.MemoRelationType(typeStr)
find.Type = &t
default:
return mcp.NewToolResultError(fmt.Sprintf("type must be REFERENCE or COMMENT, got %q", typeStr)), nil
}
}
relations, err := s.store.ListMemoRelations(ctx, find)
if err != nil {
return mcp.NewToolResultError(fmt.Sprintf("failed to list relations: %v", err)), nil
}
// Resolve memo IDs to UIDs.
idSet := make(map[int32]struct{})
for _, r := range relations {
idSet[r.MemoID] = struct{}{}
idSet[r.RelatedMemoID] = struct{}{}
}
ids := make([]int32, 0, len(idSet))
for id := range idSet {
ids = append(ids, id)
}
memos, err := s.store.ListMemos(ctx, &store.FindMemo{IDList: ids, ExcludeContent: true})
if err != nil {
return mcp.NewToolResultError(fmt.Sprintf("failed to resolve memos: %v", err)), nil
}
memoByID := make(map[int32]*store.Memo, len(memos))
for _, m := range memos {
memoByID[m.ID] = m
}
results := make([]relationJSON, 0, len(relations))
for _, r := range relations {
srcMemo, ok1 := memoByID[r.MemoID]
relatedMemo, ok2 := memoByID[r.RelatedMemoID]
if !ok1 || !ok2 {
continue
}
if checkMemoAccess(srcMemo, userID) != nil || checkMemoAccess(relatedMemo, userID) != nil {
continue
}
results = append(results, relationJSON{
Memo: "memos/" + srcMemo.UID,
RelatedMemo: "memos/" + relatedMemo.UID,
Type: string(r.Type),
})
}
return newToolResultJSON(results)
}
func (s *MCPService) handleCreateMemoRelation(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) {
userID, err := extractUserID(ctx)
if err != nil {
return mcp.NewToolResultError(err.Error()), nil
}
srcUID, err := parseMemoUID(req.GetString("name", ""))
if err != nil {
return mcp.NewToolResultError(err.Error()), nil
}
dstUID, err := parseMemoUID(req.GetString("related_memo", ""))
if err != nil {
return mcp.NewToolResultError(err.Error()), nil
}
if srcUID == dstUID {
return mcp.NewToolResultError("cannot create a relation from a memo to itself"), nil
}
srcMemo, err := s.store.GetMemo(ctx, &store.FindMemo{UID: &srcUID})
if err != nil {
return mcp.NewToolResultError(fmt.Sprintf("failed to get source memo: %v", err)), nil
}
if srcMemo == nil {
return mcp.NewToolResultError("source memo not found"), nil
}
if !hasMemoOwnership(srcMemo, userID) {
return mcp.NewToolResultError("permission denied: must own the source memo"), nil
}
dstMemo, err := s.store.GetMemo(ctx, &store.FindMemo{UID: &dstUID})
if err != nil {
return mcp.NewToolResultError(fmt.Sprintf("failed to get related memo: %v", err)), nil
}
if dstMemo == nil {
return mcp.NewToolResultError("related memo not found"), nil
}
if err := checkMemoAccess(dstMemo, userID); err != nil {
return mcp.NewToolResultError(err.Error()), nil
}
relations, changed, err := s.buildReferenceRelationSet(ctx, srcMemo, &dstMemo.UID, nil)
if err != nil {
return mcp.NewToolResultError(fmt.Sprintf("failed to build relation set: %v", err)), nil
}
if changed {
if _, err := s.apiV1Service.SetMemoRelations(ctx, &v1pb.SetMemoRelationsRequest{
Name: "memos/" + srcUID,
Relations: relations,
}); err != nil {
return mcp.NewToolResultError(fmt.Sprintf("failed to create relation: %v", err)), nil
}
}
return newToolResultJSON(relationJSON{
Memo: "memos/" + srcUID,
RelatedMemo: "memos/" + dstUID,
Type: string(store.MemoRelationReference),
})
}
func (s *MCPService) handleDeleteMemoRelation(ctx context.Context, req mcp.CallToolRequest) (*mcp.CallToolResult, error) {
userID, err := extractUserID(ctx)
if err != nil {
return mcp.NewToolResultError(err.Error()), nil
}
srcUID, err := parseMemoUID(req.GetString("name", ""))
if err != nil {
return mcp.NewToolResultError(err.Error()), nil
}
dstUID, err := parseMemoUID(req.GetString("related_memo", ""))
if err != nil {
return mcp.NewToolResultError(err.Error()), nil
}
srcMemo, err := s.store.GetMemo(ctx, &store.FindMemo{UID: &srcUID})
if err != nil {
return mcp.NewToolResultError(fmt.Sprintf("failed to get source memo: %v", err)), nil
}
if srcMemo == nil {
return mcp.NewToolResultError("source memo not found"), nil
}
if !hasMemoOwnership(srcMemo, userID) {
return mcp.NewToolResultError("permission denied: must own the source memo"), nil
}
dstMemo, err := s.store.GetMemo(ctx, &store.FindMemo{UID: &dstUID})
if err != nil {
return mcp.NewToolResultError(fmt.Sprintf("failed to get related memo: %v", err)), nil
}
if dstMemo == nil {
return mcp.NewToolResultError("related memo not found"), nil
}
if err := checkMemoAccess(dstMemo, userID); err != nil {
return mcp.NewToolResultError(err.Error()), nil
}
relations, changed, err := s.buildReferenceRelationSet(ctx, srcMemo, nil, &dstMemo.UID)
if err != nil {
return mcp.NewToolResultError(fmt.Sprintf("failed to build relation set: %v", err)), nil
}
if changed {
if _, err := s.apiV1Service.SetMemoRelations(ctx, &v1pb.SetMemoRelationsRequest{
Name: "memos/" + srcUID,
Relations: relations,
}); err != nil {
return mcp.NewToolResultError(fmt.Sprintf("failed to delete relation: %v", err)), nil
}
}
return newDeletedToolResult()
}
func (s *MCPService) buildReferenceRelationSet(ctx context.Context, source *store.Memo, includeUID *string, excludeUID *string) ([]*v1pb.MemoRelation, bool, error) {
referenceType := store.MemoRelationReference
relations, err := s.store.ListMemoRelations(ctx, &store.FindMemoRelation{
MemoIDList: []int32{source.ID},
Type: &referenceType,
})
if err != nil {
return nil, false, err
}
idSet := make(map[int32]struct{}, len(relations))
for _, relation := range relations {
idSet[relation.RelatedMemoID] = struct{}{}
}
ids := make([]int32, 0, len(idSet))
for id := range idSet {
ids = append(ids, id)
}
memosByID := map[int32]*store.Memo{}
if len(ids) > 0 {
memos, err := s.store.ListMemos(ctx, &store.FindMemo{IDList: ids, ExcludeContent: true})
if err != nil {
return nil, false, err
}
for _, memo := range memos {
memosByID[memo.ID] = memo
}
}
result := make([]*v1pb.MemoRelation, 0, len(relations)+1)
seenUIDs := map[string]struct{}{}
changed := false
for _, relation := range relations {
relatedMemo := memosByID[relation.RelatedMemoID]
if relatedMemo == nil {
continue
}
if excludeUID != nil && relatedMemo.UID == *excludeUID {
changed = true
continue
}
result = append(result, newReferenceRelation(source.UID, relatedMemo.UID))
seenUIDs[relatedMemo.UID] = struct{}{}
}
if includeUID != nil {
if _, seen := seenUIDs[*includeUID]; !seen && source.UID != *includeUID {
result = append(result, newReferenceRelation(source.UID, *includeUID))
changed = true
}
}
return result, changed, nil
}
func newReferenceRelation(sourceUID string, relatedUID string) *v1pb.MemoRelation {
return &v1pb.MemoRelation{
Memo: &v1pb.MemoRelation_Memo{Name: "memos/" + sourceUID},
RelatedMemo: &v1pb.MemoRelation_Memo{Name: "memos/" + relatedUID},
Type: v1pb.MemoRelation_REFERENCE,
}
}