You cannot select more than 25 topics Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
memos/plugin/filter/parser.go

414 lines
10 KiB
Go

package filter
import (
"time"
"github.com/pkg/errors"
exprv1 "google.golang.org/genproto/googleapis/api/expr/v1alpha1"
)
func buildCondition(expr *exprv1.Expr, schema Schema) (Condition, error) {
switch v := expr.ExprKind.(type) {
case *exprv1.Expr_CallExpr:
return buildCallCondition(v.CallExpr, schema)
case *exprv1.Expr_ConstExpr:
val, err := getConstValue(expr)
if err != nil {
return nil, err
}
switch v := val.(type) {
case bool:
return &ConstantCondition{Value: v}, nil
case int64:
return &ConstantCondition{Value: v != 0}, nil
case float64:
return &ConstantCondition{Value: v != 0}, nil
default:
return nil, errors.New("filter must evaluate to a boolean value")
}
case *exprv1.Expr_IdentExpr:
name := v.IdentExpr.GetName()
field, ok := schema.Field(name)
if !ok {
return nil, errors.Errorf("unknown identifier %q", name)
}
if field.Type != FieldTypeBool {
return nil, errors.Errorf("identifier %q is not boolean", name)
}
return &FieldPredicateCondition{Field: name}, nil
default:
return nil, errors.New("unsupported top-level expression")
}
}
func buildCallCondition(call *exprv1.Expr_Call, schema Schema) (Condition, error) {
switch call.Function {
case "_&&_":
if len(call.Args) != 2 {
return nil, errors.New("logical AND expects two arguments")
}
left, err := buildCondition(call.Args[0], schema)
if err != nil {
return nil, err
}
right, err := buildCondition(call.Args[1], schema)
if err != nil {
return nil, err
}
return &LogicalCondition{
Operator: LogicalAnd,
Left: left,
Right: right,
}, nil
case "_||_":
if len(call.Args) != 2 {
return nil, errors.New("logical OR expects two arguments")
}
left, err := buildCondition(call.Args[0], schema)
if err != nil {
return nil, err
}
right, err := buildCondition(call.Args[1], schema)
if err != nil {
return nil, err
}
return &LogicalCondition{
Operator: LogicalOr,
Left: left,
Right: right,
}, nil
case "!_":
if len(call.Args) != 1 {
return nil, errors.New("logical NOT expects one argument")
}
child, err := buildCondition(call.Args[0], schema)
if err != nil {
return nil, err
}
return &NotCondition{Expr: child}, nil
case "_==_", "_!=_", "_<_", "_>_", "_<=_", "_>=_":
return buildComparisonCondition(call, schema)
case "@in":
return buildInCondition(call, schema)
case "contains":
return buildContainsCondition(call, schema)
default:
val, ok, err := evaluateBool(call)
if err != nil {
return nil, err
}
if ok {
return &ConstantCondition{Value: val}, nil
}
return nil, errors.Errorf("unsupported call expression %q", call.Function)
}
}
func buildComparisonCondition(call *exprv1.Expr_Call, schema Schema) (Condition, error) {
if len(call.Args) != 2 {
return nil, errors.New("comparison expects two arguments")
}
op, err := toComparisonOperator(call.Function)
if err != nil {
return nil, err
}
left, err := buildValueExpr(call.Args[0], schema)
if err != nil {
return nil, err
}
right, err := buildValueExpr(call.Args[1], schema)
if err != nil {
return nil, err
}
// If the left side is a field, validate allowed operators.
if field, ok := left.(*FieldRef); ok {
def, exists := schema.Field(field.Name)
if !exists {
return nil, errors.Errorf("unknown identifier %q", field.Name)
}
if def.Kind == FieldKindVirtualAlias {
def, exists = schema.ResolveAlias(field.Name)
if !exists {
return nil, errors.Errorf("invalid alias %q", field.Name)
}
}
if def.AllowedComparisonOps != nil {
if _, allowed := def.AllowedComparisonOps[op]; !allowed {
return nil, errors.Errorf("operator %s not allowed for field %q", op, field.Name)
}
}
}
return &ComparisonCondition{
Left: left,
Operator: op,
Right: right,
}, nil
}
func buildInCondition(call *exprv1.Expr_Call, schema Schema) (Condition, error) {
if len(call.Args) != 2 {
return nil, errors.New("in operator expects two arguments")
}
// Handle identifier in list syntax.
if identName, err := getIdentName(call.Args[0]); err == nil {
if field, ok := schema.Field(identName); ok && field.Kind == FieldKindVirtualAlias {
if _, aliasOk := schema.ResolveAlias(identName); !aliasOk {
return nil, errors.Errorf("invalid alias %q", identName)
}
} else if !ok {
return nil, errors.Errorf("unknown identifier %q", identName)
}
if listExpr := call.Args[1].GetListExpr(); listExpr != nil {
values := make([]ValueExpr, 0, len(listExpr.Elements))
for _, element := range listExpr.Elements {
value, err := buildValueExpr(element, schema)
if err != nil {
return nil, err
}
values = append(values, value)
}
return &InCondition{
Left: &FieldRef{Name: identName},
Values: values,
}, nil
}
}
// Handle "value in identifier" syntax.
if identName, err := getIdentName(call.Args[1]); err == nil {
if _, ok := schema.Field(identName); !ok {
return nil, errors.Errorf("unknown identifier %q", identName)
}
element, err := buildValueExpr(call.Args[0], schema)
if err != nil {
return nil, err
}
return &ElementInCondition{
Element: element,
Field: identName,
}, nil
}
return nil, errors.New("invalid use of in operator")
}
func buildContainsCondition(call *exprv1.Expr_Call, schema Schema) (Condition, error) {
if call.Target == nil {
return nil, errors.New("contains requires a target")
}
targetName, err := getIdentName(call.Target)
if err != nil {
return nil, err
}
field, ok := schema.Field(targetName)
if !ok {
return nil, errors.Errorf("unknown identifier %q", targetName)
}
if !field.SupportsContains {
return nil, errors.Errorf("identifier %q does not support contains()", targetName)
}
if len(call.Args) != 1 {
return nil, errors.New("contains expects exactly one argument")
}
value, err := getConstValue(call.Args[0])
if err != nil {
return nil, errors.Wrap(err, "contains only supports literal arguments")
}
str, ok := value.(string)
if !ok {
return nil, errors.New("contains argument must be a string")
}
return &ContainsCondition{
Field: targetName,
Value: str,
}, nil
}
func buildValueExpr(expr *exprv1.Expr, schema Schema) (ValueExpr, error) {
if identName, err := getIdentName(expr); err == nil {
if _, ok := schema.Field(identName); !ok {
return nil, errors.Errorf("unknown identifier %q", identName)
}
return &FieldRef{Name: identName}, nil
}
if literal, err := getConstValue(expr); err == nil {
return &LiteralValue{Value: literal}, nil
}
if value, ok, err := evaluateNumeric(expr); err != nil {
return nil, err
} else if ok {
return &LiteralValue{Value: value}, nil
}
if boolVal, ok, err := evaluateBoolExpr(expr); err != nil {
return nil, err
} else if ok {
return &LiteralValue{Value: boolVal}, nil
}
if call := expr.GetCallExpr(); call != nil {
switch call.Function {
case "size":
if len(call.Args) != 1 {
return nil, errors.New("size() expects one argument")
}
arg, err := buildValueExpr(call.Args[0], schema)
if err != nil {
return nil, err
}
return &FunctionValue{
Name: "size",
Args: []ValueExpr{arg},
}, nil
case "now":
return &LiteralValue{Value: timeNowUnix()}, nil
case "_+_", "_-_", "_*_":
value, ok, err := evaluateNumeric(expr)
if err != nil {
return nil, err
}
if ok {
return &LiteralValue{Value: value}, nil
}
}
}
return nil, errors.New("unsupported value expression")
}
func toComparisonOperator(fn string) (ComparisonOperator, error) {
switch fn {
case "_==_":
return CompareEq, nil
case "_!=_":
return CompareNeq, nil
case "_<_":
return CompareLt, nil
case "_>_":
return CompareGt, nil
case "_<=_":
return CompareLte, nil
case "_>=_":
return CompareGte, nil
default:
return "", errors.Errorf("unsupported comparison operator %q", fn)
}
}
func getIdentName(expr *exprv1.Expr) (string, error) {
if ident := expr.GetIdentExpr(); ident != nil {
return ident.GetName(), nil
}
return "", errors.New("expression is not an identifier")
}
func getConstValue(expr *exprv1.Expr) (interface{}, error) {
v, ok := expr.ExprKind.(*exprv1.Expr_ConstExpr)
if !ok {
return nil, errors.New("expression is not a literal")
}
switch x := v.ConstExpr.ConstantKind.(type) {
case *exprv1.Constant_StringValue:
return v.ConstExpr.GetStringValue(), nil
case *exprv1.Constant_Int64Value:
return v.ConstExpr.GetInt64Value(), nil
case *exprv1.Constant_Uint64Value:
return int64(v.ConstExpr.GetUint64Value()), nil
case *exprv1.Constant_DoubleValue:
return v.ConstExpr.GetDoubleValue(), nil
case *exprv1.Constant_BoolValue:
return v.ConstExpr.GetBoolValue(), nil
case *exprv1.Constant_NullValue:
return nil, nil
default:
return nil, errors.Errorf("unsupported constant %T", x)
}
}
func evaluateBool(call *exprv1.Expr_Call) (bool, bool, error) {
val, ok, err := evaluateBoolExpr(&exprv1.Expr{ExprKind: &exprv1.Expr_CallExpr{CallExpr: call}})
return val, ok, err
}
func evaluateBoolExpr(expr *exprv1.Expr) (bool, bool, error) {
if literal, err := getConstValue(expr); err == nil {
if b, ok := literal.(bool); ok {
return b, true, nil
}
return false, false, nil
}
if call := expr.GetCallExpr(); call != nil && call.Function == "!_" {
if len(call.Args) != 1 {
return false, false, errors.New("NOT expects exactly one argument")
}
val, ok, err := evaluateBoolExpr(call.Args[0])
if err != nil || !ok {
return false, false, err
}
return !val, true, nil
}
return false, false, nil
}
func evaluateNumeric(expr *exprv1.Expr) (int64, bool, error) {
if literal, err := getConstValue(expr); err == nil {
switch v := literal.(type) {
case int64:
return v, true, nil
case float64:
return int64(v), true, nil
}
return 0, false, nil
}
call := expr.GetCallExpr()
if call == nil {
return 0, false, nil
}
switch call.Function {
case "now":
return timeNowUnix(), true, nil
case "_+_", "_-_", "_*_":
if len(call.Args) != 2 {
return 0, false, errors.New("arithmetic requires two arguments")
}
left, ok, err := evaluateNumeric(call.Args[0])
if err != nil {
return 0, false, err
}
if !ok {
return 0, false, nil
}
right, ok, err := evaluateNumeric(call.Args[1])
if err != nil {
return 0, false, err
}
if !ok {
return 0, false, nil
}
switch call.Function {
case "_+_":
return left + right, true, nil
case "_-_":
return left - right, true, nil
case "_*_":
return left * right, true, nil
}
}
return 0, false, nil
}
func timeNowUnix() int64 {
return time.Now().Unix()
}