Commit f56386ed authored by Phillip Webb's avatar Phillip Webb

Merge pull request #17010 from nosan

* pr/17010:
  Polish "Use request factory to support Basic Authentication"
  Use request factory to support Basic Authentication

Closes gh-17010
parents 4ac1407a 76e075dd
...@@ -17,16 +17,11 @@ ...@@ -17,16 +17,11 @@
package org.springframework.boot.test.web.client; package org.springframework.boot.test.web.client;
import java.io.IOException; import java.io.IOException;
import java.lang.reflect.Field;
import java.net.URI; import java.net.URI;
import java.util.ArrayList;
import java.util.Arrays; import java.util.Arrays;
import java.util.Collections;
import java.util.HashSet; import java.util.HashSet;
import java.util.List;
import java.util.Map; import java.util.Map;
import java.util.Set; import java.util.Set;
import java.util.function.Supplier;
import org.apache.http.client.HttpClient; import org.apache.http.client.HttpClient;
import org.apache.http.client.config.CookieSpecs; import org.apache.http.client.config.CookieSpecs;
...@@ -39,9 +34,6 @@ import org.apache.http.impl.client.HttpClients; ...@@ -39,9 +34,6 @@ import org.apache.http.impl.client.HttpClients;
import org.apache.http.protocol.HttpContext; import org.apache.http.protocol.HttpContext;
import org.apache.http.ssl.SSLContextBuilder; import org.apache.http.ssl.SSLContextBuilder;
import org.springframework.beans.BeanInstantiationException;
import org.springframework.beans.BeanUtils;
import org.springframework.boot.web.client.ClientHttpRequestFactorySupplier;
import org.springframework.boot.web.client.RestTemplateBuilder; import org.springframework.boot.web.client.RestTemplateBuilder;
import org.springframework.boot.web.client.RootUriTemplateHandler; import org.springframework.boot.web.client.RootUriTemplateHandler;
import org.springframework.core.ParameterizedTypeReference; import org.springframework.core.ParameterizedTypeReference;
...@@ -51,13 +43,9 @@ import org.springframework.http.HttpMethod; ...@@ -51,13 +43,9 @@ import org.springframework.http.HttpMethod;
import org.springframework.http.RequestEntity; import org.springframework.http.RequestEntity;
import org.springframework.http.ResponseEntity; import org.springframework.http.ResponseEntity;
import org.springframework.http.client.ClientHttpRequestFactory; import org.springframework.http.client.ClientHttpRequestFactory;
import org.springframework.http.client.ClientHttpRequestInterceptor;
import org.springframework.http.client.ClientHttpResponse; import org.springframework.http.client.ClientHttpResponse;
import org.springframework.http.client.HttpComponentsClientHttpRequestFactory; import org.springframework.http.client.HttpComponentsClientHttpRequestFactory;
import org.springframework.http.client.InterceptingClientHttpRequestFactory;
import org.springframework.http.client.support.BasicAuthenticationInterceptor;
import org.springframework.util.Assert; import org.springframework.util.Assert;
import org.springframework.util.ReflectionUtils;
import org.springframework.web.client.DefaultResponseErrorHandler; import org.springframework.web.client.DefaultResponseErrorHandler;
import org.springframework.web.client.RequestCallback; import org.springframework.web.client.RequestCallback;
import org.springframework.web.client.ResponseExtractor; import org.springframework.web.client.ResponseExtractor;
...@@ -86,14 +74,17 @@ import org.springframework.web.util.UriTemplateHandler; ...@@ -86,14 +74,17 @@ import org.springframework.web.util.UriTemplateHandler;
* @author Phillip Webb * @author Phillip Webb
* @author Andy Wilkinson * @author Andy Wilkinson
* @author Kristine Jetzke * @author Kristine Jetzke
* @author Dmytro Nosan
* @since 1.4.0 * @since 1.4.0
*/ */
public class TestRestTemplate { public class TestRestTemplate {
private final RestTemplate restTemplate; private final RestTemplateBuilder builder;
private final HttpClientOption[] httpClientOptions; private final HttpClientOption[] httpClientOptions;
private final RestTemplate restTemplate;
/** /**
* Create a new {@link TestRestTemplate} instance. * Create a new {@link TestRestTemplate} instance.
* @param restTemplateBuilder builder used to configure underlying * @param restTemplateBuilder builder used to configure underlying
...@@ -125,60 +116,30 @@ public class TestRestTemplate { ...@@ -125,60 +116,30 @@ public class TestRestTemplate {
/** /**
* Create a new {@link TestRestTemplate} instance with the specified credentials. * Create a new {@link TestRestTemplate} instance with the specified credentials.
* @param restTemplateBuilder builder used to configure underlying * @param builder builder used to configure underlying {@link RestTemplate}
* {@link RestTemplate}
* @param username the username to use (or {@code null}) * @param username the username to use (or {@code null})
* @param password the password (or {@code null}) * @param password the password (or {@code null})
* @param httpClientOptions client options to use if the Apache HTTP Client is used * @param httpClientOptions client options to use if the Apache HTTP Client is used
* @since 2.0.0 * @since 2.0.0
*/ */
public TestRestTemplate(RestTemplateBuilder restTemplateBuilder, String username, public TestRestTemplate(RestTemplateBuilder builder, String username, String password,
String password, HttpClientOption... httpClientOptions) {
this((restTemplateBuilder != null) ? restTemplateBuilder.build() : null, username,
password, httpClientOptions);
}
private TestRestTemplate(RestTemplate restTemplate, String username, String password,
HttpClientOption... httpClientOptions) { HttpClientOption... httpClientOptions) {
Assert.notNull(restTemplate, "RestTemplate must not be null"); Assert.notNull(builder, "Builder must not be null");
this.builder = builder;
this.httpClientOptions = httpClientOptions; this.httpClientOptions = httpClientOptions;
if (getRequestFactoryClass(restTemplate) if (httpClientOptions != null) {
.isAssignableFrom(HttpComponentsClientHttpRequestFactory.class)) { ClientHttpRequestFactory requestFactory = builder.buildRequestFactory();
restTemplate.setRequestFactory( if (requestFactory instanceof HttpComponentsClientHttpRequestFactory) {
new CustomHttpComponentsClientHttpRequestFactory(httpClientOptions)); builder = builder.requestFactory(
} () -> new CustomHttpComponentsClientHttpRequestFactory(
addAuthentication(restTemplate, username, password); httpClientOptions));
restTemplate.setErrorHandler(new NoOpResponseErrorHandler()); }
this.restTemplate = restTemplate;
}
private Class<? extends ClientHttpRequestFactory> getRequestFactoryClass(
RestTemplate restTemplate) {
ClientHttpRequestFactory requestFactory = restTemplate.getRequestFactory();
if (InterceptingClientHttpRequestFactory.class
.isAssignableFrom(requestFactory.getClass())) {
Field requestFactoryField = ReflectionUtils.findField(RestTemplate.class,
"requestFactory");
ReflectionUtils.makeAccessible(requestFactoryField);
requestFactory = (ClientHttpRequestFactory) ReflectionUtils
.getField(requestFactoryField, restTemplate);
}
return requestFactory.getClass();
}
private void addAuthentication(RestTemplate restTemplate, String username,
String password) {
if (username == null) {
return;
} }
List<ClientHttpRequestInterceptor> interceptors = restTemplate.getInterceptors(); if (username != null || password != null) {
if (interceptors == null) { builder = builder.basicAuthentication(username, password);
interceptors = Collections.emptyList();
} }
interceptors = new ArrayList<>(interceptors); this.restTemplate = builder.build();
interceptors.removeIf(BasicAuthenticationInterceptor.class::isInstance); this.restTemplate.setErrorHandler(new NoOpResponseErrorHandler());
interceptors.add(new BasicAuthenticationInterceptor(username, password));
restTemplate.setInterceptors(interceptors);
} }
/** /**
...@@ -1035,25 +996,10 @@ public class TestRestTemplate { ...@@ -1035,25 +996,10 @@ public class TestRestTemplate {
* @since 1.4.1 * @since 1.4.1
*/ */
public TestRestTemplate withBasicAuth(String username, String password) { public TestRestTemplate withBasicAuth(String username, String password) {
RestTemplate restTemplate = new RestTemplateBuilder() TestRestTemplate template = new TestRestTemplate(this.builder, username, password,
.requestFactory(getRequestFactorySupplier())
.messageConverters(getRestTemplate().getMessageConverters())
.interceptors(getRestTemplate().getInterceptors())
.uriTemplateHandler(getRestTemplate().getUriTemplateHandler()).build();
return new TestRestTemplate(restTemplate, username, password,
this.httpClientOptions); this.httpClientOptions);
} template.setUriTemplateHandler(getRestTemplate().getUriTemplateHandler());
return template;
private Supplier<ClientHttpRequestFactory> getRequestFactorySupplier() {
return () -> {
try {
return BeanUtils
.instantiateClass(getRequestFactoryClass(getRestTemplate()));
}
catch (BeanInstantiationException ex) {
return new ClientHttpRequestFactorySupplier().get();
}
};
} }
@SuppressWarnings({ "rawtypes", "unchecked" }) @SuppressWarnings({ "rawtypes", "unchecked" })
...@@ -1075,7 +1021,7 @@ public class TestRestTemplate { ...@@ -1075,7 +1021,7 @@ public class TestRestTemplate {
} }
/** /**
* Options used to customize the Apache Http Client if it is used. * Options used to customize the Apache HTTP Client.
*/ */
public enum HttpClientOption { public enum HttpClientOption {
......
...@@ -21,6 +21,7 @@ import java.lang.reflect.Method; ...@@ -21,6 +21,7 @@ import java.lang.reflect.Method;
import java.lang.reflect.Modifier; import java.lang.reflect.Modifier;
import java.net.URI; import java.net.URI;
import java.util.List; import java.util.List;
import java.util.stream.Collectors;
import org.apache.http.client.config.RequestConfig; import org.apache.http.client.config.RequestConfig;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
...@@ -35,12 +36,9 @@ import org.springframework.http.HttpStatus; ...@@ -35,12 +36,9 @@ import org.springframework.http.HttpStatus;
import org.springframework.http.RequestEntity; import org.springframework.http.RequestEntity;
import org.springframework.http.client.ClientHttpRequest; import org.springframework.http.client.ClientHttpRequest;
import org.springframework.http.client.ClientHttpRequestFactory; import org.springframework.http.client.ClientHttpRequestFactory;
import org.springframework.http.client.ClientHttpRequestInterceptor;
import org.springframework.http.client.HttpComponentsClientHttpRequestFactory; import org.springframework.http.client.HttpComponentsClientHttpRequestFactory;
import org.springframework.http.client.InterceptingClientHttpRequestFactory;
import org.springframework.http.client.OkHttp3ClientHttpRequestFactory; import org.springframework.http.client.OkHttp3ClientHttpRequestFactory;
import org.springframework.http.client.SimpleClientHttpRequestFactory; import org.springframework.http.client.SimpleClientHttpRequestFactory;
import org.springframework.http.client.support.BasicAuthenticationInterceptor;
import org.springframework.mock.env.MockEnvironment; import org.springframework.mock.env.MockEnvironment;
import org.springframework.mock.http.client.MockClientHttpRequest; import org.springframework.mock.http.client.MockClientHttpRequest;
import org.springframework.mock.http.client.MockClientHttpResponse; import org.springframework.mock.http.client.MockClientHttpResponse;
...@@ -102,25 +100,11 @@ public class TestRestTemplateTests { ...@@ -102,25 +100,11 @@ public class TestRestTemplateTests {
TestRestTemplate testRestTemplate = new TestRestTemplate(builder) TestRestTemplate testRestTemplate = new TestRestTemplate(builder)
.withBasicAuth("test", "test"); .withBasicAuth("test", "test");
RestTemplate restTemplate = testRestTemplate.getRestTemplate(); RestTemplate restTemplate = testRestTemplate.getRestTemplate();
assertThat(restTemplate.getRequestFactory().getClass().getName())
.contains("BasicAuth");
Object requestFactory = ReflectionTestUtils Object requestFactory = ReflectionTestUtils
.getField(restTemplate.getRequestFactory(), "requestFactory"); .getField(restTemplate.getRequestFactory(), "requestFactory");
assertThat(requestFactory).isNotEqualTo(customFactory) assertThat(requestFactory).isEqualTo(customFactory).hasSameClassAs(customFactory);
.hasSameClassAs(customFactory);
}
@Test
public void withBasicAuthWhenRequestFactoryTypeCannotBeInstantiatedShouldFallback() {
TestClientHttpRequestFactory customFactory = new TestClientHttpRequestFactory(
"my-request-factory");
RestTemplateBuilder builder = new RestTemplateBuilder()
.requestFactory(() -> customFactory);
TestRestTemplate testRestTemplate = new TestRestTemplate(builder)
.withBasicAuth("test", "test");
RestTemplate restTemplate = testRestTemplate.getRestTemplate();
Object requestFactory = ReflectionTestUtils
.getField(restTemplate.getRequestFactory(), "requestFactory");
assertThat(requestFactory).isNotEqualTo(customFactory)
.isInstanceOf(CustomHttpComponentsClientHttpRequestFactory.class);
} }
@Test @Test
...@@ -148,9 +132,10 @@ public class TestRestTemplateTests { ...@@ -148,9 +132,10 @@ public class TestRestTemplateTests {
@Test @Test
public void authenticated() { public void authenticated() {
assertThat(new TestRestTemplate("user", "password").getRestTemplate() RestTemplate restTemplate = new TestRestTemplate("user", "password")
.getRequestFactory()) .getRestTemplate();
.isInstanceOf(InterceptingClientHttpRequestFactory.class); ClientHttpRequestFactory factory = restTemplate.getRequestFactory();
assertThat(factory.getClass().getName()).contains("BasicAuthentication");
} }
@Test @Test
...@@ -227,43 +212,39 @@ public class TestRestTemplateTests { ...@@ -227,43 +212,39 @@ public class TestRestTemplateTests {
} }
@Test @Test
public void withBasicAuthAddsBasicAuthInterceptorWhenNotAlreadyPresent() { public void withBasicAuthAddsBasicAuthClientFactoryWhenNotAlreadyPresent() {
TestRestTemplate originalTemplate = new TestRestTemplate(); TestRestTemplate original = new TestRestTemplate();
TestRestTemplate basicAuthTemplate = originalTemplate.withBasicAuth("user", TestRestTemplate basicAuth = original.withBasicAuth("user", "password");
"password"); assertThat(getConverterClasses(original))
assertThat(basicAuthTemplate.getRestTemplate().getMessageConverters()) .containsExactlyElementsOf(getConverterClasses(basicAuth));
.containsExactlyElementsOf( assertThat(basicAuth.getRestTemplate().getRequestFactory().getClass().getName())
originalTemplate.getRestTemplate().getMessageConverters()); .contains("BasicAuth");
assertThat(basicAuthTemplate.getRestTemplate().getRequestFactory())
.isInstanceOf(InterceptingClientHttpRequestFactory.class);
assertThat(ReflectionTestUtils.getField( assertThat(ReflectionTestUtils.getField(
basicAuthTemplate.getRestTemplate().getRequestFactory(), basicAuth.getRestTemplate().getRequestFactory(), "requestFactory"))
"requestFactory"))
.isInstanceOf(CustomHttpComponentsClientHttpRequestFactory.class); .isInstanceOf(CustomHttpComponentsClientHttpRequestFactory.class);
assertThat(basicAuthTemplate.getRestTemplate().getUriTemplateHandler()) assertThat(basicAuth.getRestTemplate().getInterceptors()).isEmpty();
.isSameAs(originalTemplate.getRestTemplate().getUriTemplateHandler()); assertBasicAuthorizationCredentials(basicAuth, "user", "password");
assertThat(basicAuthTemplate.getRestTemplate().getInterceptors()).hasSize(1);
assertBasicAuthorizationInterceptorCredentials(basicAuthTemplate, "user",
"password");
} }
@Test @Test
public void withBasicAuthReplacesBasicAuthInterceptorWhenAlreadyPresent() { public void withBasicAuthReplacesBasicAuthClientFactoryWhenAlreadyPresent() {
TestRestTemplate original = new TestRestTemplate("foo", "bar") TestRestTemplate original = new TestRestTemplate("foo", "bar")
.withBasicAuth("replace", "replace"); .withBasicAuth("replace", "replace");
TestRestTemplate basicAuth = original.withBasicAuth("user", "password"); TestRestTemplate basicAuth = original.withBasicAuth("user", "password");
assertThat(basicAuth.getRestTemplate().getMessageConverters()) assertThat(getConverterClasses(basicAuth))
.containsExactlyElementsOf( .containsExactlyElementsOf(getConverterClasses(original));
original.getRestTemplate().getMessageConverters()); assertThat(basicAuth.getRestTemplate().getRequestFactory().getClass().getName())
assertThat(basicAuth.getRestTemplate().getRequestFactory()) .contains("BasicAuth");
.isInstanceOf(InterceptingClientHttpRequestFactory.class);
assertThat(ReflectionTestUtils.getField( assertThat(ReflectionTestUtils.getField(
basicAuth.getRestTemplate().getRequestFactory(), "requestFactory")) basicAuth.getRestTemplate().getRequestFactory(), "requestFactory"))
.isInstanceOf(CustomHttpComponentsClientHttpRequestFactory.class); .isInstanceOf(CustomHttpComponentsClientHttpRequestFactory.class);
assertThat(basicAuth.getRestTemplate().getUriTemplateHandler()) assertThat(basicAuth.getRestTemplate().getInterceptors()).isEmpty();
.isSameAs(original.getRestTemplate().getUriTemplateHandler()); assertBasicAuthorizationCredentials(basicAuth, "user", "password");
assertThat(basicAuth.getRestTemplate().getInterceptors()).hasSize(1); }
assertBasicAuthorizationInterceptorCredentials(basicAuth, "user", "password");
private List<Class<?>> getConverterClasses(TestRestTemplate testRestTemplate) {
return testRestTemplate.getRestTemplate().getMessageConverters().stream()
.map(Object::getClass).collect(Collectors.toList());
} }
@Test @Test
...@@ -394,17 +375,14 @@ public class TestRestTemplateTests { ...@@ -394,17 +375,14 @@ public class TestRestTemplateTests {
verify(requestFactory).createRequest(eq(absoluteUri), any(HttpMethod.class)); verify(requestFactory).createRequest(eq(absoluteUri), any(HttpMethod.class));
} }
private void assertBasicAuthorizationInterceptorCredentials( private void assertBasicAuthorizationCredentials(TestRestTemplate testRestTemplate,
TestRestTemplate testRestTemplate, String username, String password) { String username, String password) {
@SuppressWarnings("unchecked") ClientHttpRequestFactory requestFactory = testRestTemplate.getRestTemplate()
List<ClientHttpRequestInterceptor> requestFactoryInterceptors = (List<ClientHttpRequestInterceptor>) ReflectionTestUtils .getRequestFactory();
.getField(testRestTemplate.getRestTemplate().getRequestFactory(), Object authentication = ReflectionTestUtils.getField(requestFactory,
"interceptors"); "authentication");
assertThat(requestFactoryInterceptors).hasSize(1); assertThat(authentication).hasFieldOrPropertyWithValue("username", username);
ClientHttpRequestInterceptor interceptor = requestFactoryInterceptors.get(0); assertThat(authentication).hasFieldOrPropertyWithValue("password", password);
assertThat(interceptor).isInstanceOf(BasicAuthenticationInterceptor.class);
assertThat(interceptor).hasFieldOrPropertyWithValue("username", username);
assertThat(interceptor).hasFieldOrPropertyWithValue("password", password);
} }
......
/*
* Copyright 2012-2019 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.boot.web.client;
import java.nio.charset.Charset;
import org.springframework.http.HttpHeaders;
import org.springframework.http.client.ClientHttpRequest;
import org.springframework.util.Assert;
/**
* Basic authentication properties which are used by
* {@link BasicAuthenticationClientHttpRequestFactory}.
*
* @author Dmytro Nosan
* @see BasicAuthenticationClientHttpRequestFactory
*/
class BasicAuthentication {
private final String username;
private final String password;
private final Charset charset;
BasicAuthentication(String username, String password, Charset charset) {
Assert.notNull(username, "Username must not be null");
Assert.notNull(password, "Password must not be null");
this.username = username;
this.password = password;
this.charset = charset;
}
void applyTo(ClientHttpRequest request) {
HttpHeaders headers = request.getHeaders();
if (!headers.containsKey(HttpHeaders.AUTHORIZATION)) {
headers.setBasicAuth(this.username, this.password, this.charset);
}
}
}
/*
* Copyright 2012-2019 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.boot.web.client;
import java.io.IOException;
import java.net.URI;
import org.springframework.http.HttpMethod;
import org.springframework.http.client.AbstractClientHttpRequestFactoryWrapper;
import org.springframework.http.client.ClientHttpRequest;
import org.springframework.http.client.ClientHttpRequestFactory;
import org.springframework.util.Assert;
/**
* {@link ClientHttpRequestFactory} to apply a given HTTP Basic Authentication
* username/password pair, unless a custom Authorization header has been set before.
*
* @author Dmytro Nosan
*/
class BasicAuthenticationClientHttpRequestFactory
extends AbstractClientHttpRequestFactoryWrapper {
private final BasicAuthentication authentication;
BasicAuthenticationClientHttpRequestFactory(BasicAuthentication authentication,
ClientHttpRequestFactory clientHttpRequestFactory) {
super(clientHttpRequestFactory);
Assert.notNull(authentication, "Authentication must not be null");
this.authentication = authentication;
}
@Override
protected ClientHttpRequest createRequest(URI uri, HttpMethod httpMethod,
ClientHttpRequestFactory requestFactory) throws IOException {
ClientHttpRequest request = requestFactory.createRequest(uri, httpMethod);
this.authentication.applyTo(request);
return request;
}
}
...@@ -19,12 +19,14 @@ package org.springframework.boot.web.client; ...@@ -19,12 +19,14 @@ package org.springframework.boot.web.client;
import java.lang.reflect.Constructor; import java.lang.reflect.Constructor;
import java.lang.reflect.Field; import java.lang.reflect.Field;
import java.lang.reflect.Method; import java.lang.reflect.Method;
import java.nio.charset.Charset;
import java.time.Duration; import java.time.Duration;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.Arrays; import java.util.Arrays;
import java.util.Collection; import java.util.Collection;
import java.util.Collections; import java.util.Collections;
import java.util.LinkedHashSet; import java.util.LinkedHashSet;
import java.util.List;
import java.util.Set; import java.util.Set;
import java.util.function.Consumer; import java.util.function.Consumer;
import java.util.function.Supplier; import java.util.function.Supplier;
...@@ -33,7 +35,6 @@ import org.springframework.beans.BeanUtils; ...@@ -33,7 +35,6 @@ import org.springframework.beans.BeanUtils;
import org.springframework.http.client.AbstractClientHttpRequestFactoryWrapper; import org.springframework.http.client.AbstractClientHttpRequestFactoryWrapper;
import org.springframework.http.client.ClientHttpRequestFactory; import org.springframework.http.client.ClientHttpRequestFactory;
import org.springframework.http.client.ClientHttpRequestInterceptor; import org.springframework.http.client.ClientHttpRequestInterceptor;
import org.springframework.http.client.support.BasicAuthenticationInterceptor;
import org.springframework.http.converter.HttpMessageConverter; import org.springframework.http.converter.HttpMessageConverter;
import org.springframework.util.Assert; import org.springframework.util.Assert;
import org.springframework.util.CollectionUtils; import org.springframework.util.CollectionUtils;
...@@ -58,6 +59,7 @@ import org.springframework.web.util.UriTemplateHandler; ...@@ -58,6 +59,7 @@ import org.springframework.web.util.UriTemplateHandler;
* @author Phillip Webb * @author Phillip Webb
* @author Andy Wilkinson * @author Andy Wilkinson
* @author Brian Clozel * @author Brian Clozel
* @author Dmytro Nosan
* @since 1.4.0 * @since 1.4.0
*/ */
public class RestTemplateBuilder { public class RestTemplateBuilder {
...@@ -74,7 +76,7 @@ public class RestTemplateBuilder { ...@@ -74,7 +76,7 @@ public class RestTemplateBuilder {
private final ResponseErrorHandler errorHandler; private final ResponseErrorHandler errorHandler;
private final BasicAuthenticationInterceptor basicAuthentication; private final BasicAuthentication basicAuthentication;
private final Set<RestTemplateCustomizer> restTemplateCustomizers; private final Set<RestTemplateCustomizer> restTemplateCustomizers;
...@@ -106,7 +108,7 @@ public class RestTemplateBuilder { ...@@ -106,7 +108,7 @@ public class RestTemplateBuilder {
Set<HttpMessageConverter<?>> messageConverters, Set<HttpMessageConverter<?>> messageConverters,
Supplier<ClientHttpRequestFactory> requestFactorySupplier, Supplier<ClientHttpRequestFactory> requestFactorySupplier,
UriTemplateHandler uriTemplateHandler, ResponseErrorHandler errorHandler, UriTemplateHandler uriTemplateHandler, ResponseErrorHandler errorHandler,
BasicAuthenticationInterceptor basicAuthentication, BasicAuthentication basicAuthentication,
Set<RestTemplateCustomizer> restTemplateCustomizers, Set<RestTemplateCustomizer> restTemplateCustomizers,
RequestFactoryCustomizer requestFactoryCustomizer, RequestFactoryCustomizer requestFactoryCustomizer,
Set<ClientHttpRequestInterceptor> interceptors) { Set<ClientHttpRequestInterceptor> interceptors) {
...@@ -371,18 +373,35 @@ public class RestTemplateBuilder { ...@@ -371,18 +373,35 @@ public class RestTemplateBuilder {
} }
/** /**
* Add HTTP basic authentication to requests. See * Add HTTP Basic Authentication to requests with the given username/password pair,
* {@link BasicAuthenticationInterceptor} for details. * unless a custom Authorization header has been set before.
* @param username the user name * @param username the user name
* @param password the password * @param password the password
* @return a new builder instance * @return a new builder instance
* @since 2.1.0 * @since 2.1.0
* @see #basicAuthentication(String, String, Charset)
*/ */
public RestTemplateBuilder basicAuthentication(String username, String password) { public RestTemplateBuilder basicAuthentication(String username, String password) {
return basicAuthentication(username, password, null);
}
/**
* Add HTTP Basic Authentication to requests with the given username/password pair,
* unless a custom Authorization header has been set before.
* @param username the user name
* @param password the password
* @param charset the charset to use
* @return a new builder instance
* @since 2.2.0
* @see #basicAuthentication(String, String)
*/
public RestTemplateBuilder basicAuthentication(String username, String password,
Charset charset) {
BasicAuthentication basicAuthentication = new BasicAuthentication(username,
password, charset);
return new RestTemplateBuilder(this.detectRequestFactory, this.rootUri, return new RestTemplateBuilder(this.detectRequestFactory, this.rootUri,
this.messageConverters, this.requestFactorySupplier, this.messageConverters, this.requestFactorySupplier,
this.uriTemplateHandler, this.errorHandler, this.uriTemplateHandler, this.errorHandler, basicAuthentication,
new BasicAuthenticationInterceptor(username, password),
this.restTemplateCustomizers, this.requestFactoryCustomizer, this.restTemplateCustomizers, this.requestFactoryCustomizer,
this.interceptors); this.interceptors);
} }
...@@ -506,7 +525,6 @@ public class RestTemplateBuilder { ...@@ -506,7 +525,6 @@ public class RestTemplateBuilder {
* @see RestTemplateBuilder#build() * @see RestTemplateBuilder#build()
* @see #configure(RestTemplate) * @see #configure(RestTemplate)
*/ */
public <T extends RestTemplate> T build(Class<T> restTemplateClass) { public <T extends RestTemplate> T build(Class<T> restTemplateClass) {
return configure(BeanUtils.instantiateClass(restTemplateClass)); return configure(BeanUtils.instantiateClass(restTemplateClass));
} }
...@@ -520,7 +538,13 @@ public class RestTemplateBuilder { ...@@ -520,7 +538,13 @@ public class RestTemplateBuilder {
* @see RestTemplateBuilder#build(Class) * @see RestTemplateBuilder#build(Class)
*/ */
public <T extends RestTemplate> T configure(T restTemplate) { public <T extends RestTemplate> T configure(T restTemplate) {
configureRequestFactory(restTemplate); ClientHttpRequestFactory requestFactory = buildRequestFactory();
if (requestFactory != null) {
restTemplate.setRequestFactory(requestFactory);
}
if (this.basicAuthentication != null) {
configureBasicAuthentication(restTemplate);
}
if (!CollectionUtils.isEmpty(this.messageConverters)) { if (!CollectionUtils.isEmpty(this.messageConverters)) {
restTemplate.setMessageConverters(new ArrayList<>(this.messageConverters)); restTemplate.setMessageConverters(new ArrayList<>(this.messageConverters));
} }
...@@ -533,9 +557,6 @@ public class RestTemplateBuilder { ...@@ -533,9 +557,6 @@ public class RestTemplateBuilder {
if (this.rootUri != null) { if (this.rootUri != null) {
RootUriTemplateHandler.addTo(restTemplate, this.rootUri); RootUriTemplateHandler.addTo(restTemplate, this.rootUri);
} }
if (this.basicAuthentication != null) {
restTemplate.getInterceptors().add(this.basicAuthentication);
}
restTemplate.getInterceptors().addAll(this.interceptors); restTemplate.getInterceptors().addAll(this.interceptors);
if (!CollectionUtils.isEmpty(this.restTemplateCustomizers)) { if (!CollectionUtils.isEmpty(this.restTemplateCustomizers)) {
for (RestTemplateCustomizer customizer : this.restTemplateCustomizers) { for (RestTemplateCustomizer customizer : this.restTemplateCustomizers) {
...@@ -545,7 +566,13 @@ public class RestTemplateBuilder { ...@@ -545,7 +566,13 @@ public class RestTemplateBuilder {
return restTemplate; return restTemplate;
} }
private void configureRequestFactory(RestTemplate restTemplate) { /**
* Build a new {@link ClientHttpRequestFactory} instance using the settings of this
* builder.
* @return a {@link ClientHttpRequestFactory} or {@code null}
* @since 2.2.0
*/
public ClientHttpRequestFactory buildRequestFactory() {
ClientHttpRequestFactory requestFactory = null; ClientHttpRequestFactory requestFactory = null;
if (this.requestFactorySupplier != null) { if (this.requestFactorySupplier != null) {
requestFactory = this.requestFactorySupplier.get(); requestFactory = this.requestFactorySupplier.get();
...@@ -557,7 +584,23 @@ public class RestTemplateBuilder { ...@@ -557,7 +584,23 @@ public class RestTemplateBuilder {
if (this.requestFactoryCustomizer != null) { if (this.requestFactoryCustomizer != null) {
this.requestFactoryCustomizer.accept(requestFactory); this.requestFactoryCustomizer.accept(requestFactory);
} }
restTemplate.setRequestFactory(requestFactory); }
return requestFactory;
}
private void configureBasicAuthentication(RestTemplate restTemplate) {
List<ClientHttpRequestInterceptor> interceptors = null;
if (!restTemplate.getInterceptors().isEmpty()) {
// Stash and clear the interceptors so we can access the real factory
interceptors = new ArrayList<>(restTemplate.getInterceptors());
restTemplate.getInterceptors().clear();
}
ClientHttpRequestFactory requestFactory = restTemplate.getRequestFactory();
restTemplate.setRequestFactory(new BasicAuthenticationClientHttpRequestFactory(
this.basicAuthentication, requestFactory));
// Restore the original interceptors
if (interceptors != null) {
restTemplate.getInterceptors().addAll(interceptors);
} }
} }
...@@ -610,15 +653,14 @@ public class RestTemplateBuilder { ...@@ -610,15 +653,14 @@ public class RestTemplateBuilder {
if (!(requestFactory instanceof AbstractClientHttpRequestFactoryWrapper)) { if (!(requestFactory instanceof AbstractClientHttpRequestFactoryWrapper)) {
return requestFactory; return requestFactory;
} }
ClientHttpRequestFactory unwrappedRequestFactory = requestFactory;
Field field = ReflectionUtils.findField( Field field = ReflectionUtils.findField(
AbstractClientHttpRequestFactoryWrapper.class, "requestFactory"); AbstractClientHttpRequestFactoryWrapper.class, "requestFactory");
ReflectionUtils.makeAccessible(field); ReflectionUtils.makeAccessible(field);
do { ClientHttpRequestFactory unwrappedRequestFactory = requestFactory;
while (unwrappedRequestFactory instanceof AbstractClientHttpRequestFactoryWrapper) {
unwrappedRequestFactory = (ClientHttpRequestFactory) ReflectionUtils unwrappedRequestFactory = (ClientHttpRequestFactory) ReflectionUtils
.getField(field, unwrappedRequestFactory); .getField(field, unwrappedRequestFactory);
} }
while (unwrappedRequestFactory instanceof AbstractClientHttpRequestFactoryWrapper);
return unwrappedRequestFactory; return unwrappedRequestFactory;
} }
......
/*
* Copyright 2012-2019 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.boot.web.client;
import java.io.IOException;
import java.net.URI;
import org.junit.Before;
import org.junit.Test;
import org.springframework.http.HttpHeaders;
import org.springframework.http.HttpMethod;
import org.springframework.http.client.ClientHttpRequest;
import org.springframework.http.client.ClientHttpRequestFactory;
import static org.assertj.core.api.Assertions.assertThat;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.BDDMockito.given;
import static org.mockito.Mockito.mock;
/**
* Tests for {@link BasicAuthenticationClientHttpRequestFactory}.
*
* @author Dmytro Nosan
*/
public class BasicAuthenticationClientHttpRequestFactoryTests {
private final HttpHeaders headers = new HttpHeaders();
private final BasicAuthentication authentication = new BasicAuthentication("spring",
"boot", null);
private ClientHttpRequestFactory requestFactory;
@Before
public void setUp() throws IOException {
ClientHttpRequestFactory requestFactory = mock(ClientHttpRequestFactory.class);
ClientHttpRequest request = mock(ClientHttpRequest.class);
given(requestFactory.createRequest(any(), any())).willReturn(request);
given(request.getHeaders()).willReturn(this.headers);
this.requestFactory = new BasicAuthenticationClientHttpRequestFactory(
this.authentication, requestFactory);
}
@Test
public void shouldAddAuthorizationHeader() throws IOException {
ClientHttpRequest request = createRequest();
assertThat(request.getHeaders().get(HttpHeaders.AUTHORIZATION))
.containsExactly("Basic c3ByaW5nOmJvb3Q=");
}
@Test
public void shouldNotAddAuthorizationHeaderAuthorizationAlreadySet()
throws IOException {
this.headers.setBasicAuth("boot", "spring");
ClientHttpRequest request = createRequest();
assertThat(request.getHeaders().get(HttpHeaders.AUTHORIZATION))
.doesNotContain("Basic c3ByaW5nOmJvb3Q=");
}
private ClientHttpRequest createRequest() throws IOException {
return this.requestFactory.createRequest(URI.create("https://localhost:8080"),
HttpMethod.POST);
}
}
...@@ -16,6 +16,7 @@ ...@@ -16,6 +16,7 @@
package org.springframework.boot.web.client; package org.springframework.boot.web.client;
import java.nio.charset.StandardCharsets;
import java.time.Duration; import java.time.Duration;
import java.util.Collections; import java.util.Collections;
import java.util.Set; import java.util.Set;
...@@ -35,7 +36,6 @@ import org.springframework.http.client.HttpComponentsClientHttpRequestFactory; ...@@ -35,7 +36,6 @@ import org.springframework.http.client.HttpComponentsClientHttpRequestFactory;
import org.springframework.http.client.InterceptingClientHttpRequestFactory; import org.springframework.http.client.InterceptingClientHttpRequestFactory;
import org.springframework.http.client.OkHttp3ClientHttpRequestFactory; import org.springframework.http.client.OkHttp3ClientHttpRequestFactory;
import org.springframework.http.client.SimpleClientHttpRequestFactory; import org.springframework.http.client.SimpleClientHttpRequestFactory;
import org.springframework.http.client.support.BasicAuthenticationInterceptor;
import org.springframework.http.converter.HttpMessageConverter; import org.springframework.http.converter.HttpMessageConverter;
import org.springframework.http.converter.ResourceHttpMessageConverter; import org.springframework.http.converter.ResourceHttpMessageConverter;
import org.springframework.http.converter.StringHttpMessageConverter; import org.springframework.http.converter.StringHttpMessageConverter;
...@@ -324,12 +324,13 @@ public class RestTemplateBuilderTests { ...@@ -324,12 +324,13 @@ public class RestTemplateBuilderTests {
@Test @Test
public void basicAuthenticationShouldApply() { public void basicAuthenticationShouldApply() {
RestTemplate template = this.builder.basicAuthentication("spring", "boot") RestTemplate template = this.builder
.build(); .basicAuthentication("spring", "boot", StandardCharsets.UTF_8).build();
ClientHttpRequestInterceptor interceptor = template.getInterceptors().get(0); ClientHttpRequestFactory requestFactory = template.getRequestFactory();
assertThat(interceptor).isInstanceOf(BasicAuthenticationInterceptor.class); Object authentication = ReflectionTestUtils.getField(requestFactory,
assertThat(interceptor).extracting("username").containsExactly("spring"); "authentication");
assertThat(interceptor).extracting("password").containsExactly("boot"); assertThat(authentication).extracting("username", "password", "charset")
.containsExactly("spring", "boot", StandardCharsets.UTF_8);
} }
@Test @Test
...@@ -406,9 +407,7 @@ public class RestTemplateBuilderTests { ...@@ -406,9 +407,7 @@ public class RestTemplateBuilderTests {
.messageConverters(this.messageConverter).rootUri("http://localhost:8080") .messageConverters(this.messageConverter).rootUri("http://localhost:8080")
.errorHandler(errorHandler).basicAuthentication("spring", "boot") .errorHandler(errorHandler).basicAuthentication("spring", "boot")
.requestFactory(() -> requestFactory).customizers((restTemplate) -> { .requestFactory(() -> requestFactory).customizers((restTemplate) -> {
assertThat(restTemplate.getInterceptors()).hasSize(2) assertThat(restTemplate.getInterceptors()).hasSize(1);
.contains(this.interceptor).anyMatch(
(ic) -> ic instanceof BasicAuthenticationInterceptor);
assertThat(restTemplate.getMessageConverters()) assertThat(restTemplate.getMessageConverters())
.contains(this.messageConverter); .contains(this.messageConverter);
assertThat(restTemplate.getUriTemplateHandler()) assertThat(restTemplate.getUriTemplateHandler())
...@@ -418,7 +417,11 @@ public class RestTemplateBuilderTests { ...@@ -418,7 +417,11 @@ public class RestTemplateBuilderTests {
.getRequestFactory(); .getRequestFactory();
assertThat(actualRequestFactory) assertThat(actualRequestFactory)
.isInstanceOf(InterceptingClientHttpRequestFactory.class); .isInstanceOf(InterceptingClientHttpRequestFactory.class);
assertThat(actualRequestFactory).hasFieldOrPropertyWithValue( ClientHttpRequestFactory authRequestFactory = (ClientHttpRequestFactory) ReflectionTestUtils
.getField(actualRequestFactory, "requestFactory");
assertThat(authRequestFactory).isInstanceOf(
BasicAuthenticationClientHttpRequestFactory.class);
assertThat(authRequestFactory).hasFieldOrPropertyWithValue(
"requestFactory", requestFactory); "requestFactory", requestFactory);
}).build(); }).build();
} }
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment