refactor: memo filter

- Updated memo and reaction filtering logic to use a unified engine for compiling filter expressions into SQL statements.
- Removed redundant filter parsing and conversion code from ListMemoRelations, ListReactions, and ListAttachments methods.
- Introduced IDList and UIDList fields in FindMemo and FindReaction structs to support filtering by multiple IDs.
- Removed old filter test files for reactions and attachments, as the filtering logic has been centralized.
- Updated tests for memo filtering to reflect the new SQL statement compilation approach.
- Ensured that unsupported user filters return an error in ListUsers method.
pull/5091/merge
Copilot 3 weeks ago
parent 228cc6105d
commit b685ffacdf

@ -0,0 +1,50 @@
# Maintaining the Memo Filter Engine
The engine is memo-specific; any future field or behavior changes must stay
consistent with the memo schema and store implementations. Use this guide when
extending or debugging the package.
## Adding a New Memo Field
1. **Update the schema**
- Add the field entry in `schema.go`.
- Define the backing column (`Column`), JSON path (if applicable), type, and
allowed operators.
- Include the CEL variable in `EnvOptions`.
2. **Adjust parser or renderer (if needed)**
- For non-scalar fields (JSON booleans, lists), add handling in
`parser.go` or extend the renderer helpers.
- Keep validation in the parser (e.g., reject unsupported operators).
3. **Write a golden test**
- Extend the dialect-specific memo filter tests under
`store/db/{sqlite,mysql,postgres}/memo_filter_test.go` with a case that
exercises the new field.
4. **Run `go test ./...`** to ensure the SQL output matches expectations across
all dialects.
## Supporting Dialect Nuances
- Centralize differences inside `render.go`. If a new dialect-specific behavior
emerges (e.g., JSON operators), add the logic there rather than leaking it
into store code.
- Use the renderer helpers (`jsonExtractExpr`, `jsonArrayExpr`, etc.) rather than
sprinkling ad-hoc SQL strings.
- When placeholders change, adjust `addArg` so that argument numbering stays in
sync with store queries.
## Debugging Tips
- **Parser errors** Most originate in `buildCondition` or schema validation.
Enable logging around `parser.go` when diagnosing unknown identifier/operator
messages.
- **Renderer output** Temporary printf/log statements in `renderCondition` help
identify which IR node produced unexpected SQL.
- **Store integration** Ensure drivers call `filter.DefaultEngine()` exactly once
per process; the singleton caches the parsed CEL environment.
## Testing Checklist
- `go test ./store/...` ensures all dialect tests consume the engine correctly.
- Add targeted unit tests whenever new IR nodes or renderer paths are introduced.
- When changing boolean or JSON handling, verify all three dialect test suites
(SQLite, MySQL, Postgres) to avoid regression.

@ -0,0 +1,63 @@
# Memo Filter Engine
This package houses the memo-only filter engine that turns CEL expressions into
SQL fragments. The engine follows a three phase pipeline inspired by systems
such as Calcite or Prisma:
1. **Parsing** CEL expressions are parsed with `cel-go` and validated against
the memo-specific environment declared in `schema.go`. Only fields that
exist in the schema can surface in the filter.
2. **Normalization** the raw CEL AST is converted into an intermediate
representation (IR) defined in `ir.go`. The IR is a dialect-agnostic tree of
conditions (logical operators, comparisons, list membership, etc.). This
step enforces schema rules (e.g. operator compatibility, type checks).
3. **Rendering** the renderer in `render.go` walks the IR and produces a SQL
fragment plus placeholder arguments tailored to a target dialect
(`sqlite`, `mysql`, or `postgres`). Dialect differences such as JSON access,
boolean semantics, placeholders, and `LIKE` vs `ILIKE` are encapsulated in
renderer helpers.
The entry point is `filter.DefaultEngine()` from `engine.go`. It lazily constructs
an `Engine` configured with the memo schema and exposes:
```go
engine, _ := filter.DefaultEngine()
stmt, _ := engine.CompileToStatement(ctx, `has_task_list && visibility == "PUBLIC"`, filter.RenderOptions{
Dialect: filter.DialectPostgres,
})
// stmt.SQL -> "((memo.payload->'property'->>'hasTaskList')::boolean IS TRUE AND memo.visibility = $1)"
// stmt.Args -> ["PUBLIC"]
```
## Core Files
| File | Responsibility |
| ------------- | ------------------------------------------------------------------------------- |
| `schema.go` | Declares memo fields, their types, backing columns, CEL environment options |
| `ir.go` | IR node definitions used across the pipeline |
| `parser.go` | Converts CEL `Expr` into IR while applying schema validation |
| `render.go` | Translates IR into SQL, handling dialect-specific behavior |
| `engine.go` | Glue between the phases; exposes `Compile`, `CompileToStatement`, and `DefaultEngine` |
| `helpers.go` | Convenience helpers for store integration (appending conditions) |
## SQL Generation Notes
- **Placeholders**`?` is used for SQLite/MySQL, `$n` for Postgres. The renderer
tracks offsets to compose queries with pre-existing arguments.
- **JSON Fields** — Memo metadata lives in `memo.payload`. The renderer handles
`JSON_EXTRACT`/`json_extract`/`->`/`->>` variations and boolean coercion.
- **Tag Operations**`tag in [...]` and `"tag" in tags` become JSON array
predicates. SQLite uses `LIKE` patterns, MySQL uses `JSON_CONTAINS`, and
Postgres uses `@>`.
- **Boolean Flags** — Fields such as `has_task_list` render as `IS TRUE` equality
checks, or comparisons against `CAST('true' AS JSON)` depending on the dialect.
## Typical Integration
1. Fetch the engine with `filter.DefaultEngine()`.
2. Call `CompileToStatement` using the appropriate dialect enum.
3. Append the emitted SQL fragment/args to the existing `WHERE` clause.
4. Execute the resulting query through the store driver.
The `helpers.AppendConditions` helper encapsulates steps 23 when a driver needs
to process an array of filters.

