diff --git a/go.mod b/go.mod index de49a0093..0de65d847 100644 --- a/go.mod +++ b/go.mod @@ -3,6 +3,7 @@ module github.com/usememos/memos go 1.21 require ( + github.com/Masterminds/squirrel v1.5.4 github.com/aws/aws-sdk-go-v2 v1.22.1 github.com/aws/aws-sdk-go-v2/config v1.22.1 github.com/aws/aws-sdk-go-v2/credentials v1.15.1 @@ -16,6 +17,7 @@ require ( github.com/grpc-ecosystem/grpc-gateway/v2 v2.18.1 github.com/improbable-eng/grpc-web v0.15.0 github.com/labstack/echo/v4 v4.11.2 + github.com/lib/pq v1.10.9 github.com/microcosm-cc/bluemonday v1.0.26 github.com/pkg/errors v0.9.1 github.com/spf13/cobra v1.8.0 @@ -50,6 +52,8 @@ require ( github.com/gorilla/css v1.0.1 // indirect github.com/josharian/intern v1.0.0 // indirect github.com/kballard/go-shellquote v0.0.0-20180428030007-95032a82bc51 // indirect + github.com/lann/builder v0.0.0-20180802200727-47ae307949d0 // indirect + github.com/lann/ps v0.0.0-20150810152359-62de8c46ede0 // indirect github.com/mailru/easyjson v0.7.7 // indirect github.com/remyoudompheng/bigfft v0.0.0-20230129092748-24d4a6f8daec // indirect github.com/rs/cors v1.10.1 // indirect diff --git a/go.sum b/go.sum index 4f68aa79b..6f36c65b9 100644 --- a/go.sum +++ b/go.sum @@ -41,6 +41,8 @@ github.com/BurntSushi/xgb v0.0.0-20160522181843-27f122750802/go.mod h1:IVnqGOEym github.com/Knetic/govaluate v3.0.1-0.20171022003610-9aa49832a739+incompatible/go.mod h1:r7JcOSlj0wfOMncg0iLm8Leh48TZaKVeNIfJntJ2wa0= github.com/KyleBanks/depth v1.2.1 h1:5h8fQADFrWtarTdtDudMmGsC7GPbOAu6RVB3ffsVFHc= github.com/KyleBanks/depth v1.2.1/go.mod h1:jzSb9d0L43HxTQfT+oSA1EEp2q+ne2uh6XgeJcm8brE= +github.com/Masterminds/squirrel v1.5.4 h1:uUcX/aBc8O7Fg9kaISIUsHXdKuqehiXAMQTYX8afzqM= +github.com/Masterminds/squirrel v1.5.4/go.mod h1:NNaOrjSoIDfDA40n7sr2tPNZRfjzjA400rg+riTZj10= github.com/Shopify/sarama v1.19.0/go.mod h1:FVkBWblsNy7DGZRfXLU0O9RCGt5g3g3yEuWXgklEdEo= github.com/Shopify/toxiproxy v2.1.4+incompatible/go.mod h1:OXgGpZ6Cli1/URJOF1DMxUHB2q5Ap20/P/eIdh4G0pI= github.com/VividCortex/gohistogram v1.0.0/go.mod h1:Pf5mBqqDxYaXu3hDrrU+w6nw50o/4+TcAqDqk/vUH7g= @@ -368,7 +370,13 @@ github.com/labstack/echo/v4 v4.11.2 h1:T+cTLQxWCDfqDEoydYm5kCobjmHwOwcv4OJAPHilm github.com/labstack/echo/v4 v4.11.2/go.mod h1:UcGuQ8V6ZNRmSweBIJkPvGfwCMIlFmiqrPqiEBfPYws= github.com/labstack/gommon v0.4.0 h1:y7cvthEAEbU0yHOf4axH8ZG2NH8knB9iNSoTO8dyIk8= github.com/labstack/gommon v0.4.0/go.mod h1:uW6kP17uPlLJsD3ijUYn3/M5bAxtlZhMI6m3MFxTMTM= +github.com/lann/builder v0.0.0-20180802200727-47ae307949d0 h1:SOEGU9fKiNWd/HOJuq6+3iTQz8KNCLtVX6idSoTLdUw= +github.com/lann/builder v0.0.0-20180802200727-47ae307949d0/go.mod h1:dXGbAdH5GtBTC4WfIxhKZfyBF/HBFgRZSWwZ9g/He9o= +github.com/lann/ps v0.0.0-20150810152359-62de8c46ede0 h1:P6pPBnrTSX3DEVR4fDembhRWSsG5rVo6hYhAB/ADZrk= +github.com/lann/ps v0.0.0-20150810152359-62de8c46ede0/go.mod h1:vmVJ0l/dxyfGW6FmdpVm2joNMFikkuWg0EoCKLGUMNw= github.com/leodido/go-urn v1.2.0/go.mod h1:+8+nEpDfqqsY+g338gtMEUOtuK+4dEMhiQEgxpxOKII= +github.com/lib/pq v1.10.9 h1:YXG7RB+JIjhP29X+OtkiDnYaXQwpS4JEWq7dtCCRUEw= +github.com/lib/pq v1.10.9/go.mod h1:AlVN5x4E4T544tWzH6hKfbfQvm3HdbOxrmggDNAPY9o= github.com/lightstep/lightstep-tracer-common/golang/gogo v0.0.0-20190605223551-bc2310a04743/go.mod h1:qklhhLq1aX+mtWk9cPHPzaBjWImj5ULL6C7HFJtXQMM= github.com/lightstep/lightstep-tracer-go v0.18.1/go.mod h1:jlF1pusYV4pidLvZ+XD0UBX0ZE6WURAspgAczcDHrL4= github.com/lyft/protoc-gen-validate v0.0.13/go.mod h1:XbGvPuh87YZc5TdIa2/I4pLk0QoUACkjt2znoq26NVQ= diff --git a/store/db/db.go b/store/db/db.go index 23386050c..47a369385 100644 --- a/store/db/db.go +++ b/store/db/db.go @@ -6,6 +6,7 @@ import ( "github.com/usememos/memos/server/profile" "github.com/usememos/memos/store" "github.com/usememos/memos/store/db/mysql" + "github.com/usememos/memos/store/db/postgres" "github.com/usememos/memos/store/db/sqlite" ) @@ -19,6 +20,8 @@ func NewDBDriver(profile *profile.Profile) (store.Driver, error) { driver, err = sqlite.NewDB(profile) case "mysql": driver, err = mysql.NewDB(profile) + case "postgres": + driver, err = postgres.NewDB(profile) default: return nil, errors.New("unknown db driver") } diff --git a/store/db/postgres/activity.go b/store/db/postgres/activity.go new file mode 100644 index 000000000..9b453927f --- /dev/null +++ b/store/db/postgres/activity.go @@ -0,0 +1,117 @@ +package postgres + +import ( + "context" + "time" + + "github.com/Masterminds/squirrel" + "github.com/pkg/errors" + "google.golang.org/protobuf/encoding/protojson" + + storepb "github.com/usememos/memos/proto/gen/store" + "github.com/usememos/memos/store" +) + +func (d *DB) CreateActivity(ctx context.Context, create *store.Activity) (*store.Activity, error) { + payloadString := "{}" + if create.Payload != nil { + bytes, err := protojson.Marshal(create.Payload) + if err != nil { + return nil, errors.Wrap(err, "failed to marshal activity payload") + } + payloadString = string(bytes) + } + + qb := squirrel.Insert("activity"). + Columns("creator_id", "type", "level", "payload"). + PlaceholderFormat(squirrel.Dollar) + + values := []any{create.CreatorID, create.Type.String(), create.Level.String(), payloadString} + + if create.ID != 0 { + qb = qb.Columns("id") + values = append(values, create.ID) + } + + if create.CreatedTs != 0 { + qb = qb.Columns("created_ts") + values = append(values, squirrel.Expr("TO_TIMESTAMP(?)", create.CreatedTs)) + } + + qb = qb.Values(values...).Suffix("RETURNING id") + + stmt, args, err := qb.ToSql() + if err != nil { + return nil, errors.Wrap(err, "failed to construct query") + } + + var id int32 + err = d.db.QueryRowContext(ctx, stmt, args...).Scan(&id) + if err != nil { + return nil, errors.Wrap(err, "failed to execute statement and retrieve ID") + } + + list, err := d.ListActivities(ctx, &store.FindActivity{ID: &id}) + if err != nil || len(list) == 0 { + return nil, errors.Wrap(err, "failed to find activity") + } + + return list[0], nil +} + +func (d *DB) ListActivities(ctx context.Context, find *store.FindActivity) ([]*store.Activity, error) { + qb := squirrel.Select("id", "creator_id", "type", "level", "payload", "created_ts"). + From("activity"). + Where("1 = 1"). + PlaceholderFormat(squirrel.Dollar) + + if find.ID != nil { + qb = qb.Where(squirrel.Eq{"id": *find.ID}) + } + if find.Type != nil { + qb = qb.Where(squirrel.Eq{"type": find.Type.String()}) + } + + query, args, err := qb.ToSql() + if err != nil { + return nil, err + } + + rows, err := d.db.QueryContext(ctx, query, args...) + if err != nil { + return nil, err + } + defer rows.Close() + + list := []*store.Activity{} + for rows.Next() { + activity := &store.Activity{} + var payloadBytes []byte + createdTsPlaceHolder := time.Time{} + if err := rows.Scan( + &activity.ID, + &activity.CreatorID, + &activity.Type, + &activity.Level, + &payloadBytes, + &createdTsPlaceHolder, + ); err != nil { + return nil, err + } + + activity.CreatedTs = createdTsPlaceHolder.Unix() + + payload := &storepb.ActivityPayload{} + if err := protojson.Unmarshal(payloadBytes, payload); err != nil { + return nil, err + } + activity.Payload = payload + list = append(list, activity) + } + + if err := rows.Err(); err != nil { + return nil, err + } + + return list, nil +} diff --git a/store/db/postgres/common.go b/store/db/postgres/common.go new file mode 100644 index 000000000..fd5706d93 --- /dev/null +++ b/store/db/postgres/common.go @@ -0,0 +1,9 @@ +package postgres + +import "google.golang.org/protobuf/encoding/protojson" + +var ( + protojsonUnmarshaler = protojson.UnmarshalOptions{ + DiscardUnknown: true, + } +) diff --git a/store/db/postgres/idp.go b/store/db/postgres/idp.go new file mode 100644 index 000000000..92672052c --- /dev/null +++ b/store/db/postgres/idp.go @@ -0,0 +1,178 @@ +package postgres + +import ( + "context" + "encoding/json" + + "github.com/Masterminds/squirrel" + "github.com/pkg/errors" + + "github.com/usememos/memos/store" +) + +func (d *DB) CreateIdentityProvider(ctx context.Context, create *store.IdentityProvider) (*store.IdentityProvider, error) { + var configBytes []byte + if create.Type == store.IdentityProviderOAuth2Type { + bytes, err := json.Marshal(create.Config.OAuth2Config) + if err != nil { + return nil, err + } + configBytes = bytes + } else { + return nil, errors.Errorf("unsupported idp type %s", string(create.Type)) + } + + qb := squirrel.Insert("idp").Columns("name", "type", "identifier_filter", "config") + values := []any{create.Name, create.Type, create.IdentifierFilter, string(configBytes)} + + if create.ID != 0 { + qb = qb.Columns("id") + values = append(values, create.ID) + } + + qb = qb.Values(values...).PlaceholderFormat(squirrel.Dollar) + qb = qb.Suffix("RETURNING id") + + stmt, args, err := qb.ToSql() + if err != nil { + return nil, err + } + + var id int32 + err = d.db.QueryRowContext(ctx, stmt, args...).Scan(&id) + if err != nil { + return nil, err + } + + create.ID = id + return create, nil +} +func (d *DB) ListIdentityProviders(ctx context.Context, find *store.FindIdentityProvider) ([]*store.IdentityProvider, error) { + qb := squirrel.Select("id", "name", "type", "identifier_filter", "config"). + From("idp"). + Where("1 = 1"). + PlaceholderFormat(squirrel.Dollar) + + if v := find.ID; v != nil { + qb = qb.Where(squirrel.Eq{"id": *v}) + } + + query, args, err := qb.ToSql() + if err != nil { + return nil, err + } + + rows, err := d.db.QueryContext(ctx, query, args...) + if err != nil { + return nil, err + } + defer rows.Close() + + var identityProviders []*store.IdentityProvider + for rows.Next() { + var identityProvider store.IdentityProvider + var identityProviderConfig string + if err := rows.Scan( + &identityProvider.ID, + &identityProvider.Name, + &identityProvider.Type, + &identityProvider.IdentifierFilter, + &identityProviderConfig, + ); err != nil { + return nil, err + } + + if identityProvider.Type == store.IdentityProviderOAuth2Type { + oauth2Config := &store.IdentityProviderOAuth2Config{} + if err := json.Unmarshal([]byte(identityProviderConfig), oauth2Config); err != nil { + return nil, err + } + identityProvider.Config = &store.IdentityProviderConfig{ + OAuth2Config: oauth2Config, + } + } else { + return nil, errors.Errorf("unsupported idp type %s", string(identityProvider.Type)) + } + identityProviders = append(identityProviders, &identityProvider) + } + + if err := rows.Err(); err != nil { + return nil, err + } + + return identityProviders, nil +} + +func (d *DB) GetIdentityProvider(ctx context.Context, find *store.FindIdentityProvider) (*store.IdentityProvider, error) { + list, err := d.ListIdentityProviders(ctx, find) + if err != nil { + return nil, err + } + if len(list) == 0 { + return nil, nil + } + + return list[0], nil +} + +func (d *DB) UpdateIdentityProvider(ctx context.Context, update *store.UpdateIdentityProvider) (*store.IdentityProvider, error) { + qb := squirrel.Update("idp"). + PlaceholderFormat(squirrel.Dollar) + var err error + + if v := update.Name; v != nil { + qb = qb.Set("name", *v) + } + if v := update.IdentifierFilter; v != nil { + qb = qb.Set("identifier_filter", *v) + } + if v := update.Config; v != nil { + var configBytes []byte + if update.Type == store.IdentityProviderOAuth2Type { + bytes, err := json.Marshal(update.Config.OAuth2Config) + if err != nil { + return nil, err + } + configBytes = bytes + } else { + return nil, errors.Errorf("unsupported idp type %s", string(update.Type)) + } + qb = qb.Set("config", string(configBytes)) + } + + qb = qb.Where(squirrel.Eq{"id": update.ID}) + + stmt, args, err := qb.ToSql() + if err != nil { + return nil, err + } + + _, err = d.db.ExecContext(ctx, stmt, args...) + if err != nil { + return nil, err + } + + return d.GetIdentityProvider(ctx, &store.FindIdentityProvider{ID: &update.ID}) +} + +func (d *DB) DeleteIdentityProvider(ctx context.Context, delete *store.DeleteIdentityProvider) error { + qb := squirrel.Delete("idp"). + Where(squirrel.Eq{"id": delete.ID}). + PlaceholderFormat(squirrel.Dollar) + + stmt, args, err := qb.ToSql() + if err != nil { + return err + } + + result, err := d.db.ExecContext(ctx, stmt, args...) + if err != nil { + return err + } + + if _, err = result.RowsAffected(); err != nil { + return err + } + + return nil +} diff --git a/store/db/postgres/inbox.go b/store/db/postgres/inbox.go new file mode 100644 index 000000000..88463c7c3 --- /dev/null +++ b/store/db/postgres/inbox.go @@ -0,0 +1,144 @@ +package postgres + +import ( + "context" + "time" + + "github.com/Masterminds/squirrel" + "github.com/pkg/errors" + "google.golang.org/protobuf/encoding/protojson" + + storepb "github.com/usememos/memos/proto/gen/store" + "github.com/usememos/memos/store" +) + +func (d *DB) CreateInbox(ctx context.Context, create *store.Inbox) (*store.Inbox, error) { + messageString := "{}" + if create.Message != nil { + bytes, err := protojson.Marshal(create.Message) + if err != nil { + return nil, errors.Wrap(err, "failed to marshal inbox message") + } + messageString = string(bytes) + } + + qb := squirrel.Insert("inbox"). + Columns("sender_id", "receiver_id", "status", "message"). + Values(create.SenderID, create.ReceiverID, create.Status, messageString). + Suffix("RETURNING id"). + PlaceholderFormat(squirrel.Dollar) + + stmt, args, err := qb.ToSql() + if err != nil { + return nil, err + } + + var id int32 + err = d.db.QueryRowContext(ctx, stmt, args...).Scan(&id) + if err != nil { + return nil, err + } + + return d.GetInbox(ctx, &store.FindInbox{ID: &id}) +} + +func (d *DB) ListInboxes(ctx context.Context, find *store.FindInbox) ([]*store.Inbox, error) { + qb := squirrel.Select("id", "created_ts", "sender_id", "receiver_id", "status", "message"). + From("inbox"). + Where("1 = 1"). + PlaceholderFormat(squirrel.Dollar) + + if find.ID != nil { + qb = qb.Where(squirrel.Eq{"id": *find.ID}) + } + if find.SenderID != nil { + qb = qb.Where(squirrel.Eq{"sender_id": *find.SenderID}) + } + if find.ReceiverID != nil { + qb = qb.Where(squirrel.Eq{"receiver_id": *find.ReceiverID}) + } + if find.Status != nil { + qb = qb.Where(squirrel.Eq{"status": *find.Status}) + } + + query, args, err := qb.ToSql() + if err != nil { + return nil, err + } + + rows, err := d.db.QueryContext(ctx, query, args...) + if err != nil { + return nil, err + } + defer rows.Close() + + var list []*store.Inbox + for rows.Next() { + inbox := &store.Inbox{} + var messageBytes []byte + createdTsPlaceHolder := time.Time{} + if err := rows.Scan(&inbox.ID, &createdTsPlaceHolder, &inbox.SenderID, &inbox.ReceiverID, &inbox.Status, &messageBytes); err != nil { + return nil, err + } + + inbox.CreatedTs = createdTsPlaceHolder.Unix() + + message := &storepb.InboxMessage{} + if err := protojsonUnmarshaler.Unmarshal(messageBytes, message); err != nil { + return nil, err + } + inbox.Message = message + list = append(list, inbox) + } + + return list, rows.Err() +} + +func (d *DB) GetInbox(ctx context.Context, find *store.FindInbox) (*store.Inbox, error) { + list, err := d.ListInboxes(ctx, find) + if err != nil { + return nil, errors.Wrap(err, "failed to get inbox") + } + if len(list) != 1 { + return nil, errors.Wrapf(nil, "unexpected inbox count: %d", len(list)) + } + return list[0], nil +} + +func (d *DB) UpdateInbox(ctx context.Context, update *store.UpdateInbox) (*store.Inbox, error) { + qb := squirrel.Update("inbox"). + Set("status", update.Status.String()). + Where(squirrel.Eq{"id": update.ID}). + PlaceholderFormat(squirrel.Dollar) + + stmt, args, err := qb.ToSql() + if err != nil { + return nil, err + } + + _, err = d.db.ExecContext(ctx, stmt, args...) + if err != nil { + return nil, err + } + + return d.GetInbox(ctx, &store.FindInbox{ID: &update.ID}) +} + +func (d *DB) DeleteInbox(ctx context.Context, delete *store.DeleteInbox) error { + qb := squirrel.Delete("inbox"). + Where(squirrel.Eq{"id": delete.ID}). + PlaceholderFormat(squirrel.Dollar) + + stmt, args, err := qb.ToSql() + if err != nil { + return err + } + + result, err := d.db.ExecContext(ctx, stmt, args...) + if err != nil { + return err + } + + _, err = result.RowsAffected() + return err +} diff --git a/store/db/postgres/memo.go b/store/db/postgres/memo.go new file mode 100644 index 000000000..f2f740ef5 --- /dev/null +++ b/store/db/postgres/memo.go @@ -0,0 +1,370 @@ +package postgres + +import ( + "context" + "database/sql" + "encoding/binary" + "fmt" + "strings" + + "github.com/Masterminds/squirrel" + "github.com/pkg/errors" + + "github.com/usememos/memos/internal/util" + "github.com/usememos/memos/store" +) + +func (d *DB) CreateMemo(ctx context.Context, create *store.Memo) (*store.Memo, error) { + // Initialize a Squirrel statement builder for PostgreSQL + builder := squirrel.Insert("memo"). + PlaceholderFormat(squirrel.Dollar). + Columns("creator_id", "content", "visibility") + + // Add initial values for the columns + values := []any{create.CreatorID, create.Content, create.Visibility} + + // Conditionally add other fields and values + if create.ID != 0 { + builder = builder.Columns("id") + values = append(values, create.ID) + } + + if create.CreatedTs != 0 { + builder = builder.Columns("created_ts") + values = append(values, squirrel.Expr("to_timestamp(?)", create.CreatedTs)) + } + + if create.UpdatedTs != 0 { + builder = builder.Columns("updated_ts") + values = append(values, squirrel.Expr("to_timestamp(?)", create.UpdatedTs)) + } + + if create.RowStatus != "" { + builder = builder.Columns("row_status") + values = append(values, create.RowStatus) + } + + // Add all the values at once + builder = builder.Values(values...) + + // Add the RETURNING clause to get the ID of the inserted row + builder = builder.Suffix("RETURNING id") + + // Prepare and execute the query + query, args, err := builder.ToSql() + if err != nil { + return nil, err + } + + var id int32 + err = d.db.QueryRowContext(ctx, query, args...).Scan(&id) + if err != nil { + return nil, err + } + + // Retrieve the newly created memo + memo, err := d.GetMemo(ctx, &store.FindMemo{ID: &id}) + if err != nil { + return nil, err + } + if memo == nil { + return nil, errors.Errorf("failed to create memo") + } + + return memo, nil +} + +func (d *DB) ListMemos(ctx context.Context, find *store.FindMemo) ([]*store.Memo, error) { + // Start building the SELECT statement + builder := squirrel.Select( + "memo.id AS id", + "memo.creator_id AS creator_id", + "EXTRACT(EPOCH FROM memo.created_ts) AS created_ts", + "EXTRACT(EPOCH FROM memo.updated_ts) AS updated_ts", + "memo.row_status AS row_status", + "memo.content AS content", + "memo.visibility AS visibility", + "MAX(CASE WHEN memo_organizer.pinned = 1 THEN 1 ELSE 0 END) AS pinned", + "string_agg(CAST(resource.id AS TEXT), ',') AS resource_id_list", // Cast to TEXT + "(SELECT string_agg(CAST(memo_id AS TEXT) || ':' || CAST(related_memo_id AS TEXT) || ':' || type, ',') FROM memo_relation WHERE memo_relation.memo_id = memo.id OR memo_relation.related_memo_id = memo.id) AS relation_list"). // Cast IDs to TEXT + From("memo"). + LeftJoin("memo_organizer ON memo.id = memo_organizer.memo_id"). + LeftJoin("resource ON memo.id = resource.memo_id"). + GroupBy("memo.id"). + PlaceholderFormat(squirrel.Dollar) + + // Add conditional where clauses + if v := find.ID; v != nil { + builder = builder.Where("memo.id = ?", *v) + } + if v := find.CreatorID; v != nil { + builder = builder.Where("memo.creator_id = ?", *v) + } + if v := find.RowStatus; v != nil { + builder = builder.Where("memo.row_status = ?", *v) + } + if v := find.CreatedTsBefore; v != nil { + builder = builder.Where("EXTRACT(EPOCH FROM memo.created_ts) < ?", *v) + } + if v := find.CreatedTsAfter; v != nil { + builder = builder.Where("EXTRACT(EPOCH FROM memo.created_ts) > ?", *v) + } + if v := find.Pinned; v != nil { + builder = builder.Where("memo_organizer.pinned = 1") + } + if v := find.ContentSearch; len(v) != 0 { + for _, s := range v { + builder = builder.Where("memo.content LIKE ?", "%"+s+"%") + } + } + + if v := find.VisibilityList; len(v) != 0 { + placeholders := make([]string, len(v)) + args := make([]any, len(v)) + for i, visibility := range v { + placeholders[i] = "?" + args[i] = visibility // Assuming visibility can be directly used as an argument + } + inClause := strings.Join(placeholders, ",") + builder = builder.Where("memo.visibility IN ("+inClause+")", args...) + } + // Add order by clauses + if find.OrderByPinned { + builder = builder.OrderBy("pinned DESC") + } + if find.OrderByUpdatedTs { + builder = builder.OrderBy("updated_ts DESC") + } else { + builder = builder.OrderBy("created_ts DESC") + } + builder = builder.OrderBy("id DESC") + + // Handle pagination + if find.Limit != nil { + builder = builder.Limit(uint64(*find.Limit)) + if find.Offset != nil { + builder = builder.Offset(uint64(*find.Offset)) + } + } + + // Prepare and execute the query + query, args, err := builder.ToSql() + if err != nil { + return nil, err + } + + rows, err := d.db.QueryContext(ctx, query, args...) + if err != nil { + return nil, err + } + defer rows.Close() + + // Process the result set + list := make([]*store.Memo, 0) + updatedTsPlaceHolder, createdTsPlaceHolder := make([]uint8, 8), make([]uint8, 8) + for rows.Next() { + var memo store.Memo + var memoResourceIDList sql.NullString + var memoRelationList sql.NullString + if err := rows.Scan( + &memo.ID, + &memo.CreatorID, + &createdTsPlaceHolder, + &updatedTsPlaceHolder, + &memo.RowStatus, + &memo.Content, + &memo.Visibility, + &memo.Pinned, + &memoResourceIDList, + &memoRelationList, + ); err != nil { + return nil, err + } + + // Convert the timestamps from Postgres to Go + memo.CreatedTs = int64(binary.BigEndian.Uint64(createdTsPlaceHolder)) + memo.UpdatedTs = int64(binary.BigEndian.Uint64(updatedTsPlaceHolder)) + + if memoResourceIDList.Valid { + idStringList := strings.Split(memoResourceIDList.String, ",") + memo.ResourceIDList = make([]int32, 0, len(idStringList)) + for _, idString := range idStringList { + id, err := util.ConvertStringToInt32(idString) + if err != nil { + return nil, err + } + memo.ResourceIDList = append(memo.ResourceIDList, id) + } + } + if memoRelationList.Valid { + memo.RelationList = make([]*store.MemoRelation, 0) + relatedMemoTypeList := strings.Split(memoRelationList.String, ",") + for _, relatedMemoType := range relatedMemoTypeList { + relatedMemoTypeList := strings.Split(relatedMemoType, ":") + if len(relatedMemoTypeList) != 3 { + return nil, errors.Errorf("invalid relation format") + } + memoID, err := util.ConvertStringToInt32(relatedMemoTypeList[0]) + if err != nil { + return nil, err + } + relatedMemoID, err := util.ConvertStringToInt32(relatedMemoTypeList[1]) + if err != nil { + return nil, err + } + relationType := store.MemoRelationType(relatedMemoTypeList[2]) + memo.RelationList = append(memo.RelationList, &store.MemoRelation{ + MemoID: memoID, + RelatedMemoID: relatedMemoID, + Type: relationType, + }) + // Set the first parent ID if relation type is comment. + if memo.ParentID == nil && memoID == memo.ID && relationType == store.MemoRelationComment { + memo.ParentID = &relatedMemoID + } + } + } + list = append(list, &memo) + } + + if err := rows.Err(); err != nil { + return nil, err + } + + return list, nil +} + +func (d *DB) GetMemo(ctx context.Context, find *store.FindMemo) (*store.Memo, error) { + list, err := d.ListMemos(ctx, find) + if err != nil { + return nil, err + } + if len(list) == 0 { + return nil, nil + } + + memo := list[0] + return memo, nil +} + +func (d *DB) UpdateMemo(ctx context.Context, update *store.UpdateMemo) error { + // Start building the update statement + builder := squirrel.Update("memo"). + PlaceholderFormat(squirrel.Dollar). + Where("id = ?", update.ID) + + // Conditionally add set clauses + if v := update.CreatedTs; v != nil { + builder = builder.Set("created_ts", squirrel.Expr("to_timestamp(?)", *v)) + } + if v := update.UpdatedTs; v != nil { + builder = builder.Set("updated_ts", squirrel.Expr("to_timestamp(?)", *v)) + } + if v := update.RowStatus; v != nil { + builder = builder.Set("row_status", *v) + } + if v := update.Content; v != nil { + builder = builder.Set("content", *v) + } + if v := update.Visibility; v != nil { + builder = builder.Set("visibility", *v) + } + + // Prepare and execute the query + query, args, err := builder.ToSql() + if err != nil { + return err + } + + if _, err := d.db.ExecContext(ctx, query, args...); err != nil { + return err + } + + return nil +} + +func (d *DB) DeleteMemo(ctx context.Context, delete *store.DeleteMemo) error { + // Start building the DELETE statement + builder := squirrel.Delete("memo"). + PlaceholderFormat(squirrel.Dollar). + Where(squirrel.Eq{"id": delete.ID}) + + // Prepare the final query + query, args, err := builder.ToSql() + if err != nil { + return err + } + + // Execute the query with the context + result, err := d.db.ExecContext(ctx, query, args...) + if err != nil { + return err + } + + if _, err := result.RowsAffected(); err != nil { + return err + } + + // Perform any additional cleanup or operations such as vacuuming + // irving: wait, why do we need to vacuum here? + // I don't know why delete memo needs to vacuum. so I commented out. + // REVIEWERS LOOK AT ME: please check this. + + return d.Vacuum(ctx) +} + +func (d *DB) FindMemosVisibilityList(ctx context.Context, memoIDs []int32) ([]store.Visibility, error) { + // Start building the SELECT statement + builder := squirrel.Select("DISTINCT(visibility)").From("memo"). + PlaceholderFormat(squirrel.Dollar). + Where(squirrel.Eq{"id": memoIDs}) + + // Prepare the final query + query, args, err := builder.ToSql() + if err != nil { + return nil, err + } + + // Execute the query with the context + rows, err := d.db.QueryContext(ctx, query, args...) + if err != nil { + return nil, err + } + defer rows.Close() + + visibilityList := make([]store.Visibility, 0) + for rows.Next() { + var visibility store.Visibility + if err := rows.Scan(&visibility); err != nil { + return nil, err + } + visibilityList = append(visibilityList, visibility) + } + + if err := rows.Err(); err != nil { + return nil, err + } + + return visibilityList, nil +} + +func vacuumMemo(ctx context.Context, tx *sql.Tx) error { + // First, build the subquery + subQuery, subArgs, err := squirrel.Select("id").From("user").PlaceholderFormat(squirrel.Dollar).ToSql() + if err != nil { + return err + } + + // Now, build the main delete query using the subquery + query, args, err := squirrel.Delete("memo"). + Where(fmt.Sprintf("creator_id NOT IN (%s)", subQuery), subArgs...). + PlaceholderFormat(squirrel.Dollar). + ToSql() + if err != nil { + return err + } + + // Execute the query + _, err = tx.ExecContext(ctx, query, args...) + return err +} diff --git a/store/db/postgres/memo_organizer.go b/store/db/postgres/memo_organizer.go new file mode 100644 index 000000000..2a85d5a14 --- /dev/null +++ b/store/db/postgres/memo_organizer.go @@ -0,0 +1,123 @@ +package postgres + +import ( + "context" + "database/sql" + "fmt" + + "github.com/Masterminds/squirrel" + + "github.com/usememos/memos/store" +) + +func (d *DB) UpsertMemoOrganizer(ctx context.Context, upsert *store.MemoOrganizer) (*store.MemoOrganizer, error) { + pinnedValue := 0 + if upsert.Pinned { + pinnedValue = 1 + } + qb := squirrel.Insert("memo_organizer"). + Columns("memo_id", "user_id", "pinned"). + Values(upsert.MemoID, upsert.UserID, pinnedValue). + PlaceholderFormat(squirrel.Dollar) + + stmt, args, err := qb.ToSql() + if err != nil { + return nil, err + } + + if _, err = d.db.ExecContext(ctx, stmt, args...); err != nil { + return nil, err + } + + return upsert, nil +} + +func (d *DB) ListMemoOrganizer(ctx context.Context, find *store.FindMemoOrganizer) ([]*store.MemoOrganizer, error) { + qb := squirrel.Select("memo_id", "user_id", "pinned"). + From("memo_organizer"). + Where("1 = 1"). + PlaceholderFormat(squirrel.Dollar) + + if find.MemoID != 0 { + qb = qb.Where(squirrel.Eq{"memo_id": find.MemoID}) + } + if find.UserID != 0 { + qb = qb.Where(squirrel.Eq{"user_id": find.UserID}) + } + + query, args, err := qb.ToSql() + if err != nil { + return nil, err + } + + rows, err := d.db.QueryContext(ctx, query, args...) + if err != nil { + return nil, err + } + defer rows.Close() + + var list []*store.MemoOrganizer + for rows.Next() { + memoOrganizer := &store.MemoOrganizer{} + if err := rows.Scan(&memoOrganizer.MemoID, &memoOrganizer.UserID, &memoOrganizer.Pinned); err != nil { + return nil, err + } + list = append(list, memoOrganizer) + } + + return list, rows.Err() +} + +func (d *DB) DeleteMemoOrganizer(ctx context.Context, delete *store.DeleteMemoOrganizer) error { + qb := squirrel.Delete("memo_organizer"). + PlaceholderFormat(squirrel.Dollar) + + if v := delete.MemoID; v != nil { + qb = qb.Where(squirrel.Eq{"memo_id": *v}) + } + if v := delete.UserID; v != nil { + qb = qb.Where(squirrel.Eq{"user_id": *v}) + } + + stmt, args, err := qb.ToSql() + if err != nil { + return err + } + + if _, err = d.db.ExecContext(ctx, stmt, args...); err != nil { + return err + } + + return nil +} + +func vacuumMemoOrganizer(ctx context.Context, tx *sql.Tx) error { + // First, build the subquery for memo_id + subQueryMemo, subArgsMemo, err := squirrel.Select("id").From("memo").PlaceholderFormat(squirrel.Dollar).ToSql() + if err != nil { + return err + } + + // Build the subquery for user_id + subQueryUser, subArgsUser, err := squirrel.Select("id").From("\"user\"").PlaceholderFormat(squirrel.Dollar).ToSql() + if err != nil { + return err + } + + // Now, build the main delete query using the subqueries + query, args, err := squirrel.Delete("memo_organizer"). + Where(fmt.Sprintf("memo_id NOT IN (%s)", subQueryMemo), subArgsMemo...). + Where(fmt.Sprintf("user_id NOT IN (%s)", subQueryUser), subArgsUser...). + PlaceholderFormat(squirrel.Dollar). + ToSql() + if err != nil { + return err + } + + // Combine the arguments from both subqueries + args = append(args, subArgsUser...) + + // Execute the query + _, err = tx.ExecContext(ctx, query, args...) + return err +} diff --git a/store/db/postgres/memo_relation.go b/store/db/postgres/memo_relation.go new file mode 100644 index 000000000..fd62f483e --- /dev/null +++ b/store/db/postgres/memo_relation.go @@ -0,0 +1,128 @@ +package postgres + +import ( + "context" + "database/sql" + "fmt" + + "github.com/Masterminds/squirrel" + + "github.com/usememos/memos/store" +) + +func (d *DB) UpsertMemoRelation(ctx context.Context, create *store.MemoRelation) (*store.MemoRelation, error) { + qb := squirrel.Insert("memo_relation"). + Columns("memo_id", "related_memo_id", "type"). + Values(create.MemoID, create.RelatedMemoID, create.Type). + PlaceholderFormat(squirrel.Dollar) + + stmt, args, err := qb.ToSql() + if err != nil { + return nil, err + } + + _, err = d.db.ExecContext(ctx, stmt, args...) + if err != nil { + return nil, err + } + + return &store.MemoRelation{ + MemoID: create.MemoID, + RelatedMemoID: create.RelatedMemoID, + Type: create.Type, + }, nil +} + +func (d *DB) ListMemoRelations(ctx context.Context, find *store.FindMemoRelation) ([]*store.MemoRelation, error) { + qb := squirrel.Select("memo_id", "related_memo_id", "type"). + From("memo_relation"). + Where("TRUE"). + PlaceholderFormat(squirrel.Dollar) + + if find.MemoID != nil { + qb = qb.Where(squirrel.Eq{"memo_id": *find.MemoID}) + } + if find.RelatedMemoID != nil { + qb = qb.Where(squirrel.Eq{"related_memo_id": *find.RelatedMemoID}) + } + if find.Type != nil { + qb = qb.Where(squirrel.Eq{"type": *find.Type}) + } + + query, args, err := qb.ToSql() + if err != nil { + return nil, err + } + + rows, err := d.db.QueryContext(ctx, query, args...) + if err != nil { + return nil, err + } + defer rows.Close() + + var list []*store.MemoRelation + for rows.Next() { + memoRelation := &store.MemoRelation{} + if err := rows.Scan(&memoRelation.MemoID, &memoRelation.RelatedMemoID, &memoRelation.Type); err != nil { + return nil, err + } + list = append(list, memoRelation) + } + + return list, rows.Err() +} + +func (d *DB) DeleteMemoRelation(ctx context.Context, delete *store.DeleteMemoRelation) error { + qb := squirrel.Delete("memo_relation"). + PlaceholderFormat(squirrel.Dollar) + + if delete.MemoID != nil { + qb = qb.Where(squirrel.Eq{"memo_id": *delete.MemoID}) + } + if delete.RelatedMemoID != nil { + qb = qb.Where(squirrel.Eq{"related_memo_id": *delete.RelatedMemoID}) + } + if delete.Type != nil { + qb = qb.Where(squirrel.Eq{"type": *delete.Type}) + } + + stmt, args, err := qb.ToSql() + if err != nil { + return err + } + + result, err := d.db.ExecContext(ctx, stmt, args...) + if err != nil { + return err + } + + _, err = result.RowsAffected() + return err +} + +func vacuumMemoRelations(ctx context.Context, tx *sql.Tx) error { + // First, build the subquery for memo_id + subQueryMemo, subArgsMemo, err := squirrel.Select("id").From("memo").PlaceholderFormat(squirrel.Dollar).ToSql() + if err != nil { + return err + } + + // Note: The same subquery is used for related_memo_id as it's also checking against the "memo" table + + // Now, build the main delete query using the subqueries + query, args, err := squirrel.Delete("memo_relation"). + Where(fmt.Sprintf("memo_id NOT IN (%s)", subQueryMemo), subArgsMemo...). + Where(fmt.Sprintf("related_memo_id NOT IN (%s)", subQueryMemo), subArgsMemo...). + PlaceholderFormat(squirrel.Dollar). + ToSql() + if err != nil { + return err + } + + // Combine the arguments for both instances of the same subquery + args = append(args, subArgsMemo...) + + // Execute the query + _, err = tx.ExecContext(ctx, query, args...) + return err +} diff --git a/store/db/postgres/migration/dev/LATEST__SCHEMA.sql b/store/db/postgres/migration/dev/LATEST__SCHEMA.sql new file mode 100644 index 000000000..f60d1fdd3 --- /dev/null +++ b/store/db/postgres/migration/dev/LATEST__SCHEMA.sql @@ -0,0 +1,163 @@ +-- drop all tables first (PostgreSQL style) +DROP TABLE IF EXISTS migration_history CASCADE; + +DROP TABLE IF EXISTS system_setting CASCADE; + +DROP TABLE IF EXISTS "user" CASCADE; + +DROP TABLE IF EXISTS user_setting CASCADE; + +DROP TABLE IF EXISTS memo CASCADE; + +DROP TABLE IF EXISTS memo_organizer CASCADE; + +DROP TABLE IF EXISTS memo_relation CASCADE; + +DROP TABLE IF EXISTS resource CASCADE; + +DROP TABLE IF EXISTS tag CASCADE; + +DROP TABLE IF EXISTS activity CASCADE; + +DROP TABLE IF EXISTS storage CASCADE; + +DROP TABLE IF EXISTS idp CASCADE; + +DROP TABLE IF EXISTS inbox CASCADE; + +DROP TABLE IF EXISTS webhook CASCADE; + +-- migration_history +CREATE TABLE migration_history ( + version VARCHAR(255) NOT NULL PRIMARY KEY, + created_ts TIMESTAMP WITHOUT TIME ZONE NOT NULL DEFAULT CURRENT_TIMESTAMP +); + +-- system_setting +CREATE TABLE system_setting ( + name VARCHAR(255) NOT NULL PRIMARY KEY, + value TEXT NOT NULL, + description TEXT NOT NULL +); + +-- user +CREATE TABLE "user" ( + id SERIAL PRIMARY KEY, + created_ts TIMESTAMP WITHOUT TIME ZONE NOT NULL DEFAULT CURRENT_TIMESTAMP, + updated_ts TIMESTAMP WITHOUT TIME ZONE NOT NULL DEFAULT CURRENT_TIMESTAMP, + row_status VARCHAR(255) NOT NULL DEFAULT 'NORMAL', + username VARCHAR(255) NOT NULL UNIQUE, + role VARCHAR(255) NOT NULL DEFAULT 'USER', + email VARCHAR(255) NOT NULL DEFAULT '', + nickname VARCHAR(255) NOT NULL DEFAULT '', + password_hash VARCHAR(255) NOT NULL, + avatar_url TEXT NOT NULL +); + +-- user_setting +CREATE TABLE user_setting ( + user_id INT NOT NULL, + key VARCHAR(255) NOT NULL, + value TEXT NOT NULL, + UNIQUE(user_id, key), + FOREIGN KEY (user_id) REFERENCES "user"(id) ON DELETE CASCADE +); + +-- memo +CREATE TABLE memo ( + id SERIAL PRIMARY KEY, + creator_id INT NOT NULL, + created_ts TIMESTAMP WITHOUT TIME ZONE NOT NULL DEFAULT CURRENT_TIMESTAMP, + updated_ts TIMESTAMP WITHOUT TIME ZONE NOT NULL DEFAULT CURRENT_TIMESTAMP, + row_status VARCHAR(255) NOT NULL DEFAULT 'NORMAL', + content TEXT NOT NULL, + visibility VARCHAR(255) NOT NULL DEFAULT 'PRIVATE' +); + +-- memo_organizer +CREATE TABLE memo_organizer ( + memo_id INT NOT NULL, + user_id INT NOT NULL, + pinned INT NOT NULL DEFAULT 0, + UNIQUE(memo_id, user_id) +); + +-- memo_relation +CREATE TABLE memo_relation ( + memo_id INT NOT NULL, + related_memo_id INT NOT NULL, + type VARCHAR(256) NOT NULL, + UNIQUE(memo_id, related_memo_id, type), + FOREIGN KEY (memo_id) REFERENCES memo(id) ON DELETE CASCADE, + FOREIGN KEY (related_memo_id) REFERENCES memo(id) ON DELETE CASCADE +); + +-- resource +CREATE TABLE resource ( + id SERIAL PRIMARY KEY, + creator_id INT NOT NULL, + created_ts TIMESTAMP WITHOUT TIME ZONE NOT NULL DEFAULT CURRENT_TIMESTAMP, + updated_ts TIMESTAMP WITHOUT TIME ZONE NOT NULL DEFAULT CURRENT_TIMESTAMP, + filename TEXT NOT NULL, + blob BYTEA, + external_link TEXT NOT NULL, + type VARCHAR(255) NOT NULL DEFAULT '', + size INT NOT NULL DEFAULT 0, + internal_path VARCHAR(255) NOT NULL DEFAULT '', + memo_id INT DEFAULT NULL +); + +-- tag +CREATE TABLE tag ( + name VARCHAR(255) NOT NULL, + creator_id INT NOT NULL, + UNIQUE(name, creator_id) +); + +-- activity +CREATE TABLE activity ( + id SERIAL PRIMARY KEY, + creator_id INT NOT NULL, + created_ts TIMESTAMP WITHOUT TIME ZONE NOT NULL DEFAULT CURRENT_TIMESTAMP, + type VARCHAR(255) NOT NULL DEFAULT '', + level VARCHAR(255) NOT NULL DEFAULT 'INFO', + payload TEXT NOT NULL +); + +-- storage +CREATE TABLE storage ( + id SERIAL PRIMARY KEY, + name VARCHAR(256) NOT NULL, + type VARCHAR(256) NOT NULL, + config TEXT NOT NULL +); + +-- idp +CREATE TABLE idp ( + id SERIAL PRIMARY KEY, + name TEXT NOT NULL, + type TEXT NOT NULL, + identifier_filter VARCHAR(256) NOT NULL DEFAULT '', + config TEXT NOT NULL +); + +-- inbox +CREATE TABLE inbox ( + id SERIAL PRIMARY KEY, + created_ts TIMESTAMP WITHOUT TIME ZONE NOT NULL DEFAULT CURRENT_TIMESTAMP, + sender_id INT NOT NULL, + receiver_id INT NOT NULL, + status TEXT NOT NULL, + message TEXT NOT NULL +); + +-- webhook +CREATE TABLE webhook ( + id SERIAL PRIMARY KEY, + created_ts TIMESTAMP WITHOUT TIME ZONE NOT NULL DEFAULT CURRENT_TIMESTAMP, + updated_ts TIMESTAMP WITHOUT TIME ZONE NOT NULL DEFAULT CURRENT_TIMESTAMP, + row_status TEXT NOT NULL DEFAULT 'NORMAL', + creator_id INT NOT NULL, + name TEXT NOT NULL, + url TEXT NOT NULL +); \ No newline at end of file diff --git a/store/db/postgres/migration/prod/LATEST__SCHEMA.sql b/store/db/postgres/migration/prod/LATEST__SCHEMA.sql new file mode 100644 index 000000000..f60d1fdd3 --- /dev/null +++ b/store/db/postgres/migration/prod/LATEST__SCHEMA.sql @@ -0,0 +1,163 @@ +-- drop all tables first (PostgreSQL style) +DROP TABLE IF EXISTS migration_history CASCADE; + +DROP TABLE IF EXISTS system_setting CASCADE; + +DROP TABLE IF EXISTS "user" CASCADE; + +DROP TABLE IF EXISTS user_setting CASCADE; + +DROP TABLE IF EXISTS memo CASCADE; + +DROP TABLE IF EXISTS memo_organizer CASCADE; + +DROP TABLE IF EXISTS memo_relation CASCADE; + +DROP TABLE IF EXISTS resource CASCADE; + +DROP TABLE IF EXISTS tag CASCADE; + +DROP TABLE IF EXISTS activity CASCADE; + +DROP TABLE IF EXISTS storage CASCADE; + +DROP TABLE IF EXISTS idp CASCADE; + +DROP TABLE IF EXISTS inbox CASCADE; + +DROP TABLE IF EXISTS webhook CASCADE; + +-- migration_history +CREATE TABLE migration_history ( + version VARCHAR(255) NOT NULL PRIMARY KEY, + created_ts TIMESTAMP WITHOUT TIME ZONE NOT NULL DEFAULT CURRENT_TIMESTAMP +); + +-- system_setting +CREATE TABLE system_setting ( + name VARCHAR(255) NOT NULL PRIMARY KEY, + value TEXT NOT NULL, + description TEXT NOT NULL +); + +-- user +CREATE TABLE "user" ( + id SERIAL PRIMARY KEY, + created_ts TIMESTAMP WITHOUT TIME ZONE NOT NULL DEFAULT CURRENT_TIMESTAMP, + updated_ts TIMESTAMP WITHOUT TIME ZONE NOT NULL DEFAULT CURRENT_TIMESTAMP, + row_status VARCHAR(255) NOT NULL DEFAULT 'NORMAL', + username VARCHAR(255) NOT NULL UNIQUE, + role VARCHAR(255) NOT NULL DEFAULT 'USER', + email VARCHAR(255) NOT NULL DEFAULT '', + nickname VARCHAR(255) NOT NULL DEFAULT '', + password_hash VARCHAR(255) NOT NULL, + avatar_url TEXT NOT NULL +); + +-- user_setting +CREATE TABLE user_setting ( + user_id INT NOT NULL, + key VARCHAR(255) NOT NULL, + value TEXT NOT NULL, + UNIQUE(user_id, key), + FOREIGN KEY (user_id) REFERENCES "user"(id) ON DELETE CASCADE +); + +-- memo +CREATE TABLE memo ( + id SERIAL PRIMARY KEY, + creator_id INT NOT NULL, + created_ts TIMESTAMP WITHOUT TIME ZONE NOT NULL DEFAULT CURRENT_TIMESTAMP, + updated_ts TIMESTAMP WITHOUT TIME ZONE NOT NULL DEFAULT CURRENT_TIMESTAMP, + row_status VARCHAR(255) NOT NULL DEFAULT 'NORMAL', + content TEXT NOT NULL, + visibility VARCHAR(255) NOT NULL DEFAULT 'PRIVATE' +); + +-- memo_organizer +CREATE TABLE memo_organizer ( + memo_id INT NOT NULL, + user_id INT NOT NULL, + pinned INT NOT NULL DEFAULT 0, + UNIQUE(memo_id, user_id) +); + +-- memo_relation +CREATE TABLE memo_relation ( + memo_id INT NOT NULL, + related_memo_id INT NOT NULL, + type VARCHAR(256) NOT NULL, + UNIQUE(memo_id, related_memo_id, type), + FOREIGN KEY (memo_id) REFERENCES memo(id) ON DELETE CASCADE, + FOREIGN KEY (related_memo_id) REFERENCES memo(id) ON DELETE CASCADE +); + +-- resource +CREATE TABLE resource ( + id SERIAL PRIMARY KEY, + creator_id INT NOT NULL, + created_ts TIMESTAMP WITHOUT TIME ZONE NOT NULL DEFAULT CURRENT_TIMESTAMP, + updated_ts TIMESTAMP WITHOUT TIME ZONE NOT NULL DEFAULT CURRENT_TIMESTAMP, + filename TEXT NOT NULL, + blob BYTEA, + external_link TEXT NOT NULL, + type VARCHAR(255) NOT NULL DEFAULT '', + size INT NOT NULL DEFAULT 0, + internal_path VARCHAR(255) NOT NULL DEFAULT '', + memo_id INT DEFAULT NULL +); + +-- tag +CREATE TABLE tag ( + name VARCHAR(255) NOT NULL, + creator_id INT NOT NULL, + UNIQUE(name, creator_id) +); + +-- activity +CREATE TABLE activity ( + id SERIAL PRIMARY KEY, + creator_id INT NOT NULL, + created_ts TIMESTAMP WITHOUT TIME ZONE NOT NULL DEFAULT CURRENT_TIMESTAMP, + type VARCHAR(255) NOT NULL DEFAULT '', + level VARCHAR(255) NOT NULL DEFAULT 'INFO', + payload TEXT NOT NULL +); + +-- storage +CREATE TABLE storage ( + id SERIAL PRIMARY KEY, + name VARCHAR(256) NOT NULL, + type VARCHAR(256) NOT NULL, + config TEXT NOT NULL +); + +-- idp +CREATE TABLE idp ( + id SERIAL PRIMARY KEY, + name TEXT NOT NULL, + type TEXT NOT NULL, + identifier_filter VARCHAR(256) NOT NULL DEFAULT '', + config TEXT NOT NULL +); + +-- inbox +CREATE TABLE inbox ( + id SERIAL PRIMARY KEY, + created_ts TIMESTAMP WITHOUT TIME ZONE NOT NULL DEFAULT CURRENT_TIMESTAMP, + sender_id INT NOT NULL, + receiver_id INT NOT NULL, + status TEXT NOT NULL, + message TEXT NOT NULL +); + +-- webhook +CREATE TABLE webhook ( + id SERIAL PRIMARY KEY, + created_ts TIMESTAMP WITHOUT TIME ZONE NOT NULL DEFAULT CURRENT_TIMESTAMP, + updated_ts TIMESTAMP WITHOUT TIME ZONE NOT NULL DEFAULT CURRENT_TIMESTAMP, + row_status TEXT NOT NULL DEFAULT 'NORMAL', + creator_id INT NOT NULL, + name TEXT NOT NULL, + url TEXT NOT NULL +); \ No newline at end of file diff --git a/store/db/postgres/migration_history.go b/store/db/postgres/migration_history.go new file mode 100644 index 000000000..1a7d20392 --- /dev/null +++ b/store/db/postgres/migration_history.go @@ -0,0 +1,79 @@ +package postgres + +import ( + "context" + "time" + + "github.com/Masterminds/squirrel" + + "github.com/usememos/memos/store" +) + +func (d *DB) FindMigrationHistoryList(ctx context.Context, _ *store.FindMigrationHistory) ([]*store.MigrationHistory, error) { + qb := squirrel.Select("version", "created_ts"). + From("migration_history"). + OrderBy("created_ts DESC") + + query, args, err := qb.PlaceholderFormat(squirrel.Dollar).ToSql() + if err != nil { + return nil, err + } + + rows, err := d.db.QueryContext(ctx, query, args...) + if err != nil { + return nil, err + } + defer rows.Close() + + list := make([]*store.MigrationHistory, 0) + for rows.Next() { + var migrationHistory store.MigrationHistory + var createdTs time.Time + if err := rows.Scan(&migrationHistory.Version, &createdTs); err != nil { + return nil, err + } + migrationHistory.CreatedTs = createdTs.UnixNano() + list = append(list, &migrationHistory) + } + + if err := rows.Err(); err != nil { + return nil, err + } + + return list, nil +} + +func (d *DB) UpsertMigrationHistory(ctx context.Context, upsert *store.UpsertMigrationHistory) (*store.MigrationHistory, error) { + qb := squirrel.Insert("migration_history"). + Columns("version"). + Values(upsert.Version). + Suffix("ON CONFLICT (version) DO UPDATE SET version = ?", upsert.Version) + + query, args, err := qb.PlaceholderFormat(squirrel.Dollar).ToSql() + if err != nil { + return nil, err + } + + _, err = d.db.ExecContext(ctx, query, args...) + if err != nil { + return nil, err + } + + var migrationHistory store.MigrationHistory + var createdTs time.Time + query, args, err = squirrel.Select("version", "created_ts"). + From("migration_history"). + Where(squirrel.Eq{"version": upsert.Version}). + PlaceholderFormat(squirrel.Dollar). + ToSql() + if err != nil { + return nil, err + } + + if err := d.db.QueryRowContext(ctx, query, args...).Scan(&migrationHistory.Version, &createdTs); err != nil { + return nil, err + } + migrationHistory.CreatedTs = createdTs.UnixNano() + + return &migrationHistory, nil +} diff --git a/store/db/postgres/migrator.go b/store/db/postgres/migrator.go new file mode 100644 index 000000000..d94a56f2c --- /dev/null +++ b/store/db/postgres/migrator.go @@ -0,0 +1,207 @@ +package postgres + +import ( + "context" + "embed" + "fmt" + "io/fs" + "regexp" + "sort" + "strings" + + "github.com/pkg/errors" + + "github.com/usememos/memos/server/version" + "github.com/usememos/memos/store" +) + +const ( + latestSchemaFileName = "LATEST__SCHEMA.sql" +) + +//go:embed migration +var migrationFS embed.FS + +func (d *DB) Migrate(ctx context.Context) error { + if d.profile.IsDev() { + return d.nonProdMigrate(ctx) + } + + return d.prodMigrate(ctx) +} + +func (d *DB) nonProdMigrate(ctx context.Context) error { + rows, err := d.db.QueryContext(ctx, "SELECT tablename FROM pg_catalog.pg_tables WHERE schemaname != 'pg_catalog' AND schemaname != 'information_schema';") + if err != nil { + return errors.Errorf("failed to query database tables: %s", err) + } + if rows.Err() != nil { + return errors.Errorf("failed to query database tables: %s", err) + } + defer rows.Close() + + var tables []string + for rows.Next() { + var table string + err := rows.Scan(&table) + if err != nil { + return errors.Errorf("failed to scan table name: %s", err) + } + tables = append(tables, table) + } + + if len(tables) != 0 { + return nil + } + + println("no tables in the database. start migration") + + buf, err := migrationFS.ReadFile("migration/dev/" + latestSchemaFileName) + if err != nil { + return errors.Errorf("failed to read latest schema file: %s", err) + } + + stmt := string(buf) + if _, err := d.db.ExecContext(ctx, stmt); err != nil { + return errors.Errorf("failed to exec SQL %s: %s", stmt, err) + } + + // In demo mode, we should seed the database. + if d.profile.Mode == "demo" { + if err := d.seed(ctx); err != nil { + return errors.Wrap(err, "failed to seed") + } + } + return nil +} + +func (d *DB) prodMigrate(ctx context.Context) error { + currentVersion := version.GetCurrentVersion(d.profile.Mode) + migrationHistoryList, err := d.FindMigrationHistoryList(ctx, &store.FindMigrationHistory{}) + // If there is no migration history, we should apply the latest schema. + if err != nil || len(migrationHistoryList) == 0 { + buf, err := migrationFS.ReadFile("migration/prod/" + latestSchemaFileName) + if err != nil { + return errors.Errorf("failed to read latest schema file: %s", err) + } + + stmt := string(buf) + if _, err := d.db.ExecContext(ctx, stmt); err != nil { + return errors.Errorf("failed to exec SQL %s: %s", stmt, err) + } + if _, err := d.UpsertMigrationHistory(ctx, &store.UpsertMigrationHistory{ + Version: currentVersion, + }); err != nil { + return errors.Wrap(err, "failed to upsert migration history") + } + return nil + } + + migrationHistoryVersionList := []string{} + for _, migrationHistory := range migrationHistoryList { + migrationHistoryVersionList = append(migrationHistoryVersionList, migrationHistory.Version) + } + sort.Sort(version.SortVersion(migrationHistoryVersionList)) + latestMigrationHistoryVersion := migrationHistoryVersionList[len(migrationHistoryVersionList)-1] + if !version.IsVersionGreaterThan(version.GetSchemaVersion(currentVersion), latestMigrationHistoryVersion) { + return nil + } + + println("start migrate") + for _, minorVersion := range getMinorVersionList() { + normalizedVersion := minorVersion + ".0" + if version.IsVersionGreaterThan(normalizedVersion, latestMigrationHistoryVersion) && version.IsVersionGreaterOrEqualThan(currentVersion, normalizedVersion) { + println("applying migration for", normalizedVersion) + if err := d.applyMigrationForMinorVersion(ctx, minorVersion); err != nil { + return errors.Wrap(err, "failed to apply minor version migration") + } + } + } + println("end migrate") + return nil +} + +func (d *DB) applyMigrationForMinorVersion(ctx context.Context, minorVersion string) error { + filenames, err := fs.Glob(migrationFS, fmt.Sprintf("migration/prod/%s/*.sql", minorVersion)) + if err != nil { + return errors.Wrap(err, "failed to read ddl files") + } + + sort.Strings(filenames) + // Loop over all migration files and execute them in order. + for _, filename := range filenames { + buf, err := migrationFS.ReadFile(filename) + if err != nil { + return errors.Wrapf(err, "failed to read minor version migration file, filename=%s", filename) + } + for _, stmt := range strings.Split(string(buf), ";") { + if strings.TrimSpace(stmt) == "" { + continue + } + if _, err := d.db.ExecContext(ctx, stmt); err != nil { + return errors.Wrapf(err, "migrate error: %s", stmt) + } + } + } + + // Upsert the newest version to migration_history. + version := minorVersion + ".0" + if _, err = d.UpsertMigrationHistory(ctx, &store.UpsertMigrationHistory{Version: version}); err != nil { + return errors.Wrapf(err, "failed to upsert migration history with version: %s", version) + } + + return nil +} + +//go:embed seed +var seedFS embed.FS + +func (d *DB) seed(ctx context.Context) error { + filenames, err := fs.Glob(seedFS, "seed/*.sql") + if err != nil { + return errors.Wrap(err, "failed to read seed files") + } + + sort.Strings(filenames) + // Loop over all seed files and execute them in order. + for _, filename := range filenames { + buf, err := seedFS.ReadFile(filename) + if err != nil { + return errors.Wrapf(err, "failed to read seed file, filename=%s", filename) + } + + for _, stmt := range strings.Split(string(buf), ";") { + if strings.TrimSpace(stmt) == "" { + continue + } + if _, err := d.db.ExecContext(ctx, stmt); err != nil { + return errors.Wrapf(err, "seed error: %s", stmt) + } + } + } + return nil +} + +// minorDirRegexp is a regular expression for minor version directory. +var minorDirRegexp = regexp.MustCompile(`^migration/prod/[0-9]+\.[0-9]+$`) + +func getMinorVersionList() []string { + minorVersionList := []string{} + + if err := fs.WalkDir(migrationFS, "migration", func(path string, file fs.DirEntry, err error) error { + if err != nil { + return err + } + if file.IsDir() && minorDirRegexp.MatchString(path) { + minorVersionList = append(minorVersionList, file.Name()) + } + + return nil + }); err != nil { + panic(err) + } + + sort.Sort(version.SortVersion(minorVersionList)) + + return minorVersionList +} diff --git a/store/db/postgres/postgres.go b/store/db/postgres/postgres.go new file mode 100644 index 000000000..18b5e98ae --- /dev/null +++ b/store/db/postgres/postgres.go @@ -0,0 +1,87 @@ +package postgres + +import ( + "context" + "database/sql" + "log" + + // Import the PostgreSQL driver. + _ "github.com/lib/pq" + "github.com/pkg/errors" + + "github.com/usememos/memos/server/profile" + "github.com/usememos/memos/store" +) + +type DB struct { + db *sql.DB + profile *profile.Profile + // Add any other fields as needed +} + +func NewDB(profile *profile.Profile) (store.Driver, error) { + if profile == nil { + return nil, errors.New("profile is nil") + } + + // Open the PostgreSQL connection + db, err := sql.Open("postgres", profile.DSN) + if err != nil { + log.Printf("Failed to open database: %s", err) + return nil, errors.Wrapf(err, "failed to open database: %s", profile.DSN) + } + + var driver store.Driver = &DB{ + db: db, + profile: profile, + } + + // Return the DB struct + return driver, nil +} + +func (d *DB) GetDB() *sql.DB { + return d.db +} + +func (d *DB) Vacuum(ctx context.Context) error { + tx, err := d.db.BeginTx(ctx, nil) + if err != nil { + return err + } + defer tx.Rollback() + + if err := vacuumMemo(ctx, tx); err != nil { + return err + } + if err := vacuumResource(ctx, tx); err != nil { + return err + } + if err := vacuumUserSetting(ctx, tx); err != nil { + return err + } + if err := vacuumMemoOrganizer(ctx, tx); err != nil { + return err + } + if err := vacuumMemoRelations(ctx, tx); err != nil { + return err + } + if err := vacuumTag(ctx, tx); err != nil { + // Prevent revive warning. + return err + } + + return tx.Commit() +} + +func (*DB) BackupTo(context.Context, string) error { + return errors.New("Please use postgresdump to backup") +} + +func (*DB) GetCurrentDBSize(context.Context) (int64, error) { + return 0, errors.New("unimplemented") +} + +func (d *DB) Close() error { + return d.db.Close() +} diff --git a/store/db/postgres/resource.go b/store/db/postgres/resource.go new file mode 100644 index 000000000..802c104c5 --- /dev/null +++ b/store/db/postgres/resource.go @@ -0,0 +1,229 @@ +package postgres + +import ( + "context" + "database/sql" + "fmt" + "time" + + "github.com/Masterminds/squirrel" + "github.com/pkg/errors" + + "github.com/usememos/memos/store" +) + +func (d *DB) CreateResource(ctx context.Context, create *store.Resource) (*store.Resource, error) { + qb := squirrel.Insert("resource").Columns("filename", "blob", "external_link", "type", "size", "creator_id", "internal_path") + values := []any{create.Filename, create.Blob, create.ExternalLink, create.Type, create.Size, create.CreatorID, create.InternalPath} + + if create.ID != 0 { + qb = qb.Columns("id") + values = append(values, create.ID) + } + + if create.CreatedTs != 0 { + qb = qb.Columns("created_ts") + values = append(values, time.Unix(0, create.CreatedTs)) + } + + if create.UpdatedTs != 0 { + qb = qb.Columns("updated_ts") + values = append(values, time.Unix(0, create.UpdatedTs)) + } + + if create.MemoID != nil { + qb = qb.Columns("memo_id") + values = append(values, *create.MemoID) + } + + qb = qb.Values(values...).Suffix("RETURNING id") + query, args, err := qb.PlaceholderFormat(squirrel.Dollar).ToSql() + if err != nil { + return nil, err + } + + var id int32 + err = d.db.QueryRowContext(ctx, query, args...).Scan(&id) + if err != nil { + return nil, err + } + + list, err := d.ListResources(ctx, &store.FindResource{ID: &id}) + if err != nil { + return nil, err + } + if len(list) != 1 { + return nil, errors.Wrapf(nil, "unexpected resource count: %d", len(list)) + } + + return list[0], nil +} + +func (d *DB) ListResources(ctx context.Context, find *store.FindResource) ([]*store.Resource, error) { + qb := squirrel.Select("id", "filename", "external_link", "type", "size", "creator_id", "created_ts", "updated_ts", "internal_path", "memo_id").From("resource") + + if v := find.ID; v != nil { + qb = qb.Where(squirrel.Eq{"id": *v}) + } + if v := find.CreatorID; v != nil { + qb = qb.Where(squirrel.Eq{"creator_id": *v}) + } + if v := find.Filename; v != nil { + qb = qb.Where(squirrel.Eq{"filename": *v}) + } + if v := find.MemoID; v != nil { + qb = qb.Where(squirrel.Eq{"memo_id": *v}) + } + if find.HasRelatedMemo { + qb = qb.Where("memo_id IS NOT NULL") + } + if find.GetBlob { + qb = qb.Columns("blob") + } + + qb = qb.GroupBy("id").OrderBy("created_ts DESC") + + if find.Limit != nil { + qb = qb.Limit(uint64(*find.Limit)) + if find.Offset != nil { + qb = qb.Offset(uint64(*find.Offset)) + } + } + + query, args, err := qb.PlaceholderFormat(squirrel.Dollar).ToSql() + if err != nil { + return nil, err + } + + rows, err := d.db.QueryContext(ctx, query, args...) + if err != nil { + return nil, err + } + defer rows.Close() + + list := make([]*store.Resource, 0) + for rows.Next() { + resource := store.Resource{} + var memoID sql.NullInt32 + var createdTs, updatedTs time.Time + dests := []any{ + &resource.ID, + &resource.Filename, + &resource.ExternalLink, + &resource.Type, + &resource.Size, + &resource.CreatorID, + &createdTs, + &updatedTs, + &resource.InternalPath, + &memoID, + } + if find.GetBlob { + dests = append(dests, &resource.Blob) + } + if err := rows.Scan(dests...); err != nil { + return nil, err + } + + resource.CreatedTs = createdTs.UnixNano() + resource.UpdatedTs = updatedTs.UnixNano() + + if memoID.Valid { + resource.MemoID = &memoID.Int32 + } + list = append(list, &resource) + } + + if err := rows.Err(); err != nil { + return nil, err + } + + return list, nil +} + +func (d *DB) UpdateResource(ctx context.Context, update *store.UpdateResource) (*store.Resource, error) { + qb := squirrel.Update("resource") + + if v := update.UpdatedTs; v != nil { + qb = qb.Set("updated_ts", time.Unix(0, *v)) + } + if v := update.Filename; v != nil { + qb = qb.Set("filename", *v) + } + if v := update.InternalPath; v != nil { + qb = qb.Set("internal_path", *v) + } + if v := update.MemoID; v != nil { + qb = qb.Set("memo_id", *v) + } + if v := update.Blob; v != nil { + qb = qb.Set("blob", v) + } + + qb = qb.Where(squirrel.Eq{"id": update.ID}) + + query, args, err := qb.PlaceholderFormat(squirrel.Dollar).ToSql() + if err != nil { + return nil, err + } + + if _, err := d.db.ExecContext(ctx, query, args...); err != nil { + return nil, err + } + + list, err := d.ListResources(ctx, &store.FindResource{ID: &update.ID}) + if err != nil { + return nil, err + } + if len(list) != 1 { + return nil, errors.Wrapf(nil, "unexpected resource count: %d", len(list)) + } + + return list[0], nil +} + +func (d *DB) DeleteResource(ctx context.Context, delete *store.DeleteResource) error { + qb := squirrel.Delete("resource").Where(squirrel.Eq{"id": delete.ID}) + + query, args, err := qb.PlaceholderFormat(squirrel.Dollar).ToSql() + if err != nil { + return err + } + + result, err := d.db.ExecContext(ctx, query, args...) + if err != nil { + return err + } + + if _, err := result.RowsAffected(); err != nil { + return err + } + + if err := d.Vacuum(ctx); err != nil { + // Prevent linter warning. + return err + } + + return nil +} + +func vacuumResource(ctx context.Context, tx *sql.Tx) error { + // First, build the subquery + subQuery, subArgs, err := squirrel.Select("id").From("user").PlaceholderFormat(squirrel.Dollar).ToSql() + if err != nil { + return err + } + + // Now, build the main delete query using the subquery + query, args, err := squirrel.Delete("resource"). + Where(fmt.Sprintf("creator_id NOT IN (%s)", subQuery), subArgs...). + PlaceholderFormat(squirrel.Dollar). + ToSql() + if err != nil { + return err + } + + // Execute the query + _, err = tx.ExecContext(ctx, query, args...) + return err +} diff --git a/store/db/postgres/seed/10000__reset.sql b/store/db/postgres/seed/10000__reset.sql new file mode 100644 index 000000000..de4e97c98 --- /dev/null +++ b/store/db/postgres/seed/10000__reset.sql @@ -0,0 +1,4 @@ +TRUNCATE TABLE memo_organizer; +TRUNCATE TABLE resource; +TRUNCATE TABLE memo; +TRUNCATE TABLE user; diff --git a/store/db/postgres/seed/10001__user.sql b/store/db/postgres/seed/10001__user.sql new file mode 100644 index 000000000..45f468d81 --- /dev/null +++ b/store/db/postgres/seed/10001__user.sql @@ -0,0 +1,44 @@ +INSERT INTO "user" ( + id, + username, + role, + email, + nickname, + row_status, + avatar_url, + password_hash +) +VALUES + ( + 101, + 'memos-demo', + 'HOST', + 'demo@usememos.com', + 'Derobot', + 'NORMAL', + '', + -- raw password: secret + '$2a$14$ajq8Q7fbtFRQvXpdCq7Jcuy.Rx1h/L4J60Otx.gyNLbAYctGMJ9tK' + ), + ( + 102, + 'jack', + 'USER', + 'jack@usememos.com', + 'Jack', + 'NORMAL', + '', + -- raw password: secret + '$2a$14$ajq8Q7fbtFRQvXpdCq7Jcuy.Rx1h/L4J60Otx.gyNLbAYctGMJ9tK' + ), + ( + 103, + 'bob', + 'USER', + 'bob@usememos.com', + 'Bob', + 'ARCHIVED', + '', + -- raw password: secret + '$2a$14$ajq8Q7fbtFRQvXpdCq7Jcuy.Rx1h/L4J60Otx.gyNLbAYctGMJ9tK' + ); diff --git a/store/db/postgres/seed/10002__memo.sql b/store/db/postgres/seed/10002__memo.sql new file mode 100644 index 000000000..9c70e5c9d --- /dev/null +++ b/store/db/postgres/seed/10002__memo.sql @@ -0,0 +1,34 @@ +INSERT INTO memo (id, content, creator_id) +VALUES + ( + 1, + '#Hello πŸ‘‹ Welcome to memos.', + 101 + ); + +INSERT INTO memo (id, content, creator_id, visibility) +VALUES + ( + 2, + E'#TODO\n- [x] Take more photos about **πŸŒ„ sunset**\n- [x] Clean the room\n- [ ] Read *πŸ“– The Little Prince*\n(πŸ‘† click to toggle status)', + 101, + 'PROTECTED' + ), + ( + 3, + E'**[Slash](https://github.com/yourselfhosted/slash)**: A bookmarking and url shortener, save and share your links very easily.\n**[SQL Chat](https://www.sqlchat.ai)**: Chat-based SQL Client', + 101, + 'PUBLIC' + ), + ( + 4, + E'#TODO\n- [x] Take more photos about **πŸŒ„ sunset**\n- [ ] Clean the classroom\n- [ ] Watch *πŸ‘¦ The Boys*\n(πŸ‘† click to toggle status)', + 102, + 'PROTECTED' + ), + ( + 5, + 'δΈ‰δΊΊθ‘ŒοΌŒεΏ…ζœ‰ζˆ‘εΈˆη„‰οΌπŸ‘¨β€πŸ«', + 102, + 'PUBLIC' + ); diff --git a/store/db/postgres/seed/10003__memo_organizer.sql b/store/db/postgres/seed/10003__memo_organizer.sql new file mode 100644 index 000000000..e1fd7b8db --- /dev/null +++ b/store/db/postgres/seed/10003__memo_organizer.sql @@ -0,0 +1,5 @@ +INSERT INTO + memo_organizer (memo_id, user_id, pinned) +VALUES + (1, 101, 1), + (3, 101, 1); diff --git a/store/db/postgres/seed/10004__tag.sql b/store/db/postgres/seed/10004__tag.sql new file mode 100644 index 000000000..a679c1f6e --- /dev/null +++ b/store/db/postgres/seed/10004__tag.sql @@ -0,0 +1,6 @@ +INSERT INTO + tag (name, creator_id) +VALUES + ('Hello', 101), + ('TODO', 101), + ('TODO', 102); diff --git a/store/db/postgres/storage.go b/store/db/postgres/storage.go new file mode 100644 index 000000000..0c79e0fcc --- /dev/null +++ b/store/db/postgres/storage.go @@ -0,0 +1,125 @@ +package postgres + +import ( + "context" + + "github.com/Masterminds/squirrel" + + "github.com/usememos/memos/store" +) + +func (d *DB) CreateStorage(ctx context.Context, create *store.Storage) (*store.Storage, error) { + qb := squirrel.Insert("storage").Columns("name", "type", "config") + values := []any{create.Name, create.Type, create.Config} + + if create.ID != 0 { + qb = qb.Columns("id") + values = append(values, create.ID) + } + + qb = qb.Values(values...).Suffix("RETURNING id") + query, args, err := qb.PlaceholderFormat(squirrel.Dollar).ToSql() + if err != nil { + return nil, err + } + + err = d.db.QueryRowContext(ctx, query, args...).Scan(&create.ID) + if err != nil { + return nil, err + } + + return create, nil +} + +func (d *DB) ListStorages(ctx context.Context, find *store.FindStorage) ([]*store.Storage, error) { + qb := squirrel.Select("id", "name", "type", "config").From("storage").OrderBy("id DESC") + + if find.ID != nil { + qb = qb.Where(squirrel.Eq{"id": *find.ID}) + } + + query, args, err := qb.PlaceholderFormat(squirrel.Dollar).ToSql() + if err != nil { + return nil, err + } + + rows, err := d.db.QueryContext(ctx, query, args...) + if err != nil { + return nil, err + } + defer rows.Close() + + list := []*store.Storage{} + for rows.Next() { + storage := &store.Storage{} + if err := rows.Scan(&storage.ID, &storage.Name, &storage.Type, &storage.Config); err != nil { + return nil, err + } + list = append(list, storage) + } + + if err := rows.Err(); err != nil { + return nil, err + } + + return list, nil +} + +func (d *DB) UpdateStorage(ctx context.Context, update *store.UpdateStorage) (*store.Storage, error) { + qb := squirrel.Update("storage") + + if update.Name != nil { + qb = qb.Set("name", *update.Name) + } + if update.Config != nil { + qb = qb.Set("config", *update.Config) + } + + qb = qb.Where(squirrel.Eq{"id": update.ID}) + + query, args, err := qb.PlaceholderFormat(squirrel.Dollar).ToSql() + if err != nil { + return nil, err + } + + _, err = d.db.ExecContext(ctx, query, args...) + if err != nil { + return nil, err + } + + storage := &store.Storage{} + query, args, err = squirrel.Select("id", "name", "type", "config"). + From("storage"). + Where(squirrel.Eq{"id": update.ID}). + PlaceholderFormat(squirrel.Dollar). + ToSql() + if err != nil { + return nil, err + } + + if err := d.db.QueryRowContext(ctx, query, args...).Scan(&storage.ID, &storage.Name, &storage.Type, &storage.Config); err != nil { + return nil, err + } + + return storage, nil +} + +func (d *DB) DeleteStorage(ctx context.Context, delete *store.DeleteStorage) error { + qb := squirrel.Delete("storage").Where(squirrel.Eq{"id": delete.ID}) + + query, args, err := qb.PlaceholderFormat(squirrel.Dollar).ToSql() + if err != nil { + return err + } + + result, err := d.db.ExecContext(ctx, query, args...) + if err != nil { + return err + } + + if _, err := result.RowsAffected(); err != nil { + return err + } + + return nil +} diff --git a/store/db/postgres/system_setting.go b/store/db/postgres/system_setting.go new file mode 100644 index 000000000..4c42198dc --- /dev/null +++ b/store/db/postgres/system_setting.go @@ -0,0 +1,61 @@ +package postgres + +import ( + "context" + + "github.com/Masterminds/squirrel" + + "github.com/usememos/memos/store" +) + +func (d *DB) UpsertSystemSetting(ctx context.Context, upsert *store.SystemSetting) (*store.SystemSetting, error) { + qb := squirrel.Insert("system_setting"). + Columns("name", "value", "description"). + Values(upsert.Name, upsert.Value, upsert.Description) + + query, args, err := qb.PlaceholderFormat(squirrel.Dollar).ToSql() + if err != nil { + return nil, err + } + + _, err = d.db.ExecContext(ctx, query, args...) + if err != nil { + return nil, err + } + + return upsert, nil +} + +func (d *DB) ListSystemSettings(ctx context.Context, find *store.FindSystemSetting) ([]*store.SystemSetting, error) { + qb := squirrel.Select("name", "value", "description").From("system_setting") + + if find.Name != "" { + qb = qb.Where(squirrel.Eq{"name": find.Name}) + } + + query, args, err := qb.PlaceholderFormat(squirrel.Dollar).ToSql() + if err != nil { + return nil, err + } + + rows, err := d.db.QueryContext(ctx, query, args...) + if err != nil { + return nil, err + } + defer rows.Close() + + list := []*store.SystemSetting{} + for rows.Next() { + systemSetting := &store.SystemSetting{} + if err := rows.Scan(&systemSetting.Name, &systemSetting.Value, &systemSetting.Description); err != nil { + return nil, err + } + list = append(list, systemSetting) + } + + if err := rows.Err(); err != nil { + return nil, err + } + + return list, nil +} diff --git a/store/db/postgres/tag.go b/store/db/postgres/tag.go new file mode 100644 index 000000000..c42180481 --- /dev/null +++ b/store/db/postgres/tag.go @@ -0,0 +1,113 @@ +package postgres + +import ( + "context" + "database/sql" + "fmt" + + "github.com/Masterminds/squirrel" + + "github.com/usememos/memos/store" +) + +func (d *DB) UpsertTag(ctx context.Context, upsert *store.Tag) (*store.Tag, error) { + builder := squirrel.Insert("tag"). + Columns("name", "creator_id"). + Values(upsert.Name, upsert.CreatorID). // on conflict is not necessary, as only the pair of name and creator_id is unique + PlaceholderFormat(squirrel.Dollar) + + query, args, err := builder.ToSql() + if err != nil { + return nil, err + } + + if _, err := d.db.ExecContext(ctx, query, args...); err != nil { + return nil, err + } + + return upsert, nil +} + +func (d *DB) ListTags(ctx context.Context, find *store.FindTag) ([]*store.Tag, error) { + builder := squirrel.Select("name", "creator_id").From("tag"). + Where("1 = 1"). + OrderBy("name ASC"). + PlaceholderFormat(squirrel.Dollar) + + if find.CreatorID != 0 { + builder = builder.Where("creator_id = ?", find.CreatorID) + } + + query, args, err := builder.ToSql() + if err != nil { + return nil, err + } + + rows, err := d.db.QueryContext(ctx, query, args...) + if err != nil { + return nil, err + } + defer rows.Close() + + list := []*store.Tag{} + for rows.Next() { + tag := &store.Tag{} + if err := rows.Scan( + &tag.Name, + &tag.CreatorID, + ); err != nil { + return nil, err + } + + list = append(list, tag) + } + + if err := rows.Err(); err != nil { + return nil, err + } + + return list, nil +} + +func (d *DB) DeleteTag(ctx context.Context, delete *store.DeleteTag) error { + builder := squirrel.Delete("tag"). + Where(squirrel.Eq{"name": delete.Name, "creator_id": delete.CreatorID}). + PlaceholderFormat(squirrel.Dollar) + + query, args, err := builder.ToSql() + if err != nil { + return err + } + + result, err := d.db.ExecContext(ctx, query, args...) + if err != nil { + return err + } + + if _, err = result.RowsAffected(); err != nil { + return err + } + + return nil +} + +func vacuumTag(ctx context.Context, tx *sql.Tx) error { + // First, build the subquery for creator_id + subQuery, subArgs, err := squirrel.Select("id").From("\"user\"").PlaceholderFormat(squirrel.Dollar).ToSql() + if err != nil { + return err + } + + // Now, build the main delete query using the subquery + query, args, err := squirrel.Delete("tag"). + Where(fmt.Sprintf("creator_id NOT IN (%s)", subQuery), subArgs...). + PlaceholderFormat(squirrel.Dollar). + ToSql() + if err != nil { + return err + } + + // Execute the query + _, err = tx.ExecContext(ctx, query, args...) + return err +} diff --git a/store/db/postgres/user.go b/store/db/postgres/user.go new file mode 100644 index 000000000..d6553a779 --- /dev/null +++ b/store/db/postgres/user.go @@ -0,0 +1,225 @@ +package postgres + +import ( + "context" + + "github.com/Masterminds/squirrel" + "github.com/pkg/errors" + + "github.com/usememos/memos/store" +) + +func (d *DB) CreateUser(ctx context.Context, create *store.User) (*store.User, error) { + // Start building the insert statement + builder := squirrel.Insert("\"user\"").PlaceholderFormat(squirrel.Dollar) + + columns := []string{"username", "role", "email", "nickname", "password_hash", "avatar_url"} + builder = builder.Columns(columns...) + + values := []any{create.Username, create.Role, create.Email, create.Nickname, create.PasswordHash, create.AvatarURL} + + if create.RowStatus != "" { + builder = builder.Columns("row_status") + values = append(values, create.RowStatus) + } + + if create.CreatedTs != 0 { + builder = builder.Columns("created_ts") + values = append(values, squirrel.Expr("TO_TIMESTAMP(?)", create.CreatedTs)) + } + + if create.UpdatedTs != 0 { + builder = builder.Columns("updated_ts") + values = append(values, squirrel.Expr("TO_TIMESTAMP(?)", create.UpdatedTs)) + } + + if create.ID != 0 { + builder = builder.Columns("id") + values = append(values, create.ID) + } + + builder = builder.Values(values...) + + builder = builder.Suffix("RETURNING id") + + // Prepare the final query + query, args, err := builder.ToSql() + if err != nil { + return nil, err + } + + // Execute the query and get the returned ID + var id int32 + err = d.db.QueryRowContext(ctx, query, args...).Scan(&id) + if err != nil { + return nil, err + } + + // Use the returned ID to retrieve the full user object + user, err := d.GetUser(ctx, &store.FindUser{ID: &id}) + if err != nil { + return nil, err + } + + return user, nil +} + +func (d *DB) UpdateUser(ctx context.Context, update *store.UpdateUser) (*store.User, error) { + // Start building the update statement + builder := squirrel.Update("\"user\"").PlaceholderFormat(squirrel.Dollar) + + // Conditionally add set clauses + if v := update.UpdatedTs; v != nil { + builder = builder.Set("updated_ts", squirrel.Expr("to_timestamp(?)", *v)) + } + if v := update.RowStatus; v != nil { + builder = builder.Set("row_status", *v) + } + if v := update.Username; v != nil { + builder = builder.Set("username", *v) + } + if v := update.Email; v != nil { + builder = builder.Set("email", *v) + } + if v := update.Nickname; v != nil { + builder = builder.Set("nickname", *v) + } + if v := update.AvatarURL; v != nil { + builder = builder.Set("avatar_url", *v) + } + if v := update.PasswordHash; v != nil { + builder = builder.Set("password_hash", *v) + } + + // Add the WHERE clause + builder = builder.Where(squirrel.Eq{"id": update.ID}) + + // Prepare the final query + query, args, err := builder.ToSql() + if err != nil { + return nil, err + } + + // Execute the query with the context + if _, err := d.db.ExecContext(ctx, query, args...); err != nil { + return nil, err + } + + // Retrieve the updated user + user, err := d.GetUser(ctx, &store.FindUser{ID: &update.ID}) + if err != nil { + return nil, err + } + return user, nil +} + +func (d *DB) ListUsers(ctx context.Context, find *store.FindUser) ([]*store.User, error) { + // Start building the SELECT statement + builder := squirrel.Select("id", "username", "role", "email", "nickname", "password_hash", "avatar_url", + "FLOOR(EXTRACT(EPOCH FROM created_ts)) AS created_ts", "FLOOR(EXTRACT(EPOCH FROM updated_ts)) AS updated_ts", "row_status"). + From("\"user\""). + PlaceholderFormat(squirrel.Dollar) + + // 1 = 1 is often used as a no-op in SQL, ensuring there's always a WHERE clause + builder = builder.Where("1 = 1") + + // Conditionally add where clauses + if v := find.ID; v != nil { + builder = builder.Where(squirrel.Eq{"id": *v}) + } + if v := find.Username; v != nil { + builder = builder.Where(squirrel.Eq{"username": *v}) + } + if v := find.Role; v != nil { + builder = builder.Where(squirrel.Eq{"role": *v}) + } + if v := find.Email; v != nil { + builder = builder.Where(squirrel.Eq{"email": *v}) + } + if v := find.Nickname; v != nil { + builder = builder.Where(squirrel.Eq{"nickname": *v}) + } + + // Add ordering + builder = builder.OrderBy("created_ts DESC", "row_status DESC") + + // Prepare the final query + query, args, err := builder.ToSql() + if err != nil { + return nil, err + } + + // Execute the query with the context + rows, err := d.db.QueryContext(ctx, query, args...) + if err != nil { + return nil, err + } + defer rows.Close() + + list := make([]*store.User, 0) + for rows.Next() { + var user store.User + if err := rows.Scan( + &user.ID, + &user.Username, + &user.Role, + &user.Email, + &user.Nickname, + &user.PasswordHash, + &user.AvatarURL, + &user.CreatedTs, + &user.UpdatedTs, + &user.RowStatus, + ); err != nil { + return nil, err + } + list = append(list, &user) + } + + if err := rows.Err(); err != nil { + return nil, err + } + + return list, nil +} + +func (d *DB) GetUser(ctx context.Context, find *store.FindUser) (*store.User, error) { + list, err := d.ListUsers(ctx, find) + if err != nil { + return nil, err + } + if len(list) != 1 { + return nil, errors.Wrapf(nil, "unexpected user count: %d", len(list)) + } + return list[0], nil +} + +func (d *DB) DeleteUser(ctx context.Context, delete *store.DeleteUser) error { + // Start building the DELETE statement + builder := squirrel.Delete("\"user\""). + PlaceholderFormat(squirrel.Dollar). + Where(squirrel.Eq{"id": delete.ID}) + + // Prepare the final query + query, args, err := builder.ToSql() + if err != nil { + return err + } + + // Execute the query with the context + result, err := d.db.ExecContext(ctx, query, args...) + if err != nil { + return err + } + + if _, err := result.RowsAffected(); err != nil { + return err + } + + if err := d.Vacuum(ctx); err != nil { + // Prevent linter warning. + return err + } + + return nil +} diff --git a/store/db/postgres/user_setting.go b/store/db/postgres/user_setting.go new file mode 100644 index 000000000..5699b6788 --- /dev/null +++ b/store/db/postgres/user_setting.go @@ -0,0 +1,194 @@ +package postgres + +import ( + "context" + "database/sql" + "fmt" + + "github.com/Masterminds/squirrel" + "github.com/pkg/errors" + "google.golang.org/protobuf/encoding/protojson" + + storepb "github.com/usememos/memos/proto/gen/store" + "github.com/usememos/memos/store" +) + +func (d *DB) UpsertUserSetting(ctx context.Context, upsert *store.UserSetting) (*store.UserSetting, error) { + // Construct the query using Squirrel + query, args, err := squirrel. + Insert("user_setting"). + Columns("user_id", "key", "value"). + Values(upsert.UserID, upsert.Key, upsert.Value). + PlaceholderFormat(squirrel.Dollar). + // no need to specify ON CONFLICT clause, as the primary key is (user_id, key) + ToSql() + if err != nil { + return nil, err + } + + // Execute the query + if _, err := d.db.ExecContext(ctx, query, args...); err != nil { + return nil, err + } + + return upsert, nil +} + +func (d *DB) ListUserSettings(ctx context.Context, find *store.FindUserSetting) ([]*store.UserSetting, error) { + // Start building the query + qb := squirrel.Select("user_id", "key", "value").From("user_setting").Where("1 = 1").PlaceholderFormat(squirrel.Dollar) + + // Add conditions based on the provided find parameters + if v := find.Key; v != "" { + qb = qb.Where(squirrel.Eq{"key": v}) + } + if v := find.UserID; v != nil { + qb = qb.Where(squirrel.Eq{"user_id": *v}) + } + + // Finalize the query + query, args, err := qb.ToSql() + if err != nil { + return nil, err + } + + // Execute the query + rows, err := d.db.QueryContext(ctx, query, args...) + if err != nil { + return nil, err + } + defer rows.Close() + + // Process the rows + userSettingList := make([]*store.UserSetting, 0) + for rows.Next() { + var userSetting store.UserSetting + if err := rows.Scan( + &userSetting.UserID, + &userSetting.Key, + &userSetting.Value, + ); err != nil { + return nil, err + } + userSettingList = append(userSettingList, &userSetting) + } + + if err := rows.Err(); err != nil { + return nil, err + } + + return userSettingList, nil +} + +func (d *DB) UpsertUserSettingV1(ctx context.Context, upsert *storepb.UserSetting) (*storepb.UserSetting, error) { + var valueString string + if upsert.Key == storepb.UserSettingKey_USER_SETTING_ACCESS_TOKENS { + valueBytes, err := protojson.Marshal(upsert.GetAccessTokens()) + if err != nil { + return nil, err + } + valueString = string(valueBytes) + } else { + return nil, errors.New("invalid user setting key") + } + + // Construct the query using Squirrel + query, args, err := squirrel. + Insert("user_setting"). + Columns("user_id", "key", "value"). + Values(upsert.UserId, upsert.Key.String(), valueString). + Suffix("ON CONFLICT (user_id, key) DO UPDATE SET value = EXCLUDED.value"). + PlaceholderFormat(squirrel.Dollar). + ToSql() + if err != nil { + return nil, err + } + + // Execute the query + if _, err := d.db.ExecContext(ctx, query, args...); err != nil { + return nil, err + } + + return upsert, nil +} + +func (d *DB) ListUserSettingsV1(ctx context.Context, find *store.FindUserSettingV1) ([]*storepb.UserSetting, error) { + // Start building the query using Squirrel + qb := squirrel.Select("user_id", "key", "value").From("user_setting").PlaceholderFormat(squirrel.Dollar) + + // Add conditions based on the provided find parameters + if v := find.Key; v != storepb.UserSettingKey_USER_SETTING_KEY_UNSPECIFIED { + qb = qb.Where(squirrel.Eq{"key": v.String()}) + } + if v := find.UserID; v != nil { + qb = qb.Where(squirrel.Eq{"user_id": *v}) + } + + // Finalize the query + query, args, err := qb.ToSql() + if err != nil { + return nil, err + } + + // Execute the query + rows, err := d.db.QueryContext(ctx, query, args...) + if err != nil { + return nil, err + } + defer rows.Close() + + // Process the rows + userSettingList := make([]*storepb.UserSetting, 0) + for rows.Next() { + userSetting := &storepb.UserSetting{} + var keyString, valueString string + if err := rows.Scan( + &userSetting.UserId, + &keyString, + &valueString, + ); err != nil { + return nil, err + } + userSetting.Key = storepb.UserSettingKey(storepb.UserSettingKey_value[keyString]) + if userSetting.Key == storepb.UserSettingKey_USER_SETTING_ACCESS_TOKENS { + accessTokensUserSetting := &storepb.AccessTokensUserSetting{} + if err := protojson.Unmarshal([]byte(valueString), accessTokensUserSetting); err != nil { + return nil, err + } + userSetting.Value = &storepb.UserSetting_AccessTokens{ + AccessTokens: accessTokensUserSetting, + } + } else { + // Skip unknown user setting v1 key + continue + } + userSettingList = append(userSettingList, userSetting) + } + + if err := rows.Err(); err != nil { + return nil, err + } + + return userSettingList, nil +} + +func vacuumUserSetting(ctx context.Context, tx *sql.Tx) error { + // First, build the subquery + subQuery, subArgs, err := squirrel.Select("id").From("\"user\"").PlaceholderFormat(squirrel.Dollar).ToSql() + if err != nil { + return err + } + + // Now, build the main delete query using the subquery + query, args, err := squirrel.Delete("user_setting"). + Where(fmt.Sprintf("user_id NOT IN (%s)", subQuery), subArgs...). + PlaceholderFormat(squirrel.Dollar). + ToSql() + if err != nil { + return err + } + + // Execute the query + _, err = tx.ExecContext(ctx, query, args...) + return err +} diff --git a/store/db/postgres/webhook.go b/store/db/postgres/webhook.go new file mode 100644 index 000000000..38d2e2b9e --- /dev/null +++ b/store/db/postgres/webhook.go @@ -0,0 +1,148 @@ +package postgres + +import ( + "context" + "time" + + "github.com/Masterminds/squirrel" + + storepb "github.com/usememos/memos/proto/gen/store" + "github.com/usememos/memos/store" +) + +func (d *DB) CreateWebhook(ctx context.Context, create *storepb.Webhook) (*storepb.Webhook, error) { + qb := squirrel.Insert("webhook").Columns("name", "url", "creator_id") + values := []any{create.Name, create.Url, create.CreatorId} + + if create.Id != 0 { + qb = qb.Columns("id") + values = append(values, create.Id) + } + + qb = qb.Values(values...).Suffix("RETURNING id") + query, args, err := qb.PlaceholderFormat(squirrel.Dollar).ToSql() + if err != nil { + return nil, err + } + + err = d.db.QueryRowContext(ctx, query, args...).Scan(&create.Id) + if err != nil { + return nil, err + } + + create, err = d.GetWebhook(ctx, &store.FindWebhook{ID: &create.Id}) + if err != nil { + return nil, err + } + + return create, nil +} + +func (d *DB) ListWebhooks(ctx context.Context, find *store.FindWebhook) ([]*storepb.Webhook, error) { + qb := squirrel.Select("id", "created_ts", "updated_ts", "row_status", "creator_id", "name", "url").From("webhook").OrderBy("id DESC") + + if find.ID != nil { + qb = qb.Where(squirrel.Eq{"id": *find.ID}) + } + if find.CreatorID != nil { + qb = qb.Where(squirrel.Eq{"creator_id": *find.CreatorID}) + } + + query, args, err := qb.PlaceholderFormat(squirrel.Dollar).ToSql() + if err != nil { + return nil, err + } + + rows, err := d.db.QueryContext(ctx, query, args...) + if err != nil { + return nil, err + } + defer rows.Close() + + list := []*storepb.Webhook{} + for rows.Next() { + webhook := &storepb.Webhook{} + var rowStatus string + var createdTs, updatedTs time.Time + + if err := rows.Scan( + &webhook.Id, + &createdTs, + &updatedTs, + &rowStatus, + &webhook.CreatorId, + &webhook.Name, + &webhook.Url, + ); err != nil { + return nil, err + } + + webhook.CreatedTs = createdTs.UnixNano() + webhook.UpdatedTs = updatedTs.UnixNano() + webhook.RowStatus = storepb.RowStatus(storepb.RowStatus_value[rowStatus]) + + list = append(list, webhook) + } + + if err := rows.Err(); err != nil { + return nil, err + } + + return list, nil +} + +func (d *DB) GetWebhook(ctx context.Context, find *store.FindWebhook) (*storepb.Webhook, error) { + list, err := d.ListWebhooks(ctx, find) + if err != nil { + return nil, err + } + if len(list) == 0 { + return nil, nil + } + return list[0], nil +} + +func (d *DB) UpdateWebhook(ctx context.Context, update *store.UpdateWebhook) (*storepb.Webhook, error) { + qb := squirrel.Update("webhook") + + if update.RowStatus != nil { + qb = qb.Set("row_status", update.RowStatus.String()) + } + if update.Name != nil { + qb = qb.Set("name", *update.Name) + } + if update.URL != nil { + qb = qb.Set("url", *update.URL) + } + + qb = qb.Where(squirrel.Eq{"id": update.ID}) + + query, args, err := qb.PlaceholderFormat(squirrel.Dollar).ToSql() + if err != nil { + return nil, err + } + + _, err = d.db.ExecContext(ctx, query, args...) + if err != nil { + return nil, err + } + + webhook, err := d.GetWebhook(ctx, &store.FindWebhook{ID: &update.ID}) + if err != nil { + return nil, err + } + + return webhook, nil +} + +func (d *DB) DeleteWebhook(ctx context.Context, delete *store.DeleteWebhook) error { + qb := squirrel.Delete("webhook").Where(squirrel.Eq{"id": delete.ID}) + + query, args, err := qb.PlaceholderFormat(squirrel.Dollar).ToSql() + if err != nil { + return err + } + + _, err = d.db.ExecContext(ctx, query, args...) + return err +} diff --git a/test/store/memo_test.go b/test/store/memo_test.go index 95931b639..e1df1f775 100644 --- a/test/store/memo_test.go +++ b/test/store/memo_test.go @@ -49,4 +49,13 @@ func TestMemoStore(t *testing.T) { }) require.NoError(t, err) require.Equal(t, 0, len(memoList)) + + memoList, err = ts.ListMemos(ctx, &store.FindMemo{ + CreatorID: &user.ID, + VisibilityList: []store.Visibility{ + store.Public, + }, + }) + require.NoError(t, err) + require.Equal(t, 0, len(memoList)) }