diff --git a/spring-webflux/src/main/java/org/springframework/web/reactive/function/client/ExchangeFilterFunctions.java b/spring-webflux/src/main/java/org/springframework/web/reactive/function/client/ExchangeFilterFunctions.java index 3b7b1cd948..117eea044a 100644 --- a/spring-webflux/src/main/java/org/springframework/web/reactive/function/client/ExchangeFilterFunctions.java +++ b/spring-webflux/src/main/java/org/springframework/web/reactive/function/client/ExchangeFilterFunctions.java @@ -25,11 +25,15 @@ import java.util.function.Consumer; import java.util.function.Function; import java.util.function.Predicate; +import reactor.core.publisher.Flux; import reactor.core.publisher.Mono; +import org.springframework.core.io.buffer.DataBuffer; +import org.springframework.core.io.buffer.DataBufferUtils; import org.springframework.http.HttpHeaders; import org.springframework.http.HttpStatus; import org.springframework.util.Assert; +import org.springframework.web.reactive.function.BodyExtractors; /** * Static factory methods providing access to built-in implementations of @@ -50,6 +54,21 @@ public abstract class ExchangeFilterFunctions { public static final String BASIC_AUTHENTICATION_CREDENTIALS_ATTRIBUTE = ExchangeFilterFunctions.class.getName() + ".basicAuthenticationCredentials"; + /** + * Consume up to the specified number of bytes from the response body and + * cancel if any more data arrives. Internally delegates to + * {@link DataBufferUtils#takeUntilByteCount}. + * @return the filter to limit the response size with + * @since 5.1 + */ + public static ExchangeFilterFunction limitResponseSize(long maxByteCount) { + return (request, next) -> + next.exchange(request).map(response -> { + Flux body = response.body(BodyExtractors.toDataBuffers()); + body = DataBufferUtils.takeUntilByteCount(body, maxByteCount); + return ClientResponse.from(response).body(body).build(); + }); + } /** * Return a filter for HTTP Basic Authentication that adds an authorization diff --git a/spring-webflux/src/test/java/org/springframework/web/reactive/function/client/ExchangeFilterFunctionsTests.java b/spring-webflux/src/test/java/org/springframework/web/reactive/function/client/ExchangeFilterFunctionsTests.java index b1752950b7..963fa55c05 100644 --- a/spring-webflux/src/test/java/org/springframework/web/reactive/function/client/ExchangeFilterFunctionsTests.java +++ b/spring-webflux/src/test/java/org/springframework/web/reactive/function/client/ExchangeFilterFunctionsTests.java @@ -17,28 +17,36 @@ package org.springframework.web.reactive.function.client; import java.net.URI; +import java.nio.charset.StandardCharsets; import org.junit.Test; +import reactor.core.publisher.Flux; import reactor.core.publisher.Mono; import reactor.test.StepVerifier; +import org.springframework.core.io.buffer.DataBuffer; +import org.springframework.core.io.buffer.DataBufferUtils; +import org.springframework.core.io.buffer.DefaultDataBufferFactory; +import org.springframework.core.io.buffer.support.DataBufferTestUtils; import org.springframework.http.HttpHeaders; +import org.springframework.http.HttpMethod; import org.springframework.http.HttpStatus; +import org.springframework.web.reactive.function.BodyExtractors; import static org.junit.Assert.*; import static org.mockito.Mockito.*; -import static org.springframework.http.HttpMethod.GET; -import static org.springframework.web.reactive.function.client.ExchangeFilterFunctions.Credentials.basicAuthenticationCredentials; /** * @author Arjen Poutsma */ -@SuppressWarnings("deprecation") public class ExchangeFilterFunctionsTests { + private static final URI DEFAULT_URL = URI.create("http://example.com"); + + @Test public void andThen() { - ClientRequest request = ClientRequest.create(GET, URI.create("http://example.com")).build(); + ClientRequest request = ClientRequest.create(HttpMethod.GET, DEFAULT_URL).build(); ClientResponse response = mock(ClientResponse.class); ExchangeFunction exchange = r -> Mono.just(response); @@ -68,7 +76,7 @@ public class ExchangeFilterFunctionsTests { @Test public void apply() { - ClientRequest request = ClientRequest.create(GET, URI.create("http://example.com")).build(); + ClientRequest request = ClientRequest.create(HttpMethod.GET, DEFAULT_URL).build(); ClientResponse response = mock(ClientResponse.class); ExchangeFunction exchange = r -> Mono.just(response); @@ -86,8 +94,9 @@ public class ExchangeFilterFunctionsTests { } @Test + @SuppressWarnings("deprecation") public void basicAuthenticationUsernamePassword() { - ClientRequest request = ClientRequest.create(GET, URI.create("http://example.com")).build(); + ClientRequest request = ClientRequest.create(HttpMethod.GET, DEFAULT_URL).build(); ClientResponse response = mock(ClientResponse.class); ExchangeFunction exchange = r -> { @@ -109,9 +118,11 @@ public class ExchangeFilterFunctionsTests { } @Test + @SuppressWarnings("deprecation") public void basicAuthenticationAttributes() { - ClientRequest request = ClientRequest.create(GET, URI.create("http://example.com")) - .attributes(basicAuthenticationCredentials("foo", "bar")) + ClientRequest request = ClientRequest.create(HttpMethod.GET, DEFAULT_URL) + .attributes(org.springframework.web.reactive.function.client.ExchangeFilterFunctions + .Credentials.basicAuthenticationCredentials("foo", "bar")) .build(); ClientResponse response = mock(ClientResponse.class); @@ -128,8 +139,9 @@ public class ExchangeFilterFunctionsTests { } @Test + @SuppressWarnings("deprecation") public void basicAuthenticationAbsentAttributes() { - ClientRequest request = ClientRequest.create(GET, URI.create("http://example.com")).build(); + ClientRequest request = ClientRequest.create(HttpMethod.GET, DEFAULT_URL).build(); ClientResponse response = mock(ClientResponse.class); ExchangeFunction exchange = r -> { @@ -145,7 +157,7 @@ public class ExchangeFilterFunctionsTests { @Test public void statusHandlerMatch() { - ClientRequest request = ClientRequest.create(GET, URI.create("http://example.com")).build(); + ClientRequest request = ClientRequest.create(HttpMethod.GET, DEFAULT_URL).build(); ClientResponse response = mock(ClientResponse.class); when(response.statusCode()).thenReturn(HttpStatus.NOT_FOUND); @@ -163,16 +175,13 @@ public class ExchangeFilterFunctionsTests { @Test public void statusHandlerNoMatch() { - ClientRequest request = ClientRequest.create(GET, URI.create("http://example.com")).build(); + ClientRequest request = ClientRequest.create(HttpMethod.GET, DEFAULT_URL).build(); ClientResponse response = mock(ClientResponse.class); when(response.statusCode()).thenReturn(HttpStatus.NOT_FOUND); - ExchangeFunction exchange = r -> Mono.just(response); - - ExchangeFilterFunction errorHandler = ExchangeFilterFunctions.statusError( - HttpStatus::is5xxServerError, r -> new MyException()); - - Mono result = errorHandler.filter(request, exchange); + Mono result = ExchangeFilterFunctions + .statusError(HttpStatus::is5xxServerError, req -> new MyException()) + .filter(request, req -> Mono.just(response)); StepVerifier.create(result) .expectNext(response) @@ -180,6 +189,38 @@ public class ExchangeFilterFunctionsTests { .verify(); } + @Test + public void limitResponseSize() { + DefaultDataBufferFactory bufferFactory = new DefaultDataBufferFactory(); + DataBuffer b1 = dataBuffer("foo", bufferFactory); + DataBuffer b2 = dataBuffer("bar", bufferFactory); + DataBuffer b3 = dataBuffer("baz", bufferFactory); + + ClientRequest request = ClientRequest.create(HttpMethod.GET, DEFAULT_URL).build(); + ClientResponse response = ClientResponse.create(HttpStatus.OK).body(Flux.just(b1, b2, b3)).build(); + + Mono result = ExchangeFilterFunctions.limitResponseSize(5) + .filter(request, req -> Mono.just(response)); + + StepVerifier.create(result.flatMapMany(res -> res.body(BodyExtractors.toDataBuffers()))) + .consumeNextWith(buffer -> assertEquals("foo", string(buffer))) + .consumeNextWith(buffer -> assertEquals("ba", string(buffer))) + .expectComplete() + .verify(); + + } + + private String string(DataBuffer buffer) { + String value = DataBufferTestUtils.dumpString(buffer, StandardCharsets.UTF_8); + DataBufferUtils.release(buffer); + return value; + } + + private DataBuffer dataBuffer(String foo, DefaultDataBufferFactory bufferFactory) { + return bufferFactory.wrap(foo.getBytes(StandardCharsets.UTF_8)); + } + + @SuppressWarnings("serial") private static class MyException extends Exception {