@ -1,746 +0,0 @@
package filter
import (
"fmt"
"slices"
"strings"
"github.com/pkg/errors"
exprv1 "google.golang.org/genproto/googleapis/api/expr/v1alpha1"
)
// CommonSQLConverter handles the common CEL to SQL conversion logic.
type CommonSQLConverter struct {
dialect SQLDialect
paramIndex int
allowedFields []string
entityType string
}
// NewCommonSQLConverter creates a new converter with the specified dialect for memo filters.
func NewCommonSQLConverter(dialect SQLDialect) *CommonSQLConverter {
return &CommonSQLConverter{
dialect: dialect,
paramIndex: 1,
allowedFields: []string{"creator_id", "created_ts", "updated_ts", "visibility", "content", "pinned", "has_task_list", "has_link", "has_code", "has_incomplete_tasks"},
entityType: "memo",
}
}
// NewCommonSQLConverterWithOffset creates a new converter with the specified dialect and parameter offset for memo filters.
func NewCommonSQLConverterWithOffset(dialect SQLDialect, offset int) *CommonSQLConverter {
return &CommonSQLConverter{
dialect: dialect,
paramIndex: offset + 1,
allowedFields: []string{"creator_id", "created_ts", "updated_ts", "visibility", "content", "pinned", "has_task_list", "has_link", "has_code", "has_incomplete_tasks"},
entityType: "memo",
}
}
// NewUserSQLConverter creates a new converter for user filters.
func NewUserSQLConverter(dialect SQLDialect) *CommonSQLConverter {
return &CommonSQLConverter{
dialect: dialect,
paramIndex: 1,
allowedFields: []string{"username"},
entityType: "user",
}
}
// ConvertExprToSQL converts a CEL expression to SQL using the configured dialect.
func (c *CommonSQLConverter) ConvertExprToSQL(ctx *ConvertContext, expr *exprv1.Expr) error {
if v, ok := expr.ExprKind.(*exprv1.Expr_CallExpr); ok {
switch v.CallExpr.Function {
case "_||_", "_&&_":
return c.handleLogicalOperator(ctx, v.CallExpr)
case "!_":
return c.handleNotOperator(ctx, v.CallExpr)
case "_==_", "_!=_", "_<_", "_>_", "_<=_", "_>=_":
return c.handleComparisonOperator(ctx, v.CallExpr)
case "@in":
return c.handleInOperator(ctx, v.CallExpr)
case "contains":
return c.handleContainsOperator(ctx, v.CallExpr)
default:
return errors.Errorf("unsupported call expression function: %s", v.CallExpr.Function)
}
} else if v, ok := expr.ExprKind.(*exprv1.Expr_IdentExpr); ok {
return c.handleIdentifier(ctx, v.IdentExpr)
}
return nil
}
func (c *CommonSQLConverter) handleLogicalOperator(ctx *ConvertContext, callExpr *exprv1.Expr_Call) error {
if len(callExpr.Args) != 2 {
return errors.Errorf("invalid number of arguments for %s", callExpr.Function)
}
if _, err := ctx.Buffer.WriteString("("); err != nil {
return err
}
if err := c.ConvertExprToSQL(ctx, callExpr.Args[0]); err != nil {
return err
}
operator := "AND"
if callExpr.Function == "_||_" {
operator = "OR"
}
if _, err := ctx.Buffer.WriteString(fmt.Sprintf(" %s ", operator)); err != nil {
return err
}
if err := c.ConvertExprToSQL(ctx, callExpr.Args[1]); err != nil {
return err
}
if _, err := ctx.Buffer.WriteString(")"); err != nil {
return err
}
return nil
}
func (c *CommonSQLConverter) handleNotOperator(ctx *ConvertContext, callExpr *exprv1.Expr_Call) error {
if len(callExpr.Args) != 1 {
return errors.Errorf("invalid number of arguments for %s", callExpr.Function)
}
if _, err := ctx.Buffer.WriteString("NOT ("); err != nil {
return err
}
if err := c.ConvertExprToSQL(ctx, callExpr.Args[0]); err != nil {
return err
}
if _, err := ctx.Buffer.WriteString(")"); err != nil {
return err
}
return nil
}
func (c *CommonSQLConverter) handleComparisonOperator(ctx *ConvertContext, callExpr *exprv1.Expr_Call) error {
if len(callExpr.Args) != 2 {
return errors.Errorf("invalid number of arguments for %s", callExpr.Function)
}
// Check if the left side is a function call like size(tags)
if leftCallExpr, ok := callExpr.Args[0].ExprKind.(*exprv1.Expr_CallExpr); ok {
if leftCallExpr.CallExpr.Function == "size" {
return c.handleSizeComparison(ctx, callExpr, leftCallExpr.CallExpr)
}
}
identifier, err := GetIdentExprName(callExpr.Args[0])
if err != nil {
return err
}
if !slices.Contains(c.allowedFields, identifier) {
return errors.Errorf("invalid identifier for %s", callExpr.Function)
}
value, err := GetExprValue(callExpr.Args[1])
if err != nil {
return err
}
operator := c.getComparisonOperator(callExpr.Function)
// Handle memo fields
if c.entityType == "memo" {
switch identifier {
case "created_ts", "updated_ts":
return c.handleTimestampComparison(ctx, identifier, operator, value)
case "visibility", "content":
return c.handleStringComparison(ctx, identifier, operator, value)
case "creator_id":
return c.handleIntComparison(ctx, identifier, operator, value)
case "pinned":
return c.handlePinnedComparison(ctx, operator, value)
case "has_task_list", "has_link", "has_code", "has_incomplete_tasks":
return c.handleBooleanComparison(ctx, identifier, operator, value)
default:
return errors.Errorf("unsupported identifier in comparison: %s", identifier)
}
}
// Handle user fields
if c.entityType == "user" {
switch identifier {
case "username":
return c.handleUserStringComparison(ctx, identifier, operator, value)
default:
return errors.Errorf("unsupported user identifier in comparison: %s", identifier)
}
}
return errors.Errorf("unsupported entity type: %s", c.entityType)
}
func (c *CommonSQLConverter) handleSizeComparison(ctx *ConvertContext, callExpr *exprv1.Expr_Call, sizeCall *exprv1.Expr_Call) error {
if len(sizeCall.Args) != 1 {
return errors.New("size function requires exactly one argument")
}
identifier, err := GetIdentExprName(sizeCall.Args[0])
if err != nil {
return err
}
if identifier != "tags" {
return errors.Errorf("size function only supports 'tags' identifier, got: %s", identifier)
}
value, err := GetExprValue(callExpr.Args[1])
if err != nil {
return err
}
valueInt, ok := value.(int64)
if !ok {
return errors.New("size comparison value must be an integer")
}
operator := c.getComparisonOperator(callExpr.Function)
if _, err := ctx.Buffer.WriteString(fmt.Sprintf("%s %s %s",
c.dialect.GetJSONArrayLength("$.tags"),
operator,
c.dialect.GetParameterPlaceholder(c.paramIndex))); err != nil {
return err
}
ctx.Args = append(ctx.Args, valueInt)
c.paramIndex++
return nil
}
func (c *CommonSQLConverter) handleInOperator(ctx *ConvertContext, callExpr *exprv1.Expr_Call) error {
if len(callExpr.Args) != 2 {
return errors.Errorf("invalid number of arguments for %s", callExpr.Function)
}
// Check if this is "element in collection" syntax
if identifier, err := GetIdentExprName(callExpr.Args[1]); err == nil {
if identifier == "tags" {
return c.handleElementInTags(ctx, callExpr.Args[0])
}
return errors.Errorf("invalid collection identifier for %s: %s", callExpr.Function, identifier)
}
// Original logic for "identifier in [list]" syntax
identifier, err := GetIdentExprName(callExpr.Args[0])
if err != nil {
return err
}
if !slices.Contains([]string{"tag", "visibility", "content_id", "memo_id"}, identifier) {
return errors.Errorf("invalid identifier for %s", callExpr.Function)
}
values := []any{}
for _, element := range callExpr.Args[1].GetListExpr().Elements {
value, err := GetConstValue(element)
if err != nil {
return err
}
values = append(values, value)
}
if identifier == "tag" {
return c.handleTagInList(ctx, values)
} else if identifier == "visibility" {
return c.handleVisibilityInList(ctx, values)
} else if identifier == "content_id" {
return c.handleContentIDInList(ctx, values)
} else if identifier == "memo_id" {
return c.handleMemoIDInList(ctx, values)
}
return nil
}
func (c *CommonSQLConverter) handleElementInTags(ctx *ConvertContext, elementExpr *exprv1.Expr) error {
element, err := GetConstValue(elementExpr)
if err != nil {
return errors.Errorf("first argument must be a constant value for 'element in tags': %v", err)
}
// Use dialect-specific JSON contains logic
template := c.dialect.GetJSONContains("$.tags", "element")
sqlExpr := strings.Replace(template, "?", c.dialect.GetParameterPlaceholder(c.paramIndex), 1)
if _, err := ctx.Buffer.WriteString(sqlExpr); err != nil {
return err
}
// Handle args based on dialect
if _, ok := c.dialect.(*SQLiteDialect); ok {
// SQLite uses LIKE with pattern
ctx.Args = append(ctx.Args, fmt.Sprintf(`%%"%s"%%`, element))
} else {
// MySQL and PostgreSQL expect plain values
ctx.Args = append(ctx.Args, element)
}
c.paramIndex++
return nil
}
func (c *CommonSQLConverter) handleTagInList(ctx *ConvertContext, values []any) error {
subconditions := []string{}
args := []any{}
for _, v := range values {
if _, ok := c.dialect.(*SQLiteDialect); ok {
subconditions = append(subconditions, c.dialect.GetJSONLike("$.tags", "pattern"))
args = append(args, fmt.Sprintf(`%%"%s"%%`, v))
} else {
// Replace ? with proper placeholder for each dialect
template := c.dialect.GetJSONContains("$.tags", "element")
sql := strings.Replace(template, "?", c.dialect.GetParameterPlaceholder(c.paramIndex), 1)
subconditions = append(subconditions, sql)
args = append(args, fmt.Sprintf(`"%s"`, v))
}
c.paramIndex++
}
if len(subconditions) == 1 {
if _, err := ctx.Buffer.WriteString(subconditions[0]); err != nil {
return err
}
} else {
if _, err := ctx.Buffer.WriteString(fmt.Sprintf("(%s)", strings.Join(subconditions, " OR "))); err != nil {
return err
}
}
ctx.Args = append(ctx.Args, args...)
return nil
}
func (c *CommonSQLConverter) handleVisibilityInList(ctx *ConvertContext, values []any) error {
placeholders := []string{}
for range values {
placeholders = append(placeholders, c.dialect.GetParameterPlaceholder(c.paramIndex))
c.paramIndex++
}
tablePrefix := c.dialect.GetTablePrefix("memo")
if _, ok := c.dialect.(*PostgreSQLDialect); ok {
if _, err := ctx.Buffer.WriteString(fmt.Sprintf("%s.visibility IN (%s)", tablePrefix, strings.Join(placeholders, ","))); err != nil {
return err
}
} else {
if _, err := ctx.Buffer.WriteString(fmt.Sprintf("%s.`visibility` IN (%s)", tablePrefix, strings.Join(placeholders, ","))); err != nil {
return err
}
}
ctx.Args = append(ctx.Args, values...)
return nil
}
func (c *CommonSQLConverter) handleContentIDInList(ctx *ConvertContext, values []any) error {
placeholders := []string{}
for range values {
placeholders = append(placeholders, c.dialect.GetParameterPlaceholder(c.paramIndex))
c.paramIndex++
}
tablePrefix := c.dialect.GetTablePrefix("reaction")
if _, ok := c.dialect.(*PostgreSQLDialect); ok {
if _, err := ctx.Buffer.WriteString(fmt.Sprintf("%s.content_id IN (%s)", tablePrefix, strings.Join(placeholders, ","))); err != nil {
return err
}
} else {
if _, err := ctx.Buffer.WriteString(fmt.Sprintf("%s.`content_id` IN (%s)", tablePrefix, strings.Join(placeholders, ","))); err != nil {
return err
}
}
ctx.Args = append(ctx.Args, values...)
return nil
}
func (c *CommonSQLConverter) handleMemoIDInList(ctx *ConvertContext, values []any) error {
placeholders := []string{}
for range values {
placeholders = append(placeholders, c.dialect.GetParameterPlaceholder(c.paramIndex))
c.paramIndex++
}
tablePrefix := c.dialect.GetTablePrefix("resource")
if _, ok := c.dialect.(*PostgreSQLDialect); ok {
if _, err := ctx.Buffer.WriteString(fmt.Sprintf("%s.memo_id IN (%s)", tablePrefix, strings.Join(placeholders, ","))); err != nil {
return err
}
} else {
if _, err := ctx.Buffer.WriteString(fmt.Sprintf("%s.`memo_id` IN (%s)", tablePrefix, strings.Join(placeholders, ","))); err != nil {
return err
}
}
ctx.Args = append(ctx.Args, values...)
return nil
}
func (c *CommonSQLConverter) handleContainsOperator(ctx *ConvertContext, callExpr *exprv1.Expr_Call) error {
if len(callExpr.Args) != 1 {
return errors.Errorf("invalid number of arguments for %s", callExpr.Function)
}
identifier, err := GetIdentExprName(callExpr.Target)
if err != nil {
return err
}
if identifier != "content" {
return errors.Errorf("invalid identifier for %s", callExpr.Function)
}
arg, err := GetConstValue(callExpr.Args[0])
if err != nil {
return err
}
tablePrefix := c.dialect.GetTablePrefix("memo")
// PostgreSQL uses ILIKE and no backticks
if _, ok := c.dialect.(*PostgreSQLDialect); ok {
if _, err := ctx.Buffer.WriteString(fmt.Sprintf("%s.content ILIKE %s", tablePrefix, c.dialect.GetParameterPlaceholder(c.paramIndex))); err != nil {
return err
}
} else {
if _, err := ctx.Buffer.WriteString(fmt.Sprintf("%s.`content` LIKE %s", tablePrefix, c.dialect.GetParameterPlaceholder(c.paramIndex))); err != nil {
return err
}
}
ctx.Args = append(ctx.Args, fmt.Sprintf("%%%s%%", arg))
c.paramIndex++
return nil
}
func (c *CommonSQLConverter) handleIdentifier(ctx *ConvertContext, identExpr *exprv1.Expr_Ident) error {
identifier := identExpr.GetName()
// Only memo entity has boolean identifiers that can be used standalone
if c.entityType != "memo" {
return errors.Errorf("invalid identifier %s for entity type %s", identifier, c.entityType)
}
if !slices.Contains([]string{"pinned", "has_task_list", "has_link", "has_code", "has_incomplete_tasks"}, identifier) {
return errors.Errorf("invalid identifier %s", identifier)
}
if identifier == "pinned" {
tablePrefix := c.dialect.GetTablePrefix("memo")
if _, ok := c.dialect.(*PostgreSQLDialect); ok {
if _, err := ctx.Buffer.WriteString(fmt.Sprintf("%s.pinned IS TRUE", tablePrefix)); err != nil {
return err
}
} else {
if _, err := ctx.Buffer.WriteString(fmt.Sprintf("%s.`pinned` IS TRUE", tablePrefix)); err != nil {
return err
}
}
} else if identifier == "has_task_list" {
if _, err := ctx.Buffer.WriteString(c.dialect.GetBooleanCheck("$.property.hasTaskList")); err != nil {
return err
}
} else if identifier == "has_link" {
if _, err := ctx.Buffer.WriteString(c.dialect.GetBooleanCheck("$.property.hasLink")); err != nil {
return err
}
} else if identifier == "has_code" {
if _, err := ctx.Buffer.WriteString(c.dialect.GetBooleanCheck("$.property.hasCode")); err != nil {
return err
}
} else if identifier == "has_incomplete_tasks" {
if _, err := ctx.Buffer.WriteString(c.dialect.GetBooleanCheck("$.property.hasIncompleteTasks")); err != nil {
return err
}
}
return nil
}
func (c *CommonSQLConverter) handleTimestampComparison(ctx *ConvertContext, field, operator string, value interface{}) error {
valueInt, ok := value.(int64)
if !ok {
return errors.New("invalid integer timestamp value")
}
timestampField := c.dialect.GetTimestampComparison(field)
if _, err := ctx.Buffer.WriteString(fmt.Sprintf("%s %s %s", timestampField, operator, c.dialect.GetParameterPlaceholder(c.paramIndex))); err != nil {
return err
}
ctx.Args = append(ctx.Args, valueInt)
c.paramIndex++
return nil
}
func (c *CommonSQLConverter) handleStringComparison(ctx *ConvertContext, field, operator string, value interface{}) error {
if operator != "=" && operator != "!=" {
return errors.Errorf("invalid operator for %s", field)
}
valueStr, ok := value.(string)
if !ok {
return errors.New("invalid string value")
}
tablePrefix := c.dialect.GetTablePrefix("memo")
if _, ok := c.dialect.(*PostgreSQLDialect); ok {
// PostgreSQL doesn't use backticks
if _, err := ctx.Buffer.WriteString(fmt.Sprintf("%s.%s %s %s", tablePrefix, field, operator, c.dialect.GetParameterPlaceholder(c.paramIndex))); err != nil {
return err
}
} else {
// MySQL and SQLite use backticks
fieldName := field
if field == "visibility" {
fieldName = "`visibility`"
} else if field == "content" {
fieldName = "`content`"
}
if _, err := ctx.Buffer.WriteString(fmt.Sprintf("%s.%s %s %s", tablePrefix, fieldName, operator, c.dialect.GetParameterPlaceholder(c.paramIndex))); err != nil {
return err
}
}
ctx.Args = append(ctx.Args, valueStr)
c.paramIndex++
return nil
}
func (c *CommonSQLConverter) handleUserStringComparison(ctx *ConvertContext, field, operator string, value interface{}) error {
if operator != "=" && operator != "!=" {
return errors.Errorf("invalid operator for %s", field)
}
valueStr, ok := value.(string)
if !ok {
return errors.New("invalid string value")
}
tablePrefix := c.dialect.GetTablePrefix("user")
if _, ok := c.dialect.(*PostgreSQLDialect); ok {
// PostgreSQL doesn't use backticks
if _, err := ctx.Buffer.WriteString(fmt.Sprintf("%s.%s %s %s", tablePrefix, field, operator, c.dialect.GetParameterPlaceholder(c.paramIndex))); err != nil {
return err
}
} else {
// MySQL and SQLite use backticks
if _, err := ctx.Buffer.WriteString(fmt.Sprintf("%s.`%s` %s %s", tablePrefix, field, operator, c.dialect.GetParameterPlaceholder(c.paramIndex))); err != nil {
return err
}
}
ctx.Args = append(ctx.Args, valueStr)
c.paramIndex++
return nil
}
func (c *CommonSQLConverter) handleIntComparison(ctx *ConvertContext, field, operator string, value interface{}) error {
if operator != "=" && operator != "!=" {
return errors.Errorf("invalid operator for %s", field)
}
valueInt, ok := value.(int64)
if !ok {
return errors.New("invalid int value")
}
tablePrefix := c.dialect.GetTablePrefix("memo")
if _, ok := c.dialect.(*PostgreSQLDialect); ok {
// PostgreSQL doesn't use backticks
if _, err := ctx.Buffer.WriteString(fmt.Sprintf("%s.%s %s %s", tablePrefix, field, operator, c.dialect.GetParameterPlaceholder(c.paramIndex))); err != nil {
return err
}
} else {
// MySQL and SQLite use backticks
if _, err := ctx.Buffer.WriteString(fmt.Sprintf("%s.`%s` %s %s", tablePrefix, field, operator, c.dialect.GetParameterPlaceholder(c.paramIndex))); err != nil {
return err
}
}
ctx.Args = append(ctx.Args, valueInt)
c.paramIndex++
return nil
}
func (c *CommonSQLConverter) handlePinnedComparison(ctx *ConvertContext, operator string, value interface{}) error {
if operator != "=" && operator != "!=" {
return errors.Errorf("invalid operator for pinned field")
}
valueBool, ok := value.(bool)
if !ok {
return errors.New("invalid boolean value for pinned field")
}
tablePrefix := c.dialect.GetTablePrefix("memo")
var sqlExpr string
if _, ok := c.dialect.(*PostgreSQLDialect); ok {
sqlExpr = fmt.Sprintf("%s.pinned %s %s", tablePrefix, operator, c.dialect.GetParameterPlaceholder(c.paramIndex))
} else {
sqlExpr = fmt.Sprintf("%s.`pinned` %s %s", tablePrefix, operator, c.dialect.GetParameterPlaceholder(c.paramIndex))
}
if _, err := ctx.Buffer.WriteString(sqlExpr); err != nil {
return err
}
ctx.Args = append(ctx.Args, c.dialect.GetBooleanValue(valueBool))
c.paramIndex++
return nil
}
func (c *CommonSQLConverter) handleBooleanComparison(ctx *ConvertContext, field, operator string, value interface{}) error {
if operator != "=" && operator != "!=" {
return errors.Errorf("invalid operator for %s", field)
}
valueBool, ok := value.(bool)
if !ok {
return errors.Errorf("invalid boolean value for %s", field)
}
// Map field name to JSON path
var jsonPath string
switch field {
case "has_task_list":
jsonPath = "$.property.hasTaskList"
case "has_link":
jsonPath = "$.property.hasLink"
case "has_code":
jsonPath = "$.property.hasCode"
case "has_incomplete_tasks":
jsonPath = "$.property.hasIncompleteTasks"
default:
return errors.Errorf("unsupported boolean field: %s", field)
}
// Special handling for SQLite based on field
if _, ok := c.dialect.(*SQLiteDialect); ok {
if field == "has_task_list" {
// has_task_list uses = 1 / = 0 / != 1 / != 0
var sqlExpr string
if operator == "=" {
if valueBool {
sqlExpr = fmt.Sprintf("%s = 1", c.dialect.GetJSONExtract(jsonPath))
} else {
sqlExpr = fmt.Sprintf("%s = 0", c.dialect.GetJSONExtract(jsonPath))
}
} else { // operator == "!="
if valueBool {
sqlExpr = fmt.Sprintf("%s != 1", c.dialect.GetJSONExtract(jsonPath))
} else {
sqlExpr = fmt.Sprintf("%s != 0", c.dialect.GetJSONExtract(jsonPath))
}
}
if _, err := ctx.Buffer.WriteString(sqlExpr); err != nil {
return err
}
return nil
}
// Other fields use IS TRUE / NOT(... IS TRUE)
var sqlExpr string
if operator == "=" {
if valueBool {
sqlExpr = fmt.Sprintf("%s IS TRUE", c.dialect.GetJSONExtract(jsonPath))
} else {
sqlExpr = fmt.Sprintf("NOT(%s IS TRUE)", c.dialect.GetJSONExtract(jsonPath))
}
} else { // operator == "!="
if valueBool {
sqlExpr = fmt.Sprintf("NOT(%s IS TRUE)", c.dialect.GetJSONExtract(jsonPath))
} else {
sqlExpr = fmt.Sprintf("%s IS TRUE", c.dialect.GetJSONExtract(jsonPath))
}
}
if _, err := ctx.Buffer.WriteString(sqlExpr); err != nil {
return err
}
return nil
}
// Special handling for MySQL - use raw operator with CAST
if _, ok := c.dialect.(*MySQLDialect); ok {
var sqlExpr string
boolStr := "false"
if valueBool {
boolStr = "true"
}
sqlExpr = fmt.Sprintf("%s %s CAST('%s' AS JSON)", c.dialect.GetJSONExtract(jsonPath), operator, boolStr)
if _, err := ctx.Buffer.WriteString(sqlExpr); err != nil {
return err
}
return nil
}
// Handle PostgreSQL differently - it uses the raw operator
if _, ok := c.dialect.(*PostgreSQLDialect); ok {
jsonExtract := c.dialect.GetJSONExtract(jsonPath)
sqlExpr := fmt.Sprintf("(%s)::boolean %s %s",
jsonExtract,
operator,
c.dialect.GetParameterPlaceholder(c.paramIndex))
if _, err := ctx.Buffer.WriteString(sqlExpr); err != nil {
return err
}
ctx.Args = append(ctx.Args, valueBool)
c.paramIndex++
return nil
}
// Handle other dialects
if operator == "!=" {
valueBool = !valueBool
}
sqlExpr := c.dialect.GetBooleanComparison(jsonPath, valueBool)
if _, err := ctx.Buffer.WriteString(sqlExpr); err != nil {
return err
}
return nil
}
func (*CommonSQLConverter) getComparisonOperator(function string) string {
switch function {
case "_==_":
return "="
case "_!=_":
return "!="
case "_<_":
return "<"
case "_>_":
return ">"
case "_<=_":
return "<="
case "_>=_":
return ">="
default:
return "="
}
}

@ -1,20 +0,0 @@
package filter
import (
"strings"
)
type ConvertContext struct {
Buffer strings.Builder
Args []any
// The offset of the next argument in the condition string.
// Mainly using for PostgreSQL.
ArgsOffset int
}
func NewConvertContext() *ConvertContext {
return &ConvertContext{
Buffer: strings.Builder{},
Args: []any{},
}
}

@ -1,215 +0,0 @@
package filter
import (
"fmt"
"strings"
)
// SQLDialect defines database-specific SQL generation methods.
type SQLDialect interface {
// Basic field access
GetTablePrefix(entityName string) string
GetParameterPlaceholder(index int) string
// JSON operations
GetJSONExtract(path string) string
GetJSONArrayLength(path string) string
GetJSONContains(path, element string) string
GetJSONLike(path, pattern string) string
// Boolean operations
GetBooleanValue(value bool) interface{}
GetBooleanComparison(path string, value bool) string
GetBooleanCheck(path string) string
// Timestamp operations
GetTimestampComparison(field string) string
GetCurrentTimestamp() string
}
// DatabaseType represents the type of database.
type DatabaseType string
const (
SQLite DatabaseType = "sqlite"
MySQL DatabaseType = "mysql"
PostgreSQL DatabaseType = "postgres"
)
// GetDialect returns the appropriate dialect for the database type.
func GetDialect(dbType DatabaseType) SQLDialect {
switch dbType {
case SQLite:
return &SQLiteDialect{}
case MySQL:
return &MySQLDialect{}
case PostgreSQL:
return &PostgreSQLDialect{}
default:
return &SQLiteDialect{} // default fallback
}
}
// SQLiteDialect implements SQLDialect for SQLite.
type SQLiteDialect struct{}
func (*SQLiteDialect) GetTablePrefix(entityName string) string {
return fmt.Sprintf("`%s`", entityName)
}
func (*SQLiteDialect) GetParameterPlaceholder(_ int) string {
return "?"
}
func (d *SQLiteDialect) GetJSONExtract(path string) string {
return fmt.Sprintf("JSON_EXTRACT(%s.`payload`, '%s')", d.GetTablePrefix("memo"), path)
}
func (d *SQLiteDialect) GetJSONArrayLength(path string) string {
return fmt.Sprintf("JSON_ARRAY_LENGTH(COALESCE(%s, JSON_ARRAY()))", d.GetJSONExtract(path))
}
func (d *SQLiteDialect) GetJSONContains(path, _ string) string {
return fmt.Sprintf("%s LIKE ?", d.GetJSONExtract(path))
}
func (d *SQLiteDialect) GetJSONLike(path, _ string) string {
return fmt.Sprintf("%s LIKE ?", d.GetJSONExtract(path))
}
func (*SQLiteDialect) GetBooleanValue(value bool) interface{} {
if value {
return 1
}
return 0
}
func (d *SQLiteDialect) GetBooleanComparison(path string, value bool) string {
if value {
return fmt.Sprintf("%s = 1", d.GetJSONExtract(path))
}
return fmt.Sprintf("%s = 0", d.GetJSONExtract(path))
}
func (d *SQLiteDialect) GetBooleanCheck(path string) string {
return fmt.Sprintf("%s IS TRUE", d.GetJSONExtract(path))
}
func (d *SQLiteDialect) GetTimestampComparison(field string) string {
return fmt.Sprintf("%s.`%s`", d.GetTablePrefix("memo"), field)
}
func (*SQLiteDialect) GetCurrentTimestamp() string {
return "strftime('%s', 'now')"
}
// MySQLDialect implements SQLDialect for MySQL.
type MySQLDialect struct{}
func (*MySQLDialect) GetTablePrefix(entityName string) string {
return fmt.Sprintf("`%s`", entityName)
}
func (*MySQLDialect) GetParameterPlaceholder(_ int) string {
return "?"
}
func (d *MySQLDialect) GetJSONExtract(path string) string {
return fmt.Sprintf("JSON_EXTRACT(%s.`payload`, '%s')", d.GetTablePrefix("memo"), path)
}
func (d *MySQLDialect) GetJSONArrayLength(path string) string {
return fmt.Sprintf("JSON_LENGTH(COALESCE(%s, JSON_ARRAY()))", d.GetJSONExtract(path))
}
func (d *MySQLDialect) GetJSONContains(path, _ string) string {
return fmt.Sprintf("JSON_CONTAINS(%s, ?)", d.GetJSONExtract(path))
}
func (d *MySQLDialect) GetJSONLike(path, _ string) string {
return fmt.Sprintf("%s LIKE ?", d.GetJSONExtract(path))
}
func (*MySQLDialect) GetBooleanValue(value bool) interface{} {
return value
}
func (d *MySQLDialect) GetBooleanComparison(path string, value bool) string {
if value {
return fmt.Sprintf("%s = CAST('true' AS JSON)", d.GetJSONExtract(path))
}
return fmt.Sprintf("%s != CAST('true' AS JSON)", d.GetJSONExtract(path))
}
func (d *MySQLDialect) GetBooleanCheck(path string) string {
return fmt.Sprintf("%s = CAST('true' AS JSON)", d.GetJSONExtract(path))
}
func (d *MySQLDialect) GetTimestampComparison(field string) string {
return fmt.Sprintf("UNIX_TIMESTAMP(%s.`%s`)", d.GetTablePrefix("memo"), field)
}
func (*MySQLDialect) GetCurrentTimestamp() string {
return "UNIX_TIMESTAMP()"
}
// PostgreSQLDialect implements SQLDialect for PostgreSQL.
type PostgreSQLDialect struct{}
func (*PostgreSQLDialect) GetTablePrefix(entityName string) string {
return entityName
}
func (*PostgreSQLDialect) GetParameterPlaceholder(index int) string {
return fmt.Sprintf("$%d", index)
}
func (d *PostgreSQLDialect) GetJSONExtract(path string) string {
// Convert $.property.hasTaskList to memo.payload->'property'->>'hasTaskList'
parts := strings.Split(strings.TrimPrefix(path, "$."), ".")
result := fmt.Sprintf("%s.payload", d.GetTablePrefix("memo"))
for i, part := range parts {
if i == len(parts)-1 {
result += fmt.Sprintf("->>'%s'", part)
} else {
result += fmt.Sprintf("->'%s'", part)
}
}
return result
}
func (d *PostgreSQLDialect) GetJSONArrayLength(path string) string {
jsonPath := strings.Replace(path, "$.tags", "payload->'tags'", 1)
return fmt.Sprintf("jsonb_array_length(COALESCE(%s.%s, '[]'::jsonb))", d.GetTablePrefix("memo"), jsonPath)
}
func (d *PostgreSQLDialect) GetJSONContains(path, _ string) string {
jsonPath := strings.Replace(path, "$.tags", "payload->'tags'", 1)
return fmt.Sprintf("%s.%s @> jsonb_build_array(?::json)", d.GetTablePrefix("memo"), jsonPath)
}
func (d *PostgreSQLDialect) GetJSONLike(path, _ string) string {
jsonPath := strings.Replace(path, "$.tags", "payload->'tags'", 1)
return fmt.Sprintf("%s.%s @> jsonb_build_array(?::json)", d.GetTablePrefix("memo"), jsonPath)
}
func (*PostgreSQLDialect) GetBooleanValue(value bool) interface{} {
return value
}
func (d *PostgreSQLDialect) GetBooleanComparison(path string, _ bool) string {
// Note: The parameter placeholder will be replaced by the caller
return fmt.Sprintf("(%s)::boolean = ?", d.GetJSONExtract(path))
}
func (d *PostgreSQLDialect) GetBooleanCheck(path string) string {
return fmt.Sprintf("(%s)::boolean IS TRUE", d.GetJSONExtract(path))
}
func (d *PostgreSQLDialect) GetTimestampComparison(field string) string {
return fmt.Sprintf("EXTRACT(EPOCH FROM TO_TIMESTAMP(%s.%s))", d.GetTablePrefix("memo"), field)
}
func (*PostgreSQLDialect) GetCurrentTimestamp() string {
return "EXTRACT(EPOCH FROM NOW())"
}

