From 30c0611a82f9254952a74650095105254f2940e4 Mon Sep 17 00:00:00 2001 From: boojack Date: Thu, 23 Apr 2026 22:35:38 +0800 Subject: [PATCH] fix: fix legacy username auth flows (#5885) --- server/router/api/v1/auth_service.go | 44 ++++-- .../api/v1/auth_service_client_info_test.go | 49 +++++++ server/router/api/v1/connect_interceptors.go | 9 ++ .../api/v1/connect_interceptors_test.go | 39 +++++ .../api/v1/test/shortcut_service_test.go | 4 +- .../api/v1/test/user_notification_test.go | 2 +- .../api/v1/test/user_resource_name_test.go | 85 ++++++++++- .../test/user_service_email_username_test.go | 136 ++++++++++++++++++ .../v1/test/user_service_registration_test.go | 64 +++++++++ .../api/v1/test/user_service_stats_test.go | 2 +- server/router/api/v1/user_resource_name.go | 8 +- server/router/api/v1/user_service.go | 16 +++ web/src/pages/SignUp.tsx | 5 +- 13 files changed, 437 insertions(+), 26 deletions(-) create mode 100644 server/router/api/v1/connect_interceptors_test.go create mode 100644 server/router/api/v1/test/user_service_email_username_test.go diff --git a/server/router/api/v1/auth_service.go b/server/router/api/v1/auth_service.go index 85d904dc8..86ae6cd9e 100644 --- a/server/router/api/v1/auth_service.go +++ b/server/router/api/v1/auth_service.go @@ -529,6 +529,36 @@ func (s *APIV1Service) clearAuthCookies(ctx context.Context) error { return nil } +func isSecureRequest(ctx context.Context) bool { + md, ok := metadata.FromIncomingContext(ctx) + if !ok { + return false + } + + for _, value := range md.Get("x-forwarded-proto") { + for _, proto := range strings.Split(value, ",") { + if strings.EqualFold(strings.TrimSpace(proto), "https") { + return true + } + } + } + + for _, value := range md.Get("forwarded") { + lowerValue := strings.ToLower(value) + if strings.Contains(lowerValue, "proto=https") { + return true + } + } + + for _, value := range md.Get("origin") { + if strings.HasPrefix(strings.ToLower(strings.TrimSpace(value)), "https://") { + return true + } + } + + return false +} + func (*APIV1Service) buildRefreshTokenCookie(ctx context.Context, refreshToken string, expireTime time.Time) string { attrs := []string{ fmt.Sprintf("%s=%s", auth.RefreshTokenCookieName, refreshToken), @@ -543,19 +573,7 @@ func (*APIV1Service) buildRefreshTokenCookie(ctx context.Context, refreshToken s attrs = append(attrs, "Expires="+expireTime.UTC().Format("Mon, 02 Jan 2006 15:04:05 GMT")) } - // Try to determine if the request is HTTPS by checking the origin header - // Default to non-HTTPS (Lax SameSite) if metadata is not available - isHTTPS := false - if md, ok := metadata.FromIncomingContext(ctx); ok { - for _, v := range md.Get("origin") { - if strings.HasPrefix(v, "https://") { - isHTTPS = true - break - } - } - } - - if isHTTPS { + if isSecureRequest(ctx) { attrs = append(attrs, "SameSite=Lax", "Secure") } else { attrs = append(attrs, "SameSite=Lax") diff --git a/server/router/api/v1/auth_service_client_info_test.go b/server/router/api/v1/auth_service_client_info_test.go index 300663c04..682fd7f58 100644 --- a/server/router/api/v1/auth_service_client_info_test.go +++ b/server/router/api/v1/auth_service_client_info_test.go @@ -2,7 +2,9 @@ package v1 import ( "context" + "strings" "testing" + "time" "google.golang.org/grpc/metadata" @@ -177,3 +179,50 @@ func TestClientInfoExamples(t *testing.T) { }) } } + +func TestBuildRefreshTokenCookieSecureFlag(t *testing.T) { + service := &APIV1Service{} + + t.Run("sets Secure for https origin", func(t *testing.T) { + ctx := metadata.NewIncomingContext(context.Background(), metadata.Pairs( + "origin", "https://memos.example", + )) + cookie := service.buildRefreshTokenCookie(ctx, "token", testCookieExpiry()) + if !containsCookieAttribute(cookie, "Secure") { + t.Fatalf("expected Secure attribute in cookie: %s", cookie) + } + }) + + t.Run("sets Secure for forwarded proto", func(t *testing.T) { + ctx := metadata.NewIncomingContext(context.Background(), metadata.Pairs( + "x-forwarded-proto", "https", + )) + cookie := service.buildRefreshTokenCookie(ctx, "token", testCookieExpiry()) + if !containsCookieAttribute(cookie, "Secure") { + t.Fatalf("expected Secure attribute in cookie: %s", cookie) + } + }) + + t.Run("omits Secure for plain http", func(t *testing.T) { + ctx := metadata.NewIncomingContext(context.Background(), metadata.Pairs( + "origin", "http://memos.example", + )) + cookie := service.buildRefreshTokenCookie(ctx, "token", testCookieExpiry()) + if containsCookieAttribute(cookie, "Secure") { + t.Fatalf("did not expect Secure attribute in cookie: %s", cookie) + } + }) +} + +func testCookieExpiry() time.Time { + return time.Date(2030, time.January, 2, 3, 4, 5, 0, time.UTC) +} + +func containsCookieAttribute(cookie, attr string) bool { + for _, part := range strings.Split(cookie, ";") { + if strings.EqualFold(strings.TrimSpace(part), attr) { + return true + } + } + return false +} diff --git a/server/router/api/v1/connect_interceptors.go b/server/router/api/v1/connect_interceptors.go index 9ea26f3b0..84d736e7a 100644 --- a/server/router/api/v1/connect_interceptors.go +++ b/server/router/api/v1/connect_interceptors.go @@ -38,12 +38,21 @@ func (*MetadataInterceptor) WrapUnary(next connect.UnaryFunc) connect.UnaryFunc if ua := header.Get("User-Agent"); ua != "" { md.Set("user-agent", ua) } + if origin := header.Get("Origin"); origin != "" { + md.Set("origin", origin) + } if xff := header.Get("X-Forwarded-For"); xff != "" { md.Set("x-forwarded-for", xff) } + if xfp := header.Get("X-Forwarded-Proto"); xfp != "" { + md.Set("x-forwarded-proto", xfp) + } if xri := header.Get("X-Real-Ip"); xri != "" { md.Set("x-real-ip", xri) } + if forwarded := header.Get("Forwarded"); forwarded != "" { + md.Set("forwarded", forwarded) + } // Forward Cookie header for authentication methods that need it (e.g., RefreshToken) if cookie := header.Get("Cookie"); cookie != "" { md.Set("cookie", cookie) diff --git a/server/router/api/v1/connect_interceptors_test.go b/server/router/api/v1/connect_interceptors_test.go new file mode 100644 index 000000000..b4f8c79b9 --- /dev/null +++ b/server/router/api/v1/connect_interceptors_test.go @@ -0,0 +1,39 @@ +package v1 + +import ( + "context" + "testing" + + "connectrpc.com/connect" + "google.golang.org/grpc/metadata" + "google.golang.org/protobuf/types/known/emptypb" +) + +func TestMetadataInterceptorForwardsSecurityHeaders(t *testing.T) { + interceptor := NewMetadataInterceptor() + req := connect.NewRequest(&emptypb.Empty{}) + req.Header().Set("Origin", "https://memos.example") + req.Header().Set("X-Forwarded-Proto", "https") + req.Header().Set("Forwarded", "for=203.0.113.1;proto=https") + + handler := interceptor.WrapUnary(func(ctx context.Context, _ connect.AnyRequest) (connect.AnyResponse, error) { + md, ok := metadata.FromIncomingContext(ctx) + if !ok { + t.Fatal("expected metadata in context") + } + if got := md.Get("origin"); len(got) != 1 || got[0] != "https://memos.example" { + t.Fatalf("unexpected origin metadata: %v", got) + } + if got := md.Get("x-forwarded-proto"); len(got) != 1 || got[0] != "https" { + t.Fatalf("unexpected x-forwarded-proto metadata: %v", got) + } + if got := md.Get("forwarded"); len(got) != 1 || got[0] != "for=203.0.113.1;proto=https" { + t.Fatalf("unexpected forwarded metadata: %v", got) + } + return connect.NewResponse(&emptypb.Empty{}), nil + }) + + if _, err := handler(context.Background(), req); err != nil { + t.Fatalf("metadata interceptor returned error: %v", err) + } +} diff --git a/server/router/api/v1/test/shortcut_service_test.go b/server/router/api/v1/test/shortcut_service_test.go index bc695c451..a0e1c9920 100644 --- a/server/router/api/v1/test/shortcut_service_test.go +++ b/server/router/api/v1/test/shortcut_service_test.go @@ -94,7 +94,7 @@ func TestListShortcuts(t *testing.T) { require.Contains(t, err.Error(), "permission denied") }) - t.Run("ListShortcuts rejects numeric parent", func(t *testing.T) { + t.Run("ListShortcuts returns not found for numeric parent", func(t *testing.T) { ts := NewTestService(t) defer ts.Cleanup() @@ -107,7 +107,7 @@ func TestListShortcuts(t *testing.T) { Parent: "users/1", }) require.Error(t, err) - require.Contains(t, err.Error(), "invalid user name") + require.Contains(t, err.Error(), "user not found") }) } diff --git a/server/router/api/v1/test/user_notification_test.go b/server/router/api/v1/test/user_notification_test.go index 7b58cd2cd..fe5a91a92 100644 --- a/server/router/api/v1/test/user_notification_test.go +++ b/server/router/api/v1/test/user_notification_test.go @@ -202,7 +202,7 @@ func TestListUserNotificationsRejectsNumericParent(t *testing.T) { Parent: "users/1", }) require.Error(t, err) - require.Contains(t, err.Error(), "invalid user name") + require.Contains(t, err.Error(), "user not found") } func TestListUserNotificationsIncludesMemoMentionPayload(t *testing.T) { diff --git a/server/router/api/v1/test/user_resource_name_test.go b/server/router/api/v1/test/user_resource_name_test.go index 799cf63eb..b3463ba9b 100644 --- a/server/router/api/v1/test/user_resource_name_test.go +++ b/server/router/api/v1/test/user_resource_name_test.go @@ -5,8 +5,11 @@ import ( "testing" "github.com/stretchr/testify/require" + "google.golang.org/protobuf/types/known/fieldmaskpb" apiv1 "github.com/usememos/memos/proto/gen/api/v1" + apiv1server "github.com/usememos/memos/server/router/api/v1" + "github.com/usememos/memos/store" ) func TestUserResourceName(t *testing.T) { @@ -100,7 +103,7 @@ func TestUserResourceName(t *testing.T) { require.Contains(t, err.Error(), "invalid username") }) - t.Run("GetUser rejects numeric user resource names", func(t *testing.T) { + t.Run("GetUser returns not found for numeric user resource names", func(t *testing.T) { ts := NewTestService(t) defer ts.Cleanup() @@ -111,6 +114,84 @@ func TestUserResourceName(t *testing.T) { Name: "users/1", }) require.Error(t, err) - require.Contains(t, err.Error(), "invalid user name") + require.Contains(t, err.Error(), "user not found") + }) + + t.Run("legacy invalid username remains addressable for get update and delete", func(t *testing.T) { + ts := NewTestService(t) + defer ts.Cleanup() + + legacyUser, err := ts.CreateRegularUser(ctx, "legacy_user") + require.NoError(t, err) + + got, err := ts.Service.GetUser(ctx, &apiv1.GetUserRequest{ + Name: "users/legacy_user", + }) + require.NoError(t, err) + require.NotNil(t, got) + require.Equal(t, "users/legacy_user", got.Name) + + authCtx := ts.CreateUserContext(apiv1server.WithHeaderCarrier(ctx), legacyUser.ID) + updated, err := ts.Service.UpdateUser(authCtx, &apiv1.UpdateUserRequest{ + User: &apiv1.User{ + Name: apiv1server.BuildUserName(legacyUser.Username), + DisplayName: "Legacy User", + }, + UpdateMask: &fieldmaskpb.FieldMask{Paths: []string{"display_name"}}, + }) + require.NoError(t, err) + require.Equal(t, "Legacy User", updated.DisplayName) + + _, err = ts.Service.DeleteUser(authCtx, &apiv1.DeleteUserRequest{ + Name: apiv1server.BuildUserName(legacyUser.Username), + }) + require.NoError(t, err) + + deleted, err := ts.Store.GetUser(ctx, &store.FindUser{ID: &legacyUser.ID}) + require.NoError(t, err) + require.Nil(t, deleted) + }) + + t.Run("email-like legacy username can be renamed to a valid username", func(t *testing.T) { + ts := NewTestService(t) + defer ts.Cleanup() + + legacyUser, err := ts.CreateRegularUser(ctx, "alice@example.com") + require.NoError(t, err) + + authCtx := ts.CreateUserContext(apiv1server.WithHeaderCarrier(ctx), legacyUser.ID) + updated, err := ts.Service.UpdateUser(authCtx, &apiv1.UpdateUserRequest{ + User: &apiv1.User{ + Name: apiv1server.BuildUserName(legacyUser.Username), + Username: "alice", + }, + UpdateMask: &fieldmaskpb.FieldMask{Paths: []string{"username"}}, + }) + require.NoError(t, err) + require.Equal(t, "users/alice", updated.Name) + require.Equal(t, "alice", updated.Username) + + renamed, err := ts.Store.GetUser(ctx, &store.FindUser{ID: &legacyUser.ID}) + require.NoError(t, err) + require.NotNil(t, renamed) + require.Equal(t, "alice", renamed.Username) + }) + + t.Run("email-like legacy username can be deleted", func(t *testing.T) { + ts := NewTestService(t) + defer ts.Cleanup() + + legacyUser, err := ts.CreateRegularUser(ctx, "bob@example.com") + require.NoError(t, err) + + authCtx := ts.CreateUserContext(apiv1server.WithHeaderCarrier(ctx), legacyUser.ID) + _, err = ts.Service.DeleteUser(authCtx, &apiv1.DeleteUserRequest{ + Name: apiv1server.BuildUserName(legacyUser.Username), + }) + require.NoError(t, err) + + deleted, err := ts.Store.GetUser(ctx, &store.FindUser{ID: &legacyUser.ID}) + require.NoError(t, err) + require.Nil(t, deleted) }) } diff --git a/server/router/api/v1/test/user_service_email_username_test.go b/server/router/api/v1/test/user_service_email_username_test.go new file mode 100644 index 000000000..4215ac868 --- /dev/null +++ b/server/router/api/v1/test/user_service_email_username_test.go @@ -0,0 +1,136 @@ +package test + +import ( + "context" + "testing" + + "github.com/stretchr/testify/require" + "google.golang.org/protobuf/types/known/fieldmaskpb" + + apiv1 "github.com/usememos/memos/proto/gen/api/v1" + apiv1server "github.com/usememos/memos/server/router/api/v1" + "github.com/usememos/memos/store" +) + +func TestUserServiceWithEmailLikeUsername(t *testing.T) { + ctx := context.Background() + + t.Run("GetUser accepts email-like username in resource name", func(t *testing.T) { + ts := NewTestService(t) + defer ts.Cleanup() + + user, err := ts.CreateRegularUser(ctx, "alice@example.com") + require.NoError(t, err) + + got, err := ts.Service.GetUser(ctx, &apiv1.GetUserRequest{ + Name: "users/alice@example.com", + }) + require.NoError(t, err) + require.NotNil(t, got) + require.Equal(t, user.Username, got.Username) + require.Equal(t, "users/alice@example.com", got.Name) + }) + + t.Run("ListUserSettings accepts email-like username in parent", func(t *testing.T) { + ts := NewTestService(t) + defer ts.Cleanup() + + user, err := ts.CreateRegularUser(ctx, "alice@example.com") + require.NoError(t, err) + + userCtx := ts.CreateUserContext(ctx, user.ID) + resp, err := ts.Service.ListUserSettings(userCtx, &apiv1.ListUserSettingsRequest{ + Parent: "users/alice@example.com", + }) + require.NoError(t, err) + require.NotNil(t, resp) + require.NotEmpty(t, resp.Settings) + }) + + t.Run("UpdateUser can change non-username fields for email-like username", func(t *testing.T) { + ts := NewTestService(t) + defer ts.Cleanup() + + user, err := ts.CreateRegularUser(ctx, "alice@example.com") + require.NoError(t, err) + + authCtx := ts.CreateUserContext(ctx, user.ID) + updated, err := ts.Service.UpdateUser(authCtx, &apiv1.UpdateUserRequest{ + User: &apiv1.User{ + Name: "users/alice@example.com", + DisplayName: "Alice Example", + }, + UpdateMask: &fieldmaskpb.FieldMask{Paths: []string{"display_name"}}, + }) + require.NoError(t, err) + require.Equal(t, "Alice Example", updated.DisplayName) + require.Equal(t, "users/alice@example.com", updated.Name) + }) + + t.Run("UpdateUser can rename email-like username to valid username", func(t *testing.T) { + ts := NewTestService(t) + defer ts.Cleanup() + + user, err := ts.CreateRegularUser(ctx, "bob@example.com") + require.NoError(t, err) + + authCtx := ts.CreateUserContext(ctx, user.ID) + updated, err := ts.Service.UpdateUser(authCtx, &apiv1.UpdateUserRequest{ + User: &apiv1.User{ + Name: "users/bob@example.com", + Username: "bob", + }, + UpdateMask: &fieldmaskpb.FieldMask{Paths: []string{"username"}}, + }) + require.NoError(t, err) + require.Equal(t, "bob", updated.Username) + require.Equal(t, apiv1server.BuildUserName("bob"), updated.Name) + + stored, err := ts.Store.GetUser(ctx, &store.FindUser{ID: &user.ID}) + require.NoError(t, err) + require.NotNil(t, stored) + require.Equal(t, "bob", stored.Username) + }) + + t.Run("UpdateUser can archive email-like username account", func(t *testing.T) { + ts := NewTestService(t) + defer ts.Cleanup() + + user, err := ts.CreateRegularUser(ctx, "dave@example.com") + require.NoError(t, err) + + authCtx := ts.CreateUserContext(ctx, user.ID) + updated, err := ts.Service.UpdateUser(authCtx, &apiv1.UpdateUserRequest{ + User: &apiv1.User{ + Name: "users/dave@example.com", + State: apiv1.State_ARCHIVED, + }, + UpdateMask: &fieldmaskpb.FieldMask{Paths: []string{"state"}}, + }) + require.NoError(t, err) + require.Equal(t, apiv1.State_ARCHIVED, updated.State) + + stored, err := ts.Store.GetUser(ctx, &store.FindUser{ID: &user.ID}) + require.NoError(t, err) + require.NotNil(t, stored) + require.Equal(t, store.Archived, stored.RowStatus) + }) + + t.Run("DeleteUser can remove email-like username account", func(t *testing.T) { + ts := NewTestService(t) + defer ts.Cleanup() + + user, err := ts.CreateRegularUser(ctx, "carol@example.com") + require.NoError(t, err) + + authCtx := ts.CreateUserContext(apiv1server.WithHeaderCarrier(ctx), user.ID) + _, err = ts.Service.DeleteUser(authCtx, &apiv1.DeleteUserRequest{ + Name: "users/carol@example.com", + }) + require.NoError(t, err) + + deleted, err := ts.Store.GetUser(ctx, &store.FindUser{ID: &user.ID}) + require.NoError(t, err) + require.Nil(t, deleted) + }) +} diff --git a/server/router/api/v1/test/user_service_registration_test.go b/server/router/api/v1/test/user_service_registration_test.go index e2a5e7e09..052e97ead 100644 --- a/server/router/api/v1/test/user_service_registration_test.go +++ b/server/router/api/v1/test/user_service_registration_test.go @@ -5,9 +5,11 @@ import ( "testing" "github.com/stretchr/testify/require" + "google.golang.org/protobuf/types/known/fieldmaskpb" apiv1 "github.com/usememos/memos/proto/gen/api/v1" storepb "github.com/usememos/memos/proto/gen/store" + apiv1server "github.com/usememos/memos/server/router/api/v1" ) func TestCreateUserRegistration(t *testing.T) { @@ -172,4 +174,66 @@ func TestCreateUserRegistration(t *testing.T) { require.Equal(t, "users/wannabeadmin", createdUser.Name) require.Equal(t, apiv1.User_USER, createdUser.Role, "Unauthenticated users can only create USER role") }) + + t.Run("CreateUser blocked when password auth disabled for self signup", func(t *testing.T) { + ts := NewTestService(t) + defer ts.Cleanup() + + _, err := ts.CreateHostUser(ctx, "admin") + require.NoError(t, err) + + _, err = ts.Store.UpsertInstanceSetting(ctx, &storepb.InstanceSetting{ + Key: storepb.InstanceSettingKey_GENERAL, + Value: &storepb.InstanceSetting_GeneralSetting{ + GeneralSetting: &storepb.InstanceGeneralSetting{ + DisallowPasswordAuth: true, + }, + }, + }) + require.NoError(t, err) + + _, err = ts.Service.CreateUser(ctx, &apiv1.CreateUserRequest{ + User: &apiv1.User{ + Username: "newuser", + Email: "newuser@example.com", + Password: "password123", + }, + }) + require.Error(t, err) + require.Contains(t, err.Error(), "password signup is not allowed") + }) + + t.Run("CreateUser rejects empty password", func(t *testing.T) { + ts := NewTestService(t) + defer ts.Cleanup() + + _, err := ts.Service.CreateUser(ctx, &apiv1.CreateUserRequest{ + User: &apiv1.User{ + Username: "newuser", + Email: "newuser@example.com", + Password: "", + }, + }) + require.Error(t, err) + require.Contains(t, err.Error(), "password must not be empty") + }) + + t.Run("UpdateUser rejects empty password", func(t *testing.T) { + ts := NewTestService(t) + defer ts.Cleanup() + + user, err := ts.CreateRegularUser(ctx, "alice") + require.NoError(t, err) + + authCtx := ts.CreateUserContext(ctx, user.ID) + _, err = ts.Service.UpdateUser(authCtx, &apiv1.UpdateUserRequest{ + User: &apiv1.User{ + Name: apiv1server.BuildUserName(user.Username), + Password: "", + }, + UpdateMask: &fieldmaskpb.FieldMask{Paths: []string{"password"}}, + }) + require.Error(t, err) + require.Contains(t, err.Error(), "password must not be empty") + }) } diff --git a/server/router/api/v1/test/user_service_stats_test.go b/server/router/api/v1/test/user_service_stats_test.go index 3e92013bf..901191fb6 100644 --- a/server/router/api/v1/test/user_service_stats_test.go +++ b/server/router/api/v1/test/user_service_stats_test.go @@ -108,5 +108,5 @@ func TestGetUserStats_TagCount(t *testing.T) { Name: "users/1", }) require.Error(t, err) - require.Contains(t, err.Error(), "invalid user name") + require.Contains(t, err.Error(), "user not found") } diff --git a/server/router/api/v1/user_resource_name.go b/server/router/api/v1/user_resource_name.go index e5d7716ee..dd677777a 100644 --- a/server/router/api/v1/user_resource_name.go +++ b/server/router/api/v1/user_resource_name.go @@ -14,8 +14,7 @@ func BuildUserName(username string) string { return UserNamePrefix + username } -// ExtractUsernameFromName extracts the username token from a user resource name. -func ExtractUsernameFromName(name string) (string, error) { +func parseUsernameFromName(name string) (string, error) { tokens, err := GetNameParentTokens(name, UserNamePrefix) if err != nil { return "", err @@ -24,9 +23,6 @@ func ExtractUsernameFromName(name string) (string, error) { if username == "" { return "", errors.Errorf("invalid user name %q", name) } - if err := validateUsername(username); err != nil { - return "", err - } return username, nil } @@ -51,7 +47,7 @@ func isNumericUsername(username string) bool { // ResolveUserByName resolves a username-based user resource name to a store user. func ResolveUserByName(ctx context.Context, stores *store.Store, name string) (*store.User, error) { - username, err := ExtractUsernameFromName(name) + username, err := parseUsernameFromName(name) if err != nil { return nil, err } diff --git a/server/router/api/v1/user_service.go b/server/router/api/v1/user_service.go index fdb477bd8..9e587859c 100644 --- a/server/router/api/v1/user_service.go +++ b/server/router/api/v1/user_service.go @@ -30,6 +30,13 @@ import ( const maxBatchGetUsers = 100 +func validatePassword(password string) error { + if password == "" { + return errors.New("password must not be empty") + } + return nil +} + func (s *APIV1Service) ListUsers(ctx context.Context, request *v1pb.ListUsersRequest) (*v1pb.ListUsersResponse, error) { currentUser, err := s.fetchCurrentUser(ctx) if err != nil { @@ -156,6 +163,9 @@ func (s *APIV1Service) CreateUser(ctx context.Context, request *v1pb.CreateUserR if instanceGeneralSetting.DisallowUserRegistration { return nil, status.Errorf(codes.PermissionDenied, "user registration is not allowed") } + if instanceGeneralSetting.DisallowPasswordAuth { + return nil, status.Errorf(codes.PermissionDenied, "password signup is not allowed") + } } } @@ -179,6 +189,9 @@ func (s *APIV1Service) CreateUser(ctx context.Context, request *v1pb.CreateUserR if err := validateUsername(request.User.Username); err != nil { return nil, status.Errorf(codes.InvalidArgument, "invalid username: %s", request.User.Username) } + if err := validatePassword(request.User.Password); err != nil { + return nil, status.Errorf(codes.InvalidArgument, "%v", err) + } // If validate_only is true, just validate without creating if request.ValidateOnly { @@ -294,6 +307,9 @@ func (s *APIV1Service) UpdateUser(ctx context.Context, request *v1pb.UpdateUserR role := convertUserRoleToStore(request.User.Role) update.Role = &role case "password": + if err := validatePassword(request.User.Password); err != nil { + return nil, status.Errorf(codes.InvalidArgument, "%v", err) + } passwordHash, err := bcrypt.GenerateFromPassword([]byte(request.User.Password), bcrypt.DefaultCost) if err != nil { return nil, status.Errorf(codes.Internal, "failed to generate password hash: %v", err) diff --git a/web/src/pages/SignUp.tsx b/web/src/pages/SignUp.tsx index 9c7196f0f..9b01d5729 100644 --- a/web/src/pages/SignUp.tsx +++ b/web/src/pages/SignUp.tsx @@ -30,6 +30,7 @@ const SignUp = () => { const [searchParams] = useSearchParams(); const redirectTarget = getSafeRedirectPath(searchParams.get(AUTH_REDIRECT_PARAM)); const signInPath = searchParams.toString() ? `${ROUTES.AUTH}?${searchParams.toString()}` : ROUTES.AUTH; + const canUsePasswordSignUp = !instanceGeneralSetting.disallowUserRegistration && !instanceGeneralSetting.disallowPasswordAuth; const handleUsernameInputChanged = (e: React.ChangeEvent) => { const text = e.target.value as string; @@ -93,7 +94,7 @@ const SignUp = () => {

{instanceGeneralSetting.customProfile?.title || "Memos"}

- {!instanceGeneralSetting.disallowUserRegistration ? ( + {canUsePasswordSignUp ? ( <>

{t("auth.create-your-account")}

@@ -137,6 +138,8 @@ const SignUp = () => {
+ ) : instanceGeneralSetting.disallowPasswordAuth ? ( +

Password sign up is not allowed.

) : (

Sign up is not allowed.

)}