From 8101a5e0b162044c16385bee4f12a4a653d050b9 Mon Sep 17 00:00:00 2001 From: Steven Date: Sun, 7 Apr 2024 22:15:15 +0800 Subject: [PATCH] chore: add origin flag to config cors --- bin/memos/main.go | 25 ++++++++++++++++--------- server/profile/profile.go | 2 ++ server/server.go | 17 ++++++++++++++--- 3 files changed, 32 insertions(+), 12 deletions(-) diff --git a/bin/memos/main.go b/bin/memos/main.go index 9669d8188..e8181bced 100644 --- a/bin/memos/main.go +++ b/bin/memos/main.go @@ -31,18 +31,19 @@ const ( ) var ( - profile *_profile.Profile - mode string - addr string - port int - data string - driver string - dsn string - serveFrontend bool + profile *_profile.Profile + mode string + addr string + port int + data string + driver string + dsn string + serveFrontend bool + allowedOrigins []string rootCmd = &cobra.Command{ Use: "memos", - Short: `An open-source, self-hosted memo hub with knowledge management and social networking.`, + Short: `An open source, lightweight note-taking service. Easily capture and share your great thoughts.`, Run: func(_cmd *cobra.Command, _args []string) { ctx, cancel := context.WithCancel(context.Background()) dbDriver, err := db.NewDBDriver(profile) @@ -114,6 +115,7 @@ func init() { rootCmd.PersistentFlags().StringVarP(&driver, "driver", "", "", "database driver") rootCmd.PersistentFlags().StringVarP(&dsn, "dsn", "", "", "database source name(aka. DSN)") rootCmd.PersistentFlags().BoolVarP(&serveFrontend, "frontend", "", true, "serve frontend files") + rootCmd.PersistentFlags().StringArrayVarP(&allowedOrigins, "origins", "", []string{}, "CORS allowed domain origins") err := viper.BindPFlag("mode", rootCmd.PersistentFlags().Lookup("mode")) if err != nil { @@ -143,12 +145,17 @@ func init() { if err != nil { panic(err) } + err = viper.BindPFlag("origins", rootCmd.PersistentFlags().Lookup("origins")) + if err != nil { + panic(err) + } viper.SetDefault("mode", "demo") viper.SetDefault("driver", "sqlite") viper.SetDefault("addr", "") viper.SetDefault("port", 8081) viper.SetDefault("frontend", true) + viper.SetDefault("origins", []string{}) viper.SetEnvPrefix("memos") } diff --git a/server/profile/profile.go b/server/profile/profile.go index 74cde1eb5..d9d15c27f 100644 --- a/server/profile/profile.go +++ b/server/profile/profile.go @@ -32,6 +32,8 @@ type Profile struct { Version string `json:"version"` // Frontend indicate the frontend is enabled or not Frontend bool `json:"-"` + // Origins is the list of allowed origins + Origins []string `json:"-"` } func (p *Profile) IsDev() bool { diff --git a/server/server.go b/server/server.go index 9d9fc8d0d..132ecd4bf 100644 --- a/server/server.go +++ b/server/server.go @@ -49,7 +49,7 @@ func NewServer(ctx context.Context, profile *profile.Profile, store *store.Store } // Register CORS middleware. - e.Use(CORSMiddleware()) + e.Use(CORSMiddleware(s.Profile.Origins)) serverID, err := s.getSystemServerID(ctx) if err != nil { @@ -160,7 +160,7 @@ func grpcRequestSkipper(c echo.Context) bool { return strings.HasPrefix(c.Request().URL.Path, "/memos.api.v2.") } -func CORSMiddleware() echo.MiddlewareFunc { +func CORSMiddleware(origins []string) echo.MiddlewareFunc { return func(next echo.HandlerFunc) echo.HandlerFunc { return func(c echo.Context) error { if grpcRequestSkipper(c) { @@ -170,7 +170,18 @@ func CORSMiddleware() echo.MiddlewareFunc { r := c.Request() w := c.Response().Writer - w.Header().Set("Access-Control-Allow-Origin", r.Header.Get("Origin")) + requestOrigin := r.Header.Get("Origin") + if len(origins) == 0 { + w.Header().Set("Access-Control-Allow-Origin", requestOrigin) + } else { + for _, origin := range origins { + if origin == requestOrigin { + w.Header().Set("Access-Control-Allow-Origin", origin) + break + } + } + } + w.Header().Set("Access-Control-Allow-Methods", "GET, POST, PUT, DELETE, PATCH, OPTIONS") w.Header().Set("Access-Control-Allow-Headers", "Content-Type, Authorization") w.Header().Set("Access-Control-Allow-Credentials", "true")