@ -0,0 +1,180 @@
package filter
import (
"context"
"fmt"
"strings"
"sync"
"github.com/google/cel-go/cel"
"github.com/pkg/errors"
)
// Engine parses CEL filters into a dialect-agnostic condition tree.
type Engine struct {
schema Schema
env *cel.Env
}
// NewEngine builds a new Engine for the provided schema.
func NewEngine(schema Schema) (*Engine, error) {
env, err := cel.NewEnv(schema.EnvOptions...)
if err != nil {
return nil, errors.Wrap(err, "failed to create CEL environment")
}
return &Engine{
schema: schema,
env: env,
}, nil
}
// Program stores a compiled filter condition.
type Program struct {
schema Schema
condition Condition
}
// ConditionTree exposes the underlying condition tree.
func (p *Program) ConditionTree() Condition {
return p.condition
}
// Compile parses the filter string into an executable program.
func (e *Engine) Compile(_ context.Context, filter string) (*Program, error) {
if strings.TrimSpace(filter) == "" {
return nil, errors.New("filter expression is empty")
}
filter = normalizeLegacyFilter(filter)
ast, issues := e.env.Compile(filter)
if issues != nil && issues.Err() != nil {
return nil, errors.Wrap(issues.Err(), "failed to compile filter")
}
parsed, err := cel.AstToParsedExpr(ast)
if err != nil {
return nil, errors.Wrap(err, "failed to convert AST")
}
cond, err := buildCondition(parsed.GetExpr(), e.schema)
if err != nil {
return nil, err
}
return &Program{
schema: e.schema,
condition: cond,
}, nil
}
// CompileToStatement compiles and renders the filter in a single step.
func (e *Engine) CompileToStatement(ctx context.Context, filter string, opts RenderOptions) (Statement, error) {
program, err := e.Compile(ctx, filter)
if err != nil {
return Statement{}, err
}
return program.Render(opts)
}
// RenderOptions configure SQL rendering.
type RenderOptions struct {
Dialect DialectName
PlaceholderOffset int
DisableNullChecks bool
}
// Statement contains the rendered SQL fragment and its args.
type Statement struct {
SQL string
Args []any
}
// Render converts the program into a dialect-specific SQL fragment.
func (p *Program) Render(opts RenderOptions) (Statement, error) {
renderer := newRenderer(p.schema, opts)
return renderer.Render(p.condition)
}
var (
defaultOnce sync.Once
defaultInst *Engine
defaultErr error
)
// DefaultEngine returns the process-wide memo filter engine.
func DefaultEngine() (*Engine, error) {
defaultOnce.Do(func() {
defaultInst, defaultErr = NewEngine(NewSchema())
})
return defaultInst, defaultErr
}
func normalizeLegacyFilter(expr string) string {
expr = rewriteNumericLogicalOperand(expr, "&&")
expr = rewriteNumericLogicalOperand(expr, "||")
return expr
}
func rewriteNumericLogicalOperand(expr, op string) string {
var builder strings.Builder
n := len(expr)
i := 0
var inQuote rune
for i < n {
ch := expr[i]
if inQuote != 0 {
builder.WriteByte(ch)
if ch == '\\' && i+1 < n {
builder.WriteByte(expr[i+1])
i += 2
continue
}
if ch == byte(inQuote) {
inQuote = 0
}
i++
continue
}
if ch == '\'' || ch == '"' {
inQuote = rune(ch)
builder.WriteByte(ch)
i++
continue
}
if strings.HasPrefix(expr[i:], op) {
builder.WriteString(op)
i += len(op)
// Preserve whitespace following the operator.
wsStart := i
for i < n && (expr[i] == ' ' || expr[i] == '\t') {
i++
}
builder.WriteString(expr[wsStart:i])
signStart := i
if i < n && (expr[i] == '+' || expr[i] == '-') {
i++
}
for i < n && expr[i] >= '0' && expr[i] <= '9' {
i++
}
if i > signStart {
numLiteral := expr[signStart:i]
builder.WriteString(fmt.Sprintf("(%s != 0)", numLiteral))
} else {
builder.WriteString(expr[signStart:i])
}
continue
}
builder.WriteByte(ch)
i++
}
return builder.String()
}

@ -1,127 +0,0 @@
package filter
import (
"errors"
"time"
exprv1 "google.golang.org/genproto/googleapis/api/expr/v1alpha1"
)
// GetConstValue returns the constant value of the expression.
func GetConstValue(expr *exprv1.Expr) (any, error) {
v, ok := expr.ExprKind.(*exprv1.Expr_ConstExpr)
if !ok {
return nil, errors.New("invalid constant expression")
}
switch v.ConstExpr.ConstantKind.(type) {
case *exprv1.Constant_StringValue:
return v.ConstExpr.GetStringValue(), nil
case *exprv1.Constant_Int64Value:
return v.ConstExpr.GetInt64Value(), nil
case *exprv1.Constant_Uint64Value:
return v.ConstExpr.GetUint64Value(), nil
case *exprv1.Constant_DoubleValue:
return v.ConstExpr.GetDoubleValue(), nil
case *exprv1.Constant_BoolValue:
return v.ConstExpr.GetBoolValue(), nil
default:
return nil, errors.New("unexpected constant type")
}
}
// GetIdentExprName returns the name of the identifier expression.
func GetIdentExprName(expr *exprv1.Expr) (string, error) {
_, ok := expr.ExprKind.(*exprv1.Expr_IdentExpr)
if !ok {
return "", errors.New("invalid identifier expression")
}
return expr.GetIdentExpr().GetName(), nil
}
// GetFunctionValue evaluates CEL function calls and returns their value.
// This is specifically for time functions like now().
func GetFunctionValue(expr *exprv1.Expr) (any, error) {
callExpr, ok := expr.ExprKind.(*exprv1.Expr_CallExpr)
if !ok {
return nil, errors.New("invalid function call expression")
}
switch callExpr.CallExpr.Function {
case "now":
if len(callExpr.CallExpr.Args) != 0 {
return nil, errors.New("now() function takes no arguments")
}
return time.Now().Unix(), nil
case "_-_":
// Handle subtraction for expressions like "now() - 60 * 60 * 24"
if len(callExpr.CallExpr.Args) != 2 {
return nil, errors.New("subtraction requires exactly two arguments")
}
left, err := GetExprValue(callExpr.CallExpr.Args[0])
if err != nil {
return nil, err
}
right, err := GetExprValue(callExpr.CallExpr.Args[1])
if err != nil {
return nil, err
}
leftInt, ok1 := left.(int64)
rightInt, ok2 := right.(int64)
if !ok1 || !ok2 {
return nil, errors.New("subtraction operands must be integers")
}
return leftInt - rightInt, nil
case "_*_":
// Handle multiplication for expressions like "60 * 60 * 24"
if len(callExpr.CallExpr.Args) != 2 {
return nil, errors.New("multiplication requires exactly two arguments")
}
left, err := GetExprValue(callExpr.CallExpr.Args[0])
if err != nil {
return nil, err
}
right, err := GetExprValue(callExpr.CallExpr.Args[1])
if err != nil {
return nil, err
}
leftInt, ok1 := left.(int64)
rightInt, ok2 := right.(int64)
if !ok1 || !ok2 {
return nil, errors.New("multiplication operands must be integers")
}
return leftInt * rightInt, nil
case "_+_":
// Handle addition
if len(callExpr.CallExpr.Args) != 2 {
return nil, errors.New("addition requires exactly two arguments")
}
left, err := GetExprValue(callExpr.CallExpr.Args[0])
if err != nil {
return nil, err
}
right, err := GetExprValue(callExpr.CallExpr.Args[1])
if err != nil {
return nil, err
}
leftInt, ok1 := left.(int64)
rightInt, ok2 := right.(int64)
if !ok1 || !ok2 {
return nil, errors.New("addition operands must be integers")
}
return leftInt + rightInt, nil
default:
return nil, errors.New("unsupported function: " + callExpr.CallExpr.Function)
}
}
// GetExprValue attempts to get a value from an expression, trying constants first, then functions.
func GetExprValue(expr *exprv1.Expr) (any, error) {
// Try to get constant value first
if constValue, err := GetConstValue(expr); err == nil {
return constValue, nil
}
// If not a constant, try to evaluate as a function
return GetFunctionValue(expr)
}

@ -1,66 +0,0 @@
package filter
import (
"time"
"github.com/google/cel-go/cel"
"github.com/google/cel-go/common/types"
"github.com/google/cel-go/common/types/ref"
"github.com/pkg/errors"
exprv1 "google.golang.org/genproto/googleapis/api/expr/v1alpha1"
)
// MemoFilterCELAttributes are the CEL attributes for memo.
var MemoFilterCELAttributes = []cel.EnvOption{
cel.Variable("content", cel.StringType),
cel.Variable("creator_id", cel.IntType),
cel.Variable("created_ts", cel.IntType),
cel.Variable("updated_ts", cel.IntType),
cel.Variable("pinned", cel.BoolType),
cel.Variable("tag", cel.StringType),
cel.Variable("tags", cel.ListType(cel.StringType)),
cel.Variable("visibility", cel.StringType),
cel.Variable("has_task_list", cel.BoolType),
cel.Variable("has_link", cel.BoolType),
cel.Variable("has_code", cel.BoolType),
cel.Variable("has_incomplete_tasks", cel.BoolType),
// Current timestamp function.
cel.Function("now",
cel.Overload("now",
[]*cel.Type{},
cel.IntType,
cel.FunctionBinding(func(_ ...ref.Val) ref.Val {
return types.Int(time.Now().Unix())
}),
),
),
}
// ReactionFilterCELAttributes are the CEL attributes for reaction.
var ReactionFilterCELAttributes = []cel.EnvOption{
cel.Variable("content_id", cel.StringType),
}
// UserFilterCELAttributes are the CEL attributes for user.
var UserFilterCELAttributes = []cel.EnvOption{
cel.Variable("username", cel.StringType),
}
// AttachmentFilterCELAttributes are the CEL attributes for user.
var AttachmentFilterCELAttributes = []cel.EnvOption{
cel.Variable("memo_id", cel.StringType),
}
// Parse parses the filter string and returns the parsed expression.
// The filter string should be a CEL expression.
func Parse(filter string, opts ...cel.EnvOption) (expr *exprv1.ParsedExpr, err error) {
e, err := cel.NewEnv(opts...)
if err != nil {
return nil, errors.Wrap(err, "failed to create CEL environment")
}
ast, issues := e.Compile(filter)
if issues != nil {
return nil, errors.Errorf("failed to compile filter: %v", issues)
}
return cel.AstToParsedExpr(ast)
}

@ -0,0 +1,25 @@
package filter
import (
"context"
"fmt"
)
// AppendConditions compiles the provided filters and appends the resulting SQL fragments and args.
func AppendConditions(ctx context.Context, engine *Engine, filters []string, dialect DialectName, where *[]string, args *[]any) error {
for _, filterStr := range filters {
stmt, err := engine.CompileToStatement(ctx, filterStr, RenderOptions{
Dialect: dialect,
PlaceholderOffset: len(*args),
})
if err != nil {
return err
}
if stmt.SQL == "" {
continue
}
*where = append(*where, fmt.Sprintf("(%s)", stmt.SQL))
*args = append(*args, stmt.Args...)
}
return nil
}

@ -0,0 +1,116 @@
package filter
// Condition represents a boolean expression derived from the CEL filter.
type Condition interface {
isCondition()
}
// LogicalOperator enumerates the supported logical operators.
type LogicalOperator string
const (
LogicalAnd LogicalOperator = "AND"
LogicalOr LogicalOperator = "OR"
)
// LogicalCondition composes two conditions with a logical operator.
type LogicalCondition struct {
Operator LogicalOperator
Left Condition
Right Condition
}
func (*LogicalCondition) isCondition() {}
// NotCondition negates a child condition.
type NotCondition struct {
Expr Condition
}
func (*NotCondition) isCondition() {}
// FieldPredicateCondition asserts that a field evaluates to true.
type FieldPredicateCondition struct {
Field string
}
func (*FieldPredicateCondition) isCondition() {}
// ComparisonOperator lists supported comparison operators.
type ComparisonOperator string
const (
CompareEq ComparisonOperator = "="
CompareNeq ComparisonOperator = "!="
CompareLt ComparisonOperator = "<"
CompareLte ComparisonOperator = "<="
CompareGt ComparisonOperator = ">"
CompareGte ComparisonOperator = ">="
)
// ComparisonCondition represents a binary comparison.
type ComparisonCondition struct {
Left ValueExpr
Operator ComparisonOperator
Right ValueExpr
}
func (*ComparisonCondition) isCondition() {}
// InCondition represents an IN predicate with literal list values.
type InCondition struct {
Left ValueExpr
Values []ValueExpr
}
func (*InCondition) isCondition() {}
// ElementInCondition represents the CEL syntax `"value" in field`.
type ElementInCondition struct {
Element ValueExpr
Field string
}
func (*ElementInCondition) isCondition() {}
// ContainsCondition models the <field>.contains(<value>) call.
type ContainsCondition struct {
Field string
Value string
}
func (*ContainsCondition) isCondition() {}
// ConstantCondition captures a literal boolean outcome.
type ConstantCondition struct {
Value bool
}
func (*ConstantCondition) isCondition() {}
// ValueExpr models arithmetic or scalar expressions whose result feeds a comparison.
type ValueExpr interface {
isValueExpr()
}
// FieldRef references a named schema field.
type FieldRef struct {
Name string
}
func (*FieldRef) isValueExpr() {}
// LiteralValue holds a literal scalar.
type LiteralValue struct {
Value interface{}
}
func (*LiteralValue) isValueExpr() {}
// FunctionValue captures simple function calls like size(tags).
type FunctionValue struct {
Name string
Args []ValueExpr
}
func (*FunctionValue) isValueExpr() {}

