Fix VectorStoreDocumentRetriever to handle Filter.Expression objects directly
- Updated computeRequestFilterExpression to check if the context value is already a Filter.Expression object before attempting to parse it as a string - Added docs for FILTER_EXPRESSION key that it accepts both String and Filter.Expression - Added test Fixes #3179
This commit is contained in:
committed by
Ilayaperumal Gopinathan
parent
368be3a04f
commit
c0b9240c8d
@@ -45,6 +45,11 @@ import org.springframework.util.StringUtils;
|
||||
* List<Document> documents = retriever.retrieve(new Query("example query"));
|
||||
* }</pre>
|
||||
*
|
||||
* <p>
|
||||
* The {@link #FILTER_EXPRESSION} context key can be used to provide a filter expression
|
||||
* for a specific query. This key accepts either a string representation of a filter
|
||||
* expression or a {@link Filter.Expression} object directly.
|
||||
*
|
||||
* @author Thomas Vitale
|
||||
* @since 1.0.0
|
||||
*/
|
||||
@@ -89,10 +94,27 @@ public final class VectorStoreDocumentRetriever implements DocumentRetriever {
|
||||
return this.vectorStore.similaritySearch(searchRequest);
|
||||
}
|
||||
|
||||
/**
|
||||
* Computes the filter expression to use for the current request.
|
||||
* <p>
|
||||
* The filter expression can be provided in the query context using the
|
||||
* {@link #FILTER_EXPRESSION} key. This key accepts either a string representation of
|
||||
* a filter expression or a {@link Filter.Expression} object directly.
|
||||
* <p>
|
||||
* If no filter expression is provided in the context, the default filter expression
|
||||
* configured for this retriever is used.
|
||||
* @param query the query containing potential context with filter expression
|
||||
* @return the filter expression to use for the request
|
||||
*/
|
||||
private Filter.Expression computeRequestFilterExpression(Query query) {
|
||||
var contextFilterExpression = query.context().get(FILTER_EXPRESSION);
|
||||
if (contextFilterExpression != null && StringUtils.hasText(contextFilterExpression.toString())) {
|
||||
return new FilterExpressionTextParser().parse(contextFilterExpression.toString());
|
||||
if (contextFilterExpression != null) {
|
||||
if (contextFilterExpression instanceof Filter.Expression) {
|
||||
return (Filter.Expression) contextFilterExpression;
|
||||
}
|
||||
else if (StringUtils.hasText(contextFilterExpression.toString())) {
|
||||
return new FilterExpressionTextParser().parse(contextFilterExpression.toString());
|
||||
}
|
||||
}
|
||||
return this.filterExpression.get();
|
||||
}
|
||||
|
||||
@@ -234,6 +234,32 @@ class VectorStoreDocumentRetrieverTests {
|
||||
.isEqualTo(new FilterExpressionBuilder().eq("location", "Rivendell").build());
|
||||
}
|
||||
|
||||
@Test
|
||||
void retrieveWithQueryObjectAndFilterExpressionObject() {
|
||||
var mockVectorStore = mock(VectorStore.class);
|
||||
var documentRetriever = VectorStoreDocumentRetriever.builder().vectorStore(mockVectorStore).build();
|
||||
|
||||
// Create a Filter.Expression object directly
|
||||
var filterExpression = new Filter.Expression(EQ, new Filter.Key("location"), new Filter.Value("Rivendell"));
|
||||
|
||||
var query = Query.builder()
|
||||
.text("test query")
|
||||
.context(Map.of(VectorStoreDocumentRetriever.FILTER_EXPRESSION, filterExpression))
|
||||
.build();
|
||||
documentRetriever.retrieve(query);
|
||||
|
||||
// Verify the mock interaction
|
||||
var searchRequestCaptor = ArgumentCaptor.forClass(SearchRequest.class);
|
||||
verify(mockVectorStore).similaritySearch(searchRequestCaptor.capture());
|
||||
|
||||
// Verify the search request
|
||||
var searchRequest = searchRequestCaptor.getValue();
|
||||
assertThat(searchRequest.getQuery()).isEqualTo("test query");
|
||||
assertThat(searchRequest.getSimilarityThreshold()).isEqualTo(SearchRequest.SIMILARITY_THRESHOLD_ACCEPT_ALL);
|
||||
assertThat(searchRequest.getTopK()).isEqualTo(SearchRequest.DEFAULT_TOP_K);
|
||||
assertThat(searchRequest.getFilterExpression()).isEqualTo(filterExpression);
|
||||
}
|
||||
|
||||
static final class TenantContextHolder {
|
||||
|
||||
private static final ThreadLocal<String> tenantIdentifier = new ThreadLocal<>();
|
||||
|
||||
Reference in New Issue
Block a user