diff --git a/spring-test/src/main/java/org/springframework/test/web/reactive/server/DefaultWebTestClient.java b/spring-test/src/main/java/org/springframework/test/web/reactive/server/DefaultWebTestClient.java index 3a935903bd..178a1593dd 100644 --- a/spring-test/src/main/java/org/springframework/test/web/reactive/server/DefaultWebTestClient.java +++ b/spring-test/src/main/java/org/springframework/test/web/reactive/server/DefaultWebTestClient.java @@ -300,7 +300,7 @@ class DefaultWebTestClient implements WebTestClient { DefaultResponseSpec(WiretapConnector.Info wiretapInfo, ClientResponse response, @Nullable String uriTemplate, Duration timeout) { - this.exchangeResult = wiretapInfo.createExchangeResult(uriTemplate); + this.exchangeResult = wiretapInfo.createExchangeResult(timeout, uriTemplate); this.response = response; this.timeout = timeout; } @@ -357,13 +357,13 @@ class DefaultWebTestClient implements WebTestClient { @Override public FluxExchangeResult returnResult(Class elementType) { Flux body = this.response.bodyToFlux(elementType); - return new FluxExchangeResult<>(this.exchangeResult, body, this.timeout); + return new FluxExchangeResult<>(this.exchangeResult, body); } @Override public FluxExchangeResult returnResult(ParameterizedTypeReference elementType) { Flux body = this.response.bodyToFlux(elementType); - return new FluxExchangeResult<>(this.exchangeResult, body, this.timeout); + return new FluxExchangeResult<>(this.exchangeResult, body); } } diff --git a/spring-test/src/main/java/org/springframework/test/web/reactive/server/ExchangeResult.java b/spring-test/src/main/java/org/springframework/test/web/reactive/server/ExchangeResult.java index 823e64ba8e..81ed0ea199 100644 --- a/spring-test/src/main/java/org/springframework/test/web/reactive/server/ExchangeResult.java +++ b/spring-test/src/main/java/org/springframework/test/web/reactive/server/ExchangeResult.java @@ -24,7 +24,7 @@ import java.util.Arrays; import java.util.List; import java.util.stream.Collectors; -import reactor.core.publisher.MonoProcessor; +import reactor.core.publisher.Mono; import org.springframework.http.HttpHeaders; import org.springframework.http.HttpMethod; @@ -36,7 +36,6 @@ import org.springframework.http.client.reactive.ClientHttpResponse; import org.springframework.lang.Nullable; import org.springframework.util.Assert; import org.springframework.util.MultiValueMap; -import org.springframework.util.ObjectUtils; /** * Container for request and response details for exchanges performed through @@ -64,9 +63,11 @@ public class ExchangeResult { private final ClientHttpResponse response; - private final MonoProcessor requestBody; + private final Mono requestBody; - private final MonoProcessor responseBody; + private final Mono responseBody; + + private final Duration timeout; @Nullable private final String uriTemplate; @@ -80,11 +81,11 @@ public class ExchangeResult { * @param response the HTTP response * @param requestBody capture of serialized request body content * @param responseBody capture of serialized response body content + * @param timeout how long to wait for content to materialize * @param uriTemplate the URI template used to set up the request, if any */ ExchangeResult(ClientHttpRequest request, ClientHttpResponse response, - MonoProcessor requestBody, MonoProcessor responseBody, - @Nullable String uriTemplate) { + Mono requestBody, Mono responseBody, Duration timeout, @Nullable String uriTemplate) { Assert.notNull(request, "ClientHttpRequest is required"); Assert.notNull(response, "ClientHttpResponse is required"); @@ -95,6 +96,7 @@ public class ExchangeResult { this.response = response; this.requestBody = requestBody; this.responseBody = responseBody; + this.timeout = timeout; this.uriTemplate = uriTemplate; } @@ -106,6 +108,7 @@ public class ExchangeResult { this.response = other.response; this.requestBody = other.requestBody; this.responseBody = other.responseBody; + this.timeout = other.timeout; this.uriTemplate = other.uriTemplate; } @@ -140,14 +143,14 @@ public class ExchangeResult { } /** - * Return the raw request body content written as a {@code byte[]}. - * @throws IllegalStateException if the request body is not fully written yet. + * Return the raw request body content written through the request. + *

Note: If the request content has not been consumed + * for any reason yet, use of this method will trigger consumption. + * @throws IllegalStateException if the request body is not been fully written. */ @Nullable public byte[] getRequestBodyContent() { - MonoProcessor body = this.requestBody; - Assert.isTrue(body.isTerminated(), "Request body incomplete."); - return body.block(Duration.ZERO); + return this.requestBody.block(this.timeout); } @@ -173,14 +176,14 @@ public class ExchangeResult { } /** - * Return the raw request body content written as a {@code byte[]}. - * @throws IllegalStateException if the response is not fully read yet. + * Return the raw request body content written to the response. + *

Note: If the response content has not been consumed + * yet, use of this method will trigger consumption. + * @throws IllegalStateException if the response is not been fully read. */ @Nullable public byte[] getResponseBodyContent() { - MonoProcessor body = this.responseBody; - Assert.state(body.isTerminated(), "Response body incomplete"); - return body.block(Duration.ZERO); + return this.responseBody.block(this.timeout); } @@ -223,30 +226,25 @@ public class ExchangeResult { .collect(Collectors.joining(delimiter)); } - private String formatBody(@Nullable MediaType contentType, MonoProcessor body) { - if (body.isSuccess()) { - byte[] bytes = body.block(Duration.ZERO); - if (ObjectUtils.isEmpty(bytes)) { - return "No content"; - } - if (contentType == null) { - return "Unknown content type (" + bytes.length + " bytes)"; - } - Charset charset = contentType.getCharset(); - if (charset != null) { - return new String(bytes, charset); - } - if (PRINTABLE_MEDIA_TYPES.stream().anyMatch(contentType::isCompatibleWith)) { - return new String(bytes, StandardCharsets.UTF_8); - } - return "Unknown charset (" + bytes.length + " bytes)"; - } - else if (body.isError()) { - return "I/O failure: " + body.getError(); - } - else { - return "Content not available yet"; - } + @Nullable + private String formatBody(@Nullable MediaType contentType, Mono body) { + return body + .map(bytes -> { + if (contentType == null) { + return "Unknown content type (" + bytes.length + " bytes)"; + } + Charset charset = contentType.getCharset(); + if (charset != null) { + return new String(bytes, charset); + } + if (PRINTABLE_MEDIA_TYPES.stream().anyMatch(contentType::isCompatibleWith)) { + return new String(bytes, StandardCharsets.UTF_8); + } + return "Unknown charset (" + bytes.length + " bytes)"; + }) + .defaultIfEmpty("No content") + .onErrorResume(ex -> Mono.just("Failed to obtain content: " + ex.getMessage())) + .block(this.timeout); } } diff --git a/spring-test/src/main/java/org/springframework/test/web/reactive/server/FluxExchangeResult.java b/spring-test/src/main/java/org/springframework/test/web/reactive/server/FluxExchangeResult.java index bd563cff8e..1a61932def 100644 --- a/spring-test/src/main/java/org/springframework/test/web/reactive/server/FluxExchangeResult.java +++ b/spring-test/src/main/java/org/springframework/test/web/reactive/server/FluxExchangeResult.java @@ -16,13 +16,9 @@ package org.springframework.test.web.reactive.server; -import java.time.Duration; import java.util.function.Consumer; import reactor.core.publisher.Flux; -import reactor.core.publisher.Mono; - -import org.springframework.lang.Nullable; /** * {@code ExchangeResult} variant with the response body decoded as @@ -35,20 +31,12 @@ import org.springframework.lang.Nullable; */ public class FluxExchangeResult extends ExchangeResult { - private static final IllegalStateException TIMEOUT_ERROR = - new IllegalStateException("Response timeout: for infinite streams " + - "use getResponseBody() first with explicit cancellation, e.g. via take(n)."); - - private final Flux body; - private final Duration timeout; - - FluxExchangeResult(ExchangeResult result, Flux body, Duration timeout) { + FluxExchangeResult(ExchangeResult result, Flux body) { super(result); this.body = body; - this.timeout = timeout; } @@ -81,22 +69,6 @@ public class FluxExchangeResult extends ExchangeResult { return this.body; } - /** - * {@inheritDoc} - *

Note: this method should typically be called after - * the response has been consumed in full via {@link #getResponseBody()}. - * Calling it first will cause the response {@code Flux} to be consumed - * via {@code getResponseBody.ignoreElements()}. - */ - @Override - @Nullable - public byte[] getResponseBodyContent() { - return this.body.ignoreElements() - .timeout(this.timeout, Mono.error(TIMEOUT_ERROR)) - .then(Mono.defer(() -> Mono.justOrEmpty(super.getResponseBodyContent()))) - .block(); - } - /** * Invoke the given consumer within {@link #assertWithDiagnostics(Runnable)} * passing {@code "this"} instance to it. This method allows the following, diff --git a/spring-test/src/main/java/org/springframework/test/web/reactive/server/WiretapConnector.java b/spring-test/src/main/java/org/springframework/test/web/reactive/server/WiretapConnector.java index 3dc5de7b46..2a051a3bc0 100644 --- a/spring-test/src/main/java/org/springframework/test/web/reactive/server/WiretapConnector.java +++ b/spring-test/src/main/java/org/springframework/test/web/reactive/server/WiretapConnector.java @@ -17,12 +17,14 @@ package org.springframework.test.web.reactive.server; import java.net.URI; +import java.time.Duration; import java.util.Map; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.atomic.AtomicReference; import java.util.function.Function; import org.reactivestreams.Publisher; +import org.reactivestreams.Subscription; import reactor.core.publisher.Flux; import reactor.core.publisher.Mono; import reactor.core.publisher.MonoProcessor; @@ -112,9 +114,11 @@ class WiretapConnector implements ClientHttpConnector { } - public ExchangeResult createExchangeResult(@Nullable String uriTemplate) { + public ExchangeResult createExchangeResult(Duration timeout, @Nullable String uriTemplate) { return new ExchangeResult(this.request, this.response, - this.request.getRecorder().getContent(), this.response.getRecorder().getContent(), uriTemplate); + Mono.defer(() -> this.request.getRecorder().getContent()), + Mono.defer(() -> this.response.getRecorder().getContent()), + timeout, uriTemplate); } } @@ -126,21 +130,21 @@ class WiretapConnector implements ClientHttpConnector { private static final DataBufferFactory bufferFactory = new DefaultDataBufferFactory(); - public static final byte[] EMPTY_CONTENT = new byte[0]; - @Nullable - private final Publisher publisher; + private final Flux publisher; @Nullable - private final Publisher> publisherNested; + private final Flux> publisherNested; private final DataBuffer buffer; private final MonoProcessor content; + private volatile boolean subscriberRegistered; - private WiretapRecorder(@Nullable Publisher publisher, + + public WiretapRecorder(@Nullable Publisher publisher, @Nullable Publisher> publisherNested) { if (publisher != null && publisherNested != null) { @@ -149,6 +153,7 @@ class WiretapConnector implements ClientHttpConnector { this.publisher = publisher != null ? Flux.from(publisher) + .doOnSubscribe(this::handleOnSubscribe) .doOnNext(this::handleOnNext) .doOnError(this::handleOnError) .doOnCancel(this::handleOnComplete) @@ -156,6 +161,7 @@ class WiretapConnector implements ClientHttpConnector { this.publisherNested = publisherNested != null ? Flux.from(publisherNested) + .doOnSubscribe(this::handleOnSubscribe) .map(p -> Flux.from(p).doOnNext(this::handleOnNext).doOnError(this::handleOnError)) .doOnError(this::handleOnError) .doOnCancel(this::handleOnComplete) @@ -163,10 +169,6 @@ class WiretapConnector implements ClientHttpConnector { this.buffer = bufferFactory.allocateBuffer(); this.content = MonoProcessor.create(); - - if (this.publisher == null && this.publisherNested == null) { - this.content.onNext(EMPTY_CONTENT); - } } @@ -180,11 +182,36 @@ class WiretapConnector implements ClientHttpConnector { return this.publisherNested; } - public MonoProcessor getContent() { - return this.content; + public Mono getContent() { + // No publisher (e.g. request#setComplete) + if (this.publisher == null && this.publisherNested == null) { + return Mono.empty(); + } + if (this.content.isTerminated()) { + return this.content; + } + if (this.subscriberRegistered) { + return Mono.error(new IllegalStateException( + "Subscriber registered but content is not yet fully consumed.")); + } + else { + // No subscriber, e.g.: + // - mock server request body never consumed (error before read) + // - FluxExchangeResult#getResponseBodyContent called + (this.publisher != null ? this.publisher : this.publisherNested) + .onErrorMap(ex -> new IllegalStateException( + "Content was not been consumed and " + + "an error was raised on attempt to produce it:", ex)) + .subscribe(); + return this.content; + } } + private void handleOnSubscribe(Subscription subscription) { + this.subscriberRegistered = true; + } + private void handleOnNext(DataBuffer nextBuffer) { this.buffer.write(nextBuffer); } diff --git a/spring-test/src/test/java/org/springframework/test/web/reactive/server/HeaderAssertionTests.java b/spring-test/src/test/java/org/springframework/test/web/reactive/server/HeaderAssertionTests.java index 746840143f..584d23eb1b 100644 --- a/spring-test/src/test/java/org/springframework/test/web/reactive/server/HeaderAssertionTests.java +++ b/spring-test/src/test/java/org/springframework/test/web/reactive/server/HeaderAssertionTests.java @@ -17,6 +17,7 @@ package org.springframework.test.web.reactive.server; import java.net.URI; +import java.time.Duration; import java.time.ZoneId; import java.time.ZonedDateTime; import java.util.concurrent.TimeUnit; @@ -257,7 +258,7 @@ public class HeaderAssertionTests { MonoProcessor emptyContent = MonoProcessor.create(); emptyContent.onComplete(); - ExchangeResult result = new ExchangeResult(request, response, emptyContent, emptyContent, null); + ExchangeResult result = new ExchangeResult(request, response, emptyContent, emptyContent, Duration.ZERO, null); return new HeaderAssertions(result, mock(WebTestClient.ResponseSpec.class)); } diff --git a/spring-test/src/test/java/org/springframework/test/web/reactive/server/MockServerTests.java b/spring-test/src/test/java/org/springframework/test/web/reactive/server/MockServerTests.java index 782e19a062..cc4d5eeaab 100644 --- a/spring-test/src/test/java/org/springframework/test/web/reactive/server/MockServerTests.java +++ b/spring-test/src/test/java/org/springframework/test/web/reactive/server/MockServerTests.java @@ -18,17 +18,19 @@ package org.springframework.test.web.reactive.server; import java.util.Arrays; import org.junit.Test; +import reactor.core.publisher.Flux; import reactor.core.publisher.Mono; import org.springframework.core.io.buffer.DataBuffer; import org.springframework.core.io.buffer.DefaultDataBufferFactory; import org.springframework.http.HttpHeaders; import org.springframework.http.HttpStatus; +import org.springframework.http.MediaType; import org.springframework.http.ResponseCookie; import org.springframework.http.server.reactive.ServerHttpResponse; -import static java.nio.charset.StandardCharsets.UTF_8; -import static org.junit.Assert.assertEquals; +import static java.nio.charset.StandardCharsets.*; +import static org.junit.Assert.*; /** * Test scenarios involving a mock server. @@ -38,7 +40,7 @@ public class MockServerTests { @Test // SPR-15674 (in comments) - public void mutateDoesNotCreateNewSession() throws Exception { + public void mutateDoesNotCreateNewSession() { WebTestClient client = WebTestClient .bindToWebHandler(exchange -> { @@ -51,8 +53,7 @@ public class MockServerTests { return exchange.getSession() .map(session -> session.getAttributeOrDefault("foo", "none")) .flatMap(value -> { - byte[] bytes = value.getBytes(UTF_8); - DataBuffer buffer = new DefaultDataBufferFactory().wrap(bytes); + DataBuffer buffer = toDataBuffer(value); return exchange.getResponse().writeWith(Mono.just(buffer)); }); } @@ -74,7 +75,7 @@ public class MockServerTests { } @Test // SPR-16059 - public void mutateDoesCopy() throws Exception { + public void mutateDoesCopy() { WebTestClient.Builder builder = WebTestClient .bindToWebHandler(exchange -> exchange.getResponse().setComplete()) @@ -111,7 +112,7 @@ public class MockServerTests { } @Test // SPR-16124 - public void exchangeResultHasCookieHeaders() throws Exception { + public void exchangeResultHasCookieHeaders() { ExchangeResult result = WebTestClient .bindToWebHandler(exchange -> { @@ -136,4 +137,32 @@ public class MockServerTests { result.getRequestHeaders().get(HttpHeaders.COOKIE)); } + @Test + public void responseBodyContentWithFluxExchangeResult() { + + FluxExchangeResult result = WebTestClient + .bindToWebHandler(exchange -> { + ServerHttpResponse response = exchange.getResponse(); + response.getHeaders().setContentType(MediaType.TEXT_PLAIN); + return response.writeWith(Flux.just(toDataBuffer("body"))); + }) + .build() + .get().uri("/") + .exchange() + .expectStatus().isOk() + .returnResult(String.class); + + // Get the raw content without consuming the response body flux.. + byte[] bytes = result.getResponseBodyContent(); + + assertNotNull(bytes); + assertEquals("body", new String(bytes, UTF_8)); + } + + + private DataBuffer toDataBuffer(String value) { + byte[] bytes = value.getBytes(UTF_8); + return new DefaultDataBufferFactory().wrap(bytes); + } + } diff --git a/spring-test/src/test/java/org/springframework/test/web/reactive/server/StatusAssertionTests.java b/spring-test/src/test/java/org/springframework/test/web/reactive/server/StatusAssertionTests.java index cd63b3c812..51cbf901f6 100644 --- a/spring-test/src/test/java/org/springframework/test/web/reactive/server/StatusAssertionTests.java +++ b/spring-test/src/test/java/org/springframework/test/web/reactive/server/StatusAssertionTests.java @@ -17,6 +17,7 @@ package org.springframework.test.web.reactive.server; import java.net.URI; +import java.time.Duration; import org.junit.Test; import reactor.core.publisher.MonoProcessor; @@ -182,7 +183,7 @@ public class StatusAssertionTests { MonoProcessor emptyContent = MonoProcessor.create(); emptyContent.onComplete(); - ExchangeResult result = new ExchangeResult(request, response, emptyContent, emptyContent, null); + ExchangeResult result = new ExchangeResult(request, response, emptyContent, emptyContent, Duration.ZERO, null); return new StatusAssertions(result, mock(WebTestClient.ResponseSpec.class)); } diff --git a/spring-test/src/test/java/org/springframework/test/web/reactive/server/WiretapConnectorTests.java b/spring-test/src/test/java/org/springframework/test/web/reactive/server/WiretapConnectorTests.java index 60f3fb979c..5aa403de59 100644 --- a/spring-test/src/test/java/org/springframework/test/web/reactive/server/WiretapConnectorTests.java +++ b/spring-test/src/test/java/org/springframework/test/web/reactive/server/WiretapConnectorTests.java @@ -17,6 +17,7 @@ package org.springframework.test.web.reactive.server; import java.net.URI; +import java.time.Duration; import org.junit.Test; import reactor.core.publisher.Mono; @@ -57,7 +58,7 @@ public class WiretapConnectorTests { function.exchange(clientRequest).block(ofMillis(0)); WiretapConnector.Info actual = wiretapConnector.claimRequest("1"); - ExchangeResult result = actual.createExchangeResult(null); + ExchangeResult result = actual.createExchangeResult(Duration.ZERO, null); assertEquals(HttpMethod.GET, result.getMethod()); assertEquals("/test", result.getUrl().toString()); } diff --git a/spring-test/src/test/java/org/springframework/test/web/reactive/server/samples/ErrorTests.java b/spring-test/src/test/java/org/springframework/test/web/reactive/server/samples/ErrorTests.java index b734a58e3a..dda50c7254 100644 --- a/spring-test/src/test/java/org/springframework/test/web/reactive/server/samples/ErrorTests.java +++ b/spring-test/src/test/java/org/springframework/test/web/reactive/server/samples/ErrorTests.java @@ -16,13 +16,21 @@ package org.springframework.test.web.reactive.server.samples; +import java.nio.charset.StandardCharsets; + import org.junit.Test; import org.springframework.http.HttpStatus; +import org.springframework.http.MediaType; +import org.springframework.test.web.reactive.server.EntityExchangeResult; import org.springframework.test.web.reactive.server.WebTestClient; import org.springframework.web.bind.annotation.GetMapping; +import org.springframework.web.bind.annotation.PostMapping; +import org.springframework.web.bind.annotation.RequestBody; import org.springframework.web.bind.annotation.RestController; +import static org.junit.Assert.*; + /** * Tests with error status codes or error conditions. * @@ -35,7 +43,7 @@ public class ErrorTests { @Test - public void notFound() throws Exception { + public void notFound(){ this.client.get().uri("/invalid") .exchange() .expectStatus().isNotFound() @@ -43,13 +51,28 @@ public class ErrorTests { } @Test - public void serverException() throws Exception { + public void serverException() { this.client.get().uri("/server-error") .exchange() .expectStatus().isEqualTo(HttpStatus.INTERNAL_SERVER_ERROR) .expectBody(Void.class); } + @Test // SPR-17363 + public void badRequestBeforeRequestBodyConsumed() { + EntityExchangeResult result = this.client.post() + .uri("/post") + .contentType(MediaType.APPLICATION_JSON_UTF8) + .syncBody(new Person("Dan")) + .exchange() + .expectStatus().isBadRequest() + .expectBody().isEmpty(); + + byte[] content = result.getRequestBodyContent(); + assertNotNull(content); + assertEquals("{\"name\":\"Dan\"}", new String(content, StandardCharsets.UTF_8)); + } + @RestController static class TestController { @@ -58,6 +81,10 @@ public class ErrorTests { void handleAndThrowException() { throw new IllegalStateException("server error"); } + + @PostMapping(path = "/post", params = "p") + void handlePost(@RequestBody Person person) { + } } }