mirror of https://github.com/usememos/memos
refactor: memo filter
- Updated memo and reaction filtering logic to use a unified engine for compiling filter expressions into SQL statements. - Removed redundant filter parsing and conversion code from ListMemoRelations, ListReactions, and ListAttachments methods. - Introduced IDList and UIDList fields in FindMemo and FindReaction structs to support filtering by multiple IDs. - Removed old filter test files for reactions and attachments, as the filtering logic has been centralized. - Updated tests for memo filtering to reflect the new SQL statement compilation approach. - Ensured that unsupported user filters return an error in ListUsers method.pull/5091/merge
parent
228cc6105d
commit
b685ffacdf
@ -1,746 +0,0 @@
|
||||
package filter
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"slices"
|
||||
"strings"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
exprv1 "google.golang.org/genproto/googleapis/api/expr/v1alpha1"
|
||||
)
|
||||
|
||||
// CommonSQLConverter handles the common CEL to SQL conversion logic.
|
||||
type CommonSQLConverter struct {
|
||||
dialect SQLDialect
|
||||
paramIndex int
|
||||
allowedFields []string
|
||||
entityType string
|
||||
}
|
||||
|
||||
// NewCommonSQLConverter creates a new converter with the specified dialect for memo filters.
|
||||
func NewCommonSQLConverter(dialect SQLDialect) *CommonSQLConverter {
|
||||
return &CommonSQLConverter{
|
||||
dialect: dialect,
|
||||
paramIndex: 1,
|
||||
allowedFields: []string{"creator_id", "created_ts", "updated_ts", "visibility", "content", "pinned", "has_task_list", "has_link", "has_code", "has_incomplete_tasks"},
|
||||
entityType: "memo",
|
||||
}
|
||||
}
|
||||
|
||||
// NewCommonSQLConverterWithOffset creates a new converter with the specified dialect and parameter offset for memo filters.
|
||||
func NewCommonSQLConverterWithOffset(dialect SQLDialect, offset int) *CommonSQLConverter {
|
||||
return &CommonSQLConverter{
|
||||
dialect: dialect,
|
||||
paramIndex: offset + 1,
|
||||
allowedFields: []string{"creator_id", "created_ts", "updated_ts", "visibility", "content", "pinned", "has_task_list", "has_link", "has_code", "has_incomplete_tasks"},
|
||||
entityType: "memo",
|
||||
}
|
||||
}
|
||||
|
||||
// NewUserSQLConverter creates a new converter for user filters.
|
||||
func NewUserSQLConverter(dialect SQLDialect) *CommonSQLConverter {
|
||||
return &CommonSQLConverter{
|
||||
dialect: dialect,
|
||||
paramIndex: 1,
|
||||
allowedFields: []string{"username"},
|
||||
entityType: "user",
|
||||
}
|
||||
}
|
||||
|
||||
// ConvertExprToSQL converts a CEL expression to SQL using the configured dialect.
|
||||
func (c *CommonSQLConverter) ConvertExprToSQL(ctx *ConvertContext, expr *exprv1.Expr) error {
|
||||
if v, ok := expr.ExprKind.(*exprv1.Expr_CallExpr); ok {
|
||||
switch v.CallExpr.Function {
|
||||
case "_||_", "_&&_":
|
||||
return c.handleLogicalOperator(ctx, v.CallExpr)
|
||||
case "!_":
|
||||
return c.handleNotOperator(ctx, v.CallExpr)
|
||||
case "_==_", "_!=_", "_<_", "_>_", "_<=_", "_>=_":
|
||||
return c.handleComparisonOperator(ctx, v.CallExpr)
|
||||
case "@in":
|
||||
return c.handleInOperator(ctx, v.CallExpr)
|
||||
case "contains":
|
||||
return c.handleContainsOperator(ctx, v.CallExpr)
|
||||
default:
|
||||
return errors.Errorf("unsupported call expression function: %s", v.CallExpr.Function)
|
||||
}
|
||||
} else if v, ok := expr.ExprKind.(*exprv1.Expr_IdentExpr); ok {
|
||||
return c.handleIdentifier(ctx, v.IdentExpr)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *CommonSQLConverter) handleLogicalOperator(ctx *ConvertContext, callExpr *exprv1.Expr_Call) error {
|
||||
if len(callExpr.Args) != 2 {
|
||||
return errors.Errorf("invalid number of arguments for %s", callExpr.Function)
|
||||
}
|
||||
|
||||
if _, err := ctx.Buffer.WriteString("("); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := c.ConvertExprToSQL(ctx, callExpr.Args[0]); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
operator := "AND"
|
||||
if callExpr.Function == "_||_" {
|
||||
operator = "OR"
|
||||
}
|
||||
|
||||
if _, err := ctx.Buffer.WriteString(fmt.Sprintf(" %s ", operator)); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := c.ConvertExprToSQL(ctx, callExpr.Args[1]); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if _, err := ctx.Buffer.WriteString(")"); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *CommonSQLConverter) handleNotOperator(ctx *ConvertContext, callExpr *exprv1.Expr_Call) error {
|
||||
if len(callExpr.Args) != 1 {
|
||||
return errors.Errorf("invalid number of arguments for %s", callExpr.Function)
|
||||
}
|
||||
|
||||
if _, err := ctx.Buffer.WriteString("NOT ("); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if err := c.ConvertExprToSQL(ctx, callExpr.Args[0]); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if _, err := ctx.Buffer.WriteString(")"); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *CommonSQLConverter) handleComparisonOperator(ctx *ConvertContext, callExpr *exprv1.Expr_Call) error {
|
||||
if len(callExpr.Args) != 2 {
|
||||
return errors.Errorf("invalid number of arguments for %s", callExpr.Function)
|
||||
}
|
||||
|
||||
// Check if the left side is a function call like size(tags)
|
||||
if leftCallExpr, ok := callExpr.Args[0].ExprKind.(*exprv1.Expr_CallExpr); ok {
|
||||
if leftCallExpr.CallExpr.Function == "size" {
|
||||
return c.handleSizeComparison(ctx, callExpr, leftCallExpr.CallExpr)
|
||||
}
|
||||
}
|
||||
|
||||
identifier, err := GetIdentExprName(callExpr.Args[0])
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if !slices.Contains(c.allowedFields, identifier) {
|
||||
return errors.Errorf("invalid identifier for %s", callExpr.Function)
|
||||
}
|
||||
|
||||
value, err := GetExprValue(callExpr.Args[1])
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
operator := c.getComparisonOperator(callExpr.Function)
|
||||
|
||||
// Handle memo fields
|
||||
if c.entityType == "memo" {
|
||||
switch identifier {
|
||||
case "created_ts", "updated_ts":
|
||||
return c.handleTimestampComparison(ctx, identifier, operator, value)
|
||||
case "visibility", "content":
|
||||
return c.handleStringComparison(ctx, identifier, operator, value)
|
||||
case "creator_id":
|
||||
return c.handleIntComparison(ctx, identifier, operator, value)
|
||||
case "pinned":
|
||||
return c.handlePinnedComparison(ctx, operator, value)
|
||||
case "has_task_list", "has_link", "has_code", "has_incomplete_tasks":
|
||||
return c.handleBooleanComparison(ctx, identifier, operator, value)
|
||||
default:
|
||||
return errors.Errorf("unsupported identifier in comparison: %s", identifier)
|
||||
}
|
||||
}
|
||||
|
||||
// Handle user fields
|
||||
if c.entityType == "user" {
|
||||
switch identifier {
|
||||
case "username":
|
||||
return c.handleUserStringComparison(ctx, identifier, operator, value)
|
||||
default:
|
||||
return errors.Errorf("unsupported user identifier in comparison: %s", identifier)
|
||||
}
|
||||
}
|
||||
|
||||
return errors.Errorf("unsupported entity type: %s", c.entityType)
|
||||
}
|
||||
|
||||
func (c *CommonSQLConverter) handleSizeComparison(ctx *ConvertContext, callExpr *exprv1.Expr_Call, sizeCall *exprv1.Expr_Call) error {
|
||||
if len(sizeCall.Args) != 1 {
|
||||
return errors.New("size function requires exactly one argument")
|
||||
}
|
||||
|
||||
identifier, err := GetIdentExprName(sizeCall.Args[0])
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if identifier != "tags" {
|
||||
return errors.Errorf("size function only supports 'tags' identifier, got: %s", identifier)
|
||||
}
|
||||
|
||||
value, err := GetExprValue(callExpr.Args[1])
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
valueInt, ok := value.(int64)
|
||||
if !ok {
|
||||
return errors.New("size comparison value must be an integer")
|
||||
}
|
||||
|
||||
operator := c.getComparisonOperator(callExpr.Function)
|
||||
|
||||
if _, err := ctx.Buffer.WriteString(fmt.Sprintf("%s %s %s",
|
||||
c.dialect.GetJSONArrayLength("$.tags"),
|
||||
operator,
|
||||
c.dialect.GetParameterPlaceholder(c.paramIndex))); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
ctx.Args = append(ctx.Args, valueInt)
|
||||
c.paramIndex++
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *CommonSQLConverter) handleInOperator(ctx *ConvertContext, callExpr *exprv1.Expr_Call) error {
|
||||
if len(callExpr.Args) != 2 {
|
||||
return errors.Errorf("invalid number of arguments for %s", callExpr.Function)
|
||||
}
|
||||
|
||||
// Check if this is "element in collection" syntax
|
||||
if identifier, err := GetIdentExprName(callExpr.Args[1]); err == nil {
|
||||
if identifier == "tags" {
|
||||
return c.handleElementInTags(ctx, callExpr.Args[0])
|
||||
}
|
||||
return errors.Errorf("invalid collection identifier for %s: %s", callExpr.Function, identifier)
|
||||
}
|
||||
|
||||
// Original logic for "identifier in [list]" syntax
|
||||
identifier, err := GetIdentExprName(callExpr.Args[0])
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if !slices.Contains([]string{"tag", "visibility", "content_id", "memo_id"}, identifier) {
|
||||
return errors.Errorf("invalid identifier for %s", callExpr.Function)
|
||||
}
|
||||
|
||||
values := []any{}
|
||||
for _, element := range callExpr.Args[1].GetListExpr().Elements {
|
||||
value, err := GetConstValue(element)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
values = append(values, value)
|
||||
}
|
||||
|
||||
if identifier == "tag" {
|
||||
return c.handleTagInList(ctx, values)
|
||||
} else if identifier == "visibility" {
|
||||
return c.handleVisibilityInList(ctx, values)
|
||||
} else if identifier == "content_id" {
|
||||
return c.handleContentIDInList(ctx, values)
|
||||
} else if identifier == "memo_id" {
|
||||
return c.handleMemoIDInList(ctx, values)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *CommonSQLConverter) handleElementInTags(ctx *ConvertContext, elementExpr *exprv1.Expr) error {
|
||||
element, err := GetConstValue(elementExpr)
|
||||
if err != nil {
|
||||
return errors.Errorf("first argument must be a constant value for 'element in tags': %v", err)
|
||||
}
|
||||
|
||||
// Use dialect-specific JSON contains logic
|
||||
template := c.dialect.GetJSONContains("$.tags", "element")
|
||||
sqlExpr := strings.Replace(template, "?", c.dialect.GetParameterPlaceholder(c.paramIndex), 1)
|
||||
if _, err := ctx.Buffer.WriteString(sqlExpr); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
// Handle args based on dialect
|
||||
if _, ok := c.dialect.(*SQLiteDialect); ok {
|
||||
// SQLite uses LIKE with pattern
|
||||
ctx.Args = append(ctx.Args, fmt.Sprintf(`%%"%s"%%`, element))
|
||||
} else {
|
||||
// MySQL and PostgreSQL expect plain values
|
||||
ctx.Args = append(ctx.Args, element)
|
||||
}
|
||||
c.paramIndex++
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *CommonSQLConverter) handleTagInList(ctx *ConvertContext, values []any) error {
|
||||
subconditions := []string{}
|
||||
args := []any{}
|
||||
|
||||
for _, v := range values {
|
||||
if _, ok := c.dialect.(*SQLiteDialect); ok {
|
||||
subconditions = append(subconditions, c.dialect.GetJSONLike("$.tags", "pattern"))
|
||||
args = append(args, fmt.Sprintf(`%%"%s"%%`, v))
|
||||
} else {
|
||||
// Replace ? with proper placeholder for each dialect
|
||||
template := c.dialect.GetJSONContains("$.tags", "element")
|
||||
sql := strings.Replace(template, "?", c.dialect.GetParameterPlaceholder(c.paramIndex), 1)
|
||||
subconditions = append(subconditions, sql)
|
||||
args = append(args, fmt.Sprintf(`"%s"`, v))
|
||||
}
|
||||
c.paramIndex++
|
||||
}
|
||||
|
||||
if len(subconditions) == 1 {
|
||||
if _, err := ctx.Buffer.WriteString(subconditions[0]); err != nil {
|
||||
return err
|
||||
}
|
||||
} else {
|
||||
if _, err := ctx.Buffer.WriteString(fmt.Sprintf("(%s)", strings.Join(subconditions, " OR "))); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
ctx.Args = append(ctx.Args, args...)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *CommonSQLConverter) handleVisibilityInList(ctx *ConvertContext, values []any) error {
|
||||
placeholders := []string{}
|
||||
for range values {
|
||||
placeholders = append(placeholders, c.dialect.GetParameterPlaceholder(c.paramIndex))
|
||||
c.paramIndex++
|
||||
}
|
||||
|
||||
tablePrefix := c.dialect.GetTablePrefix("memo")
|
||||
if _, ok := c.dialect.(*PostgreSQLDialect); ok {
|
||||
if _, err := ctx.Buffer.WriteString(fmt.Sprintf("%s.visibility IN (%s)", tablePrefix, strings.Join(placeholders, ","))); err != nil {
|
||||
return err
|
||||
}
|
||||
} else {
|
||||
if _, err := ctx.Buffer.WriteString(fmt.Sprintf("%s.`visibility` IN (%s)", tablePrefix, strings.Join(placeholders, ","))); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
ctx.Args = append(ctx.Args, values...)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *CommonSQLConverter) handleContentIDInList(ctx *ConvertContext, values []any) error {
|
||||
placeholders := []string{}
|
||||
for range values {
|
||||
placeholders = append(placeholders, c.dialect.GetParameterPlaceholder(c.paramIndex))
|
||||
c.paramIndex++
|
||||
}
|
||||
|
||||
tablePrefix := c.dialect.GetTablePrefix("reaction")
|
||||
if _, ok := c.dialect.(*PostgreSQLDialect); ok {
|
||||
if _, err := ctx.Buffer.WriteString(fmt.Sprintf("%s.content_id IN (%s)", tablePrefix, strings.Join(placeholders, ","))); err != nil {
|
||||
return err
|
||||
}
|
||||
} else {
|
||||
if _, err := ctx.Buffer.WriteString(fmt.Sprintf("%s.`content_id` IN (%s)", tablePrefix, strings.Join(placeholders, ","))); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
ctx.Args = append(ctx.Args, values...)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *CommonSQLConverter) handleMemoIDInList(ctx *ConvertContext, values []any) error {
|
||||
placeholders := []string{}
|
||||
for range values {
|
||||
placeholders = append(placeholders, c.dialect.GetParameterPlaceholder(c.paramIndex))
|
||||
c.paramIndex++
|
||||
}
|
||||
|
||||
tablePrefix := c.dialect.GetTablePrefix("resource")
|
||||
if _, ok := c.dialect.(*PostgreSQLDialect); ok {
|
||||
if _, err := ctx.Buffer.WriteString(fmt.Sprintf("%s.memo_id IN (%s)", tablePrefix, strings.Join(placeholders, ","))); err != nil {
|
||||
return err
|
||||
}
|
||||
} else {
|
||||
if _, err := ctx.Buffer.WriteString(fmt.Sprintf("%s.`memo_id` IN (%s)", tablePrefix, strings.Join(placeholders, ","))); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
ctx.Args = append(ctx.Args, values...)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *CommonSQLConverter) handleContainsOperator(ctx *ConvertContext, callExpr *exprv1.Expr_Call) error {
|
||||
if len(callExpr.Args) != 1 {
|
||||
return errors.Errorf("invalid number of arguments for %s", callExpr.Function)
|
||||
}
|
||||
|
||||
identifier, err := GetIdentExprName(callExpr.Target)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if identifier != "content" {
|
||||
return errors.Errorf("invalid identifier for %s", callExpr.Function)
|
||||
}
|
||||
|
||||
arg, err := GetConstValue(callExpr.Args[0])
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
tablePrefix := c.dialect.GetTablePrefix("memo")
|
||||
|
||||
// PostgreSQL uses ILIKE and no backticks
|
||||
if _, ok := c.dialect.(*PostgreSQLDialect); ok {
|
||||
if _, err := ctx.Buffer.WriteString(fmt.Sprintf("%s.content ILIKE %s", tablePrefix, c.dialect.GetParameterPlaceholder(c.paramIndex))); err != nil {
|
||||
return err
|
||||
}
|
||||
} else {
|
||||
if _, err := ctx.Buffer.WriteString(fmt.Sprintf("%s.`content` LIKE %s", tablePrefix, c.dialect.GetParameterPlaceholder(c.paramIndex))); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
ctx.Args = append(ctx.Args, fmt.Sprintf("%%%s%%", arg))
|
||||
c.paramIndex++
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *CommonSQLConverter) handleIdentifier(ctx *ConvertContext, identExpr *exprv1.Expr_Ident) error {
|
||||
identifier := identExpr.GetName()
|
||||
|
||||
// Only memo entity has boolean identifiers that can be used standalone
|
||||
if c.entityType != "memo" {
|
||||
return errors.Errorf("invalid identifier %s for entity type %s", identifier, c.entityType)
|
||||
}
|
||||
|
||||
if !slices.Contains([]string{"pinned", "has_task_list", "has_link", "has_code", "has_incomplete_tasks"}, identifier) {
|
||||
return errors.Errorf("invalid identifier %s", identifier)
|
||||
}
|
||||
|
||||
if identifier == "pinned" {
|
||||
tablePrefix := c.dialect.GetTablePrefix("memo")
|
||||
if _, ok := c.dialect.(*PostgreSQLDialect); ok {
|
||||
if _, err := ctx.Buffer.WriteString(fmt.Sprintf("%s.pinned IS TRUE", tablePrefix)); err != nil {
|
||||
return err
|
||||
}
|
||||
} else {
|
||||
if _, err := ctx.Buffer.WriteString(fmt.Sprintf("%s.`pinned` IS TRUE", tablePrefix)); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
} else if identifier == "has_task_list" {
|
||||
if _, err := ctx.Buffer.WriteString(c.dialect.GetBooleanCheck("$.property.hasTaskList")); err != nil {
|
||||
return err
|
||||
}
|
||||
} else if identifier == "has_link" {
|
||||
if _, err := ctx.Buffer.WriteString(c.dialect.GetBooleanCheck("$.property.hasLink")); err != nil {
|
||||
return err
|
||||
}
|
||||
} else if identifier == "has_code" {
|
||||
if _, err := ctx.Buffer.WriteString(c.dialect.GetBooleanCheck("$.property.hasCode")); err != nil {
|
||||
return err
|
||||
}
|
||||
} else if identifier == "has_incomplete_tasks" {
|
||||
if _, err := ctx.Buffer.WriteString(c.dialect.GetBooleanCheck("$.property.hasIncompleteTasks")); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *CommonSQLConverter) handleTimestampComparison(ctx *ConvertContext, field, operator string, value interface{}) error {
|
||||
valueInt, ok := value.(int64)
|
||||
if !ok {
|
||||
return errors.New("invalid integer timestamp value")
|
||||
}
|
||||
|
||||
timestampField := c.dialect.GetTimestampComparison(field)
|
||||
if _, err := ctx.Buffer.WriteString(fmt.Sprintf("%s %s %s", timestampField, operator, c.dialect.GetParameterPlaceholder(c.paramIndex))); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
ctx.Args = append(ctx.Args, valueInt)
|
||||
c.paramIndex++
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *CommonSQLConverter) handleStringComparison(ctx *ConvertContext, field, operator string, value interface{}) error {
|
||||
if operator != "=" && operator != "!=" {
|
||||
return errors.Errorf("invalid operator for %s", field)
|
||||
}
|
||||
|
||||
valueStr, ok := value.(string)
|
||||
if !ok {
|
||||
return errors.New("invalid string value")
|
||||
}
|
||||
|
||||
tablePrefix := c.dialect.GetTablePrefix("memo")
|
||||
|
||||
if _, ok := c.dialect.(*PostgreSQLDialect); ok {
|
||||
// PostgreSQL doesn't use backticks
|
||||
if _, err := ctx.Buffer.WriteString(fmt.Sprintf("%s.%s %s %s", tablePrefix, field, operator, c.dialect.GetParameterPlaceholder(c.paramIndex))); err != nil {
|
||||
return err
|
||||
}
|
||||
} else {
|
||||
// MySQL and SQLite use backticks
|
||||
fieldName := field
|
||||
if field == "visibility" {
|
||||
fieldName = "`visibility`"
|
||||
} else if field == "content" {
|
||||
fieldName = "`content`"
|
||||
}
|
||||
if _, err := ctx.Buffer.WriteString(fmt.Sprintf("%s.%s %s %s", tablePrefix, fieldName, operator, c.dialect.GetParameterPlaceholder(c.paramIndex))); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
ctx.Args = append(ctx.Args, valueStr)
|
||||
c.paramIndex++
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *CommonSQLConverter) handleUserStringComparison(ctx *ConvertContext, field, operator string, value interface{}) error {
|
||||
if operator != "=" && operator != "!=" {
|
||||
return errors.Errorf("invalid operator for %s", field)
|
||||
}
|
||||
|
||||
valueStr, ok := value.(string)
|
||||
if !ok {
|
||||
return errors.New("invalid string value")
|
||||
}
|
||||
|
||||
tablePrefix := c.dialect.GetTablePrefix("user")
|
||||
|
||||
if _, ok := c.dialect.(*PostgreSQLDialect); ok {
|
||||
// PostgreSQL doesn't use backticks
|
||||
if _, err := ctx.Buffer.WriteString(fmt.Sprintf("%s.%s %s %s", tablePrefix, field, operator, c.dialect.GetParameterPlaceholder(c.paramIndex))); err != nil {
|
||||
return err
|
||||
}
|
||||
} else {
|
||||
// MySQL and SQLite use backticks
|
||||
if _, err := ctx.Buffer.WriteString(fmt.Sprintf("%s.`%s` %s %s", tablePrefix, field, operator, c.dialect.GetParameterPlaceholder(c.paramIndex))); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
ctx.Args = append(ctx.Args, valueStr)
|
||||
c.paramIndex++
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *CommonSQLConverter) handleIntComparison(ctx *ConvertContext, field, operator string, value interface{}) error {
|
||||
if operator != "=" && operator != "!=" {
|
||||
return errors.Errorf("invalid operator for %s", field)
|
||||
}
|
||||
|
||||
valueInt, ok := value.(int64)
|
||||
if !ok {
|
||||
return errors.New("invalid int value")
|
||||
}
|
||||
|
||||
tablePrefix := c.dialect.GetTablePrefix("memo")
|
||||
|
||||
if _, ok := c.dialect.(*PostgreSQLDialect); ok {
|
||||
// PostgreSQL doesn't use backticks
|
||||
if _, err := ctx.Buffer.WriteString(fmt.Sprintf("%s.%s %s %s", tablePrefix, field, operator, c.dialect.GetParameterPlaceholder(c.paramIndex))); err != nil {
|
||||
return err
|
||||
}
|
||||
} else {
|
||||
// MySQL and SQLite use backticks
|
||||
if _, err := ctx.Buffer.WriteString(fmt.Sprintf("%s.`%s` %s %s", tablePrefix, field, operator, c.dialect.GetParameterPlaceholder(c.paramIndex))); err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
|
||||
ctx.Args = append(ctx.Args, valueInt)
|
||||
c.paramIndex++
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *CommonSQLConverter) handlePinnedComparison(ctx *ConvertContext, operator string, value interface{}) error {
|
||||
if operator != "=" && operator != "!=" {
|
||||
return errors.Errorf("invalid operator for pinned field")
|
||||
}
|
||||
|
||||
valueBool, ok := value.(bool)
|
||||
if !ok {
|
||||
return errors.New("invalid boolean value for pinned field")
|
||||
}
|
||||
|
||||
tablePrefix := c.dialect.GetTablePrefix("memo")
|
||||
|
||||
var sqlExpr string
|
||||
if _, ok := c.dialect.(*PostgreSQLDialect); ok {
|
||||
sqlExpr = fmt.Sprintf("%s.pinned %s %s", tablePrefix, operator, c.dialect.GetParameterPlaceholder(c.paramIndex))
|
||||
} else {
|
||||
sqlExpr = fmt.Sprintf("%s.`pinned` %s %s", tablePrefix, operator, c.dialect.GetParameterPlaceholder(c.paramIndex))
|
||||
}
|
||||
|
||||
if _, err := ctx.Buffer.WriteString(sqlExpr); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
ctx.Args = append(ctx.Args, c.dialect.GetBooleanValue(valueBool))
|
||||
c.paramIndex++
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *CommonSQLConverter) handleBooleanComparison(ctx *ConvertContext, field, operator string, value interface{}) error {
|
||||
if operator != "=" && operator != "!=" {
|
||||
return errors.Errorf("invalid operator for %s", field)
|
||||
}
|
||||
|
||||
valueBool, ok := value.(bool)
|
||||
if !ok {
|
||||
return errors.Errorf("invalid boolean value for %s", field)
|
||||
}
|
||||
|
||||
// Map field name to JSON path
|
||||
var jsonPath string
|
||||
switch field {
|
||||
case "has_task_list":
|
||||
jsonPath = "$.property.hasTaskList"
|
||||
case "has_link":
|
||||
jsonPath = "$.property.hasLink"
|
||||
case "has_code":
|
||||
jsonPath = "$.property.hasCode"
|
||||
case "has_incomplete_tasks":
|
||||
jsonPath = "$.property.hasIncompleteTasks"
|
||||
default:
|
||||
return errors.Errorf("unsupported boolean field: %s", field)
|
||||
}
|
||||
|
||||
// Special handling for SQLite based on field
|
||||
if _, ok := c.dialect.(*SQLiteDialect); ok {
|
||||
if field == "has_task_list" {
|
||||
// has_task_list uses = 1 / = 0 / != 1 / != 0
|
||||
var sqlExpr string
|
||||
if operator == "=" {
|
||||
if valueBool {
|
||||
sqlExpr = fmt.Sprintf("%s = 1", c.dialect.GetJSONExtract(jsonPath))
|
||||
} else {
|
||||
sqlExpr = fmt.Sprintf("%s = 0", c.dialect.GetJSONExtract(jsonPath))
|
||||
}
|
||||
} else { // operator == "!="
|
||||
if valueBool {
|
||||
sqlExpr = fmt.Sprintf("%s != 1", c.dialect.GetJSONExtract(jsonPath))
|
||||
} else {
|
||||
sqlExpr = fmt.Sprintf("%s != 0", c.dialect.GetJSONExtract(jsonPath))
|
||||
}
|
||||
}
|
||||
if _, err := ctx.Buffer.WriteString(sqlExpr); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
// Other fields use IS TRUE / NOT(... IS TRUE)
|
||||
var sqlExpr string
|
||||
if operator == "=" {
|
||||
if valueBool {
|
||||
sqlExpr = fmt.Sprintf("%s IS TRUE", c.dialect.GetJSONExtract(jsonPath))
|
||||
} else {
|
||||
sqlExpr = fmt.Sprintf("NOT(%s IS TRUE)", c.dialect.GetJSONExtract(jsonPath))
|
||||
}
|
||||
} else { // operator == "!="
|
||||
if valueBool {
|
||||
sqlExpr = fmt.Sprintf("NOT(%s IS TRUE)", c.dialect.GetJSONExtract(jsonPath))
|
||||
} else {
|
||||
sqlExpr = fmt.Sprintf("%s IS TRUE", c.dialect.GetJSONExtract(jsonPath))
|
||||
}
|
||||
}
|
||||
if _, err := ctx.Buffer.WriteString(sqlExpr); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Special handling for MySQL - use raw operator with CAST
|
||||
if _, ok := c.dialect.(*MySQLDialect); ok {
|
||||
var sqlExpr string
|
||||
boolStr := "false"
|
||||
if valueBool {
|
||||
boolStr = "true"
|
||||
}
|
||||
sqlExpr = fmt.Sprintf("%s %s CAST('%s' AS JSON)", c.dialect.GetJSONExtract(jsonPath), operator, boolStr)
|
||||
if _, err := ctx.Buffer.WriteString(sqlExpr); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Handle PostgreSQL differently - it uses the raw operator
|
||||
if _, ok := c.dialect.(*PostgreSQLDialect); ok {
|
||||
jsonExtract := c.dialect.GetJSONExtract(jsonPath)
|
||||
|
||||
sqlExpr := fmt.Sprintf("(%s)::boolean %s %s",
|
||||
jsonExtract,
|
||||
operator,
|
||||
c.dialect.GetParameterPlaceholder(c.paramIndex))
|
||||
if _, err := ctx.Buffer.WriteString(sqlExpr); err != nil {
|
||||
return err
|
||||
}
|
||||
ctx.Args = append(ctx.Args, valueBool)
|
||||
c.paramIndex++
|
||||
return nil
|
||||
}
|
||||
|
||||
// Handle other dialects
|
||||
if operator == "!=" {
|
||||
valueBool = !valueBool
|
||||
}
|
||||
|
||||
sqlExpr := c.dialect.GetBooleanComparison(jsonPath, valueBool)
|
||||
if _, err := ctx.Buffer.WriteString(sqlExpr); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (*CommonSQLConverter) getComparisonOperator(function string) string {
|
||||
switch function {
|
||||
case "_==_":
|
||||
return "="
|
||||
case "_!=_":
|
||||
return "!="
|
||||
case "_<_":
|
||||
return "<"
|
||||
case "_>_":
|
||||
return ">"
|
||||
case "_<=_":
|
||||
return "<="
|
||||
case "_>=_":
|
||||
return ">="
|
||||
default:
|
||||
return "="
|
||||
}
|
||||
}
|
||||
@ -1,20 +0,0 @@
|
||||
package filter
|
||||
|
||||
import (
|
||||
"strings"
|
||||
)
|
||||
|
||||
type ConvertContext struct {
|
||||
Buffer strings.Builder
|
||||
Args []any
|
||||
// The offset of the next argument in the condition string.
|
||||
// Mainly using for PostgreSQL.
|
||||
ArgsOffset int
|
||||
}
|
||||
|
||||
func NewConvertContext() *ConvertContext {
|
||||
return &ConvertContext{
|
||||
Buffer: strings.Builder{},
|
||||
Args: []any{},
|
||||
}
|
||||
}
|
||||
@ -1,215 +0,0 @@
|
||||
package filter
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// SQLDialect defines database-specific SQL generation methods.
|
||||
type SQLDialect interface {
|
||||
// Basic field access
|
||||
GetTablePrefix(entityName string) string
|
||||
GetParameterPlaceholder(index int) string
|
||||
|
||||
// JSON operations
|
||||
GetJSONExtract(path string) string
|
||||
GetJSONArrayLength(path string) string
|
||||
GetJSONContains(path, element string) string
|
||||
GetJSONLike(path, pattern string) string
|
||||
|
||||
// Boolean operations
|
||||
GetBooleanValue(value bool) interface{}
|
||||
GetBooleanComparison(path string, value bool) string
|
||||
GetBooleanCheck(path string) string
|
||||
|
||||
// Timestamp operations
|
||||
GetTimestampComparison(field string) string
|
||||
GetCurrentTimestamp() string
|
||||
}
|
||||
|
||||
// DatabaseType represents the type of database.
|
||||
type DatabaseType string
|
||||
|
||||
const (
|
||||
SQLite DatabaseType = "sqlite"
|
||||
MySQL DatabaseType = "mysql"
|
||||
PostgreSQL DatabaseType = "postgres"
|
||||
)
|
||||
|
||||
// GetDialect returns the appropriate dialect for the database type.
|
||||
func GetDialect(dbType DatabaseType) SQLDialect {
|
||||
switch dbType {
|
||||
case SQLite:
|
||||
return &SQLiteDialect{}
|
||||
case MySQL:
|
||||
return &MySQLDialect{}
|
||||
case PostgreSQL:
|
||||
return &PostgreSQLDialect{}
|
||||
default:
|
||||
return &SQLiteDialect{} // default fallback
|
||||
}
|
||||
}
|
||||
|
||||
// SQLiteDialect implements SQLDialect for SQLite.
|
||||
type SQLiteDialect struct{}
|
||||
|
||||
func (*SQLiteDialect) GetTablePrefix(entityName string) string {
|
||||
return fmt.Sprintf("`%s`", entityName)
|
||||
}
|
||||
|
||||
func (*SQLiteDialect) GetParameterPlaceholder(_ int) string {
|
||||
return "?"
|
||||
}
|
||||
|
||||
func (d *SQLiteDialect) GetJSONExtract(path string) string {
|
||||
return fmt.Sprintf("JSON_EXTRACT(%s.`payload`, '%s')", d.GetTablePrefix("memo"), path)
|
||||
}
|
||||
|
||||
func (d *SQLiteDialect) GetJSONArrayLength(path string) string {
|
||||
return fmt.Sprintf("JSON_ARRAY_LENGTH(COALESCE(%s, JSON_ARRAY()))", d.GetJSONExtract(path))
|
||||
}
|
||||
|
||||
func (d *SQLiteDialect) GetJSONContains(path, _ string) string {
|
||||
return fmt.Sprintf("%s LIKE ?", d.GetJSONExtract(path))
|
||||
}
|
||||
|
||||
func (d *SQLiteDialect) GetJSONLike(path, _ string) string {
|
||||
return fmt.Sprintf("%s LIKE ?", d.GetJSONExtract(path))
|
||||
}
|
||||
|
||||
func (*SQLiteDialect) GetBooleanValue(value bool) interface{} {
|
||||
if value {
|
||||
return 1
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
func (d *SQLiteDialect) GetBooleanComparison(path string, value bool) string {
|
||||
if value {
|
||||
return fmt.Sprintf("%s = 1", d.GetJSONExtract(path))
|
||||
}
|
||||
return fmt.Sprintf("%s = 0", d.GetJSONExtract(path))
|
||||
}
|
||||
|
||||
func (d *SQLiteDialect) GetBooleanCheck(path string) string {
|
||||
return fmt.Sprintf("%s IS TRUE", d.GetJSONExtract(path))
|
||||
}
|
||||
|
||||
func (d *SQLiteDialect) GetTimestampComparison(field string) string {
|
||||
return fmt.Sprintf("%s.`%s`", d.GetTablePrefix("memo"), field)
|
||||
}
|
||||
|
||||
func (*SQLiteDialect) GetCurrentTimestamp() string {
|
||||
return "strftime('%s', 'now')"
|
||||
}
|
||||
|
||||
// MySQLDialect implements SQLDialect for MySQL.
|
||||
type MySQLDialect struct{}
|
||||
|
||||
func (*MySQLDialect) GetTablePrefix(entityName string) string {
|
||||
return fmt.Sprintf("`%s`", entityName)
|
||||
}
|
||||
|
||||
func (*MySQLDialect) GetParameterPlaceholder(_ int) string {
|
||||
return "?"
|
||||
}
|
||||
|
||||
func (d *MySQLDialect) GetJSONExtract(path string) string {
|
||||
return fmt.Sprintf("JSON_EXTRACT(%s.`payload`, '%s')", d.GetTablePrefix("memo"), path)
|
||||
}
|
||||
|
||||
func (d *MySQLDialect) GetJSONArrayLength(path string) string {
|
||||
return fmt.Sprintf("JSON_LENGTH(COALESCE(%s, JSON_ARRAY()))", d.GetJSONExtract(path))
|
||||
}
|
||||
|
||||
func (d *MySQLDialect) GetJSONContains(path, _ string) string {
|
||||
return fmt.Sprintf("JSON_CONTAINS(%s, ?)", d.GetJSONExtract(path))
|
||||
}
|
||||
|
||||
func (d *MySQLDialect) GetJSONLike(path, _ string) string {
|
||||
return fmt.Sprintf("%s LIKE ?", d.GetJSONExtract(path))
|
||||
}
|
||||
|
||||
func (*MySQLDialect) GetBooleanValue(value bool) interface{} {
|
||||
return value
|
||||
}
|
||||
|
||||
func (d *MySQLDialect) GetBooleanComparison(path string, value bool) string {
|
||||
if value {
|
||||
return fmt.Sprintf("%s = CAST('true' AS JSON)", d.GetJSONExtract(path))
|
||||
}
|
||||
return fmt.Sprintf("%s != CAST('true' AS JSON)", d.GetJSONExtract(path))
|
||||
}
|
||||
|
||||
func (d *MySQLDialect) GetBooleanCheck(path string) string {
|
||||
return fmt.Sprintf("%s = CAST('true' AS JSON)", d.GetJSONExtract(path))
|
||||
}
|
||||
|
||||
func (d *MySQLDialect) GetTimestampComparison(field string) string {
|
||||
return fmt.Sprintf("UNIX_TIMESTAMP(%s.`%s`)", d.GetTablePrefix("memo"), field)
|
||||
}
|
||||
|
||||
func (*MySQLDialect) GetCurrentTimestamp() string {
|
||||
return "UNIX_TIMESTAMP()"
|
||||
}
|
||||
|
||||
// PostgreSQLDialect implements SQLDialect for PostgreSQL.
|
||||
type PostgreSQLDialect struct{}
|
||||
|
||||
func (*PostgreSQLDialect) GetTablePrefix(entityName string) string {
|
||||
return entityName
|
||||
}
|
||||
|
||||
func (*PostgreSQLDialect) GetParameterPlaceholder(index int) string {
|
||||
return fmt.Sprintf("$%d", index)
|
||||
}
|
||||
|
||||
func (d *PostgreSQLDialect) GetJSONExtract(path string) string {
|
||||
// Convert $.property.hasTaskList to memo.payload->'property'->>'hasTaskList'
|
||||
parts := strings.Split(strings.TrimPrefix(path, "$."), ".")
|
||||
result := fmt.Sprintf("%s.payload", d.GetTablePrefix("memo"))
|
||||
for i, part := range parts {
|
||||
if i == len(parts)-1 {
|
||||
result += fmt.Sprintf("->>'%s'", part)
|
||||
} else {
|
||||
result += fmt.Sprintf("->'%s'", part)
|
||||
}
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
func (d *PostgreSQLDialect) GetJSONArrayLength(path string) string {
|
||||
jsonPath := strings.Replace(path, "$.tags", "payload->'tags'", 1)
|
||||
return fmt.Sprintf("jsonb_array_length(COALESCE(%s.%s, '[]'::jsonb))", d.GetTablePrefix("memo"), jsonPath)
|
||||
}
|
||||
|
||||
func (d *PostgreSQLDialect) GetJSONContains(path, _ string) string {
|
||||
jsonPath := strings.Replace(path, "$.tags", "payload->'tags'", 1)
|
||||
return fmt.Sprintf("%s.%s @> jsonb_build_array(?::json)", d.GetTablePrefix("memo"), jsonPath)
|
||||
}
|
||||
|
||||
func (d *PostgreSQLDialect) GetJSONLike(path, _ string) string {
|
||||
jsonPath := strings.Replace(path, "$.tags", "payload->'tags'", 1)
|
||||
return fmt.Sprintf("%s.%s @> jsonb_build_array(?::json)", d.GetTablePrefix("memo"), jsonPath)
|
||||
}
|
||||
|
||||
func (*PostgreSQLDialect) GetBooleanValue(value bool) interface{} {
|
||||
return value
|
||||
}
|
||||
|
||||
func (d *PostgreSQLDialect) GetBooleanComparison(path string, _ bool) string {
|
||||
// Note: The parameter placeholder will be replaced by the caller
|
||||
return fmt.Sprintf("(%s)::boolean = ?", d.GetJSONExtract(path))
|
||||
}
|
||||
|
||||
func (d *PostgreSQLDialect) GetBooleanCheck(path string) string {
|
||||
return fmt.Sprintf("(%s)::boolean IS TRUE", d.GetJSONExtract(path))
|
||||
}
|
||||
|
||||
func (d *PostgreSQLDialect) GetTimestampComparison(field string) string {
|
||||
return fmt.Sprintf("EXTRACT(EPOCH FROM TO_TIMESTAMP(%s.%s))", d.GetTablePrefix("memo"), field)
|
||||
}
|
||||
|
||||
func (*PostgreSQLDialect) GetCurrentTimestamp() string {
|
||||
return "EXTRACT(EPOCH FROM NOW())"
|
||||
}
|
||||
@ -0,0 +1,180 @@
|
||||
package filter
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"strings"
|
||||
"sync"
|
||||
|
||||
"github.com/google/cel-go/cel"
|
||||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
// Engine parses CEL filters into a dialect-agnostic condition tree.
|
||||
type Engine struct {
|
||||
schema Schema
|
||||
env *cel.Env
|
||||
}
|
||||
|
||||
// NewEngine builds a new Engine for the provided schema.
|
||||
func NewEngine(schema Schema) (*Engine, error) {
|
||||
env, err := cel.NewEnv(schema.EnvOptions...)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "failed to create CEL environment")
|
||||
}
|
||||
return &Engine{
|
||||
schema: schema,
|
||||
env: env,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Program stores a compiled filter condition.
|
||||
type Program struct {
|
||||
schema Schema
|
||||
condition Condition
|
||||
}
|
||||
|
||||
// ConditionTree exposes the underlying condition tree.
|
||||
func (p *Program) ConditionTree() Condition {
|
||||
return p.condition
|
||||
}
|
||||
|
||||
// Compile parses the filter string into an executable program.
|
||||
func (e *Engine) Compile(_ context.Context, filter string) (*Program, error) {
|
||||
if strings.TrimSpace(filter) == "" {
|
||||
return nil, errors.New("filter expression is empty")
|
||||
}
|
||||
|
||||
filter = normalizeLegacyFilter(filter)
|
||||
|
||||
ast, issues := e.env.Compile(filter)
|
||||
if issues != nil && issues.Err() != nil {
|
||||
return nil, errors.Wrap(issues.Err(), "failed to compile filter")
|
||||
}
|
||||
parsed, err := cel.AstToParsedExpr(ast)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "failed to convert AST")
|
||||
}
|
||||
|
||||
cond, err := buildCondition(parsed.GetExpr(), e.schema)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return &Program{
|
||||
schema: e.schema,
|
||||
condition: cond,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// CompileToStatement compiles and renders the filter in a single step.
|
||||
func (e *Engine) CompileToStatement(ctx context.Context, filter string, opts RenderOptions) (Statement, error) {
|
||||
program, err := e.Compile(ctx, filter)
|
||||
if err != nil {
|
||||
return Statement{}, err
|
||||
}
|
||||
return program.Render(opts)
|
||||
}
|
||||
|
||||
// RenderOptions configure SQL rendering.
|
||||
type RenderOptions struct {
|
||||
Dialect DialectName
|
||||
PlaceholderOffset int
|
||||
DisableNullChecks bool
|
||||
}
|
||||
|
||||
// Statement contains the rendered SQL fragment and its args.
|
||||
type Statement struct {
|
||||
SQL string
|
||||
Args []any
|
||||
}
|
||||
|
||||
// Render converts the program into a dialect-specific SQL fragment.
|
||||
func (p *Program) Render(opts RenderOptions) (Statement, error) {
|
||||
renderer := newRenderer(p.schema, opts)
|
||||
return renderer.Render(p.condition)
|
||||
}
|
||||
|
||||
var (
|
||||
defaultOnce sync.Once
|
||||
defaultInst *Engine
|
||||
defaultErr error
|
||||
)
|
||||
|
||||
// DefaultEngine returns the process-wide memo filter engine.
|
||||
func DefaultEngine() (*Engine, error) {
|
||||
defaultOnce.Do(func() {
|
||||
defaultInst, defaultErr = NewEngine(NewSchema())
|
||||
})
|
||||
return defaultInst, defaultErr
|
||||
}
|
||||
|
||||
func normalizeLegacyFilter(expr string) string {
|
||||
expr = rewriteNumericLogicalOperand(expr, "&&")
|
||||
expr = rewriteNumericLogicalOperand(expr, "||")
|
||||
return expr
|
||||
}
|
||||
|
||||
func rewriteNumericLogicalOperand(expr, op string) string {
|
||||
var builder strings.Builder
|
||||
n := len(expr)
|
||||
i := 0
|
||||
var inQuote rune
|
||||
|
||||
for i < n {
|
||||
ch := expr[i]
|
||||
|
||||
if inQuote != 0 {
|
||||
builder.WriteByte(ch)
|
||||
if ch == '\\' && i+1 < n {
|
||||
builder.WriteByte(expr[i+1])
|
||||
i += 2
|
||||
continue
|
||||
}
|
||||
if ch == byte(inQuote) {
|
||||
inQuote = 0
|
||||
}
|
||||
i++
|
||||
continue
|
||||
}
|
||||
|
||||
if ch == '\'' || ch == '"' {
|
||||
inQuote = rune(ch)
|
||||
builder.WriteByte(ch)
|
||||
i++
|
||||
continue
|
||||
}
|
||||
|
||||
if strings.HasPrefix(expr[i:], op) {
|
||||
builder.WriteString(op)
|
||||
i += len(op)
|
||||
|
||||
// Preserve whitespace following the operator.
|
||||
wsStart := i
|
||||
for i < n && (expr[i] == ' ' || expr[i] == '\t') {
|
||||
i++
|
||||
}
|
||||
builder.WriteString(expr[wsStart:i])
|
||||
|
||||
signStart := i
|
||||
if i < n && (expr[i] == '+' || expr[i] == '-') {
|
||||
i++
|
||||
}
|
||||
for i < n && expr[i] >= '0' && expr[i] <= '9' {
|
||||
i++
|
||||
}
|
||||
if i > signStart {
|
||||
numLiteral := expr[signStart:i]
|
||||
builder.WriteString(fmt.Sprintf("(%s != 0)", numLiteral))
|
||||
} else {
|
||||
builder.WriteString(expr[signStart:i])
|
||||
}
|
||||
continue
|
||||
}
|
||||
|
||||
builder.WriteByte(ch)
|
||||
i++
|
||||
}
|
||||
|
||||
return builder.String()
|
||||
}
|
||||
@ -1,127 +0,0 @@
|
||||
package filter
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"time"
|
||||
|
||||
exprv1 "google.golang.org/genproto/googleapis/api/expr/v1alpha1"
|
||||
)
|
||||
|
||||
// GetConstValue returns the constant value of the expression.
|
||||
func GetConstValue(expr *exprv1.Expr) (any, error) {
|
||||
v, ok := expr.ExprKind.(*exprv1.Expr_ConstExpr)
|
||||
if !ok {
|
||||
return nil, errors.New("invalid constant expression")
|
||||
}
|
||||
|
||||
switch v.ConstExpr.ConstantKind.(type) {
|
||||
case *exprv1.Constant_StringValue:
|
||||
return v.ConstExpr.GetStringValue(), nil
|
||||
case *exprv1.Constant_Int64Value:
|
||||
return v.ConstExpr.GetInt64Value(), nil
|
||||
case *exprv1.Constant_Uint64Value:
|
||||
return v.ConstExpr.GetUint64Value(), nil
|
||||
case *exprv1.Constant_DoubleValue:
|
||||
return v.ConstExpr.GetDoubleValue(), nil
|
||||
case *exprv1.Constant_BoolValue:
|
||||
return v.ConstExpr.GetBoolValue(), nil
|
||||
default:
|
||||
return nil, errors.New("unexpected constant type")
|
||||
}
|
||||
}
|
||||
|
||||
// GetIdentExprName returns the name of the identifier expression.
|
||||
func GetIdentExprName(expr *exprv1.Expr) (string, error) {
|
||||
_, ok := expr.ExprKind.(*exprv1.Expr_IdentExpr)
|
||||
if !ok {
|
||||
return "", errors.New("invalid identifier expression")
|
||||
}
|
||||
return expr.GetIdentExpr().GetName(), nil
|
||||
}
|
||||
|
||||
// GetFunctionValue evaluates CEL function calls and returns their value.
|
||||
// This is specifically for time functions like now().
|
||||
func GetFunctionValue(expr *exprv1.Expr) (any, error) {
|
||||
callExpr, ok := expr.ExprKind.(*exprv1.Expr_CallExpr)
|
||||
if !ok {
|
||||
return nil, errors.New("invalid function call expression")
|
||||
}
|
||||
|
||||
switch callExpr.CallExpr.Function {
|
||||
case "now":
|
||||
if len(callExpr.CallExpr.Args) != 0 {
|
||||
return nil, errors.New("now() function takes no arguments")
|
||||
}
|
||||
return time.Now().Unix(), nil
|
||||
case "_-_":
|
||||
// Handle subtraction for expressions like "now() - 60 * 60 * 24"
|
||||
if len(callExpr.CallExpr.Args) != 2 {
|
||||
return nil, errors.New("subtraction requires exactly two arguments")
|
||||
}
|
||||
left, err := GetExprValue(callExpr.CallExpr.Args[0])
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
right, err := GetExprValue(callExpr.CallExpr.Args[1])
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
leftInt, ok1 := left.(int64)
|
||||
rightInt, ok2 := right.(int64)
|
||||
if !ok1 || !ok2 {
|
||||
return nil, errors.New("subtraction operands must be integers")
|
||||
}
|
||||
return leftInt - rightInt, nil
|
||||
case "_*_":
|
||||
// Handle multiplication for expressions like "60 * 60 * 24"
|
||||
if len(callExpr.CallExpr.Args) != 2 {
|
||||
return nil, errors.New("multiplication requires exactly two arguments")
|
||||
}
|
||||
left, err := GetExprValue(callExpr.CallExpr.Args[0])
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
right, err := GetExprValue(callExpr.CallExpr.Args[1])
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
leftInt, ok1 := left.(int64)
|
||||
rightInt, ok2 := right.(int64)
|
||||
if !ok1 || !ok2 {
|
||||
return nil, errors.New("multiplication operands must be integers")
|
||||
}
|
||||
return leftInt * rightInt, nil
|
||||
case "_+_":
|
||||
// Handle addition
|
||||
if len(callExpr.CallExpr.Args) != 2 {
|
||||
return nil, errors.New("addition requires exactly two arguments")
|
||||
}
|
||||
left, err := GetExprValue(callExpr.CallExpr.Args[0])
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
right, err := GetExprValue(callExpr.CallExpr.Args[1])
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
leftInt, ok1 := left.(int64)
|
||||
rightInt, ok2 := right.(int64)
|
||||
if !ok1 || !ok2 {
|
||||
return nil, errors.New("addition operands must be integers")
|
||||
}
|
||||
return leftInt + rightInt, nil
|
||||
default:
|
||||
return nil, errors.New("unsupported function: " + callExpr.CallExpr.Function)
|
||||
}
|
||||
}
|
||||
|
||||
// GetExprValue attempts to get a value from an expression, trying constants first, then functions.
|
||||
func GetExprValue(expr *exprv1.Expr) (any, error) {
|
||||
// Try to get constant value first
|
||||
if constValue, err := GetConstValue(expr); err == nil {
|
||||
return constValue, nil
|
||||
}
|
||||
|
||||
// If not a constant, try to evaluate as a function
|
||||
return GetFunctionValue(expr)
|
||||
}
|
||||
@ -1,66 +0,0 @@
|
||||
package filter
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"github.com/google/cel-go/cel"
|
||||
"github.com/google/cel-go/common/types"
|
||||
"github.com/google/cel-go/common/types/ref"
|
||||
"github.com/pkg/errors"
|
||||
exprv1 "google.golang.org/genproto/googleapis/api/expr/v1alpha1"
|
||||
)
|
||||
|
||||
// MemoFilterCELAttributes are the CEL attributes for memo.
|
||||
var MemoFilterCELAttributes = []cel.EnvOption{
|
||||
cel.Variable("content", cel.StringType),
|
||||
cel.Variable("creator_id", cel.IntType),
|
||||
cel.Variable("created_ts", cel.IntType),
|
||||
cel.Variable("updated_ts", cel.IntType),
|
||||
cel.Variable("pinned", cel.BoolType),
|
||||
cel.Variable("tag", cel.StringType),
|
||||
cel.Variable("tags", cel.ListType(cel.StringType)),
|
||||
cel.Variable("visibility", cel.StringType),
|
||||
cel.Variable("has_task_list", cel.BoolType),
|
||||
cel.Variable("has_link", cel.BoolType),
|
||||
cel.Variable("has_code", cel.BoolType),
|
||||
cel.Variable("has_incomplete_tasks", cel.BoolType),
|
||||
// Current timestamp function.
|
||||
cel.Function("now",
|
||||
cel.Overload("now",
|
||||
[]*cel.Type{},
|
||||
cel.IntType,
|
||||
cel.FunctionBinding(func(_ ...ref.Val) ref.Val {
|
||||
return types.Int(time.Now().Unix())
|
||||
}),
|
||||
),
|
||||
),
|
||||
}
|
||||
|
||||
// ReactionFilterCELAttributes are the CEL attributes for reaction.
|
||||
var ReactionFilterCELAttributes = []cel.EnvOption{
|
||||
cel.Variable("content_id", cel.StringType),
|
||||
}
|
||||
|
||||
// UserFilterCELAttributes are the CEL attributes for user.
|
||||
var UserFilterCELAttributes = []cel.EnvOption{
|
||||
cel.Variable("username", cel.StringType),
|
||||
}
|
||||
|
||||
// AttachmentFilterCELAttributes are the CEL attributes for user.
|
||||
var AttachmentFilterCELAttributes = []cel.EnvOption{
|
||||
cel.Variable("memo_id", cel.StringType),
|
||||
}
|
||||
|
||||
// Parse parses the filter string and returns the parsed expression.
|
||||
// The filter string should be a CEL expression.
|
||||
func Parse(filter string, opts ...cel.EnvOption) (expr *exprv1.ParsedExpr, err error) {
|
||||
e, err := cel.NewEnv(opts...)
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "failed to create CEL environment")
|
||||
}
|
||||
ast, issues := e.Compile(filter)
|
||||
if issues != nil {
|
||||
return nil, errors.Errorf("failed to compile filter: %v", issues)
|
||||
}
|
||||
return cel.AstToParsedExpr(ast)
|
||||
}
|
||||
@ -0,0 +1,25 @@
|
||||
package filter
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
)
|
||||
|
||||
// AppendConditions compiles the provided filters and appends the resulting SQL fragments and args.
|
||||
func AppendConditions(ctx context.Context, engine *Engine, filters []string, dialect DialectName, where *[]string, args *[]any) error {
|
||||
for _, filterStr := range filters {
|
||||
stmt, err := engine.CompileToStatement(ctx, filterStr, RenderOptions{
|
||||
Dialect: dialect,
|
||||
PlaceholderOffset: len(*args),
|
||||
})
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if stmt.SQL == "" {
|
||||
continue
|
||||
}
|
||||
*where = append(*where, fmt.Sprintf("(%s)", stmt.SQL))
|
||||
*args = append(*args, stmt.Args...)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
@ -0,0 +1,116 @@
|
||||
package filter
|
||||
|
||||
// Condition represents a boolean expression derived from the CEL filter.
|
||||
type Condition interface {
|
||||
isCondition()
|
||||
}
|
||||
|
||||
// LogicalOperator enumerates the supported logical operators.
|
||||
type LogicalOperator string
|
||||
|
||||
const (
|
||||
LogicalAnd LogicalOperator = "AND"
|
||||
LogicalOr LogicalOperator = "OR"
|
||||
)
|
||||
|
||||
// LogicalCondition composes two conditions with a logical operator.
|
||||
type LogicalCondition struct {
|
||||
Operator LogicalOperator
|
||||
Left Condition
|
||||
Right Condition
|
||||
}
|
||||
|
||||
func (*LogicalCondition) isCondition() {}
|
||||
|
||||
// NotCondition negates a child condition.
|
||||
type NotCondition struct {
|
||||
Expr Condition
|
||||
}
|
||||
|
||||
func (*NotCondition) isCondition() {}
|
||||
|
||||
// FieldPredicateCondition asserts that a field evaluates to true.
|
||||
type FieldPredicateCondition struct {
|
||||
Field string
|
||||
}
|
||||
|
||||
func (*FieldPredicateCondition) isCondition() {}
|
||||
|
||||
// ComparisonOperator lists supported comparison operators.
|
||||
type ComparisonOperator string
|
||||
|
||||
const (
|
||||
CompareEq ComparisonOperator = "="
|
||||
CompareNeq ComparisonOperator = "!="
|
||||
CompareLt ComparisonOperator = "<"
|
||||
CompareLte ComparisonOperator = "<="
|
||||
CompareGt ComparisonOperator = ">"
|
||||
CompareGte ComparisonOperator = ">="
|
||||
)
|
||||
|
||||
// ComparisonCondition represents a binary comparison.
|
||||
type ComparisonCondition struct {
|
||||
Left ValueExpr
|
||||
Operator ComparisonOperator
|
||||
Right ValueExpr
|
||||
}
|
||||
|
||||
func (*ComparisonCondition) isCondition() {}
|
||||
|
||||
// InCondition represents an IN predicate with literal list values.
|
||||
type InCondition struct {
|
||||
Left ValueExpr
|
||||
Values []ValueExpr
|
||||
}
|
||||
|
||||
func (*InCondition) isCondition() {}
|
||||
|
||||
// ElementInCondition represents the CEL syntax `"value" in field`.
|
||||
type ElementInCondition struct {
|
||||
Element ValueExpr
|
||||
Field string
|
||||
}
|
||||
|
||||
func (*ElementInCondition) isCondition() {}
|
||||
|
||||
// ContainsCondition models the <field>.contains(<value>) call.
|
||||
type ContainsCondition struct {
|
||||
Field string
|
||||
Value string
|
||||
}
|
||||
|
||||
func (*ContainsCondition) isCondition() {}
|
||||
|
||||
// ConstantCondition captures a literal boolean outcome.
|
||||
type ConstantCondition struct {
|
||||
Value bool
|
||||
}
|
||||
|
||||
func (*ConstantCondition) isCondition() {}
|
||||
|
||||
// ValueExpr models arithmetic or scalar expressions whose result feeds a comparison.
|
||||
type ValueExpr interface {
|
||||
isValueExpr()
|
||||
}
|
||||
|
||||
// FieldRef references a named schema field.
|
||||
type FieldRef struct {
|
||||
Name string
|
||||
}
|
||||
|
||||
func (*FieldRef) isValueExpr() {}
|
||||
|
||||
// LiteralValue holds a literal scalar.
|
||||
type LiteralValue struct {
|
||||
Value interface{}
|
||||
}
|
||||
|
||||
func (*LiteralValue) isValueExpr() {}
|
||||
|
||||
// FunctionValue captures simple function calls like size(tags).
|
||||
type FunctionValue struct {
|
||||
Name string
|
||||
Args []ValueExpr
|
||||
}
|
||||
|
||||
func (*FunctionValue) isValueExpr() {}
|
||||
@ -0,0 +1,413 @@
|
||||
package filter
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
exprv1 "google.golang.org/genproto/googleapis/api/expr/v1alpha1"
|
||||
)
|
||||
|
||||
func buildCondition(expr *exprv1.Expr, schema Schema) (Condition, error) {
|
||||
switch v := expr.ExprKind.(type) {
|
||||
case *exprv1.Expr_CallExpr:
|
||||
return buildCallCondition(v.CallExpr, schema)
|
||||
case *exprv1.Expr_ConstExpr:
|
||||
val, err := getConstValue(expr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
switch v := val.(type) {
|
||||
case bool:
|
||||
return &ConstantCondition{Value: v}, nil
|
||||
case int64:
|
||||
return &ConstantCondition{Value: v != 0}, nil
|
||||
case float64:
|
||||
return &ConstantCondition{Value: v != 0}, nil
|
||||
default:
|
||||
return nil, errors.New("filter must evaluate to a boolean value")
|
||||
}
|
||||
case *exprv1.Expr_IdentExpr:
|
||||
name := v.IdentExpr.GetName()
|
||||
field, ok := schema.Field(name)
|
||||
if !ok {
|
||||
return nil, errors.Errorf("unknown identifier %q", name)
|
||||
}
|
||||
if field.Type != FieldTypeBool {
|
||||
return nil, errors.Errorf("identifier %q is not boolean", name)
|
||||
}
|
||||
return &FieldPredicateCondition{Field: name}, nil
|
||||
default:
|
||||
return nil, errors.New("unsupported top-level expression")
|
||||
}
|
||||
}
|
||||
|
||||
func buildCallCondition(call *exprv1.Expr_Call, schema Schema) (Condition, error) {
|
||||
switch call.Function {
|
||||
case "_&&_":
|
||||
if len(call.Args) != 2 {
|
||||
return nil, errors.New("logical AND expects two arguments")
|
||||
}
|
||||
left, err := buildCondition(call.Args[0], schema)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
right, err := buildCondition(call.Args[1], schema)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &LogicalCondition{
|
||||
Operator: LogicalAnd,
|
||||
Left: left,
|
||||
Right: right,
|
||||
}, nil
|
||||
case "_||_":
|
||||
if len(call.Args) != 2 {
|
||||
return nil, errors.New("logical OR expects two arguments")
|
||||
}
|
||||
left, err := buildCondition(call.Args[0], schema)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
right, err := buildCondition(call.Args[1], schema)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &LogicalCondition{
|
||||
Operator: LogicalOr,
|
||||
Left: left,
|
||||
Right: right,
|
||||
}, nil
|
||||
case "!_":
|
||||
if len(call.Args) != 1 {
|
||||
return nil, errors.New("logical NOT expects one argument")
|
||||
}
|
||||
child, err := buildCondition(call.Args[0], schema)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &NotCondition{Expr: child}, nil
|
||||
case "_==_", "_!=_", "_<_", "_>_", "_<=_", "_>=_":
|
||||
return buildComparisonCondition(call, schema)
|
||||
case "@in":
|
||||
return buildInCondition(call, schema)
|
||||
case "contains":
|
||||
return buildContainsCondition(call, schema)
|
||||
default:
|
||||
val, ok, err := evaluateBool(call)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if ok {
|
||||
return &ConstantCondition{Value: val}, nil
|
||||
}
|
||||
return nil, errors.Errorf("unsupported call expression %q", call.Function)
|
||||
}
|
||||
}
|
||||
|
||||
func buildComparisonCondition(call *exprv1.Expr_Call, schema Schema) (Condition, error) {
|
||||
if len(call.Args) != 2 {
|
||||
return nil, errors.New("comparison expects two arguments")
|
||||
}
|
||||
op, err := toComparisonOperator(call.Function)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
left, err := buildValueExpr(call.Args[0], schema)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
right, err := buildValueExpr(call.Args[1], schema)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// If the left side is a field, validate allowed operators.
|
||||
if field, ok := left.(*FieldRef); ok {
|
||||
def, exists := schema.Field(field.Name)
|
||||
if !exists {
|
||||
return nil, errors.Errorf("unknown identifier %q", field.Name)
|
||||
}
|
||||
if def.Kind == FieldKindVirtualAlias {
|
||||
def, exists = schema.ResolveAlias(field.Name)
|
||||
if !exists {
|
||||
return nil, errors.Errorf("invalid alias %q", field.Name)
|
||||
}
|
||||
}
|
||||
if def.AllowedComparisonOps != nil {
|
||||
if _, allowed := def.AllowedComparisonOps[op]; !allowed {
|
||||
return nil, errors.Errorf("operator %s not allowed for field %q", op, field.Name)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return &ComparisonCondition{
|
||||
Left: left,
|
||||
Operator: op,
|
||||
Right: right,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func buildInCondition(call *exprv1.Expr_Call, schema Schema) (Condition, error) {
|
||||
if len(call.Args) != 2 {
|
||||
return nil, errors.New("in operator expects two arguments")
|
||||
}
|
||||
|
||||
// Handle identifier in list syntax.
|
||||
if identName, err := getIdentName(call.Args[0]); err == nil {
|
||||
if field, ok := schema.Field(identName); ok && field.Kind == FieldKindVirtualAlias {
|
||||
if _, aliasOk := schema.ResolveAlias(identName); !aliasOk {
|
||||
return nil, errors.Errorf("invalid alias %q", identName)
|
||||
}
|
||||
} else if !ok {
|
||||
return nil, errors.Errorf("unknown identifier %q", identName)
|
||||
}
|
||||
|
||||
if listExpr := call.Args[1].GetListExpr(); listExpr != nil {
|
||||
values := make([]ValueExpr, 0, len(listExpr.Elements))
|
||||
for _, element := range listExpr.Elements {
|
||||
value, err := buildValueExpr(element, schema)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
values = append(values, value)
|
||||
}
|
||||
return &InCondition{
|
||||
Left: &FieldRef{Name: identName},
|
||||
Values: values,
|
||||
}, nil
|
||||
}
|
||||
}
|
||||
|
||||
// Handle "value in identifier" syntax.
|
||||
if identName, err := getIdentName(call.Args[1]); err == nil {
|
||||
if _, ok := schema.Field(identName); !ok {
|
||||
return nil, errors.Errorf("unknown identifier %q", identName)
|
||||
}
|
||||
element, err := buildValueExpr(call.Args[0], schema)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &ElementInCondition{
|
||||
Element: element,
|
||||
Field: identName,
|
||||
}, nil
|
||||
}
|
||||
|
||||
return nil, errors.New("invalid use of in operator")
|
||||
}
|
||||
|
||||
func buildContainsCondition(call *exprv1.Expr_Call, schema Schema) (Condition, error) {
|
||||
if call.Target == nil {
|
||||
return nil, errors.New("contains requires a target")
|
||||
}
|
||||
targetName, err := getIdentName(call.Target)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
field, ok := schema.Field(targetName)
|
||||
if !ok {
|
||||
return nil, errors.Errorf("unknown identifier %q", targetName)
|
||||
}
|
||||
if !field.SupportsContains {
|
||||
return nil, errors.Errorf("identifier %q does not support contains()", targetName)
|
||||
}
|
||||
if len(call.Args) != 1 {
|
||||
return nil, errors.New("contains expects exactly one argument")
|
||||
}
|
||||
value, err := getConstValue(call.Args[0])
|
||||
if err != nil {
|
||||
return nil, errors.Wrap(err, "contains only supports literal arguments")
|
||||
}
|
||||
str, ok := value.(string)
|
||||
if !ok {
|
||||
return nil, errors.New("contains argument must be a string")
|
||||
}
|
||||
return &ContainsCondition{
|
||||
Field: targetName,
|
||||
Value: str,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func buildValueExpr(expr *exprv1.Expr, schema Schema) (ValueExpr, error) {
|
||||
if identName, err := getIdentName(expr); err == nil {
|
||||
if _, ok := schema.Field(identName); !ok {
|
||||
return nil, errors.Errorf("unknown identifier %q", identName)
|
||||
}
|
||||
return &FieldRef{Name: identName}, nil
|
||||
}
|
||||
|
||||
if literal, err := getConstValue(expr); err == nil {
|
||||
return &LiteralValue{Value: literal}, nil
|
||||
}
|
||||
|
||||
if value, ok, err := evaluateNumeric(expr); err != nil {
|
||||
return nil, err
|
||||
} else if ok {
|
||||
return &LiteralValue{Value: value}, nil
|
||||
}
|
||||
|
||||
if boolVal, ok, err := evaluateBoolExpr(expr); err != nil {
|
||||
return nil, err
|
||||
} else if ok {
|
||||
return &LiteralValue{Value: boolVal}, nil
|
||||
}
|
||||
|
||||
if call := expr.GetCallExpr(); call != nil {
|
||||
switch call.Function {
|
||||
case "size":
|
||||
if len(call.Args) != 1 {
|
||||
return nil, errors.New("size() expects one argument")
|
||||
}
|
||||
arg, err := buildValueExpr(call.Args[0], schema)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &FunctionValue{
|
||||
Name: "size",
|
||||
Args: []ValueExpr{arg},
|
||||
}, nil
|
||||
case "now":
|
||||
return &LiteralValue{Value: timeNowUnix()}, nil
|
||||
case "_+_", "_-_", "_*_":
|
||||
value, ok, err := evaluateNumeric(expr)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if ok {
|
||||
return &LiteralValue{Value: value}, nil
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil, errors.New("unsupported value expression")
|
||||
}
|
||||
|
||||
func toComparisonOperator(fn string) (ComparisonOperator, error) {
|
||||
switch fn {
|
||||
case "_==_":
|
||||
return CompareEq, nil
|
||||
case "_!=_":
|
||||
return CompareNeq, nil
|
||||
case "_<_":
|
||||
return CompareLt, nil
|
||||
case "_>_":
|
||||
return CompareGt, nil
|
||||
case "_<=_":
|
||||
return CompareLte, nil
|
||||
case "_>=_":
|
||||
return CompareGte, nil
|
||||
default:
|
||||
return "", errors.Errorf("unsupported comparison operator %q", fn)
|
||||
}
|
||||
}
|
||||
|
||||
func getIdentName(expr *exprv1.Expr) (string, error) {
|
||||
if ident := expr.GetIdentExpr(); ident != nil {
|
||||
return ident.GetName(), nil
|
||||
}
|
||||
return "", errors.New("expression is not an identifier")
|
||||
}
|
||||
|
||||
func getConstValue(expr *exprv1.Expr) (interface{}, error) {
|
||||
v, ok := expr.ExprKind.(*exprv1.Expr_ConstExpr)
|
||||
if !ok {
|
||||
return nil, errors.New("expression is not a literal")
|
||||
}
|
||||
switch x := v.ConstExpr.ConstantKind.(type) {
|
||||
case *exprv1.Constant_StringValue:
|
||||
return v.ConstExpr.GetStringValue(), nil
|
||||
case *exprv1.Constant_Int64Value:
|
||||
return v.ConstExpr.GetInt64Value(), nil
|
||||
case *exprv1.Constant_Uint64Value:
|
||||
return int64(v.ConstExpr.GetUint64Value()), nil
|
||||
case *exprv1.Constant_DoubleValue:
|
||||
return v.ConstExpr.GetDoubleValue(), nil
|
||||
case *exprv1.Constant_BoolValue:
|
||||
return v.ConstExpr.GetBoolValue(), nil
|
||||
case *exprv1.Constant_NullValue:
|
||||
return nil, nil
|
||||
default:
|
||||
return nil, errors.Errorf("unsupported constant %T", x)
|
||||
}
|
||||
}
|
||||
|
||||
func evaluateBool(call *exprv1.Expr_Call) (bool, bool, error) {
|
||||
val, ok, err := evaluateBoolExpr(&exprv1.Expr{ExprKind: &exprv1.Expr_CallExpr{CallExpr: call}})
|
||||
return val, ok, err
|
||||
}
|
||||
|
||||
func evaluateBoolExpr(expr *exprv1.Expr) (bool, bool, error) {
|
||||
if literal, err := getConstValue(expr); err == nil {
|
||||
if b, ok := literal.(bool); ok {
|
||||
return b, true, nil
|
||||
}
|
||||
return false, false, nil
|
||||
}
|
||||
if call := expr.GetCallExpr(); call != nil && call.Function == "!_" {
|
||||
if len(call.Args) != 1 {
|
||||
return false, false, errors.New("NOT expects exactly one argument")
|
||||
}
|
||||
val, ok, err := evaluateBoolExpr(call.Args[0])
|
||||
if err != nil || !ok {
|
||||
return false, false, err
|
||||
}
|
||||
return !val, true, nil
|
||||
}
|
||||
return false, false, nil
|
||||
}
|
||||
|
||||
func evaluateNumeric(expr *exprv1.Expr) (int64, bool, error) {
|
||||
if literal, err := getConstValue(expr); err == nil {
|
||||
switch v := literal.(type) {
|
||||
case int64:
|
||||
return v, true, nil
|
||||
case float64:
|
||||
return int64(v), true, nil
|
||||
}
|
||||
return 0, false, nil
|
||||
}
|
||||
|
||||
call := expr.GetCallExpr()
|
||||
if call == nil {
|
||||
return 0, false, nil
|
||||
}
|
||||
|
||||
switch call.Function {
|
||||
case "now":
|
||||
return timeNowUnix(), true, nil
|
||||
case "_+_", "_-_", "_*_":
|
||||
if len(call.Args) != 2 {
|
||||
return 0, false, errors.New("arithmetic requires two arguments")
|
||||
}
|
||||
left, ok, err := evaluateNumeric(call.Args[0])
|
||||
if err != nil {
|
||||
return 0, false, err
|
||||
}
|
||||
if !ok {
|
||||
return 0, false, nil
|
||||
}
|
||||
right, ok, err := evaluateNumeric(call.Args[1])
|
||||
if err != nil {
|
||||
return 0, false, err
|
||||
}
|
||||
if !ok {
|
||||
return 0, false, nil
|
||||
}
|
||||
switch call.Function {
|
||||
case "_+_":
|
||||
return left + right, true, nil
|
||||
case "_-_":
|
||||
return left - right, true, nil
|
||||
case "_*_":
|
||||
return left * right, true, nil
|
||||
}
|
||||
}
|
||||
|
||||
return 0, false, nil
|
||||
}
|
||||
|
||||
func timeNowUnix() int64 {
|
||||
return time.Now().Unix()
|
||||
}
|
||||
@ -0,0 +1,626 @@
|
||||
package filter
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/pkg/errors"
|
||||
)
|
||||
|
||||
type renderer struct {
|
||||
schema Schema
|
||||
dialect DialectName
|
||||
placeholderOffset int
|
||||
placeholderCounter int
|
||||
args []any
|
||||
}
|
||||
|
||||
type renderResult struct {
|
||||
sql string
|
||||
trivial bool
|
||||
unsatisfiable bool
|
||||
}
|
||||
|
||||
func newRenderer(schema Schema, opts RenderOptions) *renderer {
|
||||
return &renderer{
|
||||
schema: schema,
|
||||
dialect: opts.Dialect,
|
||||
placeholderOffset: opts.PlaceholderOffset,
|
||||
}
|
||||
}
|
||||
|
||||
func (r *renderer) Render(cond Condition) (Statement, error) {
|
||||
result, err := r.renderCondition(cond)
|
||||
if err != nil {
|
||||
return Statement{}, err
|
||||
}
|
||||
args := r.args
|
||||
if args == nil {
|
||||
args = []any{}
|
||||
}
|
||||
|
||||
switch {
|
||||
case result.unsatisfiable:
|
||||
return Statement{
|
||||
SQL: "1 = 0",
|
||||
Args: args,
|
||||
}, nil
|
||||
case result.trivial:
|
||||
return Statement{
|
||||
SQL: "",
|
||||
Args: args,
|
||||
}, nil
|
||||
default:
|
||||
return Statement{
|
||||
SQL: result.sql,
|
||||
Args: args,
|
||||
}, nil
|
||||
}
|
||||
}
|
||||
|
||||
func (r *renderer) renderCondition(cond Condition) (renderResult, error) {
|
||||
switch c := cond.(type) {
|
||||
case *LogicalCondition:
|
||||
return r.renderLogicalCondition(c)
|
||||
case *NotCondition:
|
||||
return r.renderNotCondition(c)
|
||||
case *FieldPredicateCondition:
|
||||
return r.renderFieldPredicate(c)
|
||||
case *ComparisonCondition:
|
||||
return r.renderComparison(c)
|
||||
case *InCondition:
|
||||
return r.renderInCondition(c)
|
||||
case *ElementInCondition:
|
||||
return r.renderElementInCondition(c)
|
||||
case *ContainsCondition:
|
||||
return r.renderContainsCondition(c)
|
||||
case *ConstantCondition:
|
||||
if c.Value {
|
||||
return renderResult{trivial: true}, nil
|
||||
}
|
||||
return renderResult{sql: "1 = 0", unsatisfiable: true}, nil
|
||||
default:
|
||||
return renderResult{}, errors.Errorf("unsupported condition type %T", c)
|
||||
}
|
||||
}
|
||||
|
||||
func (r *renderer) renderLogicalCondition(cond *LogicalCondition) (renderResult, error) {
|
||||
left, err := r.renderCondition(cond.Left)
|
||||
if err != nil {
|
||||
return renderResult{}, err
|
||||
}
|
||||
right, err := r.renderCondition(cond.Right)
|
||||
if err != nil {
|
||||
return renderResult{}, err
|
||||
}
|
||||
|
||||
switch cond.Operator {
|
||||
case LogicalAnd:
|
||||
return combineAnd(left, right), nil
|
||||
case LogicalOr:
|
||||
return combineOr(left, right), nil
|
||||
default:
|
||||
return renderResult{}, errors.Errorf("unsupported logical operator %s", cond.Operator)
|
||||
}
|
||||
}
|
||||
|
||||
func (r *renderer) renderNotCondition(cond *NotCondition) (renderResult, error) {
|
||||
child, err := r.renderCondition(cond.Expr)
|
||||
if err != nil {
|
||||
return renderResult{}, err
|
||||
}
|
||||
|
||||
if child.trivial {
|
||||
return renderResult{sql: "1 = 0", unsatisfiable: true}, nil
|
||||
}
|
||||
if child.unsatisfiable {
|
||||
return renderResult{trivial: true}, nil
|
||||
}
|
||||
return renderResult{
|
||||
sql: fmt.Sprintf("NOT (%s)", child.sql),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (r *renderer) renderFieldPredicate(cond *FieldPredicateCondition) (renderResult, error) {
|
||||
field, ok := r.schema.Field(cond.Field)
|
||||
if !ok {
|
||||
return renderResult{}, errors.Errorf("unknown field %q", cond.Field)
|
||||
}
|
||||
|
||||
switch field.Kind {
|
||||
case FieldKindBoolColumn:
|
||||
column := qualifyColumn(r.dialect, field.Column)
|
||||
return renderResult{
|
||||
sql: fmt.Sprintf("%s IS TRUE", column),
|
||||
}, nil
|
||||
case FieldKindJSONBool:
|
||||
sql, err := r.jsonBoolPredicate(field)
|
||||
if err != nil {
|
||||
return renderResult{}, err
|
||||
}
|
||||
return renderResult{sql: sql}, nil
|
||||
default:
|
||||
return renderResult{}, errors.Errorf("field %q cannot be used as a predicate", cond.Field)
|
||||
}
|
||||
}
|
||||
|
||||
func (r *renderer) renderComparison(cond *ComparisonCondition) (renderResult, error) {
|
||||
switch left := cond.Left.(type) {
|
||||
case *FieldRef:
|
||||
field, ok := r.schema.Field(left.Name)
|
||||
if !ok {
|
||||
return renderResult{}, errors.Errorf("unknown field %q", left.Name)
|
||||
}
|
||||
switch field.Kind {
|
||||
case FieldKindBoolColumn:
|
||||
return r.renderBoolColumnComparison(field, cond.Operator, cond.Right)
|
||||
case FieldKindJSONBool:
|
||||
return r.renderJSONBoolComparison(field, cond.Operator, cond.Right)
|
||||
case FieldKindScalar:
|
||||
return r.renderScalarComparison(field, cond.Operator, cond.Right)
|
||||
default:
|
||||
return renderResult{}, errors.Errorf("field %q does not support comparison", field.Name)
|
||||
}
|
||||
case *FunctionValue:
|
||||
return r.renderFunctionComparison(left, cond.Operator, cond.Right)
|
||||
default:
|
||||
return renderResult{}, errors.New("comparison must start with a field reference or supported function")
|
||||
}
|
||||
}
|
||||
|
||||
func (r *renderer) renderFunctionComparison(fn *FunctionValue, op ComparisonOperator, right ValueExpr) (renderResult, error) {
|
||||
if fn.Name != "size" {
|
||||
return renderResult{}, errors.Errorf("unsupported function %s in comparison", fn.Name)
|
||||
}
|
||||
if len(fn.Args) != 1 {
|
||||
return renderResult{}, errors.New("size() expects one argument")
|
||||
}
|
||||
fieldArg, ok := fn.Args[0].(*FieldRef)
|
||||
if !ok {
|
||||
return renderResult{}, errors.New("size() argument must be a field")
|
||||
}
|
||||
|
||||
field, ok := r.schema.Field(fieldArg.Name)
|
||||
if !ok {
|
||||
return renderResult{}, errors.Errorf("unknown field %q", fieldArg.Name)
|
||||
}
|
||||
if field.Kind != FieldKindJSONList {
|
||||
return renderResult{}, errors.Errorf("size() only supports tag lists, got %q", field.Name)
|
||||
}
|
||||
|
||||
value, err := expectNumericLiteral(right)
|
||||
if err != nil {
|
||||
return renderResult{}, err
|
||||
}
|
||||
|
||||
expr := jsonArrayLengthExpr(r.dialect, field)
|
||||
placeholder := r.addArg(value)
|
||||
return renderResult{
|
||||
sql: fmt.Sprintf("%s %s %s", expr, sqlOperator(op), placeholder),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (r *renderer) renderScalarComparison(field Field, op ComparisonOperator, right ValueExpr) (renderResult, error) {
|
||||
lit, err := expectLiteral(right)
|
||||
if err != nil {
|
||||
return renderResult{}, err
|
||||
}
|
||||
|
||||
columnExpr := field.columnExpr(r.dialect)
|
||||
placeholder := ""
|
||||
switch field.Type {
|
||||
case FieldTypeString:
|
||||
value, ok := lit.(string)
|
||||
if !ok {
|
||||
return renderResult{}, errors.Errorf("field %q expects string value", field.Name)
|
||||
}
|
||||
placeholder = r.addArg(value)
|
||||
case FieldTypeInt, FieldTypeTimestamp:
|
||||
num, err := toInt64(lit)
|
||||
if err != nil {
|
||||
return renderResult{}, errors.Wrapf(err, "field %q expects integer value", field.Name)
|
||||
}
|
||||
placeholder = r.addArg(num)
|
||||
default:
|
||||
return renderResult{}, errors.Errorf("unsupported data type %q for field %s", field.Type, field.Name)
|
||||
}
|
||||
|
||||
return renderResult{
|
||||
sql: fmt.Sprintf("%s %s %s", columnExpr, sqlOperator(op), placeholder),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (r *renderer) renderBoolColumnComparison(field Field, op ComparisonOperator, right ValueExpr) (renderResult, error) {
|
||||
value, err := expectBool(right)
|
||||
if err != nil {
|
||||
return renderResult{}, err
|
||||
}
|
||||
placeholder := r.addBoolArg(value)
|
||||
column := qualifyColumn(r.dialect, field.Column)
|
||||
return renderResult{
|
||||
sql: fmt.Sprintf("%s %s %s", column, sqlOperator(op), placeholder),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (r *renderer) renderJSONBoolComparison(field Field, op ComparisonOperator, right ValueExpr) (renderResult, error) {
|
||||
value, err := expectBool(right)
|
||||
if err != nil {
|
||||
return renderResult{}, err
|
||||
}
|
||||
|
||||
jsonExpr := jsonExtractExpr(r.dialect, field)
|
||||
switch r.dialect {
|
||||
case DialectSQLite:
|
||||
switch op {
|
||||
case CompareEq:
|
||||
if field.Name == "has_task_list" {
|
||||
target := "0"
|
||||
if value {
|
||||
target = "1"
|
||||
}
|
||||
return renderResult{sql: fmt.Sprintf("%s = %s", jsonExpr, target)}, nil
|
||||
}
|
||||
if value {
|
||||
return renderResult{sql: fmt.Sprintf("%s IS TRUE", jsonExpr)}, nil
|
||||
}
|
||||
return renderResult{sql: fmt.Sprintf("NOT(%s IS TRUE)", jsonExpr)}, nil
|
||||
case CompareNeq:
|
||||
if field.Name == "has_task_list" {
|
||||
target := "0"
|
||||
if value {
|
||||
target = "1"
|
||||
}
|
||||
return renderResult{sql: fmt.Sprintf("%s != %s", jsonExpr, target)}, nil
|
||||
}
|
||||
if value {
|
||||
return renderResult{sql: fmt.Sprintf("NOT(%s IS TRUE)", jsonExpr)}, nil
|
||||
}
|
||||
return renderResult{sql: fmt.Sprintf("%s IS TRUE", jsonExpr)}, nil
|
||||
default:
|
||||
return renderResult{}, errors.Errorf("operator %s not supported for boolean JSON field", op)
|
||||
}
|
||||
case DialectMySQL:
|
||||
boolStr := "false"
|
||||
if value {
|
||||
boolStr = "true"
|
||||
}
|
||||
return renderResult{
|
||||
sql: fmt.Sprintf("%s %s CAST('%s' AS JSON)", jsonExpr, sqlOperator(op), boolStr),
|
||||
}, nil
|
||||
case DialectPostgres:
|
||||
placeholder := r.addArg(value)
|
||||
return renderResult{
|
||||
sql: fmt.Sprintf("(%s)::boolean %s %s", jsonExpr, sqlOperator(op), placeholder),
|
||||
}, nil
|
||||
default:
|
||||
return renderResult{}, errors.Errorf("unsupported dialect %s", r.dialect)
|
||||
}
|
||||
}
|
||||
|
||||
func (r *renderer) renderInCondition(cond *InCondition) (renderResult, error) {
|
||||
fieldRef, ok := cond.Left.(*FieldRef)
|
||||
if !ok {
|
||||
return renderResult{}, errors.New("IN operator requires a field on the left-hand side")
|
||||
}
|
||||
|
||||
if fieldRef.Name == "tag" {
|
||||
return r.renderTagInList(cond.Values)
|
||||
}
|
||||
|
||||
field, ok := r.schema.Field(fieldRef.Name)
|
||||
if !ok {
|
||||
return renderResult{}, errors.Errorf("unknown field %q", fieldRef.Name)
|
||||
}
|
||||
|
||||
if field.Kind != FieldKindScalar {
|
||||
return renderResult{}, errors.Errorf("field %q does not support IN()", fieldRef.Name)
|
||||
}
|
||||
|
||||
return r.renderScalarInCondition(field, cond.Values)
|
||||
}
|
||||
|
||||
func (r *renderer) renderTagInList(values []ValueExpr) (renderResult, error) {
|
||||
field, ok := r.schema.ResolveAlias("tag")
|
||||
if !ok {
|
||||
return renderResult{}, errors.New("tag attribute is not configured")
|
||||
}
|
||||
|
||||
conditions := make([]string, 0, len(values))
|
||||
for _, v := range values {
|
||||
lit, err := expectLiteral(v)
|
||||
if err != nil {
|
||||
return renderResult{}, err
|
||||
}
|
||||
str, ok := lit.(string)
|
||||
if !ok {
|
||||
return renderResult{}, errors.New("tags must be compared with string literals")
|
||||
}
|
||||
|
||||
switch r.dialect {
|
||||
case DialectSQLite:
|
||||
expr := fmt.Sprintf("%s LIKE %s", jsonArrayExpr(r.dialect, field), r.addArg(fmt.Sprintf(`%%"%s"%%`, str)))
|
||||
conditions = append(conditions, expr)
|
||||
case DialectMySQL:
|
||||
expr := fmt.Sprintf("JSON_CONTAINS(%s, %s)", jsonArrayExpr(r.dialect, field), r.addArg(fmt.Sprintf(`"%s"`, str)))
|
||||
conditions = append(conditions, expr)
|
||||
case DialectPostgres:
|
||||
expr := fmt.Sprintf("%s @> jsonb_build_array(%s::json)", jsonArrayExpr(r.dialect, field), r.addArg(fmt.Sprintf(`"%s"`, str)))
|
||||
conditions = append(conditions, expr)
|
||||
default:
|
||||
return renderResult{}, errors.Errorf("unsupported dialect %s", r.dialect)
|
||||
}
|
||||
}
|
||||
|
||||
if len(conditions) == 1 {
|
||||
return renderResult{sql: conditions[0]}, nil
|
||||
}
|
||||
return renderResult{
|
||||
sql: fmt.Sprintf("(%s)", strings.Join(conditions, " OR ")),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (r *renderer) renderElementInCondition(cond *ElementInCondition) (renderResult, error) {
|
||||
field, ok := r.schema.Field(cond.Field)
|
||||
if !ok {
|
||||
return renderResult{}, errors.Errorf("unknown field %q", cond.Field)
|
||||
}
|
||||
if field.Kind != FieldKindJSONList {
|
||||
return renderResult{}, errors.Errorf("field %q is not a tag list", cond.Field)
|
||||
}
|
||||
|
||||
lit, err := expectLiteral(cond.Element)
|
||||
if err != nil {
|
||||
return renderResult{}, err
|
||||
}
|
||||
str, ok := lit.(string)
|
||||
if !ok {
|
||||
return renderResult{}, errors.New("tags membership requires string literal")
|
||||
}
|
||||
|
||||
switch r.dialect {
|
||||
case DialectSQLite:
|
||||
sql := fmt.Sprintf("%s LIKE %s", jsonArrayExpr(r.dialect, field), r.addArg(fmt.Sprintf(`%%"%s"%%`, str)))
|
||||
return renderResult{sql: sql}, nil
|
||||
case DialectMySQL:
|
||||
sql := fmt.Sprintf("JSON_CONTAINS(%s, %s)", jsonArrayExpr(r.dialect, field), r.addArg(str))
|
||||
return renderResult{sql: sql}, nil
|
||||
case DialectPostgres:
|
||||
sql := fmt.Sprintf("%s @> jsonb_build_array(%s::json)", jsonArrayExpr(r.dialect, field), r.addArg(str))
|
||||
return renderResult{sql: sql}, nil
|
||||
default:
|
||||
return renderResult{}, errors.Errorf("unsupported dialect %s", r.dialect)
|
||||
}
|
||||
}
|
||||
|
||||
func (r *renderer) renderScalarInCondition(field Field, values []ValueExpr) (renderResult, error) {
|
||||
placeholders := make([]string, 0, len(values))
|
||||
|
||||
for _, v := range values {
|
||||
lit, err := expectLiteral(v)
|
||||
if err != nil {
|
||||
return renderResult{}, err
|
||||
}
|
||||
switch field.Type {
|
||||
case FieldTypeString:
|
||||
str, ok := lit.(string)
|
||||
if !ok {
|
||||
return renderResult{}, errors.Errorf("field %q expects string values", field.Name)
|
||||
}
|
||||
placeholders = append(placeholders, r.addArg(str))
|
||||
case FieldTypeInt:
|
||||
num, err := toInt64(lit)
|
||||
if err != nil {
|
||||
return renderResult{}, err
|
||||
}
|
||||
placeholders = append(placeholders, r.addArg(num))
|
||||
default:
|
||||
return renderResult{}, errors.Errorf("field %q does not support IN() comparisons", field.Name)
|
||||
}
|
||||
}
|
||||
|
||||
column := field.columnExpr(r.dialect)
|
||||
return renderResult{
|
||||
sql: fmt.Sprintf("%s IN (%s)", column, strings.Join(placeholders, ",")),
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (r *renderer) renderContainsCondition(cond *ContainsCondition) (renderResult, error) {
|
||||
field, ok := r.schema.Field(cond.Field)
|
||||
if !ok {
|
||||
return renderResult{}, errors.Errorf("unknown field %q", cond.Field)
|
||||
}
|
||||
column := field.columnExpr(r.dialect)
|
||||
arg := fmt.Sprintf("%%%s%%", cond.Value)
|
||||
switch r.dialect {
|
||||
case DialectPostgres:
|
||||
sql := fmt.Sprintf("%s ILIKE %s", column, r.addArg(arg))
|
||||
return renderResult{sql: sql}, nil
|
||||
default:
|
||||
sql := fmt.Sprintf("%s LIKE %s", column, r.addArg(arg))
|
||||
return renderResult{sql: sql}, nil
|
||||
}
|
||||
}
|
||||
|
||||
func (r *renderer) jsonBoolPredicate(field Field) (string, error) {
|
||||
expr := jsonExtractExpr(r.dialect, field)
|
||||
switch r.dialect {
|
||||
case DialectSQLite:
|
||||
return fmt.Sprintf("%s IS TRUE", expr), nil
|
||||
case DialectMySQL:
|
||||
return fmt.Sprintf("%s = CAST('true' AS JSON)", expr), nil
|
||||
case DialectPostgres:
|
||||
return fmt.Sprintf("(%s)::boolean IS TRUE", expr), nil
|
||||
default:
|
||||
return "", errors.Errorf("unsupported dialect %s", r.dialect)
|
||||
}
|
||||
}
|
||||
|
||||
func combineAnd(left, right renderResult) renderResult {
|
||||
if left.unsatisfiable || right.unsatisfiable {
|
||||
return renderResult{sql: "1 = 0", unsatisfiable: true}
|
||||
}
|
||||
if left.trivial {
|
||||
return right
|
||||
}
|
||||
if right.trivial {
|
||||
return left
|
||||
}
|
||||
return renderResult{
|
||||
sql: fmt.Sprintf("(%s AND %s)", left.sql, right.sql),
|
||||
}
|
||||
}
|
||||
|
||||
func combineOr(left, right renderResult) renderResult {
|
||||
if left.trivial || right.trivial {
|
||||
return renderResult{trivial: true}
|
||||
}
|
||||
if left.unsatisfiable {
|
||||
return right
|
||||
}
|
||||
if right.unsatisfiable {
|
||||
return left
|
||||
}
|
||||
return renderResult{
|
||||
sql: fmt.Sprintf("(%s OR %s)", left.sql, right.sql),
|
||||
}
|
||||
}
|
||||
|
||||
func (r *renderer) addArg(value any) string {
|
||||
r.placeholderCounter++
|
||||
r.args = append(r.args, value)
|
||||
if r.dialect == DialectPostgres {
|
||||
return fmt.Sprintf("$%d", r.placeholderOffset+r.placeholderCounter)
|
||||
}
|
||||
return "?"
|
||||
}
|
||||
|
||||
func (r *renderer) addBoolArg(value bool) string {
|
||||
var v any
|
||||
switch r.dialect {
|
||||
case DialectSQLite:
|
||||
if value {
|
||||
v = 1
|
||||
} else {
|
||||
v = 0
|
||||
}
|
||||
default:
|
||||
v = value
|
||||
}
|
||||
return r.addArg(v)
|
||||
}
|
||||
|
||||
func expectLiteral(expr ValueExpr) (any, error) {
|
||||
lit, ok := expr.(*LiteralValue)
|
||||
if !ok {
|
||||
return nil, errors.New("expression must be a literal")
|
||||
}
|
||||
return lit.Value, nil
|
||||
}
|
||||
|
||||
func expectBool(expr ValueExpr) (bool, error) {
|
||||
lit, err := expectLiteral(expr)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
value, ok := lit.(bool)
|
||||
if !ok {
|
||||
return false, errors.New("boolean literal required")
|
||||
}
|
||||
return value, nil
|
||||
}
|
||||
|
||||
func expectNumericLiteral(expr ValueExpr) (int64, error) {
|
||||
lit, err := expectLiteral(expr)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return toInt64(lit)
|
||||
}
|
||||
|
||||
func toInt64(value any) (int64, error) {
|
||||
switch v := value.(type) {
|
||||
case int:
|
||||
return int64(v), nil
|
||||
case int32:
|
||||
return int64(v), nil
|
||||
case int64:
|
||||
return v, nil
|
||||
case uint32:
|
||||
return int64(v), nil
|
||||
case uint64:
|
||||
return int64(v), nil
|
||||
case float32:
|
||||
return int64(v), nil
|
||||
case float64:
|
||||
return int64(v), nil
|
||||
default:
|
||||
return 0, errors.Errorf("cannot convert %T to int64", value)
|
||||
}
|
||||
}
|
||||
|
||||
func sqlOperator(op ComparisonOperator) string {
|
||||
return string(op)
|
||||
}
|
||||
|
||||
func qualifyColumn(d DialectName, col Column) string {
|
||||
switch d {
|
||||
case DialectPostgres:
|
||||
return fmt.Sprintf("%s.%s", col.Table, col.Name)
|
||||
default:
|
||||
return fmt.Sprintf("`%s`.`%s`", col.Table, col.Name)
|
||||
}
|
||||
}
|
||||
|
||||
func jsonPath(field Field) string {
|
||||
return "$." + strings.Join(field.JSONPath, ".")
|
||||
}
|
||||
|
||||
func jsonExtractExpr(d DialectName, field Field) string {
|
||||
column := qualifyColumn(d, field.Column)
|
||||
switch d {
|
||||
case DialectSQLite, DialectMySQL:
|
||||
return fmt.Sprintf("JSON_EXTRACT(%s, '%s')", column, jsonPath(field))
|
||||
case DialectPostgres:
|
||||
return buildPostgresJSONAccessor(column, field.JSONPath, true)
|
||||
default:
|
||||
return ""
|
||||
}
|
||||
}
|
||||
|
||||
func jsonArrayExpr(d DialectName, field Field) string {
|
||||
column := qualifyColumn(d, field.Column)
|
||||
switch d {
|
||||
case DialectSQLite, DialectMySQL:
|
||||
return fmt.Sprintf("JSON_EXTRACT(%s, '%s')", column, jsonPath(field))
|
||||
case DialectPostgres:
|
||||
return buildPostgresJSONAccessor(column, field.JSONPath, false)
|
||||
default:
|
||||
return ""
|
||||
}
|
||||
}
|
||||
|
||||
func jsonArrayLengthExpr(d DialectName, field Field) string {
|
||||
arrayExpr := jsonArrayExpr(d, field)
|
||||
switch d {
|
||||
case DialectSQLite:
|
||||
return fmt.Sprintf("JSON_ARRAY_LENGTH(COALESCE(%s, JSON_ARRAY()))", arrayExpr)
|
||||
case DialectMySQL:
|
||||
return fmt.Sprintf("JSON_LENGTH(COALESCE(%s, JSON_ARRAY()))", arrayExpr)
|
||||
case DialectPostgres:
|
||||
return fmt.Sprintf("jsonb_array_length(COALESCE(%s, '[]'::jsonb))", arrayExpr)
|
||||
default:
|
||||
return ""
|
||||
}
|
||||
}
|
||||
|
||||
func buildPostgresJSONAccessor(base string, path []string, terminalText bool) string {
|
||||
expr := base
|
||||
for idx, part := range path {
|
||||
if idx == len(path)-1 && terminalText {
|
||||
expr = fmt.Sprintf("%s->>'%s'", expr, part)
|
||||
} else {
|
||||
expr = fmt.Sprintf("%s->'%s'", expr, part)
|
||||
}
|
||||
}
|
||||
return expr
|
||||
}
|
||||
@ -0,0 +1,254 @@
|
||||
package filter
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/google/cel-go/cel"
|
||||
"github.com/google/cel-go/common/types"
|
||||
"github.com/google/cel-go/common/types/ref"
|
||||
)
|
||||
|
||||
// DialectName enumerates supported SQL dialects.
|
||||
type DialectName string
|
||||
|
||||
const (
|
||||
DialectSQLite DialectName = "sqlite"
|
||||
DialectMySQL DialectName = "mysql"
|
||||
DialectPostgres DialectName = "postgres"
|
||||
)
|
||||
|
||||
// FieldType represents the logical type of a field.
|
||||
type FieldType string
|
||||
|
||||
const (
|
||||
FieldTypeString FieldType = "string"
|
||||
FieldTypeInt FieldType = "int"
|
||||
FieldTypeBool FieldType = "bool"
|
||||
FieldTypeTimestamp FieldType = "timestamp"
|
||||
)
|
||||
|
||||
// FieldKind describes how a field is stored.
|
||||
type FieldKind string
|
||||
|
||||
const (
|
||||
FieldKindScalar FieldKind = "scalar"
|
||||
FieldKindBoolColumn FieldKind = "bool_column"
|
||||
FieldKindJSONBool FieldKind = "json_bool"
|
||||
FieldKindJSONList FieldKind = "json_list"
|
||||
FieldKindVirtualAlias FieldKind = "virtual_alias"
|
||||
)
|
||||
|
||||
// Column identifies the backing table column.
|
||||
type Column struct {
|
||||
Table string
|
||||
Name string
|
||||
}
|
||||
|
||||
// Field captures the schema metadata for an exposed CEL identifier.
|
||||
type Field struct {
|
||||
Name string
|
||||
Kind FieldKind
|
||||
Type FieldType
|
||||
Column Column
|
||||
JSONPath []string
|
||||
AliasFor string
|
||||
SupportsContains bool
|
||||
Expressions map[DialectName]string
|
||||
AllowedComparisonOps map[ComparisonOperator]bool
|
||||
}
|
||||
|
||||
// Schema collects CEL environment options and field metadata.
|
||||
type Schema struct {
|
||||
Name string
|
||||
Fields map[string]Field
|
||||
EnvOptions []cel.EnvOption
|
||||
}
|
||||
|
||||
// Field returns the field metadata if present.
|
||||
func (s Schema) Field(name string) (Field, bool) {
|
||||
f, ok := s.Fields[name]
|
||||
return f, ok
|
||||
}
|
||||
|
||||
// ResolveAlias resolves a virtual alias to its target field.
|
||||
func (s Schema) ResolveAlias(name string) (Field, bool) {
|
||||
field, ok := s.Fields[name]
|
||||
if !ok {
|
||||
return Field{}, false
|
||||
}
|
||||
if field.Kind == FieldKindVirtualAlias {
|
||||
target, ok := s.Fields[field.AliasFor]
|
||||
if !ok {
|
||||
return Field{}, false
|
||||
}
|
||||
return target, true
|
||||
}
|
||||
return field, true
|
||||
}
|
||||
|
||||
var nowFunction = cel.Function("now",
|
||||
cel.Overload("now",
|
||||
[]*cel.Type{},
|
||||
cel.IntType,
|
||||
cel.FunctionBinding(func(_ ...ref.Val) ref.Val {
|
||||
return types.Int(time.Now().Unix())
|
||||
}),
|
||||
),
|
||||
)
|
||||
|
||||
// NewSchema constructs the memo filter schema and CEL environment.
|
||||
func NewSchema() Schema {
|
||||
fields := map[string]Field{
|
||||
"content": {
|
||||
Name: "content",
|
||||
Kind: FieldKindScalar,
|
||||
Type: FieldTypeString,
|
||||
Column: Column{Table: "memo", Name: "content"},
|
||||
SupportsContains: true,
|
||||
Expressions: map[DialectName]string{},
|
||||
},
|
||||
"creator_id": {
|
||||
Name: "creator_id",
|
||||
Kind: FieldKindScalar,
|
||||
Type: FieldTypeInt,
|
||||
Column: Column{Table: "memo", Name: "creator_id"},
|
||||
Expressions: map[DialectName]string{},
|
||||
AllowedComparisonOps: map[ComparisonOperator]bool{
|
||||
CompareEq: true,
|
||||
CompareNeq: true,
|
||||
},
|
||||
},
|
||||
"created_ts": {
|
||||
Name: "created_ts",
|
||||
Kind: FieldKindScalar,
|
||||
Type: FieldTypeTimestamp,
|
||||
Column: Column{Table: "memo", Name: "created_ts"},
|
||||
Expressions: map[DialectName]string{
|
||||
DialectMySQL: "UNIX_TIMESTAMP(%s)",
|
||||
DialectPostgres: "EXTRACT(EPOCH FROM TO_TIMESTAMP(%s))",
|
||||
},
|
||||
},
|
||||
"updated_ts": {
|
||||
Name: "updated_ts",
|
||||
Kind: FieldKindScalar,
|
||||
Type: FieldTypeTimestamp,
|
||||
Column: Column{Table: "memo", Name: "updated_ts"},
|
||||
Expressions: map[DialectName]string{
|
||||
DialectMySQL: "UNIX_TIMESTAMP(%s)",
|
||||
DialectPostgres: "EXTRACT(EPOCH FROM TO_TIMESTAMP(%s))",
|
||||
},
|
||||
},
|
||||
"pinned": {
|
||||
Name: "pinned",
|
||||
Kind: FieldKindBoolColumn,
|
||||
Type: FieldTypeBool,
|
||||
Column: Column{Table: "memo", Name: "pinned"},
|
||||
Expressions: map[DialectName]string{},
|
||||
AllowedComparisonOps: map[ComparisonOperator]bool{
|
||||
CompareEq: true,
|
||||
CompareNeq: true,
|
||||
},
|
||||
},
|
||||
"visibility": {
|
||||
Name: "visibility",
|
||||
Kind: FieldKindScalar,
|
||||
Type: FieldTypeString,
|
||||
Column: Column{Table: "memo", Name: "visibility"},
|
||||
Expressions: map[DialectName]string{},
|
||||
AllowedComparisonOps: map[ComparisonOperator]bool{
|
||||
CompareEq: true,
|
||||
CompareNeq: true,
|
||||
},
|
||||
},
|
||||
"tags": {
|
||||
Name: "tags",
|
||||
Kind: FieldKindJSONList,
|
||||
Type: FieldTypeString,
|
||||
Column: Column{Table: "memo", Name: "payload"},
|
||||
JSONPath: []string{"tags"},
|
||||
},
|
||||
"tag": {
|
||||
Name: "tag",
|
||||
Kind: FieldKindVirtualAlias,
|
||||
Type: FieldTypeString,
|
||||
AliasFor: "tags",
|
||||
},
|
||||
"has_task_list": {
|
||||
Name: "has_task_list",
|
||||
Kind: FieldKindJSONBool,
|
||||
Type: FieldTypeBool,
|
||||
Column: Column{Table: "memo", Name: "payload"},
|
||||
JSONPath: []string{"property", "hasTaskList"},
|
||||
AllowedComparisonOps: map[ComparisonOperator]bool{
|
||||
CompareEq: true,
|
||||
CompareNeq: true,
|
||||
},
|
||||
},
|
||||
"has_link": {
|
||||
Name: "has_link",
|
||||
Kind: FieldKindJSONBool,
|
||||
Type: FieldTypeBool,
|
||||
Column: Column{Table: "memo", Name: "payload"},
|
||||
JSONPath: []string{"property", "hasLink"},
|
||||
AllowedComparisonOps: map[ComparisonOperator]bool{
|
||||
CompareEq: true,
|
||||
CompareNeq: true,
|
||||
},
|
||||
},
|
||||
"has_code": {
|
||||
Name: "has_code",
|
||||
Kind: FieldKindJSONBool,
|
||||
Type: FieldTypeBool,
|
||||
Column: Column{Table: "memo", Name: "payload"},
|
||||
JSONPath: []string{"property", "hasCode"},
|
||||
AllowedComparisonOps: map[ComparisonOperator]bool{
|
||||
CompareEq: true,
|
||||
CompareNeq: true,
|
||||
},
|
||||
},
|
||||
"has_incomplete_tasks": {
|
||||
Name: "has_incomplete_tasks",
|
||||
Kind: FieldKindJSONBool,
|
||||
Type: FieldTypeBool,
|
||||
Column: Column{Table: "memo", Name: "payload"},
|
||||
JSONPath: []string{"property", "hasIncompleteTasks"},
|
||||
AllowedComparisonOps: map[ComparisonOperator]bool{
|
||||
CompareEq: true,
|
||||
CompareNeq: true,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
envOptions := []cel.EnvOption{
|
||||
cel.Variable("content", cel.StringType),
|
||||
cel.Variable("creator_id", cel.IntType),
|
||||
cel.Variable("created_ts", cel.IntType),
|
||||
cel.Variable("updated_ts", cel.IntType),
|
||||
cel.Variable("pinned", cel.BoolType),
|
||||
cel.Variable("tag", cel.StringType),
|
||||
cel.Variable("tags", cel.ListType(cel.StringType)),
|
||||
cel.Variable("visibility", cel.StringType),
|
||||
cel.Variable("has_task_list", cel.BoolType),
|
||||
cel.Variable("has_link", cel.BoolType),
|
||||
cel.Variable("has_code", cel.BoolType),
|
||||
cel.Variable("has_incomplete_tasks", cel.BoolType),
|
||||
nowFunction,
|
||||
}
|
||||
|
||||
return Schema{
|
||||
Name: "memo",
|
||||
Fields: fields,
|
||||
EnvOptions: envOptions,
|
||||
}
|
||||
}
|
||||
|
||||
// columnExpr returns the field expression for the given dialect, applying
|
||||
// any schema-specific overrides (e.g. UNIX timestamp conversions).
|
||||
func (f Field) columnExpr(d DialectName) string {
|
||||
base := qualifyColumn(d, f.Column)
|
||||
if expr, ok := f.Expressions[d]; ok && expr != "" {
|
||||
return fmt.Sprintf(expr, base)
|
||||
}
|
||||
return base
|
||||
}
|
||||
@ -1,146 +0,0 @@
|
||||
package filter
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
)
|
||||
|
||||
// SQLTemplate holds database-specific SQL fragments.
|
||||
type SQLTemplate struct {
|
||||
SQLite string
|
||||
MySQL string
|
||||
PostgreSQL string
|
||||
}
|
||||
|
||||
// TemplateDBType represents the database type for templates.
|
||||
type TemplateDBType string
|
||||
|
||||
const (
|
||||
SQLiteTemplate TemplateDBType = "sqlite"
|
||||
MySQLTemplate TemplateDBType = "mysql"
|
||||
PostgreSQLTemplate TemplateDBType = "postgres"
|
||||
)
|
||||
|
||||
// SQLTemplates contains common SQL patterns for different databases.
|
||||
var SQLTemplates = map[string]SQLTemplate{
|
||||
"json_extract": {
|
||||
SQLite: "JSON_EXTRACT(`memo`.`payload`, '%s')",
|
||||
MySQL: "JSON_EXTRACT(`memo`.`payload`, '%s')",
|
||||
PostgreSQL: "memo.payload%s",
|
||||
},
|
||||
"json_array_length": {
|
||||
SQLite: "JSON_ARRAY_LENGTH(COALESCE(JSON_EXTRACT(`memo`.`payload`, '$.tags'), JSON_ARRAY()))",
|
||||
MySQL: "JSON_LENGTH(COALESCE(JSON_EXTRACT(`memo`.`payload`, '$.tags'), JSON_ARRAY()))",
|
||||
PostgreSQL: "jsonb_array_length(COALESCE(memo.payload->'tags', '[]'::jsonb))",
|
||||
},
|
||||
"json_contains_element": {
|
||||
SQLite: "JSON_EXTRACT(`memo`.`payload`, '$.tags') LIKE ?",
|
||||
MySQL: "JSON_CONTAINS(JSON_EXTRACT(`memo`.`payload`, '$.tags'), ?)",
|
||||
PostgreSQL: "memo.payload->'tags' @> jsonb_build_array(?)",
|
||||
},
|
||||
"json_contains_tag": {
|
||||
SQLite: "JSON_EXTRACT(`memo`.`payload`, '$.tags') LIKE ?",
|
||||
MySQL: "JSON_CONTAINS(JSON_EXTRACT(`memo`.`payload`, '$.tags'), ?)",
|
||||
PostgreSQL: "memo.payload->'tags' @> jsonb_build_array(?)",
|
||||
},
|
||||
"boolean_true": {
|
||||
SQLite: "JSON_EXTRACT(`memo`.`payload`, '$.property.hasTaskList') = 1",
|
||||
MySQL: "JSON_EXTRACT(`memo`.`payload`, '$.property.hasTaskList') = CAST('true' AS JSON)",
|
||||
PostgreSQL: "(memo.payload->'property'->>'hasTaskList')::boolean = true",
|
||||
},
|
||||
"boolean_false": {
|
||||
SQLite: "JSON_EXTRACT(`memo`.`payload`, '$.property.hasTaskList') = 0",
|
||||
MySQL: "JSON_EXTRACT(`memo`.`payload`, '$.property.hasTaskList') = CAST('false' AS JSON)",
|
||||
PostgreSQL: "(memo.payload->'property'->>'hasTaskList')::boolean = false",
|
||||
},
|
||||
"boolean_not_true": {
|
||||
SQLite: "JSON_EXTRACT(`memo`.`payload`, '$.property.hasTaskList') != 1",
|
||||
MySQL: "JSON_EXTRACT(`memo`.`payload`, '$.property.hasTaskList') != CAST('true' AS JSON)",
|
||||
PostgreSQL: "(memo.payload->'property'->>'hasTaskList')::boolean != true",
|
||||
},
|
||||
"boolean_not_false": {
|
||||
SQLite: "JSON_EXTRACT(`memo`.`payload`, '$.property.hasTaskList') != 0",
|
||||
MySQL: "JSON_EXTRACT(`memo`.`payload`, '$.property.hasTaskList') != CAST('false' AS JSON)",
|
||||
PostgreSQL: "(memo.payload->'property'->>'hasTaskList')::boolean != false",
|
||||
},
|
||||
"boolean_compare": {
|
||||
SQLite: "JSON_EXTRACT(`memo`.`payload`, '$.property.hasTaskList') %s ?",
|
||||
MySQL: "JSON_EXTRACT(`memo`.`payload`, '$.property.hasTaskList') %s CAST(? AS JSON)",
|
||||
PostgreSQL: "(memo.payload->'property'->>'hasTaskList')::boolean %s ?",
|
||||
},
|
||||
"boolean_check": {
|
||||
SQLite: "JSON_EXTRACT(`memo`.`payload`, '$.property.hasTaskList') IS TRUE",
|
||||
MySQL: "JSON_EXTRACT(`memo`.`payload`, '$.property.hasTaskList') = CAST('true' AS JSON)",
|
||||
PostgreSQL: "(memo.payload->'property'->>'hasTaskList')::boolean IS TRUE",
|
||||
},
|
||||
"table_prefix": {
|
||||
SQLite: "`memo`",
|
||||
MySQL: "`memo`",
|
||||
PostgreSQL: "memo",
|
||||
},
|
||||
"timestamp_field": {
|
||||
SQLite: "`memo`.`%s`",
|
||||
MySQL: "UNIX_TIMESTAMP(`memo`.`%s`)",
|
||||
PostgreSQL: "EXTRACT(EPOCH FROM memo.%s)",
|
||||
},
|
||||
"content_like": {
|
||||
SQLite: "`memo`.`content` LIKE ?",
|
||||
MySQL: "`memo`.`content` LIKE ?",
|
||||
PostgreSQL: "memo.content ILIKE ?",
|
||||
},
|
||||
"visibility_in": {
|
||||
SQLite: "`memo`.`visibility` IN (%s)",
|
||||
MySQL: "`memo`.`visibility` IN (%s)",
|
||||
PostgreSQL: "memo.visibility IN (%s)",
|
||||
},
|
||||
}
|
||||
|
||||
// GetSQL returns the appropriate SQL for the given template and database type.
|
||||
func GetSQL(templateName string, dbType TemplateDBType) string {
|
||||
template, exists := SQLTemplates[templateName]
|
||||
if !exists {
|
||||
return ""
|
||||
}
|
||||
|
||||
switch dbType {
|
||||
case SQLiteTemplate:
|
||||
return template.SQLite
|
||||
case MySQLTemplate:
|
||||
return template.MySQL
|
||||
case PostgreSQLTemplate:
|
||||
return template.PostgreSQL
|
||||
default:
|
||||
return template.SQLite
|
||||
}
|
||||
}
|
||||
|
||||
// GetParameterPlaceholder returns the appropriate parameter placeholder for the database.
|
||||
func GetParameterPlaceholder(dbType TemplateDBType, index int) string {
|
||||
switch dbType {
|
||||
case PostgreSQLTemplate:
|
||||
return fmt.Sprintf("$%d", index)
|
||||
default:
|
||||
return "?"
|
||||
}
|
||||
}
|
||||
|
||||
// GetParameterValue returns the appropriate parameter value for the database.
|
||||
func GetParameterValue(dbType TemplateDBType, templateName string, value interface{}) interface{} {
|
||||
switch templateName {
|
||||
case "json_contains_element", "json_contains_tag":
|
||||
if dbType == SQLiteTemplate {
|
||||
return fmt.Sprintf(`%%"%s"%%`, value)
|
||||
}
|
||||
return value
|
||||
default:
|
||||
return value
|
||||
}
|
||||
}
|
||||
|
||||
// FormatPlaceholders formats a list of placeholders for the given database type.
|
||||
func FormatPlaceholders(dbType TemplateDBType, count int, startIndex int) []string {
|
||||
placeholders := make([]string, count)
|
||||
for i := 0; i < count; i++ {
|
||||
placeholders[i] = GetParameterPlaceholder(dbType, startIndex+i)
|
||||
}
|
||||
return placeholders
|
||||
}
|
||||
@ -1,68 +0,0 @@
|
||||
package v1
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/usememos/memos/plugin/filter"
|
||||
)
|
||||
|
||||
func TestUserFilterValidation(t *testing.T) {
|
||||
testCases := []struct {
|
||||
name string
|
||||
filter string
|
||||
expectErr bool
|
||||
}{
|
||||
{
|
||||
name: "valid username filter with equals",
|
||||
filter: `username == "testuser"`,
|
||||
expectErr: false,
|
||||
},
|
||||
{
|
||||
name: "valid username filter with contains",
|
||||
filter: `username.contains("admin")`,
|
||||
expectErr: false,
|
||||
},
|
||||
{
|
||||
name: "invalid filter - unknown field",
|
||||
filter: `invalid_field == "test"`,
|
||||
expectErr: true,
|
||||
},
|
||||
{
|
||||
name: "empty filter",
|
||||
filter: "",
|
||||
expectErr: true,
|
||||
},
|
||||
{
|
||||
name: "invalid syntax",
|
||||
filter: `username ==`,
|
||||
expectErr: true,
|
||||
},
|
||||
}
|
||||
|
||||
for _, tc := range testCases {
|
||||
t.Run(tc.name, func(t *testing.T) {
|
||||
// Test the filter parsing directly
|
||||
_, err := filter.Parse(tc.filter, filter.UserFilterCELAttributes...)
|
||||
|
||||
if tc.expectErr && err == nil {
|
||||
t.Errorf("Expected error for filter %q, but got none", tc.filter)
|
||||
}
|
||||
if !tc.expectErr && err != nil {
|
||||
t.Errorf("Expected no error for filter %q, but got: %v", tc.filter, err)
|
||||
}
|
||||
})
|
||||
}
|
||||
}
|
||||
|
||||
func TestUserFilterCELAttributes(t *testing.T) {
|
||||
// Test that our UserFilterCELAttributes contains the username variable
|
||||
expectedAttributes := map[string]bool{
|
||||
"username": true,
|
||||
}
|
||||
|
||||
// This is a basic test to ensure the attributes are defined
|
||||
// In a real test, you would create a CEL environment and verify the attributes
|
||||
for attrName := range expectedAttributes {
|
||||
t.Logf("Expected attribute %s should be available in UserFilterCELAttributes", attrName)
|
||||
}
|
||||
}
|
||||
@ -1,39 +0,0 @@
|
||||
package mysql
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/usememos/memos/plugin/filter"
|
||||
)
|
||||
|
||||
func TestAttachmentConvertExprToSQL(t *testing.T) {
|
||||
tests := []struct {
|
||||
filter string
|
||||
want string
|
||||
args []any
|
||||
}{
|
||||
{
|
||||
filter: `memo_id in ["5atZAj8GcvkSuUA3X2KLaY"]`,
|
||||
want: "`resource`.`memo_id` IN (?)",
|
||||
args: []any{"5atZAj8GcvkSuUA3X2KLaY"},
|
||||
},
|
||||
{
|
||||
filter: `memo_id in ["5atZAj8GcvkSuUA3X2KLaY", "4EN8aEpcJ3MaK4ExHTpiTE"]`,
|
||||
want: "`resource`.`memo_id` IN (?,?)",
|
||||
args: []any{"5atZAj8GcvkSuUA3X2KLaY", "4EN8aEpcJ3MaK4ExHTpiTE"},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
parsedExpr, err := filter.Parse(tt.filter, filter.AttachmentFilterCELAttributes...)
|
||||
require.NoError(t, err)
|
||||
convertCtx := filter.NewConvertContext()
|
||||
converter := filter.NewCommonSQLConverter(&filter.MySQLDialect{})
|
||||
err = converter.ConvertExprToSQL(convertCtx, parsedExpr.GetExpr())
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, tt.want, convertCtx.Buffer.String())
|
||||
require.Equal(t, tt.args, convertCtx.Args)
|
||||
}
|
||||
}
|
||||
@ -1,39 +0,0 @@
|
||||
package mysql
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/usememos/memos/plugin/filter"
|
||||
)
|
||||
|
||||
func TestReactionConvertExprToSQL(t *testing.T) {
|
||||
tests := []struct {
|
||||
filter string
|
||||
want string
|
||||
args []any
|
||||
}{
|
||||
{
|
||||
filter: `content_id in ["memos/5atZAj8GcvkSuUA3X2KLaY"]`,
|
||||
want: "`reaction`.`content_id` IN (?)",
|
||||
args: []any{"memos/5atZAj8GcvkSuUA3X2KLaY"},
|
||||
},
|
||||
{
|
||||
filter: `content_id in ["memos/5atZAj8GcvkSuUA3X2KLaY", "memos/4EN8aEpcJ3MaK4ExHTpiTE"]`,
|
||||
want: "`reaction`.`content_id` IN (?,?)",
|
||||
args: []any{"memos/5atZAj8GcvkSuUA3X2KLaY", "memos/4EN8aEpcJ3MaK4ExHTpiTE"},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
parsedExpr, err := filter.Parse(tt.filter, filter.ReactionFilterCELAttributes...)
|
||||
require.NoError(t, err)
|
||||
convertCtx := filter.NewConvertContext()
|
||||
converter := filter.NewCommonSQLConverter(&filter.MySQLDialect{})
|
||||
err = converter.ConvertExprToSQL(convertCtx, parsedExpr.GetExpr())
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, tt.want, convertCtx.Buffer.String())
|
||||
require.Equal(t, tt.args, convertCtx.Args)
|
||||
}
|
||||
}
|
||||
@ -1,39 +0,0 @@
|
||||
package postgres
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/usememos/memos/plugin/filter"
|
||||
)
|
||||
|
||||
func TestAttachmentConvertExprToSQL(t *testing.T) {
|
||||
tests := []struct {
|
||||
filter string
|
||||
want string
|
||||
args []any
|
||||
}{
|
||||
{
|
||||
filter: `memo_id in ["5atZAj8GcvkSuUA3X2KLaY"]`,
|
||||
want: "resource.memo_id IN ($1)",
|
||||
args: []any{"5atZAj8GcvkSuUA3X2KLaY"},
|
||||
},
|
||||
{
|
||||
filter: `memo_id in ["5atZAj8GcvkSuUA3X2KLaY", "4EN8aEpcJ3MaK4ExHTpiTE"]`,
|
||||
want: "resource.memo_id IN ($1,$2)",
|
||||
args: []any{"5atZAj8GcvkSuUA3X2KLaY", "4EN8aEpcJ3MaK4ExHTpiTE"},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
parsedExpr, err := filter.Parse(tt.filter, filter.AttachmentFilterCELAttributes...)
|
||||
require.NoError(t, err)
|
||||
convertCtx := filter.NewConvertContext()
|
||||
converter := filter.NewCommonSQLConverter(&filter.PostgreSQLDialect{})
|
||||
err = converter.ConvertExprToSQL(convertCtx, parsedExpr.GetExpr())
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, tt.want, convertCtx.Buffer.String())
|
||||
require.Equal(t, tt.args, convertCtx.Args)
|
||||
}
|
||||
}
|
||||
@ -1,39 +0,0 @@
|
||||
package postgres
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/usememos/memos/plugin/filter"
|
||||
)
|
||||
|
||||
func TestReactionConvertExprToSQL(t *testing.T) {
|
||||
tests := []struct {
|
||||
filter string
|
||||
want string
|
||||
args []any
|
||||
}{
|
||||
{
|
||||
filter: `content_id in ["memos/5atZAj8GcvkSuUA3X2KLaY"]`,
|
||||
want: "reaction.content_id IN ($1)",
|
||||
args: []any{"memos/5atZAj8GcvkSuUA3X2KLaY"},
|
||||
},
|
||||
{
|
||||
filter: `content_id in ["memos/5atZAj8GcvkSuUA3X2KLaY", "memos/4EN8aEpcJ3MaK4ExHTpiTE"]`,
|
||||
want: "reaction.content_id IN ($1,$2)",
|
||||
args: []any{"memos/5atZAj8GcvkSuUA3X2KLaY", "memos/4EN8aEpcJ3MaK4ExHTpiTE"},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
parsedExpr, err := filter.Parse(tt.filter, filter.ReactionFilterCELAttributes...)
|
||||
require.NoError(t, err)
|
||||
convertCtx := filter.NewConvertContext()
|
||||
converter := filter.NewCommonSQLConverter(&filter.PostgreSQLDialect{})
|
||||
err = converter.ConvertExprToSQL(convertCtx, parsedExpr.GetExpr())
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, tt.want, convertCtx.Buffer.String())
|
||||
require.Equal(t, tt.args, convertCtx.Args)
|
||||
}
|
||||
}
|
||||
@ -1,39 +0,0 @@
|
||||
package sqlite
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/usememos/memos/plugin/filter"
|
||||
)
|
||||
|
||||
func TestAttachmentConvertExprToSQL(t *testing.T) {
|
||||
tests := []struct {
|
||||
filter string
|
||||
want string
|
||||
args []any
|
||||
}{
|
||||
{
|
||||
filter: `memo_id in ["5atZAj8GcvkSuUA3X2KLaY"]`,
|
||||
want: "`resource`.`memo_id` IN (?)",
|
||||
args: []any{"5atZAj8GcvkSuUA3X2KLaY"},
|
||||
},
|
||||
{
|
||||
filter: `memo_id in ["5atZAj8GcvkSuUA3X2KLaY", "4EN8aEpcJ3MaK4ExHTpiTE"]`,
|
||||
want: "`resource`.`memo_id` IN (?,?)",
|
||||
args: []any{"5atZAj8GcvkSuUA3X2KLaY", "4EN8aEpcJ3MaK4ExHTpiTE"},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
parsedExpr, err := filter.Parse(tt.filter, filter.AttachmentFilterCELAttributes...)
|
||||
require.NoError(t, err)
|
||||
convertCtx := filter.NewConvertContext()
|
||||
converter := filter.NewCommonSQLConverter(&filter.SQLiteDialect{})
|
||||
err = converter.ConvertExprToSQL(convertCtx, parsedExpr.GetExpr())
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, tt.want, convertCtx.Buffer.String())
|
||||
require.Equal(t, tt.args, convertCtx.Args)
|
||||
}
|
||||
}
|
||||
@ -1,39 +0,0 @@
|
||||
package sqlite
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/require"
|
||||
|
||||
"github.com/usememos/memos/plugin/filter"
|
||||
)
|
||||
|
||||
func TestReactionConvertExprToSQL(t *testing.T) {
|
||||
tests := []struct {
|
||||
filter string
|
||||
want string
|
||||
args []any
|
||||
}{
|
||||
{
|
||||
filter: `content_id in ["memos/5atZAj8GcvkSuUA3X2KLaY"]`,
|
||||
want: "`reaction`.`content_id` IN (?)",
|
||||
args: []any{"memos/5atZAj8GcvkSuUA3X2KLaY"},
|
||||
},
|
||||
{
|
||||
filter: `content_id in ["memos/5atZAj8GcvkSuUA3X2KLaY", "memos/4EN8aEpcJ3MaK4ExHTpiTE"]`,
|
||||
want: "`reaction`.`content_id` IN (?,?)",
|
||||
args: []any{"memos/5atZAj8GcvkSuUA3X2KLaY", "memos/4EN8aEpcJ3MaK4ExHTpiTE"},
|
||||
},
|
||||
}
|
||||
|
||||
for _, tt := range tests {
|
||||
parsedExpr, err := filter.Parse(tt.filter, filter.ReactionFilterCELAttributes...)
|
||||
require.NoError(t, err)
|
||||
convertCtx := filter.NewConvertContext()
|
||||
converter := filter.NewCommonSQLConverter(&filter.SQLiteDialect{})
|
||||
err = converter.ConvertExprToSQL(convertCtx, parsedExpr.GetExpr())
|
||||
require.NoError(t, err)
|
||||
require.Equal(t, tt.want, convertCtx.Buffer.String())
|
||||
require.Equal(t, tt.args, convertCtx.Args)
|
||||
}
|
||||
}
|
||||
Loading…
Reference in New Issue