Commit 43b1a667 authored by Ilya Lukyanovich's avatar Ilya Lukyanovich Committed by Phillip Webb

Support default headers with RestTemplateBuilder

Update `RestTemplateBuilder` so that it is easier to apply custom
headers to the outgoing request. The update is particularly useful
for setting the `User-Agent` header, for example so that a GitHub
username can be used when calling `api.github.com`.

See gh-17091
parent 9b5cb4f9
......@@ -31,6 +31,7 @@ import org.springframework.boot.test.web.client.TestRestTemplate.HttpClientOptio
import org.springframework.boot.web.client.RestTemplateBuilder;
import org.springframework.core.ParameterizedTypeReference;
import org.springframework.http.HttpEntity;
import org.springframework.http.HttpHeaders;
import org.springframework.http.HttpMethod;
import org.springframework.http.HttpStatus;
import org.springframework.http.RequestEntity;
......@@ -43,6 +44,7 @@ import org.springframework.mock.env.MockEnvironment;
import org.springframework.mock.http.client.MockClientHttpRequest;
import org.springframework.mock.http.client.MockClientHttpResponse;
import org.springframework.test.util.ReflectionTestUtils;
import org.springframework.util.Base64Utils;
import org.springframework.util.ReflectionUtils;
import org.springframework.util.ReflectionUtils.MethodCallback;
import org.springframework.web.client.ResponseErrorHandler;
......@@ -97,7 +99,8 @@ class TestRestTemplateTests {
RestTemplateBuilder builder = new RestTemplateBuilder().requestFactory(() -> customFactory);
TestRestTemplate testRestTemplate = new TestRestTemplate(builder).withBasicAuth("test", "test");
RestTemplate restTemplate = testRestTemplate.getRestTemplate();
assertThat(restTemplate.getRequestFactory().getClass().getName()).contains("BasicAuth");
assertThat(restTemplate.getRequestFactory().getClass().getName())
.contains("HttpHeadersCustomizingClientHttpRequestFactory");
Object requestFactory = ReflectionTestUtils.getField(restTemplate.getRequestFactory(), "requestFactory");
assertThat(requestFactory).isEqualTo(customFactory).hasSameClassAs(customFactory);
}
......@@ -125,10 +128,9 @@ class TestRestTemplateTests {
}
@Test
void authenticated() {
RestTemplate restTemplate = new TestRestTemplate("user", "password").getRestTemplate();
ClientHttpRequestFactory factory = restTemplate.getRequestFactory();
assertThat(factory.getClass().getName()).contains("BasicAuthentication");
void authenticated() throws Exception {
TestRestTemplate restTemplate = new TestRestTemplate("user", "password");
assertBasicAuthorizationCredentials(restTemplate, "user", "password");
}
@Test
......@@ -201,11 +203,12 @@ class TestRestTemplateTests {
}
@Test
void withBasicAuthAddsBasicAuthClientFactoryWhenNotAlreadyPresent() {
void withBasicAuthAddsBasicAuthClientFactoryWhenNotAlreadyPresent() throws Exception {
TestRestTemplate original = new TestRestTemplate();
TestRestTemplate basicAuth = original.withBasicAuth("user", "password");
assertThat(getConverterClasses(original)).containsExactlyElementsOf(getConverterClasses(basicAuth));
assertThat(basicAuth.getRestTemplate().getRequestFactory().getClass().getName()).contains("BasicAuth");
assertThat(basicAuth.getRestTemplate().getRequestFactory().getClass().getName())
.contains("HttpHeadersCustomizingClientHttpRequestFactory");
assertThat(ReflectionTestUtils.getField(basicAuth.getRestTemplate().getRequestFactory(), "requestFactory"))
.isInstanceOf(CustomHttpComponentsClientHttpRequestFactory.class);
assertThat(basicAuth.getRestTemplate().getInterceptors()).isEmpty();
......@@ -213,11 +216,12 @@ class TestRestTemplateTests {
}
@Test
void withBasicAuthReplacesBasicAuthClientFactoryWhenAlreadyPresent() {
void withBasicAuthReplacesBasicAuthClientFactoryWhenAlreadyPresent() throws Exception {
TestRestTemplate original = new TestRestTemplate("foo", "bar").withBasicAuth("replace", "replace");
TestRestTemplate basicAuth = original.withBasicAuth("user", "password");
assertThat(getConverterClasses(basicAuth)).containsExactlyElementsOf(getConverterClasses(original));
assertThat(basicAuth.getRestTemplate().getRequestFactory().getClass().getName()).contains("BasicAuth");
assertThat(basicAuth.getRestTemplate().getRequestFactory().getClass().getName())
.contains("HttpHeadersCustomizingClientHttpRequestFactory");
assertThat(ReflectionTestUtils.getField(basicAuth.getRestTemplate().getRequestFactory(), "requestFactory"))
.isInstanceOf(CustomHttpComponentsClientHttpRequestFactory.class);
assertThat(basicAuth.getRestTemplate().getInterceptors()).isEmpty();
......@@ -342,11 +346,12 @@ class TestRestTemplateTests {
}
private void assertBasicAuthorizationCredentials(TestRestTemplate testRestTemplate, String username,
String password) {
String password) throws Exception {
ClientHttpRequestFactory requestFactory = testRestTemplate.getRestTemplate().getRequestFactory();
Object authentication = ReflectionTestUtils.getField(requestFactory, "authentication");
assertThat(authentication).hasFieldOrPropertyWithValue("username", username);
assertThat(authentication).hasFieldOrPropertyWithValue("password", password);
ClientHttpRequest request = requestFactory.createRequest(URI.create("http://localhost"), HttpMethod.POST);
assertThat(request.getHeaders()).containsKeys(HttpHeaders.AUTHORIZATION);
assertThat(request.getHeaders().get(HttpHeaders.AUTHORIZATION)).containsExactly(
"Basic " + Base64Utils.encodeToString(String.format("%s:%s", username, password).getBytes()));
}
......@@ -356,16 +361,4 @@ class TestRestTemplateTests {
}
static class TestClientHttpRequestFactory implements ClientHttpRequestFactory {
TestClientHttpRequestFactory(String value) {
}
@Override
public ClientHttpRequest createRequest(URI uri, HttpMethod httpMethod) throws IOException {
return null;
}
}
}
/*
* Copyright 2012-2019 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.boot.web.client;
import org.springframework.http.HttpHeaders;
/**
* {@link HttpHeadersCustomizer} that only adds headers that were not populated in the
* request.
*
* @author Ilya Lukyanovich
*/
public abstract class AbstractHttpHeadersDefaultingCustomizer implements HttpHeadersCustomizer {
@Override
public void applyTo(HttpHeaders headers) {
createHeaders().forEach((key, value) -> headers.merge(key, value, (oldValue, ignored) -> oldValue));
}
protected abstract HttpHeaders createHeaders();
}
......@@ -19,17 +19,17 @@ package org.springframework.boot.web.client;
import java.nio.charset.Charset;
import org.springframework.http.HttpHeaders;
import org.springframework.http.client.ClientHttpRequest;
import org.springframework.util.Assert;
/**
* Basic authentication properties which are used by
* {@link BasicAuthenticationClientHttpRequestFactory}.
* {@link AbstractHttpHeadersDefaultingCustomizer} that applies basic authentication
* header unless it was provided in the request.
*
* @author Dmytro Nosan
* @see BasicAuthenticationClientHttpRequestFactory
* @author Ilya Lukyanovich
* @see HttpHeadersCustomizingClientHttpRequestFactory
*/
class BasicAuthentication {
class BasicAuthenticationHeaderDefaultingCustomizer extends AbstractHttpHeadersDefaultingCustomizer {
private final String username;
......@@ -37,7 +37,7 @@ class BasicAuthentication {
private final Charset charset;
BasicAuthentication(String username, String password, Charset charset) {
BasicAuthenticationHeaderDefaultingCustomizer(String username, String password, Charset charset) {
Assert.notNull(username, "Username must not be null");
Assert.notNull(password, "Password must not be null");
this.username = username;
......@@ -45,11 +45,11 @@ class BasicAuthentication {
this.charset = charset;
}
void applyTo(ClientHttpRequest request) {
HttpHeaders headers = request.getHeaders();
if (!headers.containsKey(HttpHeaders.AUTHORIZATION)) {
@Override
protected HttpHeaders createHeaders() {
HttpHeaders headers = new HttpHeaders();
headers.setBasicAuth(this.username, this.password, this.charset);
}
return headers;
}
}
/*
* Copyright 2012-2019 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.boot.web.client;
import org.springframework.http.HttpHeaders;
/**
* Callback interface that can be used to customize a {@link HttpHeaders}.
*
* @author Ilya Lukyanovich
* @see HttpHeadersCustomizingClientHttpRequestFactory
*/
@FunctionalInterface
public interface HttpHeadersCustomizer {
/**
* Callback to customize a {@link HttpHeaders} instance.
* @param headers the headers to customize
*/
void applyTo(HttpHeaders headers);
}
......@@ -18,6 +18,9 @@ package org.springframework.boot.web.client;
import java.io.IOException;
import java.net.URI;
import java.util.Collection;
import org.jetbrains.annotations.NotNull;
import org.springframework.http.HttpMethod;
import org.springframework.http.client.AbstractClientHttpRequestFactoryWrapper;
......@@ -26,27 +29,29 @@ import org.springframework.http.client.ClientHttpRequestFactory;
import org.springframework.util.Assert;
/**
* {@link ClientHttpRequestFactory} to apply a given HTTP Basic Authentication
* username/password pair, unless a custom Authorization header has been set before.
* {@link ClientHttpRequestFactory} to apply default headers to a request unless header
* values were provided.
*
* @author Ilya Lukyanovich
* @author Dmytro Nosan
*/
class BasicAuthenticationClientHttpRequestFactory extends AbstractClientHttpRequestFactoryWrapper {
class HttpHeadersCustomizingClientHttpRequestFactory extends AbstractClientHttpRequestFactoryWrapper {
private final BasicAuthentication authentication;
private final Collection<? extends HttpHeadersCustomizer> customizers;
BasicAuthenticationClientHttpRequestFactory(BasicAuthentication authentication,
HttpHeadersCustomizingClientHttpRequestFactory(Collection<? extends HttpHeadersCustomizer> customizers,
ClientHttpRequestFactory clientHttpRequestFactory) {
super(clientHttpRequestFactory);
Assert.notNull(authentication, "Authentication must not be null");
this.authentication = authentication;
Assert.notEmpty(customizers, "Customizers must not be empty");
this.customizers = customizers;
}
@NotNull
@Override
protected ClientHttpRequest createRequest(URI uri, HttpMethod httpMethod, ClientHttpRequestFactory requestFactory)
throws IOException {
protected ClientHttpRequest createRequest(@NotNull URI uri, @NotNull HttpMethod httpMethod,
ClientHttpRequestFactory requestFactory) throws IOException {
ClientHttpRequest request = requestFactory.createRequest(uri, httpMethod);
this.authentication.applyTo(request);
this.customizers.forEach((customizer) -> customizer.applyTo(request.getHeaders()));
return request;
}
......
/*
* Copyright 2012-2019 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.boot.web.client;
import java.nio.charset.Charset;
import org.springframework.http.HttpHeaders;
import org.springframework.lang.NonNull;
import org.springframework.lang.Nullable;
import org.springframework.util.Assert;
import org.springframework.util.LinkedMultiValueMap;
/**
* A {@link AbstractHttpHeadersDefaultingCustomizer} that uses provided
* {@link HttpHeaders} instance as default headers.
*
* @author Ilya Lukyanovich
* @see HttpHeadersCustomizingClientHttpRequestFactory
*/
public class SimpleHttpHeaderDefaultingCustomizer extends AbstractHttpHeadersDefaultingCustomizer {
private final HttpHeaders httpHeaders;
public SimpleHttpHeaderDefaultingCustomizer(HttpHeaders httpHeaders) {
Assert.notNull(httpHeaders, "Header must not be null");
this.httpHeaders = httpHeaders;
}
@Override
protected HttpHeaders createHeaders() {
return new HttpHeaders(new LinkedMultiValueMap<>(this.httpHeaders));
}
/**
* A factory method that creates a {@link SimpleHttpHeaderDefaultingCustomizer} with a
* single header and a single value.
* @param header the header
* @param value the value
* @return new {@link SimpleHttpHeaderDefaultingCustomizer} instance
* @see HttpHeaders#set(String, String)
*/
public static HttpHeadersCustomizer singleHeader(@NonNull String header, @NonNull String value) {
Assert.notNull(header, "Header must not be null empty");
Assert.notNull(value, "Value must not be null empty");
HttpHeaders headers = new HttpHeaders();
headers.set(header, value);
return new SimpleHttpHeaderDefaultingCustomizer(headers);
}
/**
* A factory method that creates a {@link SimpleHttpHeaderDefaultingCustomizer} for
* {@link HttpHeaders#AUTHORIZATION} header with pre-defined username and password
* pair.
* @param username the username
* @param password the password
* @return new {@link SimpleHttpHeaderDefaultingCustomizer} instance
* @see #basicAuthentication(String, String, Charset)
*/
public static HttpHeadersCustomizer basicAuthentication(@NonNull String username, @NonNull String password) {
return basicAuthentication(username, password, null);
}
/**
* A factory method that creates a {@link SimpleHttpHeaderDefaultingCustomizer} for
* {@link HttpHeaders#AUTHORIZATION} header with pre-defined username and password
* pair.
* @param username the username
* @param password the password
* @param charset the header encoding charset
* @return new {@link SimpleHttpHeaderDefaultingCustomizer} instance
* @see HttpHeaders#setBasicAuth(String, String, Charset)
*/
public static HttpHeadersCustomizer basicAuthentication(@NonNull String username, @NonNull String password,
@Nullable Charset charset) {
Assert.notNull(username, "Username must not be null");
Assert.notNull(password, "Password must not be null");
HttpHeaders headers = new HttpHeaders();
headers.setBasicAuth(username, password, charset);
return new SimpleHttpHeaderDefaultingCustomizer(headers);
}
}
......@@ -18,6 +18,8 @@ package org.springframework.boot.web.client;
import java.io.IOException;
import java.net.URI;
import java.util.Arrays;
import java.util.Collections;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
......@@ -33,29 +35,30 @@ import static org.mockito.BDDMockito.given;
import static org.mockito.Mockito.mock;
/**
* Tests for {@link BasicAuthenticationClientHttpRequestFactory}.
* Tests for {@link HttpHeadersCustomizingClientHttpRequestFactory}.
*
* @author Dmytro Nosan
* @author Ilya Lukyanovich
*/
class BasicAuthenticationClientHttpRequestFactoryTests {
public class HttpHeadersCustomizingClientHttpRequestFactoryTests {
private final HttpHeaders headers = new HttpHeaders();
private final BasicAuthentication authentication = new BasicAuthentication("spring", "boot", null);
private ClientHttpRequestFactory requestFactory;
@BeforeEach
public void setUp() throws IOException {
ClientHttpRequestFactory requestFactory = mock(ClientHttpRequestFactory.class);
this.requestFactory = mock(ClientHttpRequestFactory.class);
ClientHttpRequest request = mock(ClientHttpRequest.class);
given(requestFactory.createRequest(any(), any())).willReturn(request);
given(this.requestFactory.createRequest(any(), any())).willReturn(request);
given(request.getHeaders()).willReturn(this.headers);
this.requestFactory = new BasicAuthenticationClientHttpRequestFactory(this.authentication, requestFactory);
}
@Test
void shouldAddAuthorizationHeader() throws IOException {
this.requestFactory = new HttpHeadersCustomizingClientHttpRequestFactory(
Collections.singleton(SimpleHttpHeaderDefaultingCustomizer.basicAuthentication("spring", "boot", null)),
this.requestFactory);
ClientHttpRequest request = createRequest();
assertThat(request.getHeaders().get(HttpHeaders.AUTHORIZATION)).containsExactly("Basic c3ByaW5nOmJvb3Q=");
}
......@@ -63,11 +66,27 @@ class BasicAuthenticationClientHttpRequestFactoryTests {
@Test
void shouldNotAddAuthorizationHeaderAuthorizationAlreadySet() throws IOException {
this.headers.setBasicAuth("boot", "spring");
this.requestFactory = new HttpHeadersCustomizingClientHttpRequestFactory(
Collections.singleton(SimpleHttpHeaderDefaultingCustomizer.basicAuthentication("spring", "boot", null)),
this.requestFactory);
ClientHttpRequest request = createRequest();
assertThat(request.getHeaders().get(HttpHeaders.AUTHORIZATION)).doesNotContain("Basic c3ByaW5nOmJvb3Q=");
}
@Test
void shouldApplyCustomizersInTheProvidedOrder() throws IOException {
this.requestFactory = new HttpHeadersCustomizingClientHttpRequestFactory(
Arrays.asList((headers) -> headers.add("foo", "bar"),
SimpleHttpHeaderDefaultingCustomizer.basicAuthentication("spring", "boot", null),
SimpleHttpHeaderDefaultingCustomizer.singleHeader(HttpHeaders.AUTHORIZATION, "won't do")),
this.requestFactory);
ClientHttpRequest request = createRequest();
assertThat(request.getHeaders()).containsOnlyKeys("foo", HttpHeaders.AUTHORIZATION);
assertThat(request.getHeaders().get("foo")).containsExactly("bar");
assertThat(request.getHeaders().get(HttpHeaders.AUTHORIZATION)).containsExactly("Basic c3ByaW5nOmJvb3Q=");
}
private ClientHttpRequest createRequest() throws IOException {
return this.requestFactory.createRequest(URI.create("https://localhost:8080"), HttpMethod.POST);
}
......
......@@ -16,6 +16,7 @@
package org.springframework.boot.web.client;
import java.net.URI;
import java.nio.charset.StandardCharsets;
import java.time.Duration;
import java.util.Collections;
......@@ -29,7 +30,10 @@ import org.mockito.InOrder;
import org.mockito.Mock;
import org.mockito.MockitoAnnotations;
import org.springframework.http.HttpHeaders;
import org.springframework.http.HttpMethod;
import org.springframework.http.client.BufferingClientHttpRequestFactory;
import org.springframework.http.client.ClientHttpRequest;
import org.springframework.http.client.ClientHttpRequestFactory;
import org.springframework.http.client.ClientHttpRequestInterceptor;
import org.springframework.http.client.HttpComponentsClientHttpRequestFactory;
......@@ -298,12 +302,12 @@ class RestTemplateBuilderTests {
}
@Test
void basicAuthenticationShouldApply() {
void basicAuthenticationShouldApply() throws Exception {
RestTemplate template = this.builder.basicAuthentication("spring", "boot", StandardCharsets.UTF_8).build();
ClientHttpRequestFactory requestFactory = template.getRequestFactory();
Object authentication = ReflectionTestUtils.getField(requestFactory, "authentication");
assertThat(authentication).extracting("username", "password", "charset").containsExactly("spring", "boot",
StandardCharsets.UTF_8);
ClientHttpRequest request = requestFactory.createRequest(URI.create("http://localhost"), HttpMethod.POST);
assertThat(request.getHeaders()).containsOnlyKeys(HttpHeaders.AUTHORIZATION);
assertThat(request.getHeaders().get(HttpHeaders.AUTHORIZATION)).containsExactly("Basic c3ByaW5nOmJvb3Q=");
}
@Test
......@@ -383,7 +387,7 @@ class RestTemplateBuilderTests {
assertThat(actualRequestFactory).isInstanceOf(InterceptingClientHttpRequestFactory.class);
ClientHttpRequestFactory authRequestFactory = (ClientHttpRequestFactory) ReflectionTestUtils
.getField(actualRequestFactory, "requestFactory");
assertThat(authRequestFactory).isInstanceOf(BasicAuthenticationClientHttpRequestFactory.class);
assertThat(authRequestFactory).isInstanceOf(HttpHeadersCustomizingClientHttpRequestFactory.class);
assertThat(authRequestFactory).hasFieldOrPropertyWithValue("requestFactory", requestFactory);
}).build();
}
......
/*
* Copyright 2012-2019 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.boot.web.client;
import org.junit.jupiter.api.Test;
import org.springframework.http.HttpHeaders;
import static org.assertj.core.api.Assertions.assertThat;
/**
* Tests for {@link SimpleHttpHeaderDefaultingCustomizer}.
*
* @author Ilya Lukyanovich
*/
class SimpleHttpHeaderDefaultingCustomizerTest {
@Test
void testApplyTo_shouldAddAllHeaders() {
HttpHeaders httpHeaders = new HttpHeaders();
httpHeaders.add("foo", "bar");
httpHeaders.add("donut", "42");
SimpleHttpHeaderDefaultingCustomizer customizer = new SimpleHttpHeaderDefaultingCustomizer(httpHeaders);
HttpHeaders provided = new HttpHeaders();
customizer.applyTo(provided);
assertThat(provided).containsOnlyKeys("foo", "donut");
assertThat(provided.get("foo")).containsExactly("bar");
assertThat(provided.get("donut")).containsExactly("42");
}
@Test
void testApplyTo_shouldIgnoreProvided() {
HttpHeaders httpHeaders = new HttpHeaders();
httpHeaders.add("foo", "bar");
httpHeaders.add("donut", "42");
SimpleHttpHeaderDefaultingCustomizer customizer = new SimpleHttpHeaderDefaultingCustomizer(httpHeaders);
HttpHeaders provided = new HttpHeaders();
provided.add("donut", "touchme");
customizer.applyTo(provided);
assertThat(provided).containsOnlyKeys("foo", "donut");
assertThat(provided.get("foo")).containsExactly("bar");
assertThat(provided.get("donut")).containsExactly("touchme");
}
@Test
void testSingleHeader() {
HttpHeaders provided = new HttpHeaders();
SimpleHttpHeaderDefaultingCustomizer.singleHeader("foo", "bar").applyTo(provided);
assertThat(provided).containsOnlyKeys("foo");
assertThat(provided.get("foo")).containsExactly("bar");
}
@Test
void testBasicAuthentication() {
HttpHeaders provided = new HttpHeaders();
SimpleHttpHeaderDefaultingCustomizer.basicAuthentication("spring", "boot").applyTo(provided);
assertThat(provided).containsOnlyKeys(HttpHeaders.AUTHORIZATION);
assertThat(provided.get(HttpHeaders.AUTHORIZATION)).containsExactly("bar");
}
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment