diff --git a/spring-web/src/main/java/org/springframework/http/server/reactive/AbstractServerHttpResponse.java b/spring-web/src/main/java/org/springframework/http/server/reactive/AbstractServerHttpResponse.java index a871267914..f4841a7052 100644 --- a/spring-web/src/main/java/org/springframework/http/server/reactive/AbstractServerHttpResponse.java +++ b/spring-web/src/main/java/org/springframework/http/server/reactive/AbstractServerHttpResponse.java @@ -176,22 +176,25 @@ public abstract class AbstractServerHttpResponse implements ServerHttpResponse { @Override @SuppressWarnings("unchecked") public final Mono writeWith(Publisher body) { - // Write as Mono if possible as an optimization hint to Reactor Netty - // ChannelSendOperator not necessary for Mono + // For Mono we can avoid ChannelSendOperator and Reactor Netty is more optimized for Mono. + // We must resolve value first however, for a chance to handle potential error. if (body instanceof Mono) { - return ((Mono) body).flatMap(buffer -> - doCommit(() -> writeWithInternal(Mono.just(buffer))) - .doOnDiscard(PooledDataBuffer.class, DataBufferUtils::release)) - .doOnError(t -> this.getHeaders().clearContentHeaders()); + return ((Mono) body) + .flatMap(buffer -> doCommit(() -> + writeWithInternal(Mono.fromCallable(() -> buffer) + .doOnDiscard(PooledDataBuffer.class, DataBufferUtils::release)))) + .doOnError(t -> getHeaders().clearContentHeaders()); + } + else { + return new ChannelSendOperator<>(body, inner -> doCommit(() -> writeWithInternal(inner))) + .doOnError(t -> getHeaders().clearContentHeaders()); } - return new ChannelSendOperator<>(body, inner -> doCommit(() -> writeWithInternal(inner))) - .doOnError(t -> this.getHeaders().clearContentHeaders()); } @Override public final Mono writeAndFlushWith(Publisher> body) { return new ChannelSendOperator<>(body, inner -> doCommit(() -> writeAndFlushWithInternal(inner))) - .doOnError(t -> this.getHeaders().clearContentHeaders()); + .doOnError(t -> getHeaders().clearContentHeaders()); } @Override @@ -217,21 +220,30 @@ public abstract class AbstractServerHttpResponse implements ServerHttpResponse { if (!this.state.compareAndSet(State.NEW, State.COMMITTING)) { return Mono.empty(); } - this.commitActions.add(() -> - Mono.fromRunnable(() -> { - applyStatusCode(); - applyHeaders(); - applyCookies(); - this.state.set(State.COMMITTED); - })); + + Flux allActions = Flux.empty(); + + if (!this.commitActions.isEmpty()) { + allActions = Flux.concat(Flux.fromIterable(this.commitActions).map(Supplier::get)) + .doOnError(ex -> { + if (this.state.compareAndSet(State.COMMITTING, State.NEW)) { + getHeaders().clearContentHeaders(); + } + }); + } + + allActions = allActions.concatWith(Mono.fromRunnable(() -> { + applyStatusCode(); + applyHeaders(); + applyCookies(); + this.state.set(State.COMMITTED); + })); + if (writeAction != null) { - this.commitActions.add(writeAction); + allActions = allActions.concatWith(writeAction.get()); } - Flux commit = Flux.empty(); - for (Supplier> action : this.commitActions) { - commit = commit.concatWith(action.get()); - } - return commit.then(); + + return allActions.then(); } diff --git a/spring-web/src/test/java/org/springframework/http/server/reactive/ServerHttpResponseTests.java b/spring-web/src/test/java/org/springframework/http/server/reactive/ServerHttpResponseTests.java index c8a8552dd3..cdb4225381 100644 --- a/spring-web/src/test/java/org/springframework/http/server/reactive/ServerHttpResponseTests.java +++ b/spring-web/src/test/java/org/springframework/http/server/reactive/ServerHttpResponseTests.java @@ -20,11 +20,14 @@ import java.nio.ByteBuffer; import java.nio.charset.StandardCharsets; import java.util.ArrayList; import java.util.List; +import java.util.function.Consumer; +import java.util.function.Supplier; import org.junit.jupiter.api.Test; import org.reactivestreams.Publisher; import reactor.core.publisher.Flux; import reactor.core.publisher.Mono; +import reactor.test.StepVerifier; import org.springframework.core.io.buffer.DataBuffer; import org.springframework.core.io.buffer.DefaultDataBuffer; @@ -36,6 +39,8 @@ import org.springframework.http.ResponseCookie; import static org.assertj.core.api.Assertions.assertThat; /** + * Unit tests for {@link AbstractServerHttpRequest}. + * * @author Rossen Stoyanchev * @author Sebastien Deleuze * @author Brian Clozel @@ -43,7 +48,7 @@ import static org.assertj.core.api.Assertions.assertThat; public class ServerHttpResponseTests { @Test - void writeWith() throws Exception { + void writeWith() { TestServerHttpResponse response = new TestServerHttpResponse(); response.writeWith(Flux.just(wrap("a"), wrap("b"), wrap("c"))).block(); @@ -58,7 +63,7 @@ public class ServerHttpResponseTests { } @Test // SPR-14952 - void writeAndFlushWithFluxOfDefaultDataBuffer() throws Exception { + void writeAndFlushWithFluxOfDefaultDataBuffer() { TestServerHttpResponse response = new TestServerHttpResponse(); Flux> flux = Flux.just(Flux.just(wrap("foo"))); response.writeAndFlushWith(flux).block(); @@ -72,18 +77,18 @@ public class ServerHttpResponseTests { } @Test - void writeWithFluxError() throws Exception { + void writeWithFluxError() { IllegalStateException error = new IllegalStateException("boo"); writeWithError(Flux.error(error)); } @Test - void writeWithMonoError() throws Exception { + void writeWithMonoError() { IllegalStateException error = new IllegalStateException("boo"); writeWithError(Mono.error(error)); } - void writeWithError(Publisher body) throws Exception { + void writeWithError(Publisher body) { TestServerHttpResponse response = new TestServerHttpResponse(); HttpHeaders headers = response.getHeaders(); headers.setContentType(MediaType.APPLICATION_JSON); @@ -100,7 +105,7 @@ public class ServerHttpResponseTests { } @Test - void setComplete() throws Exception { + void setComplete() { TestServerHttpResponse response = new TestServerHttpResponse(); response.setComplete().block(); @@ -111,7 +116,7 @@ public class ServerHttpResponseTests { } @Test - void beforeCommitWithComplete() throws Exception { + void beforeCommitWithComplete() { ResponseCookie cookie = ResponseCookie.from("ID", "123").build(); TestServerHttpResponse response = new TestServerHttpResponse(); response.beforeCommit(() -> Mono.fromRunnable(() -> response.getCookies().add(cookie.getName(), cookie))); @@ -129,7 +134,7 @@ public class ServerHttpResponseTests { } @Test - void beforeCommitActionWithSetComplete() throws Exception { + void beforeCommitActionWithSetComplete() { ResponseCookie cookie = ResponseCookie.from("ID", "123").build(); TestServerHttpResponse response = new TestServerHttpResponse(); response.beforeCommit(() -> { @@ -145,6 +150,32 @@ public class ServerHttpResponseTests { assertThat(response.getCookies().getFirst("ID")).isSameAs(cookie); } + @Test // gh-24186 + void beforeCommitErrorShouldLeaveResponseNotCommitted() { + + Consumer>> tester = preCommitAction -> { + TestServerHttpResponse response = new TestServerHttpResponse(); + response.getHeaders().setContentType(MediaType.APPLICATION_JSON); + response.getHeaders().setContentLength(3); + response.beforeCommit(preCommitAction); + + StepVerifier.create(response.writeWith(Flux.just(wrap("body")))) + .expectErrorMessage("Max sessions") + .verify(); + + assertThat(response.statusCodeWritten).isFalse(); + assertThat(response.headersWritten).isFalse(); + assertThat(response.cookiesWritten).isFalse(); + assertThat(response.isCommitted()).isFalse(); + assertThat(response.getHeaders()).isEmpty(); + }; + + tester.accept(() -> Mono.error(new IllegalStateException("Max sessions"))); + tester.accept(() -> { + throw new IllegalStateException("Max sessions"); + }); + } + private DefaultDataBuffer wrap(String a) { return new DefaultDataBufferFactory().wrap(ByteBuffer.wrap(a.getBytes(StandardCharsets.UTF_8)));