diff --git a/api/idp.go b/api/idp.go new file mode 100644 index 00000000..8e7753db --- /dev/null +++ b/api/idp.go @@ -0,0 +1,56 @@ +package api + +type IdentityProviderType string + +const ( + IdentityProviderOAuth2 IdentityProviderType = "OAUTH2" +) + +type IdentityProviderConfig interface{} + +type IdentityProviderOAuth2Config struct { + ClientID string `json:"clientId"` + ClientSecret string `json:"clientSecret"` + AuthURL string `json:"authUrl"` + TokenURL string `json:"tokenUrl"` + UserInfoURL string `json:"userInfoUrl"` + Scopes []string `json:"scopes"` + FieldMapping *FieldMapping `json:"fieldMapping"` +} + +type FieldMapping struct { + Identifier string `json:"identifier"` + DisplayName string `json:"displayName"` + Email string `json:"email"` +} + +type IdentityProvider struct { + ID int `json:"id"` + Name string `json:"name"` + Type IdentityProviderType `json:"type"` + IdentifierFilter string `json:"identifierFilter"` + Config *IdentityProviderConfig `json:"config"` +} + +type IdentityProviderCreate struct { + Name string `json:"name"` + Type IdentityProviderType `json:"type"` + IdentifierFilter string `json:"identifierFilter"` + Config *IdentityProviderConfig `json:"config"` +} + +type IdentityProviderFind struct { + ID *int +} + +type IdentityProviderPatch struct { + ID int + Type IdentityProviderType `json:"type"` + Name *string `json:"name"` + IdentifierFilter *string `json:"identifierFilter"` + Config *IdentityProviderConfig `json:"config"` +} + +type IdentityProviderDelete struct { + ID int +} diff --git a/server/idp.go b/server/idp.go new file mode 100644 index 00000000..d1d87a01 --- /dev/null +++ b/server/idp.go @@ -0,0 +1,178 @@ +package server + +import ( + "encoding/json" + "fmt" + "net/http" + "strconv" + + "github.com/labstack/echo/v4" + "github.com/usememos/memos/api" + "github.com/usememos/memos/common" + "github.com/usememos/memos/store" +) + +func (s *Server) registerIdentityProviderRoutes(g *echo.Group) { + g.POST("/idp", func(c echo.Context) error { + ctx := c.Request().Context() + userID, ok := c.Get(getUserIDContextKey()).(int) + if !ok { + return echo.NewHTTPError(http.StatusUnauthorized, "Missing user in session") + } + + user, err := s.Store.FindUser(ctx, &api.UserFind{ + ID: &userID, + }) + if err != nil { + return echo.NewHTTPError(http.StatusInternalServerError, "Failed to find user").SetInternal(err) + } + if user == nil || user.Role != api.Host { + return echo.NewHTTPError(http.StatusUnauthorized, "Unauthorized") + } + + identityProviderCreate := &api.IdentityProviderCreate{} + if err := json.NewDecoder(c.Request().Body).Decode(identityProviderCreate); err != nil { + return echo.NewHTTPError(http.StatusBadRequest, "Malformatted post identity provider request").SetInternal(err) + } + + identityProvider, err := s.Store.CreateIdentityProvider(ctx, &store.IdentityProviderMessage{ + Name: identityProviderCreate.Name, + Type: store.IdentityProviderType(identityProviderCreate.Type), + IdentifierFilter: identityProviderCreate.IdentifierFilter, + Config: (*store.IdentityProviderConfig)(identityProviderCreate.Config), + }) + if err != nil { + return echo.NewHTTPError(http.StatusInternalServerError, "Failed to create identity provider").SetInternal(err) + } + return c.JSON(http.StatusOK, composeResponse(identityProvider)) + }) + + g.PATCH("/idp/:idpId", func(c echo.Context) error { + ctx := c.Request().Context() + userID, ok := c.Get(getUserIDContextKey()).(int) + if !ok { + return echo.NewHTTPError(http.StatusUnauthorized, "Missing user in session") + } + + user, err := s.Store.FindUser(ctx, &api.UserFind{ + ID: &userID, + }) + if err != nil { + return echo.NewHTTPError(http.StatusInternalServerError, "Failed to find user").SetInternal(err) + } + if user == nil || user.Role != api.Host { + return echo.NewHTTPError(http.StatusUnauthorized, "Unauthorized") + } + + identityProviderID, err := strconv.Atoi(c.Param("idpId")) + if err != nil { + return echo.NewHTTPError(http.StatusBadRequest, fmt.Sprintf("ID is not a number: %s", c.Param("idpId"))).SetInternal(err) + } + + identityProviderPatch := &api.IdentityProviderPatch{ + ID: identityProviderID, + } + if err := json.NewDecoder(c.Request().Body).Decode(identityProviderPatch); err != nil { + return echo.NewHTTPError(http.StatusBadRequest, "Malformatted patch identity provider request").SetInternal(err) + } + + identityProvider, err := s.Store.UpdateIdentityProvider(ctx, &store.UpdateIdentityProviderMessage{ + ID: identityProviderPatch.ID, + Type: store.IdentityProviderType(identityProviderPatch.Type), + Name: identityProviderPatch.Name, + IdentifierFilter: identityProviderPatch.IdentifierFilter, + Config: (*store.IdentityProviderConfig)(identityProviderPatch.Config), + }) + if err != nil { + return echo.NewHTTPError(http.StatusInternalServerError, "Failed to patch identity provider").SetInternal(err) + } + return c.JSON(http.StatusOK, identityProvider) + }) + + g.GET("/idp", func(c echo.Context) error { + ctx := c.Request().Context() + userID, ok := c.Get(getUserIDContextKey()).(int) + if !ok { + return echo.NewHTTPError(http.StatusUnauthorized, "Missing user in session") + } + + user, err := s.Store.FindUser(ctx, &api.UserFind{ + ID: &userID, + }) + if err != nil { + return echo.NewHTTPError(http.StatusInternalServerError, "Failed to find user").SetInternal(err) + } + // We should only show identity provider list to host user. + if user == nil || user.Role != api.Host { + return echo.NewHTTPError(http.StatusUnauthorized, "Unauthorized") + } + + identityProviderMessageList, err := s.Store.ListIdentityProviders(ctx, &store.FindIdentityProviderMessage{}) + if err != nil { + return echo.NewHTTPError(http.StatusInternalServerError, "Failed to find identity provider list").SetInternal(err) + } + + var identityProviderList []*api.IdentityProvider + for _, identityProviderMessage := range identityProviderMessageList { + identityProviderList = append(identityProviderList, convertIdentityProviderFromStore(identityProviderMessage)) + } + return c.JSON(http.StatusOK, composeResponse(identityProviderList)) + }) + + g.DELETE("/idp/:idpId", func(c echo.Context) error { + ctx := c.Request().Context() + userID, ok := c.Get(getUserIDContextKey()).(int) + if !ok { + return echo.NewHTTPError(http.StatusUnauthorized, "Missing user in session") + } + + user, err := s.Store.FindUser(ctx, &api.UserFind{ + ID: &userID, + }) + if err != nil { + return echo.NewHTTPError(http.StatusInternalServerError, "Failed to find user").SetInternal(err) + } + if user == nil || user.Role != api.Host { + return echo.NewHTTPError(http.StatusUnauthorized, "Unauthorized") + } + + identityProviderID, err := strconv.Atoi(c.Param("idpId")) + if err != nil { + return echo.NewHTTPError(http.StatusBadRequest, fmt.Sprintf("ID is not a number: %s", c.Param("idpId"))).SetInternal(err) + } + + if err = s.Store.DeleteIdentityProvider(ctx, &store.DeleteIdentityProviderMessage{ID: identityProviderID}); err != nil { + if common.ErrorCode(err) == common.NotFound { + return echo.NewHTTPError(http.StatusNotFound, fmt.Sprintf("Identity provider ID not found: %d", identityProviderID)) + } + return echo.NewHTTPError(http.StatusInternalServerError, "Failed to delete identity provider").SetInternal(err) + } + return c.JSON(http.StatusOK, true) + }) +} + +func convertIdentityProviderFromStore(identityProviderMessage *store.IdentityProviderMessage) *api.IdentityProvider { + identityProvider := &api.IdentityProvider{ + ID: identityProviderMessage.ID, + Name: identityProviderMessage.Name, + Type: api.IdentityProviderType(identityProviderMessage.Type), + IdentifierFilter: identityProviderMessage.IdentifierFilter, + } + if identityProvider.Type == api.IdentityProviderOAuth2 { + configMessage := any(identityProviderMessage.Config).(*store.IdentityProviderOAuth2Config) + identityProvider.Config = any(&api.IdentityProviderOAuth2Config{ + ClientID: configMessage.ClientID, + ClientSecret: configMessage.ClientSecret, + AuthURL: configMessage.AuthURL, + TokenURL: configMessage.TokenURL, + UserInfoURL: configMessage.UserInfoURL, + Scopes: configMessage.Scopes, + FieldMapping: &api.FieldMapping{ + Identifier: configMessage.FieldMapping.Identifier, + DisplayName: configMessage.FieldMapping.DisplayName, + Email: configMessage.FieldMapping.Email, + }, + }).(*api.IdentityProviderConfig) + } + return identityProvider +} diff --git a/server/server.go b/server/server.go index 86e55c25..30e4ed42 100644 --- a/server/server.go +++ b/server/server.go @@ -116,6 +116,7 @@ func NewServer(ctx context.Context, profile *profile.Profile) (*Server, error) { s.registerResourceRoutes(apiGroup) s.registerTagRoutes(apiGroup) s.registerStorageRoutes(apiGroup) + s.registerIdentityProviderRoutes(apiGroup) return s, nil } diff --git a/store/idp.go b/store/idp.go index f64e38d5..cf556090 100644 --- a/store/idp.go +++ b/store/idp.go @@ -10,10 +10,10 @@ import ( "github.com/usememos/memos/common" ) -type IdentityProvideType string +type IdentityProviderType string const ( - IdentityProviderOAuth2 IdentityProvideType = "OAUTH2" + IdentityProviderOAuth2 IdentityProviderType = "OAUTH2" ) type IdentityProviderConfig interface{} @@ -29,15 +29,15 @@ type IdentityProviderOAuth2Config struct { } type FieldMapping struct { - Identifier string - DisplayName string - Email string + Identifier string `json:"identifier"` + DisplayName string `json:"displayName"` + Email string `json:"email"` } type IdentityProviderMessage struct { ID int Name string - Type IdentityProvideType + Type IdentityProviderType IdentifierFilter string Config *IdentityProviderConfig } @@ -48,7 +48,7 @@ type FindIdentityProviderMessage struct { type UpdateIdentityProviderMessage struct { ID int - Type IdentityProvideType + Type IdentityProviderType Name *string IdentifierFilter *string Config *IdentityProviderConfig