diff --git a/spring-web/src/main/java/org/springframework/http/server/reactive/ServletServerHttpRequest.java b/spring-web/src/main/java/org/springframework/http/server/reactive/ServletServerHttpRequest.java index 33c6d5bd30..05f309d9c4 100644 --- a/spring-web/src/main/java/org/springframework/http/server/reactive/ServletServerHttpRequest.java +++ b/spring-web/src/main/java/org/springframework/http/server/reactive/ServletServerHttpRequest.java @@ -54,19 +54,13 @@ public class ServletServerHttpRequest extends AbstractServerHttpRequest { private final HttpServletRequest request; - private final DataBufferFactory bufferFactory; - - private final int bufferSize; - - private final Object bodyPublisherLock = new Object(); - - private volatile RequestBodyPublisher bodyPublisher; + private final RequestBodyPublisher bodyPublisher; private final Object cookieLock = new Object(); public ServletServerHttpRequest(HttpServletRequest request, AsyncContext asyncContext, - DataBufferFactory bufferFactory, int bufferSize) { + DataBufferFactory bufferFactory, int bufferSize) throws IOException { super(initUri(request), initHeaders(request)); @@ -74,12 +68,16 @@ public class ServletServerHttpRequest extends AbstractServerHttpRequest { Assert.isTrue(bufferSize > 0, "'bufferSize' must be higher than 0"); this.request = request; - this.bufferFactory = bufferFactory; - this.bufferSize = bufferSize; asyncContext.addListener(new RequestAsyncListener()); + + // Tomcat expects ReadListener registration on initial thread + ServletInputStream inputStream = request.getInputStream(); + this.bodyPublisher = new RequestBodyPublisher(inputStream, bufferFactory, bufferSize); + this.bodyPublisher.registerReadListener(); } + private static URI initUri(HttpServletRequest request) { Assert.notNull(request, "'request' must not be null"); try { @@ -168,24 +166,7 @@ public class ServletServerHttpRequest extends AbstractServerHttpRequest { @Override public Flux getBody() { - try { - RequestBodyPublisher publisher = this.bodyPublisher; - if (publisher == null) { - synchronized (this.bodyPublisherLock) { - publisher = this.bodyPublisher; - if (publisher == null) { - ServletInputStream inputStream = this.request.getInputStream(); - publisher = new RequestBodyPublisher(inputStream, this.bufferFactory, this.bufferSize); - publisher.registerReadListener(); - this.bodyPublisher = publisher; - } - } - } - return Flux.from(publisher); - } - catch (IOException ex) { - return Flux.error(ex); - } + return Flux.from(this.bodyPublisher); } @@ -198,33 +179,22 @@ public class ServletServerHttpRequest extends AbstractServerHttpRequest { public void onTimeout(AsyncEvent event) { Throwable ex = event.getThrowable(); ex = ex != null ? ex : new IllegalStateException("Async operation timeout."); - handleError(ex); + bodyPublisher.onError(ex); } @Override public void onError(AsyncEvent event) { - handleError(event.getThrowable()); - } - - private void handleError(Throwable ex) { - if (bodyPublisher != null) { - bodyPublisher.onError(ex); - } + bodyPublisher.onError(event.getThrowable()); } @Override public void onComplete(AsyncEvent event) { - if (bodyPublisher != null) { - bodyPublisher.onAllDataRead(); - } + bodyPublisher.onAllDataRead(); } } private static class RequestBodyPublisher extends AbstractListenerReadPublisher { - private final RequestBodyPublisherReadListener readListener = - new RequestBodyPublisherReadListener(); - private final ServletInputStream inputStream; private final DataBufferFactory bufferFactory; @@ -241,7 +211,7 @@ public class ServletServerHttpRequest extends AbstractServerHttpRequest { } public void registerReadListener() throws IOException { - this.inputStream.setReadListener(this.readListener); + this.inputStream.setReadListener(new RequestBodyPublisherReadListener()); } @Override @@ -268,7 +238,6 @@ public class ServletServerHttpRequest extends AbstractServerHttpRequest { return null; } - private class RequestBodyPublisherReadListener implements ReadListener { @Override diff --git a/spring-web/src/test/java/org/springframework/http/server/reactive/ServerHttpRequestTests.java b/spring-web/src/test/java/org/springframework/http/server/reactive/ServerHttpRequestTests.java index 4b7d2aab27..0ef6ba345e 100644 --- a/spring-web/src/test/java/org/springframework/http/server/reactive/ServerHttpRequestTests.java +++ b/spring-web/src/test/java/org/springframework/http/server/reactive/ServerHttpRequestTests.java @@ -16,14 +16,18 @@ package org.springframework.http.server.reactive; +import java.io.ByteArrayInputStream; import java.util.Arrays; import java.util.Collections; import javax.servlet.AsyncContext; +import javax.servlet.ReadListener; +import javax.servlet.ServletInputStream; import javax.servlet.http.HttpServletRequest; import org.junit.Test; import org.springframework.core.io.buffer.DefaultDataBufferFactory; +import org.springframework.mock.web.test.DelegatingServletInputStream; import org.springframework.mock.web.test.MockAsyncContext; import org.springframework.mock.web.test.MockHttpServletRequest; import org.springframework.mock.web.test.MockHttpServletResponse; @@ -75,9 +79,25 @@ public class ServerHttpRequestTests { } private ServerHttpRequest createHttpRequest(String path) throws Exception { - HttpServletRequest request = new MockHttpServletRequest("GET", path); + HttpServletRequest request = new MockHttpServletRequest("GET", path) { + @Override + public ServletInputStream getInputStream() { + return new TestServletInputStream(); + } + }; AsyncContext asyncContext = new MockAsyncContext(request, new MockHttpServletResponse()); return new ServletServerHttpRequest(request, asyncContext, new DefaultDataBufferFactory(), 1024); } + private static class TestServletInputStream extends DelegatingServletInputStream { + + public TestServletInputStream() { + super(new ByteArrayInputStream(new byte[0])); + } + + @Override + public void setReadListener(ReadListener readListener) { + // Ignore + } + } }