Commit 2b51cf44 authored by Phillip Webb's avatar Phillip Webb

Merge pull request #16972 from kstrijbos

* pr/16972:
  Polish "Make it easier to set bufferRequestBody"
  Make it easier to set bufferRequestBody

Closes gh-16972
parents 750d251a af1a6d86
...@@ -33,8 +33,11 @@ import java.util.function.Supplier; ...@@ -33,8 +33,11 @@ import java.util.function.Supplier;
import org.springframework.beans.BeanUtils; import org.springframework.beans.BeanUtils;
import org.springframework.http.client.AbstractClientHttpRequestFactoryWrapper; import org.springframework.http.client.AbstractClientHttpRequestFactoryWrapper;
import org.springframework.http.client.ClientHttpRequest;
import org.springframework.http.client.ClientHttpRequestFactory; import org.springframework.http.client.ClientHttpRequestFactory;
import org.springframework.http.client.ClientHttpRequestInterceptor; import org.springframework.http.client.ClientHttpRequestInterceptor;
import org.springframework.http.client.HttpComponentsClientHttpRequestFactory;
import org.springframework.http.client.SimpleClientHttpRequestFactory;
import org.springframework.http.converter.HttpMessageConverter; import org.springframework.http.converter.HttpMessageConverter;
import org.springframework.util.Assert; import org.springframework.util.Assert;
import org.springframework.util.CollectionUtils; import org.springframework.util.CollectionUtils;
...@@ -60,6 +63,7 @@ import org.springframework.web.util.UriTemplateHandler; ...@@ -60,6 +63,7 @@ import org.springframework.web.util.UriTemplateHandler;
* @author Andy Wilkinson * @author Andy Wilkinson
* @author Brian Clozel * @author Brian Clozel
* @author Dmytro Nosan * @author Dmytro Nosan
* @author Kevin Strijbos
* @since 1.4.0 * @since 1.4.0
*/ */
public class RestTemplateBuilder { public class RestTemplateBuilder {
...@@ -506,6 +510,24 @@ public class RestTemplateBuilder { ...@@ -506,6 +510,24 @@ public class RestTemplateBuilder {
this.interceptors); this.interceptors);
} }
/**
* Sets if the underling {@link ClientHttpRequestFactory} should buffer the
* {@linkplain ClientHttpRequest#getBody() request body} internally.
* @param bufferRequestBody value of the bufferRequestBody parameter
* @return a new builder instance.
* @since 2.2.0
* @see SimpleClientHttpRequestFactory#setBufferRequestBody(boolean)
* @see HttpComponentsClientHttpRequestFactory#setBufferRequestBody(boolean)
*/
public RestTemplateBuilder setBufferRequestBody(boolean bufferRequestBody) {
return new RestTemplateBuilder(this.detectRequestFactory, this.rootUri,
this.messageConverters, this.requestFactorySupplier,
this.uriTemplateHandler, this.errorHandler, this.basicAuthentication,
this.restTemplateCustomizers,
this.requestFactoryCustomizer.bufferRequestBody(bufferRequestBody),
this.interceptors);
}
/** /**
* Build a new {@link RestTemplate} instance and configure it using this builder. * Build a new {@link RestTemplate} instance and configure it using this builder.
* @return a configured {@link RestTemplate} instance. * @return a configured {@link RestTemplate} instance.
...@@ -617,21 +639,32 @@ public class RestTemplateBuilder { ...@@ -617,21 +639,32 @@ public class RestTemplateBuilder {
private final Duration readTimeout; private final Duration readTimeout;
private final Boolean bufferRequestBody;
RequestFactoryCustomizer() { RequestFactoryCustomizer() {
this(null, null); this(null, null, null);
} }
private RequestFactoryCustomizer(Duration connectTimeout, Duration readTimeout) { private RequestFactoryCustomizer(Duration connectTimeout, Duration readTimeout,
Boolean bufferRequestBody) {
this.connectTimeout = connectTimeout; this.connectTimeout = connectTimeout;
this.readTimeout = readTimeout; this.readTimeout = readTimeout;
this.bufferRequestBody = bufferRequestBody;
} }
public RequestFactoryCustomizer connectTimeout(Duration connectTimeout) { public RequestFactoryCustomizer connectTimeout(Duration connectTimeout) {
return new RequestFactoryCustomizer(connectTimeout, this.readTimeout); return new RequestFactoryCustomizer(connectTimeout, this.readTimeout,
this.bufferRequestBody);
} }
public RequestFactoryCustomizer readTimeout(Duration readTimeout) { public RequestFactoryCustomizer readTimeout(Duration readTimeout) {
return new RequestFactoryCustomizer(this.connectTimeout, readTimeout); return new RequestFactoryCustomizer(this.connectTimeout, readTimeout,
this.bufferRequestBody);
}
public RequestFactoryCustomizer bufferRequestBody(boolean bufferRequestBody) {
return new RequestFactoryCustomizer(this.connectTimeout, this.readTimeout,
bufferRequestBody);
} }
@Override @Override
...@@ -639,12 +672,13 @@ public class RestTemplateBuilder { ...@@ -639,12 +672,13 @@ public class RestTemplateBuilder {
ClientHttpRequestFactory unwrappedRequestFactory = unwrapRequestFactoryIfNecessary( ClientHttpRequestFactory unwrappedRequestFactory = unwrapRequestFactoryIfNecessary(
requestFactory); requestFactory);
if (this.connectTimeout != null) { if (this.connectTimeout != null) {
new TimeoutRequestFactoryCustomizer(this.connectTimeout, setConnectTimeout(unwrappedRequestFactory);
"setConnectTimeout").customize(unwrappedRequestFactory);
} }
if (this.readTimeout != null) { if (this.readTimeout != null) {
new TimeoutRequestFactoryCustomizer(this.readTimeout, "setReadTimeout") setReadTimeout(unwrappedRequestFactory);
.customize(unwrappedRequestFactory); }
if (this.bufferRequestBody != null) {
setBufferRequestBody(unwrappedRequestFactory);
} }
} }
...@@ -664,35 +698,37 @@ public class RestTemplateBuilder { ...@@ -664,35 +698,37 @@ public class RestTemplateBuilder {
return unwrappedRequestFactory; return unwrappedRequestFactory;
} }
/** private void setConnectTimeout(ClientHttpRequestFactory factory) {
* {@link ClientHttpRequestFactory} customizer to call a "set timeout" method. Method method = findMethod(factory, "setConnectTimeout", int.class);
*/ int timeout = Math.toIntExact(this.connectTimeout.toMillis());
private static final class TimeoutRequestFactoryCustomizer { invoke(factory, method, timeout);
}
private final Duration timeout;
private final String methodName;
TimeoutRequestFactoryCustomizer(Duration timeout, String methodName) { private void setReadTimeout(ClientHttpRequestFactory factory) {
this.timeout = timeout; Method method = findMethod(factory, "setReadTimeout", int.class);
this.methodName = methodName; int timeout = Math.toIntExact(this.readTimeout.toMillis());
} invoke(factory, method, timeout);
}
void customize(ClientHttpRequestFactory factory) { private void setBufferRequestBody(ClientHttpRequestFactory factory) {
ReflectionUtils.invokeMethod(findMethod(factory), factory, Method method = findMethod(factory, "setBufferRequestBody", boolean.class);
Math.toIntExact(this.timeout.toMillis())); invoke(factory, method, this.bufferRequestBody);
} }
private Method findMethod(ClientHttpRequestFactory factory) { private Method findMethod(ClientHttpRequestFactory requestFactory,
Method method = ReflectionUtils.findMethod(factory.getClass(), String methodName, Class<?>... parameters) {
this.methodName, int.class); Method method = ReflectionUtils.findMethod(requestFactory.getClass(),
if (method != null) { methodName, parameters);
return method; if (method != null) {
} return method;
throw new IllegalStateException("Request factory " + factory.getClass()
+ " does not have a " + this.methodName + "(int) method");
} }
throw new IllegalStateException("Request factory " + requestFactory.getClass()
+ " does not have a suitable " + methodName + " method");
}
private void invoke(ClientHttpRequestFactory requestFactory, Method method,
Object... parameters) {
ReflectionUtils.invokeMethod(method, requestFactory, parameters);
} }
} }
......
...@@ -47,6 +47,7 @@ import org.springframework.web.util.UriTemplateHandler; ...@@ -47,6 +47,7 @@ import org.springframework.web.util.UriTemplateHandler;
import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException; import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException;
import static org.assertj.core.api.Assertions.assertThatIllegalStateException;
import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.inOrder; import static org.mockito.Mockito.inOrder;
import static org.mockito.Mockito.mock; import static org.mockito.Mockito.mock;
...@@ -63,6 +64,7 @@ import static org.springframework.test.web.client.response.MockRestResponseCreat ...@@ -63,6 +64,7 @@ import static org.springframework.test.web.client.response.MockRestResponseCreat
* @author Phillip Webb * @author Phillip Webb
* @author Andy Wilkinson * @author Andy Wilkinson
* @author Dmytro Nosan * @author Dmytro Nosan
* @author Kevin Strijbos
*/ */
public class RestTemplateBuilderTests { public class RestTemplateBuilderTests {
...@@ -480,6 +482,23 @@ public class RestTemplateBuilderTests { ...@@ -480,6 +482,23 @@ public class RestTemplateBuilderTests {
"requestConfig")).getSocketTimeout()).isEqualTo(1234); "requestConfig")).getSocketTimeout()).isEqualTo(1234);
} }
@Test
public void bufferRequestBodyCanBeConfiguredOnHttpComponentsRequestFactory() {
ClientHttpRequestFactory requestFactory = this.builder
.requestFactory(HttpComponentsClientHttpRequestFactory.class)
.setBufferRequestBody(false).build().getRequestFactory();
assertThat(requestFactory).hasFieldOrPropertyWithValue("bufferRequestBody",
false);
requestFactory = this.builder
.requestFactory(HttpComponentsClientHttpRequestFactory.class)
.setBufferRequestBody(true).build().getRequestFactory();
assertThat(requestFactory).hasFieldOrPropertyWithValue("bufferRequestBody", true);
requestFactory = this.builder
.requestFactory(HttpComponentsClientHttpRequestFactory.class).build()
.getRequestFactory();
assertThat(requestFactory).hasFieldOrPropertyWithValue("bufferRequestBody", true);
}
@Test @Test
public void connectTimeoutCanBeConfiguredOnSimpleRequestFactory() { public void connectTimeoutCanBeConfiguredOnSimpleRequestFactory() {
ClientHttpRequestFactory requestFactory = this.builder ClientHttpRequestFactory requestFactory = this.builder
...@@ -496,6 +515,21 @@ public class RestTemplateBuilderTests { ...@@ -496,6 +515,21 @@ public class RestTemplateBuilderTests {
assertThat(requestFactory).hasFieldOrPropertyWithValue("readTimeout", 1234); assertThat(requestFactory).hasFieldOrPropertyWithValue("readTimeout", 1234);
} }
@Test
public void bufferRequestBodyCanBeConfiguredOnSimpleRequestFactory() {
ClientHttpRequestFactory requestFactory = this.builder
.requestFactory(SimpleClientHttpRequestFactory.class)
.setBufferRequestBody(false).build().getRequestFactory();
assertThat(requestFactory).hasFieldOrPropertyWithValue("bufferRequestBody",
false);
requestFactory = this.builder.requestFactory(SimpleClientHttpRequestFactory.class)
.setBufferRequestBody(true).build().getRequestFactory();
assertThat(requestFactory).hasFieldOrPropertyWithValue("bufferRequestBody", true);
requestFactory = this.builder.requestFactory(SimpleClientHttpRequestFactory.class)
.build().getRequestFactory();
assertThat(requestFactory).hasFieldOrPropertyWithValue("bufferRequestBody", true);
}
@Test @Test
public void connectTimeoutCanBeConfiguredOnOkHttp3RequestFactory() { public void connectTimeoutCanBeConfiguredOnOkHttp3RequestFactory() {
ClientHttpRequestFactory requestFactory = this.builder ClientHttpRequestFactory requestFactory = this.builder
...@@ -516,6 +550,15 @@ public class RestTemplateBuilderTests { ...@@ -516,6 +550,15 @@ public class RestTemplateBuilderTests {
.isEqualTo(1234); .isEqualTo(1234);
} }
@Test
public void bufferRequestBodyCanNotBeConfiguredOnOkHttp3RequestFactory() {
assertThatIllegalStateException()
.isThrownBy(() -> this.builder
.requestFactory(OkHttp3ClientHttpRequestFactory.class)
.setBufferRequestBody(false).build().getRequestFactory())
.withMessageContaining(OkHttp3ClientHttpRequestFactory.class.getName());
}
@Test @Test
public void connectTimeoutCanBeConfiguredOnAWrappedRequestFactory() { public void connectTimeoutCanBeConfiguredOnAWrappedRequestFactory() {
SimpleClientHttpRequestFactory requestFactory = new SimpleClientHttpRequestFactory(); SimpleClientHttpRequestFactory requestFactory = new SimpleClientHttpRequestFactory();
...@@ -536,6 +579,27 @@ public class RestTemplateBuilderTests { ...@@ -536,6 +579,27 @@ public class RestTemplateBuilderTests {
assertThat(requestFactory).hasFieldOrPropertyWithValue("readTimeout", 1234); assertThat(requestFactory).hasFieldOrPropertyWithValue("readTimeout", 1234);
} }
@Test
public void bufferRequestBodyCanBeConfiguredOnAWrappedRequestFactory() {
SimpleClientHttpRequestFactory requestFactory = new SimpleClientHttpRequestFactory();
this.builder
.requestFactory(
() -> new BufferingClientHttpRequestFactory(requestFactory))
.setBufferRequestBody(false).build();
assertThat(requestFactory).hasFieldOrPropertyWithValue("bufferRequestBody",
false);
this.builder
.requestFactory(
() -> new BufferingClientHttpRequestFactory(requestFactory))
.setBufferRequestBody(true).build();
assertThat(requestFactory).hasFieldOrPropertyWithValue("bufferRequestBody", true);
this.builder
.requestFactory(
() -> new BufferingClientHttpRequestFactory(requestFactory))
.build();
assertThat(requestFactory).hasFieldOrPropertyWithValue("bufferRequestBody", true);
}
@Test @Test
public void unwrappingDoesNotAffectRequestFactoryThatIsSetOnTheBuiltTemplate() { public void unwrappingDoesNotAffectRequestFactoryThatIsSetOnTheBuiltTemplate() {
SimpleClientHttpRequestFactory requestFactory = new SimpleClientHttpRequestFactory(); SimpleClientHttpRequestFactory requestFactory = new SimpleClientHttpRequestFactory();
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment