mirror of https://github.com/usememos/memos
chore: harden MCP access control and origin validation
parent
7c708ee27e
commit
d3f6e8ee31
@ -0,0 +1,113 @@
|
|||||||
|
package mcp
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"net/http"
|
||||||
|
"net/url"
|
||||||
|
"strconv"
|
||||||
|
"strings"
|
||||||
|
|
||||||
|
"github.com/pkg/errors"
|
||||||
|
|
||||||
|
"github.com/usememos/memos/store"
|
||||||
|
)
|
||||||
|
|
||||||
|
// checkMemoAccess returns an error if the caller cannot read the memo.
|
||||||
|
// userID == 0 means anonymous.
|
||||||
|
func checkMemoAccess(memo *store.Memo, userID int32) error {
|
||||||
|
if memo.RowStatus == store.Archived && memo.CreatorID != userID {
|
||||||
|
return errors.New("permission denied")
|
||||||
|
}
|
||||||
|
|
||||||
|
switch memo.Visibility {
|
||||||
|
case store.Protected:
|
||||||
|
if userID == 0 {
|
||||||
|
return errors.New("permission denied")
|
||||||
|
}
|
||||||
|
case store.Private:
|
||||||
|
if memo.CreatorID != userID {
|
||||||
|
return errors.New("permission denied")
|
||||||
|
}
|
||||||
|
default:
|
||||||
|
// store.Public and any unknown visibility: allow.
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
func checkMemoOwnership(memo *store.Memo, userID int32) error {
|
||||||
|
if memo.CreatorID != userID {
|
||||||
|
return errors.New("permission denied")
|
||||||
|
}
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
|
||||||
|
// applyVisibilityFilter restricts find to memos the caller may see.
|
||||||
|
func applyVisibilityFilter(find *store.FindMemo, userID int32, rowStatus *store.RowStatus) {
|
||||||
|
if rowStatus != nil && *rowStatus == store.Archived {
|
||||||
|
if userID == 0 {
|
||||||
|
impossibleCreatorID := int32(-1)
|
||||||
|
find.CreatorID = &impossibleCreatorID
|
||||||
|
return
|
||||||
|
}
|
||||||
|
find.CreatorID = &userID
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if userID == 0 {
|
||||||
|
find.VisibilityList = []store.Visibility{store.Public}
|
||||||
|
return
|
||||||
|
}
|
||||||
|
find.Filters = append(find.Filters, "creator_id == "+itoa32(userID)+` || visibility in ["PUBLIC", "PROTECTED"]`)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *MCPService) checkAttachmentAccess(ctx context.Context, attachment *store.Attachment, userID int32) error {
|
||||||
|
if attachment.CreatorID == userID {
|
||||||
|
return nil
|
||||||
|
}
|
||||||
|
if attachment.MemoID == nil {
|
||||||
|
return errors.New("permission denied")
|
||||||
|
}
|
||||||
|
|
||||||
|
memo, err := s.store.GetMemo(ctx, &store.FindMemo{ID: attachment.MemoID})
|
||||||
|
if err != nil {
|
||||||
|
return errors.Wrap(err, "failed to get linked memo")
|
||||||
|
}
|
||||||
|
if memo == nil {
|
||||||
|
return errors.New("linked memo not found")
|
||||||
|
}
|
||||||
|
return checkMemoAccess(memo, userID)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *MCPService) isAllowedOrigin(r *http.Request) bool {
|
||||||
|
origin := r.Header.Get("Origin")
|
||||||
|
if origin == "" {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
originURL, err := url.Parse(origin)
|
||||||
|
if err != nil || originURL.Scheme == "" || originURL.Host == "" {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
if sameOriginHost(originURL.Host, r.Host) {
|
||||||
|
return true
|
||||||
|
}
|
||||||
|
|
||||||
|
if s.profile.InstanceURL == "" {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
instanceURL, err := url.Parse(s.profile.InstanceURL)
|
||||||
|
if err != nil || instanceURL.Scheme == "" || instanceURL.Host == "" {
|
||||||
|
return false
|
||||||
|
}
|
||||||
|
|
||||||
|
return strings.EqualFold(originURL.Scheme, instanceURL.Scheme) && sameOriginHost(originURL.Host, instanceURL.Host)
|
||||||
|
}
|
||||||
|
|
||||||
|
func sameOriginHost(a, b string) bool {
|
||||||
|
return strings.EqualFold(a, b)
|
||||||
|
}
|
||||||
|
|
||||||
|
func itoa32(v int32) string {
|
||||||
|
return strconv.FormatInt(int64(v), 10)
|
||||||
|
}
|
||||||
@ -0,0 +1,275 @@
|
|||||||
|
package mcp
|
||||||
|
|
||||||
|
import (
|
||||||
|
"context"
|
||||||
|
"encoding/json"
|
||||||
|
"net/http/httptest"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/lithammer/shortuuid/v4"
|
||||||
|
"github.com/mark3labs/mcp-go/mcp"
|
||||||
|
"github.com/stretchr/testify/require"
|
||||||
|
|
||||||
|
"github.com/usememos/memos/internal/profile"
|
||||||
|
storepb "github.com/usememos/memos/proto/gen/store"
|
||||||
|
"github.com/usememos/memos/server/auth"
|
||||||
|
"github.com/usememos/memos/store"
|
||||||
|
teststore "github.com/usememos/memos/store/test"
|
||||||
|
)
|
||||||
|
|
||||||
|
type testMCPService struct {
|
||||||
|
service *MCPService
|
||||||
|
store *store.Store
|
||||||
|
}
|
||||||
|
|
||||||
|
func newTestMCPService(t *testing.T) *testMCPService {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
ctx := context.Background()
|
||||||
|
stores := teststore.NewTestingStore(ctx, t)
|
||||||
|
t.Cleanup(func() {
|
||||||
|
require.NoError(t, stores.Close())
|
||||||
|
})
|
||||||
|
|
||||||
|
svc := NewMCPService(&profile.Profile{
|
||||||
|
Driver: "sqlite",
|
||||||
|
InstanceURL: "https://notes.example.com",
|
||||||
|
}, stores, "test-secret")
|
||||||
|
return &testMCPService{
|
||||||
|
service: svc,
|
||||||
|
store: stores,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *testMCPService) createUser(t *testing.T, username string) *store.User {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
user, err := s.store.CreateUser(context.Background(), &store.User{
|
||||||
|
Username: username,
|
||||||
|
Role: store.RoleUser,
|
||||||
|
Email: username + "@example.com",
|
||||||
|
})
|
||||||
|
require.NoError(t, err)
|
||||||
|
return user
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *testMCPService) createMemo(t *testing.T, creatorID int32, visibility store.Visibility, content string) *store.Memo {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
memo, err := s.store.CreateMemo(context.Background(), &store.Memo{
|
||||||
|
UID: shortuuid.New(),
|
||||||
|
CreatorID: creatorID,
|
||||||
|
RowStatus: store.Normal,
|
||||||
|
Visibility: visibility,
|
||||||
|
Content: content,
|
||||||
|
})
|
||||||
|
require.NoError(t, err)
|
||||||
|
return memo
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *testMCPService) archiveMemo(t *testing.T, memoID int32) {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
rowStatus := store.Archived
|
||||||
|
require.NoError(t, s.store.UpdateMemo(context.Background(), &store.UpdateMemo{
|
||||||
|
ID: memoID,
|
||||||
|
RowStatus: &rowStatus,
|
||||||
|
}))
|
||||||
|
}
|
||||||
|
|
||||||
|
func (s *testMCPService) createAttachment(t *testing.T, creatorID int32, memoID *int32) *store.Attachment {
|
||||||
|
t.Helper()
|
||||||
|
|
||||||
|
attachment, err := s.store.CreateAttachment(context.Background(), &store.Attachment{
|
||||||
|
UID: shortuuid.New(),
|
||||||
|
CreatorID: creatorID,
|
||||||
|
Filename: "note.txt",
|
||||||
|
Type: "text/plain",
|
||||||
|
Size: 4,
|
||||||
|
StorageType: storepb.AttachmentStorageType_ATTACHMENT_STORAGE_TYPE_UNSPECIFIED,
|
||||||
|
Reference: "db://attachment/note.txt",
|
||||||
|
MemoID: memoID,
|
||||||
|
})
|
||||||
|
require.NoError(t, err)
|
||||||
|
return attachment
|
||||||
|
}
|
||||||
|
|
||||||
|
func withUser(ctx context.Context, userID int32) context.Context {
|
||||||
|
return context.WithValue(ctx, auth.UserIDContextKey, userID)
|
||||||
|
}
|
||||||
|
|
||||||
|
func toolRequest(name string, arguments map[string]any) mcp.CallToolRequest {
|
||||||
|
return mcp.CallToolRequest{
|
||||||
|
Params: mcp.CallToolParams{
|
||||||
|
Name: name,
|
||||||
|
Arguments: arguments,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func firstText(t *testing.T, result *mcp.CallToolResult) string {
|
||||||
|
t.Helper()
|
||||||
|
require.NotEmpty(t, result.Content)
|
||||||
|
text, ok := result.Content[0].(mcp.TextContent)
|
||||||
|
require.True(t, ok)
|
||||||
|
return text.Text
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestHandleGetMemoAndReadResourceDenyArchivedMemoToNonCreator(t *testing.T) {
|
||||||
|
ts := newTestMCPService(t)
|
||||||
|
owner := ts.createUser(t, "owner")
|
||||||
|
other := ts.createUser(t, "other")
|
||||||
|
|
||||||
|
memo := ts.createMemo(t, owner.ID, store.Public, "archived")
|
||||||
|
ts.archiveMemo(t, memo.ID)
|
||||||
|
|
||||||
|
ctx := withUser(context.Background(), other.ID)
|
||||||
|
result, err := ts.service.handleGetMemo(ctx, toolRequest("get_memo", map[string]any{
|
||||||
|
"name": "memos/" + memo.UID,
|
||||||
|
}))
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.True(t, result.IsError)
|
||||||
|
require.Contains(t, firstText(t, result), "permission denied")
|
||||||
|
|
||||||
|
_, err = ts.service.handleReadMemoResource(ctx, mcp.ReadResourceRequest{
|
||||||
|
Params: mcp.ReadResourceParams{
|
||||||
|
URI: "memo://memos/" + memo.UID,
|
||||||
|
},
|
||||||
|
})
|
||||||
|
require.ErrorContains(t, err, "permission denied")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestHandleListMemosArchivedOnlyReturnsCreatorMemos(t *testing.T) {
|
||||||
|
ts := newTestMCPService(t)
|
||||||
|
owner := ts.createUser(t, "owner")
|
||||||
|
other := ts.createUser(t, "other")
|
||||||
|
|
||||||
|
ownerMemo := ts.createMemo(t, owner.ID, store.Public, "owner archived")
|
||||||
|
ts.archiveMemo(t, ownerMemo.ID)
|
||||||
|
otherMemo := ts.createMemo(t, other.ID, store.Public, "other archived")
|
||||||
|
ts.archiveMemo(t, otherMemo.ID)
|
||||||
|
|
||||||
|
result, err := ts.service.handleListMemos(withUser(context.Background(), owner.ID), toolRequest("list_memos", map[string]any{
|
||||||
|
"state": "ARCHIVED",
|
||||||
|
}))
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.False(t, result.IsError)
|
||||||
|
|
||||||
|
var payload struct {
|
||||||
|
Memos []memoJSON `json:"memos"`
|
||||||
|
}
|
||||||
|
require.NoError(t, json.Unmarshal([]byte(firstText(t, result)), &payload))
|
||||||
|
require.Len(t, payload.Memos, 1)
|
||||||
|
require.Equal(t, "memos/"+ownerMemo.UID, payload.Memos[0].Name)
|
||||||
|
|
||||||
|
anonResult, err := ts.service.handleListMemos(context.Background(), toolRequest("list_memos", map[string]any{
|
||||||
|
"state": "ARCHIVED",
|
||||||
|
}))
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.NoError(t, json.Unmarshal([]byte(firstText(t, anonResult)), &payload))
|
||||||
|
require.Empty(t, payload.Memos)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestHandleListMemoRelationsFiltersUnreadableTargets(t *testing.T) {
|
||||||
|
ts := newTestMCPService(t)
|
||||||
|
owner := ts.createUser(t, "owner")
|
||||||
|
privateUser := ts.createUser(t, "private-user")
|
||||||
|
publicUser := ts.createUser(t, "public-user")
|
||||||
|
|
||||||
|
source := ts.createMemo(t, owner.ID, store.Public, "source")
|
||||||
|
privateTarget := ts.createMemo(t, privateUser.ID, store.Private, "private")
|
||||||
|
publicTarget := ts.createMemo(t, publicUser.ID, store.Public, "public")
|
||||||
|
|
||||||
|
_, err := ts.store.UpsertMemoRelation(context.Background(), &store.MemoRelation{
|
||||||
|
MemoID: source.ID,
|
||||||
|
RelatedMemoID: privateTarget.ID,
|
||||||
|
Type: store.MemoRelationReference,
|
||||||
|
})
|
||||||
|
require.NoError(t, err)
|
||||||
|
_, err = ts.store.UpsertMemoRelation(context.Background(), &store.MemoRelation{
|
||||||
|
MemoID: source.ID,
|
||||||
|
RelatedMemoID: publicTarget.ID,
|
||||||
|
Type: store.MemoRelationReference,
|
||||||
|
})
|
||||||
|
require.NoError(t, err)
|
||||||
|
|
||||||
|
result, err := ts.service.handleListMemoRelations(context.Background(), toolRequest("list_memo_relations", map[string]any{
|
||||||
|
"name": "memos/" + source.UID,
|
||||||
|
}))
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.False(t, result.IsError)
|
||||||
|
|
||||||
|
var relations []relationJSON
|
||||||
|
require.NoError(t, json.Unmarshal([]byte(firstText(t, result)), &relations))
|
||||||
|
require.Len(t, relations, 1)
|
||||||
|
require.Equal(t, "memos/"+publicTarget.UID, relations[0].RelatedMemo)
|
||||||
|
|
||||||
|
denied, err := ts.service.handleListMemoRelations(context.Background(), toolRequest("list_memo_relations", map[string]any{
|
||||||
|
"name": "memos/" + privateTarget.UID,
|
||||||
|
}))
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.True(t, denied.IsError)
|
||||||
|
require.Contains(t, firstText(t, denied), "permission denied")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestHandleLinkAttachmentToMemoRequiresMemoOwnership(t *testing.T) {
|
||||||
|
ts := newTestMCPService(t)
|
||||||
|
attachmentOwner := ts.createUser(t, "attachment-owner")
|
||||||
|
memoOwner := ts.createUser(t, "memo-owner")
|
||||||
|
|
||||||
|
attachment := ts.createAttachment(t, attachmentOwner.ID, nil)
|
||||||
|
memo := ts.createMemo(t, memoOwner.ID, store.Public, "target")
|
||||||
|
|
||||||
|
result, err := ts.service.handleLinkAttachmentToMemo(withUser(context.Background(), attachmentOwner.ID), toolRequest("link_attachment_to_memo", map[string]any{
|
||||||
|
"name": "attachments/" + attachment.UID,
|
||||||
|
"memo": "memos/" + memo.UID,
|
||||||
|
}))
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.True(t, result.IsError)
|
||||||
|
require.Contains(t, firstText(t, result), "permission denied")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestHandleGetAttachmentDeniesArchivedLinkedMemoToNonCreator(t *testing.T) {
|
||||||
|
ts := newTestMCPService(t)
|
||||||
|
owner := ts.createUser(t, "owner")
|
||||||
|
other := ts.createUser(t, "other")
|
||||||
|
|
||||||
|
memo := ts.createMemo(t, owner.ID, store.Public, "memo")
|
||||||
|
ts.archiveMemo(t, memo.ID)
|
||||||
|
attachment := ts.createAttachment(t, owner.ID, &memo.ID)
|
||||||
|
|
||||||
|
result, err := ts.service.handleGetAttachment(withUser(context.Background(), other.ID), toolRequest("get_attachment", map[string]any{
|
||||||
|
"name": "attachments/" + attachment.UID,
|
||||||
|
}))
|
||||||
|
require.NoError(t, err)
|
||||||
|
require.True(t, result.IsError)
|
||||||
|
require.Contains(t, firstText(t, result), "permission denied")
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestIsAllowedOrigin(t *testing.T) {
|
||||||
|
ts := newTestMCPService(t)
|
||||||
|
|
||||||
|
t.Run("allow missing origin", func(t *testing.T) {
|
||||||
|
req := httptest.NewRequest("POST", "http://localhost:5230/mcp", nil)
|
||||||
|
require.True(t, ts.service.isAllowedOrigin(req))
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("allow same origin as request host", func(t *testing.T) {
|
||||||
|
req := httptest.NewRequest("POST", "http://localhost:5230/mcp", nil)
|
||||||
|
req.Header.Set("Origin", "http://localhost:5230")
|
||||||
|
require.True(t, ts.service.isAllowedOrigin(req))
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("allow configured instance origin", func(t *testing.T) {
|
||||||
|
req := httptest.NewRequest("POST", "http://127.0.0.1:5230/mcp", nil)
|
||||||
|
req.Host = "127.0.0.1:5230"
|
||||||
|
req.Header.Set("Origin", "https://notes.example.com")
|
||||||
|
require.True(t, ts.service.isAllowedOrigin(req))
|
||||||
|
})
|
||||||
|
|
||||||
|
t.Run("reject cross origin", func(t *testing.T) {
|
||||||
|
req := httptest.NewRequest("POST", "http://localhost:5230/mcp", nil)
|
||||||
|
req.Header.Set("Origin", "https://evil.example.com")
|
||||||
|
require.False(t, ts.service.isAllowedOrigin(req))
|
||||||
|
})
|
||||||
|
}
|
||||||
Loading…
Reference in New Issue