@ -0,0 +1,413 @@
package filter
import (
"time"
"github.com/pkg/errors"
exprv1 "google.golang.org/genproto/googleapis/api/expr/v1alpha1"
)
func buildCondition(expr *exprv1.Expr, schema Schema) (Condition, error) {
switch v := expr.ExprKind.(type) {
case *exprv1.Expr_CallExpr:
return buildCallCondition(v.CallExpr, schema)
case *exprv1.Expr_ConstExpr:
val, err := getConstValue(expr)
if err != nil {
return nil, err
}
switch v := val.(type) {
case bool:
return &ConstantCondition{Value: v}, nil
case int64:
return &ConstantCondition{Value: v != 0}, nil
case float64:
return &ConstantCondition{Value: v != 0}, nil
default:
return nil, errors.New("filter must evaluate to a boolean value")
}
case *exprv1.Expr_IdentExpr:
name := v.IdentExpr.GetName()
field, ok := schema.Field(name)
if !ok {
return nil, errors.Errorf("unknown identifier %q", name)
}
if field.Type != FieldTypeBool {
return nil, errors.Errorf("identifier %q is not boolean", name)
}
return &FieldPredicateCondition{Field: name}, nil
default:
return nil, errors.New("unsupported top-level expression")
}
}
func buildCallCondition(call *exprv1.Expr_Call, schema Schema) (Condition, error) {
switch call.Function {
case "_&&_":
if len(call.Args) != 2 {
return nil, errors.New("logical AND expects two arguments")
}
left, err := buildCondition(call.Args[0], schema)
if err != nil {
return nil, err
}
right, err := buildCondition(call.Args[1], schema)
if err != nil {
return nil, err
}
return &LogicalCondition{
Operator: LogicalAnd,
Left: left,
Right: right,
}, nil
case "_||_":
if len(call.Args) != 2 {
return nil, errors.New("logical OR expects two arguments")
}
left, err := buildCondition(call.Args[0], schema)
if err != nil {
return nil, err
}
right, err := buildCondition(call.Args[1], schema)
if err != nil {
return nil, err
}
return &LogicalCondition{
Operator: LogicalOr,
Left: left,
Right: right,
}, nil
case "!_":
if len(call.Args) != 1 {
return nil, errors.New("logical NOT expects one argument")
}
child, err := buildCondition(call.Args[0], schema)
if err != nil {
return nil, err
}
return &NotCondition{Expr: child}, nil
case "_==_", "_!=_", "_<_", "_>_", "_<=_", "_>=_":
return buildComparisonCondition(call, schema)
case "@in":
return buildInCondition(call, schema)
case "contains":
return buildContainsCondition(call, schema)
default:
val, ok, err := evaluateBool(call)
if err != nil {
return nil, err
}
if ok {
return &ConstantCondition{Value: val}, nil
}
return nil, errors.Errorf("unsupported call expression %q", call.Function)
}
}
func buildComparisonCondition(call *exprv1.Expr_Call, schema Schema) (Condition, error) {
if len(call.Args) != 2 {
return nil, errors.New("comparison expects two arguments")
}
op, err := toComparisonOperator(call.Function)
if err != nil {
return nil, err
}
left, err := buildValueExpr(call.Args[0], schema)
if err != nil {
return nil, err
}
right, err := buildValueExpr(call.Args[1], schema)
if err != nil {
return nil, err
}
// If the left side is a field, validate allowed operators.
if field, ok := left.(*FieldRef); ok {
def, exists := schema.Field(field.Name)
if !exists {
return nil, errors.Errorf("unknown identifier %q", field.Name)
}
if def.Kind == FieldKindVirtualAlias {
def, exists = schema.ResolveAlias(field.Name)
if !exists {
return nil, errors.Errorf("invalid alias %q", field.Name)
}
}
if def.AllowedComparisonOps != nil {
if _, allowed := def.AllowedComparisonOps[op]; !allowed {
return nil, errors.Errorf("operator %s not allowed for field %q", op, field.Name)
}
}
}
return &ComparisonCondition{
Left: left,
Operator: op,
Right: right,
}, nil
}
func buildInCondition(call *exprv1.Expr_Call, schema Schema) (Condition, error) {
if len(call.Args) != 2 {
return nil, errors.New("in operator expects two arguments")
}
// Handle identifier in list syntax.
if identName, err := getIdentName(call.Args[0]); err == nil {
if field, ok := schema.Field(identName); ok && field.Kind == FieldKindVirtualAlias {
if _, aliasOk := schema.ResolveAlias(identName); !aliasOk {
return nil, errors.Errorf("invalid alias %q", identName)
}
} else if !ok {
return nil, errors.Errorf("unknown identifier %q", identName)
}
if listExpr := call.Args[1].GetListExpr(); listExpr != nil {
values := make([]ValueExpr, 0, len(listExpr.Elements))
for _, element := range listExpr.Elements {
value, err := buildValueExpr(element, schema)
if err != nil {
return nil, err
}
values = append(values, value)
}
return &InCondition{
Left: &FieldRef{Name: identName},
Values: values,
}, nil
}
}
// Handle "value in identifier" syntax.
if identName, err := getIdentName(call.Args[1]); err == nil {
if _, ok := schema.Field(identName); !ok {
return nil, errors.Errorf("unknown identifier %q", identName)
}
element, err := buildValueExpr(call.Args[0], schema)
if err != nil {
return nil, err
}
return &ElementInCondition{
Element: element,
Field: identName,
}, nil
}
return nil, errors.New("invalid use of in operator")
}
func buildContainsCondition(call *exprv1.Expr_Call, schema Schema) (Condition, error) {
if call.Target == nil {
return nil, errors.New("contains requires a target")
}
targetName, err := getIdentName(call.Target)
if err != nil {
return nil, err
}
field, ok := schema.Field(targetName)
if !ok {
return nil, errors.Errorf("unknown identifier %q", targetName)
}
if !field.SupportsContains {
return nil, errors.Errorf("identifier %q does not support contains()", targetName)
}
if len(call.Args) != 1 {
return nil, errors.New("contains expects exactly one argument")
}
value, err := getConstValue(call.Args[0])
if err != nil {
return nil, errors.Wrap(err, "contains only supports literal arguments")
}
str, ok := value.(string)
if !ok {
return nil, errors.New("contains argument must be a string")
}
return &ContainsCondition{
Field: targetName,
Value: str,
}, nil
}
func buildValueExpr(expr *exprv1.Expr, schema Schema) (ValueExpr, error) {
if identName, err := getIdentName(expr); err == nil {
if _, ok := schema.Field(identName); !ok {
return nil, errors.Errorf("unknown identifier %q", identName)
}
return &FieldRef{Name: identName}, nil
}
if literal, err := getConstValue(expr); err == nil {
return &LiteralValue{Value: literal}, nil
}
if value, ok, err := evaluateNumeric(expr); err != nil {
return nil, err
} else if ok {
return &LiteralValue{Value: value}, nil
}
if boolVal, ok, err := evaluateBoolExpr(expr); err != nil {
return nil, err
} else if ok {
return &LiteralValue{Value: boolVal}, nil
}
if call := expr.GetCallExpr(); call != nil {
switch call.Function {
case "size":
if len(call.Args) != 1 {
return nil, errors.New("size() expects one argument")
}
arg, err := buildValueExpr(call.Args[0], schema)
if err != nil {
return nil, err
}
return &FunctionValue{
Name: "size",
Args: []ValueExpr{arg},
}, nil
case "now":
return &LiteralValue{Value: timeNowUnix()}, nil
case "_+_", "_-_", "_*_":
value, ok, err := evaluateNumeric(expr)
if err != nil {
return nil, err
}
if ok {
return &LiteralValue{Value: value}, nil
}
}
}
return nil, errors.New("unsupported value expression")
}
func toComparisonOperator(fn string) (ComparisonOperator, error) {
switch fn {
case "_==_":
return CompareEq, nil
case "_!=_":
return CompareNeq, nil
case "_<_":
return CompareLt, nil
case "_>_":
return CompareGt, nil
case "_<=_":
return CompareLte, nil
case "_>=_":
return CompareGte, nil
default:
return "", errors.Errorf("unsupported comparison operator %q", fn)
}
}
func getIdentName(expr *exprv1.Expr) (string, error) {
if ident := expr.GetIdentExpr(); ident != nil {
return ident.GetName(), nil
}
return "", errors.New("expression is not an identifier")
}
func getConstValue(expr *exprv1.Expr) (interface{}, error) {
v, ok := expr.ExprKind.(*exprv1.Expr_ConstExpr)
if !ok {
return nil, errors.New("expression is not a literal")
}
switch x := v.ConstExpr.ConstantKind.(type) {
case *exprv1.Constant_StringValue:
return v.ConstExpr.GetStringValue(), nil
case *exprv1.Constant_Int64Value:
return v.ConstExpr.GetInt64Value(), nil
case *exprv1.Constant_Uint64Value:
return int64(v.ConstExpr.GetUint64Value()), nil
case *exprv1.Constant_DoubleValue:
return v.ConstExpr.GetDoubleValue(), nil
case *exprv1.Constant_BoolValue:
return v.ConstExpr.GetBoolValue(), nil
case *exprv1.Constant_NullValue:
return nil, nil
default:
return nil, errors.Errorf("unsupported constant %T", x)
}
}
func evaluateBool(call *exprv1.Expr_Call) (bool, bool, error) {
val, ok, err := evaluateBoolExpr(&exprv1.Expr{ExprKind: &exprv1.Expr_CallExpr{CallExpr: call}})
return val, ok, err
}
func evaluateBoolExpr(expr *exprv1.Expr) (bool, bool, error) {
if literal, err := getConstValue(expr); err == nil {
if b, ok := literal.(bool); ok {
return b, true, nil
}
return false, false, nil
}
if call := expr.GetCallExpr(); call != nil && call.Function == "!_" {
if len(call.Args) != 1 {
return false, false, errors.New("NOT expects exactly one argument")
}
val, ok, err := evaluateBoolExpr(call.Args[0])
if err != nil || !ok {
return false, false, err
}
return !val, true, nil
}
return false, false, nil
}
func evaluateNumeric(expr *exprv1.Expr) (int64, bool, error) {
if literal, err := getConstValue(expr); err == nil {
switch v := literal.(type) {
case int64:
return v, true, nil
case float64:
return int64(v), true, nil
}
return 0, false, nil
}
call := expr.GetCallExpr()
if call == nil {
return 0, false, nil
}
switch call.Function {
case "now":
return timeNowUnix(), true, nil
case "_+_", "_-_", "_*_":
if len(call.Args) != 2 {
return 0, false, errors.New("arithmetic requires two arguments")
}
left, ok, err := evaluateNumeric(call.Args[0])
if err != nil {
return 0, false, err
}
if !ok {
return 0, false, nil
}
right, ok, err := evaluateNumeric(call.Args[1])
if err != nil {
return 0, false, err
}
if !ok {
return 0, false, nil
}
switch call.Function {
case "_+_":
return left + right, true, nil
case "_-_":
return left - right, true, nil
case "_*_":
return left * right, true, nil
}
}
return 0, false, nil
}
func timeNowUnix() int64 {
return time.Now().Unix()
}

@ -0,0 +1,626 @@
package filter
import (
"fmt"
"strings"
"github.com/pkg/errors"
)
type renderer struct {
schema Schema
dialect DialectName
placeholderOffset int
placeholderCounter int
args []any
}
type renderResult struct {
sql string
trivial bool
unsatisfiable bool
}
func newRenderer(schema Schema, opts RenderOptions) *renderer {
return &renderer{
schema: schema,
dialect: opts.Dialect,
placeholderOffset: opts.PlaceholderOffset,
}
}
func (r *renderer) Render(cond Condition) (Statement, error) {
result, err := r.renderCondition(cond)
if err != nil {
return Statement{}, err
}
args := r.args
if args == nil {
args = []any{}
}
switch {
case result.unsatisfiable:
return Statement{
SQL: "1 = 0",
Args: args,
}, nil
case result.trivial:
return Statement{
SQL: "",
Args: args,
}, nil
default:
return Statement{
SQL: result.sql,
Args: args,
}, nil
}
}
func (r *renderer) renderCondition(cond Condition) (renderResult, error) {
switch c := cond.(type) {
case *LogicalCondition:
return r.renderLogicalCondition(c)
case *NotCondition:
return r.renderNotCondition(c)
case *FieldPredicateCondition:
return r.renderFieldPredicate(c)
case *ComparisonCondition:
return r.renderComparison(c)
case *InCondition:
return r.renderInCondition(c)
case *ElementInCondition:
return r.renderElementInCondition(c)
case *ContainsCondition:
return r.renderContainsCondition(c)
case *ConstantCondition:
if c.Value {
return renderResult{trivial: true}, nil
}
return renderResult{sql: "1 = 0", unsatisfiable: true}, nil
default:
return renderResult{}, errors.Errorf("unsupported condition type %T", c)
}
}
func (r *renderer) renderLogicalCondition(cond *LogicalCondition) (renderResult, error) {
left, err := r.renderCondition(cond.Left)
if err != nil {
return renderResult{}, err
}
right, err := r.renderCondition(cond.Right)
if err != nil {
return renderResult{}, err
}
switch cond.Operator {
case LogicalAnd:
return combineAnd(left, right), nil
case LogicalOr:
return combineOr(left, right), nil
default:
return renderResult{}, errors.Errorf("unsupported logical operator %s", cond.Operator)
}
}
func (r *renderer) renderNotCondition(cond *NotCondition) (renderResult, error) {
child, err := r.renderCondition(cond.Expr)
if err != nil {
return renderResult{}, err
}
if child.trivial {
return renderResult{sql: "1 = 0", unsatisfiable: true}, nil
}
if child.unsatisfiable {
return renderResult{trivial: true}, nil
}
return renderResult{
sql: fmt.Sprintf("NOT (%s)", child.sql),
}, nil
}
func (r *renderer) renderFieldPredicate(cond *FieldPredicateCondition) (renderResult, error) {
field, ok := r.schema.Field(cond.Field)
if !ok {
return renderResult{}, errors.Errorf("unknown field %q", cond.Field)
}
switch field.Kind {
case FieldKindBoolColumn:
column := qualifyColumn(r.dialect, field.Column)
return renderResult{
sql: fmt.Sprintf("%s IS TRUE", column),
}, nil
case FieldKindJSONBool:
sql, err := r.jsonBoolPredicate(field)
if err != nil {
return renderResult{}, err
}
return renderResult{sql: sql}, nil
default:
return renderResult{}, errors.Errorf("field %q cannot be used as a predicate", cond.Field)
}
}
func (r *renderer) renderComparison(cond *ComparisonCondition) (renderResult, error) {
switch left := cond.Left.(type) {
case *FieldRef:
field, ok := r.schema.Field(left.Name)
if !ok {
return renderResult{}, errors.Errorf("unknown field %q", left.Name)
}
switch field.Kind {
case FieldKindBoolColumn:
return r.renderBoolColumnComparison(field, cond.Operator, cond.Right)
case FieldKindJSONBool:
return r.renderJSONBoolComparison(field, cond.Operator, cond.Right)
case FieldKindScalar:
return r.renderScalarComparison(field, cond.Operator, cond.Right)
default:
return renderResult{}, errors.Errorf("field %q does not support comparison", field.Name)
}
case *FunctionValue:
return r.renderFunctionComparison(left, cond.Operator, cond.Right)
default:
return renderResult{}, errors.New("comparison must start with a field reference or supported function")
}
}
func (r *renderer) renderFunctionComparison(fn *FunctionValue, op ComparisonOperator, right ValueExpr) (renderResult, error) {
if fn.Name != "size" {
return renderResult{}, errors.Errorf("unsupported function %s in comparison", fn.Name)
}
if len(fn.Args) != 1 {
return renderResult{}, errors.New("size() expects one argument")
}
fieldArg, ok := fn.Args[0].(*FieldRef)
if !ok {
return renderResult{}, errors.New("size() argument must be a field")
}
field, ok := r.schema.Field(fieldArg.Name)
if !ok {
return renderResult{}, errors.Errorf("unknown field %q", fieldArg.Name)
}
if field.Kind != FieldKindJSONList {
return renderResult{}, errors.Errorf("size() only supports tag lists, got %q", field.Name)
}
value, err := expectNumericLiteral(right)
if err != nil {
return renderResult{}, err
}
expr := jsonArrayLengthExpr(r.dialect, field)
placeholder := r.addArg(value)
return renderResult{
sql: fmt.Sprintf("%s %s %s", expr, sqlOperator(op), placeholder),
}, nil
}
func (r *renderer) renderScalarComparison(field Field, op ComparisonOperator, right ValueExpr) (renderResult, error) {
lit, err := expectLiteral(right)
if err != nil {
return renderResult{}, err
}
columnExpr := field.columnExpr(r.dialect)
placeholder := ""
switch field.Type {
case FieldTypeString:
value, ok := lit.(string)
if !ok {
return renderResult{}, errors.Errorf("field %q expects string value", field.Name)
}
placeholder = r.addArg(value)
case FieldTypeInt, FieldTypeTimestamp:
num, err := toInt64(lit)
if err != nil {
return renderResult{}, errors.Wrapf(err, "field %q expects integer value", field.Name)
}
placeholder = r.addArg(num)
default:
return renderResult{}, errors.Errorf("unsupported data type %q for field %s", field.Type, field.Name)
}
return renderResult{
sql: fmt.Sprintf("%s %s %s", columnExpr, sqlOperator(op), placeholder),
}, nil
}
func (r *renderer) renderBoolColumnComparison(field Field, op ComparisonOperator, right ValueExpr) (renderResult, error) {
value, err := expectBool(right)
if err != nil {
return renderResult{}, err
}
placeholder := r.addBoolArg(value)
column := qualifyColumn(r.dialect, field.Column)
return renderResult{
sql: fmt.Sprintf("%s %s %s", column, sqlOperator(op), placeholder),
}, nil
}
func (r *renderer) renderJSONBoolComparison(field Field, op ComparisonOperator, right ValueExpr) (renderResult, error) {
value, err := expectBool(right)
if err != nil {
return renderResult{}, err
}
jsonExpr := jsonExtractExpr(r.dialect, field)
switch r.dialect {
case DialectSQLite:
switch op {
case CompareEq:
if field.Name == "has_task_list" {
target := "0"
if value {
target = "1"
}
return renderResult{sql: fmt.Sprintf("%s = %s", jsonExpr, target)}, nil
}
if value {
return renderResult{sql: fmt.Sprintf("%s IS TRUE", jsonExpr)}, nil
}
return renderResult{sql: fmt.Sprintf("NOT(%s IS TRUE)", jsonExpr)}, nil
case CompareNeq:
if field.Name == "has_task_list" {
target := "0"
if value {
target = "1"
}
return renderResult{sql: fmt.Sprintf("%s != %s", jsonExpr, target)}, nil
}
if value {
return renderResult{sql: fmt.Sprintf("NOT(%s IS TRUE)", jsonExpr)}, nil
}
return renderResult{sql: fmt.Sprintf("%s IS TRUE", jsonExpr)}, nil
default:
return renderResult{}, errors.Errorf("operator %s not supported for boolean JSON field", op)
}
case DialectMySQL:
boolStr := "false"
if value {
boolStr = "true"
}
return renderResult{
sql: fmt.Sprintf("%s %s CAST('%s' AS JSON)", jsonExpr, sqlOperator(op), boolStr),
}, nil
case DialectPostgres:
placeholder := r.addArg(value)
return renderResult{
sql: fmt.Sprintf("(%s)::boolean %s %s", jsonExpr, sqlOperator(op), placeholder),
}, nil
default:
return renderResult{}, errors.Errorf("unsupported dialect %s", r.dialect)
}
}
func (r *renderer) renderInCondition(cond *InCondition) (renderResult, error) {
fieldRef, ok := cond.Left.(*FieldRef)
if !ok {
return renderResult{}, errors.New("IN operator requires a field on the left-hand side")
}
if fieldRef.Name == "tag" {
return r.renderTagInList(cond.Values)
}
field, ok := r.schema.Field(fieldRef.Name)
if !ok {
return renderResult{}, errors.Errorf("unknown field %q", fieldRef.Name)
}
if field.Kind != FieldKindScalar {
return renderResult{}, errors.Errorf("field %q does not support IN()", fieldRef.Name)
}
return r.renderScalarInCondition(field, cond.Values)
}
func (r *renderer) renderTagInList(values []ValueExpr) (renderResult, error) {
field, ok := r.schema.ResolveAlias("tag")
if !ok {
return renderResult{}, errors.New("tag attribute is not configured")
}
conditions := make([]string, 0, len(values))
for _, v := range values {
lit, err := expectLiteral(v)
if err != nil {
return renderResult{}, err
}
str, ok := lit.(string)
if !ok {
return renderResult{}, errors.New("tags must be compared with string literals")
}
switch r.dialect {
case DialectSQLite:
expr := fmt.Sprintf("%s LIKE %s", jsonArrayExpr(r.dialect, field), r.addArg(fmt.Sprintf(`%%"%s"%%`, str)))
conditions = append(conditions, expr)
case DialectMySQL:
expr := fmt.Sprintf("JSON_CONTAINS(%s, %s)", jsonArrayExpr(r.dialect, field), r.addArg(fmt.Sprintf(`"%s"`, str)))
conditions = append(conditions, expr)
case DialectPostgres:
expr := fmt.Sprintf("%s @> jsonb_build_array(%s::json)", jsonArrayExpr(r.dialect, field), r.addArg(fmt.Sprintf(`"%s"`, str)))
conditions = append(conditions, expr)
default:
return renderResult{}, errors.Errorf("unsupported dialect %s", r.dialect)
}
}
if len(conditions) == 1 {
return renderResult{sql: conditions[0]}, nil
}
return renderResult{
sql: fmt.Sprintf("(%s)", strings.Join(conditions, " OR ")),
}, nil
}
func (r *renderer) renderElementInCondition(cond *ElementInCondition) (renderResult, error) {
field, ok := r.schema.Field(cond.Field)
if !ok {
return renderResult{}, errors.Errorf("unknown field %q", cond.Field)
}
if field.Kind != FieldKindJSONList {
return renderResult{}, errors.Errorf("field %q is not a tag list", cond.Field)
}
lit, err := expectLiteral(cond.Element)
if err != nil {
return renderResult{}, err
}
str, ok := lit.(string)
if !ok {
return renderResult{}, errors.New("tags membership requires string literal")
}
switch r.dialect {
case DialectSQLite:
sql := fmt.Sprintf("%s LIKE %s", jsonArrayExpr(r.dialect, field), r.addArg(fmt.Sprintf(`%%"%s"%%`, str)))
return renderResult{sql: sql}, nil
case DialectMySQL:
sql := fmt.Sprintf("JSON_CONTAINS(%s, %s)", jsonArrayExpr(r.dialect, field), r.addArg(str))
return renderResult{sql: sql}, nil
case DialectPostgres:
sql := fmt.Sprintf("%s @> jsonb_build_array(%s::json)", jsonArrayExpr(r.dialect, field), r.addArg(str))
return renderResult{sql: sql}, nil
default:
return renderResult{}, errors.Errorf("unsupported dialect %s", r.dialect)
}
}
func (r *renderer) renderScalarInCondition(field Field, values []ValueExpr) (renderResult, error) {
placeholders := make([]string, 0, len(values))
for _, v := range values {
lit, err := expectLiteral(v)
if err != nil {
return renderResult{}, err
}
switch field.Type {
case FieldTypeString:
str, ok := lit.(string)
if !ok {
return renderResult{}, errors.Errorf("field %q expects string values", field.Name)
}
placeholders = append(placeholders, r.addArg(str))
case FieldTypeInt:
num, err := toInt64(lit)
if err != nil {
return renderResult{}, err
}
placeholders = append(placeholders, r.addArg(num))
default:
return renderResult{}, errors.Errorf("field %q does not support IN() comparisons", field.Name)
}
}
column := field.columnExpr(r.dialect)
return renderResult{
sql: fmt.Sprintf("%s IN (%s)", column, strings.Join(placeholders, ",")),
}, nil
}
func (r *renderer) renderContainsCondition(cond *ContainsCondition) (renderResult, error) {
field, ok := r.schema.Field(cond.Field)
if !ok {
return renderResult{}, errors.Errorf("unknown field %q", cond.Field)
}
column := field.columnExpr(r.dialect)
arg := fmt.Sprintf("%%%s%%", cond.Value)
switch r.dialect {
case DialectPostgres:
sql := fmt.Sprintf("%s ILIKE %s", column, r.addArg(arg))
return renderResult{sql: sql}, nil
default:
sql := fmt.Sprintf("%s LIKE %s", column, r.addArg(arg))
return renderResult{sql: sql}, nil
}
}
func (r *renderer) jsonBoolPredicate(field Field) (string, error) {
expr := jsonExtractExpr(r.dialect, field)
switch r.dialect {
case DialectSQLite:
return fmt.Sprintf("%s IS TRUE", expr), nil
case DialectMySQL:
return fmt.Sprintf("%s = CAST('true' AS JSON)", expr), nil
case DialectPostgres:
return fmt.Sprintf("(%s)::boolean IS TRUE", expr), nil
default:
return "", errors.Errorf("unsupported dialect %s", r.dialect)
}
}
func combineAnd(left, right renderResult) renderResult {
if left.unsatisfiable || right.unsatisfiable {
return renderResult{sql: "1 = 0", unsatisfiable: true}
}
if left.trivial {
return right
}
if right.trivial {
return left
}
return renderResult{
sql: fmt.Sprintf("(%s AND %s)", left.sql, right.sql),
}
}
func combineOr(left, right renderResult) renderResult {
if left.trivial || right.trivial {
return renderResult{trivial: true}
}
if left.unsatisfiable {
return right
}
if right.unsatisfiable {
return left
}
return renderResult{
sql: fmt.Sprintf("(%s OR %s)", left.sql, right.sql),
}
}
func (r *renderer) addArg(value any) string {
r.placeholderCounter++
r.args = append(r.args, value)
if r.dialect == DialectPostgres {
return fmt.Sprintf("$%d", r.placeholderOffset+r.placeholderCounter)
}
return "?"
}
func (r *renderer) addBoolArg(value bool) string {
var v any
switch r.dialect {
case DialectSQLite:
if value {
v = 1
} else {
v = 0
}
default:
v = value
}
return r.addArg(v)
}
func expectLiteral(expr ValueExpr) (any, error) {
lit, ok := expr.(*LiteralValue)
if !ok {
return nil, errors.New("expression must be a literal")
}
return lit.Value, nil
}
func expectBool(expr ValueExpr) (bool, error) {
lit, err := expectLiteral(expr)
if err != nil {
return false, err
}
value, ok := lit.(bool)
if !ok {
return false, errors.New("boolean literal required")
}
return value, nil
}
func expectNumericLiteral(expr ValueExpr) (int64, error) {
lit, err := expectLiteral(expr)
if err != nil {
return 0, err
}
return toInt64(lit)
}
func toInt64(value any) (int64, error) {
switch v := value.(type) {
case int:
return int64(v), nil
case int32:
return int64(v), nil
case int64:
return v, nil
case uint32:
return int64(v), nil
case uint64:
return int64(v), nil
case float32:
return int64(v), nil
case float64:
return int64(v), nil
default:
return 0, errors.Errorf("cannot convert %T to int64", value)
}
}
func sqlOperator(op ComparisonOperator) string {
return string(op)
}
func qualifyColumn(d DialectName, col Column) string {
switch d {
case DialectPostgres:
return fmt.Sprintf("%s.%s", col.Table, col.Name)
default:
return fmt.Sprintf("`%s`.`%s`", col.Table, col.Name)
}
}
func jsonPath(field Field) string {
return "$." + strings.Join(field.JSONPath, ".")
}
func jsonExtractExpr(d DialectName, field Field) string {
column := qualifyColumn(d, field.Column)
switch d {
case DialectSQLite, DialectMySQL:
return fmt.Sprintf("JSON_EXTRACT(%s, '%s')", column, jsonPath(field))
case DialectPostgres:
return buildPostgresJSONAccessor(column, field.JSONPath, true)
default:
return ""
}
}
func jsonArrayExpr(d DialectName, field Field) string {
column := qualifyColumn(d, field.Column)
switch d {
case DialectSQLite, DialectMySQL:
return fmt.Sprintf("JSON_EXTRACT(%s, '%s')", column, jsonPath(field))
case DialectPostgres:
return buildPostgresJSONAccessor(column, field.JSONPath, false)
default:
return ""
}
}
func jsonArrayLengthExpr(d DialectName, field Field) string {
arrayExpr := jsonArrayExpr(d, field)
switch d {
case DialectSQLite:
return fmt.Sprintf("JSON_ARRAY_LENGTH(COALESCE(%s, JSON_ARRAY()))", arrayExpr)
case DialectMySQL:
return fmt.Sprintf("JSON_LENGTH(COALESCE(%s, JSON_ARRAY()))", arrayExpr)
case DialectPostgres:
return fmt.Sprintf("jsonb_array_length(COALESCE(%s, '[]'::jsonb))", arrayExpr)
default:
return ""
}
}
func buildPostgresJSONAccessor(base string, path []string, terminalText bool) string {
expr := base
for idx, part := range path {
if idx == len(path)-1 && terminalText {
expr = fmt.Sprintf("%s->>'%s'", expr, part)
} else {
expr = fmt.Sprintf("%s->'%s'", expr, part)
}
}
return expr
}

