mirror of https://github.com/usememos/memos
feat: add OpenAPI-driven MCP support (#6026)
parent
a47d04954e
commit
777d227eb9
File diff suppressed because it is too large
Load Diff
@ -0,0 +1,241 @@
|
||||
# OpenAPI-Driven MCP Support Design
|
||||
|
||||
## Context
|
||||
|
||||
Memos previously had an MCP server, but it was removed in commit `2a4638b3` and should not be used as the baseline for this work. GitHub issue #6022 reported that some old MCP tools returned a bare JSON array in `result.structuredContent`, which strict MCP clients reject because `structuredContent` must be an object.
|
||||
|
||||
The new MCP support should start from zero and use the generated OpenAPI document at `proto/gen/openapi.yaml` as the source of truth. The OpenAPI file already describes the public REST gateway operations generated from protobuf definitions, including operation IDs, descriptions, parameters, request schemas, and response schemas.
|
||||
|
||||
## Goals
|
||||
|
||||
- Add MCP support back through a new implementation that is mechanically tied to `proto/gen/openapi.yaml`.
|
||||
- Expose a standard MCP Streamable HTTP endpoint at the `/mcp` path.
|
||||
- Register tools only; do not add MCP resources or prompts in the first version.
|
||||
- Expose a curated memo-focused toolset derived from OpenAPI operations, not every API endpoint.
|
||||
- Execute MCP tool calls through the existing API contract instead of duplicating store or service logic.
|
||||
- Return object-shaped `structuredContent` for every tool result.
|
||||
- Keep authentication and authorization behavior aligned with the existing API.
|
||||
|
||||
## Non-Goals
|
||||
|
||||
- Reviving or adapting the removed MCP package design.
|
||||
- Adding custom MCP route aliases, readonly endpoints, or toolset filtering headers.
|
||||
- Exposing every operation in `proto/gen/openapi.yaml` as an MCP tool.
|
||||
- Adding MCP tools for admin settings, users, identity providers, webhooks, personal access tokens, authentication, share-link management, AI transcription, or bulk deletion.
|
||||
- Adding `list_tags` or `search_memos` unless matching proto/API operations are added and OpenAPI is regenerated.
|
||||
- Hand-editing generated OpenAPI or generated protobuf outputs.
|
||||
|
||||
## Recommended Approach
|
||||
|
||||
Build a new `server/router/mcp` package that parses `proto/gen/openapi.yaml` at startup, selects a curated allowlist of operation IDs, and converts those selected OpenAPI operations into MCP tools. Tool calls map MCP arguments into the selected operation's path parameters, query parameters, and JSON body, then execute the corresponding `/api/v1/...` HTTP request through the existing Echo/gRPC-Gateway route in-process.
|
||||
|
||||
This approach keeps OpenAPI as the authoritative contract while avoiding the usability and safety problems of exposing a large API-mirrored tool surface.
|
||||
|
||||
## Architecture
|
||||
|
||||
### MCP Service
|
||||
|
||||
Create a new MCP service package under `server/router/mcp`. `server.NewServer` creates the existing `APIV1Service`, registers the normal file, RSS, and API routes, then registers a new MCP service against the same Echo server.
|
||||
|
||||
The service exposes one MCP endpoint path:
|
||||
|
||||
```text
|
||||
/mcp
|
||||
```
|
||||
|
||||
The implementation must support standard Streamable HTTP client messages on `POST /mcp`. It may also support `GET /mcp` and `DELETE /mcp` if the chosen MCP transport implementation requires those methods for standards-compliant streaming or session cleanup.
|
||||
|
||||
The first version advertises only the MCP tools capability. It does not advertise prompts or resources.
|
||||
|
||||
### OpenAPI Operation Registry
|
||||
|
||||
At startup, the MCP service loads `proto/gen/openapi.yaml` and builds an operation registry keyed by `operationId`. Each parsed operation stores:
|
||||
|
||||
- operation ID
|
||||
- HTTP method
|
||||
- OpenAPI route template
|
||||
- description
|
||||
- path parameters
|
||||
- query parameters
|
||||
- JSON request body schema
|
||||
- HTTP 200 JSON response schema
|
||||
|
||||
The parser should fail fast during service construction if any curated operation ID is missing or cannot be converted into a valid MCP tool schema.
|
||||
|
||||
### Tool Names
|
||||
|
||||
Tool names are derived from OpenAPI `operationId` values and normalized for MCP clients. The exact naming convention should be deterministic and tested. A practical convention is lower snake case without the `Service` suffix in the subject:
|
||||
|
||||
```text
|
||||
MemoService_ListMemos -> memo_list_memos
|
||||
AttachmentService_GetAttachment -> attachment_get_attachment
|
||||
```
|
||||
|
||||
The OpenAPI `operationId` remains stored in tool metadata so tests and future diagnostics can prove which OpenAPI operation produced each MCP tool.
|
||||
|
||||
### Tool Schemas
|
||||
|
||||
Each MCP tool input schema is an object built from:
|
||||
|
||||
- OpenAPI path parameters
|
||||
- OpenAPI query parameters
|
||||
- JSON request body fields, when present
|
||||
|
||||
Required path parameters stay required. Required request bodies stay required. Optional query parameters stay optional. The schema should preserve OpenAPI descriptions and primitive types where possible.
|
||||
|
||||
Each MCP tool output schema is the OpenAPI HTTP 200 JSON response schema. For empty 200 responses with no JSON schema, the MCP output schema is:
|
||||
|
||||
```json
|
||||
{ "type": "object", "properties": { "ok": { "type": "boolean" } } }
|
||||
```
|
||||
|
||||
## Tool Scope
|
||||
|
||||
The first version exposes these curated OpenAPI operations:
|
||||
|
||||
- `MemoService_ListMemos`
|
||||
- `MemoService_CreateMemo`
|
||||
- `MemoService_GetMemo`
|
||||
- `MemoService_UpdateMemo`
|
||||
- `MemoService_DeleteMemo`
|
||||
- `MemoService_ListMemoComments`
|
||||
- `MemoService_CreateMemoComment`
|
||||
- `MemoService_ListMemoAttachments`
|
||||
- `MemoService_SetMemoAttachments`
|
||||
- `MemoService_ListMemoReactions`
|
||||
- `MemoService_UpsertMemoReaction`
|
||||
- `MemoService_DeleteMemoReaction`
|
||||
- `MemoService_ListMemoRelations`
|
||||
- `MemoService_SetMemoRelations`
|
||||
- `AttachmentService_ListAttachments`
|
||||
- `AttachmentService_GetAttachment`
|
||||
- `AttachmentService_DeleteAttachment`
|
||||
|
||||
Excluded in the first version:
|
||||
|
||||
- auth sign-in, sign-out, and refresh
|
||||
- user management
|
||||
- personal access token management
|
||||
- identity provider management
|
||||
- webhooks
|
||||
- instance settings
|
||||
- share-link management
|
||||
- AI transcription
|
||||
- bulk delete operations
|
||||
- any operation not present in generated OpenAPI
|
||||
|
||||
## Data Flow
|
||||
|
||||
### Startup
|
||||
|
||||
1. `server.NewServer` creates the existing `APIV1Service`.
|
||||
2. The new `MCPService` loads `proto/gen/openapi.yaml`.
|
||||
3. The OpenAPI parser builds the operation registry.
|
||||
4. The curated operation allowlist selects supported operations.
|
||||
5. Each selected operation is registered as an MCP tool with input schema, output schema, description, annotations, and operation metadata.
|
||||
|
||||
### Tool Call
|
||||
|
||||
1. The client sends a standard MCP `tools/call` request to `/mcp`.
|
||||
2. The MCP server validates the tool name and arguments against the OpenAPI-derived input schema.
|
||||
3. The adapter substitutes path parameters into the OpenAPI route template.
|
||||
4. The adapter encodes query parameters into the query string.
|
||||
5. The adapter marshals request body arguments as JSON when the operation has a request body.
|
||||
6. The adapter forwards the caller's `Authorization` header and executes the matching `/api/v1/...` request through the existing Echo handler in-process.
|
||||
7. The API response JSON is decoded into `map[string]any`.
|
||||
8. The MCP result returns a compact JSON text fallback plus object-shaped `structuredContent`.
|
||||
|
||||
## Result Shape
|
||||
|
||||
Every MCP result must use object-shaped `structuredContent`.
|
||||
|
||||
Rules:
|
||||
|
||||
- If the API response is a JSON object, return it unchanged.
|
||||
- If the API response is empty, return `{ "ok": true }`.
|
||||
- If an unexpected raw JSON array appears, wrap it as `{ "result": [...] }`.
|
||||
- If an unexpected scalar appears, wrap it as `{ "result": value }`.
|
||||
|
||||
This directly addresses issue #6022 by preventing collection tools from returning bare arrays.
|
||||
|
||||
## Authentication And Origin Safety
|
||||
|
||||
The MCP endpoint accepts the same bearer credentials as the API:
|
||||
|
||||
```text
|
||||
Authorization: Bearer <PAT-or-access-token>
|
||||
```
|
||||
|
||||
The MCP adapter forwards the bearer header to the in-process API request. Public API operations can work without authentication when the API allows them. Mutating operations require authentication because the API already enforces that behavior.
|
||||
|
||||
For browser-origin safety, `/mcp` rejects cross-origin browser requests unless the `Origin` header is same-origin or matches the configured instance URL. Requests without an `Origin` header are allowed because desktop MCP clients commonly omit it.
|
||||
|
||||
## Errors
|
||||
|
||||
The MCP server should convert failures into MCP tool errors:
|
||||
|
||||
- Invalid MCP arguments: concise validation message.
|
||||
- API `401` or `403`: preserve the API message where available.
|
||||
- API `404`: report that the resource was not found.
|
||||
- Other API errors: include the HTTP status code and decoded API message.
|
||||
- Internal OpenAPI or adapter errors: log server-side details and return a concise tool error.
|
||||
|
||||
Adapter errors should not bypass MCP result formatting unless the underlying MCP framework requires protocol-level errors for invalid protocol messages.
|
||||
|
||||
## Tool Annotations
|
||||
|
||||
Tool annotations are derived from HTTP methods:
|
||||
|
||||
- `GET`: read-only, non-destructive, idempotent.
|
||||
- `POST`: mutating unless the operation is explicitly known to be read-only.
|
||||
- `PATCH`: mutating, non-idempotent by default.
|
||||
- `DELETE`: destructive and idempotent by default.
|
||||
|
||||
These annotations are hints for clients and do not replace API authorization.
|
||||
|
||||
## Testing
|
||||
|
||||
### OpenAPI Parsing Tests
|
||||
|
||||
Tests should verify:
|
||||
|
||||
- every curated operation ID exists in `proto/gen/openapi.yaml`
|
||||
- every selected tool has an object input schema
|
||||
- every selected tool has an object output schema
|
||||
- selected operations do not include admin, auth, webhook, identity provider, personal access token, instance setting, share-link, AI transcription, or bulk-delete operations
|
||||
- tool names are deterministic and unique
|
||||
|
||||
### MCP Protocol Tests
|
||||
|
||||
Using an Echo test server and JSON-RPC requests, tests should verify:
|
||||
|
||||
- `initialize` succeeds
|
||||
- `tools/list` returns only curated OpenAPI-derived tools
|
||||
- tool definitions include input and output schemas
|
||||
- no prompts or resources capabilities are advertised
|
||||
- collection tool calls return object-shaped `structuredContent`, never a bare array
|
||||
|
||||
### Adapter Tests
|
||||
|
||||
Representative adapter tests should cover:
|
||||
|
||||
- `GET /api/v1/memos?pageSize=...`
|
||||
- `POST /api/v1/memos`
|
||||
- `GET /api/v1/memos/{memo}`
|
||||
- `PATCH /api/v1/memos/{memo}`
|
||||
- `DELETE /api/v1/memos/{memo}`
|
||||
- `GET /api/v1/memos/{memo}/comments`
|
||||
- `POST /api/v1/memos/{memo}/reactions`
|
||||
|
||||
Before finishing implementation, run:
|
||||
|
||||
```bash
|
||||
go test ./server/router/mcp/... ./server/router/api/v1/... ./server/...
|
||||
```
|
||||
|
||||
## Implementation Notes
|
||||
|
||||
- Do not hand-edit `proto/gen/openapi.yaml`.
|
||||
- If a needed MCP tool is not represented by OpenAPI, add or adjust the proto/API surface first, run `cd proto && buf generate`, then derive the MCP tool from the regenerated OpenAPI.
|
||||
- Prefer existing API auth and gateway behavior over duplicating authorization logic in the MCP adapter.
|
||||
- Keep the first version intentionally small. Additional OpenAPI-derived tools can be added by extending the curated allowlist and tests.
|
||||
@ -0,0 +1,11 @@
|
||||
package proto
|
||||
|
||||
import _ "embed"
|
||||
|
||||
//go:embed gen/openapi.yaml
|
||||
var openAPIYAML []byte
|
||||
|
||||
// OpenAPIYAML returns the embedded generated OpenAPI specification.
|
||||
func OpenAPIYAML() []byte {
|
||||
return append([]byte(nil), openAPIYAML...)
|
||||
}
|
||||
@ -0,0 +1,160 @@
|
||||
package mcp
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"io"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"net/url"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/labstack/echo/v5"
|
||||
sdkmcp "github.com/modelcontextprotocol/go-sdk/mcp"
|
||||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
type apiAdapter struct {
|
||||
echoServer *echo.Echo
|
||||
}
|
||||
|
||||
func newAPIAdapter(echoServer *echo.Echo) *apiAdapter {
|
||||
return &apiAdapter{echoServer: echoServer}
|
||||
}
|
||||
|
||||
func (a *apiAdapter) execute(ctx context.Context, operation *openAPIOperation, arguments map[string]any, authorization string) (*sdkmcp.CallToolResult, error) {
|
||||
req, err := buildAPIRequest(ctx, operation, arguments, authorization)
|
||||
if err != nil {
|
||||
return newToolErrorResult(err.Error()), nil
|
||||
}
|
||||
|
||||
recorder := httptest.NewRecorder()
|
||||
a.echoServer.ServeHTTP(recorder, req)
|
||||
|
||||
value, err := decodeJSONValue(recorder.Body.Bytes())
|
||||
if err != nil {
|
||||
return newToolErrorResult(err.Error()), nil
|
||||
}
|
||||
if recorder.Code < http.StatusOK || recorder.Code >= http.StatusMultipleChoices {
|
||||
return newToolErrorResult(apiErrorMessage(recorder.Code, value)), nil
|
||||
}
|
||||
return newStructuredToolResult(value)
|
||||
}
|
||||
|
||||
func buildAPIRequest(ctx context.Context, operation *openAPIOperation, arguments map[string]any, authorization string) (*http.Request, error) {
|
||||
path, err := substitutePathParameters(operation, arguments)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
query := url.Values{}
|
||||
for _, parameter := range operation.Parameters {
|
||||
if parameter.In != "query" {
|
||||
continue
|
||||
}
|
||||
value, ok := arguments[parameter.Name]
|
||||
if !ok || value == nil {
|
||||
continue
|
||||
}
|
||||
query.Set(parameter.Name, valueToString(value))
|
||||
}
|
||||
if encoded := query.Encode(); encoded != "" {
|
||||
path += "?" + encoded
|
||||
}
|
||||
|
||||
var body io.Reader
|
||||
if operation.RequestBody != nil {
|
||||
bodyValue, ok := arguments["body"]
|
||||
if !ok || bodyValue == nil {
|
||||
if operation.RequestBody.Required {
|
||||
return nil, errors.New(`missing required request body "body"`)
|
||||
}
|
||||
bodyValue = map[string]any{}
|
||||
}
|
||||
|
||||
data, err := json.Marshal(bodyValue)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "failed to marshal request body")
|
||||
}
|
||||
body = bytes.NewReader(data)
|
||||
}
|
||||
|
||||
req := httptest.NewRequest(operation.Method, path, body).WithContext(ctx)
|
||||
if body != nil {
|
||||
req.Header.Set("Content-Type", "application/json")
|
||||
}
|
||||
if authorization != "" {
|
||||
req.Header.Set("Authorization", authorization)
|
||||
}
|
||||
return req, nil
|
||||
}
|
||||
|
||||
func substitutePathParameters(operation *openAPIOperation, arguments map[string]any) (string, error) {
|
||||
path := operation.Path
|
||||
for _, parameter := range operation.Parameters {
|
||||
if parameter.In != "path" {
|
||||
continue
|
||||
}
|
||||
|
||||
value, ok := arguments[parameter.Name]
|
||||
if !ok || value == nil || valueToString(value) == "" {
|
||||
return "", errors.Errorf(`missing required path parameter "%s"`, parameter.Name)
|
||||
}
|
||||
path = strings.ReplaceAll(path, "{"+parameter.Name+"}", url.PathEscape(valueToString(value)))
|
||||
}
|
||||
return path, nil
|
||||
}
|
||||
|
||||
func valueToString(value any) string {
|
||||
switch typed := value.(type) {
|
||||
case string:
|
||||
return typed
|
||||
case bool:
|
||||
return strconv.FormatBool(typed)
|
||||
case int:
|
||||
return strconv.Itoa(typed)
|
||||
case int8:
|
||||
return strconv.FormatInt(int64(typed), 10)
|
||||
case int16:
|
||||
return strconv.FormatInt(int64(typed), 10)
|
||||
case int32:
|
||||
return strconv.FormatInt(int64(typed), 10)
|
||||
case int64:
|
||||
return strconv.FormatInt(typed, 10)
|
||||
case uint:
|
||||
return strconv.FormatUint(uint64(typed), 10)
|
||||
case uint8:
|
||||
return strconv.FormatUint(uint64(typed), 10)
|
||||
case uint16:
|
||||
return strconv.FormatUint(uint64(typed), 10)
|
||||
case uint32:
|
||||
return strconv.FormatUint(uint64(typed), 10)
|
||||
case uint64:
|
||||
return strconv.FormatUint(typed, 10)
|
||||
case float32:
|
||||
return strconv.FormatFloat(float64(typed), 'f', -1, 32)
|
||||
case float64:
|
||||
return strconv.FormatFloat(typed, 'f', -1, 64)
|
||||
default:
|
||||
data, err := json.Marshal(value)
|
||||
if err != nil {
|
||||
return ""
|
||||
}
|
||||
return strings.Trim(string(data), `"`)
|
||||
}
|
||||
}
|
||||
|
||||
func apiErrorMessage(statusCode int, value any) string {
|
||||
status := strings.TrimSpace(strconv.Itoa(statusCode) + " " + http.StatusText(statusCode))
|
||||
if object, ok := value.(map[string]any); ok {
|
||||
for _, key := range []string{"message", "error"} {
|
||||
message, ok := object[key].(string)
|
||||
if ok && message != "" {
|
||||
return status + ": " + message
|
||||
}
|
||||
}
|
||||
}
|
||||
return status
|
||||
}
|
||||
@ -0,0 +1,213 @@
|
||||
package mcp
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"io"
|
||||
"net/http"
|
||||
"testing"
|
||||
|
||||
"github.com/labstack/echo/v5"
|
||||
sdkmcp "github.com/modelcontextprotocol/go-sdk/mcp"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestNormalizeStructuredContentKeepsObjects(t *testing.T) {
|
||||
result := normalizeStructuredContent(map[string]any{"memos": []any{map[string]any{"name": "memos/a"}}})
|
||||
require.Equal(t, map[string]any{"memos": []any{map[string]any{"name": "memos/a"}}}, result)
|
||||
}
|
||||
|
||||
func TestNormalizeStructuredContentWrapsArrays(t *testing.T) {
|
||||
result := normalizeStructuredContent([]any{map[string]any{"tag": "work"}})
|
||||
require.Equal(t, map[string]any{"result": []any{map[string]any{"tag": "work"}}}, result)
|
||||
}
|
||||
|
||||
func TestNormalizeStructuredContentUsesOKForNil(t *testing.T) {
|
||||
result := normalizeStructuredContent(nil)
|
||||
require.Equal(t, map[string]any{"ok": true}, result)
|
||||
}
|
||||
|
||||
func TestNormalizeStructuredContentWrapsScalars(t *testing.T) {
|
||||
result := normalizeStructuredContent("created")
|
||||
require.Equal(t, map[string]any{"result": "created"}, result)
|
||||
}
|
||||
|
||||
func TestNewStructuredToolResultUsesObjectStructuredContent(t *testing.T) {
|
||||
result, err := newStructuredToolResult([]any{"one"})
|
||||
require.NoError(t, err)
|
||||
require.IsType(t, map[string]any{}, result.StructuredContent)
|
||||
require.Equal(t, map[string]any{"result": []any{"one"}}, result.StructuredContent)
|
||||
require.NotEmpty(t, result.Content)
|
||||
text, ok := result.Content[0].(*sdkmcp.TextContent)
|
||||
require.True(t, ok)
|
||||
require.JSONEq(t, `{"result":["one"]}`, text.Text)
|
||||
}
|
||||
|
||||
func TestNewToolErrorResult(t *testing.T) {
|
||||
result := newToolErrorResult("resource not found")
|
||||
require.True(t, result.IsError)
|
||||
require.Equal(t, map[string]any{
|
||||
"error": map[string]any{
|
||||
"message": "resource not found",
|
||||
},
|
||||
}, result.StructuredContent)
|
||||
require.NotEmpty(t, result.Content)
|
||||
text, ok := result.Content[0].(*sdkmcp.TextContent)
|
||||
require.True(t, ok)
|
||||
require.Equal(t, "resource not found", text.Text)
|
||||
}
|
||||
|
||||
func TestDecodeJSONValue(t *testing.T) {
|
||||
value, err := decodeJSONValue([]byte(`{"ok":true}`))
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, map[string]any{"ok": true}, value)
|
||||
|
||||
value, err = decodeJSONValue([]byte{})
|
||||
require.NoError(t, err)
|
||||
require.Nil(t, value)
|
||||
|
||||
value, err = decodeJSONValue([]byte(" \n\t "))
|
||||
require.NoError(t, err)
|
||||
require.Nil(t, value)
|
||||
|
||||
value, err = decodeJSONValue([]byte(`[1,"two"]`))
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, []any{float64(1), "two"}, value)
|
||||
|
||||
_, err = decodeJSONValue([]byte(`{`))
|
||||
require.Error(t, err)
|
||||
require.True(t, errorsIsJSONSyntax(err), "wrapped syntax errors should remain inspectable")
|
||||
}
|
||||
|
||||
func errorsIsJSONSyntax(err error) bool {
|
||||
var syntaxError *json.SyntaxError
|
||||
return err != nil && errors.As(err, &syntaxError)
|
||||
}
|
||||
|
||||
func TestBuildAPIRequestMapsPathQueryAndBody(t *testing.T) {
|
||||
operation := &openAPIOperation{
|
||||
Method: "PATCH",
|
||||
Path: "/api/v1/memos/{memo}",
|
||||
Parameters: []openAPIParameter{
|
||||
{Name: "memo", In: "path", Required: true, Schema: jsonSchema{"type": "string"}},
|
||||
{Name: "updateMask", In: "query", Schema: jsonSchema{"type": "string"}},
|
||||
},
|
||||
RequestBody: &openAPIRequestBody{Required: true},
|
||||
}
|
||||
arguments := map[string]any{
|
||||
"memo": "abc123",
|
||||
"updateMask": "content",
|
||||
"body": map[string]any{
|
||||
"memo": map[string]any{
|
||||
"name": "memos/abc123",
|
||||
"content": "updated",
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
req, err := buildAPIRequest(context.Background(), operation, arguments, "Bearer pat")
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "PATCH", req.Method)
|
||||
require.Equal(t, "/api/v1/memos/abc123", req.URL.Path)
|
||||
require.Equal(t, "content", req.URL.Query().Get("updateMask"))
|
||||
require.Equal(t, "Bearer pat", req.Header.Get("Authorization"))
|
||||
|
||||
body, err := io.ReadAll(req.Body)
|
||||
require.NoError(t, err)
|
||||
require.JSONEq(t, `{"memo":{"name":"memos/abc123","content":"updated"}}`, string(body))
|
||||
}
|
||||
|
||||
func TestBuildAPIRequestRequiresPathParameters(t *testing.T) {
|
||||
operation := &openAPIOperation{
|
||||
Method: "GET",
|
||||
Path: "/api/v1/memos/{memo}",
|
||||
Parameters: []openAPIParameter{{Name: "memo", In: "path", Required: true}},
|
||||
}
|
||||
|
||||
_, err := buildAPIRequest(context.Background(), operation, map[string]any{}, "")
|
||||
require.ErrorContains(t, err, `missing required path parameter "memo"`)
|
||||
}
|
||||
|
||||
func TestBuildAPIRequestRequiresRequestBody(t *testing.T) {
|
||||
operation := &openAPIOperation{
|
||||
Method: "POST",
|
||||
Path: "/api/v1/memos",
|
||||
RequestBody: &openAPIRequestBody{Required: true},
|
||||
}
|
||||
|
||||
_, err := buildAPIRequest(context.Background(), operation, map[string]any{}, "")
|
||||
require.ErrorContains(t, err, `missing required request body "body"`)
|
||||
}
|
||||
|
||||
func TestBuildAPIRequestEscapesPathAndStringifiesPrimitiveQueryParameters(t *testing.T) {
|
||||
operation := &openAPIOperation{
|
||||
Method: "DELETE",
|
||||
Path: "/api/v1/memos/{memo}",
|
||||
Parameters: []openAPIParameter{
|
||||
{Name: "memo", In: "path", Required: true, Schema: jsonSchema{"type": "string"}},
|
||||
{Name: "force", In: "query", Schema: jsonSchema{"type": "boolean"}},
|
||||
{Name: "limit", In: "query", Schema: jsonSchema{"type": "integer"}},
|
||||
},
|
||||
}
|
||||
|
||||
req, err := buildAPIRequest(context.Background(), operation, map[string]any{
|
||||
"memo": "abc 123",
|
||||
"force": true,
|
||||
"limit": 10,
|
||||
}, "")
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "/api/v1/memos/abc%20123", req.URL.EscapedPath())
|
||||
require.Equal(t, "true", req.URL.Query().Get("force"))
|
||||
require.Equal(t, "10", req.URL.Query().Get("limit"))
|
||||
}
|
||||
|
||||
func TestExecuteOperationReturnsObjectStructuredContent(t *testing.T) {
|
||||
echoServer := echo.New()
|
||||
echoServer.GET("/api/v1/memos", func(c *echo.Context) error {
|
||||
require.Equal(t, "Bearer token", c.Request().Header.Get("Authorization"))
|
||||
return c.JSON(http.StatusOK, map[string]any{
|
||||
"memos": []any{map[string]any{"name": "memos/abc123"}},
|
||||
})
|
||||
})
|
||||
|
||||
operation := &openAPIOperation{
|
||||
Method: "GET",
|
||||
Path: "/api/v1/memos",
|
||||
}
|
||||
adapter := newAPIAdapter(echoServer)
|
||||
|
||||
result, err := adapter.execute(context.Background(), operation, map[string]any{}, "Bearer token")
|
||||
require.NoError(t, err)
|
||||
require.False(t, result.IsError)
|
||||
require.Equal(t, map[string]any{
|
||||
"memos": []any{map[string]any{"name": "memos/abc123"}},
|
||||
}, result.StructuredContent)
|
||||
}
|
||||
|
||||
func TestExecuteOperationConvertsAPIErrorsToToolErrors(t *testing.T) {
|
||||
echoServer := echo.New()
|
||||
echoServer.GET("/api/v1/memos/:memo", func(c *echo.Context) error {
|
||||
return c.JSON(http.StatusNotFound, map[string]any{"message": "missing memo"})
|
||||
})
|
||||
|
||||
operation := &openAPIOperation{
|
||||
Method: "GET",
|
||||
Path: "/api/v1/memos/{memo}",
|
||||
Parameters: []openAPIParameter{{Name: "memo", In: "path", Required: true}},
|
||||
}
|
||||
adapter := newAPIAdapter(echoServer)
|
||||
|
||||
result, err := adapter.execute(context.Background(), operation, map[string]any{"memo": "missing"}, "")
|
||||
require.NoError(t, err)
|
||||
require.True(t, result.IsError)
|
||||
require.Equal(t, map[string]any{
|
||||
"error": map[string]any{
|
||||
"message": "404 Not Found: missing memo",
|
||||
},
|
||||
}, result.StructuredContent)
|
||||
text, ok := result.Content[0].(*sdkmcp.TextContent)
|
||||
require.True(t, ok)
|
||||
require.Contains(t, text.Text, "404")
|
||||
require.Contains(t, text.Text, "missing memo")
|
||||
}
|
||||
@ -0,0 +1,215 @@
|
||||
package mcp
|
||||
|
||||
import (
|
||||
"regexp"
|
||||
"strings"
|
||||
|
||||
sdkmcp "github.com/modelcontextprotocol/go-sdk/mcp"
|
||||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
var curatedOperationIDs = []string{
|
||||
"MemoService_ListMemos",
|
||||
"MemoService_CreateMemo",
|
||||
"MemoService_GetMemo",
|
||||
"MemoService_UpdateMemo",
|
||||
"MemoService_DeleteMemo",
|
||||
"MemoService_ListMemoComments",
|
||||
"MemoService_CreateMemoComment",
|
||||
"MemoService_ListMemoAttachments",
|
||||
"MemoService_SetMemoAttachments",
|
||||
"MemoService_ListMemoReactions",
|
||||
"MemoService_UpsertMemoReaction",
|
||||
"MemoService_DeleteMemoReaction",
|
||||
"MemoService_ListMemoRelations",
|
||||
"MemoService_SetMemoRelations",
|
||||
"AttachmentService_ListAttachments",
|
||||
"AttachmentService_GetAttachment",
|
||||
"AttachmentService_DeleteAttachment",
|
||||
}
|
||||
|
||||
type registeredOperation struct {
|
||||
ToolName string
|
||||
OperationID string
|
||||
Method string
|
||||
Path string
|
||||
Operation *openAPIOperation
|
||||
InputSchema jsonSchema
|
||||
}
|
||||
|
||||
var wordBoundary = regexp.MustCompile(`([a-z0-9])([A-Z])`)
|
||||
|
||||
func buildCuratedTools(registry map[string]*openAPIOperation) ([]*sdkmcp.Tool, map[string]*registeredOperation, error) {
|
||||
tools := make([]*sdkmcp.Tool, 0, len(curatedOperationIDs))
|
||||
operations := map[string]*registeredOperation{}
|
||||
for _, operationID := range curatedOperationIDs {
|
||||
operation, ok := registry[operationID]
|
||||
if !ok {
|
||||
return nil, nil, errors.Errorf("curated OpenAPI operation %q not found", operationID)
|
||||
}
|
||||
|
||||
tool, registered := buildToolFromOperation(operation)
|
||||
if _, exists := operations[tool.Name]; exists {
|
||||
return nil, nil, errors.Errorf("duplicate MCP tool name %q", tool.Name)
|
||||
}
|
||||
|
||||
tools = append(tools, tool)
|
||||
operations[tool.Name] = registered
|
||||
}
|
||||
return tools, operations, nil
|
||||
}
|
||||
|
||||
func buildToolFromOperation(operation *openAPIOperation) (*sdkmcp.Tool, *registeredOperation) {
|
||||
name := toolNameFromOperationID(operation.OperationID)
|
||||
title := titleFromToolName(name)
|
||||
inputSchema := inputSchemaForOperation(operation)
|
||||
tool := &sdkmcp.Tool{
|
||||
Meta: sdkmcp.Meta{
|
||||
"operationId": operation.OperationID,
|
||||
"method": operation.Method,
|
||||
"path": operation.Path,
|
||||
},
|
||||
Name: name,
|
||||
Title: title,
|
||||
Description: operation.Description,
|
||||
InputSchema: inputSchema,
|
||||
OutputSchema: outputSchemaForOperation(operation),
|
||||
Annotations: annotationsForMethod(operation.Method, title),
|
||||
}
|
||||
|
||||
return tool, ®isteredOperation{
|
||||
ToolName: name,
|
||||
OperationID: operation.OperationID,
|
||||
Method: operation.Method,
|
||||
Path: operation.Path,
|
||||
Operation: operation,
|
||||
InputSchema: inputSchema,
|
||||
}
|
||||
}
|
||||
|
||||
func toolNameFromOperationID(operationID string) string {
|
||||
service, method, ok := strings.Cut(operationID, "_")
|
||||
if !ok {
|
||||
return camelToSnake(operationID)
|
||||
}
|
||||
service = strings.TrimSuffix(service, "Service")
|
||||
return camelToSnake(service) + "_" + camelToSnake(method)
|
||||
}
|
||||
|
||||
func camelToSnake(value string) string {
|
||||
return strings.ToLower(wordBoundary.ReplaceAllString(value, `${1}_${2}`))
|
||||
}
|
||||
|
||||
func titleFromToolName(name string) string {
|
||||
parts := strings.Split(name, "_")
|
||||
for i, part := range parts {
|
||||
if part == "" {
|
||||
continue
|
||||
}
|
||||
parts[i] = strings.ToUpper(part[:1]) + part[1:]
|
||||
}
|
||||
return strings.Join(parts, " ")
|
||||
}
|
||||
|
||||
func inputSchemaForOperation(operation *openAPIOperation) jsonSchema {
|
||||
properties := map[string]any{}
|
||||
required := []string{}
|
||||
defs := map[string]any{}
|
||||
for _, parameter := range operation.Parameters {
|
||||
schema := cloneSchema(parameter.Schema)
|
||||
if parameter.Description != "" {
|
||||
schema["description"] = parameter.Description
|
||||
}
|
||||
properties[parameter.Name] = schema
|
||||
if parameter.Required {
|
||||
required = append(required, parameter.Name)
|
||||
}
|
||||
}
|
||||
|
||||
if operation.RequestBody != nil {
|
||||
bodySchema := requestBodySchema(operation)
|
||||
for name, definition := range extractSchemaDefs(bodySchema) {
|
||||
defs[name] = definition
|
||||
}
|
||||
properties["body"] = bodySchema
|
||||
if operation.RequestBody.Required {
|
||||
required = append(required, "body")
|
||||
}
|
||||
}
|
||||
|
||||
schema := jsonSchema{
|
||||
"type": "object",
|
||||
"properties": properties,
|
||||
"additionalProperties": false,
|
||||
}
|
||||
if len(required) > 0 {
|
||||
schema["required"] = required
|
||||
}
|
||||
if len(defs) > 0 {
|
||||
schema["$defs"] = defs
|
||||
}
|
||||
return schema
|
||||
}
|
||||
|
||||
func requestBodySchema(operation *openAPIOperation) jsonSchema {
|
||||
if operation.RequestBodySchema == nil {
|
||||
return jsonSchema{"type": "object"}
|
||||
}
|
||||
return cloneSchema(operation.RequestBodySchema)
|
||||
}
|
||||
|
||||
func outputSchemaForOperation(operation *openAPIOperation) jsonSchema {
|
||||
if operation.ResponseSchema == nil {
|
||||
return okSchema()
|
||||
}
|
||||
return cloneSchema(operation.ResponseSchema)
|
||||
}
|
||||
|
||||
func cloneSchema(schema jsonSchema) jsonSchema {
|
||||
clone := jsonSchema{}
|
||||
for key, value := range schema {
|
||||
clone[key] = value
|
||||
}
|
||||
return clone
|
||||
}
|
||||
|
||||
func extractSchemaDefs(schema jsonSchema) map[string]any {
|
||||
defs, ok := schema["$defs"].(map[string]any)
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
delete(schema, "$defs")
|
||||
return defs
|
||||
}
|
||||
|
||||
func annotationsForMethod(method string, title string) *sdkmcp.ToolAnnotations {
|
||||
openWorld := false
|
||||
destructive := false
|
||||
switch strings.ToUpper(method) {
|
||||
case "GET":
|
||||
return &sdkmcp.ToolAnnotations{
|
||||
Title: title,
|
||||
ReadOnlyHint: true,
|
||||
DestructiveHint: &destructive,
|
||||
IdempotentHint: true,
|
||||
OpenWorldHint: &openWorld,
|
||||
}
|
||||
case "DELETE":
|
||||
destructive = true
|
||||
return &sdkmcp.ToolAnnotations{
|
||||
Title: title,
|
||||
ReadOnlyHint: false,
|
||||
DestructiveHint: &destructive,
|
||||
IdempotentHint: true,
|
||||
OpenWorldHint: &openWorld,
|
||||
}
|
||||
default:
|
||||
return &sdkmcp.ToolAnnotations{
|
||||
Title: title,
|
||||
ReadOnlyHint: false,
|
||||
DestructiveHint: &destructive,
|
||||
IdempotentHint: false,
|
||||
OpenWorldHint: &openWorld,
|
||||
}
|
||||
}
|
||||
}
|
||||
@ -0,0 +1,152 @@
|
||||
package mcp
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"testing"
|
||||
|
||||
sdkmcp "github.com/modelcontextprotocol/go-sdk/mcp"
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestCuratedOperationIDsStayMemoFocused(t *testing.T) {
|
||||
require.Len(t, curatedOperationIDs, 17)
|
||||
|
||||
for _, operationID := range curatedOperationIDs {
|
||||
require.NotContains(t, operationID, "Admin")
|
||||
require.NotContains(t, operationID, "AuthService_")
|
||||
require.NotContains(t, operationID, "UserService_")
|
||||
require.NotContains(t, operationID, "AIService_")
|
||||
require.NotContains(t, operationID, "IdentityProviderService_")
|
||||
require.NotContains(t, operationID, "InstanceService_")
|
||||
require.NotContains(t, operationID, "PersonalAccessToken")
|
||||
require.NotContains(t, operationID, "PAT")
|
||||
require.NotContains(t, operationID, "Webhook")
|
||||
require.NotContains(t, operationID, "Share")
|
||||
require.NotContains(t, operationID, "BatchDelete")
|
||||
require.NotContains(t, operationID, "Transcribe")
|
||||
}
|
||||
}
|
||||
|
||||
func TestToolNameFromOperationID(t *testing.T) {
|
||||
require.Equal(t, "memo_list_memos", toolNameFromOperationID("MemoService_ListMemos"))
|
||||
require.Equal(t, "attachment_get_attachment", toolNameFromOperationID("AttachmentService_GetAttachment"))
|
||||
}
|
||||
|
||||
func TestBuildToolFromOperationIncludesSchemasAndMetadata(t *testing.T) {
|
||||
spec, err := loadOpenAPISpec("../../../proto/gen/openapi.yaml")
|
||||
require.NoError(t, err)
|
||||
registry, err := buildOperationRegistry(spec)
|
||||
require.NoError(t, err)
|
||||
|
||||
tool, operation := buildToolFromOperation(registry["MemoService_ListMemos"])
|
||||
require.Equal(t, "memo_list_memos", tool.Name)
|
||||
require.Equal(t, "Memo List Memos", tool.Title)
|
||||
require.Equal(t, "MemoService_ListMemos", operation.OperationID)
|
||||
require.Equal(t, "GET", operation.Method)
|
||||
require.Equal(t, "/api/v1/memos", operation.Path)
|
||||
require.Equal(t, "MemoService_ListMemos", tool.Meta["operationId"])
|
||||
require.Equal(t, "GET", tool.Meta["method"])
|
||||
require.Equal(t, "/api/v1/memos", tool.Meta["path"])
|
||||
require.NotEmpty(t, tool.Description)
|
||||
require.NotNil(t, tool.InputSchema)
|
||||
require.NotNil(t, tool.OutputSchema)
|
||||
require.NotNil(t, tool.Annotations)
|
||||
require.True(t, tool.Annotations.ReadOnlyHint)
|
||||
require.False(t, *tool.Annotations.DestructiveHint)
|
||||
require.True(t, tool.Annotations.IdempotentHint)
|
||||
require.False(t, *tool.Annotations.OpenWorldHint)
|
||||
|
||||
inputBytes, err := json.Marshal(tool.InputSchema)
|
||||
require.NoError(t, err)
|
||||
require.Contains(t, string(inputBytes), `"pageSize"`)
|
||||
require.Contains(t, string(inputBytes), `"additionalProperties":false`)
|
||||
|
||||
outputBytes, err := json.Marshal(tool.OutputSchema)
|
||||
require.NoError(t, err)
|
||||
require.Contains(t, string(outputBytes), `"memos"`)
|
||||
}
|
||||
|
||||
func TestBuildToolFromOperationIncludesRequestBodySchema(t *testing.T) {
|
||||
spec, err := loadOpenAPISpec("../../../proto/gen/openapi.yaml")
|
||||
require.NoError(t, err)
|
||||
registry, err := buildOperationRegistry(spec)
|
||||
require.NoError(t, err)
|
||||
|
||||
tool, operation := buildToolFromOperation(registry["MemoService_CreateMemo"])
|
||||
require.Equal(t, "POST", operation.Method)
|
||||
require.False(t, tool.Annotations.ReadOnlyHint)
|
||||
require.False(t, *tool.Annotations.DestructiveHint)
|
||||
require.False(t, tool.Annotations.IdempotentHint)
|
||||
|
||||
input, ok := tool.InputSchema.(jsonSchema)
|
||||
require.True(t, ok)
|
||||
require.Contains(t, input["required"], "body")
|
||||
properties, ok := input["properties"].(map[string]any)
|
||||
require.True(t, ok)
|
||||
require.Contains(t, properties, "memoId")
|
||||
require.Contains(t, properties, "body")
|
||||
body, ok := properties["body"].(jsonSchema)
|
||||
require.True(t, ok)
|
||||
require.Equal(t, "object", body["type"])
|
||||
require.Contains(t, body["properties"], "content")
|
||||
|
||||
err = validateToolArguments(input, map[string]any{
|
||||
"body": map[string]any{
|
||||
"state": "NORMAL",
|
||||
"content": "hello",
|
||||
"visibility": "PRIVATE",
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
}
|
||||
|
||||
func TestBuildCuratedToolsHasUniqueNames(t *testing.T) {
|
||||
spec, err := loadOpenAPISpec("../../../proto/gen/openapi.yaml")
|
||||
require.NoError(t, err)
|
||||
registry, err := buildOperationRegistry(spec)
|
||||
require.NoError(t, err)
|
||||
|
||||
tools, operations, err := buildCuratedTools(registry)
|
||||
require.NoError(t, err)
|
||||
require.Len(t, tools, len(curatedOperationIDs))
|
||||
require.Len(t, operations, len(curatedOperationIDs))
|
||||
|
||||
names := map[string]struct{}{}
|
||||
for _, tool := range tools {
|
||||
require.IsType(t, &sdkmcp.Tool{}, tool)
|
||||
require.NotEmpty(t, tool.Name)
|
||||
require.NotContains(t, names, tool.Name)
|
||||
names[tool.Name] = struct{}{}
|
||||
require.Equal(t, tool.Name, operations[tool.Name].ToolName)
|
||||
|
||||
inputBytes, err := json.Marshal(tool.InputSchema)
|
||||
require.NoError(t, err)
|
||||
require.NotContains(t, string(inputBytes), "#/components/schemas")
|
||||
outputBytes, err := json.Marshal(tool.OutputSchema)
|
||||
require.NoError(t, err)
|
||||
require.NotContains(t, string(outputBytes), "#/components/schemas")
|
||||
}
|
||||
}
|
||||
|
||||
func TestBuildCuratedToolsRejectsMissingOperation(t *testing.T) {
|
||||
_, _, err := buildCuratedTools(map[string]*openAPIOperation{})
|
||||
require.ErrorContains(t, err, "curated OpenAPI operation")
|
||||
require.ErrorContains(t, err, "not found")
|
||||
}
|
||||
|
||||
func TestBuildCuratedToolsRejectsDuplicateToolNames(t *testing.T) {
|
||||
registry := make(map[string]*openAPIOperation, len(curatedOperationIDs))
|
||||
for _, operationID := range curatedOperationIDs {
|
||||
registry[operationID] = &openAPIOperation{
|
||||
OperationID: operationID,
|
||||
Description: operationID,
|
||||
Method: "GET",
|
||||
Path: "/api/v1/test",
|
||||
ResponseSchema: okSchema(),
|
||||
}
|
||||
}
|
||||
registry["MemoService_ListMemos"].OperationID = "MemoService_GetMemo"
|
||||
|
||||
_, _, err := buildCuratedTools(registry)
|
||||
require.ErrorContains(t, err, "duplicate MCP tool name")
|
||||
}
|
||||
@ -0,0 +1,252 @@
|
||||
package mcp
|
||||
|
||||
import (
|
||||
"os"
|
||||
"strings"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
"gopkg.in/yaml.v3"
|
||||
)
|
||||
|
||||
type jsonSchema map[string]any
|
||||
|
||||
type openAPISpec struct {
|
||||
OpenAPI string `yaml:"openapi"`
|
||||
Paths map[string]map[string]*openAPIOperation `yaml:"paths"`
|
||||
Components openAPIComponents `yaml:"components"`
|
||||
}
|
||||
|
||||
type openAPIComponents struct {
|
||||
Schemas map[string]jsonSchema `yaml:"schemas"`
|
||||
}
|
||||
|
||||
type openAPIOperation struct {
|
||||
OperationID string `yaml:"operationId"`
|
||||
Description string `yaml:"description"`
|
||||
Parameters []openAPIParameter `yaml:"parameters"`
|
||||
RequestBody *openAPIRequestBody `yaml:"requestBody"`
|
||||
Responses map[string]openAPIResponse `yaml:"responses"`
|
||||
Method string `yaml:"-"`
|
||||
Path string `yaml:"-"`
|
||||
ResponseSchema jsonSchema `yaml:"-"`
|
||||
RequestBodySchema jsonSchema `yaml:"-"`
|
||||
}
|
||||
|
||||
type openAPIParameter struct {
|
||||
Name string `yaml:"name"`
|
||||
In string `yaml:"in"`
|
||||
Description string `yaml:"description"`
|
||||
Required bool `yaml:"required"`
|
||||
Schema jsonSchema `yaml:"schema"`
|
||||
}
|
||||
|
||||
type openAPIRequestBody struct {
|
||||
Required bool `yaml:"required"`
|
||||
Content map[string]openAPIMediaType `yaml:"content"`
|
||||
}
|
||||
|
||||
type openAPIResponse struct {
|
||||
Description string `yaml:"description"`
|
||||
Content map[string]openAPIMediaType `yaml:"content"`
|
||||
}
|
||||
|
||||
type openAPIMediaType struct {
|
||||
Schema jsonSchema `yaml:"schema"`
|
||||
}
|
||||
|
||||
func loadOpenAPISpec(path string) (*openAPISpec, error) {
|
||||
data, err := os.ReadFile(path)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "failed to read OpenAPI spec")
|
||||
}
|
||||
|
||||
spec := &openAPISpec{}
|
||||
if err := yaml.Unmarshal(data, spec); err != nil {
|
||||
return nil, errors.Wrap(err, "failed to parse OpenAPI spec")
|
||||
}
|
||||
if spec.Paths == nil {
|
||||
return nil, errors.New("OpenAPI spec has no paths")
|
||||
}
|
||||
return spec, nil
|
||||
}
|
||||
|
||||
func buildOperationRegistry(spec *openAPISpec) (map[string]*openAPIOperation, error) {
|
||||
registry := map[string]*openAPIOperation{}
|
||||
for path, pathItem := range spec.Paths {
|
||||
for method, operation := range pathItem {
|
||||
if operation == nil || operation.OperationID == "" {
|
||||
continue
|
||||
}
|
||||
if _, exists := registry[operation.OperationID]; exists {
|
||||
return nil, errors.Errorf("duplicate OpenAPI operationId %q", operation.OperationID)
|
||||
}
|
||||
|
||||
operation.Method = strings.ToUpper(method)
|
||||
operation.Path = path
|
||||
|
||||
responseSchema, err := operationSuccessResponseSchema(spec, operation)
|
||||
if err != nil {
|
||||
return nil, errors.Wrapf(err, "failed to resolve response schema for %s", operation.OperationID)
|
||||
}
|
||||
operation.ResponseSchema = responseSchema
|
||||
|
||||
requestBodySchema, err := operationRequestBodySchema(spec, operation)
|
||||
if err != nil {
|
||||
return nil, errors.Wrapf(err, "failed to resolve request body schema for %s", operation.OperationID)
|
||||
}
|
||||
operation.RequestBodySchema = requestBodySchema
|
||||
|
||||
registry[operation.OperationID] = operation
|
||||
}
|
||||
}
|
||||
return registry, nil
|
||||
}
|
||||
|
||||
func operationSuccessResponseSchema(spec *openAPISpec, operation *openAPIOperation) (jsonSchema, error) {
|
||||
response, ok := operation.Responses["200"]
|
||||
if !ok || response.Content == nil {
|
||||
return okSchema(), nil
|
||||
}
|
||||
mediaType, ok := response.Content["application/json"]
|
||||
if !ok || mediaType.Schema == nil {
|
||||
return okSchema(), nil
|
||||
}
|
||||
return resolveSchemaRef(spec, mediaType.Schema)
|
||||
}
|
||||
|
||||
func operationRequestBodySchema(spec *openAPISpec, operation *openAPIOperation) (jsonSchema, error) {
|
||||
if operation.RequestBody == nil {
|
||||
return nil, nil
|
||||
}
|
||||
mediaType, ok := operation.RequestBody.Content["application/json"]
|
||||
if !ok || mediaType.Schema == nil {
|
||||
return jsonSchema{"type": "object"}, nil
|
||||
}
|
||||
return resolveSchemaRef(spec, mediaType.Schema)
|
||||
}
|
||||
|
||||
func resolveSchemaRef(spec *openAPISpec, schema jsonSchema) (jsonSchema, error) {
|
||||
defs := map[string]any{}
|
||||
resolved, err := resolveSchemaValue(spec, schema, defs, map[string]bool{}, true)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
resolvedSchema, ok := resolved.(map[string]any)
|
||||
if !ok {
|
||||
return nil, errors.New("resolved schema is not an object")
|
||||
}
|
||||
if len(defs) > 0 {
|
||||
resolvedSchema["$defs"] = defs
|
||||
}
|
||||
return jsonSchema(resolvedSchema), nil
|
||||
}
|
||||
|
||||
func resolveSchemaValue(spec *openAPISpec, value any, defs map[string]any, resolving map[string]bool, inlineRef bool) (any, error) {
|
||||
switch typed := value.(type) {
|
||||
case jsonSchema:
|
||||
return resolveSchemaMap(spec, map[string]any(typed), defs, resolving, inlineRef)
|
||||
case map[string]any:
|
||||
return resolveSchemaMap(spec, typed, defs, resolving, inlineRef)
|
||||
case []any:
|
||||
resolved := make([]any, 0, len(typed))
|
||||
for _, item := range typed {
|
||||
resolvedItem, err := resolveSchemaValue(spec, item, defs, resolving, false)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
resolved = append(resolved, resolvedItem)
|
||||
}
|
||||
return resolved, nil
|
||||
default:
|
||||
return value, nil
|
||||
}
|
||||
}
|
||||
|
||||
func resolveSchemaMap(spec *openAPISpec, schema map[string]any, defs map[string]any, resolving map[string]bool, inlineRef bool) (map[string]any, error) {
|
||||
if ref, ok := schema["$ref"].(string); ok && ref != "" {
|
||||
name, err := schemaComponentName(ref)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if inlineRef {
|
||||
return resolveComponentSchema(spec, name, defs, resolving)
|
||||
}
|
||||
if err := addSchemaDef(spec, name, defs, resolving); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return map[string]any{"$ref": "#/$defs/" + name}, nil
|
||||
}
|
||||
|
||||
resolved := make(map[string]any, len(schema))
|
||||
for key, value := range schema {
|
||||
resolvedValue, err := resolveSchemaValue(spec, value, defs, resolving, false)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
resolved[key] = resolvedValue
|
||||
}
|
||||
return resolved, nil
|
||||
}
|
||||
|
||||
func resolveComponentSchema(spec *openAPISpec, name string, defs map[string]any, resolving map[string]bool) (map[string]any, error) {
|
||||
component, ok := spec.Components.Schemas[name]
|
||||
if !ok {
|
||||
return nil, errors.Errorf("schema ref %q not found", schemaComponentRef(name))
|
||||
}
|
||||
resolving[name] = true
|
||||
resolved, err := resolveSchemaMap(spec, map[string]any(component), defs, resolving, false)
|
||||
delete(resolving, name)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if _, ok := defs[name]; ok {
|
||||
defs[name] = resolved
|
||||
}
|
||||
return resolved, nil
|
||||
}
|
||||
|
||||
func addSchemaDef(spec *openAPISpec, name string, defs map[string]any, resolving map[string]bool) error {
|
||||
if _, ok := defs[name]; ok {
|
||||
return nil
|
||||
}
|
||||
component, ok := spec.Components.Schemas[name]
|
||||
if !ok {
|
||||
return errors.Errorf("schema ref %q not found", schemaComponentRef(name))
|
||||
}
|
||||
|
||||
defs[name] = map[string]any{}
|
||||
if resolving[name] {
|
||||
return nil
|
||||
}
|
||||
|
||||
resolving[name] = true
|
||||
resolved, err := resolveSchemaMap(spec, map[string]any(component), defs, resolving, false)
|
||||
delete(resolving, name)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
defs[name] = resolved
|
||||
return nil
|
||||
}
|
||||
|
||||
func schemaComponentName(ref string) (string, error) {
|
||||
const prefix = "#/components/schemas/"
|
||||
if !strings.HasPrefix(ref, prefix) {
|
||||
return "", errors.Errorf("unsupported schema ref %q", ref)
|
||||
}
|
||||
return strings.TrimPrefix(ref, prefix), nil
|
||||
}
|
||||
|
||||
func schemaComponentRef(name string) string {
|
||||
return "#/components/schemas/" + name
|
||||
}
|
||||
|
||||
func okSchema() jsonSchema {
|
||||
return jsonSchema{
|
||||
"type": "object",
|
||||
"properties": map[string]any{
|
||||
"ok": map[string]any{"type": "boolean"},
|
||||
},
|
||||
}
|
||||
}
|
||||
@ -0,0 +1,192 @@
|
||||
package mcp
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestLoadOpenAPIOperationsIncludesCuratedIDs(t *testing.T) {
|
||||
spec, err := loadOpenAPISpec("../../../proto/gen/openapi.yaml")
|
||||
require.NoError(t, err)
|
||||
|
||||
registry, err := buildOperationRegistry(spec)
|
||||
require.NoError(t, err)
|
||||
|
||||
curatedIDs := []string{
|
||||
"MemoService_ListMemos",
|
||||
"MemoService_CreateMemo",
|
||||
"MemoService_GetMemo",
|
||||
"MemoService_UpdateMemo",
|
||||
"MemoService_DeleteMemo",
|
||||
"MemoService_ListMemoComments",
|
||||
"MemoService_CreateMemoComment",
|
||||
"MemoService_ListMemoAttachments",
|
||||
"MemoService_SetMemoAttachments",
|
||||
"MemoService_ListMemoReactions",
|
||||
"MemoService_UpsertMemoReaction",
|
||||
"MemoService_DeleteMemoReaction",
|
||||
"MemoService_ListMemoRelations",
|
||||
"MemoService_SetMemoRelations",
|
||||
"AttachmentService_ListAttachments",
|
||||
"AttachmentService_GetAttachment",
|
||||
"AttachmentService_DeleteAttachment",
|
||||
}
|
||||
|
||||
for _, operationID := range curatedIDs {
|
||||
operation, ok := registry[operationID]
|
||||
require.True(t, ok, "missing curated operation %s", operationID)
|
||||
require.NotEmpty(t, operation.Method, operationID)
|
||||
require.NotEmpty(t, operation.Path, operationID)
|
||||
require.NotEmpty(t, operation.Description, operationID)
|
||||
require.NotNil(t, operation.ResponseSchema, operationID)
|
||||
}
|
||||
|
||||
createMemo := registry["MemoService_CreateMemo"]
|
||||
require.NotNil(t, createMemo.RequestBodySchema)
|
||||
require.Equal(t, "object", createMemo.RequestBodySchema["type"])
|
||||
}
|
||||
|
||||
func TestBuildOperationRegistryRejectsDuplicateOperationIDs(t *testing.T) {
|
||||
spec := &openAPISpec{
|
||||
Paths: map[string]map[string]*openAPIOperation{
|
||||
"/a": {
|
||||
"get": {OperationID: "MemoService_GetMemo", Description: "first"},
|
||||
},
|
||||
"/b": {
|
||||
"get": {OperationID: "MemoService_GetMemo", Description: "second"},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
_, err := buildOperationRegistry(spec)
|
||||
require.ErrorContains(t, err, "duplicate OpenAPI operationId")
|
||||
}
|
||||
|
||||
func TestResolveSchemaRef(t *testing.T) {
|
||||
spec := &openAPISpec{
|
||||
Components: openAPIComponents{
|
||||
Schemas: map[string]jsonSchema{
|
||||
"ListMemosResponse": {
|
||||
"type": "object",
|
||||
"properties": map[string]any{
|
||||
"memos": map[string]any{"type": "array"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
schema, err := resolveSchemaRef(spec, jsonSchema{"$ref": "#/components/schemas/ListMemosResponse"})
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, "object", schema["type"])
|
||||
require.Contains(t, schema["properties"], "memos")
|
||||
}
|
||||
|
||||
func TestResolveSchemaRefRewritesNestedComponentRefs(t *testing.T) {
|
||||
spec := &openAPISpec{
|
||||
Components: openAPIComponents{
|
||||
Schemas: map[string]jsonSchema{
|
||||
"ListMemosResponse": {
|
||||
"type": "object",
|
||||
"properties": map[string]any{
|
||||
"memos": map[string]any{
|
||||
"type": "array",
|
||||
"items": map[string]any{"$ref": "#/components/schemas/Memo"},
|
||||
},
|
||||
},
|
||||
},
|
||||
"Memo": {
|
||||
"type": "object",
|
||||
"properties": map[string]any{
|
||||
"attachment": map[string]any{"$ref": "#/components/schemas/Attachment"},
|
||||
},
|
||||
},
|
||||
"Attachment": {
|
||||
"type": "object",
|
||||
"properties": map[string]any{
|
||||
"name": map[string]any{"type": "string"},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
schema, err := resolveSchemaRef(spec, jsonSchema{"$ref": "#/components/schemas/ListMemosResponse"})
|
||||
require.NoError(t, err)
|
||||
|
||||
data, err := json.Marshal(schema)
|
||||
require.NoError(t, err)
|
||||
require.NotContains(t, string(data), "#/components/schemas")
|
||||
require.Contains(t, string(data), `"#/$defs/Memo"`)
|
||||
require.Contains(t, string(data), `"#/$defs/Attachment"`)
|
||||
}
|
||||
|
||||
func TestBuildOperationRegistryResolvesRequestBodySchema(t *testing.T) {
|
||||
spec := &openAPISpec{
|
||||
Paths: map[string]map[string]*openAPIOperation{
|
||||
"/memos": {
|
||||
"post": {
|
||||
OperationID: "MemoService_CreateMemo",
|
||||
RequestBody: &openAPIRequestBody{
|
||||
Content: map[string]openAPIMediaType{
|
||||
"application/json": {Schema: jsonSchema{"$ref": "#/components/schemas/CreateMemoRequest"}},
|
||||
},
|
||||
},
|
||||
Responses: map[string]openAPIResponse{
|
||||
"200": {
|
||||
Content: map[string]openAPIMediaType{
|
||||
"application/json": {Schema: jsonSchema{"$ref": "#/components/schemas/Memo"}},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
Components: openAPIComponents{
|
||||
Schemas: map[string]jsonSchema{
|
||||
"CreateMemoRequest": {
|
||||
"type": "object",
|
||||
"properties": map[string]any{
|
||||
"content": map[string]any{"type": "string"},
|
||||
},
|
||||
},
|
||||
"Memo": {
|
||||
"type": "object",
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
registry, err := buildOperationRegistry(spec)
|
||||
require.NoError(t, err)
|
||||
|
||||
requestSchema := registry["MemoService_CreateMemo"].RequestBodySchema
|
||||
require.Equal(t, "object", requestSchema["type"])
|
||||
require.Contains(t, requestSchema["properties"], "content")
|
||||
}
|
||||
|
||||
func TestBuildOperationRegistryUsesOKSchemaForEmptySuccessResponse(t *testing.T) {
|
||||
spec := &openAPISpec{
|
||||
Paths: map[string]map[string]*openAPIOperation{
|
||||
"/memos/{memo}": {
|
||||
"delete": {
|
||||
OperationID: "MemoService_DeleteMemo",
|
||||
Responses: map[string]openAPIResponse{
|
||||
"200": {
|
||||
Content: map[string]openAPIMediaType{},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
registry, err := buildOperationRegistry(spec)
|
||||
require.NoError(t, err)
|
||||
|
||||
responseSchema := registry["MemoService_DeleteMemo"].ResponseSchema
|
||||
require.Equal(t, "object", responseSchema["type"])
|
||||
require.Contains(t, responseSchema["properties"], "ok")
|
||||
}
|
||||
@ -0,0 +1,31 @@
|
||||
package mcp
|
||||
|
||||
import (
|
||||
"net/url"
|
||||
"strings"
|
||||
|
||||
"github.com/usememos/memos/internal/profile"
|
||||
)
|
||||
|
||||
func isAllowedMCPOrigin(host string, origin string, profile *profile.Profile) bool {
|
||||
if origin == "" {
|
||||
return true
|
||||
}
|
||||
|
||||
originURL, err := url.Parse(origin)
|
||||
if err != nil || originURL.Scheme == "" || originURL.Host == "" {
|
||||
return false
|
||||
}
|
||||
if strings.EqualFold(originURL.Host, host) {
|
||||
return true
|
||||
}
|
||||
|
||||
if profile == nil || profile.InstanceURL == "" {
|
||||
return false
|
||||
}
|
||||
instanceURL, err := url.Parse(profile.InstanceURL)
|
||||
if err != nil || instanceURL.Scheme == "" || instanceURL.Host == "" {
|
||||
return false
|
||||
}
|
||||
return strings.EqualFold(originURL.Scheme, instanceURL.Scheme) && strings.EqualFold(originURL.Host, instanceURL.Host)
|
||||
}
|
||||
@ -0,0 +1,62 @@
|
||||
package mcp
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/json"
|
||||
|
||||
sdkmcp "github.com/modelcontextprotocol/go-sdk/mcp"
|
||||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
func normalizeStructuredContent(value any) map[string]any {
|
||||
switch typed := value.(type) {
|
||||
case nil:
|
||||
return map[string]any{"ok": true}
|
||||
case map[string]any:
|
||||
return typed
|
||||
case []any:
|
||||
return map[string]any{"result": typed}
|
||||
default:
|
||||
return map[string]any{"result": typed}
|
||||
}
|
||||
}
|
||||
|
||||
func newStructuredToolResult(value any) (*sdkmcp.CallToolResult, error) {
|
||||
structured := normalizeStructuredContent(value)
|
||||
text, err := json.Marshal(structured)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "failed to marshal MCP structured result")
|
||||
}
|
||||
return &sdkmcp.CallToolResult{
|
||||
Content: []sdkmcp.Content{
|
||||
&sdkmcp.TextContent{Text: string(text)},
|
||||
},
|
||||
StructuredContent: structured,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func newToolErrorResult(message string) *sdkmcp.CallToolResult {
|
||||
structured := map[string]any{
|
||||
"error": map[string]any{
|
||||
"message": message,
|
||||
},
|
||||
}
|
||||
return &sdkmcp.CallToolResult{
|
||||
Content: []sdkmcp.Content{
|
||||
&sdkmcp.TextContent{Text: message},
|
||||
},
|
||||
StructuredContent: structured,
|
||||
IsError: true,
|
||||
}
|
||||
}
|
||||
|
||||
func decodeJSONValue(data []byte) (any, error) {
|
||||
if len(bytes.TrimSpace(data)) == 0 {
|
||||
return nil, nil
|
||||
}
|
||||
var value any
|
||||
if err := json.Unmarshal(data, &value); err != nil {
|
||||
return nil, errors.Wrap(err, "failed to decode API JSON response")
|
||||
}
|
||||
return value, nil
|
||||
}
|
||||
@ -0,0 +1,110 @@
|
||||
package mcp
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
|
||||
"github.com/labstack/echo/v5"
|
||||
sdkmcp "github.com/modelcontextprotocol/go-sdk/mcp"
|
||||
"github.com/pkg/errors"
|
||||
"gopkg.in/yaml.v3"
|
||||
|
||||
"github.com/usememos/memos/internal/profile"
|
||||
memosproto "github.com/usememos/memos/proto"
|
||||
)
|
||||
|
||||
// MCPService serves the OpenAPI-driven MCP endpoint.
|
||||
type MCPService struct {
|
||||
profile *profile.Profile
|
||||
|
||||
operationsByTool map[string]*registeredOperation
|
||||
handler http.Handler
|
||||
}
|
||||
|
||||
// NewMCPService creates an MCP service backed by the in-process API routes.
|
||||
func NewMCPService(profile *profile.Profile, echoServer *echo.Echo) (*MCPService, error) {
|
||||
spec, err := loadMCPServiceOpenAPISpec()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
registry, err := buildOperationRegistry(spec)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
tools, operationsByTool, err := buildCuratedTools(registry)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
version := "dev"
|
||||
if profile != nil && profile.Version != "" {
|
||||
version = profile.Version
|
||||
}
|
||||
server := sdkmcp.NewServer(&sdkmcp.Implementation{
|
||||
Name: "memos",
|
||||
Version: version,
|
||||
}, nil)
|
||||
|
||||
adapter := newAPIAdapter(echoServer)
|
||||
for _, tool := range tools {
|
||||
operation := operationsByTool[tool.Name]
|
||||
server.AddTool(tool, newMCPToolHandler(adapter, operation))
|
||||
}
|
||||
|
||||
handler := sdkmcp.NewStreamableHTTPHandler(func(*http.Request) *sdkmcp.Server {
|
||||
return server
|
||||
}, &sdkmcp.StreamableHTTPOptions{
|
||||
Stateless: true,
|
||||
JSONResponse: true,
|
||||
})
|
||||
|
||||
return &MCPService{
|
||||
profile: profile,
|
||||
operationsByTool: operationsByTool,
|
||||
handler: handler,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func loadMCPServiceOpenAPISpec() (*openAPISpec, error) {
|
||||
spec := &openAPISpec{}
|
||||
if err := yaml.Unmarshal(memosproto.OpenAPIYAML(), spec); err != nil {
|
||||
return nil, errors.Wrap(err, "failed to parse embedded OpenAPI spec")
|
||||
}
|
||||
if spec.Paths == nil {
|
||||
return nil, errors.New("embedded OpenAPI spec has no paths")
|
||||
}
|
||||
return spec, nil
|
||||
}
|
||||
|
||||
func newMCPToolHandler(adapter *apiAdapter, operation *registeredOperation) sdkmcp.ToolHandler {
|
||||
return func(ctx context.Context, request *sdkmcp.CallToolRequest) (*sdkmcp.CallToolResult, error) {
|
||||
arguments := map[string]any{}
|
||||
if request.Params != nil && len(request.Params.Arguments) > 0 {
|
||||
if err := json.Unmarshal(request.Params.Arguments, &arguments); err != nil {
|
||||
return newToolErrorResult(errors.Wrap(err, "failed to decode MCP tool arguments").Error()), nil
|
||||
}
|
||||
}
|
||||
if err := validateToolArguments(operation.InputSchema, arguments); err != nil {
|
||||
return newToolErrorResult(err.Error()), nil
|
||||
}
|
||||
|
||||
authorization := ""
|
||||
if request.Extra != nil {
|
||||
authorization = request.Extra.Header.Get("Authorization")
|
||||
}
|
||||
return adapter.execute(ctx, operation.Operation, arguments, authorization)
|
||||
}
|
||||
}
|
||||
|
||||
// RegisterRoutes registers the streamable HTTP MCP endpoint.
|
||||
func (s *MCPService) RegisterRoutes(echoServer *echo.Echo) {
|
||||
echoServer.Any("/mcp", func(c *echo.Context) error {
|
||||
request := c.Request()
|
||||
if !isAllowedMCPOrigin(request.Host, request.Header.Get("Origin"), s.profile) {
|
||||
return c.NoContent(http.StatusForbidden)
|
||||
}
|
||||
s.handler.ServeHTTP(c.Response(), request)
|
||||
return nil
|
||||
})
|
||||
}
|
||||
@ -0,0 +1,283 @@
|
||||
package mcp
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"encoding/json"
|
||||
"net/http"
|
||||
"net/http/httptest"
|
||||
"os"
|
||||
"testing"
|
||||
|
||||
"github.com/labstack/echo/v5"
|
||||
sdkmcp "github.com/modelcontextprotocol/go-sdk/mcp"
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/usememos/memos/internal/profile"
|
||||
memosproto "github.com/usememos/memos/proto"
|
||||
)
|
||||
|
||||
func TestIsAllowedMCPOrigin(t *testing.T) {
|
||||
profile := &profile.Profile{InstanceURL: "https://memos.example.com/app"}
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
host string
|
||||
origin string
|
||||
want bool
|
||||
}{
|
||||
{name: "empty origin", host: "localhost:5230", origin: "", want: true},
|
||||
{name: "same http host", host: "localhost:5230", origin: "http://localhost:5230", want: true},
|
||||
{name: "same https host", host: "memos.example.com", origin: "https://memos.example.com", want: true},
|
||||
{name: "configured instance URL origin", host: "127.0.0.1:5230", origin: "https://memos.example.com", want: true},
|
||||
{name: "configured instance URL ignores path", host: "127.0.0.1:5230", origin: "https://memos.example.com", want: true},
|
||||
{name: "different host", host: "localhost:5230", origin: "https://evil.example.com", want: false},
|
||||
{name: "invalid origin", host: "localhost:5230", origin: "not a url", want: false},
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
t.Run(test.name, func(t *testing.T) {
|
||||
require.Equal(t, test.want, isAllowedMCPOrigin(test.host, test.origin, profile))
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestNewMCPServiceRegistersCuratedTools(t *testing.T) {
|
||||
echoServer := echo.New()
|
||||
|
||||
service, err := NewMCPService(&profile.Profile{Version: "test-version"}, echoServer)
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, service.handler)
|
||||
require.Len(t, service.operationsByTool, len(curatedOperationIDs))
|
||||
|
||||
operation := service.operationsByTool["memo_list_memos"]
|
||||
require.NotNil(t, operation)
|
||||
require.Equal(t, "MemoService_ListMemos", operation.OperationID)
|
||||
require.Equal(t, "GET", operation.Method)
|
||||
require.Equal(t, "/api/v1/memos", operation.Path)
|
||||
}
|
||||
|
||||
func TestNewMCPServiceUsesEmbeddedOpenAPISpec(t *testing.T) {
|
||||
t.Chdir(t.TempDir())
|
||||
|
||||
service, err := NewMCPService(&profile.Profile{Version: "test-version"}, echo.New())
|
||||
require.NoError(t, err)
|
||||
require.NotNil(t, service.handler)
|
||||
require.Len(t, service.operationsByTool, len(curatedOperationIDs))
|
||||
}
|
||||
|
||||
func TestEmbeddedOpenAPISpecMatchesGeneratedFile(t *testing.T) {
|
||||
generated, err := os.ReadFile("../../../proto/gen/openapi.yaml")
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, generated, memosproto.OpenAPIYAML())
|
||||
}
|
||||
|
||||
func TestMCPToolHandlerForwardsArgumentsAndAuthorization(t *testing.T) {
|
||||
echoServer := echo.New()
|
||||
echoServer.GET("/api/v1/memos", func(c *echo.Context) error {
|
||||
require.Equal(t, "Bearer token", c.Request().Header.Get("Authorization"))
|
||||
require.Equal(t, "7", c.QueryParam("pageSize"))
|
||||
return c.JSON(http.StatusOK, map[string]any{
|
||||
"memos": []any{map[string]any{"name": "memos/test"}},
|
||||
})
|
||||
})
|
||||
|
||||
operation := ®isteredOperation{
|
||||
Operation: &openAPIOperation{
|
||||
Method: "GET",
|
||||
Path: "/api/v1/memos",
|
||||
Parameters: []openAPIParameter{{Name: "pageSize", In: "query", Schema: jsonSchema{"type": "integer"}}},
|
||||
},
|
||||
}
|
||||
handler := newMCPToolHandler(newAPIAdapter(echoServer), operation)
|
||||
arguments, err := json.Marshal(map[string]any{"pageSize": 7})
|
||||
require.NoError(t, err)
|
||||
|
||||
result, err := handler(context.Background(), &sdkmcp.CallToolRequest{
|
||||
Params: &sdkmcp.CallToolParamsRaw{
|
||||
Name: "memo_list_memos",
|
||||
Arguments: arguments,
|
||||
},
|
||||
Extra: &sdkmcp.RequestExtra{
|
||||
Header: http.Header{"Authorization": []string{"Bearer token"}},
|
||||
},
|
||||
})
|
||||
require.NoError(t, err)
|
||||
require.False(t, result.IsError)
|
||||
require.Equal(t, map[string]any{
|
||||
"memos": []any{map[string]any{"name": "memos/test"}},
|
||||
}, result.StructuredContent)
|
||||
}
|
||||
|
||||
func TestMCPProtocolListsCuratedToolsOnly(t *testing.T) {
|
||||
echoServer := echo.New()
|
||||
|
||||
service, err := NewMCPService(&profile.Profile{Version: "test-version"}, echoServer)
|
||||
require.NoError(t, err)
|
||||
service.RegisterRoutes(echoServer)
|
||||
|
||||
initializeMCP(t, echoServer)
|
||||
response := postMCP(t, echoServer, map[string]any{
|
||||
"jsonrpc": "2.0",
|
||||
"id": 2,
|
||||
"method": "tools/list",
|
||||
})
|
||||
|
||||
result, ok := response["result"].(map[string]any)
|
||||
require.True(t, ok)
|
||||
tools, ok := result["tools"].([]any)
|
||||
require.True(t, ok)
|
||||
require.Len(t, tools, len(curatedOperationIDs))
|
||||
|
||||
names := map[string]struct{}{}
|
||||
for _, rawTool := range tools {
|
||||
tool, ok := rawTool.(map[string]any)
|
||||
require.True(t, ok)
|
||||
name, ok := tool["name"].(string)
|
||||
require.True(t, ok)
|
||||
names[name] = struct{}{}
|
||||
require.Contains(t, tool, "inputSchema")
|
||||
require.Contains(t, tool, "outputSchema")
|
||||
}
|
||||
require.Contains(t, names, "memo_list_memos")
|
||||
require.Contains(t, names, "memo_create_memo")
|
||||
require.NotContains(t, names, "auth_sign_in")
|
||||
require.NotContains(t, names, "user_create_user")
|
||||
}
|
||||
|
||||
func TestMCPToolCallReturnsObjectStructuredContent(t *testing.T) {
|
||||
echoServer := echo.New()
|
||||
echoServer.GET("/api/v1/memos", func(c *echo.Context) error {
|
||||
return c.JSON(http.StatusOK, map[string]any{
|
||||
"memos": []any{map[string]any{"name": "memos/abc123"}},
|
||||
})
|
||||
})
|
||||
|
||||
service, err := NewMCPService(&profile.Profile{Version: "test-version"}, echoServer)
|
||||
require.NoError(t, err)
|
||||
service.RegisterRoutes(echoServer)
|
||||
|
||||
initializeMCP(t, echoServer)
|
||||
response := postMCP(t, echoServer, map[string]any{
|
||||
"jsonrpc": "2.0",
|
||||
"id": 2,
|
||||
"method": "tools/call",
|
||||
"params": map[string]any{
|
||||
"name": "memo_list_memos",
|
||||
"arguments": map[string]any{
|
||||
"pageSize": 1,
|
||||
},
|
||||
},
|
||||
})
|
||||
|
||||
result, ok := response["result"].(map[string]any)
|
||||
require.True(t, ok)
|
||||
require.Equal(t, map[string]any{
|
||||
"memos": []any{map[string]any{"name": "memos/abc123"}},
|
||||
}, result["structuredContent"])
|
||||
}
|
||||
|
||||
func TestMCPToolCallRejectsInvalidArguments(t *testing.T) {
|
||||
echoServer := echo.New()
|
||||
routeHits := 0
|
||||
echoServer.GET("/api/v1/memos", func(c *echo.Context) error {
|
||||
routeHits++
|
||||
return c.JSON(http.StatusOK, map[string]any{"memos": []any{}})
|
||||
})
|
||||
echoServer.GET("/api/v1/memos/:memo", func(c *echo.Context) error {
|
||||
routeHits++
|
||||
return c.JSON(http.StatusOK, map[string]any{"name": c.Param("memo")})
|
||||
})
|
||||
|
||||
service, err := NewMCPService(&profile.Profile{Version: "test-version"}, echoServer)
|
||||
require.NoError(t, err)
|
||||
service.RegisterRoutes(echoServer)
|
||||
|
||||
initializeMCP(t, echoServer)
|
||||
|
||||
tests := []struct {
|
||||
name string
|
||||
toolName string
|
||||
arguments map[string]any
|
||||
wantError string
|
||||
}{
|
||||
{
|
||||
name: "unknown argument",
|
||||
toolName: "memo_list_memos",
|
||||
arguments: map[string]any{"unexpected": true},
|
||||
wantError: `unknown argument "unexpected"`,
|
||||
},
|
||||
{
|
||||
name: "missing required argument",
|
||||
toolName: "memo_get_memo",
|
||||
arguments: map[string]any{},
|
||||
wantError: `missing required argument "memo"`,
|
||||
},
|
||||
{
|
||||
name: "wrong primitive type",
|
||||
toolName: "memo_list_memos",
|
||||
arguments: map[string]any{"pageSize": "ten"},
|
||||
wantError: `argument "pageSize" must be integer`,
|
||||
},
|
||||
}
|
||||
|
||||
for index, test := range tests {
|
||||
t.Run(test.name, func(t *testing.T) {
|
||||
response := postMCP(t, echoServer, map[string]any{
|
||||
"jsonrpc": "2.0",
|
||||
"id": index + 2,
|
||||
"method": "tools/call",
|
||||
"params": map[string]any{
|
||||
"name": test.toolName,
|
||||
"arguments": test.arguments,
|
||||
},
|
||||
})
|
||||
|
||||
result, ok := response["result"].(map[string]any)
|
||||
require.True(t, ok)
|
||||
require.Equal(t, true, result["isError"])
|
||||
structured, ok := result["structuredContent"].(map[string]any)
|
||||
require.True(t, ok)
|
||||
errorObject, ok := structured["error"].(map[string]any)
|
||||
require.True(t, ok)
|
||||
require.Contains(t, errorObject["message"], test.wantError)
|
||||
})
|
||||
}
|
||||
require.Zero(t, routeHits)
|
||||
}
|
||||
|
||||
func initializeMCP(t *testing.T, echoServer *echo.Echo) {
|
||||
t.Helper()
|
||||
response := postMCP(t, echoServer, map[string]any{
|
||||
"jsonrpc": "2.0",
|
||||
"id": 1,
|
||||
"method": "initialize",
|
||||
"params": map[string]any{
|
||||
"protocolVersion": "2025-06-18",
|
||||
"capabilities": map[string]any{},
|
||||
"clientInfo": map[string]any{
|
||||
"name": "memos-test",
|
||||
"version": "1.0.0",
|
||||
},
|
||||
},
|
||||
})
|
||||
require.NotNil(t, response["result"])
|
||||
}
|
||||
|
||||
func postMCP(t *testing.T, echoServer *echo.Echo, payload map[string]any) map[string]any {
|
||||
t.Helper()
|
||||
data, err := json.Marshal(payload)
|
||||
require.NoError(t, err)
|
||||
|
||||
request := httptest.NewRequest(http.MethodPost, "/mcp", bytes.NewReader(data))
|
||||
request.Header.Set("Content-Type", "application/json")
|
||||
request.Header.Set("Accept", "application/json, text/event-stream")
|
||||
|
||||
recorder := httptest.NewRecorder()
|
||||
echoServer.ServeHTTP(recorder, request)
|
||||
require.Equal(t, http.StatusOK, recorder.Code, recorder.Body.String())
|
||||
|
||||
var response map[string]any
|
||||
require.NoError(t, json.Unmarshal(recorder.Body.Bytes(), &response))
|
||||
return response
|
||||
}
|
||||
@ -0,0 +1,226 @@
|
||||
package mcp
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"math"
|
||||
"strings"
|
||||
|
||||
googlejsonschema "github.com/google/jsonschema-go/jsonschema"
|
||||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
func validateToolArguments(schema jsonSchema, arguments map[string]any) error {
|
||||
if schema == nil {
|
||||
return nil
|
||||
}
|
||||
if err := validateSchemaValue(schema, "argument", "argument", arguments, schema); err != nil {
|
||||
return err
|
||||
}
|
||||
return validateArgumentsWithJSONSchema(schema, arguments)
|
||||
}
|
||||
|
||||
func validateArgumentsWithJSONSchema(schema jsonSchema, arguments map[string]any) error {
|
||||
data, err := json.Marshal(schema)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "failed to marshal MCP tool input schema")
|
||||
}
|
||||
jsonSchema := &googlejsonschema.Schema{}
|
||||
if err := json.Unmarshal(data, jsonSchema); err != nil {
|
||||
return errors.Wrap(err, "failed to parse MCP tool input schema")
|
||||
}
|
||||
resolved, err := jsonSchema.Resolve(nil)
|
||||
if err != nil {
|
||||
return errors.Wrap(err, "failed to resolve MCP tool input schema")
|
||||
}
|
||||
if err := resolved.Validate(arguments); err != nil {
|
||||
return errors.Wrap(err, "MCP tool arguments do not match input schema")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func validateSchemaValue(schemaValue any, path string, label string, value any, root jsonSchema) error {
|
||||
schema, ok := asSchemaMap(schemaValue)
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
if ref, ok := schema["$ref"].(string); ok && ref != "" {
|
||||
resolved, ok := localSchemaDef(root, ref)
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
return validateSchemaValue(resolved, path, label, value, root)
|
||||
}
|
||||
if value == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
types := schemaTypes(schema["type"])
|
||||
if len(types) == 0 && schema["properties"] != nil {
|
||||
types = []string{"object"}
|
||||
}
|
||||
for _, schemaType := range types {
|
||||
if schemaTypeMatchesValue(schemaType, value) {
|
||||
if schemaType == "object" {
|
||||
return validateObjectSchema(schema, path, value, root)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
}
|
||||
if len(types) == 0 {
|
||||
return nil
|
||||
}
|
||||
return errors.Errorf(`%s "%s" must be %s`, label, path, types[0])
|
||||
}
|
||||
|
||||
func validateObjectSchema(schema map[string]any, path string, value any, root jsonSchema) error {
|
||||
object, ok := value.(map[string]any)
|
||||
if !ok {
|
||||
return errors.Errorf(`argument "%s" must be object`, path)
|
||||
}
|
||||
|
||||
properties := schemaProperties(schema["properties"])
|
||||
for _, required := range requiredNames(schema["required"]) {
|
||||
if child, ok := object[required]; !ok || child == nil {
|
||||
return errors.Errorf(`missing required argument "%s"`, joinSchemaPath(path, required))
|
||||
}
|
||||
}
|
||||
|
||||
if additionalProperties, ok := schema["additionalProperties"].(bool); ok && !additionalProperties {
|
||||
for name := range object {
|
||||
if _, ok := properties[name]; !ok {
|
||||
return errors.Errorf(`unknown argument "%s"`, joinSchemaPath(path, name))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
for name, childSchema := range properties {
|
||||
childValue, ok := object[name]
|
||||
if !ok {
|
||||
continue
|
||||
}
|
||||
if err := validateSchemaValue(childSchema, joinSchemaPath(path, name), "argument", childValue, root); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func asSchemaMap(value any) (map[string]any, bool) {
|
||||
switch typed := value.(type) {
|
||||
case jsonSchema:
|
||||
return map[string]any(typed), true
|
||||
case map[string]any:
|
||||
return typed, true
|
||||
default:
|
||||
return nil, false
|
||||
}
|
||||
}
|
||||
|
||||
func schemaProperties(value any) map[string]any {
|
||||
switch typed := value.(type) {
|
||||
case map[string]any:
|
||||
return typed
|
||||
case jsonSchema:
|
||||
return map[string]any(typed)
|
||||
default:
|
||||
return map[string]any{}
|
||||
}
|
||||
}
|
||||
|
||||
func schemaTypes(value any) []string {
|
||||
switch typed := value.(type) {
|
||||
case string:
|
||||
return []string{typed}
|
||||
case []any:
|
||||
types := make([]string, 0, len(typed))
|
||||
for _, item := range typed {
|
||||
if typeName, ok := item.(string); ok {
|
||||
types = append(types, typeName)
|
||||
}
|
||||
}
|
||||
return types
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func requiredNames(value any) []string {
|
||||
switch typed := value.(type) {
|
||||
case []string:
|
||||
return typed
|
||||
case []any:
|
||||
names := make([]string, 0, len(typed))
|
||||
for _, item := range typed {
|
||||
if name, ok := item.(string); ok {
|
||||
names = append(names, name)
|
||||
}
|
||||
}
|
||||
return names
|
||||
default:
|
||||
return nil
|
||||
}
|
||||
}
|
||||
|
||||
func schemaTypeMatchesValue(schemaType string, value any) bool {
|
||||
switch schemaType {
|
||||
case "array":
|
||||
_, ok := value.([]any)
|
||||
return ok
|
||||
case "boolean":
|
||||
_, ok := value.(bool)
|
||||
return ok
|
||||
case "integer":
|
||||
return isInteger(value)
|
||||
case "number":
|
||||
return isNumber(value)
|
||||
case "object":
|
||||
_, ok := value.(map[string]any)
|
||||
return ok
|
||||
case "string":
|
||||
_, ok := value.(string)
|
||||
return ok
|
||||
default:
|
||||
return true
|
||||
}
|
||||
}
|
||||
|
||||
func isInteger(value any) bool {
|
||||
switch typed := value.(type) {
|
||||
case int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64:
|
||||
return true
|
||||
case float32:
|
||||
return math.Trunc(float64(typed)) == float64(typed)
|
||||
case float64:
|
||||
return math.Trunc(typed) == typed
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
func isNumber(value any) bool {
|
||||
switch value.(type) {
|
||||
case int, int8, int16, int32, int64, uint, uint8, uint16, uint32, uint64, float32, float64:
|
||||
return true
|
||||
default:
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
func localSchemaDef(root jsonSchema, ref string) (any, bool) {
|
||||
const prefix = "#/$defs/"
|
||||
if !strings.HasPrefix(ref, prefix) {
|
||||
return nil, false
|
||||
}
|
||||
defs, ok := root["$defs"].(map[string]any)
|
||||
if !ok {
|
||||
return nil, false
|
||||
}
|
||||
return defs[strings.TrimPrefix(ref, prefix)], true
|
||||
}
|
||||
|
||||
func joinSchemaPath(parent string, child string) string {
|
||||
if parent == "" || parent == "argument" {
|
||||
return child
|
||||
}
|
||||
return parent + "." + child
|
||||
}
|
||||
@ -0,0 +1,63 @@
|
||||
package mcp
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
)
|
||||
|
||||
func TestValidateToolArgumentsRejectsUnknownArguments(t *testing.T) {
|
||||
schema := jsonSchema{
|
||||
"type": "object",
|
||||
"properties": map[string]any{
|
||||
"pageSize": map[string]any{"type": "integer"},
|
||||
},
|
||||
"additionalProperties": false,
|
||||
}
|
||||
|
||||
err := validateToolArguments(schema, map[string]any{"unexpected": true})
|
||||
require.ErrorContains(t, err, `unknown argument "unexpected"`)
|
||||
}
|
||||
|
||||
func TestValidateToolArgumentsRejectsMissingRequiredArguments(t *testing.T) {
|
||||
schema := jsonSchema{
|
||||
"type": "object",
|
||||
"properties": map[string]any{
|
||||
"memo": map[string]any{"type": "string"},
|
||||
},
|
||||
"required": []string{"memo"},
|
||||
"additionalProperties": false,
|
||||
}
|
||||
|
||||
err := validateToolArguments(schema, map[string]any{})
|
||||
require.ErrorContains(t, err, `missing required argument "memo"`)
|
||||
}
|
||||
|
||||
func TestValidateToolArgumentsRejectsWrongPrimitiveTypes(t *testing.T) {
|
||||
schema := jsonSchema{
|
||||
"type": "object",
|
||||
"properties": map[string]any{
|
||||
"pageSize": map[string]any{"type": "integer"},
|
||||
},
|
||||
"additionalProperties": false,
|
||||
}
|
||||
|
||||
err := validateToolArguments(schema, map[string]any{"pageSize": "ten"})
|
||||
require.ErrorContains(t, err, `argument "pageSize" must be integer`)
|
||||
}
|
||||
|
||||
func TestValidateToolArgumentsUsesJSONSchemaValidation(t *testing.T) {
|
||||
schema := jsonSchema{
|
||||
"type": "object",
|
||||
"properties": map[string]any{
|
||||
"state": map[string]any{
|
||||
"type": "string",
|
||||
"enum": []any{"NORMAL", "ARCHIVED"},
|
||||
},
|
||||
},
|
||||
"additionalProperties": false,
|
||||
}
|
||||
|
||||
err := validateToolArguments(schema, map[string]any{"state": "DELETED"})
|
||||
require.ErrorContains(t, err, "MCP tool arguments do not match input schema")
|
||||
}
|
||||
Loading…
Reference in New Issue