You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
memos/plugin/filter/render.go

627 lines
16 KiB
Go

package filter
import (
"fmt"
"strings"
"github.com/pkg/errors"
)
type renderer struct {
schema Schema
dialect DialectName
placeholderOffset int
placeholderCounter int
args []any
}
type renderResult struct {
sql string
trivial bool
unsatisfiable bool
}
func newRenderer(schema Schema, opts RenderOptions) *renderer {
return &renderer{
schema: schema,
dialect: opts.Dialect,
placeholderOffset: opts.PlaceholderOffset,
}
}
func (r *renderer) Render(cond Condition) (Statement, error) {
result, err := r.renderCondition(cond)
if err != nil {
return Statement{}, err
}
args := r.args
if args == nil {
args = []any{}
}
switch {
case result.unsatisfiable:
return Statement{
SQL: "1 = 0",
Args: args,
}, nil
case result.trivial:
return Statement{
SQL: "",
Args: args,
}, nil
default:
return Statement{
SQL: result.sql,
Args: args,
}, nil
}
}
func (r *renderer) renderCondition(cond Condition) (renderResult, error) {
switch c := cond.(type) {
case *LogicalCondition:
return r.renderLogicalCondition(c)
case *NotCondition:
return r.renderNotCondition(c)
case *FieldPredicateCondition:
return r.renderFieldPredicate(c)
case *ComparisonCondition:
return r.renderComparison(c)
case *InCondition:
return r.renderInCondition(c)
case *ElementInCondition:
return r.renderElementInCondition(c)
case *ContainsCondition:
return r.renderContainsCondition(c)
case *ConstantCondition:
if c.Value {
return renderResult{trivial: true}, nil
}
return renderResult{sql: "1 = 0", unsatisfiable: true}, nil
default:
return renderResult{}, errors.Errorf("unsupported condition type %T", c)
}
}
func (r *renderer) renderLogicalCondition(cond *LogicalCondition) (renderResult, error) {
left, err := r.renderCondition(cond.Left)
if err != nil {
return renderResult{}, err
}
right, err := r.renderCondition(cond.Right)
if err != nil {
return renderResult{}, err
}
switch cond.Operator {
case LogicalAnd:
return combineAnd(left, right), nil
case LogicalOr:
return combineOr(left, right), nil
default:
return renderResult{}, errors.Errorf("unsupported logical operator %s", cond.Operator)
}
}
func (r *renderer) renderNotCondition(cond *NotCondition) (renderResult, error) {
child, err := r.renderCondition(cond.Expr)
if err != nil {
return renderResult{}, err
}
if child.trivial {
return renderResult{sql: "1 = 0", unsatisfiable: true}, nil
}
if child.unsatisfiable {
return renderResult{trivial: true}, nil
}
return renderResult{
sql: fmt.Sprintf("NOT (%s)", child.sql),
}, nil
}
func (r *renderer) renderFieldPredicate(cond *FieldPredicateCondition) (renderResult, error) {
field, ok := r.schema.Field(cond.Field)
if !ok {
return renderResult{}, errors.Errorf("unknown field %q", cond.Field)
}
switch field.Kind {
case FieldKindBoolColumn:
column := qualifyColumn(r.dialect, field.Column)
return renderResult{
sql: fmt.Sprintf("%s IS TRUE", column),
}, nil
case FieldKindJSONBool:
sql, err := r.jsonBoolPredicate(field)
if err != nil {
return renderResult{}, err
}
return renderResult{sql: sql}, nil
default:
return renderResult{}, errors.Errorf("field %q cannot be used as a predicate", cond.Field)
}
}
func (r *renderer) renderComparison(cond *ComparisonCondition) (renderResult, error) {
switch left := cond.Left.(type) {
case *FieldRef:
field, ok := r.schema.Field(left.Name)
if !ok {
return renderResult{}, errors.Errorf("unknown field %q", left.Name)
}
switch field.Kind {
case FieldKindBoolColumn:
return r.renderBoolColumnComparison(field, cond.Operator, cond.Right)
case FieldKindJSONBool:
return r.renderJSONBoolComparison(field, cond.Operator, cond.Right)
case FieldKindScalar:
return r.renderScalarComparison(field, cond.Operator, cond.Right)
default:
return renderResult{}, errors.Errorf("field %q does not support comparison", field.Name)
}
case *FunctionValue:
return r.renderFunctionComparison(left, cond.Operator, cond.Right)
default:
return renderResult{}, errors.New("comparison must start with a field reference or supported function")
}
}
func (r *renderer) renderFunctionComparison(fn *FunctionValue, op ComparisonOperator, right ValueExpr) (renderResult, error) {
if fn.Name != "size" {
return renderResult{}, errors.Errorf("unsupported function %s in comparison", fn.Name)
}
if len(fn.Args) != 1 {
return renderResult{}, errors.New("size() expects one argument")
}
fieldArg, ok := fn.Args[0].(*FieldRef)
if !ok {
return renderResult{}, errors.New("size() argument must be a field")
}
field, ok := r.schema.Field(fieldArg.Name)
if !ok {
return renderResult{}, errors.Errorf("unknown field %q", fieldArg.Name)
}
if field.Kind != FieldKindJSONList {
return renderResult{}, errors.Errorf("size() only supports tag lists, got %q", field.Name)
}
value, err := expectNumericLiteral(right)
if err != nil {
return renderResult{}, err
}
expr := jsonArrayLengthExpr(r.dialect, field)
placeholder := r.addArg(value)
return renderResult{
sql: fmt.Sprintf("%s %s %s", expr, sqlOperator(op), placeholder),
}, nil
}
func (r *renderer) renderScalarComparison(field Field, op ComparisonOperator, right ValueExpr) (renderResult, error) {
lit, err := expectLiteral(right)
if err != nil {
return renderResult{}, err
}
columnExpr := field.columnExpr(r.dialect)
placeholder := ""
switch field.Type {
case FieldTypeString:
value, ok := lit.(string)
if !ok {
return renderResult{}, errors.Errorf("field %q expects string value", field.Name)
}
placeholder = r.addArg(value)
case FieldTypeInt, FieldTypeTimestamp:
num, err := toInt64(lit)
if err != nil {
return renderResult{}, errors.Wrapf(err, "field %q expects integer value", field.Name)
}
placeholder = r.addArg(num)
default:
return renderResult{}, errors.Errorf("unsupported data type %q for field %s", field.Type, field.Name)
}
return renderResult{
sql: fmt.Sprintf("%s %s %s", columnExpr, sqlOperator(op), placeholder),
}, nil
}
func (r *renderer) renderBoolColumnComparison(field Field, op ComparisonOperator, right ValueExpr) (renderResult, error) {
value, err := expectBool(right)
if err != nil {
return renderResult{}, err
}
placeholder := r.addBoolArg(value)
column := qualifyColumn(r.dialect, field.Column)
return renderResult{
sql: fmt.Sprintf("%s %s %s", column, sqlOperator(op), placeholder),
}, nil
}
func (r *renderer) renderJSONBoolComparison(field Field, op ComparisonOperator, right ValueExpr) (renderResult, error) {
value, err := expectBool(right)
if err != nil {
return renderResult{}, err
}
jsonExpr := jsonExtractExpr(r.dialect, field)
switch r.dialect {
case DialectSQLite:
switch op {
case CompareEq:
if field.Name == "has_task_list" {
target := "0"
if value {
target = "1"
}
return renderResult{sql: fmt.Sprintf("%s = %s", jsonExpr, target)}, nil
}
if value {
return renderResult{sql: fmt.Sprintf("%s IS TRUE", jsonExpr)}, nil
}
return renderResult{sql: fmt.Sprintf("NOT(%s IS TRUE)", jsonExpr)}, nil
case CompareNeq:
if field.Name == "has_task_list" {
target := "0"
if value {
target = "1"
}
return renderResult{sql: fmt.Sprintf("%s != %s", jsonExpr, target)}, nil
}
if value {
return renderResult{sql: fmt.Sprintf("NOT(%s IS TRUE)", jsonExpr)}, nil
}
return renderResult{sql: fmt.Sprintf("%s IS TRUE", jsonExpr)}, nil
default:
return renderResult{}, errors.Errorf("operator %s not supported for boolean JSON field", op)
}
case DialectMySQL:
boolStr := "false"
if value {
boolStr = "true"
}
return renderResult{
sql: fmt.Sprintf("%s %s CAST('%s' AS JSON)", jsonExpr, sqlOperator(op), boolStr),
}, nil
case DialectPostgres:
placeholder := r.addArg(value)
return renderResult{
sql: fmt.Sprintf("(%s)::boolean %s %s", jsonExpr, sqlOperator(op), placeholder),
}, nil
default:
return renderResult{}, errors.Errorf("unsupported dialect %s", r.dialect)
}
}
func (r *renderer) renderInCondition(cond *InCondition) (renderResult, error) {
fieldRef, ok := cond.Left.(*FieldRef)
if !ok {
return renderResult{}, errors.New("IN operator requires a field on the left-hand side")
}
if fieldRef.Name == "tag" {
return r.renderTagInList(cond.Values)
}
field, ok := r.schema.Field(fieldRef.Name)
if !ok {
return renderResult{}, errors.Errorf("unknown field %q", fieldRef.Name)
}
if field.Kind != FieldKindScalar {
return renderResult{}, errors.Errorf("field %q does not support IN()", fieldRef.Name)
}
return r.renderScalarInCondition(field, cond.Values)
}
func (r *renderer) renderTagInList(values []ValueExpr) (renderResult, error) {
field, ok := r.schema.ResolveAlias("tag")
if !ok {
return renderResult{}, errors.New("tag attribute is not configured")
}
conditions := make([]string, 0, len(values))
for _, v := range values {
lit, err := expectLiteral(v)
if err != nil {
return renderResult{}, err
}
str, ok := lit.(string)
if !ok {
return renderResult{}, errors.New("tags must be compared with string literals")
}
switch r.dialect {
case DialectSQLite:
expr := fmt.Sprintf("%s LIKE %s", jsonArrayExpr(r.dialect, field), r.addArg(fmt.Sprintf(`%%"%s"%%`, str)))
conditions = append(conditions, expr)
case DialectMySQL:
expr := fmt.Sprintf("JSON_CONTAINS(%s, %s)", jsonArrayExpr(r.dialect, field), r.addArg(fmt.Sprintf(`"%s"`, str)))
conditions = append(conditions, expr)
case DialectPostgres:
expr := fmt.Sprintf("%s @> jsonb_build_array(%s::json)", jsonArrayExpr(r.dialect, field), r.addArg(fmt.Sprintf(`"%s"`, str)))
conditions = append(conditions, expr)
default:
return renderResult{}, errors.Errorf("unsupported dialect %s", r.dialect)
}
}
if len(conditions) == 1 {
return renderResult{sql: conditions[0]}, nil
}
return renderResult{
sql: fmt.Sprintf("(%s)", strings.Join(conditions, " OR ")),
}, nil
}
func (r *renderer) renderElementInCondition(cond *ElementInCondition) (renderResult, error) {
field, ok := r.schema.Field(cond.Field)
if !ok {
return renderResult{}, errors.Errorf("unknown field %q", cond.Field)
}
if field.Kind != FieldKindJSONList {
return renderResult{}, errors.Errorf("field %q is not a tag list", cond.Field)
}
lit, err := expectLiteral(cond.Element)
if err != nil {
return renderResult{}, err
}
str, ok := lit.(string)
if !ok {
return renderResult{}, errors.New("tags membership requires string literal")
}
switch r.dialect {
case DialectSQLite:
sql := fmt.Sprintf("%s LIKE %s", jsonArrayExpr(r.dialect, field), r.addArg(fmt.Sprintf(`%%"%s"%%`, str)))
return renderResult{sql: sql}, nil
case DialectMySQL:
sql := fmt.Sprintf("JSON_CONTAINS(%s, %s)", jsonArrayExpr(r.dialect, field), r.addArg(str))
return renderResult{sql: sql}, nil
case DialectPostgres:
sql := fmt.Sprintf("%s @> jsonb_build_array(%s::json)", jsonArrayExpr(r.dialect, field), r.addArg(str))
return renderResult{sql: sql}, nil
default:
return renderResult{}, errors.Errorf("unsupported dialect %s", r.dialect)
}
}
func (r *renderer) renderScalarInCondition(field Field, values []ValueExpr) (renderResult, error) {
placeholders := make([]string, 0, len(values))
for _, v := range values {
lit, err := expectLiteral(v)
if err != nil {
return renderResult{}, err
}
switch field.Type {
case FieldTypeString:
str, ok := lit.(string)
if !ok {
return renderResult{}, errors.Errorf("field %q expects string values", field.Name)
}
placeholders = append(placeholders, r.addArg(str))
case FieldTypeInt:
num, err := toInt64(lit)
if err != nil {
return renderResult{}, err
}
placeholders = append(placeholders, r.addArg(num))
default:
return renderResult{}, errors.Errorf("field %q does not support IN() comparisons", field.Name)
}
}
column := field.columnExpr(r.dialect)
return renderResult{
sql: fmt.Sprintf("%s IN (%s)", column, strings.Join(placeholders, ",")),
}, nil
}
func (r *renderer) renderContainsCondition(cond *ContainsCondition) (renderResult, error) {
field, ok := r.schema.Field(cond.Field)
if !ok {
return renderResult{}, errors.Errorf("unknown field %q", cond.Field)
}
column := field.columnExpr(r.dialect)
arg := fmt.Sprintf("%%%s%%", cond.Value)
switch r.dialect {
case DialectPostgres:
sql := fmt.Sprintf("%s ILIKE %s", column, r.addArg(arg))
return renderResult{sql: sql}, nil
default:
sql := fmt.Sprintf("%s LIKE %s", column, r.addArg(arg))
return renderResult{sql: sql}, nil
}
}
func (r *renderer) jsonBoolPredicate(field Field) (string, error) {
expr := jsonExtractExpr(r.dialect, field)
switch r.dialect {
case DialectSQLite:
return fmt.Sprintf("%s IS TRUE", expr), nil
case DialectMySQL:
return fmt.Sprintf("%s = CAST('true' AS JSON)", expr), nil
case DialectPostgres:
return fmt.Sprintf("(%s)::boolean IS TRUE", expr), nil
default:
return "", errors.Errorf("unsupported dialect %s", r.dialect)
}
}
func combineAnd(left, right renderResult) renderResult {
if left.unsatisfiable || right.unsatisfiable {
return renderResult{sql: "1 = 0", unsatisfiable: true}
}
if left.trivial {
return right
}
if right.trivial {
return left
}
return renderResult{
sql: fmt.Sprintf("(%s AND %s)", left.sql, right.sql),
}
}
func combineOr(left, right renderResult) renderResult {
if left.trivial || right.trivial {
return renderResult{trivial: true}
}
if left.unsatisfiable {
return right
}
if right.unsatisfiable {
return left
}
return renderResult{
sql: fmt.Sprintf("(%s OR %s)", left.sql, right.sql),
}
}
func (r *renderer) addArg(value any) string {
r.placeholderCounter++
r.args = append(r.args, value)
if r.dialect == DialectPostgres {
return fmt.Sprintf("$%d", r.placeholderOffset+r.placeholderCounter)
}
return "?"
}
func (r *renderer) addBoolArg(value bool) string {
var v any
switch r.dialect {
case DialectSQLite:
if value {
v = 1
} else {
v = 0
}
default:
v = value
}
return r.addArg(v)
}
func expectLiteral(expr ValueExpr) (any, error) {
lit, ok := expr.(*LiteralValue)
if !ok {
return nil, errors.New("expression must be a literal")
}
return lit.Value, nil
}
func expectBool(expr ValueExpr) (bool, error) {
lit, err := expectLiteral(expr)
if err != nil {
return false, err
}
value, ok := lit.(bool)
if !ok {
return false, errors.New("boolean literal required")
}
return value, nil
}
func expectNumericLiteral(expr ValueExpr) (int64, error) {
lit, err := expectLiteral(expr)
if err != nil {
return 0, err
}
return toInt64(lit)
}
func toInt64(value any) (int64, error) {
switch v := value.(type) {
case int:
return int64(v), nil
case int32:
return int64(v), nil
case int64:
return v, nil
case uint32:
return int64(v), nil
case uint64:
return int64(v), nil
case float32:
return int64(v), nil
case float64:
return int64(v), nil
default:
return 0, errors.Errorf("cannot convert %T to int64", value)
}
}
func sqlOperator(op ComparisonOperator) string {
return string(op)
}
func qualifyColumn(d DialectName, col Column) string {
switch d {
case DialectPostgres:
return fmt.Sprintf("%s.%s", col.Table, col.Name)
default:
return fmt.Sprintf("`%s`.`%s`", col.Table, col.Name)
}
}
func jsonPath(field Field) string {
return "$." + strings.Join(field.JSONPath, ".")
}
func jsonExtractExpr(d DialectName, field Field) string {
column := qualifyColumn(d, field.Column)
switch d {
case DialectSQLite, DialectMySQL:
return fmt.Sprintf("JSON_EXTRACT(%s, '%s')", column, jsonPath(field))
case DialectPostgres:
return buildPostgresJSONAccessor(column, field.JSONPath, true)
default:
return ""
}
}
func jsonArrayExpr(d DialectName, field Field) string {
column := qualifyColumn(d, field.Column)
switch d {
case DialectSQLite, DialectMySQL:
return fmt.Sprintf("JSON_EXTRACT(%s, '%s')", column, jsonPath(field))
case DialectPostgres:
return buildPostgresJSONAccessor(column, field.JSONPath, false)
default:
return ""
}
}
func jsonArrayLengthExpr(d DialectName, field Field) string {
arrayExpr := jsonArrayExpr(d, field)
switch d {
case DialectSQLite:
return fmt.Sprintf("JSON_ARRAY_LENGTH(COALESCE(%s, JSON_ARRAY()))", arrayExpr)
case DialectMySQL:
return fmt.Sprintf("JSON_LENGTH(COALESCE(%s, JSON_ARRAY()))", arrayExpr)
case DialectPostgres:
return fmt.Sprintf("jsonb_array_length(COALESCE(%s, '[]'::jsonb))", arrayExpr)
default:
return ""
}
}
func buildPostgresJSONAccessor(base string, path []string, terminalText bool) string {
expr := base
for idx, part := range path {
if idx == len(path)-1 && terminalText {
expr = fmt.Sprintf("%s->>'%s'", expr, part)
} else {
expr = fmt.Sprintf("%s->'%s'", expr, part)
}
}
return expr
}