Simplify use of headers for SockJsClient requests
Before this change, XhrTransport implementations had to be configured with the headers to use for HTTP requests other than the initial handshake. After this change the handshake headers passed to SockJsClient by default are used for all other HTTP requests related to the SockJS connection (e.g. info request, xhr send/receive). A property on SockJsClient allows restricting the headers to use for other HTTP requests to a subset of the handshake headers. Issue: SPR-13254
This commit is contained in:
@@ -49,6 +49,8 @@ import org.junit.rules.TestName;
|
||||
import org.springframework.beans.factory.annotation.Autowired;
|
||||
import org.springframework.context.annotation.Bean;
|
||||
import org.springframework.context.annotation.Configuration;
|
||||
import org.springframework.http.HttpHeaders;
|
||||
import org.springframework.http.server.ServletServerHttpRequest;
|
||||
import org.springframework.scheduling.concurrent.ThreadPoolTaskScheduler;
|
||||
import org.springframework.tests.Assume;
|
||||
import org.springframework.tests.TestGroup;
|
||||
@@ -100,7 +102,7 @@ public abstract class AbstractSockJsIntegrationTests {
|
||||
|
||||
@BeforeClass
|
||||
public static void performanceTestGroupAssumption() throws Exception {
|
||||
Assume.group(TestGroup.PERFORMANCE);
|
||||
// Assume.group(TestGroup.PERFORMANCE);
|
||||
}
|
||||
|
||||
|
||||
@@ -164,19 +166,36 @@ public abstract class AbstractSockJsIntegrationTests {
|
||||
|
||||
@Test
|
||||
public void echoWebSocket() throws Exception {
|
||||
testEcho(100, createWebSocketTransport());
|
||||
testEcho(100, createWebSocketTransport(), null);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void echoXhrStreaming() throws Exception {
|
||||
testEcho(100, createXhrTransport());
|
||||
testEcho(100, createXhrTransport(), null);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void echoXhr() throws Exception {
|
||||
AbstractXhrTransport xhrTransport = createXhrTransport();
|
||||
xhrTransport.setXhrStreamingDisabled(true);
|
||||
testEcho(100, xhrTransport);
|
||||
testEcho(100, xhrTransport, null);
|
||||
}
|
||||
|
||||
// SPR-13254
|
||||
|
||||
@Test
|
||||
public void echoXhrWithHeaders() throws Exception {
|
||||
AbstractXhrTransport xhrTransport = createXhrTransport();
|
||||
xhrTransport.setXhrStreamingDisabled(true);
|
||||
|
||||
WebSocketHttpHeaders headers = new WebSocketHttpHeaders();
|
||||
headers.add("auth", "123");
|
||||
testEcho(10, xhrTransport, headers);
|
||||
|
||||
for (Map.Entry<String, HttpHeaders> entry : this.testFilter.requests.entrySet()) {
|
||||
HttpHeaders httpHeaders = entry.getValue();
|
||||
assertEquals("No auth header for: " + entry.getKey(), "123", httpHeaders.getFirst("auth"));
|
||||
}
|
||||
}
|
||||
|
||||
@Test
|
||||
@@ -246,14 +265,15 @@ public abstract class AbstractSockJsIntegrationTests {
|
||||
}
|
||||
|
||||
|
||||
private void testEcho(int messageCount, Transport transport) throws Exception {
|
||||
private void testEcho(int messageCount, Transport transport, WebSocketHttpHeaders headers) throws Exception {
|
||||
List<TextMessage> messages = new ArrayList<>();
|
||||
for (int i = 0; i < messageCount; i++) {
|
||||
messages.add(new TextMessage("m" + i));
|
||||
}
|
||||
TestClientHandler handler = new TestClientHandler();
|
||||
initSockJsClient(transport);
|
||||
WebSocketSession session = this.sockJsClient.doHandshake(handler, this.baseUrl + "/echo").get();
|
||||
URI url = new URI(this.baseUrl + "/echo");
|
||||
WebSocketSession session = this.sockJsClient.doHandshake(handler, headers, url).get();
|
||||
for (TextMessage message : messages) {
|
||||
session.sendMessage(message);
|
||||
}
|
||||
@@ -386,7 +406,7 @@ public abstract class AbstractSockJsIntegrationTests {
|
||||
|
||||
private static class TestFilter implements Filter {
|
||||
|
||||
private final List<ServletRequest> requests = new ArrayList<>();
|
||||
private final Map<String, HttpHeaders> requests = new HashMap<>();
|
||||
|
||||
private final Map<String, Long> sleepDelayMap = new HashMap<>();
|
||||
|
||||
@@ -397,10 +417,13 @@ public abstract class AbstractSockJsIntegrationTests {
|
||||
public void doFilter(ServletRequest request, ServletResponse response, FilterChain chain)
|
||||
throws IOException, ServletException {
|
||||
|
||||
this.requests.add(request);
|
||||
HttpServletRequest httpRequest = (HttpServletRequest) request;
|
||||
String uri = httpRequest.getRequestURI();
|
||||
HttpHeaders headers = new ServletServerHttpRequest(httpRequest).getHeaders();
|
||||
this.requests.put(uri, headers);
|
||||
|
||||
for (String suffix : this.sleepDelayMap.keySet()) {
|
||||
if (((HttpServletRequest) request).getRequestURI().endsWith(suffix)) {
|
||||
if ((httpRequest).getRequestURI().endsWith(suffix)) {
|
||||
try {
|
||||
Thread.sleep(this.sleepDelayMap.get(suffix));
|
||||
break;
|
||||
@@ -411,7 +434,7 @@ public abstract class AbstractSockJsIntegrationTests {
|
||||
}
|
||||
}
|
||||
for (String suffix : this.sendErrorMap.keySet()) {
|
||||
if (((HttpServletRequest) request).getRequestURI().endsWith(suffix)) {
|
||||
if ((httpRequest).getRequestURI().endsWith(suffix)) {
|
||||
((HttpServletResponse) response).sendError(this.sendErrorMap.get(suffix));
|
||||
return;
|
||||
}
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
/*
|
||||
* Copyright 2002-2014 the original author or authors.
|
||||
* Copyright 2002-2015 the original author or authors.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
@@ -64,7 +64,7 @@ public class ClientSockJsSessionTests {
|
||||
public void setup() throws Exception {
|
||||
SockJsUrlInfo urlInfo = new SockJsUrlInfo(new URI("http://example.com"));
|
||||
Transport transport = mock(Transport.class);
|
||||
TransportRequest request = new DefaultTransportRequest(urlInfo, null, transport, TransportType.XHR, CODEC);
|
||||
TransportRequest request = new DefaultTransportRequest(urlInfo, null, null, transport, TransportType.XHR, CODEC);
|
||||
this.handler = mock(WebSocketHandler.class);
|
||||
this.connectFuture = new SettableListenableFuture<>();
|
||||
this.session = new TestClientSockJsSession(request, this.handler, this.connectFuture);
|
||||
|
||||
@@ -127,7 +127,7 @@ public class DefaultTransportRequestTests {
|
||||
|
||||
protected DefaultTransportRequest createTransportRequest(Transport transport, TransportType type) throws Exception {
|
||||
SockJsUrlInfo urlInfo = new SockJsUrlInfo(new URI("http://example.com"));
|
||||
return new DefaultTransportRequest(urlInfo, new HttpHeaders(), transport, type, CODEC);
|
||||
return new DefaultTransportRequest(urlInfo, new HttpHeaders(), new HttpHeaders(), transport, type, CODEC);
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
/*
|
||||
* Copyright 2002-2014 the original author or authors.
|
||||
* Copyright 2002-2015 the original author or authors.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
@@ -182,7 +182,8 @@ public class RestTemplateXhrTransportTests {
|
||||
SockJsUrlInfo urlInfo = new SockJsUrlInfo(new URI("http://example.com"));
|
||||
HttpHeaders headers = new HttpHeaders();
|
||||
headers.add("h-foo", "h-bar");
|
||||
TransportRequest request = new DefaultTransportRequest(urlInfo, headers, transport, TransportType.XHR, CODEC);
|
||||
TransportRequest request = new DefaultTransportRequest(urlInfo, headers, headers,
|
||||
transport, TransportType.XHR, CODEC);
|
||||
|
||||
return transport.connect(request, this.webSocketHandler);
|
||||
}
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
/*
|
||||
* Copyright 2002-2014 the original author or authors.
|
||||
* Copyright 2002-2015 the original author or authors.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
@@ -16,22 +16,34 @@
|
||||
|
||||
package org.springframework.web.socket.sockjs.client;
|
||||
|
||||
import java.net.URI;
|
||||
import java.net.URISyntaxException;
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
|
||||
import org.junit.Before;
|
||||
import org.junit.Test;
|
||||
import org.mockito.ArgumentCaptor;
|
||||
|
||||
import org.springframework.http.HttpHeaders;
|
||||
import org.springframework.http.HttpStatus;
|
||||
import org.springframework.util.concurrent.ListenableFutureCallback;
|
||||
import org.springframework.web.client.HttpServerErrorException;
|
||||
import org.springframework.web.socket.WebSocketHandler;
|
||||
import org.springframework.web.socket.WebSocketHttpHeaders;
|
||||
import org.springframework.web.socket.WebSocketSession;
|
||||
import org.springframework.web.socket.sockjs.client.TestTransport.XhrTestTransport;
|
||||
|
||||
import static org.junit.Assert.*;
|
||||
import static org.mockito.BDDMockito.*;
|
||||
import static org.junit.Assert.assertEquals;
|
||||
import static org.junit.Assert.assertFalse;
|
||||
import static org.junit.Assert.assertTrue;
|
||||
import static org.mockito.BDDMockito.any;
|
||||
import static org.mockito.BDDMockito.given;
|
||||
import static org.mockito.BDDMockito.mock;
|
||||
import static org.mockito.BDDMockito.times;
|
||||
import static org.mockito.BDDMockito.verify;
|
||||
import static org.mockito.BDDMockito.verifyNoMoreInteractions;
|
||||
import static org.mockito.BDDMockito.when;
|
||||
|
||||
/**
|
||||
* Unit tests for {@link org.springframework.web.socket.sockjs.client.SockJsClient}.
|
||||
@@ -102,11 +114,51 @@ public class SockJsClientTests {
|
||||
assertTrue(this.xhrTransport.getRequest().getTransportUrl().toString().endsWith("xhr"));
|
||||
}
|
||||
|
||||
// SPR-13254
|
||||
|
||||
@Test
|
||||
public void connectWithHandshakeHeaders() throws Exception {
|
||||
ArgumentCaptor<HttpHeaders> headersCaptor = setupInfoRequest(false);
|
||||
this.xhrTransport.setStreamingDisabled(true);
|
||||
|
||||
WebSocketHttpHeaders headers = new WebSocketHttpHeaders();
|
||||
headers.set("foo", "bar");
|
||||
headers.set("auth", "123");
|
||||
this.sockJsClient.doHandshake(handler, headers, new URI(URL)).addCallback(this.connectCallback);
|
||||
|
||||
HttpHeaders httpHeaders = headersCaptor.getValue();
|
||||
assertEquals(2, httpHeaders.size());
|
||||
assertEquals("bar", httpHeaders.getFirst("foo"));
|
||||
assertEquals("123", httpHeaders.getFirst("auth"));
|
||||
|
||||
httpHeaders = this.xhrTransport.getRequest().getHttpRequestHeaders();
|
||||
assertEquals(2, httpHeaders.size());
|
||||
assertEquals("bar", httpHeaders.getFirst("foo"));
|
||||
assertEquals("123", httpHeaders.getFirst("auth"));
|
||||
}
|
||||
|
||||
@Test
|
||||
public void connectAndUseSubsetOfHandshakeHeadersForHttpRequests() throws Exception {
|
||||
ArgumentCaptor<HttpHeaders> headersCaptor = setupInfoRequest(false);
|
||||
this.xhrTransport.setStreamingDisabled(true);
|
||||
|
||||
WebSocketHttpHeaders headers = new WebSocketHttpHeaders();
|
||||
headers.set("foo", "bar");
|
||||
headers.set("auth", "123");
|
||||
this.sockJsClient.setHttpHeaderNames("auth");
|
||||
this.sockJsClient.doHandshake(handler, headers, new URI(URL)).addCallback(this.connectCallback);
|
||||
|
||||
assertEquals(1, headersCaptor.getValue().size());
|
||||
assertEquals("123", headersCaptor.getValue().getFirst("auth"));
|
||||
assertEquals(1, this.xhrTransport.getRequest().getHttpRequestHeaders().size());
|
||||
assertEquals("123", this.xhrTransport.getRequest().getHttpRequestHeaders().getFirst("auth"));
|
||||
}
|
||||
|
||||
@Test
|
||||
public void connectSockJsInfo() throws Exception {
|
||||
setupInfoRequest(true);
|
||||
this.sockJsClient.doHandshake(handler, URL);
|
||||
verify(this.infoReceiver, times(1)).executeInfoRequest(any());
|
||||
verify(this.infoReceiver, times(1)).executeInfoRequest(any(), any());
|
||||
}
|
||||
|
||||
@Test
|
||||
@@ -115,22 +167,27 @@ public class SockJsClientTests {
|
||||
this.sockJsClient.doHandshake(handler, URL);
|
||||
this.sockJsClient.doHandshake(handler, URL);
|
||||
this.sockJsClient.doHandshake(handler, URL);
|
||||
verify(this.infoReceiver, times(1)).executeInfoRequest(any());
|
||||
verify(this.infoReceiver, times(1)).executeInfoRequest(any(), any());
|
||||
}
|
||||
|
||||
@Test
|
||||
public void connectInfoRequestFailure() throws URISyntaxException {
|
||||
HttpServerErrorException exception = new HttpServerErrorException(HttpStatus.SERVICE_UNAVAILABLE);
|
||||
given(this.infoReceiver.executeInfoRequest(any())).willThrow(exception);
|
||||
given(this.infoReceiver.executeInfoRequest(any(), any())).willThrow(exception);
|
||||
this.sockJsClient.doHandshake(handler, URL).addCallback(this.connectCallback);
|
||||
verify(this.connectCallback).onFailure(exception);
|
||||
assertFalse(this.webSocketTransport.invoked());
|
||||
assertFalse(this.xhrTransport.invoked());
|
||||
}
|
||||
|
||||
private void setupInfoRequest(boolean webSocketEnabled) {
|
||||
given(this.infoReceiver.executeInfoRequest(any())).willReturn("{\"entropy\":123," +
|
||||
"\"origins\":[\"*:*\"],\"cookie_needed\":true,\"websocket\":" + webSocketEnabled + "}");
|
||||
private ArgumentCaptor<HttpHeaders> setupInfoRequest(boolean webSocketEnabled) {
|
||||
ArgumentCaptor<HttpHeaders> headersCaptor = ArgumentCaptor.forClass(HttpHeaders.class);
|
||||
when(this.infoReceiver.executeInfoRequest(any(), headersCaptor.capture())).thenReturn(
|
||||
"{\"entropy\":123," +
|
||||
"\"origins\":[\"*:*\"]," +
|
||||
"\"cookie_needed\":true," +
|
||||
"\"websocket\":" + webSocketEnabled + "}");
|
||||
return headersCaptor;
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
/*
|
||||
* Copyright 2002-2014 the original author or authors.
|
||||
* Copyright 2002-2015 the original author or authors.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
@@ -18,9 +18,12 @@ package org.springframework.web.socket.sockjs.client;
|
||||
|
||||
import java.net.URI;
|
||||
import java.util.Arrays;
|
||||
import java.util.Collections;
|
||||
import java.util.List;
|
||||
|
||||
import org.mockito.ArgumentCaptor;
|
||||
|
||||
import org.springframework.http.HttpHeaders;
|
||||
import org.springframework.util.concurrent.ListenableFuture;
|
||||
import org.springframework.util.concurrent.ListenableFutureCallback;
|
||||
import org.springframework.web.socket.TextMessage;
|
||||
@@ -28,7 +31,8 @@ import org.springframework.web.socket.WebSocketHandler;
|
||||
import org.springframework.web.socket.WebSocketSession;
|
||||
import org.springframework.web.socket.sockjs.transport.TransportType;
|
||||
|
||||
import static org.mockito.Mockito.*;
|
||||
import static org.mockito.Mockito.mock;
|
||||
import static org.mockito.Mockito.verify;
|
||||
|
||||
/**
|
||||
* Test SockJS Transport.
|
||||
@@ -51,7 +55,7 @@ class TestTransport implements Transport {
|
||||
|
||||
@Override
|
||||
public List<TransportType> getTransportTypes() {
|
||||
return Arrays.asList(TransportType.WEBSOCKET);
|
||||
return Collections.singletonList(TransportType.WEBSOCKET);
|
||||
}
|
||||
|
||||
public TransportRequest getRequest() {
|
||||
@@ -95,7 +99,7 @@ class TestTransport implements Transport {
|
||||
@Override
|
||||
public List<TransportType> getTransportTypes() {
|
||||
return (isXhrStreamingDisabled() ?
|
||||
Arrays.asList(TransportType.XHR) :
|
||||
Collections.singletonList(TransportType.XHR) :
|
||||
Arrays.asList(TransportType.XHR_STREAMING, TransportType.XHR));
|
||||
}
|
||||
|
||||
@@ -109,11 +113,11 @@ class TestTransport implements Transport {
|
||||
}
|
||||
|
||||
@Override
|
||||
public void executeSendRequest(URI transportUrl, TextMessage message) {
|
||||
public void executeSendRequest(URI transportUrl, HttpHeaders headers, TextMessage message) {
|
||||
}
|
||||
|
||||
@Override
|
||||
public String executeInfoRequest(URI infoUrl) {
|
||||
public String executeInfoRequest(URI infoUrl, HttpHeaders headers) {
|
||||
return null;
|
||||
}
|
||||
}
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
/*
|
||||
* Copyright 2002-2014 the original author or authors.
|
||||
* Copyright 2002-2015 the original author or authors.
|
||||
*
|
||||
* Licensed under the Apache License, Version 2.0 (the "License");
|
||||
* you may not use this file except in compliance with the License.
|
||||
@@ -46,25 +46,25 @@ public class XhrTransportTests {
|
||||
public void infoResponse() throws Exception {
|
||||
TestXhrTransport transport = new TestXhrTransport();
|
||||
transport.infoResponseToReturn = new ResponseEntity<>("body", HttpStatus.OK);
|
||||
assertEquals("body", transport.executeInfoRequest(new URI("http://example.com/info")));
|
||||
assertEquals("body", transport.executeInfoRequest(new URI("http://example.com/info"), null));
|
||||
}
|
||||
|
||||
@Test(expected = HttpServerErrorException.class)
|
||||
public void infoResponseError() throws Exception {
|
||||
TestXhrTransport transport = new TestXhrTransport();
|
||||
transport.infoResponseToReturn = new ResponseEntity<>("body", HttpStatus.BAD_REQUEST);
|
||||
assertEquals("body", transport.executeInfoRequest(new URI("http://example.com/info")));
|
||||
assertEquals("body", transport.executeInfoRequest(new URI("http://example.com/info"), null));
|
||||
}
|
||||
|
||||
@Test
|
||||
public void sendMessage() throws Exception {
|
||||
HttpHeaders requestHeaders = new HttpHeaders();
|
||||
requestHeaders.set("foo", "bar");
|
||||
requestHeaders.setContentType(MediaType.APPLICATION_JSON);
|
||||
TestXhrTransport transport = new TestXhrTransport();
|
||||
transport.setRequestHeaders(requestHeaders);
|
||||
transport.sendMessageResponseToReturn = new ResponseEntity<>(HttpStatus.NO_CONTENT);
|
||||
URI url = new URI("http://example.com");
|
||||
transport.executeSendRequest(url, new TextMessage("payload"));
|
||||
transport.executeSendRequest(url, requestHeaders, new TextMessage("payload"));
|
||||
assertEquals(2, transport.actualSendRequestHeaders.size());
|
||||
assertEquals("bar", transport.actualSendRequestHeaders.getFirst("foo"));
|
||||
assertEquals(MediaType.APPLICATION_JSON, transport.actualSendRequestHeaders.getContentType());
|
||||
@@ -75,9 +75,10 @@ public class XhrTransportTests {
|
||||
TestXhrTransport transport = new TestXhrTransport();
|
||||
transport.sendMessageResponseToReturn = new ResponseEntity<>(HttpStatus.BAD_REQUEST);
|
||||
URI url = new URI("http://example.com");
|
||||
transport.executeSendRequest(url, new TextMessage("payload"));
|
||||
transport.executeSendRequest(url, null, new TextMessage("payload"));
|
||||
}
|
||||
|
||||
@SuppressWarnings("deprecation")
|
||||
@Test
|
||||
public void connect() throws Exception {
|
||||
HttpHeaders handshakeHeaders = new HttpHeaders();
|
||||
@@ -101,6 +102,7 @@ public class XhrTransportTests {
|
||||
verify(request).addTimeoutTask(captor.capture());
|
||||
verify(request).getTransportUrl();
|
||||
verify(request).getHandshakeHeaders();
|
||||
verify(request).getHttpRequestHeaders();
|
||||
verifyNoMoreInteractions(request);
|
||||
|
||||
assertEquals(2, transport.actualHandshakeHeaders.size());
|
||||
@@ -127,7 +129,7 @@ public class XhrTransportTests {
|
||||
|
||||
|
||||
@Override
|
||||
protected ResponseEntity<String> executeInfoRequestInternal(URI infoUrl) {
|
||||
protected ResponseEntity<String> executeInfoRequestInternal(URI infoUrl, HttpHeaders headers) {
|
||||
return this.infoResponseToReturn;
|
||||
}
|
||||
|
||||
|
||||
Reference in New Issue
Block a user