From 4b8a937bee7ac265e03054db82bf8f4ac2f2703e Mon Sep 17 00:00:00 2001 From: sdeleuze Date: Fri, 8 Dec 2017 12:42:16 -0800 Subject: [PATCH] Allow interceptors to add existing header values Provide a fully mutable HttpHeaders to ClientHttpRequestInterceptors of a RestTemplate when headers are set using HttpEntity. This avoids UnsupportedOperationException if both HttpEntity and ClientHttpRequestInterceptor add values for the same HTTP header. Issue: SPR-15066 --- .../web/client/RestTemplate.java | 6 ++- .../web/client/RestTemplateTests.java | 45 ++++++++++++++----- 2 files changed, 38 insertions(+), 13 deletions(-) diff --git a/spring-web/src/main/java/org/springframework/web/client/RestTemplate.java b/spring-web/src/main/java/org/springframework/web/client/RestTemplate.java index 87ef209029..4b56e73d52 100644 --- a/spring-web/src/main/java/org/springframework/web/client/RestTemplate.java +++ b/spring-web/src/main/java/org/springframework/web/client/RestTemplate.java @@ -20,6 +20,8 @@ import java.io.IOException; import java.lang.reflect.Type; import java.net.URI; import java.util.ArrayList; +import java.util.Collections; +import java.util.LinkedList; import java.util.List; import java.util.Map; import java.util.Set; @@ -837,7 +839,9 @@ public class RestTemplate extends InterceptingHttpAccessor implements RestOperat HttpHeaders httpHeaders = httpRequest.getHeaders(); HttpHeaders requestHeaders = this.requestEntity.getHeaders(); if (!requestHeaders.isEmpty()) { - httpHeaders.putAll(requestHeaders); + for (Map.Entry> entry : requestHeaders.entrySet()) { + httpHeaders.put(entry.getKey(), new LinkedList(entry.getValue())); + } } if (httpHeaders.getContentLength() < 0) { httpHeaders.setContentLength(0L); diff --git a/spring-web/src/test/java/org/springframework/web/client/RestTemplateTests.java b/spring-web/src/test/java/org/springframework/web/client/RestTemplateTests.java index 6e3236d358..021c377089 100644 --- a/spring-web/src/test/java/org/springframework/web/client/RestTemplateTests.java +++ b/spring-web/src/test/java/org/springframework/web/client/RestTemplateTests.java @@ -39,23 +39,18 @@ import org.springframework.http.MediaType; import org.springframework.http.ResponseEntity; import org.springframework.http.client.ClientHttpRequest; import org.springframework.http.client.ClientHttpRequestFactory; +import org.springframework.http.client.ClientHttpRequestInterceptor; import org.springframework.http.client.ClientHttpResponse; import org.springframework.http.converter.GenericHttpMessageConverter; import org.springframework.http.converter.HttpMessageConverter; import org.springframework.web.util.DefaultUriTemplateHandler; -import static org.junit.Assert.assertEquals; -import static org.junit.Assert.assertFalse; -import static org.junit.Assert.assertNull; -import static org.junit.Assert.assertSame; -import static org.junit.Assert.fail; -import static org.mockito.BDDMockito.any; -import static org.mockito.BDDMockito.eq; -import static org.mockito.BDDMockito.given; -import static org.mockito.BDDMockito.mock; -import static org.mockito.BDDMockito.verify; -import static org.mockito.BDDMockito.willThrow; -import static org.springframework.http.MediaType.parseMediaType; +import static org.hamcrest.MatcherAssert.assertThat; +import static org.hamcrest.collection.IsIterableContainingInOrder.contains; +import static org.junit.Assert.*; +import static org.mockito.BDDMockito.*; +import static org.springframework.http.HttpMethod.POST; +import static org.springframework.http.MediaType.*; /** * @author Arjen Poutsma @@ -840,4 +835,30 @@ public class RestTemplateTests { verify(response).close(); } + @Test // SPR-15066 + public void requestInterceptorCanAddExistingHeaderValue() throws Exception { + ClientHttpRequestInterceptor interceptor = (request, body, execution) -> { + request.getHeaders().add("MyHeader", "MyInterceptorValue"); + return execution.execute(request, body); + }; + template.setInterceptors(Collections.singletonList(interceptor)); + + given(requestFactory.createRequest(new URI("http://example.com"), HttpMethod.POST)).willReturn(request); + HttpHeaders requestHeaders = new HttpHeaders(); + given(request.getHeaders()).willReturn(requestHeaders); + given(request.execute()).willReturn(response); + given(errorHandler.hasError(response)).willReturn(false); + HttpStatus status = HttpStatus.OK; + given(response.getStatusCode()).willReturn(status); + given(response.getStatusText()).willReturn(status.getReasonPhrase()); + + HttpHeaders entityHeaders = new HttpHeaders(); + entityHeaders.add("MyHeader", "MyEntityValue"); + HttpEntity entity = new HttpEntity<>(null, entityHeaders); + template.exchange("http://example.com", HttpMethod.POST, entity, Void.class); + assertThat(requestHeaders.get("MyHeader"), contains("MyEntityValue", "MyInterceptorValue")); + + verify(response).close(); + } + }