diff --git a/server/auth/authenticator.go b/server/auth/authenticator.go index d66961b1a..16e98611f 100644 --- a/server/auth/authenticator.go +++ b/server/auth/authenticator.go @@ -141,7 +141,14 @@ func (a *Authenticator) AuthenticateToUser(ctx context.Context, authHeader, cook if !strings.HasPrefix(token, PersonalAccessTokenPrefix) { claims, err := a.AuthenticateByAccessTokenV2(token) if err == nil && claims != nil { - return a.store.GetUser(ctx, &store.FindUser{ID: &claims.UserID}) + user, err := a.store.GetUser(ctx, &store.FindUser{ID: &claims.UserID}) + if err != nil { + return nil, err + } + if user == nil || user.RowStatus == store.Archived { + return nil, nil + } + return user, nil } } else { user, _, err := a.AuthenticateByPAT(ctx, token) @@ -174,6 +181,10 @@ func (a *Authenticator) Authenticate(ctx context.Context, authHeader string) *Au if token != "" && !strings.HasPrefix(token, PersonalAccessTokenPrefix) { claims, err := a.AuthenticateByAccessTokenV2(token) if err == nil && claims != nil { + user, err := a.store.GetUser(ctx, &store.FindUser{ID: &claims.UserID}) + if err != nil || user == nil || user.RowStatus == store.Archived { + return nil + } return &AuthResult{ Claims: claims, AccessToken: token, diff --git a/server/router/api/v1/attachment_service.go b/server/router/api/v1/attachment_service.go index b2f874cf1..d7ca0b0ea 100644 --- a/server/router/api/v1/attachment_service.go +++ b/server/router/api/v1/attachment_service.go @@ -140,6 +140,24 @@ func (s *APIV1Service) CreateAttachment(ctx context.Context, request *v1pb.Creat create.Size = int64(size) create.Blob = request.Attachment.Content + if request.Attachment.Memo != nil { + memoUID, err := ExtractMemoUIDFromName(*request.Attachment.Memo) + if err != nil { + return nil, status.Errorf(codes.InvalidArgument, "invalid memo name: %v", err) + } + memo, err := s.Store.GetMemo(ctx, &store.FindMemo{UID: &memoUID}) + if err != nil { + return nil, status.Errorf(codes.Internal, "failed to find memo: %v", err) + } + if memo == nil { + return nil, status.Errorf(codes.NotFound, "memo not found: %s", *request.Attachment.Memo) + } + if !canModifyMemo(user, memo) { + return nil, status.Errorf(codes.PermissionDenied, "permission denied") + } + create.MemoID = &memo.ID + } + if create.Payload == nil || create.Payload.MotionMedia == nil { if detectedMotion := detectAndroidMotionMedia(create.Blob, create.Type, attachmentUID); detectedMotion != nil { create.Payload = ensureAttachmentPayload(create.Payload) @@ -172,20 +190,6 @@ func (s *APIV1Service) CreateAttachment(ctx context.Context, request *v1pb.Creat return nil, status.Errorf(codes.Internal, "failed to save attachment blob: %v", err) } - if request.Attachment.Memo != nil { - memoUID, err := ExtractMemoUIDFromName(*request.Attachment.Memo) - if err != nil { - return nil, status.Errorf(codes.InvalidArgument, "invalid memo name: %v", err) - } - memo, err := s.Store.GetMemo(ctx, &store.FindMemo{UID: &memoUID}) - if err != nil { - return nil, status.Errorf(codes.Internal, "failed to find memo: %v", err) - } - if memo == nil { - return nil, status.Errorf(codes.NotFound, "memo not found: %s", *request.Attachment.Memo) - } - create.MemoID = &memo.ID - } attachment, err := s.Store.CreateAttachment(ctx, create) if err != nil { return nil, status.Errorf(codes.Internal, "failed to create attachment: %v", err) diff --git a/server/router/api/v1/auth_service.go b/server/router/api/v1/auth_service.go index c2bfd7105..604bd414c 100644 --- a/server/router/api/v1/auth_service.go +++ b/server/router/api/v1/auth_service.go @@ -595,6 +595,9 @@ func (s *APIV1Service) fetchCurrentUser(ctx context.Context) (*store.User, error if user == nil { return nil, errors.Errorf("user %d not found", userID) } + if user.RowStatus == store.Archived { + return nil, nil + } return user, nil } diff --git a/server/router/api/v1/common.go b/server/router/api/v1/common.go index 087b0e5bc..f386d6e08 100644 --- a/server/router/api/v1/common.go +++ b/server/router/api/v1/common.go @@ -77,3 +77,7 @@ func unmarshalPageToken(s string, pageToken *v1pb.PageToken) error { func isSuperUser(user *store.User) bool { return user.Role == store.RoleAdmin } + +func canModifyMemo(user *store.User, memo *store.Memo) bool { + return user != nil && memo != nil && (memo.CreatorID == user.ID || isSuperUser(user)) +} diff --git a/server/router/api/v1/connect_interceptors_test.go b/server/router/api/v1/connect_interceptors_test.go index b4f8c79b9..62925610f 100644 --- a/server/router/api/v1/connect_interceptors_test.go +++ b/server/router/api/v1/connect_interceptors_test.go @@ -2,11 +2,16 @@ package v1 import ( "context" + "net/http" + "net/http/httptest" "testing" "connectrpc.com/connect" + "github.com/labstack/echo/v5" "google.golang.org/grpc/metadata" "google.golang.org/protobuf/types/known/emptypb" + + "github.com/usememos/memos/internal/profile" ) func TestMetadataInterceptorForwardsSecurityHeaders(t *testing.T) { @@ -37,3 +42,24 @@ func TestMetadataInterceptorForwardsSecurityHeaders(t *testing.T) { t.Fatalf("metadata interceptor returned error: %v", err) } } + +func TestAllowedConnectOrigin(t *testing.T) { + service := &APIV1Service{ + Profile: &profile.Profile{InstanceURL: "https://memos.example"}, + } + e := echo.New() + req := httptest.NewRequest(http.MethodOptions, "http://localhost/memos.api.v1.AuthService/SignIn", nil) + req.Host = "localhost" + rec := httptest.NewRecorder() + ctx := e.NewContext(req, rec) + + if !service.isAllowedConnectOrigin(ctx, "http://localhost") { + t.Fatal("expected same host origin to be allowed") + } + if !service.isAllowedConnectOrigin(ctx, "https://memos.example") { + t.Fatal("expected instance URL origin to be allowed") + } + if service.isAllowedConnectOrigin(ctx, "https://evil.example") { + t.Fatal("expected unknown origin to be denied") + } +} diff --git a/server/router/api/v1/memo_attachment_service.go b/server/router/api/v1/memo_attachment_service.go index 8d2b6a04a..6f6e31a14 100644 --- a/server/router/api/v1/memo_attachment_service.go +++ b/server/router/api/v1/memo_attachment_service.go @@ -32,7 +32,7 @@ func (s *APIV1Service) SetMemoAttachments(ctx context.Context, request *v1pb.Set if memo == nil { return nil, status.Errorf(codes.NotFound, "memo not found") } - if memo.CreatorID != user.ID && !isSuperUser(user) { + if !canModifyMemo(user, memo) { return nil, status.Errorf(codes.PermissionDenied, "permission denied") } if err := s.setMemoAttachmentsInternal(ctx, memo, request.Attachments); err != nil { diff --git a/server/router/api/v1/sso_username.go b/server/router/api/v1/sso_username.go index 4f55ea74a..0c40b92ec 100644 --- a/server/router/api/v1/sso_username.go +++ b/server/router/api/v1/sso_username.go @@ -13,7 +13,7 @@ import ( // retry loops around concurrent first-time logins. func deriveSSOUsername() (string, error) { username := util.GenUUID() - if err := validateUsername(username); err != nil { + if err := validateWritableUsername(username); err != nil { return "", errors.Wrap(err, "generated UUID did not satisfy username constraints") } return username, nil diff --git a/server/router/api/v1/test/attachment_service_test.go b/server/router/api/v1/test/attachment_service_test.go index b4cae7c81..cab318684 100644 --- a/server/router/api/v1/test/attachment_service_test.go +++ b/server/router/api/v1/test/attachment_service_test.go @@ -138,6 +138,123 @@ func TestCreateAttachment(t *testing.T) { }) } +func TestCreateAttachmentMemoPermission(t *testing.T) { + ctx := context.Background() + + t.Run("owner can create attachment directly linked to memo", func(t *testing.T) { + ts := NewTestService(t) + defer ts.Cleanup() + + owner, err := ts.CreateRegularUser(ctx, "attachment-owner") + require.NoError(t, err) + ownerCtx := ts.CreateUserContext(ctx, owner.ID) + + memo, err := ts.Service.CreateMemo(ownerCtx, &v1pb.CreateMemoRequest{ + Memo: &v1pb.Memo{ + Content: "memo with direct attachment", + }, + }) + require.NoError(t, err) + + attachment, err := ts.Service.CreateAttachment(ownerCtx, &v1pb.CreateAttachmentRequest{ + Attachment: &v1pb.Attachment{ + Filename: "owner.txt", + Type: "text/plain", + Content: []byte("owner"), + Memo: &memo.Name, + }, + }) + require.NoError(t, err) + attachmentUID, err := apiv1.ExtractAttachmentUIDFromName(attachment.Name) + require.NoError(t, err) + stored, err := ts.Store.GetAttachment(ctx, &store.FindAttachment{UID: &attachmentUID}) + require.NoError(t, err) + require.NotNil(t, stored.MemoID) + require.Equal(t, memoIDFromName(ctx, t, ts, memo.Name), *stored.MemoID) + }) + + t.Run("admin can create attachment directly linked to memo", func(t *testing.T) { + ts := NewTestService(t) + defer ts.Cleanup() + + owner, err := ts.CreateRegularUser(ctx, "attachment-admin-owner") + require.NoError(t, err) + ownerCtx := ts.CreateUserContext(ctx, owner.ID) + admin, err := ts.CreateHostUser(ctx, "attachment-admin") + require.NoError(t, err) + adminCtx := ts.CreateUserContext(ctx, admin.ID) + + memo, err := ts.Service.CreateMemo(ownerCtx, &v1pb.CreateMemoRequest{ + Memo: &v1pb.Memo{ + Content: "memo with admin attachment", + }, + }) + require.NoError(t, err) + + attachment, err := ts.Service.CreateAttachment(adminCtx, &v1pb.CreateAttachmentRequest{ + Attachment: &v1pb.Attachment{ + Filename: "admin.txt", + Type: "text/plain", + Content: []byte("admin"), + Memo: &memo.Name, + }, + }) + require.NoError(t, err) + attachmentUID, err := apiv1.ExtractAttachmentUIDFromName(attachment.Name) + require.NoError(t, err) + stored, err := ts.Store.GetAttachment(ctx, &store.FindAttachment{UID: &attachmentUID}) + require.NoError(t, err) + require.NotNil(t, stored.MemoID) + require.Equal(t, memoIDFromName(ctx, t, ts, memo.Name), *stored.MemoID) + }) + + t.Run("non-owner cannot create attachment directly linked to memo", func(t *testing.T) { + ts := NewTestService(t) + defer ts.Cleanup() + + owner, err := ts.CreateRegularUser(ctx, "attachment-owner-denied") + require.NoError(t, err) + ownerCtx := ts.CreateUserContext(ctx, owner.ID) + other, err := ts.CreateRegularUser(ctx, "attachment-other-denied") + require.NoError(t, err) + otherCtx := ts.CreateUserContext(ctx, other.ID) + + memo, err := ts.Service.CreateMemo(ownerCtx, &v1pb.CreateMemoRequest{ + Memo: &v1pb.Memo{ + Content: "memo with blocked attachment", + }, + }) + require.NoError(t, err) + + _, err = ts.Service.CreateAttachment(otherCtx, &v1pb.CreateAttachmentRequest{ + Attachment: &v1pb.Attachment{ + Filename: "blocked.txt", + Type: "text/plain", + Content: []byte("blocked"), + Memo: &memo.Name, + }, + }) + require.Error(t, err) + require.Contains(t, err.Error(), "permission denied") + + attachments, err := ts.Store.ListAttachments(ctx, &store.FindAttachment{ + CreatorID: &other.ID, + }) + require.NoError(t, err) + require.Empty(t, attachments) + }) +} + +func memoIDFromName(ctx context.Context, t *testing.T, ts *TestService, name string) int32 { + t.Helper() + memoUID, err := apiv1.ExtractMemoUIDFromName(name) + require.NoError(t, err) + memo, err := ts.Store.GetMemo(ctx, &store.FindMemo{UID: &memoUID}) + require.NoError(t, err) + require.NotNil(t, memo) + return memo.ID +} + func TestCreateAttachmentMotionMedia(t *testing.T) { ts := NewTestService(t) defer ts.Cleanup() diff --git a/server/router/api/v1/test/auth_service_test.go b/server/router/api/v1/test/auth_service_test.go index f41556556..5f5928d7a 100644 --- a/server/router/api/v1/test/auth_service_test.go +++ b/server/router/api/v1/test/auth_service_test.go @@ -154,6 +154,23 @@ func TestListAndDeleteLinkedIdentities(t *testing.T) { require.Empty(t, listResp.LinkedIdentities) } +func TestListLinkedIdentitiesRequiresAuthentication(t *testing.T) { + t.Parallel() + + ts := NewTestService(t) + defer ts.Cleanup() + + ctx := context.Background() + user, err := ts.CreateRegularUser(ctx, "linked-identity-auth") + require.NoError(t, err) + + _, err = ts.Service.ListLinkedIdentities(ctx, &v1pb.ListLinkedIdentitiesRequest{ + Parent: apiv1.BuildUserName(user.Username), + }) + require.Error(t, err) + require.Equal(t, codes.Unauthenticated, status.Code(err)) +} + func TestCreateLinkedIdentityRejectsSecondIdentityForSameProvider(t *testing.T) { t.Parallel() diff --git a/server/router/api/v1/test/auth_test.go b/server/router/api/v1/test/auth_test.go index 971ae4653..ee225cfac 100644 --- a/server/router/api/v1/test/auth_test.go +++ b/server/router/api/v1/test/auth_test.go @@ -78,6 +78,33 @@ func TestAuthenticatorAccessTokenV2(t *testing.T) { _, err = authenticator.AuthenticateByAccessTokenV2(token) assert.Error(t, err) }) + + t.Run("request authentication rejects archived user", func(t *testing.T) { + ts := NewTestService(t) + defer ts.Cleanup() + + user, err := ts.CreateRegularUser(ctx, "archived-access-token") + require.NoError(t, err) + token, _, err := auth.GenerateAccessTokenV2( + user.ID, + user.Username, + string(user.Role), + string(user.RowStatus), + []byte(ts.Secret), + ) + require.NoError(t, err) + + archivedStatus := store.Archived + _, err = ts.Store.UpdateUser(ctx, &store.UpdateUser{ + ID: user.ID, + RowStatus: &archivedStatus, + }) + require.NoError(t, err) + + authenticator := auth.NewAuthenticator(ts.Store, ts.Secret) + result := authenticator.Authenticate(ctx, "Bearer "+token) + assert.Nil(t, result) + }) } func TestAuthenticatorRefreshToken(t *testing.T) { diff --git a/server/router/api/v1/test/user_search_test.go b/server/router/api/v1/test/user_search_test.go index d108a7323..8027e74aa 100644 --- a/server/router/api/v1/test/user_search_test.go +++ b/server/router/api/v1/test/user_search_test.go @@ -52,3 +52,20 @@ func TestBatchGetUsersRejectsTooManyUsernames(t *testing.T) { require.Error(t, err) require.Contains(t, err.Error(), "too many usernames") } + +func TestBatchGetUsersRejectsTooManyNonEmptyUsernamesBeforeDedupe(t *testing.T) { + ctx := context.Background() + ts := NewTestService(t) + defer ts.Cleanup() + + usernames := make([]string, 0, 101) + for range 101 { + usernames = append(usernames, "legacy@example.com") + } + + _, err := ts.Service.BatchGetUsers(ctx, &apiv1.BatchGetUsersRequest{ + Usernames: usernames, + }) + require.Error(t, err) + require.Contains(t, err.Error(), "too many usernames") +} diff --git a/server/router/api/v1/test/user_service_email_username_test.go b/server/router/api/v1/test/user_service_email_username_test.go index 4215ac868..98fc5216f 100644 --- a/server/router/api/v1/test/user_service_email_username_test.go +++ b/server/router/api/v1/test/user_service_email_username_test.go @@ -5,6 +5,7 @@ import ( "testing" "github.com/stretchr/testify/require" + "golang.org/x/crypto/bcrypt" "google.golang.org/protobuf/types/known/fieldmaskpb" apiv1 "github.com/usememos/memos/proto/gen/api/v1" @@ -15,6 +16,26 @@ import ( func TestUserServiceWithEmailLikeUsername(t *testing.T) { ctx := context.Background() + t.Run("SignIn accepts email-like legacy username", func(t *testing.T) { + ts := NewTestService(t) + defer ts.Cleanup() + + user := createLegacyPasswordUser(ctx, t, ts, "signin@example.com", "password123") + + signInCtx := apiv1server.WithHeaderCarrier(ctx) + resp, err := ts.Service.SignIn(signInCtx, &apiv1.SignInRequest{ + Credentials: &apiv1.SignInRequest_PasswordCredentials_{ + PasswordCredentials: &apiv1.SignInRequest_PasswordCredentials{ + Username: user.Username, + Password: "password123", + }, + }, + }) + require.NoError(t, err) + require.Equal(t, user.Username, resp.User.Username) + require.NotEmpty(t, resp.AccessToken) + }) + t.Run("GetUser accepts email-like username in resource name", func(t *testing.T) { ts := NewTestService(t) defer ts.Cleanup() @@ -31,6 +52,38 @@ func TestUserServiceWithEmailLikeUsername(t *testing.T) { require.Equal(t, "users/alice@example.com", got.Name) }) + t.Run("BatchGetUsers accepts email-like legacy username", func(t *testing.T) { + ts := NewTestService(t) + defer ts.Cleanup() + + user, err := ts.CreateRegularUser(ctx, "batch@example.com") + require.NoError(t, err) + + resp, err := ts.Service.BatchGetUsers(ctx, &apiv1.BatchGetUsersRequest{ + Usernames: []string{" batch@example.com ", "missing@example.com", "batch@example.com"}, + }) + require.NoError(t, err) + require.Len(t, resp.Users, 1) + require.Equal(t, user.Username, resp.Users[0].Username) + require.Equal(t, "users/batch@example.com", resp.Users[0].Name) + }) + + t.Run("BatchGetUsers accepts underscore legacy username", func(t *testing.T) { + ts := NewTestService(t) + defer ts.Cleanup() + + user, err := ts.CreateRegularUser(ctx, "legacy_batch") + require.NoError(t, err) + + resp, err := ts.Service.BatchGetUsers(ctx, &apiv1.BatchGetUsersRequest{ + Usernames: []string{"legacy_batch"}, + }) + require.NoError(t, err) + require.Len(t, resp.Users, 1) + require.Equal(t, user.Username, resp.Users[0].Username) + require.Equal(t, "users/legacy_batch", resp.Users[0].Name) + }) + t.Run("ListUserSettings accepts email-like username in parent", func(t *testing.T) { ts := NewTestService(t) defer ts.Cleanup() @@ -92,14 +145,70 @@ func TestUserServiceWithEmailLikeUsername(t *testing.T) { require.Equal(t, "bob", stored.Username) }) + t.Run("UpdateUser rejects writing invalid username values", func(t *testing.T) { + for _, username := range []string{"alice@example.com", "legacy_user"} { + t.Run(username, func(t *testing.T) { + ts := NewTestService(t) + defer ts.Cleanup() + + user, err := ts.CreateRegularUser(ctx, "rename@example.com") + require.NoError(t, err) + + authCtx := ts.CreateUserContext(ctx, user.ID) + _, err = ts.Service.UpdateUser(authCtx, &apiv1.UpdateUserRequest{ + User: &apiv1.User{ + Name: "users/rename@example.com", + Username: username, + }, + UpdateMask: &fieldmaskpb.FieldMask{Paths: []string{"username"}}, + }) + require.Error(t, err) + require.Contains(t, err.Error(), "invalid username") + + stored, err := ts.Store.GetUser(ctx, &store.FindUser{ID: &user.ID}) + require.NoError(t, err) + require.NotNil(t, stored) + require.Equal(t, "rename@example.com", stored.Username) + }) + } + }) + + t.Run("admin cannot rename user to invalid username", func(t *testing.T) { + ts := NewTestService(t) + defer ts.Cleanup() + + user, err := ts.CreateRegularUser(ctx, "admin-rename-target") + require.NoError(t, err) + admin, err := ts.CreateHostUser(ctx, "rename-admin") + require.NoError(t, err) + + adminCtx := ts.CreateUserContext(ctx, admin.ID) + _, err = ts.Service.UpdateUser(adminCtx, &apiv1.UpdateUserRequest{ + User: &apiv1.User{ + Name: apiv1server.BuildUserName(user.Username), + Username: "admin@example.com", + }, + UpdateMask: &fieldmaskpb.FieldMask{Paths: []string{"username"}}, + }) + require.Error(t, err) + require.Contains(t, err.Error(), "invalid username") + + stored, err := ts.Store.GetUser(ctx, &store.FindUser{ID: &user.ID}) + require.NoError(t, err) + require.NotNil(t, stored) + require.Equal(t, "admin-rename-target", stored.Username) + }) + t.Run("UpdateUser can archive email-like username account", func(t *testing.T) { ts := NewTestService(t) defer ts.Cleanup() user, err := ts.CreateRegularUser(ctx, "dave@example.com") require.NoError(t, err) + admin, err := ts.CreateHostUser(ctx, "email-admin") + require.NoError(t, err) - authCtx := ts.CreateUserContext(ctx, user.ID) + authCtx := ts.CreateUserContext(ctx, admin.ID) updated, err := ts.Service.UpdateUser(authCtx, &apiv1.UpdateUserRequest{ User: &apiv1.User{ Name: "users/dave@example.com", @@ -134,3 +243,17 @@ func TestUserServiceWithEmailLikeUsername(t *testing.T) { require.Nil(t, deleted) }) } + +func createLegacyPasswordUser(ctx context.Context, t *testing.T, ts *TestService, username, password string) *store.User { + passwordHash, err := bcrypt.GenerateFromPassword([]byte(password), bcrypt.DefaultCost) + require.NoError(t, err) + + user, err := ts.Store.CreateUser(ctx, &store.User{ + Username: username, + Role: store.RoleUser, + Email: username, + PasswordHash: string(passwordHash), + }) + require.NoError(t, err) + return user +} diff --git a/server/router/api/v1/test/user_service_registration_test.go b/server/router/api/v1/test/user_service_registration_test.go index 052e97ead..56f8c0179 100644 --- a/server/router/api/v1/test/user_service_registration_test.go +++ b/server/router/api/v1/test/user_service_registration_test.go @@ -2,6 +2,8 @@ package test import ( "context" + "fmt" + "sync" "testing" "github.com/stretchr/testify/require" @@ -10,6 +12,7 @@ import ( apiv1 "github.com/usememos/memos/proto/gen/api/v1" storepb "github.com/usememos/memos/proto/gen/store" apiv1server "github.com/usememos/memos/server/router/api/v1" + "github.com/usememos/memos/store" ) func TestCreateUserRegistration(t *testing.T) { @@ -218,6 +221,41 @@ func TestCreateUserRegistration(t *testing.T) { require.Contains(t, err.Error(), "password must not be empty") }) + t.Run("CreateUser rejects invalid writable usernames", func(t *testing.T) { + for _, username := range []string{"alice@example.com", "legacy_user", "123"} { + t.Run(username, func(t *testing.T) { + ts := NewTestService(t) + defer ts.Cleanup() + + _, err := ts.Service.CreateUser(ctx, &apiv1.CreateUserRequest{ + User: &apiv1.User{ + Username: username, + Email: "newuser@example.com", + Password: "password123", + }, + }) + require.Error(t, err) + require.Contains(t, err.Error(), "invalid username") + }) + } + }) + + t.Run("CreateUser validate only rejects invalid writable username", func(t *testing.T) { + ts := NewTestService(t) + defer ts.Cleanup() + + _, err := ts.Service.CreateUser(ctx, &apiv1.CreateUserRequest{ + User: &apiv1.User{ + Username: "alice@example.com", + Email: "newuser@example.com", + Password: "password123", + }, + ValidateOnly: true, + }) + require.Error(t, err) + require.Contains(t, err.Error(), "invalid username") + }) + t.Run("UpdateUser rejects empty password", func(t *testing.T) { ts := NewTestService(t) defer ts.Cleanup() @@ -236,4 +274,114 @@ func TestCreateUserRegistration(t *testing.T) { require.Error(t, err) require.Contains(t, err.Error(), "password must not be empty") }) + + t.Run("UpdateUser rejects missing user message", func(t *testing.T) { + ts := NewTestService(t) + defer ts.Cleanup() + + user, err := ts.CreateRegularUser(ctx, "missing-message") + require.NoError(t, err) + + authCtx := ts.CreateUserContext(ctx, user.ID) + _, err = ts.Service.UpdateUser(authCtx, &apiv1.UpdateUserRequest{ + UpdateMask: &fieldmaskpb.FieldMask{Paths: []string{"display_name"}}, + }) + require.Error(t, err) + require.Contains(t, err.Error(), "user is required") + }) + + t.Run("CreateUser concurrent first setup creates one admin", func(t *testing.T) { + ts := NewTestService(t) + defer ts.Cleanup() + + const workers = 12 + var wg sync.WaitGroup + for i := range workers { + wg.Go(func() { + _, _ = ts.Service.CreateUser(ctx, &apiv1.CreateUserRequest{ + User: &apiv1.User{ + Username: fmt.Sprintf("setup-user-%d", i), + Email: "setup-user@example.com", + Password: "password123", + }, + }) + }) + } + wg.Wait() + + users, err := ts.Store.ListUsers(ctx, &store.FindUser{}) + require.NoError(t, err) + adminCount := 0 + for _, user := range users { + if user.Role == store.RoleAdmin { + adminCount++ + } + } + require.Equal(t, 1, adminCount) + }) + + t.Run("UpdateUser state requires admin", func(t *testing.T) { + ts := NewTestService(t) + defer ts.Cleanup() + + user, err := ts.CreateRegularUser(ctx, "state-user") + require.NoError(t, err) + userCtx := ts.CreateUserContext(ctx, user.ID) + _, err = ts.Service.UpdateUser(userCtx, &apiv1.UpdateUserRequest{ + User: &apiv1.User{ + Name: apiv1server.BuildUserName(user.Username), + State: apiv1.State_ARCHIVED, + }, + UpdateMask: &fieldmaskpb.FieldMask{Paths: []string{"state"}}, + }) + require.Error(t, err) + require.Contains(t, err.Error(), "permission denied") + + admin, err := ts.CreateHostUser(ctx, "state-admin") + require.NoError(t, err) + adminCtx := ts.CreateUserContext(ctx, admin.ID) + updated, err := ts.Service.UpdateUser(adminCtx, &apiv1.UpdateUserRequest{ + User: &apiv1.User{ + Name: apiv1server.BuildUserName(user.Username), + State: apiv1.State_ARCHIVED, + }, + UpdateMask: &fieldmaskpb.FieldMask{Paths: []string{"state"}}, + }) + require.NoError(t, err) + require.Equal(t, apiv1.State_ARCHIVED, updated.State) + }) + + t.Run("archived user context is rejected", func(t *testing.T) { + ts := NewTestService(t) + defer ts.Cleanup() + + user, err := ts.CreateRegularUser(ctx, "archived-access-user") + require.NoError(t, err) + userCtx := ts.CreateUserContext(ctx, user.ID) + archived := store.Archived + _, err = ts.Store.UpdateUser(ctx, &store.UpdateUser{ + ID: user.ID, + RowStatus: &archived, + }) + require.NoError(t, err) + + _, err = ts.Service.GetCurrentUser(userCtx, &apiv1.GetCurrentUserRequest{}) + require.Error(t, err) + + _, err = ts.Service.CreateMemo(userCtx, &apiv1.CreateMemoRequest{ + Memo: &apiv1.Memo{ + Content: "should not be created", + }, + }) + require.Error(t, err) + + _, err = ts.Service.UpdateUser(userCtx, &apiv1.UpdateUserRequest{ + User: &apiv1.User{ + Name: apiv1server.BuildUserName(user.Username), + State: apiv1.State_NORMAL, + }, + UpdateMask: &fieldmaskpb.FieldMask{Paths: []string{"state"}}, + }) + require.Error(t, err) + }) } diff --git a/server/router/api/v1/user_resource_name.go b/server/router/api/v1/user_resource_name.go index dd677777a..5ac408ed4 100644 --- a/server/router/api/v1/user_resource_name.go +++ b/server/router/api/v1/user_resource_name.go @@ -26,7 +26,7 @@ func parseUsernameFromName(name string) (string, error) { return username, nil } -func validateUsername(username string) error { +func validateWritableUsername(username string) error { if username == "" || isNumericUsername(username) || !base.UIDMatcher.MatchString(username) { return errors.Errorf("invalid username %q", username) } diff --git a/server/router/api/v1/user_resource_name_test.go b/server/router/api/v1/user_resource_name_test.go new file mode 100644 index 000000000..35c8572d0 --- /dev/null +++ b/server/router/api/v1/user_resource_name_test.go @@ -0,0 +1,116 @@ +package v1 + +import ( + "testing" +) + +func TestValidateWritableUsername(t *testing.T) { + tests := []struct { + name string + username string + wantError bool + }{ + { + name: "lowercase", + username: "alice", + }, + { + name: "mixed case", + username: "Alice", + }, + { + name: "hyphenated", + username: "alice-smith", + }, + { + name: "uuid", + username: "550e8400-e29b-41d4-a716-446655440000", + }, + { + name: "empty", + username: "", + wantError: true, + }, + { + name: "numeric", + username: "123", + wantError: true, + }, + { + name: "email", + username: "alice@example.com", + wantError: true, + }, + { + name: "underscore", + username: "alice_smith", + wantError: true, + }, + { + name: "space", + username: "alice smith", + wantError: true, + }, + { + name: "slash", + username: "alice/smith", + wantError: true, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + err := validateWritableUsername(test.username) + if test.wantError && err == nil { + t.Fatalf("validateWritableUsername(%q) succeeded, want error", test.username) + } + if !test.wantError && err != nil { + t.Fatalf("validateWritableUsername(%q) returned error: %v", test.username, err) + } + }) + } +} + +func TestParseUsernameFromNameAllowsLegacyUsernames(t *testing.T) { + tests := []struct { + name string + want string + wantFail bool + }{ + { + name: "users/alice", + want: "alice", + }, + { + name: "users/alice@example.com", + want: "alice@example.com", + }, + { + name: "users/alice_smith", + want: "alice_smith", + }, + { + name: "users/", + wantFail: true, + }, + { + name: "invalid/alice", + wantFail: true, + }, + } + + for _, test := range tests { + t.Run(test.name, func(t *testing.T) { + got, err := parseUsernameFromName(test.name) + if test.wantFail && err == nil { + t.Fatalf("parseUsernameFromName(%q) succeeded, want error", test.name) + } + if !test.wantFail && err != nil { + t.Fatalf("parseUsernameFromName(%q) returned error: %v", test.name, err) + } + if got != test.want { + t.Fatalf("parseUsernameFromName(%q) = %q, want %q", test.name, got, test.want) + } + }) + } +} diff --git a/server/router/api/v1/user_service.go b/server/router/api/v1/user_service.go index 4550815f9..8fab09b8d 100644 --- a/server/router/api/v1/user_service.go +++ b/server/router/api/v1/user_service.go @@ -78,21 +78,23 @@ func (s *APIV1Service) ListUsers(ctx context.Context, request *v1pb.ListUsersReq return response, nil } -func normalizeBatchUsernames(usernames []string) []string { +func normalizeBatchUsernames(usernames []string) ([]string, int) { uniqueUsernames := make([]string, 0, len(usernames)) seen := make(map[string]struct{}, len(usernames)) + nonEmptyCount := 0 for _, username := range usernames { username = strings.TrimSpace(username) - if validateUsername(username) != nil { + if username == "" { continue } + nonEmptyCount++ if _, ok := seen[username]; ok { continue } seen[username] = struct{}{} uniqueUsernames = append(uniqueUsernames, username) } - return uniqueUsernames + return uniqueUsernames, nonEmptyCount } func (s *APIV1Service) BatchGetUsers(ctx context.Context, request *v1pb.BatchGetUsersRequest) (*v1pb.BatchGetUsersResponse, error) { @@ -100,8 +102,8 @@ func (s *APIV1Service) BatchGetUsers(ctx context.Context, request *v1pb.BatchGet return &v1pb.BatchGetUsersResponse{Users: []*v1pb.User{}}, nil } - uniqueUsernames := normalizeBatchUsernames(request.Usernames) - if len(uniqueUsernames) > maxBatchGetUsers { + uniqueUsernames, nonEmptyUsernameCount := normalizeBatchUsernames(request.Usernames) + if nonEmptyUsernameCount > maxBatchGetUsers { return nil, status.Errorf(codes.InvalidArgument, "too many usernames (max %d)", maxBatchGetUsers) } @@ -144,18 +146,54 @@ func (s *APIV1Service) CreateUser(ctx context.Context, request *v1pb.CreateUserR // Get current user (might be nil for unauthenticated requests) currentUser, _ := s.fetchCurrentUser(ctx) - // Check if there are any existing users (for first-time setup detection) - limitOne := 1 - allUsers, err := s.Store.ListUsers(ctx, &store.FindUser{Limit: &limitOne}) - if err != nil { - return nil, status.Errorf(codes.Internal, "failed to list users: %v", err) + if request.User == nil { + return nil, status.Errorf(codes.InvalidArgument, "user is required") + } + if err := validateWritableUsername(request.User.Username); err != nil { + return nil, status.Errorf(codes.InvalidArgument, "invalid username: %s", request.User.Username) + } + if err := validatePassword(request.User.Password); err != nil { + return nil, status.Errorf(codes.InvalidArgument, "%v", err) } - isFirstUser := len(allUsers) == 0 - // Check registration settings FIRST (unless it's the very first user) - if !isFirstUser { + roleToAssign := store.RoleUser + if currentUser != nil && currentUser.Role == store.RoleAdmin { + // Authenticated ADMIN user can create users with any role specified in request + if request.User.Role != v1pb.User_ROLE_UNSPECIFIED { + roleToAssign = convertUserRoleToStore(request.User.Role) + } + } else { + limitOne := 1 + allUsers, err := s.Store.ListUsers(ctx, &store.FindUser{Limit: &limitOne}) + if err != nil { + return nil, status.Errorf(codes.Internal, "failed to list users: %v", err) + } + if len(allUsers) == 0 { + roleToAssign = store.RoleAdmin + if !request.ValidateOnly { + passwordHash, err := bcrypt.GenerateFromPassword([]byte(request.User.Password), bcrypt.DefaultCost) + if err != nil { + return nil, status.Errorf(codes.Internal, "failed to generate password hash: %v", err) + } + user, created, err := s.Store.CreateUserIfNoUsers(ctx, &store.User{ + Username: request.User.Username, + Role: store.RoleAdmin, + Email: request.User.Email, + Nickname: request.User.DisplayName, + PasswordHash: string(passwordHash), + }) + if err != nil { + return nil, status.Errorf(codes.Internal, "failed to create first user: %v", err) + } + if created { + return convertUserFromStore(user, user), nil + } + roleToAssign = store.RoleUser + } + } + // Only allow user registration if it is enabled in the settings, or if the user is a superuser - if currentUser == nil || !isSuperUser(currentUser) { + if roleToAssign != store.RoleAdmin { instanceGeneralSetting, err := s.Store.GetInstanceGeneralSetting(ctx) if err != nil { return nil, status.Errorf(codes.Internal, "failed to get instance general setting, error: %v", err) @@ -169,30 +207,6 @@ func (s *APIV1Service) CreateUser(ctx context.Context, request *v1pb.CreateUserR } } - // Determine the role to assign - var roleToAssign store.Role - if isFirstUser { - // First-time setup: create the first user as ADMIN (no authentication required) - roleToAssign = store.RoleAdmin - } else if currentUser != nil && currentUser.Role == store.RoleAdmin { - // Authenticated ADMIN user can create users with any role specified in request - if request.User.Role != v1pb.User_ROLE_UNSPECIFIED { - roleToAssign = convertUserRoleToStore(request.User.Role) - } else { - roleToAssign = store.RoleUser - } - } else { - // Unauthenticated or non-ADMIN users can only create normal users - roleToAssign = store.RoleUser - } - - if err := validateUsername(request.User.Username); err != nil { - return nil, status.Errorf(codes.InvalidArgument, "invalid username: %s", request.User.Username) - } - if err := validatePassword(request.User.Password); err != nil { - return nil, status.Errorf(codes.InvalidArgument, "%v", err) - } - // If validate_only is true, just validate without creating if request.ValidateOnly { // Perform validation checks without actually creating the user @@ -224,6 +238,9 @@ func (s *APIV1Service) CreateUser(ctx context.Context, request *v1pb.CreateUserR } func (s *APIV1Service) UpdateUser(ctx context.Context, request *v1pb.UpdateUserRequest) (*v1pb.User, error) { + if request.User == nil { + return nil, status.Errorf(codes.InvalidArgument, "user is required") + } if request.UpdateMask == nil || len(request.UpdateMask.Paths) == 0 { return nil, status.Errorf(codes.InvalidArgument, "update mask is empty") } @@ -266,7 +283,7 @@ func (s *APIV1Service) UpdateUser(ctx context.Context, request *v1pb.UpdateUserR if instanceGeneralSetting.DisallowChangeUsername { return nil, status.Errorf(codes.PermissionDenied, "permission denied: disallow change username") } - if err := validateUsername(request.User.Username); err != nil { + if err := validateWritableUsername(request.User.Username); err != nil { return nil, status.Errorf(codes.InvalidArgument, "invalid username: %s", request.User.Username) } update.Username = &request.User.Username @@ -317,6 +334,9 @@ func (s *APIV1Service) UpdateUser(ctx context.Context, request *v1pb.UpdateUserR passwordHashStr := string(passwordHash) update.PasswordHash = &passwordHashStr case "state": + if currentUser.Role != store.RoleAdmin { + return nil, status.Errorf(codes.PermissionDenied, "permission denied") + } rowStatus := convertStateToStore(request.User.State) update.RowStatus = &rowStatus default: @@ -661,6 +681,20 @@ func (s *APIV1Service) ListUserSettings(ctx context.Context, request *v1pb.ListU return response, nil } +func (s *APIV1Service) authorizeUserResourceAccess(ctx context.Context, userID int32, allowAdmin bool) (*store.User, error) { + currentUser, err := s.fetchCurrentUser(ctx) + if err != nil { + return nil, status.Errorf(codes.Internal, "failed to get current user: %v", err) + } + if currentUser == nil { + return nil, status.Errorf(codes.Unauthenticated, "user not authenticated") + } + if currentUser.ID == userID || (allowAdmin && currentUser.Role == store.RoleAdmin) { + return currentUser, nil + } + return nil, status.Errorf(codes.PermissionDenied, "permission denied") +} + func (s *APIV1Service) ListLinkedIdentities(ctx context.Context, request *v1pb.ListLinkedIdentitiesRequest) (*v1pb.ListLinkedIdentitiesResponse, error) { user, err := s.resolveUserFromName(ctx, request.Parent) if err != nil { @@ -668,12 +702,8 @@ func (s *APIV1Service) ListLinkedIdentities(ctx context.Context, request *v1pb.L } userID := user.ID - claims := auth.GetUserClaims(ctx) - if claims == nil || claims.UserID != userID { - currentUser, _ := s.fetchCurrentUser(ctx) - if currentUser == nil || (currentUser.ID != userID && currentUser.Role != store.RoleAdmin) { - return nil, status.Errorf(codes.PermissionDenied, "permission denied") - } + if _, err := s.authorizeUserResourceAccess(ctx, userID, true); err != nil { + return nil, err } identities, err := s.Store.ListUserIdentities(ctx, &store.FindUserIdentity{UserID: &userID}) @@ -739,12 +769,8 @@ func (s *APIV1Service) GetLinkedIdentity(ctx context.Context, request *v1pb.GetL } userID := user.ID - claims := auth.GetUserClaims(ctx) - if claims == nil || claims.UserID != userID { - currentUser, _ := s.fetchCurrentUser(ctx) - if currentUser == nil || (currentUser.ID != userID && currentUser.Role != store.RoleAdmin) { - return nil, status.Errorf(codes.PermissionDenied, "permission denied") - } + if _, err := s.authorizeUserResourceAccess(ctx, userID, true); err != nil { + return nil, err } identity, err := s.Store.GetUserIdentity(ctx, &store.FindUserIdentity{ @@ -768,12 +794,8 @@ func (s *APIV1Service) DeleteLinkedIdentity(ctx context.Context, request *v1pb.D } userID := user.ID - claims := auth.GetUserClaims(ctx) - if claims == nil || claims.UserID != userID { - currentUser, _ := s.fetchCurrentUser(ctx) - if currentUser == nil || (currentUser.ID != userID && currentUser.Role != store.RoleAdmin) { - return nil, status.Errorf(codes.PermissionDenied, "permission denied") - } + if _, err := s.authorizeUserResourceAccess(ctx, userID, true); err != nil { + return nil, err } existing, err := s.Store.GetUserIdentity(ctx, &store.FindUserIdentity{ @@ -819,12 +841,8 @@ func (s *APIV1Service) ListPersonalAccessTokens(ctx context.Context, request *v1 userID := user.ID // Verify permission - claims := auth.GetUserClaims(ctx) - if claims == nil || claims.UserID != userID { - currentUser, _ := s.fetchCurrentUser(ctx) - if currentUser == nil || (currentUser.ID != userID && currentUser.Role != store.RoleAdmin) { - return nil, status.Errorf(codes.PermissionDenied, "permission denied") - } + if _, err := s.authorizeUserResourceAccess(ctx, userID, true); err != nil { + return nil, err } tokens, err := s.Store.GetUserPersonalAccessTokens(ctx, userID) @@ -874,12 +892,8 @@ func (s *APIV1Service) CreatePersonalAccessToken(ctx context.Context, request *v userID := user.ID // Verify permission - claims := auth.GetUserClaims(ctx) - if claims == nil || claims.UserID != userID { - currentUser, _ := s.fetchCurrentUser(ctx) - if currentUser == nil || currentUser.ID != userID { - return nil, status.Errorf(codes.PermissionDenied, "permission denied") - } + if _, err := s.authorizeUserResourceAccess(ctx, userID, false); err != nil { + return nil, err } // Generate PAT @@ -942,12 +956,8 @@ func (s *APIV1Service) DeletePersonalAccessToken(ctx context.Context, request *v tokenID := parts[3] // Verify permission - claims := auth.GetUserClaims(ctx) - if claims == nil || claims.UserID != userID { - currentUser, _ := s.fetchCurrentUser(ctx) - if currentUser == nil || currentUser.ID != userID { - return nil, status.Errorf(codes.PermissionDenied, "permission denied") - } + if _, err := s.authorizeUserResourceAccess(ctx, userID, false); err != nil { + return nil, err } if err := s.Store.RemoveUserPersonalAccessToken(ctx, userID, tokenID); err != nil { diff --git a/server/router/api/v1/v1.go b/server/router/api/v1/v1.go index ad974b4a5..259a1b8b9 100644 --- a/server/router/api/v1/v1.go +++ b/server/router/api/v1/v1.go @@ -3,6 +3,8 @@ package v1 import ( "context" "net/http" + "net/url" + "strings" "connectrpc.com/connect" "github.com/grpc-ecosystem/grpc-gateway/v2/runtime" @@ -143,8 +145,11 @@ func (s *APIV1Service) RegisterGateway(ctx context.Context, echoServer *echo.Ech // Wrap with CORS for browser access corsHandler := middleware.CORSWithConfig(middleware.CORSConfig{ - UnsafeAllowOriginFunc: func(_ *echo.Context, origin string) (string, bool, error) { - return origin, true, nil + UnsafeAllowOriginFunc: func(c *echo.Context, origin string) (string, bool, error) { + if s.isAllowedConnectOrigin(c, origin) { + return origin, true, nil + } + return "", false, nil }, AllowMethods: []string{http.MethodGet, http.MethodPost, http.MethodOptions}, AllowHeaders: []string{"*"}, @@ -155,3 +160,23 @@ func (s *APIV1Service) RegisterGateway(ctx context.Context, echoServer *echo.Ech return nil } + +func (s *APIV1Service) isAllowedConnectOrigin(c *echo.Context, origin string) bool { + originURL, err := url.Parse(origin) + if err != nil || originURL.Scheme == "" || originURL.Host == "" { + return false + } + + if strings.EqualFold(originURL.Host, c.Request().Host) { + return true + } + + if s.Profile == nil || s.Profile.InstanceURL == "" { + return false + } + instanceURL, err := url.Parse(s.Profile.InstanceURL) + if err != nil || instanceURL.Scheme == "" || instanceURL.Host == "" { + return false + } + return strings.EqualFold(originURL.Scheme, instanceURL.Scheme) && strings.EqualFold(originURL.Host, instanceURL.Host) +} diff --git a/store/store.go b/store/store.go index bb3e54dad..5858c9a22 100644 --- a/store/store.go +++ b/store/store.go @@ -1,6 +1,7 @@ package store import ( + "sync" "time" "github.com/usememos/memos/internal/profile" @@ -12,6 +13,8 @@ type Store struct { profile *profile.Profile driver Driver + userCreateMu sync.Mutex + // Cache settings cacheConfig cache.Config diff --git a/store/user.go b/store/user.go index 9826c0a77..d07759ea4 100644 --- a/store/user.go +++ b/store/user.go @@ -95,6 +95,30 @@ func (s *Store) CreateUser(ctx context.Context, create *User) (*User, error) { return user, nil } +// CreateUserIfNoUsers creates a user only when the instance has no users. +// The in-process lock prevents concurrent first-user setup requests from +// creating multiple admins in the same server process. +func (s *Store) CreateUserIfNoUsers(ctx context.Context, create *User) (*User, bool, error) { + s.userCreateMu.Lock() + defer s.userCreateMu.Unlock() + + limitOne := 1 + users, err := s.driver.ListUsers(ctx, &FindUser{Limit: &limitOne}) + if err != nil { + return nil, false, err + } + if len(users) > 0 { + return nil, false, nil + } + + user, err := s.driver.CreateUser(ctx, create) + if err != nil { + return nil, false, err + } + s.userCache.Set(ctx, userCacheKey(user.ID), user) + return user, true, nil +} + func (s *Store) UpdateUser(ctx context.Context, update *UpdateUser) (*User, error) { user, err := s.driver.UpdateUser(ctx, update) if err != nil {