diff --git a/spring-ai-core/src/main/java/org/springframework/ai/vectorstore/filter/Filter.java b/spring-ai-core/src/main/java/org/springframework/ai/vectorstore/filter/Filter.java index 3919994db..02c18f6b2 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/vectorstore/filter/Filter.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/vectorstore/filter/Filter.java @@ -114,7 +114,7 @@ public class Filter { */ public enum ExpressionType { - AND, OR, EQ, NE, GT, GTE, LT, LTE, IN, NIN + AND, OR, EQ, NE, GT, GTE, LT, LTE, IN, NIN, NOT } @@ -131,6 +131,9 @@ public class Filter { * be another {@link Expression}. */ public record Expression(ExpressionType type, Operand left, Operand right) implements Operand { + public Expression(ExpressionType type, Operand operand) { + this(type, operand, null); + } } /** diff --git a/spring-ai-core/src/main/java/org/springframework/ai/vectorstore/filter/FilterExpressionBuilder.java b/spring-ai-core/src/main/java/org/springframework/ai/vectorstore/filter/FilterExpressionBuilder.java index 7b6595e86..b0c3e9c25 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/vectorstore/filter/FilterExpressionBuilder.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/vectorstore/filter/FilterExpressionBuilder.java @@ -122,4 +122,8 @@ public class FilterExpressionBuilder { return new Op(new Filter.Group(content.build())); } + public Op not(Op content) { + return new Op(new Filter.Expression(ExpressionType.NOT, content.expression, null)); + } + } diff --git a/spring-ai-core/src/main/java/org/springframework/ai/vectorstore/filter/FilterExpressionTextParser.java b/spring-ai-core/src/main/java/org/springframework/ai/vectorstore/filter/FilterExpressionTextParser.java index dddbf50ba..bbaff2044 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/vectorstore/filter/FilterExpressionTextParser.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/vectorstore/filter/FilterExpressionTextParser.java @@ -35,6 +35,7 @@ import org.antlr.v4.runtime.misc.ParseCancellationException; import org.springframework.ai.vectorstore.filter.antlr4.FiltersBaseVisitor; import org.springframework.ai.vectorstore.filter.antlr4.FiltersLexer; import org.springframework.ai.vectorstore.filter.antlr4.FiltersParser; +import org.springframework.ai.vectorstore.filter.antlr4.FiltersParser.NotExpressionContext; import org.springframework.core.NestedExceptionUtils; import org.springframework.util.Assert; @@ -263,6 +264,11 @@ public class FilterExpressionTextParser { return new Filter.Group(castToExpression(this.visit(ctx.booleanExpression()))); } + @Override + public Filter.Operand visitNotExpression(NotExpressionContext ctx) { + return new Filter.Expression(Filter.ExpressionType.NOT, this.visit(ctx.booleanExpression()), null); + } + public Filter.Expression castToExpression(Filter.Operand expression) { if (expression instanceof Filter.Group group) { // Remove the top-level grouping. diff --git a/spring-ai-core/src/main/java/org/springframework/ai/vectorstore/filter/FilterHelper.java b/spring-ai-core/src/main/java/org/springframework/ai/vectorstore/filter/FilterHelper.java new file mode 100644 index 000000000..0bf221774 --- /dev/null +++ b/spring-ai-core/src/main/java/org/springframework/ai/vectorstore/filter/FilterHelper.java @@ -0,0 +1,205 @@ +/* + * Copyright 2023-2023 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.vectorstore.filter; + +import java.util.ArrayList; +import java.util.List; +import java.util.Map; + +import org.springframework.ai.vectorstore.filter.Filter.Expression; +import org.springframework.ai.vectorstore.filter.Filter.ExpressionType; +import org.springframework.ai.vectorstore.filter.Filter.Operand; +import org.springframework.ai.vectorstore.filter.converter.FilterExpressionConverter; +import org.springframework.util.Assert; + +/** + * Helper class providing various boolean transformation. + * + * @author Christian Tzolov + */ +public class FilterHelper { + + private FilterHelper() { + } + + private final static Map TYPE_NEGATION_MAP = Map.of(ExpressionType.AND, + ExpressionType.OR, ExpressionType.OR, ExpressionType.AND, ExpressionType.EQ, ExpressionType.NE, + ExpressionType.NE, ExpressionType.EQ, ExpressionType.GT, ExpressionType.LTE, ExpressionType.GTE, + ExpressionType.LT, ExpressionType.LT, ExpressionType.GTE, ExpressionType.LTE, ExpressionType.GT, + ExpressionType.IN, ExpressionType.NIN, ExpressionType.NIN, ExpressionType.IN); + + /** + * Transforms the input expression into a semantically equivalent one with negation + * operators propagated thought the expression tree by following the negation rules: + * + *
+	 * 	NOT(NOT(a)) = a
+	 *
+	 * 	NOT(a AND b) = NOT(a) OR NOT(b)
+	 * 	NOT(a OR b) = NOT(a) AND NOT(b)
+	 *
+	 * 	NOT(a EQ b) = a NE b
+	 * 	NOT(a NE b) = a EQ b
+	 *
+	 * 	NOT(a GT b) = a LTE b
+	 * 	NOT(a GTE b) = a LT b
+	 *
+	 * 	NOT(a LT b) = a GTE b
+	 * 	NOT(a LTE b) = a GT b
+	 *
+	 * 	NOT(a IN [...]) = a NIN [...]
+	 * 	NOT(a NIN [...]) = a IN [...]
+	 * 
+ * @param operand Filter expression to negate. + * @return Returns an negation of the input expression. + */ + public static Filter.Operand negate(Filter.Operand operand) { + + if (operand instanceof Filter.Group group) { + Operand inEx = negate(group.content()); + if (inEx instanceof Filter.Group inEx2) { + inEx = inEx2.content(); + } + return new Filter.Group((Expression) inEx); + } + else if (operand instanceof Filter.Expression exp) { + switch (exp.type()) { + case NOT: // NOT(NOT(a)) = a + return negate(exp.left()); + case AND: // NOT(a AND b) = NOT(a) OR NOT(b) + case OR: // NOT(a OR b) = NOT(a) AND NOT(b) + return new Filter.Expression(TYPE_NEGATION_MAP.get(exp.type()), negate(exp.left()), + negate(exp.right())); + case EQ: // NOT(e EQ b) = e NE b + case NE: // NOT(e NE b) = e EQ b + case GT: // NOT(e GT b) = e LTE b + case GTE: // NOT(e GTE b) = e LT b + case LT: // NOT(e LT b) = e GTE b + case LTE: // NOT(e LTE b) = e GT b + return new Filter.Expression(TYPE_NEGATION_MAP.get(exp.type()), exp.left(), exp.right()); + case IN: // NOT(e IN [...]) = e NIN [...] + case NIN: // NOT(e NIN [...]) = e IN [...] + return new Filter.Expression(TYPE_NEGATION_MAP.get(exp.type()), exp.left(), exp.right()); + default: + throw new IllegalArgumentException("Unknown expression type: " + exp.type()); + } + } + else { + throw new IllegalArgumentException("Can not negate operand of type: " + operand.getClass()); + } + } + + /** + * Expands the IN into a semantically equivalent boolean expressions of ORs of EQs. + * Useful for providers that don't provide native IN support. + * + * For example the
+	 * foo IN ["bar1", "bar2", "bar3"]
+	 * 
+ * + * expression is equivalent to + * + *
+	 * {@code foo == "bar1" || foo == "bar2" || foo == "bar3" (e.g. OR(foo EQ "bar1" OR(foo EQ "bar2" OR(foo EQ "bar3")))}
+	 * 
+ * @param exp input IN expression. + * @param context Output native expression. + * @param filterExpressionConverter {@link FilterExpressionConverter} used to compose + * the OR and EQ expanded expressions. + */ + public static void expandIn(Expression exp, StringBuilder context, + FilterExpressionConverter filterExpressionConverter) { + Assert.isTrue(exp.type() == ExpressionType.IN, "Expected IN expressions but was: " + exp.type()); + expandInNinExpressions(ExpressionType.OR, ExpressionType.EQ, exp, context, filterExpressionConverter); + } + + /** + * + * Expands the NIN (e.g. NOT IN) into a semantically equivalent boolean expressions of + * ANDs of NEs. Useful for providers that don't provide native NIN support.
+ * + * For example the + * + *
+	 * foo NIN ["bar1", "bar2", "bar3"] (or foo NOT IN ["bar1", "bar2", "bar3"])
+	 * 
+ * + * express is equivalent to + * + *
+	 * {@code foo != "bar1" && foo != "bar2" && foo != "bar3" (e.g. AND(foo NE "bar1" AND( foo NE "bar2" OR(foo NE "bar3"))) )}
+	 * 
+ * @param exp input NIN expression. + * @param context Output native expression. + * @param filterExpressionConverter {@link FilterExpressionConverter} used to compose + * the AND and NE expanded expressions. + */ + public static void expandNin(Expression exp, StringBuilder context, + FilterExpressionConverter filterExpressionConverter) { + Assert.isTrue(exp.type() == ExpressionType.NIN, "Expected NIN expressions but was: " + exp.type()); + expandInNinExpressions(ExpressionType.AND, ExpressionType.NE, exp, context, filterExpressionConverter); + } + + private static void expandInNinExpressions(Filter.ExpressionType outerExpressionType, + Filter.ExpressionType innerExpressionType, Expression exp, StringBuilder context, + FilterExpressionConverter expressionConverter) { + if (exp.right() instanceof Filter.Value value) { + if (value.value() instanceof List list) { + // 1. foo IN ["bar1", "bar2", "bar3"] is equivalent to foo == "bar1" || + // foo == "bar2" || foo == "bar3" + // or equivalent to OR(foo == "bar1" OR( foo == "bar2" OR(foo == "bar3"))) + // 2. foo IN ["bar1", "bar2", "bar3"] is equivalent to foo != "bar1" && + // foo != "bar2" && foo != "bar3" + // or equivalent to AND(foo != "bar1" AND( foo != "bar2" OR(foo != + // "bar3"))) + List eqExprs = new ArrayList<>(); + for (Object o : list) { + eqExprs.add(new Filter.Expression(innerExpressionType, exp.left(), new Filter.Value(o))); + } + context.append(expressionConverter.convertExpression(aggregate(outerExpressionType, eqExprs))); + } + else { + // 1. foo IN ["bar"] is equivalent to foo == "BAR" + // 2. foo NIN ["bar"] is equivalent to foo != "BAR" + context.append(expressionConverter + .convertExpression(new Filter.Expression(innerExpressionType, exp.left(), exp.right()))); + } + } + else { + throw new IllegalStateException( + "Filter IN right expression should be of Filter.Value type but was " + exp.right().getClass()); + } + } + + /** + * Recursively aggregates a list of expression into a binary tree with 'aggregateType' + * join nodes. + * @param aggregateType type all tree splits. + * @param expressions list of expressions to aggregate. + * @return Returns a binary tree expression. + */ + private static Filter.Expression aggregate(Filter.ExpressionType aggregateType, + List expressions) { + + if (expressions.size() == 1) { + return expressions.get(0); + } + return new Filter.Expression(aggregateType, expressions.get(0), + aggregate(aggregateType, expressions.subList(1, expressions.size()))); + } + +} diff --git a/spring-ai-core/src/main/java/org/springframework/ai/vectorstore/filter/antlr4/Filters.interp b/spring-ai-core/src/main/java/org/springframework/ai/vectorstore/filter/antlr4/Filters.interp index 615892786..51775a8a5 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/vectorstore/filter/antlr4/Filters.interp +++ b/spring-ai-core/src/main/java/org/springframework/ai/vectorstore/filter/antlr4/Filters.interp @@ -66,4 +66,4 @@ constant atn: -[4, 1, 26, 87, 2, 0, 7, 0, 2, 1, 7, 1, 2, 2, 7, 2, 2, 3, 7, 3, 2, 4, 7, 4, 2, 5, 7, 5, 1, 0, 1, 0, 1, 0, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 3, 1, 30, 8, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 3, 1, 38, 8, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 5, 1, 46, 8, 1, 10, 1, 12, 1, 49, 9, 1, 1, 2, 1, 2, 1, 2, 1, 2, 5, 2, 55, 8, 2, 10, 2, 12, 2, 58, 9, 2, 1, 2, 1, 2, 1, 3, 1, 3, 1, 4, 1, 4, 1, 4, 1, 4, 1, 4, 3, 4, 69, 8, 4, 1, 5, 3, 5, 72, 8, 5, 1, 5, 1, 5, 3, 5, 76, 8, 5, 1, 5, 1, 5, 4, 5, 80, 8, 5, 11, 5, 12, 5, 81, 1, 5, 3, 5, 85, 8, 5, 1, 5, 0, 1, 2, 6, 0, 2, 4, 6, 8, 10, 0, 2, 2, 0, 8, 8, 11, 15, 1, 0, 9, 10, 95, 0, 12, 1, 0, 0, 0, 2, 37, 1, 0, 0, 0, 4, 50, 1, 0, 0, 0, 6, 61, 1, 0, 0, 0, 8, 68, 1, 0, 0, 0, 10, 84, 1, 0, 0, 0, 12, 13, 5, 1, 0, 0, 13, 14, 3, 2, 1, 0, 14, 15, 5, 0, 0, 1, 15, 1, 1, 0, 0, 0, 16, 17, 6, 1, -1, 0, 17, 18, 3, 8, 4, 0, 18, 19, 3, 6, 3, 0, 19, 20, 3, 10, 5, 0, 20, 38, 1, 0, 0, 0, 21, 22, 3, 8, 4, 0, 22, 23, 5, 18, 0, 0, 23, 24, 3, 4, 2, 0, 24, 38, 1, 0, 0, 0, 25, 29, 3, 8, 4, 0, 26, 27, 5, 20, 0, 0, 27, 30, 5, 18, 0, 0, 28, 30, 5, 19, 0, 0, 29, 26, 1, 0, 0, 0, 29, 28, 1, 0, 0, 0, 30, 31, 1, 0, 0, 0, 31, 32, 3, 4, 2, 0, 32, 38, 1, 0, 0, 0, 33, 34, 5, 6, 0, 0, 34, 35, 3, 2, 1, 0, 35, 36, 5, 7, 0, 0, 36, 38, 1, 0, 0, 0, 37, 16, 1, 0, 0, 0, 37, 21, 1, 0, 0, 0, 37, 25, 1, 0, 0, 0, 37, 33, 1, 0, 0, 0, 38, 47, 1, 0, 0, 0, 39, 40, 10, 3, 0, 0, 40, 41, 5, 16, 0, 0, 41, 46, 3, 2, 1, 4, 42, 43, 10, 2, 0, 0, 43, 44, 5, 17, 0, 0, 44, 46, 3, 2, 1, 3, 45, 39, 1, 0, 0, 0, 45, 42, 1, 0, 0, 0, 46, 49, 1, 0, 0, 0, 47, 45, 1, 0, 0, 0, 47, 48, 1, 0, 0, 0, 48, 3, 1, 0, 0, 0, 49, 47, 1, 0, 0, 0, 50, 51, 5, 4, 0, 0, 51, 56, 3, 10, 5, 0, 52, 53, 5, 3, 0, 0, 53, 55, 3, 10, 5, 0, 54, 52, 1, 0, 0, 0, 55, 58, 1, 0, 0, 0, 56, 54, 1, 0, 0, 0, 56, 57, 1, 0, 0, 0, 57, 59, 1, 0, 0, 0, 58, 56, 1, 0, 0, 0, 59, 60, 5, 5, 0, 0, 60, 5, 1, 0, 0, 0, 61, 62, 7, 0, 0, 0, 62, 7, 1, 0, 0, 0, 63, 64, 5, 25, 0, 0, 64, 65, 5, 2, 0, 0, 65, 69, 5, 25, 0, 0, 66, 69, 5, 25, 0, 0, 67, 69, 5, 22, 0, 0, 68, 63, 1, 0, 0, 0, 68, 66, 1, 0, 0, 0, 68, 67, 1, 0, 0, 0, 69, 9, 1, 0, 0, 0, 70, 72, 7, 1, 0, 0, 71, 70, 1, 0, 0, 0, 71, 72, 1, 0, 0, 0, 72, 73, 1, 0, 0, 0, 73, 85, 5, 23, 0, 0, 74, 76, 7, 1, 0, 0, 75, 74, 1, 0, 0, 0, 75, 76, 1, 0, 0, 0, 76, 77, 1, 0, 0, 0, 77, 85, 5, 24, 0, 0, 78, 80, 5, 22, 0, 0, 79, 78, 1, 0, 0, 0, 80, 81, 1, 0, 0, 0, 81, 79, 1, 0, 0, 0, 81, 82, 1, 0, 0, 0, 82, 85, 1, 0, 0, 0, 83, 85, 5, 21, 0, 0, 84, 71, 1, 0, 0, 0, 84, 75, 1, 0, 0, 0, 84, 79, 1, 0, 0, 0, 84, 83, 1, 0, 0, 0, 85, 11, 1, 0, 0, 0, 10, 29, 37, 45, 47, 56, 68, 71, 75, 81, 84] \ No newline at end of file +[4, 1, 26, 89, 2, 0, 7, 0, 2, 1, 7, 1, 2, 2, 7, 2, 2, 3, 7, 3, 2, 4, 7, 4, 2, 5, 7, 5, 1, 0, 1, 0, 1, 0, 1, 0, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 3, 1, 30, 8, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 3, 1, 40, 8, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 1, 5, 1, 48, 8, 1, 10, 1, 12, 1, 51, 9, 1, 1, 2, 1, 2, 1, 2, 1, 2, 5, 2, 57, 8, 2, 10, 2, 12, 2, 60, 9, 2, 1, 2, 1, 2, 1, 3, 1, 3, 1, 4, 1, 4, 1, 4, 1, 4, 1, 4, 3, 4, 71, 8, 4, 1, 5, 3, 5, 74, 8, 5, 1, 5, 1, 5, 3, 5, 78, 8, 5, 1, 5, 1, 5, 4, 5, 82, 8, 5, 11, 5, 12, 5, 83, 1, 5, 3, 5, 87, 8, 5, 1, 5, 0, 1, 2, 6, 0, 2, 4, 6, 8, 10, 0, 2, 2, 0, 8, 8, 11, 15, 1, 0, 9, 10, 98, 0, 12, 1, 0, 0, 0, 2, 39, 1, 0, 0, 0, 4, 52, 1, 0, 0, 0, 6, 63, 1, 0, 0, 0, 8, 70, 1, 0, 0, 0, 10, 86, 1, 0, 0, 0, 12, 13, 5, 1, 0, 0, 13, 14, 3, 2, 1, 0, 14, 15, 5, 0, 0, 1, 15, 1, 1, 0, 0, 0, 16, 17, 6, 1, -1, 0, 17, 18, 3, 8, 4, 0, 18, 19, 3, 6, 3, 0, 19, 20, 3, 10, 5, 0, 20, 40, 1, 0, 0, 0, 21, 22, 3, 8, 4, 0, 22, 23, 5, 18, 0, 0, 23, 24, 3, 4, 2, 0, 24, 40, 1, 0, 0, 0, 25, 29, 3, 8, 4, 0, 26, 27, 5, 20, 0, 0, 27, 30, 5, 18, 0, 0, 28, 30, 5, 19, 0, 0, 29, 26, 1, 0, 0, 0, 29, 28, 1, 0, 0, 0, 30, 31, 1, 0, 0, 0, 31, 32, 3, 4, 2, 0, 32, 40, 1, 0, 0, 0, 33, 34, 5, 6, 0, 0, 34, 35, 3, 2, 1, 0, 35, 36, 5, 7, 0, 0, 36, 40, 1, 0, 0, 0, 37, 38, 5, 20, 0, 0, 38, 40, 3, 2, 1, 1, 39, 16, 1, 0, 0, 0, 39, 21, 1, 0, 0, 0, 39, 25, 1, 0, 0, 0, 39, 33, 1, 0, 0, 0, 39, 37, 1, 0, 0, 0, 40, 49, 1, 0, 0, 0, 41, 42, 10, 4, 0, 0, 42, 43, 5, 16, 0, 0, 43, 48, 3, 2, 1, 5, 44, 45, 10, 3, 0, 0, 45, 46, 5, 17, 0, 0, 46, 48, 3, 2, 1, 4, 47, 41, 1, 0, 0, 0, 47, 44, 1, 0, 0, 0, 48, 51, 1, 0, 0, 0, 49, 47, 1, 0, 0, 0, 49, 50, 1, 0, 0, 0, 50, 3, 1, 0, 0, 0, 51, 49, 1, 0, 0, 0, 52, 53, 5, 4, 0, 0, 53, 58, 3, 10, 5, 0, 54, 55, 5, 3, 0, 0, 55, 57, 3, 10, 5, 0, 56, 54, 1, 0, 0, 0, 57, 60, 1, 0, 0, 0, 58, 56, 1, 0, 0, 0, 58, 59, 1, 0, 0, 0, 59, 61, 1, 0, 0, 0, 60, 58, 1, 0, 0, 0, 61, 62, 5, 5, 0, 0, 62, 5, 1, 0, 0, 0, 63, 64, 7, 0, 0, 0, 64, 7, 1, 0, 0, 0, 65, 66, 5, 25, 0, 0, 66, 67, 5, 2, 0, 0, 67, 71, 5, 25, 0, 0, 68, 71, 5, 25, 0, 0, 69, 71, 5, 22, 0, 0, 70, 65, 1, 0, 0, 0, 70, 68, 1, 0, 0, 0, 70, 69, 1, 0, 0, 0, 71, 9, 1, 0, 0, 0, 72, 74, 7, 1, 0, 0, 73, 72, 1, 0, 0, 0, 73, 74, 1, 0, 0, 0, 74, 75, 1, 0, 0, 0, 75, 87, 5, 23, 0, 0, 76, 78, 7, 1, 0, 0, 77, 76, 1, 0, 0, 0, 77, 78, 1, 0, 0, 0, 78, 79, 1, 0, 0, 0, 79, 87, 5, 24, 0, 0, 80, 82, 5, 22, 0, 0, 81, 80, 1, 0, 0, 0, 82, 83, 1, 0, 0, 0, 83, 81, 1, 0, 0, 0, 83, 84, 1, 0, 0, 0, 84, 87, 1, 0, 0, 0, 85, 87, 5, 21, 0, 0, 86, 73, 1, 0, 0, 0, 86, 77, 1, 0, 0, 0, 86, 81, 1, 0, 0, 0, 86, 85, 1, 0, 0, 0, 87, 11, 1, 0, 0, 0, 10, 29, 39, 47, 49, 58, 70, 73, 77, 83, 86] \ No newline at end of file diff --git a/spring-ai-core/src/main/java/org/springframework/ai/vectorstore/filter/antlr4/FiltersBaseListener.java b/spring-ai-core/src/main/java/org/springframework/ai/vectorstore/filter/antlr4/FiltersBaseListener.java index c8e98410a..962a36c67 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/vectorstore/filter/antlr4/FiltersBaseListener.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/vectorstore/filter/antlr4/FiltersBaseListener.java @@ -121,6 +121,28 @@ public class FiltersBaseListener implements FiltersListener { public void exitInExpression(FiltersParser.InExpressionContext ctx) { } + /** + * {@inheritDoc} + * + *

+ * The default implementation does nothing. + *

+ */ + @Override + public void enterNotExpression(FiltersParser.NotExpressionContext ctx) { + } + + /** + * {@inheritDoc} + * + *

+ * The default implementation does nothing. + *

+ */ + @Override + public void exitNotExpression(FiltersParser.NotExpressionContext ctx) { + } + /** * {@inheritDoc} * diff --git a/spring-ai-core/src/main/java/org/springframework/ai/vectorstore/filter/antlr4/FiltersBaseVisitor.java b/spring-ai-core/src/main/java/org/springframework/ai/vectorstore/filter/antlr4/FiltersBaseVisitor.java index 5c0c81193..555a69629 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/vectorstore/filter/antlr4/FiltersBaseVisitor.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/vectorstore/filter/antlr4/FiltersBaseVisitor.java @@ -86,6 +86,19 @@ public class FiltersBaseVisitor extends AbstractParseTreeVisitor implement return visitChildren(ctx); } + /** + * {@inheritDoc} + * + *

+ * The default implementation returns the result of calling {@link #visitChildren} on + * {@code ctx}. + *

+ */ + @Override + public T visitNotExpression(FiltersParser.NotExpressionContext ctx) { + return visitChildren(ctx); + } + /** * {@inheritDoc} * diff --git a/spring-ai-core/src/main/java/org/springframework/ai/vectorstore/filter/antlr4/FiltersListener.java b/spring-ai-core/src/main/java/org/springframework/ai/vectorstore/filter/antlr4/FiltersListener.java index f6b920479..77444e527 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/vectorstore/filter/antlr4/FiltersListener.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/vectorstore/filter/antlr4/FiltersListener.java @@ -83,6 +83,20 @@ public interface FiltersListener extends ParseTreeListener { */ void exitInExpression(FiltersParser.InExpressionContext ctx); + /** + * Enter a parse tree produced by the {@code NotExpression} labeled alternative in + * {@link FiltersParser#booleanExpression}. + * @param ctx the parse tree + */ + void enterNotExpression(FiltersParser.NotExpressionContext ctx); + + /** + * Exit a parse tree produced by the {@code NotExpression} labeled alternative in + * {@link FiltersParser#booleanExpression}. + * @param ctx the parse tree + */ + void exitNotExpression(FiltersParser.NotExpressionContext ctx); + /** * Enter a parse tree produced by the {@code CompareExpression} labeled alternative in * {@link FiltersParser#booleanExpression}. diff --git a/spring-ai-core/src/main/java/org/springframework/ai/vectorstore/filter/antlr4/FiltersParser.java b/spring-ai-core/src/main/java/org/springframework/ai/vectorstore/filter/antlr4/FiltersParser.java index 3171c1991..a17e355d7 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/vectorstore/filter/antlr4/FiltersParser.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/vectorstore/filter/antlr4/FiltersParser.java @@ -358,6 +358,43 @@ public class FiltersParser extends Parser { } + @SuppressWarnings("CheckReturnValue") + public static class NotExpressionContext extends BooleanExpressionContext { + + public TerminalNode NOT() { + return getToken(FiltersParser.NOT, 0); + } + + public BooleanExpressionContext booleanExpression() { + return getRuleContext(BooleanExpressionContext.class, 0); + } + + public NotExpressionContext(BooleanExpressionContext ctx) { + copyFrom(ctx); + } + + @Override + public void enterRule(ParseTreeListener listener) { + if (listener instanceof FiltersListener) + ((FiltersListener) listener).enterNotExpression(this); + } + + @Override + public void exitRule(ParseTreeListener listener) { + if (listener instanceof FiltersListener) + ((FiltersListener) listener).exitNotExpression(this); + } + + @Override + public T accept(ParseTreeVisitor visitor) { + if (visitor instanceof FiltersVisitor) + return ((FiltersVisitor) visitor).visitNotExpression(this); + else + return visitor.visitChildren(this); + } + + } + @SuppressWarnings("CheckReturnValue") public static class CompareExpressionContext extends BooleanExpressionContext { @@ -502,7 +539,7 @@ public class FiltersParser extends Parser { int _alt; enterOuterAlt(_localctx, 1); { - setState(37); + setState(39); _errHandler.sync(this); switch (getInterpreter().adaptivePredict(_input, 1, _ctx)) { case 1: { @@ -570,9 +607,19 @@ public class FiltersParser extends Parser { match(RIGHT_PARENTHESIS); } break; + case 5: { + _localctx = new NotExpressionContext(_localctx); + _ctx = _localctx; + _prevctx = _localctx; + setState(37); + match(NOT); + setState(38); + booleanExpression(1); + } + break; } _ctx.stop = _input.LT(-1); - setState(47); + setState(49); _errHandler.sync(this); _alt = getInterpreter().adaptivePredict(_input, 3, _ctx); while (_alt != 2 && _alt != org.antlr.v4.runtime.atn.ATN.INVALID_ALT_NUMBER) { @@ -581,7 +628,7 @@ public class FiltersParser extends Parser { triggerExitRuleEvent(); _prevctx = _localctx; { - setState(45); + setState(47); _errHandler.sync(this); switch (getInterpreter().adaptivePredict(_input, 2, _ctx)) { case 1: { @@ -589,13 +636,13 @@ public class FiltersParser extends Parser { new BooleanExpressionContext(_parentctx, _parentState)); ((AndExpressionContext) _localctx).left = _prevctx; pushNewRecursionContext(_localctx, _startState, RULE_booleanExpression); - setState(39); - if (!(precpred(_ctx, 3))) - throw new FailedPredicateException(this, "precpred(_ctx, 3)"); - setState(40); - ((AndExpressionContext) _localctx).operator = match(AND); setState(41); - ((AndExpressionContext) _localctx).right = booleanExpression(4); + if (!(precpred(_ctx, 4))) + throw new FailedPredicateException(this, "precpred(_ctx, 4)"); + setState(42); + ((AndExpressionContext) _localctx).operator = match(AND); + setState(43); + ((AndExpressionContext) _localctx).right = booleanExpression(5); } break; case 2: { @@ -603,19 +650,19 @@ public class FiltersParser extends Parser { new BooleanExpressionContext(_parentctx, _parentState)); ((OrExpressionContext) _localctx).left = _prevctx; pushNewRecursionContext(_localctx, _startState, RULE_booleanExpression); - setState(42); - if (!(precpred(_ctx, 2))) - throw new FailedPredicateException(this, "precpred(_ctx, 2)"); - setState(43); - ((OrExpressionContext) _localctx).operator = match(OR); setState(44); - ((OrExpressionContext) _localctx).right = booleanExpression(3); + if (!(precpred(_ctx, 3))) + throw new FailedPredicateException(this, "precpred(_ctx, 3)"); + setState(45); + ((OrExpressionContext) _localctx).operator = match(OR); + setState(46); + ((OrExpressionContext) _localctx).right = booleanExpression(4); } break; } } } - setState(49); + setState(51); _errHandler.sync(this); _alt = getInterpreter().adaptivePredict(_input, 3, _ctx); } @@ -697,27 +744,27 @@ public class FiltersParser extends Parser { try { enterOuterAlt(_localctx, 1); { - setState(50); + setState(52); match(LEFT_SQUARE_BRACKETS); - setState(51); + setState(53); constant(); - setState(56); + setState(58); _errHandler.sync(this); _la = _input.LA(1); while (_la == COMMA) { { { - setState(52); + setState(54); match(COMMA); - setState(53); + setState(55); constant(); } } - setState(58); + setState(60); _errHandler.sync(this); _la = _input.LA(1); } - setState(59); + setState(61); match(RIGHT_SQUARE_BRACKETS); } } @@ -797,7 +844,7 @@ public class FiltersParser extends Parser { try { enterOuterAlt(_localctx, 1); { - setState(61); + setState(63); _la = _input.LA(1); if (!((((_la) & ~0x3f) == 0 && ((1L << _la) & 63744L) != 0))) { _errHandler.recoverInline(this); @@ -875,28 +922,28 @@ public class FiltersParser extends Parser { IdentifierContext _localctx = new IdentifierContext(_ctx, getState()); enterRule(_localctx, 8, RULE_identifier); try { - setState(68); + setState(70); _errHandler.sync(this); switch (getInterpreter().adaptivePredict(_input, 5, _ctx)) { case 1: enterOuterAlt(_localctx, 1); { - setState(63); - match(IDENTIFIER); - setState(64); - match(DOT); setState(65); match(IDENTIFIER); + setState(66); + match(DOT); + setState(67); + match(IDENTIFIER); } break; case 2: enterOuterAlt(_localctx, 2); { - setState(66); + setState(68); match(IDENTIFIER); } break; case 3: enterOuterAlt(_localctx, 3); { - setState(67); + setState(69); match(QUOTED_STRING); } break; @@ -1092,18 +1139,18 @@ public class FiltersParser extends Parser { int _la; try { int _alt; - setState(84); + setState(86); _errHandler.sync(this); switch (getInterpreter().adaptivePredict(_input, 9, _ctx)) { case 1: _localctx = new IntegerConstantContext(_localctx); enterOuterAlt(_localctx, 1); { - setState(71); + setState(73); _errHandler.sync(this); _la = _input.LA(1); if (_la == MINUS || _la == PLUS) { { - setState(70); + setState(72); _la = _input.LA(1); if (!(_la == MINUS || _la == PLUS)) { _errHandler.recoverInline(this); @@ -1117,19 +1164,19 @@ public class FiltersParser extends Parser { } } - setState(73); + setState(75); match(INTEGER_VALUE); } break; case 2: _localctx = new DecimalConstantContext(_localctx); enterOuterAlt(_localctx, 2); { - setState(75); + setState(77); _errHandler.sync(this); _la = _input.LA(1); if (_la == MINUS || _la == PLUS) { { - setState(74); + setState(76); _la = _input.LA(1); if (!(_la == MINUS || _la == PLUS)) { _errHandler.recoverInline(this); @@ -1143,21 +1190,21 @@ public class FiltersParser extends Parser { } } - setState(77); + setState(79); match(DECIMAL_VALUE); } break; case 3: _localctx = new TextConstantContext(_localctx); enterOuterAlt(_localctx, 3); { - setState(79); + setState(81); _errHandler.sync(this); _alt = 1; do { switch (_alt) { case 1: { { - setState(78); + setState(80); match(QUOTED_STRING); } } @@ -1165,7 +1212,7 @@ public class FiltersParser extends Parser { default: throw new NoViableAltException(this); } - setState(81); + setState(83); _errHandler.sync(this); _alt = getInterpreter().adaptivePredict(_input, 8, _ctx); } @@ -1175,7 +1222,7 @@ public class FiltersParser extends Parser { case 4: _localctx = new BooleanConstantContext(_localctx); enterOuterAlt(_localctx, 4); { - setState(83); + setState(85); match(BOOLEAN_VALUE); } break; @@ -1203,67 +1250,68 @@ public class FiltersParser extends Parser { private boolean booleanExpression_sempred(BooleanExpressionContext _localctx, int predIndex) { switch (predIndex) { case 0: - return precpred(_ctx, 3); + return precpred(_ctx, 4); case 1: - return precpred(_ctx, 2); + return precpred(_ctx, 3); } return true; } - public static final String _serializedATN = "\u0004\u0001\u001aW\u0002\u0000\u0007\u0000\u0002\u0001\u0007\u0001\u0002" + public static final String _serializedATN = "\u0004\u0001\u001aY\u0002\u0000\u0007\u0000\u0002\u0001\u0007\u0001\u0002" + "\u0002\u0007\u0002\u0002\u0003\u0007\u0003\u0002\u0004\u0007\u0004\u0002" + "\u0005\u0007\u0005\u0001\u0000\u0001\u0000\u0001\u0000\u0001\u0000\u0001" + "\u0001\u0001\u0001\u0001\u0001\u0001\u0001\u0001\u0001\u0001\u0001\u0001" + "\u0001\u0001\u0001\u0001\u0001\u0001\u0001\u0001\u0001\u0001\u0001\u0001" + "\u0001\u0003\u0001\u001e\b\u0001\u0001\u0001\u0001\u0001\u0001\u0001\u0001" - + "\u0001\u0001\u0001\u0001\u0001\u0003\u0001&\b\u0001\u0001\u0001\u0001" - + "\u0001\u0001\u0001\u0001\u0001\u0001\u0001\u0001\u0001\u0005\u0001.\b" - + "\u0001\n\u0001\f\u00011\t\u0001\u0001\u0002\u0001\u0002\u0001\u0002\u0001" - + "\u0002\u0005\u00027\b\u0002\n\u0002\f\u0002:\t\u0002\u0001\u0002\u0001" - + "\u0002\u0001\u0003\u0001\u0003\u0001\u0004\u0001\u0004\u0001\u0004\u0001" - + "\u0004\u0001\u0004\u0003\u0004E\b\u0004\u0001\u0005\u0003\u0005H\b\u0005" - + "\u0001\u0005\u0001\u0005\u0003\u0005L\b\u0005\u0001\u0005\u0001\u0005" - + "\u0004\u0005P\b\u0005\u000b\u0005\f\u0005Q\u0001\u0005\u0003\u0005U\b" - + "\u0005\u0001\u0005\u0000\u0001\u0002\u0006\u0000\u0002\u0004\u0006\b\n" - + "\u0000\u0002\u0002\u0000\b\b\u000b\u000f\u0001\u0000\t\n_\u0000\f\u0001" - + "\u0000\u0000\u0000\u0002%\u0001\u0000\u0000\u0000\u00042\u0001\u0000\u0000" - + "\u0000\u0006=\u0001\u0000\u0000\u0000\bD\u0001\u0000\u0000\u0000\nT\u0001" - + "\u0000\u0000\u0000\f\r\u0005\u0001\u0000\u0000\r\u000e\u0003\u0002\u0001" - + "\u0000\u000e\u000f\u0005\u0000\u0000\u0001\u000f\u0001\u0001\u0000\u0000" - + "\u0000\u0010\u0011\u0006\u0001\uffff\uffff\u0000\u0011\u0012\u0003\b\u0004" - + "\u0000\u0012\u0013\u0003\u0006\u0003\u0000\u0013\u0014\u0003\n\u0005\u0000" - + "\u0014&\u0001\u0000\u0000\u0000\u0015\u0016\u0003\b\u0004\u0000\u0016" - + "\u0017\u0005\u0012\u0000\u0000\u0017\u0018\u0003\u0004\u0002\u0000\u0018" - + "&\u0001\u0000\u0000\u0000\u0019\u001d\u0003\b\u0004\u0000\u001a\u001b" - + "\u0005\u0014\u0000\u0000\u001b\u001e\u0005\u0012\u0000\u0000\u001c\u001e" - + "\u0005\u0013\u0000\u0000\u001d\u001a\u0001\u0000\u0000\u0000\u001d\u001c" - + "\u0001\u0000\u0000\u0000\u001e\u001f\u0001\u0000\u0000\u0000\u001f \u0003" - + "\u0004\u0002\u0000 &\u0001\u0000\u0000\u0000!\"\u0005\u0006\u0000\u0000" - + "\"#\u0003\u0002\u0001\u0000#$\u0005\u0007\u0000\u0000$&\u0001\u0000\u0000" - + "\u0000%\u0010\u0001\u0000\u0000\u0000%\u0015\u0001\u0000\u0000\u0000%" - + "\u0019\u0001\u0000\u0000\u0000%!\u0001\u0000\u0000\u0000&/\u0001\u0000" - + "\u0000\u0000\'(\n\u0003\u0000\u0000()\u0005\u0010\u0000\u0000).\u0003" - + "\u0002\u0001\u0004*+\n\u0002\u0000\u0000+,\u0005\u0011\u0000\u0000,.\u0003" - + "\u0002\u0001\u0003-\'\u0001\u0000\u0000\u0000-*\u0001\u0000\u0000\u0000" - + ".1\u0001\u0000\u0000\u0000/-\u0001\u0000\u0000\u0000/0\u0001\u0000\u0000" - + "\u00000\u0003\u0001\u0000\u0000\u00001/\u0001\u0000\u0000\u000023\u0005" - + "\u0004\u0000\u000038\u0003\n\u0005\u000045\u0005\u0003\u0000\u000057\u0003" - + "\n\u0005\u000064\u0001\u0000\u0000\u00007:\u0001\u0000\u0000\u000086\u0001" - + "\u0000\u0000\u000089\u0001\u0000\u0000\u00009;\u0001\u0000\u0000\u0000" - + ":8\u0001\u0000\u0000\u0000;<\u0005\u0005\u0000\u0000<\u0005\u0001\u0000" - + "\u0000\u0000=>\u0007\u0000\u0000\u0000>\u0007\u0001\u0000\u0000\u0000" - + "?@\u0005\u0019\u0000\u0000@A\u0005\u0002\u0000\u0000AE\u0005\u0019\u0000" - + "\u0000BE\u0005\u0019\u0000\u0000CE\u0005\u0016\u0000\u0000D?\u0001\u0000" - + "\u0000\u0000DB\u0001\u0000\u0000\u0000DC\u0001\u0000\u0000\u0000E\t\u0001" - + "\u0000\u0000\u0000FH\u0007\u0001\u0000\u0000GF\u0001\u0000\u0000\u0000" - + "GH\u0001\u0000\u0000\u0000HI\u0001\u0000\u0000\u0000IU\u0005\u0017\u0000" - + "\u0000JL\u0007\u0001\u0000\u0000KJ\u0001\u0000\u0000\u0000KL\u0001\u0000" - + "\u0000\u0000LM\u0001\u0000\u0000\u0000MU\u0005\u0018\u0000\u0000NP\u0005" - + "\u0016\u0000\u0000ON\u0001\u0000\u0000\u0000PQ\u0001\u0000\u0000\u0000" - + "QO\u0001\u0000\u0000\u0000QR\u0001\u0000\u0000\u0000RU\u0001\u0000\u0000" - + "\u0000SU\u0005\u0015\u0000\u0000TG\u0001\u0000\u0000\u0000TK\u0001\u0000" - + "\u0000\u0000TO\u0001\u0000\u0000\u0000TS\u0001\u0000\u0000\u0000U\u000b" - + "\u0001\u0000\u0000\u0000\n\u001d%-/8DGKQT"; + + "\u0001\u0001\u0001\u0001\u0001\u0001\u0001\u0001\u0001\u0003\u0001(\b" + + "\u0001\u0001\u0001\u0001\u0001\u0001\u0001\u0001\u0001\u0001\u0001\u0001" + + "\u0001\u0005\u00010\b\u0001\n\u0001\f\u00013\t\u0001\u0001\u0002\u0001" + + "\u0002\u0001\u0002\u0001\u0002\u0005\u00029\b\u0002\n\u0002\f\u0002<\t" + + "\u0002\u0001\u0002\u0001\u0002\u0001\u0003\u0001\u0003\u0001\u0004\u0001" + + "\u0004\u0001\u0004\u0001\u0004\u0001\u0004\u0003\u0004G\b\u0004\u0001" + + "\u0005\u0003\u0005J\b\u0005\u0001\u0005\u0001\u0005\u0003\u0005N\b\u0005" + + "\u0001\u0005\u0001\u0005\u0004\u0005R\b\u0005\u000b\u0005\f\u0005S\u0001" + + "\u0005\u0003\u0005W\b\u0005\u0001\u0005\u0000\u0001\u0002\u0006\u0000" + + "\u0002\u0004\u0006\b\n\u0000\u0002\u0002\u0000\b\b\u000b\u000f\u0001\u0000" + + "\t\nb\u0000\f\u0001\u0000\u0000\u0000\u0002\'\u0001\u0000\u0000\u0000" + + "\u00044\u0001\u0000\u0000\u0000\u0006?\u0001\u0000\u0000\u0000\bF\u0001" + + "\u0000\u0000\u0000\nV\u0001\u0000\u0000\u0000\f\r\u0005\u0001\u0000\u0000" + + "\r\u000e\u0003\u0002\u0001\u0000\u000e\u000f\u0005\u0000\u0000\u0001\u000f" + + "\u0001\u0001\u0000\u0000\u0000\u0010\u0011\u0006\u0001\uffff\uffff\u0000" + + "\u0011\u0012\u0003\b\u0004\u0000\u0012\u0013\u0003\u0006\u0003\u0000\u0013" + + "\u0014\u0003\n\u0005\u0000\u0014(\u0001\u0000\u0000\u0000\u0015\u0016" + + "\u0003\b\u0004\u0000\u0016\u0017\u0005\u0012\u0000\u0000\u0017\u0018\u0003" + + "\u0004\u0002\u0000\u0018(\u0001\u0000\u0000\u0000\u0019\u001d\u0003\b" + + "\u0004\u0000\u001a\u001b\u0005\u0014\u0000\u0000\u001b\u001e\u0005\u0012" + + "\u0000\u0000\u001c\u001e\u0005\u0013\u0000\u0000\u001d\u001a\u0001\u0000" + + "\u0000\u0000\u001d\u001c\u0001\u0000\u0000\u0000\u001e\u001f\u0001\u0000" + + "\u0000\u0000\u001f \u0003\u0004\u0002\u0000 (\u0001\u0000\u0000\u0000" + + "!\"\u0005\u0006\u0000\u0000\"#\u0003\u0002\u0001\u0000#$\u0005\u0007\u0000" + + "\u0000$(\u0001\u0000\u0000\u0000%&\u0005\u0014\u0000\u0000&(\u0003\u0002" + + "\u0001\u0001\'\u0010\u0001\u0000\u0000\u0000\'\u0015\u0001\u0000\u0000" + + "\u0000\'\u0019\u0001\u0000\u0000\u0000\'!\u0001\u0000\u0000\u0000\'%\u0001" + + "\u0000\u0000\u0000(1\u0001\u0000\u0000\u0000)*\n\u0004\u0000\u0000*+\u0005" + + "\u0010\u0000\u0000+0\u0003\u0002\u0001\u0005,-\n\u0003\u0000\u0000-.\u0005" + + "\u0011\u0000\u0000.0\u0003\u0002\u0001\u0004/)\u0001\u0000\u0000\u0000" + + "/,\u0001\u0000\u0000\u000003\u0001\u0000\u0000\u00001/\u0001\u0000\u0000" + + "\u000012\u0001\u0000\u0000\u00002\u0003\u0001\u0000\u0000\u000031\u0001" + + "\u0000\u0000\u000045\u0005\u0004\u0000\u00005:\u0003\n\u0005\u000067\u0005" + + "\u0003\u0000\u000079\u0003\n\u0005\u000086\u0001\u0000\u0000\u00009<\u0001" + + "\u0000\u0000\u0000:8\u0001\u0000\u0000\u0000:;\u0001\u0000\u0000\u0000" + + ";=\u0001\u0000\u0000\u0000<:\u0001\u0000\u0000\u0000=>\u0005\u0005\u0000" + + "\u0000>\u0005\u0001\u0000\u0000\u0000?@\u0007\u0000\u0000\u0000@\u0007" + + "\u0001\u0000\u0000\u0000AB\u0005\u0019\u0000\u0000BC\u0005\u0002\u0000" + + "\u0000CG\u0005\u0019\u0000\u0000DG\u0005\u0019\u0000\u0000EG\u0005\u0016" + + "\u0000\u0000FA\u0001\u0000\u0000\u0000FD\u0001\u0000\u0000\u0000FE\u0001" + + "\u0000\u0000\u0000G\t\u0001\u0000\u0000\u0000HJ\u0007\u0001\u0000\u0000" + + "IH\u0001\u0000\u0000\u0000IJ\u0001\u0000\u0000\u0000JK\u0001\u0000\u0000" + + "\u0000KW\u0005\u0017\u0000\u0000LN\u0007\u0001\u0000\u0000ML\u0001\u0000" + + "\u0000\u0000MN\u0001\u0000\u0000\u0000NO\u0001\u0000\u0000\u0000OW\u0005" + + "\u0018\u0000\u0000PR\u0005\u0016\u0000\u0000QP\u0001\u0000\u0000\u0000" + + "RS\u0001\u0000\u0000\u0000SQ\u0001\u0000\u0000\u0000ST\u0001\u0000\u0000" + + "\u0000TW\u0001\u0000\u0000\u0000UW\u0005\u0015\u0000\u0000VI\u0001\u0000" + + "\u0000\u0000VM\u0001\u0000\u0000\u0000VQ\u0001\u0000\u0000\u0000VU\u0001" + + "\u0000\u0000\u0000W\u000b\u0001\u0000\u0000\u0000\n\u001d\'/1:FIMSV"; public static final ATN _ATN = new ATNDeserializer().deserialize(_serializedATN.toCharArray()); static { diff --git a/spring-ai-core/src/main/java/org/springframework/ai/vectorstore/filter/antlr4/FiltersVisitor.java b/spring-ai-core/src/main/java/org/springframework/ai/vectorstore/filter/antlr4/FiltersVisitor.java index 831271273..27413bd6e 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/vectorstore/filter/antlr4/FiltersVisitor.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/vectorstore/filter/antlr4/FiltersVisitor.java @@ -63,6 +63,14 @@ public interface FiltersVisitor extends ParseTreeVisitor { */ T visitInExpression(FiltersParser.InExpressionContext ctx); + /** + * Visit a parse tree produced by the {@code NotExpression} labeled alternative in + * {@link FiltersParser#booleanExpression}. + * @param ctx the parse tree + * @return the visitor result + */ + T visitNotExpression(FiltersParser.NotExpressionContext ctx); + /** * Visit a parse tree produced by the {@code CompareExpression} labeled alternative in * {@link FiltersParser#booleanExpression}. diff --git a/spring-ai-core/src/main/java/org/springframework/ai/vectorstore/filter/converter/AbstractFilterExpressionConverter.java b/spring-ai-core/src/main/java/org/springframework/ai/vectorstore/filter/converter/AbstractFilterExpressionConverter.java index 468805654..8769d97ec 100644 --- a/spring-ai-core/src/main/java/org/springframework/ai/vectorstore/filter/converter/AbstractFilterExpressionConverter.java +++ b/spring-ai-core/src/main/java/org/springframework/ai/vectorstore/filter/converter/AbstractFilterExpressionConverter.java @@ -19,6 +19,7 @@ package org.springframework.ai.vectorstore.filter.converter; import java.util.List; import org.springframework.ai.vectorstore.filter.Filter; +import org.springframework.ai.vectorstore.filter.FilterHelper; import org.springframework.ai.vectorstore.filter.Filter.Expression; import org.springframework.ai.vectorstore.filter.Filter.ExpressionType; import org.springframework.ai.vectorstore.filter.Filter.Group; @@ -52,14 +53,27 @@ public abstract class AbstractFilterExpressionConverter implements FilterExpress this.doValue(value, context); } else if (operand instanceof Filter.Expression expression) { - if ((expression.type() != ExpressionType.AND && expression.type() != ExpressionType.OR) - && !(expression.right() instanceof Filter.Value)) { + if ((expression.type() != ExpressionType.NOT && expression.type() != ExpressionType.AND + && expression.type() != ExpressionType.OR) && !(expression.right() instanceof Filter.Value)) { throw new RuntimeException("Non AND/OR expression must have Value right argument!"); } - this.doExpression(expression, context); + if (expression.type() == ExpressionType.NOT) { + this.doNot(expression, context); + } + else { + this.doExpression(expression, context); + } } } + protected void doNot(Filter.Expression expression, StringBuilder context) { + // Default behavior is to convert the NOT expression into its semantically + // equivalent negation expression. + // Effectively removing the NOT types form the boolean expression tree before + // passing it to the doExpression. + this.convertOperand(FilterHelper.negate(expression), context); + } + protected abstract void doExpression(Filter.Expression expression, StringBuilder context); protected abstract void doKey(Filter.Key filterKey, StringBuilder context); diff --git a/spring-ai-core/src/main/resources/antlr4/org/springframework/ai/vectorstore/filter/antlr4/Filters.g4 b/spring-ai-core/src/main/resources/antlr4/org/springframework/ai/vectorstore/filter/antlr4/Filters.g4 index 21bb45ad4..086766b6a 100644 --- a/spring-ai-core/src/main/resources/antlr4/org/springframework/ai/vectorstore/filter/antlr4/Filters.g4 +++ b/spring-ai-core/src/main/resources/antlr4/org/springframework/ai/vectorstore/filter/antlr4/Filters.g4 @@ -33,6 +33,7 @@ booleanExpression | left=booleanExpression operator=AND right=booleanExpression # AndExpression | left=booleanExpression operator=OR right=booleanExpression # OrExpression | LEFT_PARENTHESIS booleanExpression RIGHT_PARENTHESIS # GroupExpression + | NOT booleanExpression # NotExpression ; constantArray diff --git a/spring-ai-core/src/test/java/org/springframework/ai/vectorstore/filter/FilterExpressionBuilderTests.java b/spring-ai-core/src/test/java/org/springframework/ai/vectorstore/filter/FilterExpressionBuilderTests.java index c05cba98b..d5f8e077a 100644 --- a/spring-ai-core/src/test/java/org/springframework/ai/vectorstore/filter/FilterExpressionBuilderTests.java +++ b/spring-ai-core/src/test/java/org/springframework/ai/vectorstore/filter/FilterExpressionBuilderTests.java @@ -33,6 +33,7 @@ import static org.springframework.ai.vectorstore.filter.Filter.ExpressionType.IN import static org.springframework.ai.vectorstore.filter.Filter.ExpressionType.NE; import static org.springframework.ai.vectorstore.filter.Filter.ExpressionType.NIN; import static org.springframework.ai.vectorstore.filter.Filter.ExpressionType.OR; +import static org.springframework.ai.vectorstore.filter.Filter.ExpressionType.NOT; /** * @author Christian Tzolov @@ -97,4 +98,18 @@ public class FilterExpressionBuilderTests { new Expression(IN, new Key("country"), new Value(List.of("BG", "NL", "US"))))); } + @Test + public void tesNot() { + // isOpen == true AND year >= 2020 AND country IN ["BG", "NL", "US"] + var exp = b.not(b.and(b.and(b.eq("isOpen", true), b.gte("year", 2020)), b.in("country", "BG", "NL", "US"))) + .build(); + + assertThat(exp).isEqualTo(new Expression(NOT, + new Expression(AND, + new Expression(AND, new Expression(EQ, new Key("isOpen"), new Value(true)), + new Expression(GTE, new Key("year"), new Value(2020))), + new Expression(IN, new Key("country"), new Value(List.of("BG", "NL", "US")))), + null)); + } + } diff --git a/spring-ai-core/src/test/java/org/springframework/ai/vectorstore/filter/FilterExpressionTextParserTests.java b/spring-ai-core/src/test/java/org/springframework/ai/vectorstore/filter/FilterExpressionTextParserTests.java index 057ae86cc..27e62f1d8 100644 --- a/spring-ai-core/src/test/java/org/springframework/ai/vectorstore/filter/FilterExpressionTextParserTests.java +++ b/spring-ai-core/src/test/java/org/springframework/ai/vectorstore/filter/FilterExpressionTextParserTests.java @@ -30,10 +30,11 @@ import static org.springframework.ai.vectorstore.filter.Filter.ExpressionType.AN import static org.springframework.ai.vectorstore.filter.Filter.ExpressionType.EQ; import static org.springframework.ai.vectorstore.filter.Filter.ExpressionType.GTE; import static org.springframework.ai.vectorstore.filter.Filter.ExpressionType.IN; +import static org.springframework.ai.vectorstore.filter.Filter.ExpressionType.LTE; import static org.springframework.ai.vectorstore.filter.Filter.ExpressionType.NE; import static org.springframework.ai.vectorstore.filter.Filter.ExpressionType.NIN; +import static org.springframework.ai.vectorstore.filter.Filter.ExpressionType.NOT; import static org.springframework.ai.vectorstore.filter.Filter.ExpressionType.OR; -import static org.springframework.ai.vectorstore.filter.Filter.ExpressionType.LTE; /** * @author Christian Tzolov @@ -111,6 +112,61 @@ public class FilterExpressionTextParserTests { .get("WHERE " + "isOpen == true AND year >= 2020 AND country IN [\"BG\", \"NL\", \"US\"]")).isEqualTo(exp); } + @Test + public void tesNot() { + // NOT(isOpen == true AND year >= 2020 AND country IN ["BG", "NL", "US"]) + Expression exp = parser.parse("not(isOpen == true AND year >= 2020 AND country IN [\"BG\", \"NL\", \"US\"])"); + + assertThat(exp).isEqualTo(new Expression(NOT, + new Group(new Expression(AND, + new Expression(AND, new Expression(EQ, new Key("isOpen"), new Value(true)), + new Expression(GTE, new Key("year"), new Value(2020))), + new Expression(IN, new Key("country"), new Value(List.of("BG", "NL", "US"))))), + null)); + + assertThat(parser.getCache() + .get("WHERE " + "not(isOpen == true AND year >= 2020 AND country IN [\"BG\", \"NL\", \"US\"])")) + .isEqualTo(exp); + } + + @Test + public void tesNotNin() { + // NOT(country NOT IN ["BG", "NL", "US"]) + Expression exp = parser.parse("not(country NOT IN [\"BG\", \"NL\", \"US\"])"); + + assertThat(exp).isEqualTo(new Expression(NOT, + new Group(new Expression(NIN, new Key("country"), new Value(List.of("BG", "NL", "US")))), null)); + } + + @Test + public void tesNotNin2() { + // NOT country NOT IN ["BG", "NL", "US"] + Expression exp = parser.parse("NOT country NOT IN [\"BG\", \"NL\", \"US\"]"); + + assertThat(exp).isEqualTo(new Expression(NOT, + new Expression(NIN, new Key("country"), new Value(List.of("BG", "NL", "US"))), null)); + } + + @Test + public void tesNestedNot() { + // NOT(isOpen == true AND year >= 2020 AND NOT(country IN ["BG", "NL", "US"])) + Expression exp = parser + .parse("not(isOpen == true AND year >= 2020 AND NOT(country IN [\"BG\", \"NL\", \"US\"]))"); + + assertThat(exp).isEqualTo(new Expression(NOT, + new Group(new Expression(AND, + new Expression(AND, new Expression(EQ, new Key("isOpen"), new Value(true)), + new Expression(GTE, new Key("year"), new Value(2020))), + new Expression(NOT, + new Group(new Expression(IN, new Key("country"), new Value(List.of("BG", "NL", "US")))), + null))), + null)); + + assertThat(parser.getCache() + .get("WHERE " + "not(isOpen == true AND year >= 2020 AND NOT(country IN [\"BG\", \"NL\", \"US\"]))")) + .isEqualTo(exp); + } + @Test public void testDecimal() { // temperature >= -15.6 && temperature <= +20.13 diff --git a/spring-ai-core/src/test/java/org/springframework/ai/vectorstore/filter/FilterHelperTests.java b/spring-ai-core/src/test/java/org/springframework/ai/vectorstore/filter/FilterHelperTests.java new file mode 100644 index 000000000..154c94490 --- /dev/null +++ b/spring-ai-core/src/test/java/org/springframework/ai/vectorstore/filter/FilterHelperTests.java @@ -0,0 +1,171 @@ +/* + * Copyright 2023-2023 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.ai.vectorstore.filter; + +import java.util.List; + +import org.junit.jupiter.api.Test; + +import org.springframework.ai.vectorstore.filter.Filter.Expression; +import org.springframework.ai.vectorstore.filter.Filter.ExpressionType; +import org.springframework.ai.vectorstore.filter.Filter.Key; +import org.springframework.ai.vectorstore.filter.Filter.Value; +import org.springframework.ai.vectorstore.filter.converter.PrintFilterExpressionConverter; + +import static org.assertj.core.api.Assertions.assertThat; + +/** + * @author Christian Tzolov + */ +public class FilterHelperTests { + + @Test + public void negateEQ() { + assertThat(Filter.parser().parse("NOT key == 'UK' ")).isEqualTo(new Filter.Expression(ExpressionType.NOT, + new Filter.Expression(ExpressionType.EQ, new Key("key"), new Value("UK")), null)); + + assertThat(FilterHelper.negate(Filter.parser().parse("NOT key == 'UK' "))) + .isEqualTo(new Filter.Expression(ExpressionType.NE, new Key("key"), new Value("UK"))); + + assertThat(FilterHelper.negate(Filter.parser().parse("NOT (key == 'UK') "))) + .isEqualTo(new Filter.Group(new Filter.Expression(ExpressionType.NE, new Key("key"), new Value("UK")))); + } + + @Test + public void negateNE() { + var exp = Filter.parser().parse("NOT key != 'UK' "); + assertThat(FilterHelper.negate(exp)) + .isEqualTo(new Filter.Expression(ExpressionType.EQ, new Key("key"), new Value("UK"))); + + } + + @Test + public void negateGT() { + var exp = Filter.parser().parse("NOT key > 13 "); + assertThat(FilterHelper.negate(exp)) + .isEqualTo(new Filter.Expression(ExpressionType.LTE, new Key("key"), new Value(13))); + + } + + @Test + public void negateGTE() { + var exp = Filter.parser().parse("NOT key >= 13 "); + assertThat(FilterHelper.negate(exp)) + .isEqualTo(new Filter.Expression(ExpressionType.LT, new Key("key"), new Value(13))); + } + + @Test + public void negateLT() { + var exp = Filter.parser().parse("NOT key < 13 "); + assertThat(FilterHelper.negate(exp)) + .isEqualTo(new Filter.Expression(ExpressionType.GTE, new Key("key"), new Value(13))); + } + + @Test + public void negateLTE() { + var exp = Filter.parser().parse("NOT key <= 13 "); + assertThat(FilterHelper.negate(exp)) + .isEqualTo(new Filter.Expression(ExpressionType.GT, new Key("key"), new Value(13))); + } + + @Test + public void negateIN() { + var exp = Filter.parser().parse("NOT key IN [11, 12, 13] "); + assertThat(FilterHelper.negate(exp)) + .isEqualTo(new Filter.Expression(ExpressionType.NIN, new Key("key"), new Value(List.of(11, 12, 13)))); + } + + @Test + public void negateNIN() { + var exp = Filter.parser().parse("NOT key NIN [11, 12, 13] "); + assertThat(FilterHelper.negate(exp)) + .isEqualTo(new Filter.Expression(ExpressionType.IN, new Key("key"), new Value(List.of(11, 12, 13)))); + } + + @Test + public void negateNIN2() { + var exp = Filter.parser().parse("NOT key NOT IN [11, 12, 13] "); + assertThat(FilterHelper.negate(exp)) + .isEqualTo(new Filter.Expression(ExpressionType.IN, new Key("key"), new Value(List.of(11, 12, 13)))); + } + + @Test + public void negateAND() { + var exp = Filter.parser().parse("NOT(key >= 11 AND key < 13)"); + assertThat(FilterHelper.negate(exp)).isEqualTo(new Filter.Group(new Filter.Expression(ExpressionType.OR, + new Filter.Expression(ExpressionType.LT, new Key("key"), new Value(11)), + new Filter.Expression(ExpressionType.GTE, new Key("key"), new Value(13))))); + } + + @Test + public void negateOR() { + var exp = Filter.parser().parse("NOT(key >= 11 OR key < 13)"); + assertThat(FilterHelper.negate(exp)).isEqualTo(new Filter.Group(new Filter.Expression(ExpressionType.AND, + new Filter.Expression(ExpressionType.LT, new Key("key"), new Value(11)), + new Filter.Expression(ExpressionType.GTE, new Key("key"), new Value(13))))); + } + + @Test + public void negateNot() { + var exp = Filter.parser().parse("NOT NOT(key >= 11)"); + assertThat(FilterHelper.negate(exp)) + .isEqualTo(new Filter.Group(new Filter.Expression(ExpressionType.LT, new Key("key"), new Value(11)))); + } + + @Test + public void negateNestedNot() { + var exp = Filter.parser().parse("NOT(NOT(key >= 11))"); + assertThat(exp).isEqualTo( + new Filter.Expression(ExpressionType.NOT, new Filter.Group(new Filter.Expression(ExpressionType.NOT, + new Filter.Group(new Filter.Expression(ExpressionType.GTE, new Key("key"), new Value(11))))))); + + assertThat(FilterHelper.negate(exp)) + .isEqualTo(new Filter.Group(new Filter.Expression(ExpressionType.LT, new Key("key"), new Value(11)))); + } + + @Test + public void expandIN() { + var exp = Filter.parser().parse("key IN [11, 12, 13] "); + assertThat(new InNinTestConverter().convertExpression(exp)).isEqualTo("key EQ 11 OR key EQ 12 OR key EQ 13"); + } + + @Test + public void expandNIN() { + var exp1 = Filter.parser().parse("key NIN [11, 12, 13] "); + var exp2 = Filter.parser().parse("key NOT IN [11, 12, 13] "); + assertThat(exp1).isEqualTo(exp2); + assertThat(new InNinTestConverter().convertExpression(exp1)).isEqualTo("key NE 11 AND key NE 12 AND key NE 13"); + } + + private static class InNinTestConverter extends PrintFilterExpressionConverter { + + @Override + public void doExpression(Expression expression, StringBuilder context) { + if (expression.type() == ExpressionType.IN) { + FilterHelper.expandIn(expression, context, this); + } + else if (expression.type() == ExpressionType.NIN) { + FilterHelper.expandNin(expression, context, this); + } + else { + super.doExpression(expression, context); + } + } + + }; + +} diff --git a/spring-ai-docs/concepts-staging.adoc b/spring-ai-docs/concepts-staging.adoc index 26873e2d3..37693931c 100644 --- a/spring-ai-docs/concepts-staging.adoc +++ b/spring-ai-docs/concepts-staging.adoc @@ -113,7 +113,7 @@ One approach involves presenting both the user's request and the AI model's resp Furthermore, leveraging the information stored in the Vector Database as supplementary data can enhance the evaluation process, aiding in the determination of response relevance. -The Spring AI project currenlty provides some very basic examples of how you can evaluate the responses in the form of prompts to include in a JUnit test. +The Spring AI project currently provides some very basic examples of how you can evaluate the responses in the form of prompts to include in a JUnit test. diff --git a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/vectordbs.adoc b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/vectordbs.adoc index b5f8be513..73dfb679f 100644 --- a/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/vectordbs.adoc +++ b/spring-ai-docs/src/main/antora/modules/ROOT/pages/api/vectordbs.adoc @@ -87,8 +87,7 @@ country == 'UK' && year >= 2020 && isActive == true. These are the available implementations of the `VectorStore` interface: -* `InMemoryVectorStore` -* `SimplePersistentVectorStore` +* `InMemoryVectorStore` and `SimplePersistentVectorStore`. * Pinecone: https://www.pinecone.io/[PineCone] vector store. * PgVector [`PgVectorStore`]: The https://github.com/pgvector/pgvector[PostgreSQL/PGVector] vector store. * Milvus [`MilvusVectorStore`]: The https://milvus.io/[Milvus] vector store @@ -117,7 +116,7 @@ The `VectorStore` implementation computes the embeddings and stores the JSON and @Autowired VectorStore vectorStore; - void load(String sourceFile) {} + void load(String sourceFile) { JsonReader jsonReader = new JsonReader(new FileSystemResource(sourceFile), "price", "name", "shortDescription", "description", "tags"); List documents = jsonReader.get(); diff --git a/vector-stores/spring-ai-azure/src/test/java/org/springframework/ai/vectorstore/azure/AzureVectorStoreIT.java b/vector-stores/spring-ai-azure/src/test/java/org/springframework/ai/vectorstore/azure/AzureVectorStoreIT.java index 84c59636e..a99c22af6 100644 --- a/vector-stores/spring-ai-azure/src/test/java/org/springframework/ai/vectorstore/azure/AzureVectorStoreIT.java +++ b/vector-stores/spring-ai-azure/src/test/java/org/springframework/ai/vectorstore/azure/AzureVectorStoreIT.java @@ -166,11 +166,20 @@ public class AzureVectorStoreIT { results = vectorStore.similaritySearch(SearchRequest.query("The World") .withTopK(5) .withSimilarityThresholdAll() - .withFilterExpression("country nin ['BG']")); + .withFilterExpression("country not in ['BG']")); assertThat(results).hasSize(1); assertThat(results.get(0).getId()).isEqualTo(nlDocument.getId()); + results = vectorStore.similaritySearch(SearchRequest.query("The World") + .withTopK(5) + .withSimilarityThresholdAll() + .withFilterExpression("NOT(country not in ['BG'])")); + + assertThat(results).hasSize(2); + assertThat(results.get(0).getId()).isIn(bgDocument.getId(), bgDocument2.getId()); + assertThat(results.get(1).getId()).isIn(bgDocument.getId(), bgDocument2.getId()); + // List results = // vectorStore.similaritySearch(SearchRequest.query("The World") // .withTopK(5) diff --git a/vector-stores/spring-ai-chroma/src/test/java/org/springframework/experimental/ai/vectorstore/ChromaVectorStoreIT.java b/vector-stores/spring-ai-chroma/src/test/java/org/springframework/experimental/ai/vectorstore/ChromaVectorStoreIT.java index f6e88c0f1..12f00cf17 100644 --- a/vector-stores/spring-ai-chroma/src/test/java/org/springframework/experimental/ai/vectorstore/ChromaVectorStoreIT.java +++ b/vector-stores/spring-ai-chroma/src/test/java/org/springframework/experimental/ai/vectorstore/ChromaVectorStoreIT.java @@ -122,6 +122,11 @@ public class ChromaVectorStoreIT { assertThat(results).hasSize(1); assertThat(results.get(0).getId()).isEqualTo(nlDocument.getId()); + results = vectorStore.similaritySearch( + request.withSimilarityThresholdAll().withFilterExpression("NOT(country == 'Netherland')")); + assertThat(results).hasSize(1); + assertThat(results.get(0).getId()).isEqualTo(bgDocument.getId()); + // Remove all documents from the store vectorStore.delete(List.of(bgDocument, nlDocument).stream().map(doc -> doc.getId()).toList()); }); diff --git a/vector-stores/spring-ai-milvus-store/src/test/java/org/springframework/ai/vectorstore/MilvusVectorStoreIT.java b/vector-stores/spring-ai-milvus-store/src/test/java/org/springframework/ai/vectorstore/MilvusVectorStoreIT.java index fd19b8fdf..6d035d645 100644 --- a/vector-stores/spring-ai-milvus-store/src/test/java/org/springframework/ai/vectorstore/MilvusVectorStoreIT.java +++ b/vector-stores/spring-ai-milvus-store/src/test/java/org/springframework/ai/vectorstore/MilvusVectorStoreIT.java @@ -188,6 +188,16 @@ public class MilvusVectorStoreIT { assertThat(results).hasSize(1); assertThat(results.get(0).getId()).isEqualTo(bgDocument.getId()); + + results = vectorStore.similaritySearch(SearchRequest.query("The World") + .withTopK(5) + .withSimilarityThresholdAll() + .withFilterExpression("NOT(country == 'BG' && year == 2020)")); + + assertThat(results).hasSize(2); + assertThat(results.get(0).getId()).isIn(nlDocument.getId(), bgDocument2.getId()); + assertThat(results.get(1).getId()).isIn(nlDocument.getId(), bgDocument2.getId()); + }); } diff --git a/vector-stores/spring-ai-pgvector-store/src/test/java/org/springframework/ai/vectorstore/PgVectorStoreIT.java b/vector-stores/spring-ai-pgvector-store/src/test/java/org/springframework/ai/vectorstore/PgVectorStoreIT.java index d6d154f09..adf887011 100644 --- a/vector-stores/spring-ai-pgvector-store/src/test/java/org/springframework/ai/vectorstore/PgVectorStoreIT.java +++ b/vector-stores/spring-ai-pgvector-store/src/test/java/org/springframework/ai/vectorstore/PgVectorStoreIT.java @@ -182,6 +182,12 @@ public class PgVectorStoreIT { assertThat(results.get(0).getId()).isIn(bgDocument.getId(), nlDocument.getId()); assertThat(results.get(1).getId()).isIn(bgDocument.getId(), nlDocument.getId()); + results = vectorStore.similaritySearch(searchRequest + .withFilterExpression("NOT((country == 'BG' && year == 2020) || (country == 'NL'))")); + + assertThat(results).hasSize(1); + assertThat(results.get(0).getId()).isEqualTo(bgDocument2.getId()); + results = vectorStore.similaritySearch(SearchRequest.query("The World") .withTopK(5) .withSimilarityThresholdAll() diff --git a/vector-stores/spring-ai-pinecone/src/test/java/org/springframework/ai/vectorstore/PineconeVectorStoreIT.java b/vector-stores/spring-ai-pinecone/src/test/java/org/springframework/ai/vectorstore/PineconeVectorStoreIT.java index 63baf358b..2506aa19f 100644 --- a/vector-stores/spring-ai-pinecone/src/test/java/org/springframework/ai/vectorstore/PineconeVectorStoreIT.java +++ b/vector-stores/spring-ai-pinecone/src/test/java/org/springframework/ai/vectorstore/PineconeVectorStoreIT.java @@ -156,6 +156,13 @@ public class PineconeVectorStoreIT { assertThat(results).hasSize(1); assertThat(results.get(0).getId()).isEqualTo(nlDocument.getId()); + results = vectorStore.similaritySearch(searchRequest.withTopK(5) + .withSimilarityThresholdAll() + .withFilterExpression("NOT(country == 'Netherland')")); + + assertThat(results).hasSize(1); + assertThat(results.get(0).getId()).isEqualTo(bgDocument.getId()); + // Remove all documents from the store vectorStore.delete(List.of(bgDocument, nlDocument).stream().map(doc -> doc.getId()).toList()); diff --git a/vector-stores/spring-ai-weaviate/src/main/java/org/springframework/ai/vectorstore/WeaviateFilterExpressionConverter.java b/vector-stores/spring-ai-weaviate/src/main/java/org/springframework/ai/vectorstore/WeaviateFilterExpressionConverter.java index 3083749c5..1c66e3a1f 100644 --- a/vector-stores/spring-ai-weaviate/src/main/java/org/springframework/ai/vectorstore/WeaviateFilterExpressionConverter.java +++ b/vector-stores/spring-ai-weaviate/src/main/java/org/springframework/ai/vectorstore/WeaviateFilterExpressionConverter.java @@ -16,7 +16,6 @@ package org.springframework.ai.vectorstore; -import java.util.ArrayList; import java.util.Date; import java.util.List; @@ -27,6 +26,7 @@ import org.springframework.ai.vectorstore.filter.Filter.Expression; import org.springframework.ai.vectorstore.filter.Filter.ExpressionType; import org.springframework.ai.vectorstore.filter.Filter.Group; import org.springframework.ai.vectorstore.filter.Filter.Key; +import org.springframework.ai.vectorstore.filter.FilterHelper; import org.springframework.ai.vectorstore.filter.converter.AbstractFilterExpressionConverter; import org.springframework.util.Assert; @@ -62,10 +62,10 @@ public class WeaviateFilterExpressionConverter extends AbstractFilterExpressionC protected void doExpression(Expression exp, StringBuilder context) { if (exp.type() == ExpressionType.IN) { - rewriteInNinExpressions(Filter.ExpressionType.OR, Filter.ExpressionType.EQ, exp, context); + FilterHelper.expandIn(exp, context, this); } else if (exp.type() == ExpressionType.NIN) { - rewriteInNinExpressions(Filter.ExpressionType.AND, Filter.ExpressionType.NE, exp, context); + FilterHelper.expandNin(exp, context, this); } else if (exp.type() == ExpressionType.AND || exp.type() == ExpressionType.OR) { context.append(getOperationSymbol(exp)); @@ -82,51 +82,6 @@ public class WeaviateFilterExpressionConverter extends AbstractFilterExpressionC } } - /** - * Recursively aggregates a list of expression into a binary tree with 'aggregateType' - * join nodes. - * @param aggregateType type all tree splits. - * @param expressions list of expressions to aggregate. - * @return Returns a binary tree expression. - */ - private Filter.Expression aggregate(Filter.ExpressionType aggregateType, List expressions) { - - if (expressions.size() == 1) { - return expressions.get(0); - } - return new Filter.Expression(aggregateType, expressions.get(0), - aggregate(aggregateType, expressions.subList(1, expressions.size()))); - } - - private void rewriteInNinExpressions(Filter.ExpressionType outerExpressionType, - Filter.ExpressionType innerExpressionType, Expression exp, StringBuilder context) { - if (exp.right() instanceof Filter.Value value) { - if (value.value() instanceof List list) { - // 1. foo IN ["bar1", "bar2", "bar3"] is equivalent to foo == "bar1" || - // foo == "bar2" || foo == "bar3" - // or equivalent to OR(foo == "bar1" OR( foo == "bar2" OR(foo == "bar3"))) - // 2. foo IN ["bar1", "bar2", "bar3"] is equivalent to foo != "bar1" && - // foo != "bar2" && foo != "bar3" - // or equivalent to AND(foo != "bar1" AND( foo != "bar2" OR(foo != - // "bar3"))) - List eqExprs = new ArrayList<>(); - for (Object o : list) { - eqExprs.add(new Filter.Expression(innerExpressionType, exp.left(), new Filter.Value(o))); - } - this.doExpression(aggregate(outerExpressionType, eqExprs), context); - } - else { - // 1. foo IN ["bar"] is equivalent to foo == "BAR" - // 2. foo NIN ["bar"] is equivalent to foo != "BAR" - this.doExpression(new Filter.Expression(innerExpressionType, exp.left(), exp.right()), context); - } - } - else { - throw new IllegalStateException( - "Filter IN right expression should be of Filter.Value type but was " + exp.right().getClass()); - } - } - private String getOperationSymbol(Expression exp) { switch (exp.type()) { case AND: diff --git a/vector-stores/spring-ai-weaviate/src/test/java/org/springframework/ai/vectorstore/WeaviateFilterExpressionConverterTests.java b/vector-stores/spring-ai-weaviate/src/test/java/org/springframework/ai/vectorstore/WeaviateFilterExpressionConverterTests.java index 9f4b92646..6b3084a2c 100644 --- a/vector-stores/spring-ai-weaviate/src/test/java/org/springframework/ai/vectorstore/WeaviateFilterExpressionConverterTests.java +++ b/vector-stores/spring-ai-weaviate/src/test/java/org/springframework/ai/vectorstore/WeaviateFilterExpressionConverterTests.java @@ -20,7 +20,6 @@ import java.util.List; import org.junit.jupiter.api.Test; -import org.springframework.ai.vectorstore.filter.Filter; import org.springframework.ai.vectorstore.filter.Filter.Expression; import org.springframework.ai.vectorstore.filter.Filter.Group; import org.springframework.ai.vectorstore.filter.Filter.Key; diff --git a/vector-stores/spring-ai-weaviate/src/test/java/org/springframework/ai/vectorstore/WeaviateVectorStoreIT.java b/vector-stores/spring-ai-weaviate/src/test/java/org/springframework/ai/vectorstore/WeaviateVectorStoreIT.java index 336cecce1..ed0c6e870 100644 --- a/vector-stores/spring-ai-weaviate/src/test/java/org/springframework/ai/vectorstore/WeaviateVectorStoreIT.java +++ b/vector-stores/spring-ai-weaviate/src/test/java/org/springframework/ai/vectorstore/WeaviateVectorStoreIT.java @@ -152,6 +152,14 @@ public class WeaviateVectorStoreIT { assertThat(results).hasSize(1); assertThat(results.get(0).getId()).isEqualTo(bgDocument.getId()); + results = vectorStore.similaritySearch(SearchRequest.query("The World") + .withTopK(5) + .withSimilarityThresholdAll() + .withFilterExpression("NOT((country == 'BG' && year == 2020) || (country == 'NL'))")); + + assertThat(results).hasSize(1); + assertThat(results.get(0).getId()).isEqualTo(bgDocument2.getId()); + vectorStore.delete(List.of(bgDocument.getId(), nlDocument.getId(), bgDocument2.getId())); }); }