@ -0,0 +1,254 @@
package filter
import (
"fmt"
"time"
"github.com/google/cel-go/cel"
"github.com/google/cel-go/common/types"
"github.com/google/cel-go/common/types/ref"
)
// DialectName enumerates supported SQL dialects.
type DialectName string
const (
DialectSQLite DialectName = "sqlite"
DialectMySQL DialectName = "mysql"
DialectPostgres DialectName = "postgres"
)
// FieldType represents the logical type of a field.
type FieldType string
const (
FieldTypeString FieldType = "string"
FieldTypeInt FieldType = "int"
FieldTypeBool FieldType = "bool"
FieldTypeTimestamp FieldType = "timestamp"
)
// FieldKind describes how a field is stored.
type FieldKind string
const (
FieldKindScalar FieldKind = "scalar"
FieldKindBoolColumn FieldKind = "bool_column"
FieldKindJSONBool FieldKind = "json_bool"
FieldKindJSONList FieldKind = "json_list"
FieldKindVirtualAlias FieldKind = "virtual_alias"
)
// Column identifies the backing table column.
type Column struct {
Table string
Name string
}
// Field captures the schema metadata for an exposed CEL identifier.
type Field struct {
Name string
Kind FieldKind
Type FieldType
Column Column
JSONPath []string
AliasFor string
SupportsContains bool
Expressions map[DialectName]string
AllowedComparisonOps map[ComparisonOperator]bool
}
// Schema collects CEL environment options and field metadata.
type Schema struct {
Name string
Fields map[string]Field
EnvOptions []cel.EnvOption
}
// Field returns the field metadata if present.
func (s Schema) Field(name string) (Field, bool) {
f, ok := s.Fields[name]
return f, ok
}
// ResolveAlias resolves a virtual alias to its target field.
func (s Schema) ResolveAlias(name string) (Field, bool) {
field, ok := s.Fields[name]
if !ok {
return Field{}, false
}
if field.Kind == FieldKindVirtualAlias {
target, ok := s.Fields[field.AliasFor]
if !ok {
return Field{}, false
}
return target, true
}
return field, true
}
var nowFunction = cel.Function("now",
cel.Overload("now",
[]*cel.Type{},
cel.IntType,
cel.FunctionBinding(func(_ ...ref.Val) ref.Val {
return types.Int(time.Now().Unix())
}),
),
)
// NewSchema constructs the memo filter schema and CEL environment.
func NewSchema() Schema {
fields := map[string]Field{
"content": {
Name: "content",
Kind: FieldKindScalar,
Type: FieldTypeString,
Column: Column{Table: "memo", Name: "content"},
SupportsContains: true,
Expressions: map[DialectName]string{},
},
"creator_id": {
Name: "creator_id",
Kind: FieldKindScalar,
Type: FieldTypeInt,
Column: Column{Table: "memo", Name: "creator_id"},
Expressions: map[DialectName]string{},
AllowedComparisonOps: map[ComparisonOperator]bool{
CompareEq: true,
CompareNeq: true,
},
},
"created_ts": {
Name: "created_ts",
Kind: FieldKindScalar,
Type: FieldTypeTimestamp,
Column: Column{Table: "memo", Name: "created_ts"},
Expressions: map[DialectName]string{
DialectMySQL: "UNIX_TIMESTAMP(%s)",
DialectPostgres: "EXTRACT(EPOCH FROM TO_TIMESTAMP(%s))",
},
},
"updated_ts": {
Name: "updated_ts",
Kind: FieldKindScalar,
Type: FieldTypeTimestamp,
Column: Column{Table: "memo", Name: "updated_ts"},
Expressions: map[DialectName]string{
DialectMySQL: "UNIX_TIMESTAMP(%s)",
DialectPostgres: "EXTRACT(EPOCH FROM TO_TIMESTAMP(%s))",
},
},
"pinned": {
Name: "pinned",
Kind: FieldKindBoolColumn,
Type: FieldTypeBool,
Column: Column{Table: "memo", Name: "pinned"},
Expressions: map[DialectName]string{},
AllowedComparisonOps: map[ComparisonOperator]bool{
CompareEq: true,
CompareNeq: true,
},
},
"visibility": {
Name: "visibility",
Kind: FieldKindScalar,
Type: FieldTypeString,
Column: Column{Table: "memo", Name: "visibility"},
Expressions: map[DialectName]string{},
AllowedComparisonOps: map[ComparisonOperator]bool{
CompareEq: true,
CompareNeq: true,
},
},
"tags": {
Name: "tags",
Kind: FieldKindJSONList,
Type: FieldTypeString,
Column: Column{Table: "memo", Name: "payload"},
JSONPath: []string{"tags"},
},
"tag": {
Name: "tag",
Kind: FieldKindVirtualAlias,
Type: FieldTypeString,
AliasFor: "tags",
},
"has_task_list": {
Name: "has_task_list",
Kind: FieldKindJSONBool,
Type: FieldTypeBool,
Column: Column{Table: "memo", Name: "payload"},
JSONPath: []string{"property", "hasTaskList"},
AllowedComparisonOps: map[ComparisonOperator]bool{
CompareEq: true,
CompareNeq: true,
},
},
"has_link": {
Name: "has_link",
Kind: FieldKindJSONBool,
Type: FieldTypeBool,
Column: Column{Table: "memo", Name: "payload"},
JSONPath: []string{"property", "hasLink"},
AllowedComparisonOps: map[ComparisonOperator]bool{
CompareEq: true,
CompareNeq: true,
},
},
"has_code": {
Name: "has_code",
Kind: FieldKindJSONBool,
Type: FieldTypeBool,
Column: Column{Table: "memo", Name: "payload"},
JSONPath: []string{"property", "hasCode"},
AllowedComparisonOps: map[ComparisonOperator]bool{
CompareEq: true,
CompareNeq: true,
},
},
"has_incomplete_tasks": {
Name: "has_incomplete_tasks",
Kind: FieldKindJSONBool,
Type: FieldTypeBool,
Column: Column{Table: "memo", Name: "payload"},
JSONPath: []string{"property", "hasIncompleteTasks"},
AllowedComparisonOps: map[ComparisonOperator]bool{
CompareEq: true,
CompareNeq: true,
},
},
}
envOptions := []cel.EnvOption{
cel.Variable("content", cel.StringType),
cel.Variable("creator_id", cel.IntType),
cel.Variable("created_ts", cel.IntType),
cel.Variable("updated_ts", cel.IntType),
cel.Variable("pinned", cel.BoolType),
cel.Variable("tag", cel.StringType),
cel.Variable("tags", cel.ListType(cel.StringType)),
cel.Variable("visibility", cel.StringType),
cel.Variable("has_task_list", cel.BoolType),
cel.Variable("has_link", cel.BoolType),
cel.Variable("has_code", cel.BoolType),
cel.Variable("has_incomplete_tasks", cel.BoolType),
nowFunction,
}
return Schema{
Name: "memo",
Fields: fields,
EnvOptions: envOptions,
}
}
// columnExpr returns the field expression for the given dialect, applying
// any schema-specific overrides (e.g. UNIX timestamp conversions).
func (f Field) columnExpr(d DialectName) string {
base := qualifyColumn(d, f.Column)
if expr, ok := f.Expressions[d]; ok && expr != "" {
return fmt.Sprintf(expr, base)
}
return base
}

