From a8f0c9a7b171f70be64bb97626aa5180e9bf49ef Mon Sep 17 00:00:00 2001 From: email Date: Fri, 4 Feb 2022 00:56:44 +0800 Subject: [PATCH] fix: get&set session --- api/auth.go | 8 ++++---- api/user.go | 12 ++++++------ server/auth.go | 21 +++++++++++++++------ server/jwt.go | 42 ++++++++++++++++++++++++++++++++---------- server/server.go | 3 ++- store/user.go | 5 +++-- 6 files changed, 62 insertions(+), 29 deletions(-) diff --git a/api/auth.go b/api/auth.go index 8191d8b1..5e8bf3e6 100644 --- a/api/auth.go +++ b/api/auth.go @@ -1,11 +1,11 @@ package api type Login struct { - Name string - Password string + Name string `jsonapi:"attr,name"` + Password string `jsonapi:"attr,password"` } type Signup struct { - Name string - Password string + Name string `jsonapi:"attr,name"` + Password string `jsonapi:"attr,password"` } diff --git a/api/user.go b/api/user.go index a2373831..80042117 100644 --- a/api/user.go +++ b/api/user.go @@ -5,25 +5,25 @@ type User struct { CreatedTs int64 `jsonapi:"attr,createdTs"` UpdatedTs int64 `jsonapi:"attr,updatedTs"` + OpenId string `jsonapi:"attr,openId"` Name string `jsonapi:"attr,name"` Password string - OpenId string `jsonapi:"attr,openId"` } type UserCreate struct { + OpenId string `jsonapi:"attr,openId"` Name string `jsonapi:"attr,name"` Password string `jsonapi:"attr,password"` - OpenId string `jsonapi:"attr,openId"` } type UserPatch struct { Id int - Name *string `jsonapi:"attr,name"` - Password *string `jsonapi:"attr,password"` - OpenId *string + OpenId *string - ResetOpenId *bool `jsonapi:"attr,resetOpenId"` + Name *string `jsonapi:"attr,name"` + Password *string `jsonapi:"attr,password"` + ResetOpenId *bool `jsonapi:"attr,resetOpenId"` } type UserFind struct { diff --git a/server/auth.go b/server/auth.go index f74a51e7..fe28f85b 100644 --- a/server/auth.go +++ b/server/auth.go @@ -34,26 +34,31 @@ func (s *Server) registerAuthRoutes(g *echo.Group) { return echo.NewHTTPError(http.StatusUnauthorized, "Incorrect password").SetInternal(err) } + err = setUserSession(c, user) + if err != nil { + return echo.NewHTTPError(http.StatusInternalServerError, "Failed to set login session").SetInternal(err) + } + c.Response().Header().Set(echo.HeaderContentType, echo.MIMEApplicationJSONCharsetUTF8) if err := jsonapi.MarshalPayload(c.Response().Writer, user); err != nil { return echo.NewHTTPError(http.StatusInternalServerError, "Failed to marshal create user response").SetInternal(err) } - setUserSession(c, user) - return nil }) g.POST("/auth/logout", func(c echo.Context) error { - removeUserSession(c) + err := removeUserSession(c) + if err != nil { + return echo.NewHTTPError(http.StatusInternalServerError, "Failed to set logout session").SetInternal(err) + } - c.Response().Header().Set(echo.HeaderContentType, echo.MIMEApplicationJSONCharsetUTF8) c.Response().WriteHeader(http.StatusOK) return nil }) g.POST("/auth/signup", func(c echo.Context) error { signup := &api.Signup{} if err := jsonapi.UnmarshalPayload(c.Request().Body, signup); err != nil { - return echo.NewHTTPError(http.StatusBadRequest, "Malformatted login request").SetInternal(err) + return echo.NewHTTPError(http.StatusBadRequest, "Malformatted signup request").SetInternal(err) } userFind := &api.UserFind{ @@ -77,12 +82,16 @@ func (s *Server) registerAuthRoutes(g *echo.Group) { return echo.NewHTTPError(http.StatusInternalServerError, "Failed to create user").SetInternal(err) } + err = setUserSession(c, user) + if err != nil { + return echo.NewHTTPError(http.StatusInternalServerError, "Failed to set signup session").SetInternal(err) + } + c.Response().Header().Set(echo.HeaderContentType, echo.MIMEApplicationJSONCharsetUTF8) if err := jsonapi.MarshalPayload(c.Response().Writer, user); err != nil { return echo.NewHTTPError(http.StatusInternalServerError, "Failed to marshal create user response").SetInternal(err) } - setUserSession(c, user) return nil }) } diff --git a/server/jwt.go b/server/jwt.go index f61db24e..fc07ad65 100644 --- a/server/jwt.go +++ b/server/jwt.go @@ -21,33 +21,49 @@ func getUserIdContextKey() string { } // Purpose of this cookie is to store the user's id. -func setUserSession(c echo.Context, user *api.User) { - sess, _ := session.Get("session", c) +func setUserSession(c echo.Context, user *api.User) error { + sess, err := session.Get("session", c) + if err != nil { + return fmt.Errorf("failed to get session") + } sess.Options = &sessions.Options{ Path: "/", MaxAge: 1000 * 3600 * 24 * 30, HttpOnly: true, } - sess.Values[userIdContextKey] = strconv.Itoa(user.Id) - sess.Save(c.Request(), c.Response()) + sess.Values[userIdContextKey] = user.Id + err = sess.Save(c.Request(), c.Response()) + if err != nil { + return fmt.Errorf("failed to set session") + } + + return nil } -func removeUserSession(c echo.Context) { - sess, _ := session.Get("session", c) +func removeUserSession(c echo.Context) error { + sess, err := session.Get("session", c) + if err != nil { + return fmt.Errorf("failed to get session") + } sess.Options = &sessions.Options{ Path: "/", MaxAge: 0, HttpOnly: true, } sess.Values[userIdContextKey] = nil - sess.Save(c.Request(), c.Response()) + err = sess.Save(c.Request(), c.Response()) + if err != nil { + return fmt.Errorf("failed to set session") + } + + return nil } // Use session instead of jwt in the initial version func JWTMiddleware(us api.UserService, next echo.HandlerFunc) echo.HandlerFunc { return func(c echo.Context) error { - // Skips auth, test - if common.HasPrefixes(c.Path(), "/api/auth", "/api/test") { + // Skips auth + if common.HasPrefixes(c.Path(), "/api/auth") { return next(c) } @@ -55,7 +71,13 @@ func JWTMiddleware(us api.UserService, next echo.HandlerFunc) echo.HandlerFunc { if err != nil { return echo.NewHTTPError(http.StatusUnauthorized, "Missing session") } - userId, err := strconv.Atoi(fmt.Sprintf("%v", sess.Values[userIdContextKey])) + + userIdValue := sess.Values[userIdContextKey] + if userIdValue == nil { + return echo.NewHTTPError(http.StatusUnauthorized, "Missing userId in session") + } + + userId, err := strconv.Atoi(fmt.Sprintf("%v", userIdValue)) if err != nil { return echo.NewHTTPError(http.StatusUnauthorized, "Failed to malformatted user id in the session.") } diff --git a/server/server.go b/server/server.go index 9a1b32b1..633b127d 100644 --- a/server/server.go +++ b/server/server.go @@ -3,6 +3,7 @@ package server import ( "fmt" "memos/api" + "memos/common" "github.com/gorilla/sessions" "github.com/labstack/echo-contrib/session" @@ -33,7 +34,7 @@ func NewServer() *Server { HTML5: true, })) - e.Use(session.Middleware(sessions.NewCookieStore([]byte("secret")))) + e.Use(session.Middleware(sessions.NewCookieStore([]byte(common.GenUUID())))) s := &Server{ e: e, diff --git a/store/user.go b/store/user.go index 7471b2ac..2ef3f1d0 100644 --- a/store/user.go +++ b/store/user.go @@ -124,7 +124,7 @@ func patchUser(db *DB, patch *api.UserPatch) (*api.User, error) { } func findUserList(db *DB, find *api.UserFind) ([]*api.User, error) { - where, args := []string{}, []interface{}{} + where, args := []string{"1 = 1"}, []interface{}{} if v := find.Id; v != nil { where, args = append(where, "id = ?"), append(args, *v) @@ -142,7 +142,7 @@ func findUserList(db *DB, find *api.UserFind) ([]*api.User, error) { name, password, open_id, - created_ts + created_ts, updated_ts FROM user WHERE `+strings.Join(where, " AND "), @@ -164,6 +164,7 @@ func findUserList(db *DB, find *api.UserFind) ([]*api.User, error) { &user.CreatedTs, &user.UpdatedTs, ); err != nil { + fmt.Println(err) return nil, FormatError(err) }