feat: add OpenAPI-driven MCP support (#6026)

pull/6028/head
boojack 2 weeks ago committed by GitHub
parent a47d04954e
commit 777d227eb9
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

File diff suppressed because it is too large Load Diff

@ -0,0 +1,241 @@
# OpenAPI-Driven MCP Support Design
## Context
Memos previously had an MCP server, but it was removed in commit `2a4638b3` and should not be used as the baseline for this work. GitHub issue #6022 reported that some old MCP tools returned a bare JSON array in `result.structuredContent`, which strict MCP clients reject because `structuredContent` must be an object.
The new MCP support should start from zero and use the generated OpenAPI document at `proto/gen/openapi.yaml` as the source of truth. The OpenAPI file already describes the public REST gateway operations generated from protobuf definitions, including operation IDs, descriptions, parameters, request schemas, and response schemas.
## Goals
- Add MCP support back through a new implementation that is mechanically tied to `proto/gen/openapi.yaml`.
- Expose a standard MCP Streamable HTTP endpoint at the `/mcp` path.
- Register tools only; do not add MCP resources or prompts in the first version.
- Expose a curated memo-focused toolset derived from OpenAPI operations, not every API endpoint.
- Execute MCP tool calls through the existing API contract instead of duplicating store or service logic.
- Return object-shaped `structuredContent` for every tool result.
- Keep authentication and authorization behavior aligned with the existing API.
## Non-Goals
- Reviving or adapting the removed MCP package design.
- Adding custom MCP route aliases, readonly endpoints, or toolset filtering headers.
- Exposing every operation in `proto/gen/openapi.yaml` as an MCP tool.
- Adding MCP tools for admin settings, users, identity providers, webhooks, personal access tokens, authentication, share-link management, AI transcription, or bulk deletion.
- Adding `list_tags` or `search_memos` unless matching proto/API operations are added and OpenAPI is regenerated.
- Hand-editing generated OpenAPI or generated protobuf outputs.
## Recommended Approach
Build a new `server/router/mcp` package that parses `proto/gen/openapi.yaml` at startup, selects a curated allowlist of operation IDs, and converts those selected OpenAPI operations into MCP tools. Tool calls map MCP arguments into the selected operation's path parameters, query parameters, and JSON body, then execute the corresponding `/api/v1/...` HTTP request through the existing Echo/gRPC-Gateway route in-process.
This approach keeps OpenAPI as the authoritative contract while avoiding the usability and safety problems of exposing a large API-mirrored tool surface.
## Architecture
### MCP Service
Create a new MCP service package under `server/router/mcp`. `server.NewServer` creates the existing `APIV1Service`, registers the normal file, RSS, and API routes, then registers a new MCP service against the same Echo server.
The service exposes one MCP endpoint path:
```text
/mcp
```
The implementation must support standard Streamable HTTP client messages on `POST /mcp`. It may also support `GET /mcp` and `DELETE /mcp` if the chosen MCP transport implementation requires those methods for standards-compliant streaming or session cleanup.
The first version advertises only the MCP tools capability. It does not advertise prompts or resources.
### OpenAPI Operation Registry
At startup, the MCP service loads `proto/gen/openapi.yaml` and builds an operation registry keyed by `operationId`. Each parsed operation stores:
- operation ID
- HTTP method
- OpenAPI route template
- description
- path parameters
- query parameters
- JSON request body schema
- HTTP 200 JSON response schema
The parser should fail fast during service construction if any curated operation ID is missing or cannot be converted into a valid MCP tool schema.
### Tool Names
Tool names are derived from OpenAPI `operationId` values and normalized for MCP clients. The exact naming convention should be deterministic and tested. A practical convention is lower snake case without the `Service` suffix in the subject:
```text
MemoService_ListMemos -> memo_list_memos
AttachmentService_GetAttachment -> attachment_get_attachment
```
The OpenAPI `operationId` remains stored in tool metadata so tests and future diagnostics can prove which OpenAPI operation produced each MCP tool.
### Tool Schemas
Each MCP tool input schema is an object built from:
- OpenAPI path parameters
- OpenAPI query parameters
- JSON request body fields, when present
Required path parameters stay required. Required request bodies stay required. Optional query parameters stay optional. The schema should preserve OpenAPI descriptions and primitive types where possible.
Each MCP tool output schema is the OpenAPI HTTP 200 JSON response schema. For empty 200 responses with no JSON schema, the MCP output schema is:
```json
{ "type": "object", "properties": { "ok": { "type": "boolean" } } }
```
## Tool Scope
The first version exposes these curated OpenAPI operations:
- `MemoService_ListMemos`
- `MemoService_CreateMemo`
- `MemoService_GetMemo`
- `MemoService_UpdateMemo`
- `MemoService_DeleteMemo`
- `MemoService_ListMemoComments`
- `MemoService_CreateMemoComment`
- `MemoService_ListMemoAttachments`
- `MemoService_SetMemoAttachments`
- `MemoService_ListMemoReactions`
- `MemoService_UpsertMemoReaction`
- `MemoService_DeleteMemoReaction`
- `MemoService_ListMemoRelations`
- `MemoService_SetMemoRelations`
- `AttachmentService_ListAttachments`
- `AttachmentService_GetAttachment`
- `AttachmentService_DeleteAttachment`
Excluded in the first version:
- auth sign-in, sign-out, and refresh
- user management
- personal access token management
- identity provider management
- webhooks
- instance settings
- share-link management
- AI transcription
- bulk delete operations
- any operation not present in generated OpenAPI
## Data Flow
### Startup
1. `server.NewServer` creates the existing `APIV1Service`.
2. The new `MCPService` loads `proto/gen/openapi.yaml`.
3. The OpenAPI parser builds the operation registry.
4. The curated operation allowlist selects supported operations.
5. Each selected operation is registered as an MCP tool with input schema, output schema, description, annotations, and operation metadata.
### Tool Call
1. The client sends a standard MCP `tools/call` request to `/mcp`.
2. The MCP server validates the tool name and arguments against the OpenAPI-derived input schema.
3. The adapter substitutes path parameters into the OpenAPI route template.
4. The adapter encodes query parameters into the query string.
5. The adapter marshals request body arguments as JSON when the operation has a request body.
6. The adapter forwards the caller's `Authorization` header and executes the matching `/api/v1/...` request through the existing Echo handler in-process.
7. The API response JSON is decoded into `map[string]any`.
8. The MCP result returns a compact JSON text fallback plus object-shaped `structuredContent`.
## Result Shape
Every MCP result must use object-shaped `structuredContent`.
Rules:
- If the API response is a JSON object, return it unchanged.
- If the API response is empty, return `{ "ok": true }`.
- If an unexpected raw JSON array appears, wrap it as `{ "result": [...] }`.
- If an unexpected scalar appears, wrap it as `{ "result": value }`.
This directly addresses issue #6022 by preventing collection tools from returning bare arrays.
## Authentication And Origin Safety
The MCP endpoint accepts the same bearer credentials as the API:
```text
Authorization: Bearer <PAT-or-access-token>
```
The MCP adapter forwards the bearer header to the in-process API request. Public API operations can work without authentication when the API allows them. Mutating operations require authentication because the API already enforces that behavior.
For browser-origin safety, `/mcp` rejects cross-origin browser requests unless the `Origin` header is same-origin or matches the configured instance URL. Requests without an `Origin` header are allowed because desktop MCP clients commonly omit it.
## Errors
The MCP server should convert failures into MCP tool errors:
- Invalid MCP arguments: concise validation message.
- API `401` or `403`: preserve the API message where available.
- API `404`: report that the resource was not found.
- Other API errors: include the HTTP status code and decoded API message.
- Internal OpenAPI or adapter errors: log server-side details and return a concise tool error.
Adapter errors should not bypass MCP result formatting unless the underlying MCP framework requires protocol-level errors for invalid protocol messages.
## Tool Annotations
Tool annotations are derived from HTTP methods:
- `GET`: read-only, non-destructive, idempotent.
- `POST`: mutating unless the operation is explicitly known to be read-only.
- `PATCH`: mutating, non-idempotent by default.
- `DELETE`: destructive and idempotent by default.
These annotations are hints for clients and do not replace API authorization.
## Testing
### OpenAPI Parsing Tests
Tests should verify:
- every curated operation ID exists in `proto/gen/openapi.yaml`
- every selected tool has an object input schema
- every selected tool has an object output schema
- selected operations do not include admin, auth, webhook, identity provider, personal access token, instance setting, share-link, AI transcription, or bulk-delete operations
- tool names are deterministic and unique
### MCP Protocol Tests
Using an Echo test server and JSON-RPC requests, tests should verify:
- `initialize` succeeds
- `tools/list` returns only curated OpenAPI-derived tools
- tool definitions include input and output schemas
- no prompts or resources capabilities are advertised
- collection tool calls return object-shaped `structuredContent`, never a bare array
### Adapter Tests
Representative adapter tests should cover:
- `GET /api/v1/memos?pageSize=...`
- `POST /api/v1/memos`
- `GET /api/v1/memos/{memo}`
- `PATCH /api/v1/memos/{memo}`
- `DELETE /api/v1/memos/{memo}`
- `GET /api/v1/memos/{memo}/comments`
- `POST /api/v1/memos/{memo}/reactions`
Before finishing implementation, run:
```bash
go test ./server/router/mcp/... ./server/router/api/v1/... ./server/...
```
## Implementation Notes
- Do not hand-edit `proto/gen/openapi.yaml`.
- If a needed MCP tool is not represented by OpenAPI, add or adjust the proto/API surface first, run `cd proto && buf generate`, then derive the MCP tool from the regenerated OpenAPI.
- Prefer existing API auth and gateway behavior over duplicating authorization logic in the MCP adapter.
- Keep the first version intentionally small. Additional OpenAPI-derived tools can be added by extending the curated allowlist and tests.

@ -11,6 +11,7 @@ require (
github.com/aws/aws-sdk-go-v2/service/s3 v1.100.0
github.com/go-sql-driver/mysql v1.9.3
github.com/google/cel-go v0.28.0
github.com/google/jsonschema-go v0.4.3
github.com/google/uuid v1.6.0
github.com/gorilla/feeds v1.2.0
github.com/grpc-ecosystem/grpc-gateway/v2 v2.29.0
@ -19,6 +20,7 @@ require (
github.com/lib/pq v1.12.3
github.com/lithammer/shortuuid/v4 v4.2.0
github.com/moby/moby/api v1.54.2
github.com/modelcontextprotocol/go-sdk v1.6.1
github.com/openai/openai-go/v3 v3.32.0
github.com/pion/opus v0.0.0-20260430223319-81a9c5dc5013
github.com/pkg/errors v0.9.1
@ -94,6 +96,8 @@ require (
github.com/power-devops/perfstat v0.0.0-20240221224432-82ca36839d55 // indirect
github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec // indirect
github.com/sagikazarmark/locafero v0.12.0 // indirect
github.com/segmentio/asm v1.1.3 // indirect
github.com/segmentio/encoding v0.5.4 // indirect
github.com/shirou/gopsutil/v4 v4.26.3 // indirect
github.com/sirupsen/logrus v1.9.4 // indirect
github.com/spf13/afero v1.15.0 // indirect
@ -106,6 +110,7 @@ require (
github.com/tidwall/sjson v1.2.5 // indirect
github.com/tklauser/go-sysconf v0.3.16 // indirect
github.com/tklauser/numcpus v0.11.0 // indirect
github.com/yosida95/uritemplate/v3 v3.0.2 // indirect
github.com/yusufpapurcu/wmi v1.2.4 // indirect
go.opentelemetry.io/auto/sdk v1.2.1 // indirect
go.opentelemetry.io/contrib/instrumentation/net/http/otelhttp v0.68.0 // indirect
@ -145,5 +150,5 @@ require (
golang.org/x/text v0.36.0
golang.org/x/time v0.15.0 // indirect
google.golang.org/protobuf v1.36.11
gopkg.in/yaml.v3 v3.0.1 // indirect
gopkg.in/yaml.v3 v3.0.1
)

@ -115,6 +115,8 @@ github.com/google/cel-go v0.28.0 h1:KjSWstCpz/MN5t4a8gnGJNIYUsJRpdi/r97xWDphIQc=
github.com/google/cel-go v0.28.0/go.mod h1:X0bD6iVNR8pkROSOoHVdgTkzmRcosof7WQqCD6wcMc8=
github.com/google/go-cmp v0.7.0 h1:wk8382ETsv4JYUZwIsn6YpYiWiBsYLSJiTsyBybVuN8=
github.com/google/go-cmp v0.7.0/go.mod h1:pXiqmnSA92OHEEa9HXL2W4E7lf9JzCmGVUdgjX3N/iU=
github.com/google/jsonschema-go v0.4.3 h1:/DBOLZTfDow7pe2GmaJNhltueGTtDKICi8V8p+DQPd0=
github.com/google/jsonschema-go v0.4.3/go.mod h1:r5quNTdLOYEz95Ru18zA0ydNbBuYoo9tgaYcxEYhJVE=
github.com/google/pprof v0.0.0-20250317173921-a4b03ec1a45e h1:ijClszYn+mADRFY17kjQEVQ1XRhq2/JR1M3sGqeJoxs=
github.com/google/pprof v0.0.0-20250317173921-a4b03ec1a45e/go.mod h1:boTsfXsheKC2y+lKOCMpSfarhxDeIzfZG1jqGcPl3cA=
github.com/google/s2a-go v0.1.9 h1:LGD7gtMgezd8a/Xak7mEWL0PjoTQFvpRudN895yqKW0=
@ -183,6 +185,8 @@ github.com/moby/sys/userns v0.1.0 h1:tVLXkFOxVu9A64/yh59slHVv9ahO9UIev4JZusOLG/g
github.com/moby/sys/userns v0.1.0/go.mod h1:IHUYgu/kao6N8YZlp9Cf444ySSvCmDlmzUcYfDHOl28=
github.com/moby/term v0.5.2 h1:6qk3FJAFDs6i/q3W/pQ97SX192qKfZgGjCQqfCJkgzQ=
github.com/moby/term v0.5.2/go.mod h1:d3djjFCrjnB+fl8NJux+EJzu0msscUP+f8it8hPkFLc=
github.com/modelcontextprotocol/go-sdk v1.6.1 h1:0zOSupjKUxPKSocPT1Wtago+mUHU2/uZ4xSOY0FGReU=
github.com/modelcontextprotocol/go-sdk v1.6.1/go.mod h1:kzm3kzFL1/+AziGOE0nUs3gvPoNxMCvkxokMkuFapXQ=
github.com/ncruces/go-strftime v1.0.0 h1:HMFp8mLCTPp341M/ZnA4qaf7ZlsbTc+miZjCLOFAw7w=
github.com/ncruces/go-strftime v1.0.0/go.mod h1:Fwc5htZGVVkseilnfgOVb9mKy6w1naJmn9CehxcKcls=
github.com/openai/openai-go/v3 v3.32.0 h1:aHp/3wkX1W6jB8zTtf9xV0aK0qPFSVDqS7AHmlJ4hXs=
@ -208,6 +212,10 @@ github.com/rogpeppe/go-internal v1.14.1/go.mod h1:MaRKkUm5W0goXpeCfT7UZI6fk/L7L7
github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM=
github.com/sagikazarmark/locafero v0.12.0 h1:/NQhBAkUb4+fH1jivKHWusDYFjMOOKU88eegjfxfHb4=
github.com/sagikazarmark/locafero v0.12.0/go.mod h1:sZh36u/YSZ918v0Io+U9ogLYQJ9tLLBmM4eneO6WwsI=
github.com/segmentio/asm v1.1.3 h1:WM03sfUOENvvKexOLp+pCqgb/WDjsi7EK8gIsICtzhc=
github.com/segmentio/asm v1.1.3/go.mod h1:Ld3L4ZXGNcSLRg4JBsZ3//1+f/TjYl0Mzen/DQy1EJg=
github.com/segmentio/encoding v0.5.4 h1:OW1VRern8Nw6ITAtwSZ7Idrl3MXCFwXHPgqESYfvNt0=
github.com/segmentio/encoding v0.5.4/go.mod h1:HS1ZKa3kSN32ZHVZ7ZLPLXWvOVIiZtyJnO1gPH1sKt0=
github.com/shirou/gopsutil/v4 v4.26.3 h1:2ESdQt90yU3oXF/CdOlRCJxrP+Am1aBYubTMTfxJ1qc=
github.com/shirou/gopsutil/v4 v4.26.3/go.mod h1:LZ6ewCSkBqUpvSOf+LsTGnRinC6iaNUNMGBtDkJBaLQ=
github.com/sirupsen/logrus v1.9.4 h1:TsZE7l11zFCLZnZ+teH4Umoq5BhEIfIzfRDZ1Uzql2w=
@ -250,6 +258,8 @@ github.com/tklauser/go-sysconf v0.3.16 h1:frioLaCQSsF5Cy1jgRBrzr6t502KIIwQ0MArYI
github.com/tklauser/go-sysconf v0.3.16/go.mod h1:/qNL9xxDhc7tx3HSRsLWNnuzbVfh3e7gh/BmM179nYI=
github.com/tklauser/numcpus v0.11.0 h1:nSTwhKH5e1dMNsCdVBukSZrURJRoHbSEQjdEbY+9RXw=
github.com/tklauser/numcpus v0.11.0/go.mod h1:z+LwcLq54uWZTX0u/bGobaV34u6V7KNlTZejzM6/3MQ=
github.com/yosida95/uritemplate/v3 v3.0.2 h1:Ed3Oyj9yrmi9087+NczuL5BwkIc4wvTb5zIM+UJPGz4=
github.com/yosida95/uritemplate/v3 v3.0.2/go.mod h1:ILOh0sOhIJR3+L/8afwt/kE++YT040gmv5BQTMR2HP4=
github.com/yuin/goldmark v1.8.2 h1:kEGpgqJXdgbkhcOgBxkC0X0PmoPG1ZyoZ117rDVp4zE=
github.com/yuin/goldmark v1.8.2/go.mod h1:ip/1k0VRfGynBgxOz0yCqHrbZXhcjxyuS66Brc7iBKg=
github.com/yusufpapurcu/wmi v1.2.4 h1:zFUKzehAFReQwLys1b/iSMl+JQGSCSjtVqQn9bBrPo0=

@ -0,0 +1,11 @@
package proto
import _ "embed"
//go:embed gen/openapi.yaml
var openAPIYAML []byte
// OpenAPIYAML returns the embedded generated OpenAPI specification.
func OpenAPIYAML() []byte {
return append([]byte(nil), openAPIYAML...)
}

@ -0,0 +1,160 @@
package mcp
import (
"bytes"
"context"
"encoding/json"
"io"
"net/http"
"net/http/httptest"
"net/url"
"strconv"
"strings"
"github.com/labstack/echo/v5"
sdkmcp "github.com/modelcontextprotocol/go-sdk/mcp"
"github.com/pkg/errors"
)
type apiAdapter struct {
echoServer *echo.Echo
}
func newAPIAdapter(echoServer *echo.Echo) *apiAdapter {
return &apiAdapter{echoServer: echoServer}
}
func (a *apiAdapter) execute(ctx context.Context, operation *openAPIOperation, arguments map[string]any, authorization string) (*sdkmcp.CallToolResult, error) {
req, err := buildAPIRequest(ctx, operation, arguments, authorization)
if err != nil {
return newToolErrorResult(err.Error()), nil
}
recorder := httptest.NewRecorder()
a.echoServer.ServeHTTP(recorder, req)
value, err := decodeJSONValue(recorder.Body.Bytes())
if err != nil {
return newToolErrorResult(err.Error()), nil
}
if recorder.Code < http.StatusOK || recorder.Code >= http.StatusMultipleChoices {
return newToolErrorResult(apiErrorMessage(recorder.Code, value)), nil
}
return newStructuredToolResult(value)
}
func buildAPIRequest(ctx context.Context, operation *openAPIOperation, arguments map[string]any, authorization string) (*http.Request, error) {
path, err := substitutePathParameters(operation, arguments)
if err != nil {
return nil, err
}
query := url.Values{}
for _, parameter := range operation.Parameters {
if parameter.In != "query" {
continue
}
value, ok := arguments[parameter.Name]
if !ok || value == nil {
continue
}
query.Set(parameter.Name, valueToString(value))
}
if encoded := query.Encode(); encoded != "" {
path += "?" + encoded
}
var body io.Reader
if operation.RequestBody != nil {
bodyValue, ok := arguments["body"]
if !ok || bodyValue == nil {
if operation.RequestBody.Required {
return nil, errors.New(`missing required request body "body"`)
}
bodyValue = map[string]any{}
}
data, err := json.Marshal(bodyValue)
if err != nil {
return nil, errors.Wrap(err, "failed to marshal request body")
}
body = bytes.NewReader(data)
}
req := httptest.NewRequest(operation.Method, path, body).WithContext(ctx)
if body != nil {
req.Header.Set("Content-Type", "application/json")
}
if authorization != "" {
req.Header.Set("Authorization", authorization)
}
return req, nil
}
func substitutePathParameters(operation *openAPIOperation, arguments map[string]any) (string, error) {
path := operation.Path
for _, parameter := range operation.Parameters {
if parameter.In != "path" {
continue
}
value, ok := arguments[parameter.Name]
if !ok || value == nil || valueToString(value) == "" {
return "", errors.Errorf(`missing required path parameter "%s"`, parameter.Name)
}
path = strings.ReplaceAll(path, "{"+parameter.Name+"}", url.PathEscape(valueToString(value)))
}
return path, nil
}
func valueToString(value any) string {
switch typed := value.(type) {
case string:
return typed
case bool:
return strconv.FormatBool(typed)
case int:
return strconv.Itoa(typed)
case int8:
return strconv.FormatInt(int64(typed), 10)
case int16:
return strconv.FormatInt(int64(typed), 10)
case int32:
return strconv.FormatInt(int64(typed), 10)
case int64:
return strconv.FormatInt(typed, 10)
case uint:
return strconv.FormatUint(uint64(typed), 10)
case uint8:
return strconv.FormatUint(uint64(typed), 10)
case uint16:
return strconv.FormatUint(uint64(typed), 10)
case uint32:
return strconv.FormatUint(uint64(typed), 10)
case uint64:
return strconv.FormatUint(typed, 10)
case float32:
return strconv.FormatFloat(float64(typed), 'f', -1, 32)
case float64:
return strconv.FormatFloat(typed, 'f', -1, 64)
default:
data, err := json.Marshal(value)
if err != nil {
return ""
}
return strings.Trim(string(data), `"`)
}
}
func apiErrorMessage(statusCode int, value any) string {
status := strings.TrimSpace(strconv.Itoa(statusCode) + " " + http.StatusText(statusCode))
if object, ok := value.(map[string]any); ok {
for _, key := range []string{"message", "error"} {
message, ok := object[key].(string)
if ok && message != "" {
return status + ": " + message
}
}
}
return status
}

@ -0,0 +1,213 @@
package mcp
import (
"context"
"encoding/json"
"errors"
"io"
"net/http"
"testing"
"github.com/labstack/echo/v5"
sdkmcp "github.com/modelcontextprotocol/go-sdk/mcp"
"github.com/stretchr/testify/require"
)
func TestNormalizeStructuredContentKeepsObjects(t *testing.T) {
result := normalizeStructuredContent(map[string]any{"memos": []any{map[string]any{"name": "memos/a"}}})
require.Equal(t, map[string]any{"memos": []any{map[string]any{"name": "memos/a"}}}, result)
}
func TestNormalizeStructuredContentWrapsArrays(t *testing.T) {
result := normalizeStructuredContent([]any{map[string]any{"tag": "work"}})
require.Equal(t, map[string]any{"result": []any{map[string]any{"tag": "work"}}}, result)
}
func TestNormalizeStructuredContentUsesOKForNil(t *testing.T) {
result := normalizeStructuredContent(nil)
require.Equal(t, map[string]any{"ok": true}, result)
}
func TestNormalizeStructuredContentWrapsScalars(t *testing.T) {
result := normalizeStructuredContent("created")
require.Equal(t, map[string]any{"result": "created"}, result)
}
func TestNewStructuredToolResultUsesObjectStructuredContent(t *testing.T) {
result, err := newStructuredToolResult([]any{"one"})
require.NoError(t, err)
require.IsType(t, map[string]any{}, result.StructuredContent)
require.Equal(t, map[string]any{"result": []any{"one"}}, result.StructuredContent)
require.NotEmpty(t, result.Content)
text, ok := result.Content[0].(*sdkmcp.TextContent)
require.True(t, ok)
require.JSONEq(t, `{"result":["one"]}`, text.Text)
}
func TestNewToolErrorResult(t *testing.T) {
result := newToolErrorResult("resource not found")
require.True(t, result.IsError)
require.Equal(t, map[string]any{
"error": map[string]any{
"message": "resource not found",
},
}, result.StructuredContent)
require.NotEmpty(t, result.Content)
text, ok := result.Content[0].(*sdkmcp.TextContent)
require.True(t, ok)
require.Equal(t, "resource not found", text.Text)
}
func TestDecodeJSONValue(t *testing.T) {
value, err := decodeJSONValue([]byte(`{"ok":true}`))
require.NoError(t, err)
require.Equal(t, map[string]any{"ok": true}, value)
value, err = decodeJSONValue([]byte{})
require.NoError(t, err)
require.Nil(t, value)
value, err = decodeJSONValue([]byte(" \n\t "))
require.NoError(t, err)
require.Nil(t, value)
value, err = decodeJSONValue([]byte(`[1,"two"]`))
require.NoError(t, err)
require.Equal(t, []any{float64(1), "two"}, value)
_, err = decodeJSONValue([]byte(`{`))
require.Error(t, err)
require.True(t, errorsIsJSONSyntax(err), "wrapped syntax errors should remain inspectable")
}
func errorsIsJSONSyntax(err error) bool {
var syntaxError *json.SyntaxError
return err != nil && errors.As(err, &syntaxError)
}
func TestBuildAPIRequestMapsPathQueryAndBody(t *testing.T) {
operation := &openAPIOperation{
Method: "PATCH",
Path: "/api/v1/memos/{memo}",
Parameters: []openAPIParameter{
{Name: "memo", In: "path", Required: true, Schema: jsonSchema{"type": "string"}},
{Name: "updateMask", In: "query", Schema: jsonSchema{"type": "string"}},
},
RequestBody: &openAPIRequestBody{Required: true},
}
arguments := map[string]any{
"memo": "abc123",
"updateMask": "content",
"body": map[string]any{
"memo": map[string]any{
"name": "memos/abc123",
"content": "updated",
},
},
}
req, err := buildAPIRequest(context.Background(), operation, arguments, "Bearer pat")
require.NoError(t, err)
require.Equal(t, "PATCH", req.Method)
require.Equal(t, "/api/v1/memos/abc123", req.URL.Path)
require.Equal(t, "content", req.URL.Query().Get("updateMask"))
require.Equal(t, "Bearer pat", req.Header.Get("Authorization"))
body, err := io.ReadAll(req.Body)
require.NoError(t, err)
require.JSONEq(t, `{"memo":{"name":"memos/abc123","content":"updated"}}`, string(body))
}
func TestBuildAPIRequestRequiresPathParameters(t *testing.T) {
operation := &openAPIOperation{
Method: "GET",
Path: "/api/v1/memos/{memo}",
Parameters: []openAPIParameter{{Name: "memo", In: "path", Required: true}},
}
_, err := buildAPIRequest(context.Background(), operation, map[string]any{}, "")
require.ErrorContains(t, err, `missing required path parameter "memo"`)
}
func TestBuildAPIRequestRequiresRequestBody(t *testing.T) {
operation := &openAPIOperation{
Method: "POST",
Path: "/api/v1/memos",
RequestBody: &openAPIRequestBody{Required: true},
}
_, err := buildAPIRequest(context.Background(), operation, map[string]any{}, "")
require.ErrorContains(t, err, `missing required request body "body"`)
}
func TestBuildAPIRequestEscapesPathAndStringifiesPrimitiveQueryParameters(t *testing.T) {
operation := &openAPIOperation{
Method: "DELETE",
Path: "/api/v1/memos/{memo}",
Parameters: []openAPIParameter{
{Name: "memo", In: "path", Required: true, Schema: jsonSchema{"type": "string"}},
{Name: "force", In: "query", Schema: jsonSchema{"type": "boolean"}},
{Name: "limit", In: "query", Schema: jsonSchema{"type": "integer"}},
},
}
req, err := buildAPIRequest(context.Background(), operation, map[string]any{
"memo": "abc 123",
"force": true,
"limit": 10,
}, "")
require.NoError(t, err)
require.Equal(t, "/api/v1/memos/abc%20123", req.URL.EscapedPath())
require.Equal(t, "true", req.URL.Query().Get("force"))
require.Equal(t, "10", req.URL.Query().Get("limit"))
}
func TestExecuteOperationReturnsObjectStructuredContent(t *testing.T) {
echoServer := echo.New()
echoServer.GET("/api/v1/memos", func(c *echo.Context) error {
require.Equal(t, "Bearer token", c.Request().Header.Get("Authorization"))
return c.JSON(http.StatusOK, map[string]any{
"memos": []any{map[string]any{"name": "memos/abc123"}},
})
})
operation := &openAPIOperation{
Method: "GET",
Path: "/api/v1/memos",
}
adapter := newAPIAdapter(echoServer)
result, err := adapter.execute(context.Background(), operation, map[string]any{}, "Bearer token")
require.NoError(t, err)
require.False(t, result.IsError)
require.Equal(t, map[string]any{
"memos": []any{map[string]any{"name": "memos/abc123"}},
}, result.StructuredContent)
}
func TestExecuteOperationConvertsAPIErrorsToToolErrors(t *testing.T) {
echoServer := echo.New()
echoServer.GET("/api/v1/memos/:memo", func(c *echo.Context) error {
return c.JSON(http.StatusNotFound, map[string]any{"message": "missing memo"})
})
operation := &openAPIOperation{
Method: "GET",
Path: "/api/v1/memos/{memo}",
Parameters: []openAPIParameter{{Name: "memo", In: "path", Required: true}},
}
adapter := newAPIAdapter(echoServer)
result, err := adapter.execute(context.Background(), operation, map[string]any{"memo": "missing"}, "")
require.NoError(t, err)
require.True(t, result.IsError)
require.Equal(t, map[string]any{
"error": map[string]any{
"message": "404 Not Found: missing memo",
},
}, result.StructuredContent)
text, ok := result.Content[0].(*sdkmcp.TextContent)
require.True(t, ok)
require.Contains(t, text.Text, "404")
require.Contains(t, text.Text, "missing memo")
}

@ -0,0 +1,215 @@
package mcp
import (
"regexp"
"strings"
sdkmcp "github.com/modelcontextprotocol/go-sdk/mcp"
"github.com/pkg/errors"
)
var curatedOperationIDs = []string{
"MemoService_ListMemos",
"MemoService_CreateMemo",
"MemoService_GetMemo",
"MemoService_UpdateMemo",
"MemoService_DeleteMemo",
"MemoService_ListMemoComments",
"MemoService_CreateMemoComment",
"MemoService_ListMemoAttachments",
"MemoService_SetMemoAttachments",
"MemoService_ListMemoReactions",
"MemoService_UpsertMemoReaction",
"MemoService_DeleteMemoReaction",
"MemoService_ListMemoRelations",
"MemoService_SetMemoRelations",
"AttachmentService_ListAttachments",
"AttachmentService_GetAttachment",
"AttachmentService_DeleteAttachment",
}
type registeredOperation struct {
ToolName string
OperationID string
Method string
Path string
Operation *openAPIOperation
InputSchema jsonSchema
}
var wordBoundary = regexp.MustCompile(`([a-z0-9])([A-Z])`)
func buildCuratedTools(registry map[string]*openAPIOperation) ([]*sdkmcp.Tool, map[string]*registeredOperation, error) {
tools := make([]*sdkmcp.Tool, 0, len(curatedOperationIDs))
operations := map[string]*registeredOperation{}
for _, operationID := range curatedOperationIDs {
operation, ok := registry[operationID]
if !ok {
return nil, nil, errors.Errorf("curated OpenAPI operation %q not found", operationID)
}
tool, registered := buildToolFromOperation(operation)
if _, exists := operations[tool.Name]; exists {
return nil, nil, errors.Errorf("duplicate MCP tool name %q", tool.Name)
}
tools = append(tools, tool)
operations[tool.Name] = registered
}
return tools, operations, nil
}
func buildToolFromOperation(operation *openAPIOperation) (*sdkmcp.Tool, *registeredOperation) {
name := toolNameFromOperationID(operation.OperationID)
title := titleFromToolName(name)
inputSchema := inputSchemaForOperation(operation)
tool := &sdkmcp.Tool{
Meta: sdkmcp.Meta{
"operationId": operation.OperationID,
"method": operation.Method,
"path": operation.Path,
},
Name: name,
Title: title,
Description: operation.Description,
InputSchema: inputSchema,
OutputSchema: outputSchemaForOperation(operation),
Annotations: annotationsForMethod(operation.Method, title),
}
return tool, &registeredOperation{
ToolName: name,
OperationID: operation.OperationID,
Method: operation.Method,
Path: operation.Path,
Operation: operation,
InputSchema: inputSchema,
}
}
func toolNameFromOperationID(operationID string) string {
service, method, ok := strings.Cut(operationID, "_")
if !ok {
return camelToSnake(operationID)
}
service = strings.TrimSuffix(service, "Service")
return camelToSnake(service) + "_" + camelToSnake(method)
}
func camelToSnake(value string) string {
return strings.ToLower(wordBoundary.ReplaceAllString(value, `${1}_${2}`))
}
func titleFromToolName(name string) string {
parts := strings.Split(name, "_")
for i, part := range parts {
if part == "" {
continue
}
parts[i] = strings.ToUpper(part[:1]) + part[1:]
}
return strings.Join(parts, " ")
}
func inputSchemaForOperation(operation *openAPIOperation) jsonSchema {
properties := map[string]any{}
required := []string{}
defs := map[string]any{}
for _, parameter := range operation.Parameters {
schema := cloneSchema(parameter.Schema)
if parameter.Description != "" {
schema["description"] = parameter.Description
}
properties[parameter.Name] = schema
if parameter.Required {
required = append(required, parameter.Name)
}
}
if operation.RequestBody != nil {
bodySchema := requestBodySchema(operation)
for name, definition := range extractSchemaDefs(bodySchema) {
defs[name] = definition
}
properties["body"] = bodySchema
if operation.RequestBody.Required {
required = append(required, "body")
}
}
schema := jsonSchema{
"type": "object",
"properties": properties,
"additionalProperties": false,
}
if len(required) > 0 {
schema["required"] = required
}
if len(defs) > 0 {
schema["$defs"] = defs
}
return schema
}
func requestBodySchema(operation *openAPIOperation) jsonSchema {
if operation.RequestBodySchema == nil {
return jsonSchema{"type": "object"}
}
return cloneSchema(operation.RequestBodySchema)
}
func outputSchemaForOperation(operation *openAPIOperation) jsonSchema {
if operation.ResponseSchema == nil {
return okSchema()
}
return cloneSchema(operation.ResponseSchema)
}
func cloneSchema(schema jsonSchema) jsonSchema {
clone := jsonSchema{}
for key, value := range schema {
clone[key] = value
}
return clone
}
func extractSchemaDefs(schema jsonSchema) map[string]any {
defs, ok := schema["$defs"].(map[string]any)
if !ok {
return nil
}
delete(schema, "$defs")
return defs
}
func annotationsForMethod(method string, title string) *sdkmcp.ToolAnnotations {
openWorld := false
destructive := false
switch strings.ToUpper(method) {
case "GET":
return &sdkmcp.ToolAnnotations{
Title: title,
ReadOnlyHint: true,
DestructiveHint: &destructive,
IdempotentHint: true,
OpenWorldHint: &openWorld,
}
case "DELETE":
destructive = true
return &sdkmcp.ToolAnnotations{
Title: title,
ReadOnlyHint: false,
DestructiveHint: &destructive,
IdempotentHint: true,
OpenWorldHint: &openWorld,
}
default:
return &sdkmcp.ToolAnnotations{
Title: title,
ReadOnlyHint: false,
DestructiveHint: &destructive,
IdempotentHint: false,
OpenWorldHint: &openWorld,
}
}
}

@ -0,0 +1,152 @@
package mcp
import (
"encoding/json"
"testing"
sdkmcp "github.com/modelcontextprotocol/go-sdk/mcp"
"github.com/stretchr/testify/require"
)
func TestCuratedOperationIDsStayMemoFocused(t *testing.T) {
require.Len(t, curatedOperationIDs, 17)
for _, operationID := range curatedOperationIDs {
require.NotContains(t, operationID, "Admin")
require.NotContains(t, operationID, "AuthService_")
require.NotContains(t, operationID, "UserService_")
require.NotContains(t, operationID, "AIService_")
require.NotContains(t, operationID, "IdentityProviderService_")
require.NotContains(t, operationID, "InstanceService_")
require.NotContains(t, operationID, "PersonalAccessToken")
require.NotContains(t, operationID, "PAT")
require.NotContains(t, operationID, "Webhook")
require.NotContains(t, operationID, "Share")
require.NotContains(t, operationID, "BatchDelete")
require.NotContains(t, operationID, "Transcribe")
}
}
func TestToolNameFromOperationID(t *testing.T) {
require.Equal(t, "memo_list_memos", toolNameFromOperationID("MemoService_ListMemos"))
require.Equal(t, "attachment_get_attachment", toolNameFromOperationID("AttachmentService_GetAttachment"))
}
func TestBuildToolFromOperationIncludesSchemasAndMetadata(t *testing.T) {
spec, err := loadOpenAPISpec("../../../proto/gen/openapi.yaml")
require.NoError(t, err)
registry, err := buildOperationRegistry(spec)
require.NoError(t, err)
tool, operation := buildToolFromOperation(registry["MemoService_ListMemos"])
require.Equal(t, "memo_list_memos", tool.Name)
require.Equal(t, "Memo List Memos", tool.Title)
require.Equal(t, "MemoService_ListMemos", operation.OperationID)
require.Equal(t, "GET", operation.Method)
require.Equal(t, "/api/v1/memos", operation.Path)
require.Equal(t, "MemoService_ListMemos", tool.Meta["operationId"])
require.Equal(t, "GET", tool.Meta["method"])
require.Equal(t, "/api/v1/memos", tool.Meta["path"])
require.NotEmpty(t, tool.Description)
require.NotNil(t, tool.InputSchema)
require.NotNil(t, tool.OutputSchema)
require.NotNil(t, tool.Annotations)
require.True(t, tool.Annotations.ReadOnlyHint)
require.False(t, *tool.Annotations.DestructiveHint)
require.True(t, tool.Annotations.IdempotentHint)
require.False(t, *tool.Annotations.OpenWorldHint)
inputBytes, err := json.Marshal(tool.InputSchema)
require.NoError(t, err)
require.Contains(t, string(inputBytes), `"pageSize"`)
require.Contains(t, string(inputBytes), `"additionalProperties":false`)
outputBytes, err := json.Marshal(tool.OutputSchema)
require.NoError(t, err)
require.Contains(t, string(outputBytes), `"memos"`)
}
func TestBuildToolFromOperationIncludesRequestBodySchema(t *testing.T) {
spec, err := loadOpenAPISpec("../../../proto/gen/openapi.yaml")
require.NoError(t, err)
registry, err := buildOperationRegistry(spec)
require.NoError(t, err)
tool, operation := buildToolFromOperation(registry["MemoService_CreateMemo"])
require.Equal(t, "POST", operation.Method)
require.False(t, tool.Annotations.ReadOnlyHint)
require.False(t, *tool.Annotations.DestructiveHint)
require.False(t, tool.Annotations.IdempotentHint)
input, ok := tool.InputSchema.(jsonSchema)
require.True(t, ok)
require.Contains(t, input["required"], "body")
properties, ok := input["properties"].(map[string]any)
require.True(t, ok)
require.Contains(t, properties, "memoId")
require.Contains(t, properties, "body")
body, ok := properties["body"].(jsonSchema)
require.True(t, ok)
require.Equal(t, "object", body["type"])
require.Contains(t, body["properties"], "content")
err = validateToolArguments(input, map[string]any{
"body": map[string]any{
"state": "NORMAL",
"content": "hello",
"visibility": "PRIVATE",
},
})
require.NoError(t, err)
}
func TestBuildCuratedToolsHasUniqueNames(t *testing.T) {
spec, err := loadOpenAPISpec("../../../proto/gen/openapi.yaml")
require.NoError(t, err)
registry, err := buildOperationRegistry(spec)
require.NoError(t, err)
tools, operations, err := buildCuratedTools(registry)
require.NoError(t, err)
require.Len(t, tools, len(curatedOperationIDs))
require.Len(t, operations, len(curatedOperationIDs))
names := map[string]struct{}{}
for _, tool := range tools {
require.IsType(t, &sdkmcp.Tool{}, tool)
require.NotEmpty(t, tool.Name)
require.NotContains(t, names, tool.Name)
names[tool.Name] = struct{}{}
require.Equal(t, tool.Name, operations[tool.Name].ToolName)
inputBytes, err := json.Marshal(tool.InputSchema)
require.NoError(t, err)
require.NotContains(t, string(inputBytes), "#/components/schemas")
outputBytes, err := json.Marshal(tool.OutputSchema)
require.NoError(t, err)
require.NotContains(t, string(outputBytes), "#/components/schemas")
}
}
func TestBuildCuratedToolsRejectsMissingOperation(t *testing.T) {
_, _, err := buildCuratedTools(map[string]*openAPIOperation{})
require.ErrorContains(t, err, "curated OpenAPI operation")
require.ErrorContains(t, err, "not found")
}
func TestBuildCuratedToolsRejectsDuplicateToolNames(t *testing.T) {
registry := make(map[string]*openAPIOperation, len(curatedOperationIDs))
for _, operationID := range curatedOperationIDs {
registry[operationID] = &openAPIOperation{
OperationID: operationID,
Description: operationID,
Method: "GET",
Path: "/api/v1/test",
ResponseSchema: okSchema(),
}
}
registry["MemoService_ListMemos"].OperationID = "MemoService_GetMemo"
_, _, err := buildCuratedTools(registry)
require.ErrorContains(t, err, "duplicate MCP tool name")
}

@ -0,0 +1,252 @@
package mcp
import (
"os"
"strings"
"github.com/pkg/errors"
"gopkg.in/yaml.v3"
)
type jsonSchema map[string]any
type openAPISpec struct {
OpenAPI string `yaml:"openapi"`
Paths map[string]map[string]*openAPIOperation `yaml:"paths"`
Components openAPIComponents `yaml:"components"`
}
type openAPIComponents struct {
Schemas map[string]jsonSchema `yaml:"schemas"`
}
type openAPIOperation struct {
OperationID string `yaml:"operationId"`
Description string `yaml:"description"`
Parameters []openAPIParameter `yaml:"parameters"`
RequestBody *openAPIRequestBody `yaml:"requestBody"`
Responses map[string]openAPIResponse `yaml:"responses"`
Method string `yaml:"-"`
Path string `yaml:"-"`
ResponseSchema jsonSchema `yaml:"-"`
RequestBodySchema jsonSchema `yaml:"-"`
}
type openAPIParameter struct {
Name string `yaml:"name"`
In string `yaml:"in"`
Description string `yaml:"description"`
Required bool `yaml:"required"`
Schema jsonSchema `yaml:"schema"`
}
type openAPIRequestBody struct {
Required bool `yaml:"required"`
Content map[string]openAPIMediaType `yaml:"content"`
}
type openAPIResponse struct {
Description string `yaml:"description"`
Content map[string]openAPIMediaType `yaml:"content"`
}
type openAPIMediaType struct {
Schema jsonSchema `yaml:"schema"`
}
func loadOpenAPISpec(path string) (*openAPISpec, error) {
data, err := os.ReadFile(path)
if err != nil {
return nil, errors.Wrap(err, "failed to read OpenAPI spec")
}
spec := &openAPISpec{}
if err := yaml.Unmarshal(data, spec); err != nil {
return nil, errors.Wrap(err, "failed to parse OpenAPI spec")
}
if spec.Paths == nil {
return nil, errors.New("OpenAPI spec has no paths")
}
return spec, nil
}
func buildOperationRegistry(spec *openAPISpec) (map[string]*openAPIOperation, error) {
registry := map[string]*openAPIOperation{}
for path, pathItem := range spec.Paths {
for method, operation := range pathItem {
if operation == nil || operation.OperationID == "" {
continue
}
if _, exists := registry[operation.OperationID]; exists {
return nil, errors.Errorf("duplicate OpenAPI operationId %q", operation.OperationID)
}
operation.Method = strings.ToUpper(method)
operation.Path = path
responseSchema, err := operationSuccessResponseSchema(spec, operation)
if err != nil {
return nil, errors.Wrapf(err, "failed to resolve response schema for %s", operation.OperationID)
}
operation.ResponseSchema = responseSchema
requestBodySchema, err := operationRequestBodySchema(spec, operation)
if err != nil {
return nil, errors.Wrapf(err, "failed to resolve request body schema for %s", operation.OperationID)
}
operation.RequestBodySchema = requestBodySchema
registry[operation.OperationID] = operation
}
}
return registry, nil
}
func operationSuccessResponseSchema(spec *openAPISpec, operation *openAPIOperation) (jsonSchema, error) {
response, ok := operation.Responses["200"]
if !ok || response.Content == nil {
return okSchema(), nil
}
mediaType, ok := response.Content["application/json"]
if !ok || mediaType.Schema == nil {
return okSchema(), nil
}
return resolveSchemaRef(spec, mediaType.Schema)
}
func operationRequestBodySchema(spec *openAPISpec, operation *openAPIOperation) (jsonSchema, error) {
if operation.RequestBody == nil {
return nil, nil
}
mediaType, ok := operation.RequestBody.Content["application/json"]
if !ok || mediaType.Schema == nil {
return jsonSchema{"type": "object"}, nil
}
return resolveSchemaRef(spec, mediaType.Schema)
}
func resolveSchemaRef(spec *openAPISpec, schema jsonSchema) (jsonSchema, error) {
defs := map[string]any{}
resolved, err := resolveSchemaValue(spec, schema, defs, map[string]bool{}, true)
if err != nil {
return nil, err
}
resolvedSchema, ok := resolved.(map[string]any)
if !ok {
return nil, errors.New("resolved schema is not an object")
}
if len(defs) > 0 {
resolvedSchema["$defs"] = defs
}
return jsonSchema(resolvedSchema), nil
}
func resolveSchemaValue(spec *openAPISpec, value any, defs map[string]any, resolving map[string]bool, inlineRef bool) (any, error) {
switch typed := value.(type) {
case jsonSchema:
return resolveSchemaMap(spec, map[string]any(typed), defs, resolving, inlineRef)
case map[string]any:
return resolveSchemaMap(spec, typed, defs, resolving, inlineRef)
case []any:
resolved := make([]any, 0, len(typed))
for _, item := range typed {
resolvedItem, err := resolveSchemaValue(spec, item, defs, resolving, false)
if err != nil {
return nil, err
}
resolved = append(resolved, resolvedItem)
}
return resolved, nil
default:
return value, nil
}
}
func resolveSchemaMap(spec *openAPISpec, schema map[string]any, defs map[string]any, resolving map[string]bool, inlineRef bool) (map[string]any, error) {
if ref, ok := schema["$ref"].(string); ok && ref != "" {
name, err := schemaComponentName(ref)
if err != nil {
return nil, err
}
if inlineRef {
return resolveComponentSchema(spec, name, defs, resolving)
}
if err := addSchemaDef(spec, name, defs, resolving); err != nil {
return nil, err
}
return map[string]any{"$ref": "#/$defs/" + name}, nil
}
resolved := make(map[string]any, len(schema))
for key, value := range schema {
resolvedValue, err := resolveSchemaValue(spec, value, defs, resolving, false)
if err != nil {
return nil, err
}
resolved[key] = resolvedValue
}
return resolved, nil
}
func resolveComponentSchema(spec *openAPISpec, name string, defs map[string]any, resolving map[string]bool) (map[string]any, error) {
component, ok := spec.Components.Schemas[name]
if !ok {
return nil, errors.Errorf("schema ref %q not found", schemaComponentRef(name))
}
resolving[name] = true
resolved, err := resolveSchemaMap(spec, map[string]any(component), defs, resolving, false)
delete(resolving, name)
if err != nil {
return nil, err
}
if _, ok := defs[name]; ok {
defs[name] = resolved
}
return resolved, nil
}
func addSchemaDef(spec *openAPISpec, name string, defs map[string]any, resolving map[string]bool) error {
if _, ok := defs[name]; ok {
return nil
}
component, ok := spec.Components.Schemas[name]
if !ok {
return errors.Errorf("schema ref %q not found", schemaComponentRef(name))
}
defs[name] = map[string]any{}
if resolving[name] {
return nil
}
resolving[name] = true
resolved, err := resolveSchemaMap(spec, map[string]any(component), defs, resolving, false)
delete(resolving, name)
if err != nil {
return err
}
defs[name] = resolved
return nil
}
func schemaComponentName(ref string) (string, error) {
const prefix = "#/components/schemas/"
if !strings.HasPrefix(ref, prefix) {
return "", errors.Errorf("unsupported schema ref %q", ref)
}
return strings.TrimPrefix(ref, prefix), nil
}
func schemaComponentRef(name string) string {
return "#/components/schemas/" + name
}
func okSchema() jsonSchema {
return jsonSchema{
"type": "object",
"properties": map[string]any{
"ok": map[string]any{"type": "boolean"},
},
}
}

@ -0,0 +1,192 @@
package mcp
import (
"encoding/json"
"testing"
"github.com/stretchr/testify/require"
)
func TestLoadOpenAPIOperationsIncludesCuratedIDs(t *testing.T) {
spec, err := loadOpenAPISpec("../../../proto/gen/openapi.yaml")
require.NoError(t, err)
registry, err := buildOperationRegistry(spec)
require.NoError(t, err)
curatedIDs := []string{
"MemoService_ListMemos",
"MemoService_CreateMemo",
"MemoService_GetMemo",
"MemoService_UpdateMemo",
"MemoService_DeleteMemo",
"MemoService_ListMemoComments",
"MemoService_CreateMemoComment",
"MemoService_ListMemoAttachments",
"MemoService_SetMemoAttachments",
"MemoService_ListMemoReactions",
"MemoService_UpsertMemoReaction",
"MemoService_DeleteMemoReaction",
"MemoService_ListMemoRelations",
"MemoService_SetMemoRelations",
"AttachmentService_ListAttachments",
"AttachmentService_GetAttachment",
"AttachmentService_DeleteAttachment",
}
for _, operationID := range curatedIDs {
operation, ok := registry[operationID]
require.True(t, ok, "missing curated operation %s", operationID)
require.NotEmpty(t, operation.Method, operationID)
require.NotEmpty(t, operation.Path, operationID)
require.NotEmpty(t, operation.Description, operationID)
require.NotNil(t, operation.ResponseSchema, operationID)
}
createMemo := registry["MemoService_CreateMemo"]
require.NotNil(t, createMemo.RequestBodySchema)
require.Equal(t, "object", createMemo.RequestBodySchema["type"])
}
func TestBuildOperationRegistryRejectsDuplicateOperationIDs(t *testing.T) {
spec := &openAPISpec{
Paths: map[string]map[string]*openAPIOperation{
"/a": {
"get": {OperationID: "MemoService_GetMemo", Description: "first"},
},
"/b": {
"get": {OperationID: "MemoService_GetMemo", Description: "second"},
},
},
}
_, err := buildOperationRegistry(spec)
require.ErrorContains(t, err, "duplicate OpenAPI operationId")
}
func TestResolveSchemaRef(t *testing.T) {
spec := &openAPISpec{
Components: openAPIComponents{
Schemas: map[string]jsonSchema{
"ListMemosResponse": {
"type": "object",
"properties": map[string]any{
"memos": map[string]any{"type": "array"},
},
},
},
},
}
schema, err := resolveSchemaRef(spec, jsonSchema{"$ref": "#/components/schemas/ListMemosResponse"})
require.NoError(t, err)
require.Equal(t, "object", schema["type"])
require.Contains(t, schema["properties"], "memos")
}
func TestResolveSchemaRefRewritesNestedComponentRefs(t *testing.T) {
spec := &openAPISpec{
Components: openAPIComponents{
Schemas: map[string]jsonSchema{
"ListMemosResponse": {
"type": "object",
"properties": map[string]any{
"memos": map[string]any{
"type": "array",
"items": map[string]any{"$ref": "#/components/schemas/Memo"},
},
},
},
"Memo": {
"type": "object",
"properties": map[string]any{
"attachment": map[string]any{"$ref": "#/components/schemas/Attachment"},
},
},
"Attachment": {
"type": "object",
"properties": map[string]any{
"name": map[string]any{"type": "string"},
},
},
},
},
}
schema, err := resolveSchemaRef(spec, jsonSchema{"$ref": "#/components/schemas/ListMemosResponse"})
require.NoError(t, err)
data, err := json.Marshal(schema)
require.NoError(t, err)
require.NotContains(t, string(data), "#/components/schemas")
require.Contains(t, string(data), `"#/$defs/Memo"`)
require.Contains(t, string(data), `"#/$defs/Attachment"`)
}
func TestBuildOperationRegistryResolvesRequestBodySchema(t *testing.T) {
spec := &openAPISpec{
Paths: map[string]map[string]*openAPIOperation{
"/memos": {
"post": {
OperationID: "MemoService_CreateMemo",
RequestBody: &openAPIRequestBody{
Content: map[string]openAPIMediaType{
"application/json": {Schema: jsonSchema{"$ref": "#/components/schemas/CreateMemoRequest"}},
},
},
Responses: map[string]openAPIResponse{
"200": {
Content: map[string]openAPIMediaType{
"application/json": {Schema: jsonSchema{"$ref": "#/components/schemas/Memo"}},
},
},
},
},
},
},
Components: openAPIComponents{
Schemas: map[string]jsonSchema{
"CreateMemoRequest": {
"type": "object",
"properties": map[string]any{
"content": map[string]any{"type": "string"},
},
},
"Memo": {
"type": "object",
},
},
},
}
registry, err := buildOperationRegistry(spec)
require.NoError(t, err)
requestSchema := registry["MemoService_CreateMemo"].RequestBodySchema
require.Equal(t, "object", requestSchema["type"])
require.Contains(t, requestSchema["properties"], "content")
}
func TestBuildOperationRegistryUsesOKSchemaForEmptySuccessResponse(t *testing.T) {
spec := &openAPISpec{
Paths: map[string]map[string]*openAPIOperation{
"/memos/{memo}": {
"delete": {
OperationID: "MemoService_DeleteMemo",
Responses: map[string]openAPIResponse{
"200": {
Content: map[string]openAPIMediaType{},
},
},
},
},
},
}
registry, err := buildOperationRegistry(spec)
require.NoError(t, err)
responseSchema := registry["MemoService_DeleteMemo"].ResponseSchema
require.Equal(t, "object", responseSchema["type"])
require.Contains(t, responseSchema["properties"], "ok")
}

@ -0,0 +1,31 @@
package mcp
import (
"net/url"
"strings"
"github.com/usememos/memos/internal/profile"
)
func isAllowedMCPOrigin(host string, origin string, profile *profile.Profile) bool {
if origin == "" {
return true
}
originURL, err := url.Parse(origin)
if err != nil || originURL.Scheme == "" || originURL.Host == "" {
return false
}
if strings.EqualFold(originURL.Host, host) {
return true
}
if profile == nil || profile.InstanceURL == "" {
return false
}
instanceURL, err := url.Parse(profile.InstanceURL)
if err != nil || instanceURL.Scheme == "" || instanceURL.Host == "" {
return false
}
return strings.EqualFold(originURL.Scheme, instanceURL.Scheme) && strings.EqualFold(originURL.Host, instanceURL.Host)
}

@ -0,0 +1,62 @@
package mcp
import (
"bytes"
"encoding/json"
sdkmcp "github.com/modelcontextprotocol/go-sdk/mcp"
"github.com/pkg/errors"
)
func normalizeStructuredContent(value any) map[string]any {
switch typed := value.(type) {
case nil:
return map[string]any{"ok": true}
case map[string]any:
return typed
case []any:
return map[string]any{"result": typed}
default:
return map[string]any{"result": typed}
}
}
func newStructuredToolResult(value any) (*sdkmcp.CallToolResult, error) {
structured := normalizeStructuredContent(value)
text, err := json.Marshal(structured)
if err != nil {
return nil, errors.Wrap(err, "failed to marshal MCP structured result")
}
return &sdkmcp.CallToolResult{
Content: []sdkmcp.Content{
&sdkmcp.TextContent{Text: string(text)},
},
StructuredContent: structured,
}, nil
}
func newToolErrorResult(message string) *sdkmcp.CallToolResult {
structured := map[string]any{
"error": map[string]any{
"message": message,
},
}
return &sdkmcp.CallToolResult{
Content: []sdkmcp.Content{
&sdkmcp.TextContent{Text: message},
},
StructuredContent: structured,
IsError: true,
}
}
func decodeJSONValue(data []byte) (any, error) {
if len(bytes.TrimSpace(data)) == 0 {
return nil, nil
}
var value any
if err := json.Unmarshal(data, &value); err != nil {
return nil, errors.Wrap(err, "failed to decode API JSON response")
}
return value, nil
}

@ -0,0 +1,110 @@
package mcp
import (
"context"
"encoding/json"
"net/http"
"github.com/labstack/echo/v5"
sdkmcp "github.com/modelcontextprotocol/go-sdk/mcp"
"github.com/pkg/errors"
"gopkg.in/yaml.v3"
"github.com/usememos/memos/internal/profile"
memosproto "github.com/usememos/memos/proto"
)
// MCPService serves the OpenAPI-driven MCP endpoint.
type MCPService struct {
profile *profile.Profile
operationsByTool map[string]*registeredOperation
handler http.Handler
}
// NewMCPService creates an MCP service backed by the in-process API routes.
func NewMCPService(profile *profile.Profile, echoServer *echo.Echo) (*MCPService, error) {
spec, err := loadMCPServiceOpenAPISpec()
if err != nil {
return nil, err
}
registry, err := buildOperationRegistry(spec)
if err != nil {
return nil, err
}
tools, operationsByTool, err := buildCuratedTools(registry)
if err != nil {
return nil, err
}
version := "dev"
if profile != nil && profile.Version != "" {
version = profile.Version
}
server := sdkmcp.NewServer(&sdkmcp.Implementation{
Name: "memos",
Version: version,
}, nil)
adapter := newAPIAdapter(echoServer)
for _, tool := range tools {
operation := operationsByTool[tool.Name]
server.AddTool(tool, newMCPToolHandler(adapter, operation))
}
handler := sdkmcp.NewStreamableHTTPHandler(func(*http.Request) *sdkmcp.Server {
return server
}, &sdkmcp.StreamableHTTPOptions{
Stateless: true,
JSONResponse: true,
})
return &MCPService{
profile: profile,
operationsByTool: operationsByTool,
handler: handler,
}, nil
}
func loadMCPServiceOpenAPISpec() (*openAPISpec, error) {
spec := &openAPISpec{}
if err := yaml.Unmarshal(memosproto.OpenAPIYAML(), spec); err != nil {
return nil, errors.Wrap(err, "failed to parse embedded OpenAPI spec")
}
if spec.Paths == nil {
return nil, errors.New("embedded OpenAPI spec has no paths")
}
return spec, nil
}
func newMCPToolHandler(adapter *apiAdapter, operation *registeredOperation) sdkmcp.ToolHandler {
return func(ctx context.Context, request *sdkmcp.CallToolRequest) (*sdkmcp.CallToolResult, error) {
arguments := map[string]any{}
if request.Params != nil && len(request.Params.Arguments) > 0 {
if err := json.Unmarshal(request.Params.Arguments, &arguments); err != nil {
return newToolErrorResult(errors.Wrap(err, "failed to decode MCP tool arguments").Error()), nil
}
}
if err := validateToolArguments(operation.InputSchema, arguments); err != nil {
return newToolErrorResult(err.Error()), nil
}
authorization := ""
if request.Extra != nil {
authorization = request.Extra.Header.Get("Authorization")
}
return adapter.execute(ctx, operation.Operation, arguments, authorization)
}
}
// RegisterRoutes registers the streamable HTTP MCP endpoint.
func (s *MCPService) RegisterRoutes(echoServer *echo.Echo) {
echoServer.Any("/mcp", func(c *echo.Context) error {
request := c.Request()
if !isAllowedMCPOrigin(request.Host, request.Header.Get("Origin"), s.profile) {
return c.NoContent(http.StatusForbidden)
}
s.handler.ServeHTTP(c.Response(), request)
return nil
})
}

@ -0,0 +1,283 @@
package mcp
import (
"bytes"
"context"
"encoding/json"
"net/http"
"net/http/httptest"
"os"
"testing"
"github.com/labstack/echo/v5"
sdkmcp "github.com/modelcontextprotocol/go-sdk/mcp"
"github.com/stretchr/testify/require"
"github.com/usememos/memos/internal/profile"
memosproto "github.com/usememos/memos/proto"
)
func TestIsAllowedMCPOrigin(t *testing.T) {
profile := &profile.Profile{InstanceURL: "https://memos.example.com/app"}
tests := []struct {
name string
host string
origin string
want bool
}{
{name: "empty origin", host: "localhost:5230", origin: "", want: true},
{name: "same http host", host: "localhost:5230", origin: "http://localhost:5230", want: true},
{name: "same https host", host: "memos.example.com", origin: "https://memos.example.com", want: true},
{name: "configured instance URL origin", host: "127.0.0.1:5230", origin: "https://memos.example.com", want: true},
{name: "configured instance URL ignores path", host: "127.0.0.1:5230", origin: "https://memos.example.com", want: true},
{name: "different host", host: "localhost:5230", origin: "https://evil.example.com", want: false},
{name: "invalid origin", host: "localhost:5230", origin: "not a url", want: false},
}
for _, test := range tests {
t.Run(test.name, func(t *testing.T) {
require.Equal(t, test.want, isAllowedMCPOrigin(test.host, test.origin, profile))
})
}
}
func TestNewMCPServiceRegistersCuratedTools(t *testing.T) {
echoServer := echo.New()
service, err := NewMCPService(&profile.Profile{Version: "test-version"}, echoServer)
require.NoError(t, err)
require.NotNil(t, service.handler)
require.Len(t, service.operationsByTool, len(curatedOperationIDs))
operation := service.operationsByTool["memo_list_memos"]
require.NotNil(t, operation)
require.Equal(t, "MemoService_ListMemos", operation.OperationID)
require.Equal(t, "GET", operation.Method)
require.Equal(t, "/api/v1/memos", operation.Path)
}
func TestNewMCPServiceUsesEmbeddedOpenAPISpec(t *testing.T) {
t.Chdir(t.TempDir())
service, err := NewMCPService(&profile.Profile{Version: "test-version"}, echo.New())
require.NoError(t, err)
require.NotNil(t, service.handler)
require.Len(t, service.operationsByTool, len(curatedOperationIDs))
}
func TestEmbeddedOpenAPISpecMatchesGeneratedFile(t *testing.T) {
generated, err := os.ReadFile("../../../proto/gen/openapi.yaml")
require.NoError(t, err)
require.Equal(t, generated, memosproto.OpenAPIYAML())
}
func TestMCPToolHandlerForwardsArgumentsAndAuthorization(t *testing.T) {
echoServer := echo.New()
echoServer.GET("/api/v1/memos", func(c *echo.Context) error {
require.Equal(t, "Bearer token", c.Request().Header.Get("Authorization"))
require.Equal(t, "7", c.QueryParam("pageSize"))
return c.JSON(http.StatusOK, map[string]any{
"memos": []any{map[string]any{"name": "memos/test"}},
})
})
operation := &registeredOperation{
Operation: &openAPIOperation{
Method: "GET",
Path: "/api/v1/memos",
Parameters: []openAPIParameter{{Name: "pageSize", In: "query", Schema: jsonSchema{"type": "integer"}}},
},
}
handler := newMCPToolHandler(newAPIAdapter(echoServer), operation)
arguments, err := json.Marshal(map[string]any{"pageSize": 7})
require.NoError(t, err)
result, err := handler(context.Background(), &sdkmcp.CallToolRequest{
Params: &sdkmcp.CallToolParamsRaw{
Name: "memo_list_memos",
Arguments: arguments,
},
Extra: &sdkmcp.RequestExtra{
Header: http.Header{"Authorization": []string{"Bearer token"}},
},
})
require.NoError(t, err)
require.False(t, result.IsError)
require.Equal(t, map[string]any{
"memos": []any{map[string]any{"name": "memos/test"}},
}, result.StructuredContent)
}
func TestMCPProtocolListsCuratedToolsOnly(t *testing.T) {
echoServer := echo.New()
service, err := NewMCPService(&profile.Profile{Version: "test-version"}, echoServer)
require.NoError(t, err)
service.RegisterRoutes(echoServer)
initializeMCP(t, echoServer)
response := postMCP(t, echoServer, map[string]any{
"jsonrpc": "2.0",
"id": 2,
"method": "tools/list",
})
result, ok := response["result"].(map[string]any)
require.True(t, ok)
tools, ok := result["tools"].([]any)
require.True(t, ok)
require.Len(t, tools, len(curatedOperationIDs))
names := map[string]struct{}{}
for _, rawTool := range tools {
tool, ok := rawTool.(map[string]any)
require.True(t, ok)
name, ok := tool["name"].(string)
require.True(t, ok)
names[name] = struct{}{}
require.Contains(t, tool, "inputSchema")
require.Contains(t, tool, "outputSchema")
}
require.Contains(t, names, "memo_list_memos")
require.Contains(t, names, "memo_create_memo")
require.NotContains(t, names, "auth_sign_in")
require.NotContains(t, names, "user_create_user")
}
func TestMCPToolCallReturnsObjectStructuredContent(t *testing.T) {
echoServer := echo.New()
echoServer.GET("/api/v1/memos", func(c *echo.Context) error {
return c.JSON(http.StatusOK, map[string]any{
"memos": []any{map[string]any{"name": "memos/abc123"}},
})
})
service, err := NewMCPService(&profile.Profile{Version: "test-version"}, echoServer)
require.NoError(t, err)
service.RegisterRoutes(echoServer)
initializeMCP(t, echoServer)
response := postMCP(t, echoServer, map[string]any{
"jsonrpc": "2.0",
"id": 2,
"method": "tools/call",
"params": map[string]any{
"name": "memo_list_memos",
"arguments": map[string]any{
"pageSize": 1,
},
},
})
result, ok := response["result"].(map[string]any)
require.True(t, ok)
require.Equal(t, map[string]any{
"memos": []any{map[string]any{"name": "memos/abc123"}},
}, result["structuredContent"])
}
func TestMCPToolCallRejectsInvalidArguments(t *testing.T) {
echoServer := echo.New()
routeHits := 0
echoServer.GET("/api/v1/memos", func(c *echo.Context) error {
routeHits++
return c.JSON(http.StatusOK, map[string]any{"memos": []any{}})
})
echoServer.GET("/api/v1/memos/:memo", func(c *echo.Context) error {
routeHits++
return c.JSON(http.StatusOK, map[string]any{"name": c.Param("memo")})
})
service, err := NewMCPService(&profile.Profile{Version: "test-version"}, echoServer)
require.NoError(t, err)
service.RegisterRoutes(echoServer)
initializeMCP(t, echoServer)
tests := []struct {
name string
toolName string
arguments map[string]any
wantError string
}{
{
name: "unknown argument",
toolName: "memo_list_memos",
arguments: map[string]any{"unexpected": true},
wantError: `unknown argument "unexpected"`,
},
{
name: "missing required argument",
toolName: "memo_get_memo",
arguments: map[string]any{},
wantError: `missing required argument "memo"`,
},
{
name: "wrong primitive type",
toolName: "memo_list_memos",
arguments: map[string]any{"pageSize": "ten"},
wantError: `argument "pageSize" must be integer`,
},
}
for index, test := range tests {
t.Run(test.name, func(t *testing.T) {
response := postMCP(t, echoServer, map[string]any{
"jsonrpc": "2.0",
"id": index + 2,
"method": "tools/call",
"params": map[string]any{
"name": test.toolName,
"arguments": test.arguments,
},
})
result, ok := response["result"].(map[string]any)
require.True(t, ok)
require.Equal(t, true, result["isError"])
structured, ok := result["structuredContent"].(map[string]any)
require.True(t, ok)
errorObject, ok := structured["error"].(map[string]any)
require.True(t, ok)
require.Contains(t, errorObject["message"], test.wantError)
})
}
require.Zero(t, routeHits)
}
func initializeMCP(t *testing.T, echoServer *echo.Echo) {
t.Helper()
response := postMCP(t, echoServer, map[string]any{
"jsonrpc": "2.0",
"id": 1,
"method": "initialize",
"params": map[string]any{
"protocolVersion": "2025-06-18",
"capabilities": map[string]any{},
"clientInfo": map[string]any{
"name": "memos-test",
"version": "1.0.0",
},
},
})
require.NotNil(t, response["result"])
}
func postMCP(t *testing.T, echoServer *echo.Echo, payload map[string]any) map[string]any {
t.Helper()
data, err := json.Marshal(payload)
require.NoError(t, err)
request := httptest.NewRequest(http.MethodPost, "/mcp", bytes.NewReader(data))
request.Header.Set("Content-Type", "application/json")
request.Header.Set("Accept", "application/json, text/event-stream")
recorder := httptest.NewRecorder()
echoServer.ServeHTTP(recorder, request)
require.Equal(t, http.StatusOK, recorder.Code, recorder.Body.String())
var response map[string]any
require.NoError(t, json.Unmarshal(recorder.Body.Bytes(), &response))
return response
}

@ -0,0 +1,226 @@
package mcp
import (
"encoding/json"
"math"
"strings"
googlejsonschema "github.com/google/jsonschema-go/jsonschema"
"github.com/pkg/errors"
)
func validateToolArguments(schema jsonSchema, arguments map[string]any) error {
if schema == nil {
return nil
}
if err := validateSchemaValue(schema, "argument", "argument", arguments, schema); err != nil {
return err
}
return validateArgumentsWithJSONSchema(schema, arguments)
}
func validateArgumentsWithJSONSchema(schema jsonSchema, arguments map[string]any) error {
data, err := json.Marshal(schema)
if err != nil {
return errors.Wrap(err, "failed to marshal MCP tool input schema")
}
jsonSchema := &googlejsonschema.Schema{}
if err := json.Unmarshal(data, jsonSchema); err != nil {
return errors.Wrap(err, "failed to parse MCP tool input schema")
}
resolved, err := jsonSchema.Resolve(nil)
if err != nil {
return errors.Wrap(err, "failed to resolve MCP tool input schema")
}
if err := resolved.Validate(arguments); err != nil {
return errors.Wrap(err, "MCP tool arguments do not match input schema")
}
return nil
}
func validateSchemaValue(schemaValue any, path string, label string, value any, root jsonSchema) error {
schema, ok := asSchemaMap(schemaValue)
if !ok {
return nil
}
if ref, ok := schema["$ref"].(string); ok && ref != "" {
resolved, ok := localSchemaDef(root, ref)
if !ok {
return nil
}
return validateSchemaValue(resolved, path, label, value, root)
}
if value == nil {
return nil
}
types := schemaTypes(schema["type"])
if len(types) == 0 && schema["properties"] != nil {
types = []string{"object"}
}
for _, schemaType := range types {
if schemaTypeMatchesValue(schemaType, value) {
if schemaType == "object" {
return validateObjectSchema(schema, path, value, root)
}
return nil
}
}
if len(types) == 0 {
return nil
}
return errors.Errorf(`%s "%s" must be %s`, label, path, types[0])
}
func validateObjectSchema(schema map[string]any, path string, value any, root jsonSchema) error {
object, ok := value.(map[string]any)
if !ok {
return errors.Errorf(`argument "%s" must be object`, path)
}
properties := schemaProperties(schema["properties"])
for _, required := range requiredNames(schema["required"]) {
if child, ok := object[required]; !ok || child == nil {
return errors.Errorf(`missing required argument "%s"`, joinSchemaPath(path, required))
}
}
if additionalProperties, ok := schema["additionalProperties"].(bool); ok && !additionalProperties {
for name := range object {
if _, ok := properties[name]; !ok {
return errors.Errorf(`unknown argument "%s"`, joinSchemaPath(path, name))
}
}
}
for name, childSchema := range properties {
childValue, ok := object[name]
if !ok {
continue
}
if err := validateSchemaValue(childSchema, joinSchemaPath(path, name), "argument", childValue, root); err != nil {
return err
}
}
return nil
}
func asSchemaMap(value any) (map[string]any, bool) {
switch typed := value.(type) {
case jsonSchema:
return map[string]any(typed), true
case map[string]any:
return typed, true
default:
return nil, false
}
}
func schemaProperties(value any) map[string]any {
switch typed := value.(type) {
case map[string]any:
return typed
case jsonSchema:
return map[string]any(typed)
default:
return map[string]any{}
}
}
func schemaTypes(value any) []string {
switch typed := value.(type) {
case string:
return []string{typed}
case []any:
types := make([]string, 0, len(typed))
for _, item := range typed {
if typeName, ok := item.(string); ok {
types = append(types, typeName)
}
}
return types
default:
return nil
}
}
func requiredNames(value any) []string {
switch typed := value.(type) {
case []string:
return typed
case []any:
names := make([]string, 0, len(typed))
for _, item := range typed {
if name, ok := item.(string); ok {
names = append(names, name)
}
}
return names
default:
return nil
}
}
func schemaTypeMatchesValue(schemaType string, value any) bool {
switch schemaType {
case "array":
_, ok := value.([]any)
return ok
case "boolean":
_, ok := value.(bool)
return ok
case "integer":
return isInteger(value)
case "number":
return isNumber(value)
case "object":
_, ok := value.(map[string]any)
return ok
case "string":
_, ok := value.(string)
return ok
default:
return true
}
}
func isInteger(value any) bool {
switch typed := value.(type) {
case int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64:
return true
case float32:
return math.Trunc(float64(typed)) == float64(typed)
case float64:
return math.Trunc(typed) == typed
default:
return false
}
}
func isNumber(value any) bool {
switch value.(type) {
case int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64, float32, float64:
return true
default:
return false
}
}
func localSchemaDef(root jsonSchema, ref string) (any, bool) {
const prefix = "#/$defs/"
if !strings.HasPrefix(ref, prefix) {
return nil, false
}
defs, ok := root["$defs"].(map[string]any)
if !ok {
return nil, false
}
return defs[strings.TrimPrefix(ref, prefix)], true
}
func joinSchemaPath(parent string, child string) string {
if parent == "" || parent == "argument" {
return child
}
return parent + "." + child
}

@ -0,0 +1,63 @@
package mcp
import (
"testing"
"github.com/stretchr/testify/require"
)
func TestValidateToolArgumentsRejectsUnknownArguments(t *testing.T) {
schema := jsonSchema{
"type": "object",
"properties": map[string]any{
"pageSize": map[string]any{"type": "integer"},
},
"additionalProperties": false,
}
err := validateToolArguments(schema, map[string]any{"unexpected": true})
require.ErrorContains(t, err, `unknown argument "unexpected"`)
}
func TestValidateToolArgumentsRejectsMissingRequiredArguments(t *testing.T) {
schema := jsonSchema{
"type": "object",
"properties": map[string]any{
"memo": map[string]any{"type": "string"},
},
"required": []string{"memo"},
"additionalProperties": false,
}
err := validateToolArguments(schema, map[string]any{})
require.ErrorContains(t, err, `missing required argument "memo"`)
}
func TestValidateToolArgumentsRejectsWrongPrimitiveTypes(t *testing.T) {
schema := jsonSchema{
"type": "object",
"properties": map[string]any{
"pageSize": map[string]any{"type": "integer"},
},
"additionalProperties": false,
}
err := validateToolArguments(schema, map[string]any{"pageSize": "ten"})
require.ErrorContains(t, err, `argument "pageSize" must be integer`)
}
func TestValidateToolArgumentsUsesJSONSchemaValidation(t *testing.T) {
schema := jsonSchema{
"type": "object",
"properties": map[string]any{
"state": map[string]any{
"type": "string",
"enum": []any{"NORMAL", "ARCHIVED"},
},
},
"additionalProperties": false,
}
err := validateToolArguments(schema, map[string]any{"state": "DELETED"})
require.ErrorContains(t, err, "MCP tool arguments do not match input schema")
}

@ -20,6 +20,7 @@ import (
apiv1 "github.com/usememos/memos/server/router/api/v1"
"github.com/usememos/memos/server/router/fileserver"
"github.com/usememos/memos/server/router/frontend"
"github.com/usememos/memos/server/router/mcp"
"github.com/usememos/memos/server/router/rss"
"github.com/usememos/memos/server/runner/s3presign"
"github.com/usememos/memos/store"
@ -88,6 +89,12 @@ func NewServer(ctx context.Context, profile *profile.Profile, store *store.Store
return nil, errors.Wrap(err, "failed to register gRPC gateway")
}
mcpService, err := mcp.NewMCPService(profile, echoServer)
if err != nil {
return nil, errors.Wrap(err, "failed to create MCP service")
}
mcpService.RegisterRoutes(echoServer)
return s, nil
}

Loading…
Cancel
Save