Add NOT operator to VectorStore portable filter expressions

- Add NOT expression type to the portable Filter.Expression model.
 - Add NOT to the Antlr grammar and implement the related parser listener method to generate Filter NOT expressions.
 - Add NOT support to the filter programming DSL.
 - Implement FilterHelper.negation for logically transform any boolean expression with NOT statements into
   semantically equivalent one with NOT applied to the leaf expressions.
 - Add tests for paresers, converters and vectorsores ITs.
 - Move the filter IN/NIN expansion logic to the FilterHelper
 - Factor out the filter IN/NIN boolean expression expansion logic out of Weaviate up to the FilterHelper.
 - add in/nin expantion FilterHelper tests
This commit is contained in:
Christian Tzolov
2023-11-23 21:29:48 +01:00
committed by Mark Pollack
parent e3e5070256
commit ac9ae589f4
25 changed files with 731 additions and 153 deletions

View File

@@ -114,7 +114,7 @@ public class Filter {
*/
public enum ExpressionType {
AND, OR, EQ, NE, GT, GTE, LT, LTE, IN, NIN
AND, OR, EQ, NE, GT, GTE, LT, LTE, IN, NIN, NOT
}
@@ -131,6 +131,9 @@ public class Filter {
* be another {@link Expression}.
*/
public record Expression(ExpressionType type, Operand left, Operand right) implements Operand {
public Expression(ExpressionType type, Operand operand) {
this(type, operand, null);
}
}
/**

View File

@@ -122,4 +122,8 @@ public class FilterExpressionBuilder {
return new Op(new Filter.Group(content.build()));
}
public Op not(Op content) {
return new Op(new Filter.Expression(ExpressionType.NOT, content.expression, null));
}
}

View File

@@ -35,6 +35,7 @@ import org.antlr.v4.runtime.misc.ParseCancellationException;
import org.springframework.ai.vectorstore.filter.antlr4.FiltersBaseVisitor;
import org.springframework.ai.vectorstore.filter.antlr4.FiltersLexer;
import org.springframework.ai.vectorstore.filter.antlr4.FiltersParser;
import org.springframework.ai.vectorstore.filter.antlr4.FiltersParser.NotExpressionContext;
import org.springframework.core.NestedExceptionUtils;
import org.springframework.util.Assert;
@@ -263,6 +264,11 @@ public class FilterExpressionTextParser {
return new Filter.Group(castToExpression(this.visit(ctx.booleanExpression())));
}
@Override
public Filter.Operand visitNotExpression(NotExpressionContext ctx) {
return new Filter.Expression(Filter.ExpressionType.NOT, this.visit(ctx.booleanExpression()), null);
}
public Filter.Expression castToExpression(Filter.Operand expression) {
if (expression instanceof Filter.Group group) {
// Remove the top-level grouping.

View File

@@ -0,0 +1,205 @@
/*
* Copyright 2023-2023 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.ai.vectorstore.filter;
import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import org.springframework.ai.vectorstore.filter.Filter.Expression;
import org.springframework.ai.vectorstore.filter.Filter.ExpressionType;
import org.springframework.ai.vectorstore.filter.Filter.Operand;
import org.springframework.ai.vectorstore.filter.converter.FilterExpressionConverter;
import org.springframework.util.Assert;
/**
* Helper class providing various boolean transformation.
*
* @author Christian Tzolov
*/
public class FilterHelper {
private FilterHelper() {
}
private final static Map<ExpressionType, ExpressionType> TYPE_NEGATION_MAP = Map.of(ExpressionType.AND,
ExpressionType.OR, ExpressionType.OR, ExpressionType.AND, ExpressionType.EQ, ExpressionType.NE,
ExpressionType.NE, ExpressionType.EQ, ExpressionType.GT, ExpressionType.LTE, ExpressionType.GTE,
ExpressionType.LT, ExpressionType.LT, ExpressionType.GTE, ExpressionType.LTE, ExpressionType.GT,
ExpressionType.IN, ExpressionType.NIN, ExpressionType.NIN, ExpressionType.IN);
/**
* Transforms the input expression into a semantically equivalent one with negation
* operators propagated thought the expression tree by following the negation rules:
*
* <pre>
* NOT(NOT(a)) = a
*
* NOT(a AND b) = NOT(a) OR NOT(b)
* NOT(a OR b) = NOT(a) AND NOT(b)
*
* NOT(a EQ b) = a NE b
* NOT(a NE b) = a EQ b
*
* NOT(a GT b) = a LTE b
* NOT(a GTE b) = a LT b
*
* NOT(a LT b) = a GTE b
* NOT(a LTE b) = a GT b
*
* NOT(a IN [...]) = a NIN [...]
* NOT(a NIN [...]) = a IN [...]
* </pre>
* @param operand Filter expression to negate.
* @return Returns an negation of the input expression.
*/
public static Filter.Operand negate(Filter.Operand operand) {
if (operand instanceof Filter.Group group) {
Operand inEx = negate(group.content());
if (inEx instanceof Filter.Group inEx2) {
inEx = inEx2.content();
}
return new Filter.Group((Expression) inEx);
}
else if (operand instanceof Filter.Expression exp) {
switch (exp.type()) {
case NOT: // NOT(NOT(a)) = a
return negate(exp.left());
case AND: // NOT(a AND b) = NOT(a) OR NOT(b)
case OR: // NOT(a OR b) = NOT(a) AND NOT(b)
return new Filter.Expression(TYPE_NEGATION_MAP.get(exp.type()), negate(exp.left()),
negate(exp.right()));
case EQ: // NOT(e EQ b) = e NE b
case NE: // NOT(e NE b) = e EQ b
case GT: // NOT(e GT b) = e LTE b
case GTE: // NOT(e GTE b) = e LT b
case LT: // NOT(e LT b) = e GTE b
case LTE: // NOT(e LTE b) = e GT b
return new Filter.Expression(TYPE_NEGATION_MAP.get(exp.type()), exp.left(), exp.right());
case IN: // NOT(e IN [...]) = e NIN [...]
case NIN: // NOT(e NIN [...]) = e IN [...]
return new Filter.Expression(TYPE_NEGATION_MAP.get(exp.type()), exp.left(), exp.right());
default:
throw new IllegalArgumentException("Unknown expression type: " + exp.type());
}
}
else {
throw new IllegalArgumentException("Can not negate operand of type: " + operand.getClass());
}
}
/**
* Expands the IN into a semantically equivalent boolean expressions of ORs of EQs.
* Useful for providers that don't provide native IN support.
*
* For example the <pre>
* foo IN ["bar1", "bar2", "bar3"]
* </pre>
*
* expression is equivalent to
*
* <pre>
* {@code foo == "bar1" || foo == "bar2" || foo == "bar3" (e.g. OR(foo EQ "bar1" OR(foo EQ "bar2" OR(foo EQ "bar3")))}
* </pre>
* @param exp input IN expression.
* @param context Output native expression.
* @param filterExpressionConverter {@link FilterExpressionConverter} used to compose
* the OR and EQ expanded expressions.
*/
public static void expandIn(Expression exp, StringBuilder context,
FilterExpressionConverter filterExpressionConverter) {
Assert.isTrue(exp.type() == ExpressionType.IN, "Expected IN expressions but was: " + exp.type());
expandInNinExpressions(ExpressionType.OR, ExpressionType.EQ, exp, context, filterExpressionConverter);
}
/**
*
* Expands the NIN (e.g. NOT IN) into a semantically equivalent boolean expressions of
* ANDs of NEs. Useful for providers that don't provide native NIN support.<br/>
*
* For example the
*
* <pre>
* foo NIN ["bar1", "bar2", "bar3"] (or foo NOT IN ["bar1", "bar2", "bar3"])
* </pre>
*
* express is equivalent to
*
* <pre>
* {@code foo != "bar1" && foo != "bar2" && foo != "bar3" (e.g. AND(foo NE "bar1" AND( foo NE "bar2" OR(foo NE "bar3"))) )}
* </pre>
* @param exp input NIN expression.
* @param context Output native expression.
* @param filterExpressionConverter {@link FilterExpressionConverter} used to compose
* the AND and NE expanded expressions.
*/
public static void expandNin(Expression exp, StringBuilder context,
FilterExpressionConverter filterExpressionConverter) {
Assert.isTrue(exp.type() == ExpressionType.NIN, "Expected NIN expressions but was: " + exp.type());
expandInNinExpressions(ExpressionType.AND, ExpressionType.NE, exp, context, filterExpressionConverter);
}
private static void expandInNinExpressions(Filter.ExpressionType outerExpressionType,
Filter.ExpressionType innerExpressionType, Expression exp, StringBuilder context,
FilterExpressionConverter expressionConverter) {
if (exp.right() instanceof Filter.Value value) {
if (value.value() instanceof List list) {
// 1. foo IN ["bar1", "bar2", "bar3"] is equivalent to foo == "bar1" ||
// foo == "bar2" || foo == "bar3"
// or equivalent to OR(foo == "bar1" OR( foo == "bar2" OR(foo == "bar3")))
// 2. foo IN ["bar1", "bar2", "bar3"] is equivalent to foo != "bar1" &&
// foo != "bar2" && foo != "bar3"
// or equivalent to AND(foo != "bar1" AND( foo != "bar2" OR(foo !=
// "bar3")))
List<Filter.Expression> eqExprs = new ArrayList<>();
for (Object o : list) {
eqExprs.add(new Filter.Expression(innerExpressionType, exp.left(), new Filter.Value(o)));
}
context.append(expressionConverter.convertExpression(aggregate(outerExpressionType, eqExprs)));
}
else {
// 1. foo IN ["bar"] is equivalent to foo == "BAR"
// 2. foo NIN ["bar"] is equivalent to foo != "BAR"
context.append(expressionConverter
.convertExpression(new Filter.Expression(innerExpressionType, exp.left(), exp.right())));
}
}
else {
throw new IllegalStateException(
"Filter IN right expression should be of Filter.Value type but was " + exp.right().getClass());
}
}
/**
* Recursively aggregates a list of expression into a binary tree with 'aggregateType'
* join nodes.
* @param aggregateType type all tree splits.
* @param expressions list of expressions to aggregate.
* @return Returns a binary tree expression.
*/
private static Filter.Expression aggregate(Filter.ExpressionType aggregateType,
List<Filter.Expression> expressions) {
if (expressions.size() == 1) {
return expressions.get(0);
}
return new Filter.Expression(aggregateType, expressions.get(0),
aggregate(aggregateType, expressions.subList(1, expressions.size())));
}
}

View File

@@ -66,4 +66,4 @@ constant
atn:
[4, 1, 26, 87, 2, 0, 7, 0, 2, 1, 7, 1, 2, 2, 7, 2, 2, 3, 7, 3, 2, 4, 7, 4, 2, 5, 7, 5, 1, 0, 1, 0, 1, 0, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 3, 1, 30, 8, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 3, 1, 38, 8, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 5, 1, 46, 8, 1, 10, 1, 12, 1, 49, 9, 1, 1, 2, 1, 2, 1, 2, 1, 2, 5, 2, 55, 8, 2, 10, 2, 12, 2, 58, 9, 2, 1, 2, 1, 2, 1, 3, 1, 3, 1, 4, 1, 4, 1, 4, 1, 4, 1, 4, 3, 4, 69, 8, 4, 1, 5, 3, 5, 72, 8, 5, 1, 5, 1, 5, 3, 5, 76, 8, 5, 1, 5, 1, 5, 4, 5, 80, 8, 5, 11, 5, 12, 5, 81, 1, 5, 3, 5, 85, 8, 5, 1, 5, 0, 1, 2, 6, 0, 2, 4, 6, 8, 10, 0, 2, 2, 0, 8, 8, 11, 15, 1, 0, 9, 10, 95, 0, 12, 1, 0, 0, 0, 2, 37, 1, 0, 0, 0, 4, 50, 1, 0, 0, 0, 6, 61, 1, 0, 0, 0, 8, 68, 1, 0, 0, 0, 10, 84, 1, 0, 0, 0, 12, 13, 5, 1, 0, 0, 13, 14, 3, 2, 1, 0, 14, 15, 5, 0, 0, 1, 15, 1, 1, 0, 0, 0, 16, 17, 6, 1, -1, 0, 17, 18, 3, 8, 4, 0, 18, 19, 3, 6, 3, 0, 19, 20, 3, 10, 5, 0, 20, 38, 1, 0, 0, 0, 21, 22, 3, 8, 4, 0, 22, 23, 5, 18, 0, 0, 23, 24, 3, 4, 2, 0, 24, 38, 1, 0, 0, 0, 25, 29, 3, 8, 4, 0, 26, 27, 5, 20, 0, 0, 27, 30, 5, 18, 0, 0, 28, 30, 5, 19, 0, 0, 29, 26, 1, 0, 0, 0, 29, 28, 1, 0, 0, 0, 30, 31, 1, 0, 0, 0, 31, 32, 3, 4, 2, 0, 32, 38, 1, 0, 0, 0, 33, 34, 5, 6, 0, 0, 34, 35, 3, 2, 1, 0, 35, 36, 5, 7, 0, 0, 36, 38, 1, 0, 0, 0, 37, 16, 1, 0, 0, 0, 37, 21, 1, 0, 0, 0, 37, 25, 1, 0, 0, 0, 37, 33, 1, 0, 0, 0, 38, 47, 1, 0, 0, 0, 39, 40, 10, 3, 0, 0, 40, 41, 5, 16, 0, 0, 41, 46, 3, 2, 1, 4, 42, 43, 10, 2, 0, 0, 43, 44, 5, 17, 0, 0, 44, 46, 3, 2, 1, 3, 45, 39, 1, 0, 0, 0, 45, 42, 1, 0, 0, 0, 46, 49, 1, 0, 0, 0, 47, 45, 1, 0, 0, 0, 47, 48, 1, 0, 0, 0, 48, 3, 1, 0, 0, 0, 49, 47, 1, 0, 0, 0, 50, 51, 5, 4, 0, 0, 51, 56, 3, 10, 5, 0, 52, 53, 5, 3, 0, 0, 53, 55, 3, 10, 5, 0, 54, 52, 1, 0, 0, 0, 55, 58, 1, 0, 0, 0, 56, 54, 1, 0, 0, 0, 56, 57, 1, 0, 0, 0, 57, 59, 1, 0, 0, 0, 58, 56, 1, 0, 0, 0, 59, 60, 5, 5, 0, 0, 60, 5, 1, 0, 0, 0, 61, 62, 7, 0, 0, 0, 62, 7, 1, 0, 0, 0, 63, 64, 5, 25, 0, 0, 64, 65, 5, 2, 0, 0, 65, 69, 5, 25, 0, 0, 66, 69, 5, 25, 0, 0, 67, 69, 5, 22, 0, 0, 68, 63, 1, 0, 0, 0, 68, 66, 1, 0, 0, 0, 68, 67, 1, 0, 0, 0, 69, 9, 1, 0, 0, 0, 70, 72, 7, 1, 0, 0, 71, 70, 1, 0, 0, 0, 71, 72, 1, 0, 0, 0, 72, 73, 1, 0, 0, 0, 73, 85, 5, 23, 0, 0, 74, 76, 7, 1, 0, 0, 75, 74, 1, 0, 0, 0, 75, 76, 1, 0, 0, 0, 76, 77, 1, 0, 0, 0, 77, 85, 5, 24, 0, 0, 78, 80, 5, 22, 0, 0, 79, 78, 1, 0, 0, 0, 80, 81, 1, 0, 0, 0, 81, 79, 1, 0, 0, 0, 81, 82, 1, 0, 0, 0, 82, 85, 1, 0, 0, 0, 83, 85, 5, 21, 0, 0, 84, 71, 1, 0, 0, 0, 84, 75, 1, 0, 0, 0, 84, 79, 1, 0, 0, 0, 84, 83, 1, 0, 0, 0, 85, 11, 1, 0, 0, 0, 10, 29, 37, 45, 47, 56, 68, 71, 75, 81, 84]
[4, 1, 26, 89, 2, 0, 7, 0, 2, 1, 7, 1, 2, 2, 7, 2, 2, 3, 7, 3, 2, 4, 7, 4, 2, 5, 7, 5, 1, 0, 1, 0, 1, 0, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 3, 1, 30, 8, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 3, 1, 40, 8, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 5, 1, 48, 8, 1, 10, 1, 12, 1, 51, 9, 1, 1, 2, 1, 2, 1, 2, 1, 2, 5, 2, 57, 8, 2, 10, 2, 12, 2, 60, 9, 2, 1, 2, 1, 2, 1, 3, 1, 3, 1, 4, 1, 4, 1, 4, 1, 4, 1, 4, 3, 4, 71, 8, 4, 1, 5, 3, 5, 74, 8, 5, 1, 5, 1, 5, 3, 5, 78, 8, 5, 1, 5, 1, 5, 4, 5, 82, 8, 5, 11, 5, 12, 5, 83, 1, 5, 3, 5, 87, 8, 5, 1, 5, 0, 1, 2, 6, 0, 2, 4, 6, 8, 10, 0, 2, 2, 0, 8, 8, 11, 15, 1, 0, 9, 10, 98, 0, 12, 1, 0, 0, 0, 2, 39, 1, 0, 0, 0, 4, 52, 1, 0, 0, 0, 6, 63, 1, 0, 0, 0, 8, 70, 1, 0, 0, 0, 10, 86, 1, 0, 0, 0, 12, 13, 5, 1, 0, 0, 13, 14, 3, 2, 1, 0, 14, 15, 5, 0, 0, 1, 15, 1, 1, 0, 0, 0, 16, 17, 6, 1, -1, 0, 17, 18, 3, 8, 4, 0, 18, 19, 3, 6, 3, 0, 19, 20, 3, 10, 5, 0, 20, 40, 1, 0, 0, 0, 21, 22, 3, 8, 4, 0, 22, 23, 5, 18, 0, 0, 23, 24, 3, 4, 2, 0, 24, 40, 1, 0, 0, 0, 25, 29, 3, 8, 4, 0, 26, 27, 5, 20, 0, 0, 27, 30, 5, 18, 0, 0, 28, 30, 5, 19, 0, 0, 29, 26, 1, 0, 0, 0, 29, 28, 1, 0, 0, 0, 30, 31, 1, 0, 0, 0, 31, 32, 3, 4, 2, 0, 32, 40, 1, 0, 0, 0, 33, 34, 5, 6, 0, 0, 34, 35, 3, 2, 1, 0, 35, 36, 5, 7, 0, 0, 36, 40, 1, 0, 0, 0, 37, 38, 5, 20, 0, 0, 38, 40, 3, 2, 1, 1, 39, 16, 1, 0, 0, 0, 39, 21, 1, 0, 0, 0, 39, 25, 1, 0, 0, 0, 39, 33, 1, 0, 0, 0, 39, 37, 1, 0, 0, 0, 40, 49, 1, 0, 0, 0, 41, 42, 10, 4, 0, 0, 42, 43, 5, 16, 0, 0, 43, 48, 3, 2, 1, 5, 44, 45, 10, 3, 0, 0, 45, 46, 5, 17, 0, 0, 46, 48, 3, 2, 1, 4, 47, 41, 1, 0, 0, 0, 47, 44, 1, 0, 0, 0, 48, 51, 1, 0, 0, 0, 49, 47, 1, 0, 0, 0, 49, 50, 1, 0, 0, 0, 50, 3, 1, 0, 0, 0, 51, 49, 1, 0, 0, 0, 52, 53, 5, 4, 0, 0, 53, 58, 3, 10, 5, 0, 54, 55, 5, 3, 0, 0, 55, 57, 3, 10, 5, 0, 56, 54, 1, 0, 0, 0, 57, 60, 1, 0, 0, 0, 58, 56, 1, 0, 0, 0, 58, 59, 1, 0, 0, 0, 59, 61, 1, 0, 0, 0, 60, 58, 1, 0, 0, 0, 61, 62, 5, 5, 0, 0, 62, 5, 1, 0, 0, 0, 63, 64, 7, 0, 0, 0, 64, 7, 1, 0, 0, 0, 65, 66, 5, 25, 0, 0, 66, 67, 5, 2, 0, 0, 67, 71, 5, 25, 0, 0, 68, 71, 5, 25, 0, 0, 69, 71, 5, 22, 0, 0, 70, 65, 1, 0, 0, 0, 70, 68, 1, 0, 0, 0, 70, 69, 1, 0, 0, 0, 71, 9, 1, 0, 0, 0, 72, 74, 7, 1, 0, 0, 73, 72, 1, 0, 0, 0, 73, 74, 1, 0, 0, 0, 74, 75, 1, 0, 0, 0, 75, 87, 5, 23, 0, 0, 76, 78, 7, 1, 0, 0, 77, 76, 1, 0, 0, 0, 77, 78, 1, 0, 0, 0, 78, 79, 1, 0, 0, 0, 79, 87, 5, 24, 0, 0, 80, 82, 5, 22, 0, 0, 81, 80, 1, 0, 0, 0, 82, 83, 1, 0, 0, 0, 83, 81, 1, 0, 0, 0, 83, 84, 1, 0, 0, 0, 84, 87, 1, 0, 0, 0, 85, 87, 5, 21, 0, 0, 86, 73, 1, 0, 0, 0, 86, 77, 1, 0, 0, 0, 86, 81, 1, 0, 0, 0, 86, 85, 1, 0, 0, 0, 87, 11, 1, 0, 0, 0, 10, 29, 39, 47, 49, 58, 70, 73, 77, 83, 86]

View File

@@ -121,6 +121,28 @@ public class FiltersBaseListener implements FiltersListener {
public void exitInExpression(FiltersParser.InExpressionContext ctx) {
}
/**
* {@inheritDoc}
*
* <p>
* The default implementation does nothing.
* </p>
*/
@Override
public void enterNotExpression(FiltersParser.NotExpressionContext ctx) {
}
/**
* {@inheritDoc}
*
* <p>
* The default implementation does nothing.
* </p>
*/
@Override
public void exitNotExpression(FiltersParser.NotExpressionContext ctx) {
}
/**
* {@inheritDoc}
*

View File

@@ -86,6 +86,19 @@ public class FiltersBaseVisitor<T> extends AbstractParseTreeVisitor<T> implement
return visitChildren(ctx);
}
/**
* {@inheritDoc}
*
* <p>
* The default implementation returns the result of calling {@link #visitChildren} on
* {@code ctx}.
* </p>
*/
@Override
public T visitNotExpression(FiltersParser.NotExpressionContext ctx) {
return visitChildren(ctx);
}
/**
* {@inheritDoc}
*

View File

@@ -83,6 +83,20 @@ public interface FiltersListener extends ParseTreeListener {
*/
void exitInExpression(FiltersParser.InExpressionContext ctx);
/**
* Enter a parse tree produced by the {@code NotExpression} labeled alternative in
* {@link FiltersParser#booleanExpression}.
* @param ctx the parse tree
*/
void enterNotExpression(FiltersParser.NotExpressionContext ctx);
/**
* Exit a parse tree produced by the {@code NotExpression} labeled alternative in
* {@link FiltersParser#booleanExpression}.
* @param ctx the parse tree
*/
void exitNotExpression(FiltersParser.NotExpressionContext ctx);
/**
* Enter a parse tree produced by the {@code CompareExpression} labeled alternative in
* {@link FiltersParser#booleanExpression}.

View File

@@ -358,6 +358,43 @@ public class FiltersParser extends Parser {
}
@SuppressWarnings("CheckReturnValue")
public static class NotExpressionContext extends BooleanExpressionContext {
public TerminalNode NOT() {
return getToken(FiltersParser.NOT, 0);
}
public BooleanExpressionContext booleanExpression() {
return getRuleContext(BooleanExpressionContext.class, 0);
}
public NotExpressionContext(BooleanExpressionContext ctx) {
copyFrom(ctx);
}
@Override
public void enterRule(ParseTreeListener listener) {
if (listener instanceof FiltersListener)
((FiltersListener) listener).enterNotExpression(this);
}
@Override
public void exitRule(ParseTreeListener listener) {
if (listener instanceof FiltersListener)
((FiltersListener) listener).exitNotExpression(this);
}
@Override
public <T> T accept(ParseTreeVisitor<? extends T> visitor) {
if (visitor instanceof FiltersVisitor)
return ((FiltersVisitor<? extends T>) visitor).visitNotExpression(this);
else
return visitor.visitChildren(this);
}
}
@SuppressWarnings("CheckReturnValue")
public static class CompareExpressionContext extends BooleanExpressionContext {
@@ -502,7 +539,7 @@ public class FiltersParser extends Parser {
int _alt;
enterOuterAlt(_localctx, 1);
{
setState(37);
setState(39);
_errHandler.sync(this);
switch (getInterpreter().adaptivePredict(_input, 1, _ctx)) {
case 1: {
@@ -570,9 +607,19 @@ public class FiltersParser extends Parser {
match(RIGHT_PARENTHESIS);
}
break;
case 5: {
_localctx = new NotExpressionContext(_localctx);
_ctx = _localctx;
_prevctx = _localctx;
setState(37);
match(NOT);
setState(38);
booleanExpression(1);
}
break;
}
_ctx.stop = _input.LT(-1);
setState(47);
setState(49);
_errHandler.sync(this);
_alt = getInterpreter().adaptivePredict(_input, 3, _ctx);
while (_alt != 2 && _alt != org.antlr.v4.runtime.atn.ATN.INVALID_ALT_NUMBER) {
@@ -581,7 +628,7 @@ public class FiltersParser extends Parser {
triggerExitRuleEvent();
_prevctx = _localctx;
{
setState(45);
setState(47);
_errHandler.sync(this);
switch (getInterpreter().adaptivePredict(_input, 2, _ctx)) {
case 1: {
@@ -589,13 +636,13 @@ public class FiltersParser extends Parser {
new BooleanExpressionContext(_parentctx, _parentState));
((AndExpressionContext) _localctx).left = _prevctx;
pushNewRecursionContext(_localctx, _startState, RULE_booleanExpression);
setState(39);
if (!(precpred(_ctx, 3)))
throw new FailedPredicateException(this, "precpred(_ctx, 3)");
setState(40);
((AndExpressionContext) _localctx).operator = match(AND);
setState(41);
((AndExpressionContext) _localctx).right = booleanExpression(4);
if (!(precpred(_ctx, 4)))
throw new FailedPredicateException(this, "precpred(_ctx, 4)");
setState(42);
((AndExpressionContext) _localctx).operator = match(AND);
setState(43);
((AndExpressionContext) _localctx).right = booleanExpression(5);
}
break;
case 2: {
@@ -603,19 +650,19 @@ public class FiltersParser extends Parser {
new BooleanExpressionContext(_parentctx, _parentState));
((OrExpressionContext) _localctx).left = _prevctx;
pushNewRecursionContext(_localctx, _startState, RULE_booleanExpression);
setState(42);
if (!(precpred(_ctx, 2)))
throw new FailedPredicateException(this, "precpred(_ctx, 2)");
setState(43);
((OrExpressionContext) _localctx).operator = match(OR);
setState(44);
((OrExpressionContext) _localctx).right = booleanExpression(3);
if (!(precpred(_ctx, 3)))
throw new FailedPredicateException(this, "precpred(_ctx, 3)");
setState(45);
((OrExpressionContext) _localctx).operator = match(OR);
setState(46);
((OrExpressionContext) _localctx).right = booleanExpression(4);
}
break;
}
}
}
setState(49);
setState(51);
_errHandler.sync(this);
_alt = getInterpreter().adaptivePredict(_input, 3, _ctx);
}
@@ -697,27 +744,27 @@ public class FiltersParser extends Parser {
try {
enterOuterAlt(_localctx, 1);
{
setState(50);
setState(52);
match(LEFT_SQUARE_BRACKETS);
setState(51);
setState(53);
constant();
setState(56);
setState(58);
_errHandler.sync(this);
_la = _input.LA(1);
while (_la == COMMA) {
{
{
setState(52);
setState(54);
match(COMMA);
setState(53);
setState(55);
constant();
}
}
setState(58);
setState(60);
_errHandler.sync(this);
_la = _input.LA(1);
}
setState(59);
setState(61);
match(RIGHT_SQUARE_BRACKETS);
}
}
@@ -797,7 +844,7 @@ public class FiltersParser extends Parser {
try {
enterOuterAlt(_localctx, 1);
{
setState(61);
setState(63);
_la = _input.LA(1);
if (!((((_la) & ~0x3f) == 0 && ((1L << _la) & 63744L) != 0))) {
_errHandler.recoverInline(this);
@@ -875,28 +922,28 @@ public class FiltersParser extends Parser {
IdentifierContext _localctx = new IdentifierContext(_ctx, getState());
enterRule(_localctx, 8, RULE_identifier);
try {
setState(68);
setState(70);
_errHandler.sync(this);
switch (getInterpreter().adaptivePredict(_input, 5, _ctx)) {
case 1:
enterOuterAlt(_localctx, 1); {
setState(63);
match(IDENTIFIER);
setState(64);
match(DOT);
setState(65);
match(IDENTIFIER);
setState(66);
match(DOT);
setState(67);
match(IDENTIFIER);
}
break;
case 2:
enterOuterAlt(_localctx, 2); {
setState(66);
setState(68);
match(IDENTIFIER);
}
break;
case 3:
enterOuterAlt(_localctx, 3); {
setState(67);
setState(69);
match(QUOTED_STRING);
}
break;
@@ -1092,18 +1139,18 @@ public class FiltersParser extends Parser {
int _la;
try {
int _alt;
setState(84);
setState(86);
_errHandler.sync(this);
switch (getInterpreter().adaptivePredict(_input, 9, _ctx)) {
case 1:
_localctx = new IntegerConstantContext(_localctx);
enterOuterAlt(_localctx, 1); {
setState(71);
setState(73);
_errHandler.sync(this);
_la = _input.LA(1);
if (_la == MINUS || _la == PLUS) {
{
setState(70);
setState(72);
_la = _input.LA(1);
if (!(_la == MINUS || _la == PLUS)) {
_errHandler.recoverInline(this);
@@ -1117,19 +1164,19 @@ public class FiltersParser extends Parser {
}
}
setState(73);
setState(75);
match(INTEGER_VALUE);
}
break;
case 2:
_localctx = new DecimalConstantContext(_localctx);
enterOuterAlt(_localctx, 2); {
setState(75);
setState(77);
_errHandler.sync(this);
_la = _input.LA(1);
if (_la == MINUS || _la == PLUS) {
{
setState(74);
setState(76);
_la = _input.LA(1);
if (!(_la == MINUS || _la == PLUS)) {
_errHandler.recoverInline(this);
@@ -1143,21 +1190,21 @@ public class FiltersParser extends Parser {
}
}
setState(77);
setState(79);
match(DECIMAL_VALUE);
}
break;
case 3:
_localctx = new TextConstantContext(_localctx);
enterOuterAlt(_localctx, 3); {
setState(79);
setState(81);
_errHandler.sync(this);
_alt = 1;
do {
switch (_alt) {
case 1: {
{
setState(78);
setState(80);
match(QUOTED_STRING);
}
}
@@ -1165,7 +1212,7 @@ public class FiltersParser extends Parser {
default:
throw new NoViableAltException(this);
}
setState(81);
setState(83);
_errHandler.sync(this);
_alt = getInterpreter().adaptivePredict(_input, 8, _ctx);
}
@@ -1175,7 +1222,7 @@ public class FiltersParser extends Parser {
case 4:
_localctx = new BooleanConstantContext(_localctx);
enterOuterAlt(_localctx, 4); {
setState(83);
setState(85);
match(BOOLEAN_VALUE);
}
break;
@@ -1203,67 +1250,68 @@ public class FiltersParser extends Parser {
private boolean booleanExpression_sempred(BooleanExpressionContext _localctx, int predIndex) {
switch (predIndex) {
case 0:
return precpred(_ctx, 3);
return precpred(_ctx, 4);
case 1:
return precpred(_ctx, 2);
return precpred(_ctx, 3);
}
return true;
}
public static final String _serializedATN = "\u0004\u0001\u001aW\u0002\u0000\u0007\u0000\u0002\u0001\u0007\u0001\u0002"
public static final String _serializedATN = "\u0004\u0001\u001aY\u0002\u0000\u0007\u0000\u0002\u0001\u0007\u0001\u0002"
+ "\u0002\u0007\u0002\u0002\u0003\u0007\u0003\u0002\u0004\u0007\u0004\u0002"
+ "\u0005\u0007\u0005\u0001\u0000\u0001\u0000\u0001\u0000\u0001\u0000\u0001"
+ "\u0001\u0001\u0001\u0001\u0001\u0001\u0001\u0001\u0001\u0001\u0001\u0001"
+ "\u0001\u0001\u0001\u0001\u0001\u0001\u0001\u0001\u0001\u0001\u0001\u0001"
+ "\u0001\u0003\u0001\u001e\b\u0001\u0001\u0001\u0001\u0001\u0001\u0001\u0001"
+ "\u0001\u0001\u0001\u0001\u0001\u0003\u0001&\b\u0001\u0001\u0001\u0001"
+ "\u0001\u0001\u0001\u0001\u0001\u0001\u0001\u0001\u0001\u0005\u0001.\b"
+ "\u0001\n\u0001\f\u00011\t\u0001\u0001\u0002\u0001\u0002\u0001\u0002\u0001"
+ "\u0002\u0005\u00027\b\u0002\n\u0002\f\u0002:\t\u0002\u0001\u0002\u0001"
+ "\u0002\u0001\u0003\u0001\u0003\u0001\u0004\u0001\u0004\u0001\u0004\u0001"
+ "\u0004\u0001\u0004\u0003\u0004E\b\u0004\u0001\u0005\u0003\u0005H\b\u0005"
+ "\u0001\u0005\u0001\u0005\u0003\u0005L\b\u0005\u0001\u0005\u0001\u0005"
+ "\u0004\u0005P\b\u0005\u000b\u0005\f\u0005Q\u0001\u0005\u0003\u0005U\b"
+ "\u0005\u0001\u0005\u0000\u0001\u0002\u0006\u0000\u0002\u0004\u0006\b\n"
+ "\u0000\u0002\u0002\u0000\b\b\u000b\u000f\u0001\u0000\t\n_\u0000\f\u0001"
+ "\u0000\u0000\u0000\u0002%\u0001\u0000\u0000\u0000\u00042\u0001\u0000\u0000"
+ "\u0000\u0006=\u0001\u0000\u0000\u0000\bD\u0001\u0000\u0000\u0000\nT\u0001"
+ "\u0000\u0000\u0000\f\r\u0005\u0001\u0000\u0000\r\u000e\u0003\u0002\u0001"
+ "\u0000\u000e\u000f\u0005\u0000\u0000\u0001\u000f\u0001\u0001\u0000\u0000"
+ "\u0000\u0010\u0011\u0006\u0001\uffff\uffff\u0000\u0011\u0012\u0003\b\u0004"
+ "\u0000\u0012\u0013\u0003\u0006\u0003\u0000\u0013\u0014\u0003\n\u0005\u0000"
+ "\u0014&\u0001\u0000\u0000\u0000\u0015\u0016\u0003\b\u0004\u0000\u0016"
+ "\u0017\u0005\u0012\u0000\u0000\u0017\u0018\u0003\u0004\u0002\u0000\u0018"
+ "&\u0001\u0000\u0000\u0000\u0019\u001d\u0003\b\u0004\u0000\u001a\u001b"
+ "\u0005\u0014\u0000\u0000\u001b\u001e\u0005\u0012\u0000\u0000\u001c\u001e"
+ "\u0005\u0013\u0000\u0000\u001d\u001a\u0001\u0000\u0000\u0000\u001d\u001c"
+ "\u0001\u0000\u0000\u0000\u001e\u001f\u0001\u0000\u0000\u0000\u001f \u0003"
+ "\u0004\u0002\u0000 &\u0001\u0000\u0000\u0000!\"\u0005\u0006\u0000\u0000"
+ "\"#\u0003\u0002\u0001\u0000#$\u0005\u0007\u0000\u0000$&\u0001\u0000\u0000"
+ "\u0000%\u0010\u0001\u0000\u0000\u0000%\u0015\u0001\u0000\u0000\u0000%"
+ "\u0019\u0001\u0000\u0000\u0000%!\u0001\u0000\u0000\u0000&/\u0001\u0000"
+ "\u0000\u0000\'(\n\u0003\u0000\u0000()\u0005\u0010\u0000\u0000).\u0003"
+ "\u0002\u0001\u0004*+\n\u0002\u0000\u0000+,\u0005\u0011\u0000\u0000,.\u0003"
+ "\u0002\u0001\u0003-\'\u0001\u0000\u0000\u0000-*\u0001\u0000\u0000\u0000"
+ ".1\u0001\u0000\u0000\u0000/-\u0001\u0000\u0000\u0000/0\u0001\u0000\u0000"
+ "\u00000\u0003\u0001\u0000\u0000\u00001/\u0001\u0000\u0000\u000023\u0005"
+ "\u0004\u0000\u000038\u0003\n\u0005\u000045\u0005\u0003\u0000\u000057\u0003"
+ "\n\u0005\u000064\u0001\u0000\u0000\u00007:\u0001\u0000\u0000\u000086\u0001"
+ "\u0000\u0000\u000089\u0001\u0000\u0000\u00009;\u0001\u0000\u0000\u0000"
+ ":8\u0001\u0000\u0000\u0000;<\u0005\u0005\u0000\u0000<\u0005\u0001\u0000"
+ "\u0000\u0000=>\u0007\u0000\u0000\u0000>\u0007\u0001\u0000\u0000\u0000"
+ "?@\u0005\u0019\u0000\u0000@A\u0005\u0002\u0000\u0000AE\u0005\u0019\u0000"
+ "\u0000BE\u0005\u0019\u0000\u0000CE\u0005\u0016\u0000\u0000D?\u0001\u0000"
+ "\u0000\u0000DB\u0001\u0000\u0000\u0000DC\u0001\u0000\u0000\u0000E\t\u0001"
+ "\u0000\u0000\u0000FH\u0007\u0001\u0000\u0000GF\u0001\u0000\u0000\u0000"
+ "GH\u0001\u0000\u0000\u0000HI\u0001\u0000\u0000\u0000IU\u0005\u0017\u0000"
+ "\u0000JL\u0007\u0001\u0000\u0000KJ\u0001\u0000\u0000\u0000KL\u0001\u0000"
+ "\u0000\u0000LM\u0001\u0000\u0000\u0000MU\u0005\u0018\u0000\u0000NP\u0005"
+ "\u0016\u0000\u0000ON\u0001\u0000\u0000\u0000PQ\u0001\u0000\u0000\u0000"
+ "QO\u0001\u0000\u0000\u0000QR\u0001\u0000\u0000\u0000RU\u0001\u0000\u0000"
+ "\u0000SU\u0005\u0015\u0000\u0000TG\u0001\u0000\u0000\u0000TK\u0001\u0000"
+ "\u0000\u0000TO\u0001\u0000\u0000\u0000TS\u0001\u0000\u0000\u0000U\u000b"
+ "\u0001\u0000\u0000\u0000\n\u001d%-/8DGKQT";
+ "\u0001\u0001\u0001\u0001\u0001\u0001\u0001\u0001\u0001\u0003\u0001(\b"
+ "\u0001\u0001\u0001\u0001\u0001\u0001\u0001\u0001\u0001\u0001\u0001\u0001"
+ "\u0001\u0005\u00010\b\u0001\n\u0001\f\u00013\t\u0001\u0001\u0002\u0001"
+ "\u0002\u0001\u0002\u0001\u0002\u0005\u00029\b\u0002\n\u0002\f\u0002<\t"
+ "\u0002\u0001\u0002\u0001\u0002\u0001\u0003\u0001\u0003\u0001\u0004\u0001"
+ "\u0004\u0001\u0004\u0001\u0004\u0001\u0004\u0003\u0004G\b\u0004\u0001"
+ "\u0005\u0003\u0005J\b\u0005\u0001\u0005\u0001\u0005\u0003\u0005N\b\u0005"
+ "\u0001\u0005\u0001\u0005\u0004\u0005R\b\u0005\u000b\u0005\f\u0005S\u0001"
+ "\u0005\u0003\u0005W\b\u0005\u0001\u0005\u0000\u0001\u0002\u0006\u0000"
+ "\u0002\u0004\u0006\b\n\u0000\u0002\u0002\u0000\b\b\u000b\u000f\u0001\u0000"
+ "\t\nb\u0000\f\u0001\u0000\u0000\u0000\u0002\'\u0001\u0000\u0000\u0000"
+ "\u00044\u0001\u0000\u0000\u0000\u0006?\u0001\u0000\u0000\u0000\bF\u0001"
+ "\u0000\u0000\u0000\nV\u0001\u0000\u0000\u0000\f\r\u0005\u0001\u0000\u0000"
+ "\r\u000e\u0003\u0002\u0001\u0000\u000e\u000f\u0005\u0000\u0000\u0001\u000f"
+ "\u0001\u0001\u0000\u0000\u0000\u0010\u0011\u0006\u0001\uffff\uffff\u0000"
+ "\u0011\u0012\u0003\b\u0004\u0000\u0012\u0013\u0003\u0006\u0003\u0000\u0013"
+ "\u0014\u0003\n\u0005\u0000\u0014(\u0001\u0000\u0000\u0000\u0015\u0016"
+ "\u0003\b\u0004\u0000\u0016\u0017\u0005\u0012\u0000\u0000\u0017\u0018\u0003"
+ "\u0004\u0002\u0000\u0018(\u0001\u0000\u0000\u0000\u0019\u001d\u0003\b"
+ "\u0004\u0000\u001a\u001b\u0005\u0014\u0000\u0000\u001b\u001e\u0005\u0012"
+ "\u0000\u0000\u001c\u001e\u0005\u0013\u0000\u0000\u001d\u001a\u0001\u0000"
+ "\u0000\u0000\u001d\u001c\u0001\u0000\u0000\u0000\u001e\u001f\u0001\u0000"
+ "\u0000\u0000\u001f \u0003\u0004\u0002\u0000 (\u0001\u0000\u0000\u0000"
+ "!\"\u0005\u0006\u0000\u0000\"#\u0003\u0002\u0001\u0000#$\u0005\u0007\u0000"
+ "\u0000$(\u0001\u0000\u0000\u0000%&\u0005\u0014\u0000\u0000&(\u0003\u0002"
+ "\u0001\u0001\'\u0010\u0001\u0000\u0000\u0000\'\u0015\u0001\u0000\u0000"
+ "\u0000\'\u0019\u0001\u0000\u0000\u0000\'!\u0001\u0000\u0000\u0000\'%\u0001"
+ "\u0000\u0000\u0000(1\u0001\u0000\u0000\u0000)*\n\u0004\u0000\u0000*+\u0005"
+ "\u0010\u0000\u0000+0\u0003\u0002\u0001\u0005,-\n\u0003\u0000\u0000-.\u0005"
+ "\u0011\u0000\u0000.0\u0003\u0002\u0001\u0004/)\u0001\u0000\u0000\u0000"
+ "/,\u0001\u0000\u0000\u000003\u0001\u0000\u0000\u00001/\u0001\u0000\u0000"
+ "\u000012\u0001\u0000\u0000\u00002\u0003\u0001\u0000\u0000\u000031\u0001"
+ "\u0000\u0000\u000045\u0005\u0004\u0000\u00005:\u0003\n\u0005\u000067\u0005"
+ "\u0003\u0000\u000079\u0003\n\u0005\u000086\u0001\u0000\u0000\u00009<\u0001"
+ "\u0000\u0000\u0000:8\u0001\u0000\u0000\u0000:;\u0001\u0000\u0000\u0000"
+ ";=\u0001\u0000\u0000\u0000<:\u0001\u0000\u0000\u0000=>\u0005\u0005\u0000"
+ "\u0000>\u0005\u0001\u0000\u0000\u0000?@\u0007\u0000\u0000\u0000@\u0007"
+ "\u0001\u0000\u0000\u0000AB\u0005\u0019\u0000\u0000BC\u0005\u0002\u0000"
+ "\u0000CG\u0005\u0019\u0000\u0000DG\u0005\u0019\u0000\u0000EG\u0005\u0016"
+ "\u0000\u0000FA\u0001\u0000\u0000\u0000FD\u0001\u0000\u0000\u0000FE\u0001"
+ "\u0000\u0000\u0000G\t\u0001\u0000\u0000\u0000HJ\u0007\u0001\u0000\u0000"
+ "IH\u0001\u0000\u0000\u0000IJ\u0001\u0000\u0000\u0000JK\u0001\u0000\u0000"
+ "\u0000KW\u0005\u0017\u0000\u0000LN\u0007\u0001\u0000\u0000ML\u0001\u0000"
+ "\u0000\u0000MN\u0001\u0000\u0000\u0000NO\u0001\u0000\u0000\u0000OW\u0005"
+ "\u0018\u0000\u0000PR\u0005\u0016\u0000\u0000QP\u0001\u0000\u0000\u0000"
+ "RS\u0001\u0000\u0000\u0000SQ\u0001\u0000\u0000\u0000ST\u0001\u0000\u0000"
+ "\u0000TW\u0001\u0000\u0000\u0000UW\u0005\u0015\u0000\u0000VI\u0001\u0000"
+ "\u0000\u0000VM\u0001\u0000\u0000\u0000VQ\u0001\u0000\u0000\u0000VU\u0001"
+ "\u0000\u0000\u0000W\u000b\u0001\u0000\u0000\u0000\n\u001d\'/1:FIMSV";
public static final ATN _ATN = new ATNDeserializer().deserialize(_serializedATN.toCharArray());
static {

View File

@@ -63,6 +63,14 @@ public interface FiltersVisitor<T> extends ParseTreeVisitor<T> {
*/
T visitInExpression(FiltersParser.InExpressionContext ctx);
/**
* Visit a parse tree produced by the {@code NotExpression} labeled alternative in
* {@link FiltersParser#booleanExpression}.
* @param ctx the parse tree
* @return the visitor result
*/
T visitNotExpression(FiltersParser.NotExpressionContext ctx);
/**
* Visit a parse tree produced by the {@code CompareExpression} labeled alternative in
* {@link FiltersParser#booleanExpression}.

View File

@@ -19,6 +19,7 @@ package org.springframework.ai.vectorstore.filter.converter;
import java.util.List;
import org.springframework.ai.vectorstore.filter.Filter;
import org.springframework.ai.vectorstore.filter.FilterHelper;
import org.springframework.ai.vectorstore.filter.Filter.Expression;
import org.springframework.ai.vectorstore.filter.Filter.ExpressionType;
import org.springframework.ai.vectorstore.filter.Filter.Group;
@@ -52,14 +53,27 @@ public abstract class AbstractFilterExpressionConverter implements FilterExpress
this.doValue(value, context);
}
else if (operand instanceof Filter.Expression expression) {
if ((expression.type() != ExpressionType.AND && expression.type() != ExpressionType.OR)
&& !(expression.right() instanceof Filter.Value)) {
if ((expression.type() != ExpressionType.NOT && expression.type() != ExpressionType.AND
&& expression.type() != ExpressionType.OR) && !(expression.right() instanceof Filter.Value)) {
throw new RuntimeException("Non AND/OR expression must have Value right argument!");
}
this.doExpression(expression, context);
if (expression.type() == ExpressionType.NOT) {
this.doNot(expression, context);
}
else {
this.doExpression(expression, context);
}
}
}
protected void doNot(Filter.Expression expression, StringBuilder context) {
// Default behavior is to convert the NOT expression into its semantically
// equivalent negation expression.
// Effectively removing the NOT types form the boolean expression tree before
// passing it to the doExpression.
this.convertOperand(FilterHelper.negate(expression), context);
}
protected abstract void doExpression(Filter.Expression expression, StringBuilder context);
protected abstract void doKey(Filter.Key filterKey, StringBuilder context);

View File

@@ -33,6 +33,7 @@ booleanExpression
| left=booleanExpression operator=AND right=booleanExpression # AndExpression
| left=booleanExpression operator=OR right=booleanExpression # OrExpression
| LEFT_PARENTHESIS booleanExpression RIGHT_PARENTHESIS # GroupExpression
| NOT booleanExpression # NotExpression
;
constantArray

View File

@@ -33,6 +33,7 @@ import static org.springframework.ai.vectorstore.filter.Filter.ExpressionType.IN
import static org.springframework.ai.vectorstore.filter.Filter.ExpressionType.NE;
import static org.springframework.ai.vectorstore.filter.Filter.ExpressionType.NIN;
import static org.springframework.ai.vectorstore.filter.Filter.ExpressionType.OR;
import static org.springframework.ai.vectorstore.filter.Filter.ExpressionType.NOT;
/**
* @author Christian Tzolov
@@ -97,4 +98,18 @@ public class FilterExpressionBuilderTests {
new Expression(IN, new Key("country"), new Value(List.of("BG", "NL", "US")))));
}
@Test
public void tesNot() {
// isOpen == true AND year >= 2020 AND country IN ["BG", "NL", "US"]
var exp = b.not(b.and(b.and(b.eq("isOpen", true), b.gte("year", 2020)), b.in("country", "BG", "NL", "US")))
.build();
assertThat(exp).isEqualTo(new Expression(NOT,
new Expression(AND,
new Expression(AND, new Expression(EQ, new Key("isOpen"), new Value(true)),
new Expression(GTE, new Key("year"), new Value(2020))),
new Expression(IN, new Key("country"), new Value(List.of("BG", "NL", "US")))),
null));
}
}

View File

@@ -30,10 +30,11 @@ import static org.springframework.ai.vectorstore.filter.Filter.ExpressionType.AN
import static org.springframework.ai.vectorstore.filter.Filter.ExpressionType.EQ;
import static org.springframework.ai.vectorstore.filter.Filter.ExpressionType.GTE;
import static org.springframework.ai.vectorstore.filter.Filter.ExpressionType.IN;
import static org.springframework.ai.vectorstore.filter.Filter.ExpressionType.LTE;
import static org.springframework.ai.vectorstore.filter.Filter.ExpressionType.NE;
import static org.springframework.ai.vectorstore.filter.Filter.ExpressionType.NIN;
import static org.springframework.ai.vectorstore.filter.Filter.ExpressionType.NOT;
import static org.springframework.ai.vectorstore.filter.Filter.ExpressionType.OR;
import static org.springframework.ai.vectorstore.filter.Filter.ExpressionType.LTE;
/**
* @author Christian Tzolov
@@ -111,6 +112,61 @@ public class FilterExpressionTextParserTests {
.get("WHERE " + "isOpen == true AND year >= 2020 AND country IN [\"BG\", \"NL\", \"US\"]")).isEqualTo(exp);
}
@Test
public void tesNot() {
// NOT(isOpen == true AND year >= 2020 AND country IN ["BG", "NL", "US"])
Expression exp = parser.parse("not(isOpen == true AND year >= 2020 AND country IN [\"BG\", \"NL\", \"US\"])");
assertThat(exp).isEqualTo(new Expression(NOT,
new Group(new Expression(AND,
new Expression(AND, new Expression(EQ, new Key("isOpen"), new Value(true)),
new Expression(GTE, new Key("year"), new Value(2020))),
new Expression(IN, new Key("country"), new Value(List.of("BG", "NL", "US"))))),
null));
assertThat(parser.getCache()
.get("WHERE " + "not(isOpen == true AND year >= 2020 AND country IN [\"BG\", \"NL\", \"US\"])"))
.isEqualTo(exp);
}
@Test
public void tesNotNin() {
// NOT(country NOT IN ["BG", "NL", "US"])
Expression exp = parser.parse("not(country NOT IN [\"BG\", \"NL\", \"US\"])");
assertThat(exp).isEqualTo(new Expression(NOT,
new Group(new Expression(NIN, new Key("country"), new Value(List.of("BG", "NL", "US")))), null));
}
@Test
public void tesNotNin2() {
// NOT country NOT IN ["BG", "NL", "US"]
Expression exp = parser.parse("NOT country NOT IN [\"BG\", \"NL\", \"US\"]");
assertThat(exp).isEqualTo(new Expression(NOT,
new Expression(NIN, new Key("country"), new Value(List.of("BG", "NL", "US"))), null));
}
@Test
public void tesNestedNot() {
// NOT(isOpen == true AND year >= 2020 AND NOT(country IN ["BG", "NL", "US"]))
Expression exp = parser
.parse("not(isOpen == true AND year >= 2020 AND NOT(country IN [\"BG\", \"NL\", \"US\"]))");
assertThat(exp).isEqualTo(new Expression(NOT,
new Group(new Expression(AND,
new Expression(AND, new Expression(EQ, new Key("isOpen"), new Value(true)),
new Expression(GTE, new Key("year"), new Value(2020))),
new Expression(NOT,
new Group(new Expression(IN, new Key("country"), new Value(List.of("BG", "NL", "US")))),
null))),
null));
assertThat(parser.getCache()
.get("WHERE " + "not(isOpen == true AND year >= 2020 AND NOT(country IN [\"BG\", \"NL\", \"US\"]))"))
.isEqualTo(exp);
}
@Test
public void testDecimal() {
// temperature >= -15.6 && temperature <= +20.13

View File

@@ -0,0 +1,171 @@
/*
* Copyright 2023-2023 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.ai.vectorstore.filter;
import java.util.List;
import org.junit.jupiter.api.Test;
import org.springframework.ai.vectorstore.filter.Filter.Expression;
import org.springframework.ai.vectorstore.filter.Filter.ExpressionType;
import org.springframework.ai.vectorstore.filter.Filter.Key;
import org.springframework.ai.vectorstore.filter.Filter.Value;
import org.springframework.ai.vectorstore.filter.converter.PrintFilterExpressionConverter;
import static org.assertj.core.api.Assertions.assertThat;
/**
* @author Christian Tzolov
*/
public class FilterHelperTests {
@Test
public void negateEQ() {
assertThat(Filter.parser().parse("NOT key == 'UK' ")).isEqualTo(new Filter.Expression(ExpressionType.NOT,
new Filter.Expression(ExpressionType.EQ, new Key("key"), new Value("UK")), null));
assertThat(FilterHelper.negate(Filter.parser().parse("NOT key == 'UK' ")))
.isEqualTo(new Filter.Expression(ExpressionType.NE, new Key("key"), new Value("UK")));
assertThat(FilterHelper.negate(Filter.parser().parse("NOT (key == 'UK') ")))
.isEqualTo(new Filter.Group(new Filter.Expression(ExpressionType.NE, new Key("key"), new Value("UK"))));
}
@Test
public void negateNE() {
var exp = Filter.parser().parse("NOT key != 'UK' ");
assertThat(FilterHelper.negate(exp))
.isEqualTo(new Filter.Expression(ExpressionType.EQ, new Key("key"), new Value("UK")));
}
@Test
public void negateGT() {
var exp = Filter.parser().parse("NOT key > 13 ");
assertThat(FilterHelper.negate(exp))
.isEqualTo(new Filter.Expression(ExpressionType.LTE, new Key("key"), new Value(13)));
}
@Test
public void negateGTE() {
var exp = Filter.parser().parse("NOT key >= 13 ");
assertThat(FilterHelper.negate(exp))
.isEqualTo(new Filter.Expression(ExpressionType.LT, new Key("key"), new Value(13)));
}
@Test
public void negateLT() {
var exp = Filter.parser().parse("NOT key < 13 ");
assertThat(FilterHelper.negate(exp))
.isEqualTo(new Filter.Expression(ExpressionType.GTE, new Key("key"), new Value(13)));
}
@Test
public void negateLTE() {
var exp = Filter.parser().parse("NOT key <= 13 ");
assertThat(FilterHelper.negate(exp))
.isEqualTo(new Filter.Expression(ExpressionType.GT, new Key("key"), new Value(13)));
}
@Test
public void negateIN() {
var exp = Filter.parser().parse("NOT key IN [11, 12, 13] ");
assertThat(FilterHelper.negate(exp))
.isEqualTo(new Filter.Expression(ExpressionType.NIN, new Key("key"), new Value(List.of(11, 12, 13))));
}
@Test
public void negateNIN() {
var exp = Filter.parser().parse("NOT key NIN [11, 12, 13] ");
assertThat(FilterHelper.negate(exp))
.isEqualTo(new Filter.Expression(ExpressionType.IN, new Key("key"), new Value(List.of(11, 12, 13))));
}
@Test
public void negateNIN2() {
var exp = Filter.parser().parse("NOT key NOT IN [11, 12, 13] ");
assertThat(FilterHelper.negate(exp))
.isEqualTo(new Filter.Expression(ExpressionType.IN, new Key("key"), new Value(List.of(11, 12, 13))));
}
@Test
public void negateAND() {
var exp = Filter.parser().parse("NOT(key >= 11 AND key < 13)");
assertThat(FilterHelper.negate(exp)).isEqualTo(new Filter.Group(new Filter.Expression(ExpressionType.OR,
new Filter.Expression(ExpressionType.LT, new Key("key"), new Value(11)),
new Filter.Expression(ExpressionType.GTE, new Key("key"), new Value(13)))));
}
@Test
public void negateOR() {
var exp = Filter.parser().parse("NOT(key >= 11 OR key < 13)");
assertThat(FilterHelper.negate(exp)).isEqualTo(new Filter.Group(new Filter.Expression(ExpressionType.AND,
new Filter.Expression(ExpressionType.LT, new Key("key"), new Value(11)),
new Filter.Expression(ExpressionType.GTE, new Key("key"), new Value(13)))));
}
@Test
public void negateNot() {
var exp = Filter.parser().parse("NOT NOT(key >= 11)");
assertThat(FilterHelper.negate(exp))
.isEqualTo(new Filter.Group(new Filter.Expression(ExpressionType.LT, new Key("key"), new Value(11))));
}
@Test
public void negateNestedNot() {
var exp = Filter.parser().parse("NOT(NOT(key >= 11))");
assertThat(exp).isEqualTo(
new Filter.Expression(ExpressionType.NOT, new Filter.Group(new Filter.Expression(ExpressionType.NOT,
new Filter.Group(new Filter.Expression(ExpressionType.GTE, new Key("key"), new Value(11)))))));
assertThat(FilterHelper.negate(exp))
.isEqualTo(new Filter.Group(new Filter.Expression(ExpressionType.LT, new Key("key"), new Value(11))));
}
@Test
public void expandIN() {
var exp = Filter.parser().parse("key IN [11, 12, 13] ");
assertThat(new InNinTestConverter().convertExpression(exp)).isEqualTo("key EQ 11 OR key EQ 12 OR key EQ 13");
}
@Test
public void expandNIN() {
var exp1 = Filter.parser().parse("key NIN [11, 12, 13] ");
var exp2 = Filter.parser().parse("key NOT IN [11, 12, 13] ");
assertThat(exp1).isEqualTo(exp2);
assertThat(new InNinTestConverter().convertExpression(exp1)).isEqualTo("key NE 11 AND key NE 12 AND key NE 13");
}
private static class InNinTestConverter extends PrintFilterExpressionConverter {
@Override
public void doExpression(Expression expression, StringBuilder context) {
if (expression.type() == ExpressionType.IN) {
FilterHelper.expandIn(expression, context, this);
}
else if (expression.type() == ExpressionType.NIN) {
FilterHelper.expandNin(expression, context, this);
}
else {
super.doExpression(expression, context);
}
}
};
}

View File

@@ -113,7 +113,7 @@ One approach involves presenting both the user's request and the AI model's resp
Furthermore, leveraging the information stored in the Vector Database as supplementary data can enhance the evaluation process, aiding in the determination of response relevance.
The Spring AI project currenlty provides some very basic examples of how you can evaluate the responses in the form of prompts to include in a JUnit test.
The Spring AI project currently provides some very basic examples of how you can evaluate the responses in the form of prompts to include in a JUnit test.

View File

@@ -87,8 +87,7 @@ country == 'UK' && year >= 2020 && isActive == true.
These are the available implementations of the `VectorStore` interface:
* `InMemoryVectorStore`
* `SimplePersistentVectorStore`
* `InMemoryVectorStore` and `SimplePersistentVectorStore`.
* Pinecone: https://www.pinecone.io/[PineCone] vector store.
* PgVector [`PgVectorStore`]: The https://github.com/pgvector/pgvector[PostgreSQL/PGVector] vector store.
* Milvus [`MilvusVectorStore`]: The https://milvus.io/[Milvus] vector store
@@ -117,7 +116,7 @@ The `VectorStore` implementation computes the embeddings and stores the JSON and
@Autowired
VectorStore vectorStore;
void load(String sourceFile) {}
void load(String sourceFile) {
JsonReader jsonReader = new JsonReader(new FileSystemResource(sourceFile),
"price", "name", "shortDescription", "description", "tags");
List<Document> documents = jsonReader.get();

View File

@@ -166,11 +166,20 @@ public class AzureVectorStoreIT {
results = vectorStore.similaritySearch(SearchRequest.query("The World")
.withTopK(5)
.withSimilarityThresholdAll()
.withFilterExpression("country nin ['BG']"));
.withFilterExpression("country not in ['BG']"));
assertThat(results).hasSize(1);
assertThat(results.get(0).getId()).isEqualTo(nlDocument.getId());
results = vectorStore.similaritySearch(SearchRequest.query("The World")
.withTopK(5)
.withSimilarityThresholdAll()
.withFilterExpression("NOT(country not in ['BG'])"));
assertThat(results).hasSize(2);
assertThat(results.get(0).getId()).isIn(bgDocument.getId(), bgDocument2.getId());
assertThat(results.get(1).getId()).isIn(bgDocument.getId(), bgDocument2.getId());
// List<Document> results =
// vectorStore.similaritySearch(SearchRequest.query("The World")
// .withTopK(5)

View File

@@ -122,6 +122,11 @@ public class ChromaVectorStoreIT {
assertThat(results).hasSize(1);
assertThat(results.get(0).getId()).isEqualTo(nlDocument.getId());
results = vectorStore.similaritySearch(
request.withSimilarityThresholdAll().withFilterExpression("NOT(country == 'Netherland')"));
assertThat(results).hasSize(1);
assertThat(results.get(0).getId()).isEqualTo(bgDocument.getId());
// Remove all documents from the store
vectorStore.delete(List.of(bgDocument, nlDocument).stream().map(doc -> doc.getId()).toList());
});

View File

@@ -188,6 +188,16 @@ public class MilvusVectorStoreIT {
assertThat(results).hasSize(1);
assertThat(results.get(0).getId()).isEqualTo(bgDocument.getId());
results = vectorStore.similaritySearch(SearchRequest.query("The World")
.withTopK(5)
.withSimilarityThresholdAll()
.withFilterExpression("NOT(country == 'BG' && year == 2020)"));
assertThat(results).hasSize(2);
assertThat(results.get(0).getId()).isIn(nlDocument.getId(), bgDocument2.getId());
assertThat(results.get(1).getId()).isIn(nlDocument.getId(), bgDocument2.getId());
});
}

View File

@@ -182,6 +182,12 @@ public class PgVectorStoreIT {
assertThat(results.get(0).getId()).isIn(bgDocument.getId(), nlDocument.getId());
assertThat(results.get(1).getId()).isIn(bgDocument.getId(), nlDocument.getId());
results = vectorStore.similaritySearch(searchRequest
.withFilterExpression("NOT((country == 'BG' && year == 2020) || (country == 'NL'))"));
assertThat(results).hasSize(1);
assertThat(results.get(0).getId()).isEqualTo(bgDocument2.getId());
results = vectorStore.similaritySearch(SearchRequest.query("The World")
.withTopK(5)
.withSimilarityThresholdAll()

View File

@@ -156,6 +156,13 @@ public class PineconeVectorStoreIT {
assertThat(results).hasSize(1);
assertThat(results.get(0).getId()).isEqualTo(nlDocument.getId());
results = vectorStore.similaritySearch(searchRequest.withTopK(5)
.withSimilarityThresholdAll()
.withFilterExpression("NOT(country == 'Netherland')"));
assertThat(results).hasSize(1);
assertThat(results.get(0).getId()).isEqualTo(bgDocument.getId());
// Remove all documents from the store
vectorStore.delete(List.of(bgDocument, nlDocument).stream().map(doc -> doc.getId()).toList());

View File

@@ -16,7 +16,6 @@
package org.springframework.ai.vectorstore;
import java.util.ArrayList;
import java.util.Date;
import java.util.List;
@@ -27,6 +26,7 @@ import org.springframework.ai.vectorstore.filter.Filter.Expression;
import org.springframework.ai.vectorstore.filter.Filter.ExpressionType;
import org.springframework.ai.vectorstore.filter.Filter.Group;
import org.springframework.ai.vectorstore.filter.Filter.Key;
import org.springframework.ai.vectorstore.filter.FilterHelper;
import org.springframework.ai.vectorstore.filter.converter.AbstractFilterExpressionConverter;
import org.springframework.util.Assert;
@@ -62,10 +62,10 @@ public class WeaviateFilterExpressionConverter extends AbstractFilterExpressionC
protected void doExpression(Expression exp, StringBuilder context) {
if (exp.type() == ExpressionType.IN) {
rewriteInNinExpressions(Filter.ExpressionType.OR, Filter.ExpressionType.EQ, exp, context);
FilterHelper.expandIn(exp, context, this);
}
else if (exp.type() == ExpressionType.NIN) {
rewriteInNinExpressions(Filter.ExpressionType.AND, Filter.ExpressionType.NE, exp, context);
FilterHelper.expandNin(exp, context, this);
}
else if (exp.type() == ExpressionType.AND || exp.type() == ExpressionType.OR) {
context.append(getOperationSymbol(exp));
@@ -82,51 +82,6 @@ public class WeaviateFilterExpressionConverter extends AbstractFilterExpressionC
}
}
/**
* Recursively aggregates a list of expression into a binary tree with 'aggregateType'
* join nodes.
* @param aggregateType type all tree splits.
* @param expressions list of expressions to aggregate.
* @return Returns a binary tree expression.
*/
private Filter.Expression aggregate(Filter.ExpressionType aggregateType, List<Filter.Expression> expressions) {
if (expressions.size() == 1) {
return expressions.get(0);
}
return new Filter.Expression(aggregateType, expressions.get(0),
aggregate(aggregateType, expressions.subList(1, expressions.size())));
}
private void rewriteInNinExpressions(Filter.ExpressionType outerExpressionType,
Filter.ExpressionType innerExpressionType, Expression exp, StringBuilder context) {
if (exp.right() instanceof Filter.Value value) {
if (value.value() instanceof List list) {
// 1. foo IN ["bar1", "bar2", "bar3"] is equivalent to foo == "bar1" ||
// foo == "bar2" || foo == "bar3"
// or equivalent to OR(foo == "bar1" OR( foo == "bar2" OR(foo == "bar3")))
// 2. foo IN ["bar1", "bar2", "bar3"] is equivalent to foo != "bar1" &&
// foo != "bar2" && foo != "bar3"
// or equivalent to AND(foo != "bar1" AND( foo != "bar2" OR(foo !=
// "bar3")))
List<Filter.Expression> eqExprs = new ArrayList<>();
for (Object o : list) {
eqExprs.add(new Filter.Expression(innerExpressionType, exp.left(), new Filter.Value(o)));
}
this.doExpression(aggregate(outerExpressionType, eqExprs), context);
}
else {
// 1. foo IN ["bar"] is equivalent to foo == "BAR"
// 2. foo NIN ["bar"] is equivalent to foo != "BAR"
this.doExpression(new Filter.Expression(innerExpressionType, exp.left(), exp.right()), context);
}
}
else {
throw new IllegalStateException(
"Filter IN right expression should be of Filter.Value type but was " + exp.right().getClass());
}
}
private String getOperationSymbol(Expression exp) {
switch (exp.type()) {
case AND:

View File

@@ -20,7 +20,6 @@ import java.util.List;
import org.junit.jupiter.api.Test;
import org.springframework.ai.vectorstore.filter.Filter;
import org.springframework.ai.vectorstore.filter.Filter.Expression;
import org.springframework.ai.vectorstore.filter.Filter.Group;
import org.springframework.ai.vectorstore.filter.Filter.Key;

View File

@@ -152,6 +152,14 @@ public class WeaviateVectorStoreIT {
assertThat(results).hasSize(1);
assertThat(results.get(0).getId()).isEqualTo(bgDocument.getId());
results = vectorStore.similaritySearch(SearchRequest.query("The World")
.withTopK(5)
.withSimilarityThresholdAll()
.withFilterExpression("NOT((country == 'BG' && year == 2020) || (country == 'NL'))"));
assertThat(results).hasSize(1);
assertThat(results.get(0).getId()).isEqualTo(bgDocument2.getId());
vectorStore.delete(List.of(bgDocument.getId(), nlDocument.getId(), bgDocument2.getId()));
});
}