@ -1,146 +0,0 @@
package filter
import (
"fmt"
)
// SQLTemplate holds database-specific SQL fragments.
type SQLTemplate struct {
SQLite string
MySQL string
PostgreSQL string
}
// TemplateDBType represents the database type for templates.
type TemplateDBType string
const (
SQLiteTemplate TemplateDBType = "sqlite"
MySQLTemplate TemplateDBType = "mysql"
PostgreSQLTemplate TemplateDBType = "postgres"
)
// SQLTemplates contains common SQL patterns for different databases.
var SQLTemplates = map[string]SQLTemplate{
"json_extract": {
SQLite: "JSON_EXTRACT(`memo`.`payload`, '%s')",
MySQL: "JSON_EXTRACT(`memo`.`payload`, '%s')",
PostgreSQL: "memo.payload%s",
},
"json_array_length": {
SQLite: "JSON_ARRAY_LENGTH(COALESCE(JSON_EXTRACT(`memo`.`payload`, '$.tags'), JSON_ARRAY()))",
MySQL: "JSON_LENGTH(COALESCE(JSON_EXTRACT(`memo`.`payload`, '$.tags'), JSON_ARRAY()))",
PostgreSQL: "jsonb_array_length(COALESCE(memo.payload->'tags', '[]'::jsonb))",
},
"json_contains_element": {
SQLite: "JSON_EXTRACT(`memo`.`payload`, '$.tags') LIKE ?",
MySQL: "JSON_CONTAINS(JSON_EXTRACT(`memo`.`payload`, '$.tags'), ?)",
PostgreSQL: "memo.payload->'tags' @> jsonb_build_array(?)",
},
"json_contains_tag": {
SQLite: "JSON_EXTRACT(`memo`.`payload`, '$.tags') LIKE ?",
MySQL: "JSON_CONTAINS(JSON_EXTRACT(`memo`.`payload`, '$.tags'), ?)",
PostgreSQL: "memo.payload->'tags' @> jsonb_build_array(?)",
},
"boolean_true": {
SQLite: "JSON_EXTRACT(`memo`.`payload`, '$.property.hasTaskList') = 1",
MySQL: "JSON_EXTRACT(`memo`.`payload`, '$.property.hasTaskList') = CAST('true' AS JSON)",
PostgreSQL: "(memo.payload->'property'->>'hasTaskList')::boolean = true",
},
"boolean_false": {
SQLite: "JSON_EXTRACT(`memo`.`payload`, '$.property.hasTaskList') = 0",
MySQL: "JSON_EXTRACT(`memo`.`payload`, '$.property.hasTaskList') = CAST('false' AS JSON)",
PostgreSQL: "(memo.payload->'property'->>'hasTaskList')::boolean = false",
},
"boolean_not_true": {
SQLite: "JSON_EXTRACT(`memo`.`payload`, '$.property.hasTaskList') != 1",
MySQL: "JSON_EXTRACT(`memo`.`payload`, '$.property.hasTaskList') != CAST('true' AS JSON)",
PostgreSQL: "(memo.payload->'property'->>'hasTaskList')::boolean != true",
},
"boolean_not_false": {
SQLite: "JSON_EXTRACT(`memo`.`payload`, '$.property.hasTaskList') != 0",
MySQL: "JSON_EXTRACT(`memo`.`payload`, '$.property.hasTaskList') != CAST('false' AS JSON)",
PostgreSQL: "(memo.payload->'property'->>'hasTaskList')::boolean != false",
},
"boolean_compare": {
SQLite: "JSON_EXTRACT(`memo`.`payload`, '$.property.hasTaskList') %s ?",
MySQL: "JSON_EXTRACT(`memo`.`payload`, '$.property.hasTaskList') %s CAST(? AS JSON)",
PostgreSQL: "(memo.payload->'property'->>'hasTaskList')::boolean %s ?",
},
"boolean_check": {
SQLite: "JSON_EXTRACT(`memo`.`payload`, '$.property.hasTaskList') IS TRUE",
MySQL: "JSON_EXTRACT(`memo`.`payload`, '$.property.hasTaskList') = CAST('true' AS JSON)",
PostgreSQL: "(memo.payload->'property'->>'hasTaskList')::boolean IS TRUE",
},
"table_prefix": {
SQLite: "`memo`",
MySQL: "`memo`",
PostgreSQL: "memo",
},
"timestamp_field": {
SQLite: "`memo`.`%s`",
MySQL: "UNIX_TIMESTAMP(`memo`.`%s`)",
PostgreSQL: "EXTRACT(EPOCH FROM memo.%s)",
},
"content_like": {
SQLite: "`memo`.`content` LIKE ?",
MySQL: "`memo`.`content` LIKE ?",
PostgreSQL: "memo.content ILIKE ?",
},
"visibility_in": {
SQLite: "`memo`.`visibility` IN (%s)",
MySQL: "`memo`.`visibility` IN (%s)",
PostgreSQL: "memo.visibility IN (%s)",
},
}
// GetSQL returns the appropriate SQL for the given template and database type.
func GetSQL(templateName string, dbType TemplateDBType) string {
template, exists := SQLTemplates[templateName]
if !exists {
return ""
}
switch dbType {
case SQLiteTemplate:
return template.SQLite
case MySQLTemplate:
return template.MySQL
case PostgreSQLTemplate:
return template.PostgreSQL
default:
return template.SQLite
}
}
// GetParameterPlaceholder returns the appropriate parameter placeholder for the database.
func GetParameterPlaceholder(dbType TemplateDBType, index int) string {
switch dbType {
case PostgreSQLTemplate:
return fmt.Sprintf("$%d", index)
default:
return "?"
}
}
// GetParameterValue returns the appropriate parameter value for the database.
func GetParameterValue(dbType TemplateDBType, templateName string, value interface{}) interface{} {
switch templateName {
case "json_contains_element", "json_contains_tag":
if dbType == SQLiteTemplate {
return fmt.Sprintf(`%%"%s"%%`, value)
}
return value
default:
return value
}
}
// FormatPlaceholders formats a list of placeholders for the given database type.
func FormatPlaceholders(dbType TemplateDBType, count int, startIndex int) []string {
placeholders := make([]string, count)
for i := 0; i < count; i++ {
placeholders[i] = GetParameterPlaceholder(dbType, startIndex+i)
}
return placeholders
}

@ -200,20 +200,18 @@ func (s *APIV1Service) ListMemos(ctx context.Context, request *v1pb.ListMemosReq
}
reactionMap := make(map[string][]*store.Reaction)
memoNames := make([]string, 0, len(memos))
contentIDs := make([]string, 0, len(memos))
attachmentMap := make(map[int32][]*store.Attachment)
memoIDs := make([]string, 0, len(memos))
memoIDs := make([]int32, 0, len(memos))
for _, m := range memos {
memoNames = append(memoNames, fmt.Sprintf("'%s%s'", MemoNamePrefix, m.UID))
memoIDs = append(memoIDs, fmt.Sprintf("'%d'", m.ID))
contentIDs = append(contentIDs, fmt.Sprintf("%s%s", MemoNamePrefix, m.UID))
memoIDs = append(memoIDs, m.ID)
}
// REACTIONS
reactions, err := s.Store.ListReactions(ctx, &store.FindReaction{
Filters: []string{fmt.Sprintf("content_id in [%s]", strings.Join(memoNames, ", "))},
})
reactions, err := s.Store.ListReactions(ctx, &store.FindReaction{ContentIDList: contentIDs})
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to list reactions")
}
@ -222,9 +220,7 @@ func (s *APIV1Service) ListMemos(ctx context.Context, request *v1pb.ListMemosReq
}
// ATTACHMENTS
attachments, err := s.Store.ListAttachments(ctx, &store.FindAttachment{
Filters: []string{fmt.Sprintf("memo_id in [%s]", strings.Join(memoIDs, ", "))},
})
attachments, err := s.Store.ListAttachments(ctx, &store.FindAttachment{MemoIDList: memoIDs})
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to list attachments")
}
@ -630,30 +626,26 @@ func (s *APIV1Service) ListMemoComments(ctx context.Context, request *v1pb.ListM
return response, nil
}
memoRelationIDs := make([]string, 0, len(memoRelations))
memoRelationIDs := make([]int32, 0, len(memoRelations))
for _, m := range memoRelations {
memoRelationIDs = append(memoRelationIDs, fmt.Sprintf("%d", m.MemoID))
memoRelationIDs = append(memoRelationIDs, m.MemoID)
}
memos, err := s.Store.ListMemos(ctx, &store.FindMemo{
Filters: []string{fmt.Sprintf("id in [%s]", strings.Join(memoRelationIDs, ", "))},
})
memos, err := s.Store.ListMemos(ctx, &store.FindMemo{IDList: memoRelationIDs})
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to list memos")
}
memoIDToNameMap := make(map[int32]string)
memoNamesForQuery := make([]string, 0, len(memos))
memoIDsForQuery := make([]string, 0, len(memos))
contentIDs := make([]string, 0, len(memos))
memoIDsForAttachments := make([]int32, 0, len(memos))
for _, memo := range memos {
memoName := fmt.Sprintf("%s%s", MemoNamePrefix, memo.UID)
memoIDToNameMap[memo.ID] = memoName
memoNamesForQuery = append(memoNamesForQuery, fmt.Sprintf("'%s'", memoName))
memoIDsForQuery = append(memoIDsForQuery, fmt.Sprintf("'%d'", memo.ID))
contentIDs = append(contentIDs, memoName)
memoIDsForAttachments = append(memoIDsForAttachments, memo.ID)
}
reactions, err := s.Store.ListReactions(ctx, &store.FindReaction{
Filters: []string{fmt.Sprintf("content_id in [%s]", strings.Join(memoNamesForQuery, ", "))},
})
reactions, err := s.Store.ListReactions(ctx, &store.FindReaction{ContentIDList: contentIDs})
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to list reactions")
}
@ -663,9 +655,7 @@ func (s *APIV1Service) ListMemoComments(ctx context.Context, request *v1pb.ListM
memoReactionsMap[reaction.ContentID] = append(memoReactionsMap[reaction.ContentID], reaction)
}
attachments, err := s.Store.ListAttachments(ctx, &store.FindAttachment{
Filters: []string{fmt.Sprintf("memo_id in [%s]", strings.Join(memoIDsForQuery, ", "))},
})
attachments, err := s.Store.ListAttachments(ctx, &store.FindAttachment{MemoIDList: memoIDsForAttachments})
if err != nil {
return nil, status.Errorf(codes.Internal, "failed to list attachments")
}

