diff --git a/api/v1/auth.go b/api/v1/auth.go index 755588e95..310e9a6f6 100644 --- a/api/v1/auth.go +++ b/api/v1/auth.go @@ -74,16 +74,19 @@ func (s *APIV1Service) registerAuthRoutes(g *echo.Group, secret string) { return echo.NewHTTPError(http.StatusBadRequest, "Malformatted signin request").SetInternal(err) } - identityProviderMessage, err := s.Store.GetIdentityProvider(ctx, &store.FindIdentityProviderMessage{ + identityProvider, err := s.Store.GetIdentityProvider(ctx, &store.FindIdentityProvider{ ID: &signin.IdentityProviderID, }) if err != nil { return echo.NewHTTPError(http.StatusInternalServerError, "Failed to find identity provider").SetInternal(err) } + if identityProvider == nil { + return echo.NewHTTPError(http.StatusNotFound, "Identity provider not found") + } var userInfo *idp.IdentityProviderUserInfo - if identityProviderMessage.Type == store.IdentityProviderOAuth2 { - oauth2IdentityProvider, err := oauth2.NewIdentityProvider(identityProviderMessage.Config.OAuth2Config) + if identityProvider.Type == store.IdentityProviderOAuth2 { + oauth2IdentityProvider, err := oauth2.NewIdentityProvider(identityProvider.Config.OAuth2Config) if err != nil { return echo.NewHTTPError(http.StatusInternalServerError, "Failed to create identity provider instance").SetInternal(err) } @@ -97,7 +100,7 @@ func (s *APIV1Service) registerAuthRoutes(g *echo.Group, secret string) { } } - identifierFilter := identityProviderMessage.IdentifierFilter + identifierFilter := identityProvider.IdentifierFilter if identifierFilter != "" { identifierFilterRegex, err := regexp.Compile(identifierFilter) if err != nil { diff --git a/api/v1/idp.go b/api/v1/idp.go index 650c7ac27..03aa5e211 100644 --- a/api/v1/idp.go +++ b/api/v1/idp.go @@ -83,7 +83,7 @@ func (s *APIV1Service) registerIdentityProviderRoutes(g *echo.Group) { return echo.NewHTTPError(http.StatusBadRequest, "Malformatted post identity provider request").SetInternal(err) } - identityProviderMessage, err := s.Store.CreateIdentityProvider(ctx, &store.IdentityProviderMessage{ + identityProvider, err := s.Store.CreateIdentityProvider(ctx, &store.IdentityProvider{ Name: identityProviderCreate.Name, Type: store.IdentityProviderType(identityProviderCreate.Type), IdentifierFilter: identityProviderCreate.IdentifierFilter, @@ -92,7 +92,7 @@ func (s *APIV1Service) registerIdentityProviderRoutes(g *echo.Group) { if err != nil { return echo.NewHTTPError(http.StatusInternalServerError, "Failed to create identity provider").SetInternal(err) } - return c.JSON(http.StatusOK, convertIdentityProviderFromStore(identityProviderMessage)) + return c.JSON(http.StatusOK, convertIdentityProviderFromStore(identityProvider)) }) g.PATCH("/idp/:idpId", func(c echo.Context) error { @@ -124,7 +124,7 @@ func (s *APIV1Service) registerIdentityProviderRoutes(g *echo.Group) { return echo.NewHTTPError(http.StatusBadRequest, "Malformatted patch identity provider request").SetInternal(err) } - identityProviderMessage, err := s.Store.UpdateIdentityProvider(ctx, &store.UpdateIdentityProviderMessage{ + identityProvider, err := s.Store.UpdateIdentityProvider(ctx, &store.UpdateIdentityProvider{ ID: identityProviderPatch.ID, Type: store.IdentityProviderType(identityProviderPatch.Type), Name: identityProviderPatch.Name, @@ -134,12 +134,12 @@ func (s *APIV1Service) registerIdentityProviderRoutes(g *echo.Group) { if err != nil { return echo.NewHTTPError(http.StatusInternalServerError, "Failed to patch identity provider").SetInternal(err) } - return c.JSON(http.StatusOK, convertIdentityProviderFromStore(identityProviderMessage)) + return c.JSON(http.StatusOK, convertIdentityProviderFromStore(identityProvider)) }) g.GET("/idp", func(c echo.Context) error { ctx := c.Request().Context() - identityProviderMessageList, err := s.Store.ListIdentityProviders(ctx, &store.FindIdentityProviderMessage{}) + list, err := s.Store.ListIdentityProviders(ctx, &store.FindIdentityProvider{}) if err != nil { return echo.NewHTTPError(http.StatusInternalServerError, "Failed to find identity provider list").SetInternal(err) } @@ -159,8 +159,8 @@ func (s *APIV1Service) registerIdentityProviderRoutes(g *echo.Group) { } identityProviderList := []*IdentityProvider{} - for _, identityProviderMessage := range identityProviderMessageList { - identityProvider := convertIdentityProviderFromStore(identityProviderMessage) + for _, item := range list { + identityProvider := convertIdentityProviderFromStore(item) // data desensitize if !isHostUser { identityProvider.Config.OAuth2Config.ClientSecret = "" @@ -191,13 +191,17 @@ func (s *APIV1Service) registerIdentityProviderRoutes(g *echo.Group) { if err != nil { return echo.NewHTTPError(http.StatusBadRequest, fmt.Sprintf("ID is not a number: %s", c.Param("idpId"))).SetInternal(err) } - identityProviderMessage, err := s.Store.GetIdentityProvider(ctx, &store.FindIdentityProviderMessage{ + identityProvider, err := s.Store.GetIdentityProvider(ctx, &store.FindIdentityProvider{ ID: &identityProviderID, }) if err != nil { return echo.NewHTTPError(http.StatusInternalServerError, "Failed to get identity provider").SetInternal(err) } - return c.JSON(http.StatusOK, convertIdentityProviderFromStore(identityProviderMessage)) + if identityProvider == nil { + return echo.NewHTTPError(http.StatusNotFound, "Identity provider not found") + } + + return c.JSON(http.StatusOK, convertIdentityProviderFromStore(identityProvider)) }) g.DELETE("/idp/:idpId", func(c echo.Context) error { @@ -222,7 +226,7 @@ func (s *APIV1Service) registerIdentityProviderRoutes(g *echo.Group) { return echo.NewHTTPError(http.StatusBadRequest, fmt.Sprintf("ID is not a number: %s", c.Param("idpId"))).SetInternal(err) } - if err = s.Store.DeleteIdentityProvider(ctx, &store.DeleteIdentityProviderMessage{ID: identityProviderID}); err != nil { + if err = s.Store.DeleteIdentityProvider(ctx, &store.DeleteIdentityProvider{ID: identityProviderID}); err != nil { if common.ErrorCode(err) == common.NotFound { return echo.NewHTTPError(http.StatusNotFound, fmt.Sprintf("Identity provider ID not found: %d", identityProviderID)) } @@ -232,13 +236,13 @@ func (s *APIV1Service) registerIdentityProviderRoutes(g *echo.Group) { }) } -func convertIdentityProviderFromStore(identityProviderMessage *store.IdentityProviderMessage) *IdentityProvider { +func convertIdentityProviderFromStore(identityProvider *store.IdentityProvider) *IdentityProvider { return &IdentityProvider{ - ID: identityProviderMessage.ID, - Name: identityProviderMessage.Name, - Type: IdentityProviderType(identityProviderMessage.Type), - IdentifierFilter: identityProviderMessage.IdentifierFilter, - Config: convertIdentityProviderConfigFromStore(identityProviderMessage.Config), + ID: identityProvider.ID, + Name: identityProvider.Name, + Type: IdentityProviderType(identityProvider.Type), + IdentifierFilter: identityProvider.IdentifierFilter, + Config: convertIdentityProviderConfigFromStore(identityProvider.Config), } } diff --git a/store/idp.go b/store/idp.go index 9b51e14ac..0167d22bc 100644 --- a/store/idp.go +++ b/store/idp.go @@ -6,8 +6,6 @@ import ( "encoding/json" "fmt" "strings" - - "github.com/usememos/memos/common" ) type IdentityProviderType string @@ -36,7 +34,7 @@ type FieldMapping struct { Email string `json:"email"` } -type IdentityProviderMessage struct { +type IdentityProvider struct { ID int Name string Type IdentityProviderType @@ -44,11 +42,11 @@ type IdentityProviderMessage struct { Config *IdentityProviderConfig } -type FindIdentityProviderMessage struct { +type FindIdentityProvider struct { ID *int } -type UpdateIdentityProviderMessage struct { +type UpdateIdentityProvider struct { ID int Type IdentityProviderType Name *string @@ -56,14 +54,14 @@ type UpdateIdentityProviderMessage struct { Config *IdentityProviderConfig } -type DeleteIdentityProviderMessage struct { +type DeleteIdentityProvider struct { ID int } -func (s *Store) CreateIdentityProvider(ctx context.Context, create *IdentityProviderMessage) (*IdentityProviderMessage, error) { +func (s *Store) CreateIdentityProvider(ctx context.Context, create *IdentityProvider) (*IdentityProvider, error) { tx, err := s.db.BeginTx(ctx, nil) if err != nil { - return nil, FormatError(err) + return nil, err } defer tx.Rollback() @@ -76,6 +74,7 @@ func (s *Store) CreateIdentityProvider(ctx context.Context, create *IdentityProv } else { return nil, fmt.Errorf("unsupported idp type %s", string(create.Type)) } + query := ` INSERT INTO idp ( name, @@ -96,20 +95,22 @@ func (s *Store) CreateIdentityProvider(ctx context.Context, create *IdentityProv ).Scan( &create.ID, ); err != nil { - return nil, FormatError(err) + return nil, err } + if err := tx.Commit(); err != nil { - return nil, FormatError(err) + return nil, err } - identityProviderMessage := create - s.idpCache.Store(identityProviderMessage.ID, identityProviderMessage) - return identityProviderMessage, nil + + identityProvider := create + s.idpCache.Store(identityProvider.ID, identityProvider) + return identityProvider, nil } -func (s *Store) ListIdentityProviders(ctx context.Context, find *FindIdentityProviderMessage) ([]*IdentityProviderMessage, error) { +func (s *Store) ListIdentityProviders(ctx context.Context, find *FindIdentityProvider) ([]*IdentityProvider, error) { tx, err := s.db.BeginTx(ctx, nil) if err != nil { - return nil, FormatError(err) + return nil, err } defer tx.Rollback() @@ -124,16 +125,16 @@ func (s *Store) ListIdentityProviders(ctx context.Context, find *FindIdentityPro return list, nil } -func (s *Store) GetIdentityProvider(ctx context.Context, find *FindIdentityProviderMessage) (*IdentityProviderMessage, error) { +func (s *Store) GetIdentityProvider(ctx context.Context, find *FindIdentityProvider) (*IdentityProvider, error) { if find.ID != nil { if cache, ok := s.idpCache.Load(*find.ID); ok { - return cache.(*IdentityProviderMessage), nil + return cache.(*IdentityProvider), nil } } tx, err := s.db.BeginTx(ctx, nil) if err != nil { - return nil, FormatError(err) + return nil, err } defer tx.Rollback() @@ -142,18 +143,18 @@ func (s *Store) GetIdentityProvider(ctx context.Context, find *FindIdentityProvi return nil, err } if len(list) == 0 { - return nil, &common.Error{Code: common.NotFound, Err: fmt.Errorf("not found")} + return nil, nil } - identityProviderMessage := list[0] - s.idpCache.Store(identityProviderMessage.ID, identityProviderMessage) - return identityProviderMessage, nil + identityProvider := list[0] + s.idpCache.Store(identityProvider.ID, identityProvider) + return identityProvider, nil } -func (s *Store) UpdateIdentityProvider(ctx context.Context, update *UpdateIdentityProviderMessage) (*IdentityProviderMessage, error) { +func (s *Store) UpdateIdentityProvider(ctx context.Context, update *UpdateIdentityProvider) (*IdentityProvider, error) { tx, err := s.db.BeginTx(ctx, nil) if err != nil { - return nil, FormatError(err) + return nil, err } defer tx.Rollback() @@ -184,39 +185,42 @@ func (s *Store) UpdateIdentityProvider(ctx context.Context, update *UpdateIdenti WHERE id = ? RETURNING id, name, type, identifier_filter, config ` - var identityProviderMessage IdentityProviderMessage + var identityProvider IdentityProvider var identityProviderConfig string if err := tx.QueryRowContext(ctx, query, args...).Scan( - &identityProviderMessage.ID, - &identityProviderMessage.Name, - &identityProviderMessage.Type, - &identityProviderMessage.IdentifierFilter, + &identityProvider.ID, + &identityProvider.Name, + &identityProvider.Type, + &identityProvider.IdentifierFilter, &identityProviderConfig, ); err != nil { - return nil, FormatError(err) + return nil, err } - if identityProviderMessage.Type == IdentityProviderOAuth2 { + + if identityProvider.Type == IdentityProviderOAuth2 { oauth2Config := &IdentityProviderOAuth2Config{} if err := json.Unmarshal([]byte(identityProviderConfig), oauth2Config); err != nil { return nil, err } - identityProviderMessage.Config = &IdentityProviderConfig{ + identityProvider.Config = &IdentityProviderConfig{ OAuth2Config: oauth2Config, } } else { - return nil, fmt.Errorf("unsupported idp type %s", string(identityProviderMessage.Type)) + return nil, fmt.Errorf("unsupported idp type %s", string(identityProvider.Type)) } + if err := tx.Commit(); err != nil { - return nil, FormatError(err) + return nil, err } - s.idpCache.Store(identityProviderMessage.ID, identityProviderMessage) - return &identityProviderMessage, nil + + s.idpCache.Store(identityProvider.ID, identityProvider) + return &identityProvider, nil } -func (s *Store) DeleteIdentityProvider(ctx context.Context, delete *DeleteIdentityProviderMessage) error { +func (s *Store) DeleteIdentityProvider(ctx context.Context, delete *DeleteIdentityProvider) error { tx, err := s.db.BeginTx(ctx, nil) if err != nil { - return FormatError(err) + return err } defer tx.Rollback() @@ -224,24 +228,22 @@ func (s *Store) DeleteIdentityProvider(ctx context.Context, delete *DeleteIdenti stmt := `DELETE FROM idp WHERE ` + strings.Join(where, " AND ") result, err := tx.ExecContext(ctx, stmt, args...) if err != nil { - return FormatError(err) + return err } - rows, err := result.RowsAffected() - if err != nil { + if _, err = result.RowsAffected(); err != nil { return err } - if rows == 0 { - return &common.Error{Code: common.NotFound, Err: fmt.Errorf("idp not found")} - } + if err := tx.Commit(); err != nil { return err } + s.idpCache.Delete(delete.ID) return nil } -func listIdentityProviders(ctx context.Context, tx *sql.Tx, find *FindIdentityProviderMessage) ([]*IdentityProviderMessage, error) { +func listIdentityProviders(ctx context.Context, tx *sql.Tx, find *FindIdentityProvider) ([]*IdentityProvider, error) { where, args := []string{"TRUE"}, []any{} if v := find.ID; v != nil { where, args = append(where, fmt.Sprintf("id = $%d", len(args)+1)), append(args, *v) @@ -259,40 +261,41 @@ func listIdentityProviders(ctx context.Context, tx *sql.Tx, find *FindIdentityPr args..., ) if err != nil { - return nil, FormatError(err) + return nil, err } defer rows.Close() - var identityProviderMessages []*IdentityProviderMessage + var identityProviders []*IdentityProvider for rows.Next() { - var identityProviderMessage IdentityProviderMessage + var identityProvider IdentityProvider var identityProviderConfig string if err := rows.Scan( - &identityProviderMessage.ID, - &identityProviderMessage.Name, - &identityProviderMessage.Type, - &identityProviderMessage.IdentifierFilter, + &identityProvider.ID, + &identityProvider.Name, + &identityProvider.Type, + &identityProvider.IdentifierFilter, &identityProviderConfig, ); err != nil { - return nil, FormatError(err) + return nil, err } - if identityProviderMessage.Type == IdentityProviderOAuth2 { + + if identityProvider.Type == IdentityProviderOAuth2 { oauth2Config := &IdentityProviderOAuth2Config{} if err := json.Unmarshal([]byte(identityProviderConfig), oauth2Config); err != nil { return nil, err } - identityProviderMessage.Config = &IdentityProviderConfig{ + identityProvider.Config = &IdentityProviderConfig{ OAuth2Config: oauth2Config, } } else { - return nil, fmt.Errorf("unsupported idp type %s", string(identityProviderMessage.Type)) + return nil, fmt.Errorf("unsupported idp type %s", string(identityProvider.Type)) } - identityProviderMessages = append(identityProviderMessages, &identityProviderMessage) + identityProviders = append(identityProviders, &identityProvider) } if err := rows.Err(); err != nil { return nil, err } - return identityProviderMessages, nil + return identityProviders, nil } diff --git a/store/store.go b/store/store.go index 5c4f55f5f..c21049d5f 100644 --- a/store/store.go +++ b/store/store.go @@ -16,7 +16,7 @@ type Store struct { userCache sync.Map // map[int]*userRaw userSettingCache sync.Map // map[string]*UserSettingMessage shortcutCache sync.Map // map[int]*shortcutRaw - idpCache sync.Map // map[int]*IdentityProviderMessage + idpCache sync.Map // map[int]*IdentityProvider resourceCache sync.Map // map[int]*resourceRaw } diff --git a/test/store/idp_test.go b/test/store/idp_test.go index 483c30608..21d1d83bc 100644 --- a/test/store/idp_test.go +++ b/test/store/idp_test.go @@ -12,14 +12,14 @@ import ( func TestIdentityProviderStore(t *testing.T) { ctx := context.Background() ts := NewTestingStore(ctx, t) - createdIDP, err := ts.CreateIdentityProvider(ctx, &store.IdentityProviderMessage{ + createdIDP, err := ts.CreateIdentityProvider(ctx, &store.IdentityProvider{ Name: "GitHub OAuth", Type: store.IdentityProviderOAuth2, IdentifierFilter: "", Config: &store.IdentityProviderConfig{ OAuth2Config: &store.IdentityProviderOAuth2Config{ - ClientID: "asd", - ClientSecret: "123", + ClientID: "client_id", + ClientSecret: "client_secret", AuthURL: "https://github.com/auth", TokenURL: "https://github.com/token", UserInfoURL: "https://github.com/user", @@ -33,16 +33,23 @@ func TestIdentityProviderStore(t *testing.T) { }, }) require.NoError(t, err) - idp, err := ts.GetIdentityProvider(ctx, &store.FindIdentityProviderMessage{ + idp, err := ts.GetIdentityProvider(ctx, &store.FindIdentityProvider{ ID: &createdIDP.ID, }) require.NoError(t, err) require.Equal(t, createdIDP, idp) - err = ts.DeleteIdentityProvider(ctx, &store.DeleteIdentityProviderMessage{ + newName := "My GitHub OAuth" + updatedIdp, err := ts.UpdateIdentityProvider(ctx, &store.UpdateIdentityProvider{ + ID: idp.ID, + Name: &newName, + }) + require.NoError(t, err) + require.Equal(t, newName, updatedIdp.Name) + err = ts.DeleteIdentityProvider(ctx, &store.DeleteIdentityProvider{ ID: idp.ID, }) require.NoError(t, err) - idpList, err := ts.ListIdentityProviders(ctx, &store.FindIdentityProviderMessage{}) + idpList, err := ts.ListIdentityProviders(ctx, &store.FindIdentityProvider{}) require.NoError(t, err) require.Equal(t, 0, len(idpList)) }