mirror of https://github.com/usememos/memos
feat(mcp): refactor MCP server to standard protocol structure
- Replace PAT-only auth with optional auth supporting both PAT and JWT
via auth.Authenticator.Authenticate(); unauthenticated requests see
only public memos, matching REST API visibility semantics
- Inline auth middleware into mcp.go following fileserver pattern;
remove auth_middleware.go
- Introduce memoJSON response type that correctly serialises store.Memo
(including Payload.Tags and Payload.Property) without proto marshalling
- Add tools: list_memo_comments, create_memo_comment, list_tags
- Extend list_memos with state (NORMAL/ARCHIVED), order_by_pinned, and
page parameters
- Extend update_memo with pinned and state parameters
- Extract #tags from content on create/update via regex to pre-populate
Payload.Tags without requiring a full markdown service rebuild
- Add MCP Resources: memo://memos/{uid} template returns memo as
Markdown with YAML frontmatter, allowing clients to read memos by URI
- Add MCP Prompts: capture (save a thought) and review (search + summarise)
pull/5638/head
parent
16576be111
commit
803d488a5f
@ -1,31 +0,0 @@
|
||||
package mcp
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
|
||||
"github.com/labstack/echo/v5"
|
||||
|
||||
"github.com/usememos/memos/server/auth"
|
||||
"github.com/usememos/memos/store"
|
||||
)
|
||||
|
||||
func newAuthMiddleware(s *store.Store, secret string) echo.MiddlewareFunc {
|
||||
authenticator := auth.NewAuthenticator(s, secret)
|
||||
return func(next echo.HandlerFunc) echo.HandlerFunc {
|
||||
return func(c *echo.Context) error {
|
||||
token := auth.ExtractBearerToken(c.Request().Header.Get("Authorization"))
|
||||
if token == "" {
|
||||
return c.JSON(http.StatusUnauthorized, map[string]string{"message": "a personal access token is required"})
|
||||
}
|
||||
|
||||
user, pat, err := authenticator.AuthenticateByPAT(c.Request().Context(), token)
|
||||
if err != nil || user == nil {
|
||||
return c.JSON(http.StatusUnauthorized, map[string]string{"message": "invalid or expired personal access token"})
|
||||
}
|
||||
|
||||
ctx := auth.SetUserInContext(c.Request().Context(), user, pat.GetTokenId())
|
||||
c.SetRequest(c.Request().WithContext(ctx))
|
||||
return next(c)
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -0,0 +1,84 @@
|
||||
package mcp
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
|
||||
"github.com/mark3labs/mcp-go/mcp"
|
||||
mcpserver "github.com/mark3labs/mcp-go/server"
|
||||
)
|
||||
|
||||
func (s *MCPService) registerPrompts(mcpSrv *mcpserver.MCPServer) {
|
||||
// capture — turns free-form user input into a structured create_memo call.
|
||||
mcpSrv.AddPrompt(
|
||||
mcp.NewPrompt("capture",
|
||||
mcp.WithPromptDescription("Capture a thought, idea, or note as a new memo. "+
|
||||
"Use this prompt when the user wants to quickly save something. "+
|
||||
"The assistant will call create_memo with the provided content."),
|
||||
mcp.WithArgument("content",
|
||||
mcp.ArgumentDescription("The text to save as a memo"),
|
||||
mcp.RequiredArgument(),
|
||||
),
|
||||
mcp.WithArgument("tags",
|
||||
mcp.ArgumentDescription("Comma-separated tags to apply, e.g. \"work,project\""),
|
||||
),
|
||||
),
|
||||
s.handleCapturePrompt,
|
||||
)
|
||||
|
||||
// review — surfaces existing memos on a topic for summarisation.
|
||||
mcpSrv.AddPrompt(
|
||||
mcp.NewPrompt("review",
|
||||
mcp.WithPromptDescription("Search and review memos on a given topic. "+
|
||||
"The assistant will call search_memos and summarise the results."),
|
||||
mcp.WithArgument("topic",
|
||||
mcp.ArgumentDescription("Topic or keyword to search for"),
|
||||
mcp.RequiredArgument(),
|
||||
),
|
||||
),
|
||||
s.handleReviewPrompt,
|
||||
)
|
||||
}
|
||||
|
||||
func (*MCPService) handleCapturePrompt(_ context.Context, req mcp.GetPromptRequest) (*mcp.GetPromptResult, error) {
|
||||
content := req.Params.Arguments["content"]
|
||||
if content == "" {
|
||||
return nil, errors.New("content argument is required")
|
||||
}
|
||||
|
||||
tags := req.Params.Arguments["tags"]
|
||||
instruction := fmt.Sprintf(
|
||||
"Please save the following as a new private memo using the create_memo tool.\n\nContent:\n%s",
|
||||
content,
|
||||
)
|
||||
if tags != "" {
|
||||
instruction += fmt.Sprintf("\n\nAppend these tags inline using #tag syntax: %s", tags)
|
||||
}
|
||||
|
||||
return &mcp.GetPromptResult{
|
||||
Description: "Capture a memo",
|
||||
Messages: []mcp.PromptMessage{
|
||||
mcp.NewPromptMessage(mcp.RoleUser, mcp.NewTextContent(instruction)),
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (*MCPService) handleReviewPrompt(_ context.Context, req mcp.GetPromptRequest) (*mcp.GetPromptResult, error) {
|
||||
topic := req.Params.Arguments["topic"]
|
||||
if topic == "" {
|
||||
return nil, errors.New("topic argument is required")
|
||||
}
|
||||
|
||||
instruction := fmt.Sprintf(
|
||||
"Please use the search_memos tool to find memos about %q, then provide a concise summary of what has been written on this topic, grouped by theme. Include the memo names so the user can reference them.",
|
||||
topic,
|
||||
)
|
||||
|
||||
return &mcp.GetPromptResult{
|
||||
Description: fmt.Sprintf("Review memos about %q", topic),
|
||||
Messages: []mcp.PromptMessage{
|
||||
mcp.NewPromptMessage(mcp.RoleUser, mcp.NewTextContent(instruction)),
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
@ -0,0 +1,85 @@
|
||||
package mcp
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/mark3labs/mcp-go/mcp"
|
||||
mcpserver "github.com/mark3labs/mcp-go/server"
|
||||
"github.com/pkg/errors"
|
||||
|
||||
"github.com/usememos/memos/server/auth"
|
||||
"github.com/usememos/memos/store"
|
||||
)
|
||||
|
||||
// Memo resource URI scheme: memo://memos/{uid}
|
||||
// Clients can read any memo they have access to by URI without calling a tool.
|
||||
|
||||
func (s *MCPService) registerMemoResources(mcpSrv *mcpserver.MCPServer) {
|
||||
mcpSrv.AddResourceTemplate(
|
||||
mcp.NewResourceTemplate(
|
||||
"memo://memos/{uid}",
|
||||
"Memo",
|
||||
mcp.WithTemplateDescription("A single Memos note identified by its UID. Returns the memo content as Markdown with a YAML frontmatter header containing metadata."),
|
||||
mcp.WithTemplateMIMEType("text/markdown"),
|
||||
),
|
||||
s.handleReadMemoResource,
|
||||
)
|
||||
}
|
||||
|
||||
func (s *MCPService) handleReadMemoResource(ctx context.Context, req mcp.ReadResourceRequest) ([]mcp.ResourceContents, error) {
|
||||
userID := auth.GetUserID(ctx)
|
||||
|
||||
// URI format: memo://memos/{uid}
|
||||
uid := strings.TrimPrefix(req.Params.URI, "memo://memos/")
|
||||
if uid == req.Params.URI || uid == "" {
|
||||
return nil, errors.Errorf("invalid memo URI %q: expected memo://memos/<uid>", req.Params.URI)
|
||||
}
|
||||
|
||||
memo, err := s.store.GetMemo(ctx, &store.FindMemo{UID: &uid})
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "failed to get memo")
|
||||
}
|
||||
if memo == nil {
|
||||
return nil, errors.Errorf("memo not found: %s", uid)
|
||||
}
|
||||
if err := checkMemoAccess(memo, userID); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
j := storeMemoToJSON(memo)
|
||||
text := formatMemoMarkdown(j)
|
||||
|
||||
return []mcp.ResourceContents{
|
||||
mcp.TextResourceContents{
|
||||
URI: req.Params.URI,
|
||||
MIMEType: "text/markdown",
|
||||
Text: text,
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
// formatMemoMarkdown renders a memo as Markdown with a YAML frontmatter header.
|
||||
func formatMemoMarkdown(j memoJSON) string {
|
||||
var sb strings.Builder
|
||||
|
||||
sb.WriteString("---\n")
|
||||
fmt.Fprintf(&sb, "name: %s\n", j.Name)
|
||||
fmt.Fprintf(&sb, "creator: %s\n", j.Creator)
|
||||
fmt.Fprintf(&sb, "visibility: %s\n", j.Visibility)
|
||||
fmt.Fprintf(&sb, "state: %s\n", j.State)
|
||||
fmt.Fprintf(&sb, "pinned: %v\n", j.Pinned)
|
||||
if len(j.Tags) > 0 {
|
||||
fmt.Fprintf(&sb, "tags: [%s]\n", strings.Join(j.Tags, ", "))
|
||||
}
|
||||
fmt.Fprintf(&sb, "create_time: %d\n", j.CreateTime)
|
||||
fmt.Fprintf(&sb, "update_time: %d\n", j.UpdateTime)
|
||||
if j.Parent != "" {
|
||||
fmt.Fprintf(&sb, "parent: %s\n", j.Parent)
|
||||
}
|
||||
sb.WriteString("---\n\n")
|
||||
sb.WriteString(j.Content)
|
||||
|
||||
return sb.String()
|
||||
}
|
||||
@ -0,0 +1,68 @@
|
||||
package mcp
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"sort"
|
||||
|
||||
"github.com/mark3labs/mcp-go/mcp"
|
||||
mcpserver "github.com/mark3labs/mcp-go/server"
|
||||
|
||||
"github.com/usememos/memos/server/auth"
|
||||
"github.com/usememos/memos/store"
|
||||
)
|
||||
|
||||
func (s *MCPService) registerTagTools(mcpSrv *mcpserver.MCPServer) {
|
||||
mcpSrv.AddTool(mcp.NewTool("list_tags",
|
||||
mcp.WithDescription("List all tags with their memo counts. Authenticated users see tags from their own and visible memos; unauthenticated callers see tags from public memos only. Results are sorted by count descending, then alphabetically."),
|
||||
), s.handleListTags)
|
||||
}
|
||||
|
||||
type tagEntry struct {
|
||||
Tag string `json:"tag"`
|
||||
Count int `json:"count"`
|
||||
}
|
||||
|
||||
func (s *MCPService) handleListTags(ctx context.Context, _ mcp.CallToolRequest) (*mcp.CallToolResult, error) {
|
||||
userID := auth.GetUserID(ctx)
|
||||
|
||||
rowStatus := store.Normal
|
||||
find := &store.FindMemo{
|
||||
ExcludeComments: true,
|
||||
ExcludeContent: true,
|
||||
RowStatus: &rowStatus,
|
||||
}
|
||||
applyVisibilityFilter(find, userID)
|
||||
|
||||
memos, err := s.store.ListMemos(ctx, find)
|
||||
if err != nil {
|
||||
return mcp.NewToolResultError(fmt.Sprintf("failed to list memos: %v", err)), nil
|
||||
}
|
||||
|
||||
counts := make(map[string]int)
|
||||
for _, m := range memos {
|
||||
if m.Payload == nil {
|
||||
continue
|
||||
}
|
||||
for _, tag := range m.Payload.Tags {
|
||||
counts[tag]++
|
||||
}
|
||||
}
|
||||
|
||||
entries := make([]tagEntry, 0, len(counts))
|
||||
for tag, count := range counts {
|
||||
entries = append(entries, tagEntry{Tag: tag, Count: count})
|
||||
}
|
||||
sort.Slice(entries, func(i, j int) bool {
|
||||
if entries[i].Count != entries[j].Count {
|
||||
return entries[i].Count > entries[j].Count
|
||||
}
|
||||
return entries[i].Tag < entries[j].Tag
|
||||
})
|
||||
|
||||
out, err := marshalJSON(entries)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return mcp.NewToolResultText(out), nil
|
||||
}
|
||||
Loading…
Reference in New Issue