diff --git a/cmd/memos.go b/cmd/memos.go index fb5df0cf..ba26e39d 100644 --- a/cmd/memos.go +++ b/cmd/memos.go @@ -17,6 +17,7 @@ import ( _profile "github.com/usememos/memos/server/profile" "github.com/usememos/memos/store" "github.com/usememos/memos/store/db" + "github.com/usememos/memos/store/sqlite3" ) const ( @@ -54,7 +55,9 @@ var ( return } - store := store.New(db.DBInstance, profile) + driver := sqlite3.NewDriver(db.DBInstance) + + store := store.New(db.DBInstance, driver, profile) s, err := server.NewServer(ctx, profile, store) if err != nil { cancel() diff --git a/cmd/mvrss.go b/cmd/mvrss.go index dd64720f..2e1722b3 100644 --- a/cmd/mvrss.go +++ b/cmd/mvrss.go @@ -10,6 +10,7 @@ import ( "github.com/usememos/memos/store" "github.com/usememos/memos/store/db" + "github.com/usememos/memos/store/sqlite3" ) var ( @@ -49,7 +50,9 @@ var ( return } - s := store.New(db.DBInstance, profile) + driver := sqlite3.NewDriver(db.DBInstance) + + s := store.New(db.DBInstance, driver, profile) resources, err := s.ListResources(ctx, &store.FindResource{}) if err != nil { fmt.Printf("failed to list resources, error: %+v\n", err) diff --git a/cmd/setup.go b/cmd/setup.go index a3bb1bae..1f74d13b 100644 --- a/cmd/setup.go +++ b/cmd/setup.go @@ -12,6 +12,7 @@ import ( "github.com/usememos/memos/common/util" "github.com/usememos/memos/store" "github.com/usememos/memos/store/db" + "github.com/usememos/memos/store/sqlite3" ) var ( @@ -46,7 +47,9 @@ var ( return } - store := store.New(db.DBInstance, profile) + driver := sqlite3.NewDriver(db.DBInstance) + + store := store.New(db.DBInstance, driver, profile) if err := ExecuteSetup(ctx, store, hostUsername, hostPassword); err != nil { fmt.Printf("failed to setup, error: %+v\n", err) return diff --git a/store/activity.go b/store/activity.go index 96c80c5f..a9e7cb30 100644 --- a/store/activity.go +++ b/store/activity.go @@ -18,23 +18,5 @@ type Activity struct { } func (s *Store) CreateActivity(ctx context.Context, create *Activity) (*Activity, error) { - stmt := ` - INSERT INTO activity ( - creator_id, - type, - level, - payload - ) - VALUES (?, ?, ?, ?) - RETURNING id, created_ts - ` - if err := s.db.QueryRowContext(ctx, stmt, create.CreatorID, create.Type, create.Level, create.Payload).Scan( - &create.ID, - &create.CreatedTs, - ); err != nil { - return nil, err - } - - activity := create - return activity, nil + return s.driver.CreateActivity(ctx, create) } diff --git a/store/driver.go b/store/driver.go new file mode 100644 index 00000000..6cba2187 --- /dev/null +++ b/store/driver.go @@ -0,0 +1,10 @@ +package store + +import "context" + +type Driver interface { + CreateActivity(ctx context.Context, create *Activity) (*Activity, error) + + UpsertSystemSetting(ctx context.Context, upsert *SystemSetting) (*SystemSetting, error) + ListSystemSettings(ctx context.Context, find *FindSystemSetting) ([]*SystemSetting, error) +} diff --git a/store/sqlite3/activity.go b/store/sqlite3/activity.go new file mode 100644 index 00000000..e6f82594 --- /dev/null +++ b/store/sqlite3/activity.go @@ -0,0 +1,28 @@ +package sqlite3 + +import ( + "context" + + "github.com/usememos/memos/store" +) + +func (d *Driver) CreateActivity(ctx context.Context, create *store.Activity) (*store.Activity, error) { + stmt := ` + INSERT INTO activity ( + creator_id, + type, + level, + payload + ) + VALUES (?, ?, ?, ?) + RETURNING id, created_ts + ` + if err := d.db.QueryRowContext(ctx, stmt, create.CreatorID, create.Type, create.Level, create.Payload).Scan( + &create.ID, + &create.CreatedTs, + ); err != nil { + return nil, err + } + + return create, nil +} diff --git a/store/sqlite3/driver.go b/store/sqlite3/driver.go new file mode 100644 index 00000000..034ceb9c --- /dev/null +++ b/store/sqlite3/driver.go @@ -0,0 +1,15 @@ +package sqlite3 + +import ( + "database/sql" + + "github.com/usememos/memos/store" +) + +type Driver struct { + db *sql.DB +} + +func NewDriver(db *sql.DB) store.Driver { + return &Driver{db: db} +} diff --git a/store/sqlite3/system_setting.go b/store/sqlite3/system_setting.go new file mode 100644 index 00000000..de5b55ac --- /dev/null +++ b/store/sqlite3/system_setting.go @@ -0,0 +1,66 @@ +package sqlite3 + +import ( + "context" + "strings" + + "github.com/usememos/memos/store" +) + +func (d *Driver) UpsertSystemSetting(ctx context.Context, upsert *store.SystemSetting) (*store.SystemSetting, error) { + stmt := ` + INSERT INTO system_setting ( + name, value, description + ) + VALUES (?, ?, ?) + ON CONFLICT(name) DO UPDATE + SET + value = EXCLUDED.value, + description = EXCLUDED.description + ` + if _, err := d.db.ExecContext(ctx, stmt, upsert.Name, upsert.Value, upsert.Description); err != nil { + return nil, err + } + + return upsert, nil +} + +func (d *Driver) ListSystemSettings(ctx context.Context, find *store.FindSystemSetting) ([]*store.SystemSetting, error) { + where, args := []string{"1 = 1"}, []any{} + if find.Name != "" { + where, args = append(where, "name = ?"), append(args, find.Name) + } + + query := ` + SELECT + name, + value, + description + FROM system_setting + WHERE ` + strings.Join(where, " AND ") + + rows, err := d.db.QueryContext(ctx, query, args...) + if err != nil { + return nil, err + } + defer rows.Close() + + list := []*store.SystemSetting{} + for rows.Next() { + systemSettingMessage := &store.SystemSetting{} + if err := rows.Scan( + &systemSettingMessage.Name, + &systemSettingMessage.Value, + &systemSettingMessage.Description, + ); err != nil { + return nil, err + } + list = append(list, systemSettingMessage) + } + + if err := rows.Err(); err != nil { + return nil, err + } + + return list, nil +} diff --git a/store/store.go b/store/store.go index ed3fe8c8..9beb55a8 100644 --- a/store/store.go +++ b/store/store.go @@ -16,6 +16,7 @@ import ( type Store struct { Profile *profile.Profile db *sql.DB + driver Driver systemSettingCache sync.Map // map[string]*SystemSetting userCache sync.Map // map[int]*User userSettingCache sync.Map // map[string]*UserSetting @@ -23,10 +24,11 @@ type Store struct { } // New creates a new instance of Store. -func New(db *sql.DB, profile *profile.Profile) *Store { +func New(db *sql.DB, driver Driver, profile *profile.Profile) *Store { return &Store{ Profile: profile, db: db, + driver: driver, } } diff --git a/store/system_setting.go b/store/system_setting.go index aa94849f..6d7b48c1 100644 --- a/store/system_setting.go +++ b/store/system_setting.go @@ -2,7 +2,6 @@ package store import ( "context" - "strings" ) type SystemSetting struct { @@ -16,60 +15,14 @@ type FindSystemSetting struct { } func (s *Store) UpsertSystemSetting(ctx context.Context, upsert *SystemSetting) (*SystemSetting, error) { - stmt := ` - INSERT INTO system_setting ( - name, value, description - ) - VALUES (?, ?, ?) - ON CONFLICT(name) DO UPDATE - SET - value = EXCLUDED.value, - description = EXCLUDED.description - ` - if _, err := s.db.ExecContext(ctx, stmt, upsert.Name, upsert.Value, upsert.Description); err != nil { - return nil, err - } - - systemSetting := upsert - return systemSetting, nil + return s.driver.UpsertSystemSetting(ctx, upsert) } func (s *Store) ListSystemSettings(ctx context.Context, find *FindSystemSetting) ([]*SystemSetting, error) { - where, args := []string{"1 = 1"}, []any{} - if find.Name != "" { - where, args = append(where, "name = ?"), append(args, find.Name) - } - - query := ` - SELECT - name, - value, - description - FROM system_setting - WHERE ` + strings.Join(where, " AND ") - - rows, err := s.db.QueryContext(ctx, query, args...) + list, err := s.driver.ListSystemSettings(ctx, find) if err != nil { return nil, err } - defer rows.Close() - - list := []*SystemSetting{} - for rows.Next() { - systemSettingMessage := &SystemSetting{} - if err := rows.Scan( - &systemSettingMessage.Name, - &systemSettingMessage.Value, - &systemSettingMessage.Description, - ); err != nil { - return nil, err - } - list = append(list, systemSettingMessage) - } - - if err := rows.Err(); err != nil { - return nil, err - } for _, systemSettingMessage := range list { s.systemSettingCache.Store(systemSettingMessage.Name, systemSettingMessage) diff --git a/test/server/server.go b/test/server/server.go index 88631821..7ab47637 100644 --- a/test/server/server.go +++ b/test/server/server.go @@ -19,6 +19,7 @@ import ( "github.com/usememos/memos/server/profile" "github.com/usememos/memos/store" "github.com/usememos/memos/store/db" + "github.com/usememos/memos/store/sqlite3" "github.com/usememos/memos/test" ) @@ -39,7 +40,9 @@ func NewTestingServer(ctx context.Context, t *testing.T) (*TestingServer, error) return nil, errors.Wrap(err, "failed to migrate db") } - store := store.New(db.DBInstance, profile) + driver := sqlite3.NewDriver(db.DBInstance) + + store := store.New(db.DBInstance, driver, profile) server, err := server.NewServer(ctx, profile, store) if err != nil { return nil, errors.Wrap(err, "failed to create server") diff --git a/test/store/store.go b/test/store/store.go index f728fcc9..de6ea759 100644 --- a/test/store/store.go +++ b/test/store/store.go @@ -7,6 +7,7 @@ import ( "github.com/usememos/memos/store" "github.com/usememos/memos/store/db" + "github.com/usememos/memos/store/sqlite3" "github.com/usememos/memos/test" // sqlite driver. @@ -23,6 +24,8 @@ func NewTestingStore(ctx context.Context, t *testing.T) *store.Store { fmt.Printf("failed to migrate db, error: %+v\n", err) } - store := store.New(db.DBInstance, profile) + driver := sqlite3.NewDriver(db.DBInstance) + + store := store.New(db.DBInstance, driver, profile) return store }