Commit ad326036 authored by Phillip Webb's avatar Phillip Webb Committed by Stephane Nicoll

Restore compatibility with MockRestServiceServer

Closes gh-17885
parent 3d5530d1
...@@ -99,10 +99,7 @@ class TestRestTemplateTests { ...@@ -99,10 +99,7 @@ class TestRestTemplateTests {
RestTemplateBuilder builder = new RestTemplateBuilder().requestFactory(() -> customFactory); RestTemplateBuilder builder = new RestTemplateBuilder().requestFactory(() -> customFactory);
TestRestTemplate testRestTemplate = new TestRestTemplate(builder).withBasicAuth("test", "test"); TestRestTemplate testRestTemplate = new TestRestTemplate(builder).withBasicAuth("test", "test");
RestTemplate restTemplate = testRestTemplate.getRestTemplate(); RestTemplate restTemplate = testRestTemplate.getRestTemplate();
assertThat(restTemplate.getRequestFactory().getClass().getName()) assertThat(restTemplate.getRequestFactory()).isEqualTo(customFactory).hasSameClassAs(customFactory);
.contains("RestTemplateBuilderClientHttpRequestFactoryWrapper");
Object requestFactory = ReflectionTestUtils.getField(restTemplate.getRequestFactory(), "requestFactory");
assertThat(requestFactory).isEqualTo(customFactory).hasSameClassAs(customFactory);
} }
@Test @Test
...@@ -203,28 +200,21 @@ class TestRestTemplateTests { ...@@ -203,28 +200,21 @@ class TestRestTemplateTests {
} }
@Test @Test
void withBasicAuthAddsBasicAuthClientFactoryWhenNotAlreadyPresent() throws Exception { void withBasicAuthAddsBasicAuthWhenNotAlreadyPresent() throws Exception {
TestRestTemplate original = new TestRestTemplate(); TestRestTemplate original = new TestRestTemplate();
TestRestTemplate basicAuth = original.withBasicAuth("user", "password"); TestRestTemplate basicAuth = original.withBasicAuth("user", "password");
assertThat(getConverterClasses(original)).containsExactlyElementsOf(getConverterClasses(basicAuth)); assertThat(getConverterClasses(original)).containsExactlyElementsOf(getConverterClasses(basicAuth));
assertThat(basicAuth.getRestTemplate().getRequestFactory().getClass().getName())
.contains("RestTemplateBuilderClientHttpRequestFactoryWrapper");
assertThat(ReflectionTestUtils.getField(basicAuth.getRestTemplate().getRequestFactory(), "requestFactory"))
.isInstanceOf(CustomHttpComponentsClientHttpRequestFactory.class);
assertThat(basicAuth.getRestTemplate().getInterceptors()).isEmpty(); assertThat(basicAuth.getRestTemplate().getInterceptors()).isEmpty();
assertBasicAuthorizationCredentials(original, null, null);
assertBasicAuthorizationCredentials(basicAuth, "user", "password"); assertBasicAuthorizationCredentials(basicAuth, "user", "password");
} }
@Test @Test
void withBasicAuthReplacesBasicAuthClientFactoryWhenAlreadyPresent() throws Exception { void withBasicAuthReplacesBasicAuthWhenAlreadyPresent() throws Exception {
TestRestTemplate original = new TestRestTemplate("foo", "bar").withBasicAuth("replace", "replace"); TestRestTemplate original = new TestRestTemplate("foo", "bar").withBasicAuth("replace", "replace");
TestRestTemplate basicAuth = original.withBasicAuth("user", "password"); TestRestTemplate basicAuth = original.withBasicAuth("user", "password");
assertThat(getConverterClasses(basicAuth)).containsExactlyElementsOf(getConverterClasses(original)); assertThat(getConverterClasses(basicAuth)).containsExactlyElementsOf(getConverterClasses(original));
assertThat(basicAuth.getRestTemplate().getRequestFactory().getClass().getName()) assertBasicAuthorizationCredentials(original, "replace", "replace");
.contains("RestTemplateBuilderClientHttpRequestFactoryWrapper");
assertThat(ReflectionTestUtils.getField(basicAuth.getRestTemplate().getRequestFactory(), "requestFactory"))
.isInstanceOf(CustomHttpComponentsClientHttpRequestFactory.class);
assertThat(basicAuth.getRestTemplate().getInterceptors()).isEmpty();
assertBasicAuthorizationCredentials(basicAuth, "user", "password"); assertBasicAuthorizationCredentials(basicAuth, "user", "password");
} }
...@@ -347,11 +337,16 @@ class TestRestTemplateTests { ...@@ -347,11 +337,16 @@ class TestRestTemplateTests {
private void assertBasicAuthorizationCredentials(TestRestTemplate testRestTemplate, String username, private void assertBasicAuthorizationCredentials(TestRestTemplate testRestTemplate, String username,
String password) throws Exception { String password) throws Exception {
ClientHttpRequestFactory requestFactory = testRestTemplate.getRestTemplate().getRequestFactory(); ClientHttpRequest request = ReflectionTestUtils.invokeMethod(testRestTemplate.getRestTemplate(),
ClientHttpRequest request = requestFactory.createRequest(URI.create("http://localhost"), HttpMethod.POST); "createRequest", URI.create("http://localhost"), HttpMethod.POST);
assertThat(request.getHeaders()).containsKeys(HttpHeaders.AUTHORIZATION); if (username == null) {
assertThat(request.getHeaders().get(HttpHeaders.AUTHORIZATION)).containsExactly( assertThat(request.getHeaders()).doesNotContainKey(HttpHeaders.AUTHORIZATION);
"Basic " + Base64Utils.encodeToString(String.format("%s:%s", username, password).getBytes())); }
else {
assertThat(request.getHeaders()).containsKeys(HttpHeaders.AUTHORIZATION);
assertThat(request.getHeaders().get(HttpHeaders.AUTHORIZATION)).containsExactly(
"Basic " + Base64Utils.encodeToString(String.format("%s:%s", username, password).getBytes()));
}
} }
......
...@@ -615,7 +615,7 @@ public class RestTemplateBuilder { ...@@ -615,7 +615,7 @@ public class RestTemplateBuilder {
if (requestFactory != null) { if (requestFactory != null) {
restTemplate.setRequestFactory(requestFactory); restTemplate.setRequestFactory(requestFactory);
} }
addClientHttpRequestFactoryWrapper(restTemplate); addClientHttpRequestInitializer(restTemplate);
if (!CollectionUtils.isEmpty(this.messageConverters)) { if (!CollectionUtils.isEmpty(this.messageConverters)) {
restTemplate.setMessageConverters(new ArrayList<>(this.messageConverters)); restTemplate.setMessageConverters(new ArrayList<>(this.messageConverters));
} }
...@@ -659,24 +659,12 @@ public class RestTemplateBuilder { ...@@ -659,24 +659,12 @@ public class RestTemplateBuilder {
return requestFactory; return requestFactory;
} }
private void addClientHttpRequestFactoryWrapper(RestTemplate restTemplate) { private void addClientHttpRequestInitializer(RestTemplate restTemplate) {
if (this.basicAuthentication == null && this.defaultHeaders.isEmpty() && this.requestCustomizers.isEmpty()) { if (this.basicAuthentication == null && this.defaultHeaders.isEmpty() && this.requestCustomizers.isEmpty()) {
return; return;
} }
List<ClientHttpRequestInterceptor> interceptors = null; restTemplate.getClientHttpRequestInitializers().add(new RestTemplateBuilderClientHttpRequestInitializer(
if (!restTemplate.getInterceptors().isEmpty()) { this.basicAuthentication, this.defaultHeaders, this.requestCustomizers));
// Stash and clear the interceptors so we can access the real factory
interceptors = new ArrayList<>(restTemplate.getInterceptors());
restTemplate.getInterceptors().clear();
}
ClientHttpRequestFactory requestFactory = restTemplate.getRequestFactory();
ClientHttpRequestFactory wrapper = new RestTemplateBuilderClientHttpRequestFactoryWrapper(requestFactory,
this.basicAuthentication, this.defaultHeaders, this.requestCustomizers);
restTemplate.setRequestFactory(wrapper);
// Restore the original interceptors
if (interceptors != null) {
restTemplate.getInterceptors().addAll(interceptors);
}
} }
@SuppressWarnings("unchecked") @SuppressWarnings("unchecked")
......
...@@ -16,18 +16,15 @@ ...@@ -16,18 +16,15 @@
package org.springframework.boot.web.client; package org.springframework.boot.web.client;
import java.io.IOException;
import java.net.URI;
import java.util.List; import java.util.List;
import java.util.Map; import java.util.Map;
import java.util.Set; import java.util.Set;
import org.springframework.boot.util.LambdaSafe; import org.springframework.boot.util.LambdaSafe;
import org.springframework.http.HttpHeaders; import org.springframework.http.HttpHeaders;
import org.springframework.http.HttpMethod;
import org.springframework.http.client.AbstractClientHttpRequestFactoryWrapper;
import org.springframework.http.client.ClientHttpRequest; import org.springframework.http.client.ClientHttpRequest;
import org.springframework.http.client.ClientHttpRequestFactory; import org.springframework.http.client.ClientHttpRequestFactory;
import org.springframework.http.client.ClientHttpRequestInitializer;
/** /**
* {@link ClientHttpRequestFactory} to apply customizations from the * {@link ClientHttpRequestFactory} to apply customizations from the
...@@ -36,7 +33,7 @@ import org.springframework.http.client.ClientHttpRequestFactory; ...@@ -36,7 +33,7 @@ import org.springframework.http.client.ClientHttpRequestFactory;
* @author Dmytro Nosan * @author Dmytro Nosan
* @author Ilya Lukyanovich * @author Ilya Lukyanovich
*/ */
class RestTemplateBuilderClientHttpRequestFactoryWrapper extends AbstractClientHttpRequestFactoryWrapper { class RestTemplateBuilderClientHttpRequestInitializer implements ClientHttpRequestInitializer {
private final BasicAuthentication basicAuthentication; private final BasicAuthentication basicAuthentication;
...@@ -44,10 +41,8 @@ class RestTemplateBuilderClientHttpRequestFactoryWrapper extends AbstractClientH ...@@ -44,10 +41,8 @@ class RestTemplateBuilderClientHttpRequestFactoryWrapper extends AbstractClientH
private final Set<RestTemplateRequestCustomizer<?>> requestCustomizers; private final Set<RestTemplateRequestCustomizer<?>> requestCustomizers;
RestTemplateBuilderClientHttpRequestFactoryWrapper(ClientHttpRequestFactory requestFactory, RestTemplateBuilderClientHttpRequestInitializer(BasicAuthentication basicAuthentication,
BasicAuthentication basicAuthentication, Map<String, List<String>> defaultHeaders, Map<String, List<String>> defaultHeaders, Set<RestTemplateRequestCustomizer<?>> requestCustomizers) {
Set<RestTemplateRequestCustomizer<?>> requestCustomizers) {
super(requestFactory);
this.basicAuthentication = basicAuthentication; this.basicAuthentication = basicAuthentication;
this.defaultHeaders = defaultHeaders; this.defaultHeaders = defaultHeaders;
this.requestCustomizers = requestCustomizers; this.requestCustomizers = requestCustomizers;
...@@ -55,9 +50,7 @@ class RestTemplateBuilderClientHttpRequestFactoryWrapper extends AbstractClientH ...@@ -55,9 +50,7 @@ class RestTemplateBuilderClientHttpRequestFactoryWrapper extends AbstractClientH
@Override @Override
@SuppressWarnings("unchecked") @SuppressWarnings("unchecked")
protected ClientHttpRequest createRequest(URI uri, HttpMethod httpMethod, ClientHttpRequestFactory requestFactory) public void initialize(ClientHttpRequest request) {
throws IOException {
ClientHttpRequest request = requestFactory.createRequest(uri, httpMethod);
HttpHeaders headers = request.getHeaders(); HttpHeaders headers = request.getHeaders();
if (this.basicAuthentication != null) { if (this.basicAuthentication != null) {
this.basicAuthentication.applyTo(headers); this.basicAuthentication.applyTo(headers);
...@@ -65,7 +58,6 @@ class RestTemplateBuilderClientHttpRequestFactoryWrapper extends AbstractClientH ...@@ -65,7 +58,6 @@ class RestTemplateBuilderClientHttpRequestFactoryWrapper extends AbstractClientH
this.defaultHeaders.forEach(headers::putIfAbsent); this.defaultHeaders.forEach(headers::putIfAbsent);
LambdaSafe.callbacks(RestTemplateRequestCustomizer.class, this.requestCustomizers, request) LambdaSafe.callbacks(RestTemplateRequestCustomizer.class, this.requestCustomizers, request)
.invoke((customizer) -> customizer.customize(request)); .invoke((customizer) -> customizer.customize(request));
return request;
} }
} }
...@@ -17,6 +17,7 @@ ...@@ -17,6 +17,7 @@
package org.springframework.boot.web.client; package org.springframework.boot.web.client;
import org.springframework.http.client.ClientHttpRequest; import org.springframework.http.client.ClientHttpRequest;
import org.springframework.http.client.ClientHttpRequestInitializer;
import org.springframework.web.client.RestTemplate; import org.springframework.web.client.RestTemplate;
/** /**
...@@ -28,6 +29,7 @@ import org.springframework.web.client.RestTemplate; ...@@ -28,6 +29,7 @@ import org.springframework.web.client.RestTemplate;
* @author Phillip Webb * @author Phillip Webb
* @since 2.2.0 * @since 2.2.0
* @see RestTemplateBuilder * @see RestTemplateBuilder
* @see ClientHttpRequestInitializer
*/ */
@FunctionalInterface @FunctionalInterface
public interface RestTemplateRequestCustomizer<T extends ClientHttpRequest> { public interface RestTemplateRequestCustomizer<T extends ClientHttpRequest> {
......
...@@ -17,7 +17,6 @@ ...@@ -17,7 +17,6 @@
package org.springframework.boot.web.client; package org.springframework.boot.web.client;
import java.io.IOException; import java.io.IOException;
import java.net.URI;
import java.util.Arrays; import java.util.Arrays;
import java.util.Collections; import java.util.Collections;
import java.util.LinkedHashMap; import java.util.LinkedHashMap;
...@@ -26,72 +25,55 @@ import java.util.List; ...@@ -26,72 +25,55 @@ import java.util.List;
import java.util.Map; import java.util.Map;
import java.util.Set; import java.util.Set;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import org.mockito.InOrder; import org.mockito.InOrder;
import org.springframework.http.HttpHeaders; import org.springframework.http.HttpHeaders;
import org.springframework.http.HttpMethod;
import org.springframework.http.client.ClientHttpRequest; import org.springframework.http.client.ClientHttpRequest;
import org.springframework.http.client.ClientHttpRequestFactory; import org.springframework.mock.http.client.MockClientHttpRequest;
import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThat;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.BDDMockito.given;
import static org.mockito.Mockito.inOrder; import static org.mockito.Mockito.inOrder;
import static org.mockito.Mockito.mock; import static org.mockito.Mockito.mock;
/** /**
* Tests for {@link RestTemplateBuilderClientHttpRequestFactoryWrapper}. * Tests for {@link RestTemplateBuilderClientHttpRequestInitializer}.
* *
* @author Dmytro Nosan * @author Dmytro Nosan
* @author Ilya Lukyanovich * @author Ilya Lukyanovich
* @author Phillip Webb * @author Phillip Webb
*/ */
public class RestTemplateBuilderClientHttpRequestFactoryWrapperTests { public class RestTemplateBuilderClientHttpRequestInitializerTests {
private ClientHttpRequestFactory requestFactory; private final MockClientHttpRequest request = new MockClientHttpRequest();
private final HttpHeaders headers = new HttpHeaders();
@BeforeEach
void setUp() throws IOException {
this.requestFactory = mock(ClientHttpRequestFactory.class);
ClientHttpRequest request = mock(ClientHttpRequest.class);
given(this.requestFactory.createRequest(any(), any())).willReturn(request);
given(request.getHeaders()).willReturn(this.headers);
}
@Test @Test
void createRequestWhenHasBasicAuthAndNoAuthHeaderAddsHeader() throws IOException { void createRequestWhenHasBasicAuthAndNoAuthHeaderAddsHeader() throws IOException {
this.requestFactory = new RestTemplateBuilderClientHttpRequestFactoryWrapper(this.requestFactory, new RestTemplateBuilderClientHttpRequestInitializer(new BasicAuthentication("spring", "boot", null),
new BasicAuthentication("spring", "boot", null), Collections.emptyMap(), Collections.emptySet()); Collections.emptyMap(), Collections.emptySet()).initialize(this.request);
ClientHttpRequest request = createRequest(); assertThat(this.request.getHeaders().get(HttpHeaders.AUTHORIZATION)).containsExactly("Basic c3ByaW5nOmJvb3Q=");
assertThat(request.getHeaders().get(HttpHeaders.AUTHORIZATION)).containsExactly("Basic c3ByaW5nOmJvb3Q=");
} }
@Test @Test
void createRequestWhenHasBasicAuthAndExistingAuthHeaderDoesNotAddHeader() throws IOException { void createRequestWhenHasBasicAuthAndExistingAuthHeaderDoesNotAddHeader() throws IOException {
this.headers.setBasicAuth("boot", "spring"); this.request.getHeaders().setBasicAuth("boot", "spring");
this.requestFactory = new RestTemplateBuilderClientHttpRequestFactoryWrapper(this.requestFactory, new RestTemplateBuilderClientHttpRequestInitializer(new BasicAuthentication("spring", "boot", null),
new BasicAuthentication("spring", "boot", null), Collections.emptyMap(), Collections.emptySet()); Collections.emptyMap(), Collections.emptySet()).initialize(this.request);
ClientHttpRequest request = createRequest(); assertThat(this.request.getHeaders().get(HttpHeaders.AUTHORIZATION)).doesNotContain("Basic c3ByaW5nOmJvb3Q=");
assertThat(request.getHeaders().get(HttpHeaders.AUTHORIZATION)).doesNotContain("Basic c3ByaW5nOmJvb3Q=");
} }
@Test @Test
void createRequestWhenHasDefaultHeadersAddsMissing() throws IOException { void createRequestWhenHasDefaultHeadersAddsMissing() throws IOException {
this.headers.add("one", "existing"); this.request.getHeaders().add("one", "existing");
Map<String, List<String>> defaultHeaders = new LinkedHashMap<>(); Map<String, List<String>> defaultHeaders = new LinkedHashMap<>();
defaultHeaders.put("one", Collections.singletonList("1")); defaultHeaders.put("one", Collections.singletonList("1"));
defaultHeaders.put("two", Arrays.asList("2", "3")); defaultHeaders.put("two", Arrays.asList("2", "3"));
defaultHeaders.put("three", Collections.singletonList("4")); defaultHeaders.put("three", Collections.singletonList("4"));
this.requestFactory = new RestTemplateBuilderClientHttpRequestFactoryWrapper(this.requestFactory, null, new RestTemplateBuilderClientHttpRequestInitializer(null, defaultHeaders, Collections.emptySet())
defaultHeaders, Collections.emptySet()); .initialize(this.request);
ClientHttpRequest request = createRequest(); assertThat(this.request.getHeaders().get("one")).containsExactly("existing");
assertThat(request.getHeaders().get("one")).containsExactly("existing"); assertThat(this.request.getHeaders().get("two")).containsExactly("2", "3");
assertThat(request.getHeaders().get("two")).containsExactly("2", "3"); assertThat(this.request.getHeaders().get("three")).containsExactly("4");
assertThat(request.getHeaders().get("three")).containsExactly("4");
} }
@Test @Test
...@@ -101,17 +83,12 @@ public class RestTemplateBuilderClientHttpRequestFactoryWrapperTests { ...@@ -101,17 +83,12 @@ public class RestTemplateBuilderClientHttpRequestFactoryWrapperTests {
customizers.add(mock(RestTemplateRequestCustomizer.class)); customizers.add(mock(RestTemplateRequestCustomizer.class));
customizers.add(mock(RestTemplateRequestCustomizer.class)); customizers.add(mock(RestTemplateRequestCustomizer.class));
customizers.add(mock(RestTemplateRequestCustomizer.class)); customizers.add(mock(RestTemplateRequestCustomizer.class));
this.requestFactory = new RestTemplateBuilderClientHttpRequestFactoryWrapper(this.requestFactory, null, new RestTemplateBuilderClientHttpRequestInitializer(null, Collections.emptyMap(), customizers)
Collections.emptyMap(), customizers); .initialize(this.request);
ClientHttpRequest request = createRequest();
InOrder inOrder = inOrder(customizers.toArray()); InOrder inOrder = inOrder(customizers.toArray());
for (RestTemplateRequestCustomizer<?> customizer : customizers) { for (RestTemplateRequestCustomizer<?> customizer : customizers) {
inOrder.verify((RestTemplateRequestCustomizer<ClientHttpRequest>) customizer).customize(request); inOrder.verify((RestTemplateRequestCustomizer<ClientHttpRequest>) customizer).customize(this.request);
} }
} }
private ClientHttpRequest createRequest() throws IOException {
return this.requestFactory.createRequest(URI.create("https://localhost:8080"), HttpMethod.POST);
}
} }
...@@ -38,6 +38,7 @@ import org.springframework.http.MediaType; ...@@ -38,6 +38,7 @@ import org.springframework.http.MediaType;
import org.springframework.http.client.BufferingClientHttpRequestFactory; import org.springframework.http.client.BufferingClientHttpRequestFactory;
import org.springframework.http.client.ClientHttpRequest; import org.springframework.http.client.ClientHttpRequest;
import org.springframework.http.client.ClientHttpRequestFactory; import org.springframework.http.client.ClientHttpRequestFactory;
import org.springframework.http.client.ClientHttpRequestInitializer;
import org.springframework.http.client.ClientHttpRequestInterceptor; import org.springframework.http.client.ClientHttpRequestInterceptor;
import org.springframework.http.client.HttpComponentsClientHttpRequestFactory; import org.springframework.http.client.HttpComponentsClientHttpRequestFactory;
import org.springframework.http.client.InterceptingClientHttpRequestFactory; import org.springframework.http.client.InterceptingClientHttpRequestFactory;
...@@ -309,8 +310,7 @@ class RestTemplateBuilderTests { ...@@ -309,8 +310,7 @@ class RestTemplateBuilderTests {
@Test @Test
void basicAuthenticationShouldApply() throws Exception { void basicAuthenticationShouldApply() throws Exception {
RestTemplate template = this.builder.basicAuthentication("spring", "boot", StandardCharsets.UTF_8).build(); RestTemplate template = this.builder.basicAuthentication("spring", "boot", StandardCharsets.UTF_8).build();
ClientHttpRequestFactory requestFactory = template.getRequestFactory(); ClientHttpRequest request = createRequest(template);
ClientHttpRequest request = requestFactory.createRequest(URI.create("http://localhost"), HttpMethod.POST);
assertThat(request.getHeaders()).containsOnlyKeys(HttpHeaders.AUTHORIZATION); assertThat(request.getHeaders()).containsOnlyKeys(HttpHeaders.AUTHORIZATION);
assertThat(request.getHeaders().get(HttpHeaders.AUTHORIZATION)).containsExactly("Basic c3ByaW5nOmJvb3Q="); assertThat(request.getHeaders().get(HttpHeaders.AUTHORIZATION)).containsExactly("Basic c3ByaW5nOmJvb3Q=");
} }
...@@ -318,8 +318,7 @@ class RestTemplateBuilderTests { ...@@ -318,8 +318,7 @@ class RestTemplateBuilderTests {
@Test @Test
void defaultHeaderAddsHeader() throws IOException { void defaultHeaderAddsHeader() throws IOException {
RestTemplate template = this.builder.defaultHeader("spring", "boot").build(); RestTemplate template = this.builder.defaultHeader("spring", "boot").build();
ClientHttpRequestFactory requestFactory = template.getRequestFactory(); ClientHttpRequest request = createRequest(template);
ClientHttpRequest request = requestFactory.createRequest(URI.create("http://localhost"), HttpMethod.GET);
assertThat(request.getHeaders()).contains(entry("spring", Collections.singletonList("boot"))); assertThat(request.getHeaders()).contains(entry("spring", Collections.singletonList("boot")));
} }
...@@ -328,17 +327,23 @@ class RestTemplateBuilderTests { ...@@ -328,17 +327,23 @@ class RestTemplateBuilderTests {
String name = HttpHeaders.ACCEPT; String name = HttpHeaders.ACCEPT;
String[] values = { MediaType.APPLICATION_JSON_VALUE, MediaType.APPLICATION_XML_VALUE }; String[] values = { MediaType.APPLICATION_JSON_VALUE, MediaType.APPLICATION_XML_VALUE };
RestTemplate template = this.builder.defaultHeader(name, values).build(); RestTemplate template = this.builder.defaultHeader(name, values).build();
ClientHttpRequestFactory requestFactory = template.getRequestFactory(); ClientHttpRequest request = createRequest(template);
ClientHttpRequest request = requestFactory.createRequest(URI.create("http://localhost"), HttpMethod.GET);
assertThat(request.getHeaders()).contains(entry(name, Arrays.asList(values))); assertThat(request.getHeaders()).contains(entry(name, Arrays.asList(values)));
} }
@Test // gh-17885
void defaultHeaderWhenUsingMockRestServiceServerAddsHeader() throws IOException {
RestTemplate template = this.builder.defaultHeader("spring", "boot").build();
MockRestServiceServer.bindTo(template).build();
ClientHttpRequest request = createRequest(template);
assertThat(request.getHeaders()).contains(entry("spring", Collections.singletonList("boot")));
}
@Test @Test
void requestCustomizersAddsCustomizers() throws IOException { void requestCustomizersAddsCustomizers() throws IOException {
RestTemplate template = this.builder RestTemplate template = this.builder
.requestCustomizers((request) -> request.getHeaders().add("spring", "framework")).build(); .requestCustomizers((request) -> request.getHeaders().add("spring", "framework")).build();
ClientHttpRequestFactory requestFactory = template.getRequestFactory(); ClientHttpRequest request = createRequest(template);
ClientHttpRequest request = requestFactory.createRequest(URI.create("http://localhost"), HttpMethod.GET);
assertThat(request.getHeaders()).contains(entry("spring", Collections.singletonList("framework"))); assertThat(request.getHeaders()).contains(entry("spring", Collections.singletonList("framework")));
} }
...@@ -347,8 +352,7 @@ class RestTemplateBuilderTests { ...@@ -347,8 +352,7 @@ class RestTemplateBuilderTests {
RestTemplate template = this.builder RestTemplate template = this.builder
.requestCustomizers((request) -> request.getHeaders().add("spring", "framework")) .requestCustomizers((request) -> request.getHeaders().add("spring", "framework"))
.additionalRequestCustomizers((request) -> request.getHeaders().add("for", "java")).build(); .additionalRequestCustomizers((request) -> request.getHeaders().add("for", "java")).build();
ClientHttpRequestFactory requestFactory = template.getRequestFactory(); ClientHttpRequest request = createRequest(template);
ClientHttpRequest request = requestFactory.createRequest(URI.create("http://localhost"), HttpMethod.GET);
assertThat(request.getHeaders()).contains(entry("spring", Collections.singletonList("framework"))) assertThat(request.getHeaders()).contains(entry("spring", Collections.singletonList("framework")))
.contains(entry("for", Collections.singletonList("java"))); .contains(entry("for", Collections.singletonList("java")));
} }
...@@ -428,11 +432,8 @@ class RestTemplateBuilderTests { ...@@ -428,11 +432,8 @@ class RestTemplateBuilderTests {
assertThat(restTemplate.getErrorHandler()).isEqualTo(errorHandler); assertThat(restTemplate.getErrorHandler()).isEqualTo(errorHandler);
ClientHttpRequestFactory actualRequestFactory = restTemplate.getRequestFactory(); ClientHttpRequestFactory actualRequestFactory = restTemplate.getRequestFactory();
assertThat(actualRequestFactory).isInstanceOf(InterceptingClientHttpRequestFactory.class); assertThat(actualRequestFactory).isInstanceOf(InterceptingClientHttpRequestFactory.class);
ClientHttpRequestFactory authRequestFactory = (ClientHttpRequestFactory) ReflectionTestUtils ClientHttpRequestInitializer initializer = restTemplate.getClientHttpRequestInitializers().get(0);
.getField(actualRequestFactory, "requestFactory"); assertThat(initializer).isInstanceOf(RestTemplateBuilderClientHttpRequestInitializer.class);
assertThat(authRequestFactory)
.isInstanceOf(RestTemplateBuilderClientHttpRequestFactoryWrapper.class);
assertThat(authRequestFactory).hasFieldOrPropertyWithValue("requestFactory", requestFactory);
}).build(); }).build();
} }
...@@ -589,6 +590,11 @@ class RestTemplateBuilderTests { ...@@ -589,6 +590,11 @@ class RestTemplateBuilderTests {
assertThat(template.getRequestFactory()).isInstanceOf(BufferingClientHttpRequestFactory.class); assertThat(template.getRequestFactory()).isInstanceOf(BufferingClientHttpRequestFactory.class);
} }
private ClientHttpRequest createRequest(RestTemplate template) {
return ReflectionTestUtils.invokeMethod(template, "createRequest", URI.create("http://localhost"),
HttpMethod.GET);
}
static class RestTemplateSubclass extends RestTemplate { static class RestTemplateSubclass extends RestTemplate {
} }
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment