mirror of https://github.com/usememos/memos
You cannot select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
627 lines
16 KiB
Go
627 lines
16 KiB
Go
package filter
|
|
|
|
import (
|
|
"fmt"
|
|
"strings"
|
|
|
|
"github.com/pkg/errors"
|
|
)
|
|
|
|
type renderer struct {
|
|
schema Schema
|
|
dialect DialectName
|
|
placeholderOffset int
|
|
placeholderCounter int
|
|
args []any
|
|
}
|
|
|
|
type renderResult struct {
|
|
sql string
|
|
trivial bool
|
|
unsatisfiable bool
|
|
}
|
|
|
|
func newRenderer(schema Schema, opts RenderOptions) *renderer {
|
|
return &renderer{
|
|
schema: schema,
|
|
dialect: opts.Dialect,
|
|
placeholderOffset: opts.PlaceholderOffset,
|
|
}
|
|
}
|
|
|
|
func (r *renderer) Render(cond Condition) (Statement, error) {
|
|
result, err := r.renderCondition(cond)
|
|
if err != nil {
|
|
return Statement{}, err
|
|
}
|
|
args := r.args
|
|
if args == nil {
|
|
args = []any{}
|
|
}
|
|
|
|
switch {
|
|
case result.unsatisfiable:
|
|
return Statement{
|
|
SQL: "1 = 0",
|
|
Args: args,
|
|
}, nil
|
|
case result.trivial:
|
|
return Statement{
|
|
SQL: "",
|
|
Args: args,
|
|
}, nil
|
|
default:
|
|
return Statement{
|
|
SQL: result.sql,
|
|
Args: args,
|
|
}, nil
|
|
}
|
|
}
|
|
|
|
func (r *renderer) renderCondition(cond Condition) (renderResult, error) {
|
|
switch c := cond.(type) {
|
|
case *LogicalCondition:
|
|
return r.renderLogicalCondition(c)
|
|
case *NotCondition:
|
|
return r.renderNotCondition(c)
|
|
case *FieldPredicateCondition:
|
|
return r.renderFieldPredicate(c)
|
|
case *ComparisonCondition:
|
|
return r.renderComparison(c)
|
|
case *InCondition:
|
|
return r.renderInCondition(c)
|
|
case *ElementInCondition:
|
|
return r.renderElementInCondition(c)
|
|
case *ContainsCondition:
|
|
return r.renderContainsCondition(c)
|
|
case *ConstantCondition:
|
|
if c.Value {
|
|
return renderResult{trivial: true}, nil
|
|
}
|
|
return renderResult{sql: "1 = 0", unsatisfiable: true}, nil
|
|
default:
|
|
return renderResult{}, errors.Errorf("unsupported condition type %T", c)
|
|
}
|
|
}
|
|
|
|
func (r *renderer) renderLogicalCondition(cond *LogicalCondition) (renderResult, error) {
|
|
left, err := r.renderCondition(cond.Left)
|
|
if err != nil {
|
|
return renderResult{}, err
|
|
}
|
|
right, err := r.renderCondition(cond.Right)
|
|
if err != nil {
|
|
return renderResult{}, err
|
|
}
|
|
|
|
switch cond.Operator {
|
|
case LogicalAnd:
|
|
return combineAnd(left, right), nil
|
|
case LogicalOr:
|
|
return combineOr(left, right), nil
|
|
default:
|
|
return renderResult{}, errors.Errorf("unsupported logical operator %s", cond.Operator)
|
|
}
|
|
}
|
|
|
|
func (r *renderer) renderNotCondition(cond *NotCondition) (renderResult, error) {
|
|
child, err := r.renderCondition(cond.Expr)
|
|
if err != nil {
|
|
return renderResult{}, err
|
|
}
|
|
|
|
if child.trivial {
|
|
return renderResult{sql: "1 = 0", unsatisfiable: true}, nil
|
|
}
|
|
if child.unsatisfiable {
|
|
return renderResult{trivial: true}, nil
|
|
}
|
|
return renderResult{
|
|
sql: fmt.Sprintf("NOT (%s)", child.sql),
|
|
}, nil
|
|
}
|
|
|
|
func (r *renderer) renderFieldPredicate(cond *FieldPredicateCondition) (renderResult, error) {
|
|
field, ok := r.schema.Field(cond.Field)
|
|
if !ok {
|
|
return renderResult{}, errors.Errorf("unknown field %q", cond.Field)
|
|
}
|
|
|
|
switch field.Kind {
|
|
case FieldKindBoolColumn:
|
|
column := qualifyColumn(r.dialect, field.Column)
|
|
return renderResult{
|
|
sql: fmt.Sprintf("%s IS TRUE", column),
|
|
}, nil
|
|
case FieldKindJSONBool:
|
|
sql, err := r.jsonBoolPredicate(field)
|
|
if err != nil {
|
|
return renderResult{}, err
|
|
}
|
|
return renderResult{sql: sql}, nil
|
|
default:
|
|
return renderResult{}, errors.Errorf("field %q cannot be used as a predicate", cond.Field)
|
|
}
|
|
}
|
|
|
|
func (r *renderer) renderComparison(cond *ComparisonCondition) (renderResult, error) {
|
|
switch left := cond.Left.(type) {
|
|
case *FieldRef:
|
|
field, ok := r.schema.Field(left.Name)
|
|
if !ok {
|
|
return renderResult{}, errors.Errorf("unknown field %q", left.Name)
|
|
}
|
|
switch field.Kind {
|
|
case FieldKindBoolColumn:
|
|
return r.renderBoolColumnComparison(field, cond.Operator, cond.Right)
|
|
case FieldKindJSONBool:
|
|
return r.renderJSONBoolComparison(field, cond.Operator, cond.Right)
|
|
case FieldKindScalar:
|
|
return r.renderScalarComparison(field, cond.Operator, cond.Right)
|
|
default:
|
|
return renderResult{}, errors.Errorf("field %q does not support comparison", field.Name)
|
|
}
|
|
case *FunctionValue:
|
|
return r.renderFunctionComparison(left, cond.Operator, cond.Right)
|
|
default:
|
|
return renderResult{}, errors.New("comparison must start with a field reference or supported function")
|
|
}
|
|
}
|
|
|
|
func (r *renderer) renderFunctionComparison(fn *FunctionValue, op ComparisonOperator, right ValueExpr) (renderResult, error) {
|
|
if fn.Name != "size" {
|
|
return renderResult{}, errors.Errorf("unsupported function %s in comparison", fn.Name)
|
|
}
|
|
if len(fn.Args) != 1 {
|
|
return renderResult{}, errors.New("size() expects one argument")
|
|
}
|
|
fieldArg, ok := fn.Args[0].(*FieldRef)
|
|
if !ok {
|
|
return renderResult{}, errors.New("size() argument must be a field")
|
|
}
|
|
|
|
field, ok := r.schema.Field(fieldArg.Name)
|
|
if !ok {
|
|
return renderResult{}, errors.Errorf("unknown field %q", fieldArg.Name)
|
|
}
|
|
if field.Kind != FieldKindJSONList {
|
|
return renderResult{}, errors.Errorf("size() only supports tag lists, got %q", field.Name)
|
|
}
|
|
|
|
value, err := expectNumericLiteral(right)
|
|
if err != nil {
|
|
return renderResult{}, err
|
|
}
|
|
|
|
expr := jsonArrayLengthExpr(r.dialect, field)
|
|
placeholder := r.addArg(value)
|
|
return renderResult{
|
|
sql: fmt.Sprintf("%s %s %s", expr, sqlOperator(op), placeholder),
|
|
}, nil
|
|
}
|
|
|
|
func (r *renderer) renderScalarComparison(field Field, op ComparisonOperator, right ValueExpr) (renderResult, error) {
|
|
lit, err := expectLiteral(right)
|
|
if err != nil {
|
|
return renderResult{}, err
|
|
}
|
|
|
|
columnExpr := field.columnExpr(r.dialect)
|
|
placeholder := ""
|
|
switch field.Type {
|
|
case FieldTypeString:
|
|
value, ok := lit.(string)
|
|
if !ok {
|
|
return renderResult{}, errors.Errorf("field %q expects string value", field.Name)
|
|
}
|
|
placeholder = r.addArg(value)
|
|
case FieldTypeInt, FieldTypeTimestamp:
|
|
num, err := toInt64(lit)
|
|
if err != nil {
|
|
return renderResult{}, errors.Wrapf(err, "field %q expects integer value", field.Name)
|
|
}
|
|
placeholder = r.addArg(num)
|
|
default:
|
|
return renderResult{}, errors.Errorf("unsupported data type %q for field %s", field.Type, field.Name)
|
|
}
|
|
|
|
return renderResult{
|
|
sql: fmt.Sprintf("%s %s %s", columnExpr, sqlOperator(op), placeholder),
|
|
}, nil
|
|
}
|
|
|
|
func (r *renderer) renderBoolColumnComparison(field Field, op ComparisonOperator, right ValueExpr) (renderResult, error) {
|
|
value, err := expectBool(right)
|
|
if err != nil {
|
|
return renderResult{}, err
|
|
}
|
|
placeholder := r.addBoolArg(value)
|
|
column := qualifyColumn(r.dialect, field.Column)
|
|
return renderResult{
|
|
sql: fmt.Sprintf("%s %s %s", column, sqlOperator(op), placeholder),
|
|
}, nil
|
|
}
|
|
|
|
func (r *renderer) renderJSONBoolComparison(field Field, op ComparisonOperator, right ValueExpr) (renderResult, error) {
|
|
value, err := expectBool(right)
|
|
if err != nil {
|
|
return renderResult{}, err
|
|
}
|
|
|
|
jsonExpr := jsonExtractExpr(r.dialect, field)
|
|
switch r.dialect {
|
|
case DialectSQLite:
|
|
switch op {
|
|
case CompareEq:
|
|
if field.Name == "has_task_list" {
|
|
target := "0"
|
|
if value {
|
|
target = "1"
|
|
}
|
|
return renderResult{sql: fmt.Sprintf("%s = %s", jsonExpr, target)}, nil
|
|
}
|
|
if value {
|
|
return renderResult{sql: fmt.Sprintf("%s IS TRUE", jsonExpr)}, nil
|
|
}
|
|
return renderResult{sql: fmt.Sprintf("NOT(%s IS TRUE)", jsonExpr)}, nil
|
|
case CompareNeq:
|
|
if field.Name == "has_task_list" {
|
|
target := "0"
|
|
if value {
|
|
target = "1"
|
|
}
|
|
return renderResult{sql: fmt.Sprintf("%s != %s", jsonExpr, target)}, nil
|
|
}
|
|
if value {
|
|
return renderResult{sql: fmt.Sprintf("NOT(%s IS TRUE)", jsonExpr)}, nil
|
|
}
|
|
return renderResult{sql: fmt.Sprintf("%s IS TRUE", jsonExpr)}, nil
|
|
default:
|
|
return renderResult{}, errors.Errorf("operator %s not supported for boolean JSON field", op)
|
|
}
|
|
case DialectMySQL:
|
|
boolStr := "false"
|
|
if value {
|
|
boolStr = "true"
|
|
}
|
|
return renderResult{
|
|
sql: fmt.Sprintf("%s %s CAST('%s' AS JSON)", jsonExpr, sqlOperator(op), boolStr),
|
|
}, nil
|
|
case DialectPostgres:
|
|
placeholder := r.addArg(value)
|
|
return renderResult{
|
|
sql: fmt.Sprintf("(%s)::boolean %s %s", jsonExpr, sqlOperator(op), placeholder),
|
|
}, nil
|
|
default:
|
|
return renderResult{}, errors.Errorf("unsupported dialect %s", r.dialect)
|
|
}
|
|
}
|
|
|
|
func (r *renderer) renderInCondition(cond *InCondition) (renderResult, error) {
|
|
fieldRef, ok := cond.Left.(*FieldRef)
|
|
if !ok {
|
|
return renderResult{}, errors.New("IN operator requires a field on the left-hand side")
|
|
}
|
|
|
|
if fieldRef.Name == "tag" {
|
|
return r.renderTagInList(cond.Values)
|
|
}
|
|
|
|
field, ok := r.schema.Field(fieldRef.Name)
|
|
if !ok {
|
|
return renderResult{}, errors.Errorf("unknown field %q", fieldRef.Name)
|
|
}
|
|
|
|
if field.Kind != FieldKindScalar {
|
|
return renderResult{}, errors.Errorf("field %q does not support IN()", fieldRef.Name)
|
|
}
|
|
|
|
return r.renderScalarInCondition(field, cond.Values)
|
|
}
|
|
|
|
func (r *renderer) renderTagInList(values []ValueExpr) (renderResult, error) {
|
|
field, ok := r.schema.ResolveAlias("tag")
|
|
if !ok {
|
|
return renderResult{}, errors.New("tag attribute is not configured")
|
|
}
|
|
|
|
conditions := make([]string, 0, len(values))
|
|
for _, v := range values {
|
|
lit, err := expectLiteral(v)
|
|
if err != nil {
|
|
return renderResult{}, err
|
|
}
|
|
str, ok := lit.(string)
|
|
if !ok {
|
|
return renderResult{}, errors.New("tags must be compared with string literals")
|
|
}
|
|
|
|
switch r.dialect {
|
|
case DialectSQLite:
|
|
expr := fmt.Sprintf("%s LIKE %s", jsonArrayExpr(r.dialect, field), r.addArg(fmt.Sprintf(`%%"%s"%%`, str)))
|
|
conditions = append(conditions, expr)
|
|
case DialectMySQL:
|
|
expr := fmt.Sprintf("JSON_CONTAINS(%s, %s)", jsonArrayExpr(r.dialect, field), r.addArg(fmt.Sprintf(`"%s"`, str)))
|
|
conditions = append(conditions, expr)
|
|
case DialectPostgres:
|
|
expr := fmt.Sprintf("%s @> jsonb_build_array(%s::json)", jsonArrayExpr(r.dialect, field), r.addArg(fmt.Sprintf(`"%s"`, str)))
|
|
conditions = append(conditions, expr)
|
|
default:
|
|
return renderResult{}, errors.Errorf("unsupported dialect %s", r.dialect)
|
|
}
|
|
}
|
|
|
|
if len(conditions) == 1 {
|
|
return renderResult{sql: conditions[0]}, nil
|
|
}
|
|
return renderResult{
|
|
sql: fmt.Sprintf("(%s)", strings.Join(conditions, " OR ")),
|
|
}, nil
|
|
}
|
|
|
|
func (r *renderer) renderElementInCondition(cond *ElementInCondition) (renderResult, error) {
|
|
field, ok := r.schema.Field(cond.Field)
|
|
if !ok {
|
|
return renderResult{}, errors.Errorf("unknown field %q", cond.Field)
|
|
}
|
|
if field.Kind != FieldKindJSONList {
|
|
return renderResult{}, errors.Errorf("field %q is not a tag list", cond.Field)
|
|
}
|
|
|
|
lit, err := expectLiteral(cond.Element)
|
|
if err != nil {
|
|
return renderResult{}, err
|
|
}
|
|
str, ok := lit.(string)
|
|
if !ok {
|
|
return renderResult{}, errors.New("tags membership requires string literal")
|
|
}
|
|
|
|
switch r.dialect {
|
|
case DialectSQLite:
|
|
sql := fmt.Sprintf("%s LIKE %s", jsonArrayExpr(r.dialect, field), r.addArg(fmt.Sprintf(`%%"%s"%%`, str)))
|
|
return renderResult{sql: sql}, nil
|
|
case DialectMySQL:
|
|
sql := fmt.Sprintf("JSON_CONTAINS(%s, %s)", jsonArrayExpr(r.dialect, field), r.addArg(str))
|
|
return renderResult{sql: sql}, nil
|
|
case DialectPostgres:
|
|
sql := fmt.Sprintf("%s @> jsonb_build_array(%s::json)", jsonArrayExpr(r.dialect, field), r.addArg(str))
|
|
return renderResult{sql: sql}, nil
|
|
default:
|
|
return renderResult{}, errors.Errorf("unsupported dialect %s", r.dialect)
|
|
}
|
|
}
|
|
|
|
func (r *renderer) renderScalarInCondition(field Field, values []ValueExpr) (renderResult, error) {
|
|
placeholders := make([]string, 0, len(values))
|
|
|
|
for _, v := range values {
|
|
lit, err := expectLiteral(v)
|
|
if err != nil {
|
|
return renderResult{}, err
|
|
}
|
|
switch field.Type {
|
|
case FieldTypeString:
|
|
str, ok := lit.(string)
|
|
if !ok {
|
|
return renderResult{}, errors.Errorf("field %q expects string values", field.Name)
|
|
}
|
|
placeholders = append(placeholders, r.addArg(str))
|
|
case FieldTypeInt:
|
|
num, err := toInt64(lit)
|
|
if err != nil {
|
|
return renderResult{}, err
|
|
}
|
|
placeholders = append(placeholders, r.addArg(num))
|
|
default:
|
|
return renderResult{}, errors.Errorf("field %q does not support IN() comparisons", field.Name)
|
|
}
|
|
}
|
|
|
|
column := field.columnExpr(r.dialect)
|
|
return renderResult{
|
|
sql: fmt.Sprintf("%s IN (%s)", column, strings.Join(placeholders, ",")),
|
|
}, nil
|
|
}
|
|
|
|
func (r *renderer) renderContainsCondition(cond *ContainsCondition) (renderResult, error) {
|
|
field, ok := r.schema.Field(cond.Field)
|
|
if !ok {
|
|
return renderResult{}, errors.Errorf("unknown field %q", cond.Field)
|
|
}
|
|
column := field.columnExpr(r.dialect)
|
|
arg := fmt.Sprintf("%%%s%%", cond.Value)
|
|
switch r.dialect {
|
|
case DialectPostgres:
|
|
sql := fmt.Sprintf("%s ILIKE %s", column, r.addArg(arg))
|
|
return renderResult{sql: sql}, nil
|
|
default:
|
|
sql := fmt.Sprintf("%s LIKE %s", column, r.addArg(arg))
|
|
return renderResult{sql: sql}, nil
|
|
}
|
|
}
|
|
|
|
func (r *renderer) jsonBoolPredicate(field Field) (string, error) {
|
|
expr := jsonExtractExpr(r.dialect, field)
|
|
switch r.dialect {
|
|
case DialectSQLite:
|
|
return fmt.Sprintf("%s IS TRUE", expr), nil
|
|
case DialectMySQL:
|
|
return fmt.Sprintf("%s = CAST('true' AS JSON)", expr), nil
|
|
case DialectPostgres:
|
|
return fmt.Sprintf("(%s)::boolean IS TRUE", expr), nil
|
|
default:
|
|
return "", errors.Errorf("unsupported dialect %s", r.dialect)
|
|
}
|
|
}
|
|
|
|
func combineAnd(left, right renderResult) renderResult {
|
|
if left.unsatisfiable || right.unsatisfiable {
|
|
return renderResult{sql: "1 = 0", unsatisfiable: true}
|
|
}
|
|
if left.trivial {
|
|
return right
|
|
}
|
|
if right.trivial {
|
|
return left
|
|
}
|
|
return renderResult{
|
|
sql: fmt.Sprintf("(%s AND %s)", left.sql, right.sql),
|
|
}
|
|
}
|
|
|
|
func combineOr(left, right renderResult) renderResult {
|
|
if left.trivial || right.trivial {
|
|
return renderResult{trivial: true}
|
|
}
|
|
if left.unsatisfiable {
|
|
return right
|
|
}
|
|
if right.unsatisfiable {
|
|
return left
|
|
}
|
|
return renderResult{
|
|
sql: fmt.Sprintf("(%s OR %s)", left.sql, right.sql),
|
|
}
|
|
}
|
|
|
|
func (r *renderer) addArg(value any) string {
|
|
r.placeholderCounter++
|
|
r.args = append(r.args, value)
|
|
if r.dialect == DialectPostgres {
|
|
return fmt.Sprintf("$%d", r.placeholderOffset+r.placeholderCounter)
|
|
}
|
|
return "?"
|
|
}
|
|
|
|
func (r *renderer) addBoolArg(value bool) string {
|
|
var v any
|
|
switch r.dialect {
|
|
case DialectSQLite:
|
|
if value {
|
|
v = 1
|
|
} else {
|
|
v = 0
|
|
}
|
|
default:
|
|
v = value
|
|
}
|
|
return r.addArg(v)
|
|
}
|
|
|
|
func expectLiteral(expr ValueExpr) (any, error) {
|
|
lit, ok := expr.(*LiteralValue)
|
|
if !ok {
|
|
return nil, errors.New("expression must be a literal")
|
|
}
|
|
return lit.Value, nil
|
|
}
|
|
|
|
func expectBool(expr ValueExpr) (bool, error) {
|
|
lit, err := expectLiteral(expr)
|
|
if err != nil {
|
|
return false, err
|
|
}
|
|
value, ok := lit.(bool)
|
|
if !ok {
|
|
return false, errors.New("boolean literal required")
|
|
}
|
|
return value, nil
|
|
}
|
|
|
|
func expectNumericLiteral(expr ValueExpr) (int64, error) {
|
|
lit, err := expectLiteral(expr)
|
|
if err != nil {
|
|
return 0, err
|
|
}
|
|
return toInt64(lit)
|
|
}
|
|
|
|
func toInt64(value any) (int64, error) {
|
|
switch v := value.(type) {
|
|
case int:
|
|
return int64(v), nil
|
|
case int32:
|
|
return int64(v), nil
|
|
case int64:
|
|
return v, nil
|
|
case uint32:
|
|
return int64(v), nil
|
|
case uint64:
|
|
return int64(v), nil
|
|
case float32:
|
|
return int64(v), nil
|
|
case float64:
|
|
return int64(v), nil
|
|
default:
|
|
return 0, errors.Errorf("cannot convert %T to int64", value)
|
|
}
|
|
}
|
|
|
|
func sqlOperator(op ComparisonOperator) string {
|
|
return string(op)
|
|
}
|
|
|
|
func qualifyColumn(d DialectName, col Column) string {
|
|
switch d {
|
|
case DialectPostgres:
|
|
return fmt.Sprintf("%s.%s", col.Table, col.Name)
|
|
default:
|
|
return fmt.Sprintf("`%s`.`%s`", col.Table, col.Name)
|
|
}
|
|
}
|
|
|
|
func jsonPath(field Field) string {
|
|
return "$." + strings.Join(field.JSONPath, ".")
|
|
}
|
|
|
|
func jsonExtractExpr(d DialectName, field Field) string {
|
|
column := qualifyColumn(d, field.Column)
|
|
switch d {
|
|
case DialectSQLite, DialectMySQL:
|
|
return fmt.Sprintf("JSON_EXTRACT(%s, '%s')", column, jsonPath(field))
|
|
case DialectPostgres:
|
|
return buildPostgresJSONAccessor(column, field.JSONPath, true)
|
|
default:
|
|
return ""
|
|
}
|
|
}
|
|
|
|
func jsonArrayExpr(d DialectName, field Field) string {
|
|
column := qualifyColumn(d, field.Column)
|
|
switch d {
|
|
case DialectSQLite, DialectMySQL:
|
|
return fmt.Sprintf("JSON_EXTRACT(%s, '%s')", column, jsonPath(field))
|
|
case DialectPostgres:
|
|
return buildPostgresJSONAccessor(column, field.JSONPath, false)
|
|
default:
|
|
return ""
|
|
}
|
|
}
|
|
|
|
func jsonArrayLengthExpr(d DialectName, field Field) string {
|
|
arrayExpr := jsonArrayExpr(d, field)
|
|
switch d {
|
|
case DialectSQLite:
|
|
return fmt.Sprintf("JSON_ARRAY_LENGTH(COALESCE(%s, JSON_ARRAY()))", arrayExpr)
|
|
case DialectMySQL:
|
|
return fmt.Sprintf("JSON_LENGTH(COALESCE(%s, JSON_ARRAY()))", arrayExpr)
|
|
case DialectPostgres:
|
|
return fmt.Sprintf("jsonb_array_length(COALESCE(%s, '[]'::jsonb))", arrayExpr)
|
|
default:
|
|
return ""
|
|
}
|
|
}
|
|
|
|
func buildPostgresJSONAccessor(base string, path []string, terminalText bool) string {
|
|
expr := base
|
|
for idx, part := range path {
|
|
if idx == len(path)-1 && terminalText {
|
|
expr = fmt.Sprintf("%s->>'%s'", expr, part)
|
|
} else {
|
|
expr = fmt.Sprintf("%s->'%s'", expr, part)
|
|
}
|
|
}
|
|
return expr
|
|
}
|