@ -319,35 +319,30 @@ func (s *APIV1Service) DeleteShortcut(ctx context.Context, request *v1pb.DeleteS
return &emptypb.Empty{}, nil
}
func (s *APIV1Service) validateFilter(_ context.Context, filterStr string) error {
func (s *APIV1Service) validateFilter(ctx context.Context, filterStr string) error {
if filterStr == "" {
return errors.New("filter cannot be empty")
}
// Validate the filter.
parsedExpr, err := filter.Parse(filterStr, filter.MemoFilterCELAttributes...)
engine, err := filter.DefaultEngine()
if err != nil {
return errors.Wrap(err, "failed to parse filter")
return err
}
convertCtx := filter.NewConvertContext()
// Determine the dialect based on the actual database driver
var dialect filter.SQLDialect
var dialect filter.DialectName
switch s.Profile.Driver {
case "sqlite":
dialect = &filter.SQLiteDialect{}
dialect = filter.DialectSQLite
case "mysql":
dialect = &filter.MySQLDialect{}
dialect = filter.DialectMySQL
case "postgres":
dialect = &filter.PostgreSQLDialect{}
dialect = filter.DialectPostgres
default:
// Default to SQLite for unknown drivers
dialect = &filter.SQLiteDialect{}
dialect = filter.DialectSQLite
}
converter := filter.NewCommonSQLConverter(dialect)
err = converter.ConvertExprToSQL(convertCtx, parsedExpr.GetExpr())
if err != nil {
return errors.Wrap(err, "failed to convert filter to SQL")
if _, err := engine.CompileToStatement(ctx, filterStr, filter.RenderOptions{Dialect: dialect}); err != nil {
return errors.Wrap(err, "failed to compile filter")
}
return nil
}

@ -1,68 +0,0 @@
package v1
import (
"testing"
"github.com/usememos/memos/plugin/filter"
)
func TestUserFilterValidation(t *testing.T) {
testCases := []struct {
name string
filter string
expectErr bool
}{
{
name: "valid username filter with equals",
filter: `username == "testuser"`,
expectErr: false,
},
{
name: "valid username filter with contains",
filter: `username.contains("admin")`,
expectErr: false,
},
{
name: "invalid filter - unknown field",
filter: `invalid_field == "test"`,
expectErr: true,
},
{
name: "empty filter",
filter: "",
expectErr: true,
},
{
name: "invalid syntax",
filter: `username ==`,
expectErr: true,
},
}
for _, tc := range testCases {
t.Run(tc.name, func(t *testing.T) {
// Test the filter parsing directly
_, err := filter.Parse(tc.filter, filter.UserFilterCELAttributes...)
if tc.expectErr && err == nil {
t.Errorf("Expected error for filter %q, but got none", tc.filter)
}
if !tc.expectErr && err != nil {
t.Errorf("Expected no error for filter %q, but got: %v", tc.filter, err)
}
})
}
}
func TestUserFilterCELAttributes(t *testing.T) {
// Test that our UserFilterCELAttributes contains the username variable
expectedAttributes := map[string]bool{
"username": true,
}
// This is a basic test to ensure the attributes are defined
// In a real test, you would create a CEL environment and verify the attributes
for attrName := range expectedAttributes {
t.Logf("Expected attribute %s should be available in UserFilterCELAttributes", attrName)
}
}

@ -25,7 +25,6 @@ import (
"github.com/usememos/memos/internal/base"
"github.com/usememos/memos/internal/util"
"github.com/usememos/memos/plugin/filter"
v1pb "github.com/usememos/memos/proto/gen/api/v1"
storepb "github.com/usememos/memos/proto/gen/store"
"github.com/usememos/memos/store"
@ -49,7 +48,6 @@ func (s *APIV1Service) ListUsers(ctx context.Context, request *v1pb.ListUsersReq
if err := s.validateUserFilter(ctx, request.Filter); err != nil {
return nil, status.Errorf(codes.InvalidArgument, "invalid filter: %v", err)
}
userFind.Filters = append(userFind.Filters, request.Filter)
}
users, err := s.Store.ListUsers(ctx, userFind)
@ -1368,34 +1366,8 @@ func extractWebhookIDFromName(name string) string {
// validateUserFilter validates the user filter string.
func (s *APIV1Service) validateUserFilter(_ context.Context, filterStr string) error {
if filterStr == "" {
return errors.New("filter cannot be empty")
}
// Validate the filter.
parsedExpr, err := filter.Parse(filterStr, filter.UserFilterCELAttributes...)
if err != nil {
return errors.Wrap(err, "failed to parse filter")
}
convertCtx := filter.NewConvertContext()
// Determine the dialect based on the actual database driver
var dialect filter.SQLDialect
switch s.Profile.Driver {
case "sqlite":
dialect = &filter.SQLiteDialect{}
case "mysql":
dialect = &filter.MySQLDialect{}
case "postgres":
dialect = &filter.PostgreSQLDialect{}
default:
// Default to SQLite for unknown drivers
dialect = &filter.SQLiteDialect{}
}
converter := filter.NewUserSQLConverter(dialect)
err = converter.ConvertExprToSQL(convertCtx, parsedExpr.GetExpr())
if err != nil {
return errors.Wrap(err, "failed to convert filter to SQL")
if strings.TrimSpace(filterStr) != "" {
return errors.New("user filters are not supported")
}
return nil
}

@ -48,11 +48,11 @@ type FindAttachment struct {
Filename *string
FilenameSearch *string
MemoID *int32
MemoIDList []int32
HasRelatedMemo bool
StorageType *storepb.AttachmentStorageType
Limit *int
Offset *int
Filters []string
}
type UpdateAttachment struct {

@ -9,7 +9,6 @@ import (
"github.com/pkg/errors"
"google.golang.org/protobuf/encoding/protojson"
"github.com/usememos/memos/plugin/filter"
storepb "github.com/usememos/memos/proto/gen/store"
"github.com/usememos/memos/store"
)
@ -49,26 +48,6 @@ func (d *DB) CreateAttachment(ctx context.Context, create *store.Attachment) (*s
func (d *DB) ListAttachments(ctx context.Context, find *store.FindAttachment) ([]*store.Attachment, error) {
where, args := []string{"1 = 1"}, []any{}
for _, filterStr := range find.Filters {
// Parse filter string and return the parsed expression.
// The filter string should be a CEL expression.
parsedExpr, err := filter.Parse(filterStr, filter.AttachmentFilterCELAttributes...)
if err != nil {
return nil, err
}
convertCtx := filter.NewConvertContext()
// ConvertExprToSQL converts the parsed expression to a SQL condition string.
converter := filter.NewCommonSQLConverter(&filter.MySQLDialect{})
if err := converter.ConvertExprToSQL(convertCtx, parsedExpr.GetExpr()); err != nil {
return nil, err
}
condition := convertCtx.Buffer.String()
if condition != "" {
where = append(where, fmt.Sprintf("(%s)", condition))
args = append(args, convertCtx.Args...)
}
}
if v := find.ID; v != nil {
where, args = append(where, "`resource`.`id` = ?"), append(args, *v)
}
@ -87,6 +66,16 @@ func (d *DB) ListAttachments(ctx context.Context, find *store.FindAttachment) ([
if v := find.MemoID; v != nil {
where, args = append(where, "`resource`.`memo_id` = ?"), append(args, *v)
}
if len(find.MemoIDList) > 0 {
placeholders := make([]string, 0, len(find.MemoIDList))
for range find.MemoIDList {
placeholders = append(placeholders, "?")
}
where = append(where, "`resource`.`memo_id` IN ("+strings.Join(placeholders, ",")+")")
for _, id := range find.MemoIDList {
args = append(args, id)
}
}
if find.HasRelatedMemo {
where = append(where, "`resource`.`memo_id` IS NOT NULL")
}

@ -1,39 +0,0 @@
package mysql
import (
"testing"
"github.com/stretchr/testify/require"
"github.com/usememos/memos/plugin/filter"
)
func TestAttachmentConvertExprToSQL(t *testing.T) {
tests := []struct {
filter string
want string
args []any
}{
{
filter: `memo_id in ["5atZAj8GcvkSuUA3X2KLaY"]`,
want: "`resource`.`memo_id` IN (?)",
args: []any{"5atZAj8GcvkSuUA3X2KLaY"},
},
{
filter: `memo_id in ["5atZAj8GcvkSuUA3X2KLaY", "4EN8aEpcJ3MaK4ExHTpiTE"]`,
want: "`resource`.`memo_id` IN (?,?)",
args: []any{"5atZAj8GcvkSuUA3X2KLaY", "4EN8aEpcJ3MaK4ExHTpiTE"},
},
}
for _, tt := range tests {
parsedExpr, err := filter.Parse(tt.filter, filter.AttachmentFilterCELAttributes...)
require.NoError(t, err)
convertCtx := filter.NewConvertContext()
converter := filter.NewCommonSQLConverter(&filter.MySQLDialect{})
err = converter.ConvertExprToSQL(convertCtx, parsedExpr.GetExpr())
require.NoError(t, err)
require.Equal(t, tt.want, convertCtx.Buffer.String())
require.Equal(t, tt.args, convertCtx.Args)
}
}

@ -50,31 +50,39 @@ func (d *DB) CreateMemo(ctx context.Context, create *store.Memo) (*store.Memo, e
func (d *DB) ListMemos(ctx context.Context, find *store.FindMemo) ([]*store.Memo, error) {
where, having, args := []string{"1 = 1"}, []string{"1 = 1"}, []any{}
for _, filterStr := range find.Filters {
// Parse filter string and return the parsed expression.
// The filter string should be a CEL expression.
parsedExpr, err := filter.Parse(filterStr, filter.MemoFilterCELAttributes...)
if err != nil {
return nil, err
}
convertCtx := filter.NewConvertContext()
// ConvertExprToSQL converts the parsed expression to a SQL condition string.
converter := filter.NewCommonSQLConverter(&filter.MySQLDialect{})
if err := converter.ConvertExprToSQL(convertCtx, parsedExpr.GetExpr()); err != nil {
return nil, err
}
condition := convertCtx.Buffer.String()
if condition != "" {
where = append(where, fmt.Sprintf("(%s)", condition))
args = append(args, convertCtx.Args...)
}
engine, err := filter.DefaultEngine()
if err != nil {
return nil, err
}
if err := filter.AppendConditions(ctx, engine, find.Filters, filter.DialectMySQL, &where, &args); err != nil {
return nil, err
}
if v := find.ID; v != nil {
where, args = append(where, "`memo`.`id` = ?"), append(args, *v)
}
if len(find.IDList) > 0 {
placeholders := make([]string, 0, len(find.IDList))
for range find.IDList {
placeholders = append(placeholders, "?")
}
where = append(where, "`memo`.`id` IN ("+strings.Join(placeholders, ",")+")")
for _, id := range find.IDList {
args = append(args, id)
}
}
if v := find.UID; v != nil {
where, args = append(where, "`memo`.`uid` = ?"), append(args, *v)
}
if len(find.UIDList) > 0 {
placeholders := make([]string, 0, len(find.UIDList))
for range find.UIDList {
placeholders = append(placeholders, "?")
}
where = append(where, "`memo`.`uid` IN ("+strings.Join(placeholders, ",")+")")
for _, uid := range find.UIDList {
args = append(args, uid)
}
}
if v := find.CreatorID; v != nil {
where, args = append(where, "`memo`.`creator_id` = ?"), append(args, *v)
}

@ -1,6 +1,7 @@
package mysql
import (
"context"
"testing"
"time"
@ -147,14 +148,15 @@ func TestConvertExprToSQL(t *testing.T) {
},
}
engine, err := filter.DefaultEngine()
require.NoError(t, err)
for _, tt := range tests {
parsedExpr, err := filter.Parse(tt.filter, filter.MemoFilterCELAttributes...)
require.NoError(t, err)
convertCtx := filter.NewConvertContext()
converter := filter.NewCommonSQLConverter(&filter.MySQLDialect{})
err = converter.ConvertExprToSQL(convertCtx, parsedExpr.GetExpr())
stmt, err := engine.CompileToStatement(context.Background(), tt.filter, filter.RenderOptions{
Dialect: filter.DialectMySQL,
})
require.NoError(t, err)
require.Equal(t, tt.want, convertCtx.Buffer.String())
require.Equal(t, tt.args, convertCtx.Args)
require.Equal(t, tt.want, stmt.SQL)
require.Equal(t, tt.args, stmt.Args)
}
}

@ -43,23 +43,21 @@ func (d *DB) ListMemoRelations(ctx context.Context, find *store.FindMemoRelation
where, args = append(where, "`type` = ?"), append(args, find.Type)
}
if find.MemoFilter != nil {
// Parse filter string and return the parsed expression.
// The filter string should be a CEL expression.
parsedExpr, err := filter.Parse(*find.MemoFilter, filter.MemoFilterCELAttributes...)
engine, err := filter.DefaultEngine()
if err != nil {
return nil, err
}
convertCtx := filter.NewConvertContext()
// ConvertExprToSQL converts the parsed expression to a SQL condition string.
converter := filter.NewCommonSQLConverter(&filter.MySQLDialect{})
if err := converter.ConvertExprToSQL(convertCtx, parsedExpr.GetExpr()); err != nil {
stmt, err := engine.CompileToStatement(ctx, *find.MemoFilter, filter.RenderOptions{
Dialect: filter.DialectMySQL,
PlaceholderOffset: 0,
})
if err != nil {
return nil, err
}
condition := convertCtx.Buffer.String()
if condition != "" {
where = append(where, fmt.Sprintf("memo_id IN (SELECT id FROM memo WHERE %s)", condition))
where = append(where, fmt.Sprintf("related_memo_id IN (SELECT id FROM memo WHERE %s)", condition))
args = append(args, append(convertCtx.Args, convertCtx.Args...)...)
if stmt.SQL != "" {
where = append(where, fmt.Sprintf("memo_id IN (SELECT id FROM memo WHERE %s)", stmt.SQL))
where = append(where, fmt.Sprintf("related_memo_id IN (SELECT id FROM memo WHERE %s)", stmt.SQL))
args = append(args, append(stmt.Args, stmt.Args...)...)
}
}

@ -2,12 +2,10 @@ package mysql
import (
"context"
"fmt"
"strings"
"github.com/pkg/errors"
"github.com/usememos/memos/plugin/filter"
"github.com/usememos/memos/store"
)
@ -37,27 +35,7 @@ func (d *DB) UpsertReaction(ctx context.Context, upsert *store.Reaction) (*store
}
func (d *DB) ListReactions(ctx context.Context, find *store.FindReaction) ([]*store.Reaction, error) {
where, args := []string{"1 = 1"}, []interface{}{}
for _, filterStr := range find.Filters {
// Parse filter string and return the parsed expression.
// The filter string should be a CEL expression.
parsedExpr, err := filter.Parse(filterStr, filter.ReactionFilterCELAttributes...)
if err != nil {
return nil, err
}
convertCtx := filter.NewConvertContext()
// ConvertExprToSQL converts the parsed expression to a SQL condition string.
converter := filter.NewCommonSQLConverter(&filter.MySQLDialect{})
if err := converter.ConvertExprToSQL(convertCtx, parsedExpr.GetExpr()); err != nil {
return nil, err
}
condition := convertCtx.Buffer.String()
if condition != "" {
where = append(where, fmt.Sprintf("(%s)", condition))
args = append(args, convertCtx.Args...)
}
}
where, args := []string{"1 = 1"}, []any{}
if find.ID != nil {
where, args = append(where, "`id` = ?"), append(args, *find.ID)
@ -68,6 +46,14 @@ func (d *DB) ListReactions(ctx context.Context, find *store.FindReaction) ([]*st
if find.ContentID != nil {
where, args = append(where, "`content_id` = ?"), append(args, *find.ContentID)
}
if len(find.ContentIDList) > 0 {
placeholders := make([]string, 0, len(find.ContentIDList))
for _, id := range find.ContentIDList {
placeholders = append(placeholders, "?")
args = append(args, id)
}
where = append(where, "`content_id` IN ("+strings.Join(placeholders, ",")+")")
}
rows, err := d.db.QueryContext(ctx, `
SELECT

@ -1,39 +0,0 @@
package mysql
import (
"testing"
"github.com/stretchr/testify/require"
"github.com/usememos/memos/plugin/filter"
)
func TestReactionConvertExprToSQL(t *testing.T) {
tests := []struct {
filter string
want string
args []any
}{
{
filter: `content_id in ["memos/5atZAj8GcvkSuUA3X2KLaY"]`,
want: "`reaction`.`content_id` IN (?)",
args: []any{"memos/5atZAj8GcvkSuUA3X2KLaY"},
},
{
filter: `content_id in ["memos/5atZAj8GcvkSuUA3X2KLaY", "memos/4EN8aEpcJ3MaK4ExHTpiTE"]`,
want: "`reaction`.`content_id` IN (?,?)",
args: []any{"memos/5atZAj8GcvkSuUA3X2KLaY", "memos/4EN8aEpcJ3MaK4ExHTpiTE"},
},
}
for _, tt := range tests {
parsedExpr, err := filter.Parse(tt.filter, filter.ReactionFilterCELAttributes...)
require.NoError(t, err)
convertCtx := filter.NewConvertContext()
converter := filter.NewCommonSQLConverter(&filter.MySQLDialect{})
err = converter.ConvertExprToSQL(convertCtx, parsedExpr.GetExpr())
require.NoError(t, err)
require.Equal(t, tt.want, convertCtx.Buffer.String())
require.Equal(t, tt.args, convertCtx.Args)
}
}

@ -7,7 +7,6 @@ import (
"github.com/pkg/errors"
"github.com/usememos/memos/plugin/filter"
"github.com/usememos/memos/store"
)
@ -85,24 +84,8 @@ func (d *DB) UpdateUser(ctx context.Context, update *store.UpdateUser) (*store.U
func (d *DB) ListUsers(ctx context.Context, find *store.FindUser) ([]*store.User, error) {
where, args := []string{"1 = 1"}, []any{}
for _, filterStr := range find.Filters {
// Parse filter string and return the parsed expression.
// The filter string should be a CEL expression.
parsedExpr, err := filter.Parse(filterStr, filter.UserFilterCELAttributes...)
if err != nil {
return nil, err
}
convertCtx := filter.NewConvertContext()
// ConvertExprToSQL converts the parsed expression to a SQL condition string.
converter := filter.NewUserSQLConverter(&filter.MySQLDialect{})
if err := converter.ConvertExprToSQL(convertCtx, parsedExpr.GetExpr()); err != nil {
return nil, err
}
condition := convertCtx.Buffer.String()
if condition != "" {
where = append(where, fmt.Sprintf("(%s)", condition))
args = append(args, convertCtx.Args...)
}
if len(find.Filters) > 0 {
return nil, errors.Errorf("user filters are not supported")
}
if v := find.ID; v != nil {

@ -9,7 +9,6 @@ import (
"github.com/pkg/errors"
"google.golang.org/protobuf/encoding/protojson"
"github.com/usememos/memos/plugin/filter"
storepb "github.com/usememos/memos/proto/gen/store"
"github.com/usememos/memos/store"
)
@ -40,26 +39,6 @@ func (d *DB) CreateAttachment(ctx context.Context, create *store.Attachment) (*s
func (d *DB) ListAttachments(ctx context.Context, find *store.FindAttachment) ([]*store.Attachment, error) {
where, args := []string{"1 = 1"}, []any{}
for _, filterStr := range find.Filters {
// Parse filter string and return the parsed expression.
// The filter string should be a CEL expression.
parsedExpr, err := filter.Parse(filterStr, filter.AttachmentFilterCELAttributes...)
if err != nil {
return nil, err
}
convertCtx := filter.NewConvertContext()
// ConvertExprToSQL converts the parsed expression to a SQL condition string.
converter := filter.NewCommonSQLConverter(&filter.PostgreSQLDialect{})
if err := converter.ConvertExprToSQL(convertCtx, parsedExpr.GetExpr()); err != nil {
return nil, err
}
condition := convertCtx.Buffer.String()
if condition != "" {
where = append(where, fmt.Sprintf("(%s)", condition))
args = append(args, convertCtx.Args...)
}
}
if v := find.ID; v != nil {
where, args = append(where, "resource.id = "+placeholder(len(args)+1)), append(args, *v)
}
@ -78,6 +57,16 @@ func (d *DB) ListAttachments(ctx context.Context, find *store.FindAttachment) ([
if v := find.MemoID; v != nil {
where, args = append(where, "resource.memo_id = "+placeholder(len(args)+1)), append(args, *v)
}
if len(find.MemoIDList) > 0 {
holders := make([]string, 0, len(find.MemoIDList))
for range find.MemoIDList {
holders = append(holders, placeholder(len(args)+1))
}
where = append(where, "resource.memo_id IN ("+strings.Join(holders, ", ")+")")
for _, id := range find.MemoIDList {
args = append(args, id)
}
}
if find.HasRelatedMemo {
where = append(where, "resource.memo_id IS NOT NULL")
}

@ -1,39 +0,0 @@
package postgres
import (
"testing"
"github.com/stretchr/testify/require"
"github.com/usememos/memos/plugin/filter"
)
func TestAttachmentConvertExprToSQL(t *testing.T) {
tests := []struct {
filter string
want string
args []any
}{
{
filter: `memo_id in ["5atZAj8GcvkSuUA3X2KLaY"]`,
want: "resource.memo_id IN ($1)",
args: []any{"5atZAj8GcvkSuUA3X2KLaY"},
},
{
filter: `memo_id in ["5atZAj8GcvkSuUA3X2KLaY", "4EN8aEpcJ3MaK4ExHTpiTE"]`,
want: "resource.memo_id IN ($1,$2)",
args: []any{"5atZAj8GcvkSuUA3X2KLaY", "4EN8aEpcJ3MaK4ExHTpiTE"},
},
}
for _, tt := range tests {
parsedExpr, err := filter.Parse(tt.filter, filter.AttachmentFilterCELAttributes...)
require.NoError(t, err)
convertCtx := filter.NewConvertContext()
converter := filter.NewCommonSQLConverter(&filter.PostgreSQLDialect{})
err = converter.ConvertExprToSQL(convertCtx, parsedExpr.GetExpr())
require.NoError(t, err)
require.Equal(t, tt.want, convertCtx.Buffer.String())
require.Equal(t, tt.args, convertCtx.Args)
}
}

@ -41,32 +41,39 @@ func (d *DB) CreateMemo(ctx context.Context, create *store.Memo) (*store.Memo, e
func (d *DB) ListMemos(ctx context.Context, find *store.FindMemo) ([]*store.Memo, error) {
where, args := []string{"1 = 1"}, []any{}
for _, filterStr := range find.Filters {
// Parse filter string and return the parsed expression.
// The filter string should be a CEL expression.
parsedExpr, err := filter.Parse(filterStr, filter.MemoFilterCELAttributes...)
if err != nil {
return nil, err
}
convertCtx := filter.NewConvertContext()
convertCtx.ArgsOffset = len(args)
// ConvertExprToSQL converts the parsed expression to a SQL condition string.
converter := filter.NewCommonSQLConverterWithOffset(&filter.PostgreSQLDialect{}, convertCtx.ArgsOffset+len(convertCtx.Args))
if err := converter.ConvertExprToSQL(convertCtx, parsedExpr.GetExpr()); err != nil {
return nil, err
}
condition := convertCtx.Buffer.String()
if condition != "" {
where = append(where, fmt.Sprintf("(%s)", condition))
args = append(args, convertCtx.Args...)
}
engine, err := filter.DefaultEngine()
if err != nil {
return nil, err
}
if err := filter.AppendConditions(ctx, engine, find.Filters, filter.DialectPostgres, &where, &args); err != nil {
return nil, err
}
if v := find.ID; v != nil {
where, args = append(where, "memo.id = "+placeholder(len(args)+1)), append(args, *v)
}
if len(find.IDList) > 0 {
holders := make([]string, 0, len(find.IDList))
for range find.IDList {
holders = append(holders, placeholder(len(args)+1))
}
where = append(where, "memo.id IN ("+strings.Join(holders, ", ")+")")
for _, id := range find.IDList {
args = append(args, id)
}
}
if v := find.UID; v != nil {
where, args = append(where, "memo.uid = "+placeholder(len(args)+1)), append(args, *v)
}
if len(find.UIDList) > 0 {
holders := make([]string, 0, len(find.UIDList))
for range find.UIDList {
holders = append(holders, placeholder(len(args)+1))
}
where = append(where, "memo.uid IN ("+strings.Join(holders, ", ")+")")
for _, uid := range find.UIDList {
args = append(args, uid)
}
}
if v := find.CreatorID; v != nil {
where, args = append(where, "memo.creator_id = "+placeholder(len(args)+1)), append(args, *v)
}

@ -1,6 +1,7 @@
package postgres
import (
"context"
"testing"
"time"
@ -147,14 +148,13 @@ func TestConvertExprToSQL(t *testing.T) {
},
}
engine, err := filter.DefaultEngine()
require.NoError(t, err)
for _, tt := range tests {
parsedExpr, err := filter.Parse(tt.filter, filter.MemoFilterCELAttributes...)
require.NoError(t, err)
convertCtx := filter.NewConvertContext()
converter := filter.NewCommonSQLConverterWithOffset(&filter.PostgreSQLDialect{}, convertCtx.ArgsOffset+len(convertCtx.Args))
err = converter.ConvertExprToSQL(convertCtx, parsedExpr.GetExpr())
stmt, err := engine.CompileToStatement(context.Background(), tt.filter, filter.RenderOptions{Dialect: filter.DialectPostgres})
require.NoError(t, err)
require.Equal(t, tt.want, convertCtx.Buffer.String())
require.Equal(t, tt.args, convertCtx.Args)
require.Equal(t, tt.want, stmt.SQL)
require.Equal(t, tt.args, stmt.Args)
}
}

@ -49,24 +49,32 @@ func (d *DB) ListMemoRelations(ctx context.Context, find *store.FindMemoRelation
where, args = append(where, "type = "+placeholder(len(args)+1)), append(args, find.Type)
}
if find.MemoFilter != nil {
// Parse filter string and return the parsed expression.
// The filter string should be a CEL expression.
parsedExpr, err := filter.Parse(*find.MemoFilter, filter.MemoFilterCELAttributes...)
engine, err := filter.DefaultEngine()
if err != nil {
return nil, err
}
convertCtx := filter.NewConvertContext()
convertCtx.ArgsOffset = len(args)
// ConvertExprToSQL converts the parsed expression to a SQL condition string.
converter := filter.NewCommonSQLConverterWithOffset(&filter.PostgreSQLDialect{}, convertCtx.ArgsOffset+len(convertCtx.Args))
if err := converter.ConvertExprToSQL(convertCtx, parsedExpr.GetExpr()); err != nil {
stmt, err := engine.CompileToStatement(ctx, *find.MemoFilter, filter.RenderOptions{
Dialect: filter.DialectPostgres,
PlaceholderOffset: len(args),
})
if err != nil {
return nil, err
}
condition := convertCtx.Buffer.String()
if condition != "" {
where = append(where, fmt.Sprintf("memo_id IN (SELECT id FROM memo WHERE %s)", condition))
where = append(where, fmt.Sprintf("related_memo_id IN (SELECT id FROM memo WHERE %s)", condition))
args = append(args, convertCtx.Args...)
if stmt.SQL != "" {
where = append(where, fmt.Sprintf("memo_id IN (SELECT id FROM memo WHERE %s)", stmt.SQL))
args = append(args, stmt.Args...)
stmtRelated, err := engine.CompileToStatement(ctx, *find.MemoFilter, filter.RenderOptions{
Dialect: filter.DialectPostgres,
PlaceholderOffset: len(args),
})
if err != nil {
return nil, err
}
if stmtRelated.SQL != "" {
where = append(where, fmt.Sprintf("related_memo_id IN (SELECT id FROM memo WHERE %s)", stmtRelated.SQL))
args = append(args, stmtRelated.Args...)
}
}
}

@ -2,10 +2,8 @@ package postgres
import (
"context"
"fmt"
"strings"
"github.com/usememos/memos/plugin/filter"
"github.com/usememos/memos/store"
)
@ -25,27 +23,7 @@ func (d *DB) UpsertReaction(ctx context.Context, upsert *store.Reaction) (*store
}
func (d *DB) ListReactions(ctx context.Context, find *store.FindReaction) ([]*store.Reaction, error) {
where, args := []string{"1 = 1"}, []interface{}{}
for _, filterStr := range find.Filters {
// Parse filter string and return the parsed expression.
// The filter string should be a CEL expression.
parsedExpr, err := filter.Parse(filterStr, filter.ReactionFilterCELAttributes...)
if err != nil {
return nil, err
}
convertCtx := filter.NewConvertContext()
// ConvertExprToSQL converts the parsed expression to a SQL condition string.
converter := filter.NewCommonSQLConverter(&filter.PostgreSQLDialect{})
if err := converter.ConvertExprToSQL(convertCtx, parsedExpr.GetExpr()); err != nil {
return nil, err
}
condition := convertCtx.Buffer.String()
if condition != "" {
where = append(where, fmt.Sprintf("(%s)", condition))
args = append(args, convertCtx.Args...)
}
}
where, args := []string{"1 = 1"}, []any{}
if find.ID != nil {
where, args = append(where, "id = "+placeholder(len(args)+1)), append(args, *find.ID)
@ -56,6 +34,18 @@ func (d *DB) ListReactions(ctx context.Context, find *store.FindReaction) ([]*st
if find.ContentID != nil {
where, args = append(where, "content_id = "+placeholder(len(args)+1)), append(args, *find.ContentID)
}
if len(find.ContentIDList) > 0 {
holders := make([]string, 0, len(find.ContentIDList))
for range find.ContentIDList {
holders = append(holders, placeholder(len(args)+1))
}
if len(holders) > 0 {
where = append(where, "content_id IN ("+strings.Join(holders, ", ")+")")
for _, id := range find.ContentIDList {
args = append(args, id)
}
}
}
rows, err := d.db.QueryContext(ctx, `
SELECT

@ -1,39 +0,0 @@
package postgres
import (
"testing"
"github.com/stretchr/testify/require"
"github.com/usememos/memos/plugin/filter"
)
func TestReactionConvertExprToSQL(t *testing.T) {
tests := []struct {
filter string
want string
args []any
}{
{
filter: `content_id in ["memos/5atZAj8GcvkSuUA3X2KLaY"]`,
want: "reaction.content_id IN ($1)",
args: []any{"memos/5atZAj8GcvkSuUA3X2KLaY"},
},
{
filter: `content_id in ["memos/5atZAj8GcvkSuUA3X2KLaY", "memos/4EN8aEpcJ3MaK4ExHTpiTE"]`,
want: "reaction.content_id IN ($1,$2)",
args: []any{"memos/5atZAj8GcvkSuUA3X2KLaY", "memos/4EN8aEpcJ3MaK4ExHTpiTE"},
},
}
for _, tt := range tests {
parsedExpr, err := filter.Parse(tt.filter, filter.ReactionFilterCELAttributes...)
require.NoError(t, err)
convertCtx := filter.NewConvertContext()
converter := filter.NewCommonSQLConverter(&filter.PostgreSQLDialect{})
err = converter.ConvertExprToSQL(convertCtx, parsedExpr.GetExpr())
require.NoError(t, err)
require.Equal(t, tt.want, convertCtx.Buffer.String())
require.Equal(t, tt.args, convertCtx.Args)
}
}

@ -5,7 +5,8 @@ import (
"fmt"
"strings"
"github.com/usememos/memos/plugin/filter"
"github.com/pkg/errors"
"github.com/usememos/memos/store"
)
@ -86,24 +87,8 @@ func (d *DB) UpdateUser(ctx context.Context, update *store.UpdateUser) (*store.U
func (d *DB) ListUsers(ctx context.Context, find *store.FindUser) ([]*store.User, error) {
where, args := []string{"1 = 1"}, []any{}
for _, filterStr := range find.Filters {
// Parse filter string and return the parsed expression.
// The filter string should be a CEL expression.
parsedExpr, err := filter.Parse(filterStr, filter.UserFilterCELAttributes...)
if err != nil {
return nil, err
}
convertCtx := filter.NewConvertContext()
// ConvertExprToSQL converts the parsed expression to a SQL condition string.
converter := filter.NewUserSQLConverter(&filter.PostgreSQLDialect{})
if err := converter.ConvertExprToSQL(convertCtx, parsedExpr.GetExpr()); err != nil {
return nil, err
}
condition := convertCtx.Buffer.String()
if condition != "" {
where = append(where, fmt.Sprintf("(%s)", condition))
args = append(args, convertCtx.Args...)
}
if len(find.Filters) > 0 {
return nil, errors.Errorf("user filters are not supported")
}
if v := find.ID; v != nil {

@ -9,7 +9,6 @@ import (
"github.com/pkg/errors"
"google.golang.org/protobuf/encoding/protojson"
"github.com/usememos/memos/plugin/filter"
storepb "github.com/usememos/memos/proto/gen/store"
"github.com/usememos/memos/store"
)
@ -42,26 +41,6 @@ func (d *DB) CreateAttachment(ctx context.Context, create *store.Attachment) (*s
func (d *DB) ListAttachments(ctx context.Context, find *store.FindAttachment) ([]*store.Attachment, error) {
where, args := []string{"1 = 1"}, []any{}
for _, filterStr := range find.Filters {
// Parse filter string and return the parsed expression.
// The filter string should be a CEL expression.
parsedExpr, err := filter.Parse(filterStr, filter.AttachmentFilterCELAttributes...)
if err != nil {
return nil, err
}
convertCtx := filter.NewConvertContext()
// ConvertExprToSQL converts the parsed expression to a SQL condition string.
converter := filter.NewCommonSQLConverter(&filter.SQLiteDialect{})
if err := converter.ConvertExprToSQL(convertCtx, parsedExpr.GetExpr()); err != nil {
return nil, err
}
condition := convertCtx.Buffer.String()
if condition != "" {
where = append(where, fmt.Sprintf("(%s)", condition))
args = append(args, convertCtx.Args...)
}
}
if v := find.ID; v != nil {
where, args = append(where, "`resource`.`id` = ?"), append(args, *v)
}
@ -80,6 +59,16 @@ func (d *DB) ListAttachments(ctx context.Context, find *store.FindAttachment) ([
if v := find.MemoID; v != nil {
where, args = append(where, "`resource`.`memo_id` = ?"), append(args, *v)
}
if len(find.MemoIDList) > 0 {
placeholders := make([]string, 0, len(find.MemoIDList))
for range find.MemoIDList {
placeholders = append(placeholders, "?")
}
where = append(where, "`resource`.`memo_id` IN ("+strings.Join(placeholders, ",")+")")
for _, id := range find.MemoIDList {
args = append(args, id)
}
}
if find.HasRelatedMemo {
where = append(where, "`resource`.`memo_id` IS NOT NULL")
}

@ -1,39 +0,0 @@
package sqlite
import (
"testing"
"github.com/stretchr/testify/require"
"github.com/usememos/memos/plugin/filter"
)
func TestAttachmentConvertExprToSQL(t *testing.T) {
tests := []struct {
filter string
want string
args []any
}{
{
filter: `memo_id in ["5atZAj8GcvkSuUA3X2KLaY"]`,
want: "`resource`.`memo_id` IN (?)",
args: []any{"5atZAj8GcvkSuUA3X2KLaY"},
},
{
filter: `memo_id in ["5atZAj8GcvkSuUA3X2KLaY", "4EN8aEpcJ3MaK4ExHTpiTE"]`,
want: "`resource`.`memo_id` IN (?,?)",
args: []any{"5atZAj8GcvkSuUA3X2KLaY", "4EN8aEpcJ3MaK4ExHTpiTE"},
},
}
for _, tt := range tests {
parsedExpr, err := filter.Parse(tt.filter, filter.AttachmentFilterCELAttributes...)
require.NoError(t, err)
convertCtx := filter.NewConvertContext()
converter := filter.NewCommonSQLConverter(&filter.SQLiteDialect{})
err = converter.ConvertExprToSQL(convertCtx, parsedExpr.GetExpr())
require.NoError(t, err)
require.Equal(t, tt.want, convertCtx.Buffer.String())
require.Equal(t, tt.args, convertCtx.Args)
}
}

@ -42,31 +42,39 @@ func (d *DB) CreateMemo(ctx context.Context, create *store.Memo) (*store.Memo, e
func (d *DB) ListMemos(ctx context.Context, find *store.FindMemo) ([]*store.Memo, error) {
where, args := []string{"1 = 1"}, []any{}
for _, filterStr := range find.Filters {
// Parse filter string and return the parsed expression.
// The filter string should be a CEL expression.
parsedExpr, err := filter.Parse(filterStr, filter.MemoFilterCELAttributes...)
if err != nil {
return nil, err
}
convertCtx := filter.NewConvertContext()
// ConvertExprToSQL converts the parsed expression to a SQL condition string.
converter := filter.NewCommonSQLConverter(&filter.SQLiteDialect{})
if err := converter.ConvertExprToSQL(convertCtx, parsedExpr.GetExpr()); err != nil {
return nil, err
}
condition := convertCtx.Buffer.String()
if condition != "" {
where = append(where, fmt.Sprintf("(%s)", condition))
args = append(args, convertCtx.Args...)
}
engine, err := filter.DefaultEngine()
if err != nil {
return nil, err
}
if err := filter.AppendConditions(ctx, engine, find.Filters, filter.DialectSQLite, &where, &args); err != nil {
return nil, err
}
if v := find.ID; v != nil {
where, args = append(where, "`memo`.`id` = ?"), append(args, *v)
}
if len(find.IDList) > 0 {
placeholders := make([]string, 0, len(find.IDList))
for range find.IDList {
placeholders = append(placeholders, "?")
}
where = append(where, "`memo`.`id` IN ("+strings.Join(placeholders, ",")+")")
for _, id := range find.IDList {
args = append(args, id)
}
}
if v := find.UID; v != nil {
where, args = append(where, "`memo`.`uid` = ?"), append(args, *v)
}
if len(find.UIDList) > 0 {
placeholders := make([]string, 0, len(find.UIDList))
for range find.UIDList {
placeholders = append(placeholders, "?")
}
where = append(where, "`memo`.`uid` IN ("+strings.Join(placeholders, ",")+")")
for _, uid := range find.UIDList {
args = append(args, uid)
}
}
if v := find.CreatorID; v != nil {
where, args = append(where, "`memo`.`creator_id` = ?"), append(args, *v)
}

@ -1,6 +1,7 @@
package sqlite
import (
"context"
"testing"
"time"
@ -152,14 +153,13 @@ func TestConvertExprToSQL(t *testing.T) {
},
}
engine, err := filter.DefaultEngine()
require.NoError(t, err)
for _, tt := range tests {
parsedExpr, err := filter.Parse(tt.filter, filter.MemoFilterCELAttributes...)
require.NoError(t, err)
convertCtx := filter.NewConvertContext()
converter := filter.NewCommonSQLConverter(&filter.SQLiteDialect{})
err = converter.ConvertExprToSQL(convertCtx, parsedExpr.GetExpr())
stmt, err := engine.CompileToStatement(context.Background(), tt.filter, filter.RenderOptions{Dialect: filter.DialectSQLite})
require.NoError(t, err)
require.Equal(t, tt.want, convertCtx.Buffer.String())
require.Equal(t, tt.args, convertCtx.Args)
require.Equal(t, tt.want, stmt.SQL)
require.Equal(t, tt.args, stmt.Args)
}
}

@ -49,23 +49,18 @@ func (d *DB) ListMemoRelations(ctx context.Context, find *store.FindMemoRelation
where, args = append(where, "type = ?"), append(args, find.Type)
}
if find.MemoFilter != nil {
// Parse filter string and return the parsed expression.
// The filter string should be a CEL expression.
parsedExpr, err := filter.Parse(*find.MemoFilter, filter.MemoFilterCELAttributes...)
engine, err := filter.DefaultEngine()
if err != nil {
return nil, err
}
convertCtx := filter.NewConvertContext()
// ConvertExprToSQL converts the parsed expression to a SQL condition string.
converter := filter.NewCommonSQLConverter(&filter.SQLiteDialect{})
if err := converter.ConvertExprToSQL(convertCtx, parsedExpr.GetExpr()); err != nil {
stmt, err := engine.CompileToStatement(ctx, *find.MemoFilter, filter.RenderOptions{Dialect: filter.DialectSQLite})
if err != nil {
return nil, err
}
condition := convertCtx.Buffer.String()
if condition != "" {
where = append(where, fmt.Sprintf("memo_id IN (SELECT id FROM memo WHERE %s)", condition))
where = append(where, fmt.Sprintf("related_memo_id IN (SELECT id FROM memo WHERE %s)", condition))
args = append(args, append(convertCtx.Args, convertCtx.Args...)...)
if stmt.SQL != "" {
where = append(where, fmt.Sprintf("memo_id IN (SELECT id FROM memo WHERE %s)", stmt.SQL))
where = append(where, fmt.Sprintf("related_memo_id IN (SELECT id FROM memo WHERE %s)", stmt.SQL))
args = append(args, append(stmt.Args, stmt.Args...)...)
}
}

@ -2,10 +2,8 @@ package sqlite
import (
"context"
"fmt"
"strings"
"github.com/usememos/memos/plugin/filter"
"github.com/usememos/memos/store"
)
@ -26,27 +24,7 @@ func (d *DB) UpsertReaction(ctx context.Context, upsert *store.Reaction) (*store
}
func (d *DB) ListReactions(ctx context.Context, find *store.FindReaction) ([]*store.Reaction, error) {
where, args := []string{"1 = 1"}, []interface{}{}
for _, filterStr := range find.Filters {
// Parse filter string and return the parsed expression.
// The filter string should be a CEL expression.
parsedExpr, err := filter.Parse(filterStr, filter.ReactionFilterCELAttributes...)
if err != nil {
return nil, err
}
convertCtx := filter.NewConvertContext()
// ConvertExprToSQL converts the parsed expression to a SQL condition string.
converter := filter.NewCommonSQLConverter(&filter.SQLiteDialect{})
if err := converter.ConvertExprToSQL(convertCtx, parsedExpr.GetExpr()); err != nil {
return nil, err
}
condition := convertCtx.Buffer.String()
if condition != "" {
where = append(where, fmt.Sprintf("(%s)", condition))
args = append(args, convertCtx.Args...)
}
}
where, args := []string{"1 = 1"}, []any{}
if find.ID != nil {
where, args = append(where, "id = ?"), append(args, *find.ID)
@ -57,6 +35,18 @@ func (d *DB) ListReactions(ctx context.Context, find *store.FindReaction) ([]*st
if find.ContentID != nil {
where, args = append(where, "content_id = ?"), append(args, *find.ContentID)
}
if len(find.ContentIDList) > 0 {
placeholders := make([]string, 0, len(find.ContentIDList))
for range find.ContentIDList {
placeholders = append(placeholders, "?")
}
if len(placeholders) > 0 {
where = append(where, "content_id IN ("+strings.Join(placeholders, ",")+")")
for _, id := range find.ContentIDList {
args = append(args, id)
}
}
}
rows, err := d.db.QueryContext(ctx, `
SELECT

@ -1,39 +0,0 @@
package sqlite
import (
"testing"
"github.com/stretchr/testify/require"
"github.com/usememos/memos/plugin/filter"
)
func TestReactionConvertExprToSQL(t *testing.T) {
tests := []struct {
filter string
want string
args []any
}{
{
filter: `content_id in ["memos/5atZAj8GcvkSuUA3X2KLaY"]`,
want: "`reaction`.`content_id` IN (?)",
args: []any{"memos/5atZAj8GcvkSuUA3X2KLaY"},
},
{
filter: `content_id in ["memos/5atZAj8GcvkSuUA3X2KLaY", "memos/4EN8aEpcJ3MaK4ExHTpiTE"]`,
want: "`reaction`.`content_id` IN (?,?)",
args: []any{"memos/5atZAj8GcvkSuUA3X2KLaY", "memos/4EN8aEpcJ3MaK4ExHTpiTE"},
},
}
for _, tt := range tests {
parsedExpr, err := filter.Parse(tt.filter, filter.ReactionFilterCELAttributes...)
require.NoError(t, err)
convertCtx := filter.NewConvertContext()
converter := filter.NewCommonSQLConverter(&filter.SQLiteDialect{})
err = converter.ConvertExprToSQL(convertCtx, parsedExpr.GetExpr())
require.NoError(t, err)
require.Equal(t, tt.want, convertCtx.Buffer.String())
require.Equal(t, tt.args, convertCtx.Args)
}
}

@ -5,7 +5,8 @@ import (
"fmt"
"strings"
"github.com/usememos/memos/plugin/filter"
"github.com/pkg/errors"
"github.com/usememos/memos/store"
)
@ -87,24 +88,8 @@ func (d *DB) UpdateUser(ctx context.Context, update *store.UpdateUser) (*store.U
func (d *DB) ListUsers(ctx context.Context, find *store.FindUser) ([]*store.User, error) {
where, args := []string{"1 = 1"}, []any{}
for _, filterStr := range find.Filters {
// Parse filter string and return the parsed expression.
// The filter string should be a CEL expression.
parsedExpr, err := filter.Parse(filterStr, filter.UserFilterCELAttributes...)
if err != nil {
return nil, err
}
convertCtx := filter.NewConvertContext()
// ConvertExprToSQL converts the parsed expression to a SQL condition string.
converter := filter.NewUserSQLConverter(&filter.SQLiteDialect{})
if err := converter.ConvertExprToSQL(convertCtx, parsedExpr.GetExpr()); err != nil {
return nil, err
}
condition := convertCtx.Buffer.String()
if condition != "" {
where = append(where, fmt.Sprintf("(%s)", condition))
args = append(args, convertCtx.Args...)
}
if len(find.Filters) > 0 {
return nil, errors.Errorf("user filters are not supported")
}
if v := find.ID; v != nil {

@ -60,6 +60,9 @@ type FindMemo struct {
ID *int32
UID *string
IDList []int32
UIDList []string
// Standard fields
RowStatus *RowStatus
CreatorID *int32

@ -14,10 +14,10 @@ type Reaction struct {
}
type FindReaction struct {
ID *int32
CreatorID *int32
ContentID *string
Filters []string
ID *int32
CreatorID *int32
ContentID *string
ContentIDList []string
}
type DeleteReaction struct {

Loading…
Cancel
Save