Commit 03a6f97e authored by Brian Clozel's avatar Brian Clozel

TestRestTemplate should not override request factory

Previously `TestRestTemplate` would override the configured
`ClientHttpRequestFactory` if the Apache HTTP client library was on
classpath.

This commit fixes two issues:

1. The existing `ClientHttpRequestFactory` is overridden *only* if it is
using the Apache HTTP client variant, in order to wrap it with the
`TestRestTemplate` custom support

2. Calling `withBasicAuth` will no longer directly use the request
factory returned by the internal `RestTemplate`; if client interceptors
are configured, the request factory is wrapped with an
`InterceptingClientHttpRequestFactory`. If we don't unwrap it,
interceptors are copied/applied twice in the newly created
`TestRestTemplate` instance. For that, we need to use reflection as the
underlying request factory is not accessible directly.

Closes gh-8697
parent 7872cda8
...@@ -17,6 +17,7 @@ ...@@ -17,6 +17,7 @@
package org.springframework.boot.test.web.client; package org.springframework.boot.test.web.client;
import java.io.IOException; import java.io.IOException;
import java.lang.reflect.Field;
import java.net.URI; import java.net.URI;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.Arrays; import java.util.Arrays;
...@@ -45,12 +46,14 @@ import org.springframework.http.HttpHeaders; ...@@ -45,12 +46,14 @@ import org.springframework.http.HttpHeaders;
import org.springframework.http.HttpMethod; import org.springframework.http.HttpMethod;
import org.springframework.http.RequestEntity; import org.springframework.http.RequestEntity;
import org.springframework.http.ResponseEntity; import org.springframework.http.ResponseEntity;
import org.springframework.http.client.ClientHttpRequestFactory;
import org.springframework.http.client.ClientHttpRequestInterceptor; import org.springframework.http.client.ClientHttpRequestInterceptor;
import org.springframework.http.client.ClientHttpResponse; import org.springframework.http.client.ClientHttpResponse;
import org.springframework.http.client.HttpComponentsClientHttpRequestFactory; import org.springframework.http.client.HttpComponentsClientHttpRequestFactory;
import org.springframework.http.client.InterceptingClientHttpRequestFactory;
import org.springframework.http.client.support.BasicAuthorizationInterceptor; import org.springframework.http.client.support.BasicAuthorizationInterceptor;
import org.springframework.util.Assert; import org.springframework.util.Assert;
import org.springframework.util.ClassUtils; import org.springframework.util.ReflectionUtils;
import org.springframework.web.client.DefaultResponseErrorHandler; import org.springframework.web.client.DefaultResponseErrorHandler;
import org.springframework.web.client.RequestCallback; import org.springframework.web.client.RequestCallback;
import org.springframework.web.client.ResponseExtractor; import org.springframework.web.client.ResponseExtractor;
...@@ -135,7 +138,8 @@ public class TestRestTemplate { ...@@ -135,7 +138,8 @@ public class TestRestTemplate {
HttpClientOption... httpClientOptions) { HttpClientOption... httpClientOptions) {
Assert.notNull(restTemplate, "RestTemplate must not be null"); Assert.notNull(restTemplate, "RestTemplate must not be null");
this.httpClientOptions = httpClientOptions; this.httpClientOptions = httpClientOptions;
if (ClassUtils.isPresent("org.apache.http.client.config.RequestConfig", null)) { if (restTemplate.getRequestFactory().getClass().getName()
.equals("org.springframework.http.client.HttpComponentsClientHttpRequestFactory")) {
restTemplate.setRequestFactory( restTemplate.setRequestFactory(
new CustomHttpComponentsClientHttpRequestFactory(httpClientOptions)); new CustomHttpComponentsClientHttpRequestFactory(httpClientOptions));
} }
...@@ -1021,7 +1025,7 @@ public class TestRestTemplate { ...@@ -1021,7 +1025,7 @@ public class TestRestTemplate {
RestTemplate restTemplate = new RestTemplate(); RestTemplate restTemplate = new RestTemplate();
restTemplate.setMessageConverters(getRestTemplate().getMessageConverters()); restTemplate.setMessageConverters(getRestTemplate().getMessageConverters());
restTemplate.setInterceptors(getRestTemplate().getInterceptors()); restTemplate.setInterceptors(getRestTemplate().getInterceptors());
restTemplate.setRequestFactory(getRestTemplate().getRequestFactory()); restTemplate.setRequestFactory(getRequestFactory(getRestTemplate()));
restTemplate.setUriTemplateHandler(getRestTemplate().getUriTemplateHandler()); restTemplate.setUriTemplateHandler(getRestTemplate().getUriTemplateHandler());
TestRestTemplate testRestTemplate = new TestRestTemplate(restTemplate, username, TestRestTemplate testRestTemplate = new TestRestTemplate(restTemplate, username,
password, this.httpClientOptions); password, this.httpClientOptions);
...@@ -1030,6 +1034,18 @@ public class TestRestTemplate { ...@@ -1030,6 +1034,18 @@ public class TestRestTemplate {
return testRestTemplate; return testRestTemplate;
} }
private ClientHttpRequestFactory getRequestFactory(RestTemplate restTemplate) {
ClientHttpRequestFactory requestFactory = restTemplate.getRequestFactory();
if (InterceptingClientHttpRequestFactory.class.isAssignableFrom(requestFactory.getClass())) {
Field requestFactoryField = ReflectionUtils
.findField(RestTemplate.class, "requestFactory");
ReflectionUtils.makeAccessible(requestFactoryField);
requestFactory = (ClientHttpRequestFactory)
ReflectionUtils.getField(requestFactoryField, getRestTemplate());
}
return requestFactory;
}
@SuppressWarnings({ "rawtypes", "unchecked" }) @SuppressWarnings({ "rawtypes", "unchecked" })
private RequestEntity<?> createRequestEntityWithRootAppliedUri( private RequestEntity<?> createRequestEntityWithRootAppliedUri(
RequestEntity<?> requestEntity) { RequestEntity<?> requestEntity) {
......
...@@ -37,6 +37,8 @@ import org.springframework.http.client.ClientHttpRequestFactory; ...@@ -37,6 +37,8 @@ import org.springframework.http.client.ClientHttpRequestFactory;
import org.springframework.http.client.ClientHttpRequestInterceptor; import org.springframework.http.client.ClientHttpRequestInterceptor;
import org.springframework.http.client.HttpComponentsClientHttpRequestFactory; import org.springframework.http.client.HttpComponentsClientHttpRequestFactory;
import org.springframework.http.client.InterceptingClientHttpRequestFactory; import org.springframework.http.client.InterceptingClientHttpRequestFactory;
import org.springframework.http.client.OkHttp3ClientHttpRequestFactory;
import org.springframework.http.client.SimpleClientHttpRequestFactory;
import org.springframework.http.client.support.BasicAuthorizationInterceptor; import org.springframework.http.client.support.BasicAuthorizationInterceptor;
import org.springframework.mock.env.MockEnvironment; import org.springframework.mock.env.MockEnvironment;
import org.springframework.mock.http.client.MockClientHttpRequest; import org.springframework.mock.http.client.MockClientHttpRequest;
...@@ -82,6 +84,15 @@ public class TestRestTemplateTests { ...@@ -82,6 +84,15 @@ public class TestRestTemplateTests {
.isInstanceOf(HttpComponentsClientHttpRequestFactory.class); .isInstanceOf(HttpComponentsClientHttpRequestFactory.class);
} }
@Test
public void doNotReplaceCustomRequestFactory() {
RestTemplateBuilder builder = new RestTemplateBuilder()
.requestFactory(OkHttp3ClientHttpRequestFactory.class);
TestRestTemplate testRestTemplate = new TestRestTemplate(builder);
assertThat(testRestTemplate.getRestTemplate().getRequestFactory())
.isInstanceOf(OkHttp3ClientHttpRequestFactory.class);
}
@Test @Test
public void getRootUriRootUriSetViaRestTemplateBuilder() { public void getRootUriRootUriSetViaRestTemplateBuilder() {
String rootUri = "http://example.com"; String rootUri = "http://example.com";
...@@ -125,6 +136,7 @@ public class TestRestTemplateTests { ...@@ -125,6 +136,7 @@ public class TestRestTemplateTests {
@Test @Test
public void restOperationsAreAvailable() { public void restOperationsAreAvailable() {
RestTemplate delegate = mock(RestTemplate.class); RestTemplate delegate = mock(RestTemplate.class);
given(delegate.getRequestFactory()).willReturn(new SimpleClientHttpRequestFactory());
given(delegate.getUriTemplateHandler()) given(delegate.getUriTemplateHandler())
.willReturn(new DefaultUriBuilderFactory()); .willReturn(new DefaultUriBuilderFactory());
RestTemplateBuilder builder = mock(RestTemplateBuilder.class); RestTemplateBuilder builder = mock(RestTemplateBuilder.class);
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment