diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/client/AbstractXhrTransport.java b/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/client/AbstractXhrTransport.java index a8e12ac3b5..ce77aeb279 100644 --- a/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/client/AbstractXhrTransport.java +++ b/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/client/AbstractXhrTransport.java @@ -26,7 +26,6 @@ import org.apache.commons.logging.LogFactory; import org.springframework.http.HttpHeaders; import org.springframework.http.HttpStatus; -import org.springframework.http.MediaType; import org.springframework.http.ResponseEntity; import org.springframework.util.concurrent.ListenableFuture; import org.springframework.util.concurrent.SettableListenableFuture; @@ -61,8 +60,6 @@ public abstract class AbstractXhrTransport implements XhrTransport { private HttpHeaders requestHeaders = new HttpHeaders(); - private HttpHeaders xhrSendRequestHeaders = new HttpHeaders(); - @Override public List getTransportTypes() { @@ -97,17 +94,17 @@ public abstract class AbstractXhrTransport implements XhrTransport { /** * Configure headers to be added to every executed HTTP request. * @param requestHeaders the headers to add to requests + * @deprecated as of 4.2 in favor of {@link SockJsClient#setHttpHeaderNames}. */ + @Deprecated public void setRequestHeaders(HttpHeaders requestHeaders) { this.requestHeaders.clear(); - this.xhrSendRequestHeaders.clear(); if (requestHeaders != null) { this.requestHeaders.putAll(requestHeaders); - this.xhrSendRequestHeaders.putAll(requestHeaders); - this.xhrSendRequestHeaders.setContentType(MediaType.APPLICATION_JSON); } } + @Deprecated public HttpHeaders getRequestHeaders() { return this.requestHeaders; } @@ -115,6 +112,7 @@ public abstract class AbstractXhrTransport implements XhrTransport { // Transport methods + @SuppressWarnings("deprecation") @Override public ListenableFuture connect(TransportRequest request, WebSocketHandler handler) { SettableListenableFuture connectFuture = new SettableListenableFuture(); @@ -128,8 +126,8 @@ public abstract class AbstractXhrTransport implements XhrTransport { } HttpHeaders handshakeHeaders = new HttpHeaders(); - handshakeHeaders.putAll(request.getHandshakeHeaders()); handshakeHeaders.putAll(getRequestHeaders()); + handshakeHeaders.putAll(request.getHandshakeHeaders()); connectInternal(request, handler, receiveUrl, handshakeHeaders, session, connectFuture); return connectFuture; @@ -142,11 +140,17 @@ public abstract class AbstractXhrTransport implements XhrTransport { // InfoReceiver methods @Override - public String executeInfoRequest(URI infoUrl) { + @SuppressWarnings("deprecation") + public String executeInfoRequest(URI infoUrl, HttpHeaders headers) { if (logger.isDebugEnabled()) { logger.debug("Executing SockJS Info request, url=" + infoUrl); } - ResponseEntity response = executeInfoRequestInternal(infoUrl); + HttpHeaders infoRequestHeaders = new HttpHeaders(); + infoRequestHeaders.putAll(getRequestHeaders()); + if (headers != null) { + infoRequestHeaders.putAll(headers); + } + ResponseEntity response = executeInfoRequestInternal(infoUrl, infoRequestHeaders); if (response.getStatusCode() != HttpStatus.OK) { if (logger.isErrorEnabled()) { logger.error("SockJS Info request (url=" + infoUrl + ") failed: " + response); @@ -159,16 +163,16 @@ public abstract class AbstractXhrTransport implements XhrTransport { return response.getBody(); } - protected abstract ResponseEntity executeInfoRequestInternal(URI infoUrl); + protected abstract ResponseEntity executeInfoRequestInternal(URI infoUrl, HttpHeaders headers); // XhrTransport methods @Override - public void executeSendRequest(URI url, TextMessage message) { + public void executeSendRequest(URI url, HttpHeaders headers, TextMessage message) { if (logger.isTraceEnabled()) { logger.trace("Starting XHR send, url=" + url); } - ResponseEntity response = executeSendRequestInternal(url, this.xhrSendRequestHeaders, message); + ResponseEntity response = executeSendRequestInternal(url, headers, message); if (response.getStatusCode() != HttpStatus.NO_CONTENT) { if (logger.isErrorEnabled()) { logger.error("XHR send request (url=" + url + ") failed: " + response); @@ -180,7 +184,8 @@ public abstract class AbstractXhrTransport implements XhrTransport { } } - protected abstract ResponseEntity executeSendRequestInternal(URI url, HttpHeaders headers, TextMessage message); + protected abstract ResponseEntity executeSendRequestInternal(URI url, + HttpHeaders headers, TextMessage message); @Override diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/client/DefaultTransportRequest.java b/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/client/DefaultTransportRequest.java index bf1eacdcd7..06563ad633 100644 --- a/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/client/DefaultTransportRequest.java +++ b/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/client/DefaultTransportRequest.java @@ -52,6 +52,8 @@ class DefaultTransportRequest implements TransportRequest { private final HttpHeaders handshakeHeaders; + private final HttpHeaders httpRequestHeaders; + private final Transport transport; private final TransportType serverTransportType; @@ -69,7 +71,8 @@ class DefaultTransportRequest implements TransportRequest { private DefaultTransportRequest fallbackRequest; - public DefaultTransportRequest(SockJsUrlInfo sockJsUrlInfo, HttpHeaders handshakeHeaders, + public DefaultTransportRequest(SockJsUrlInfo sockJsUrlInfo, + HttpHeaders handshakeHeaders, HttpHeaders httpRequestHeaders, Transport transport, TransportType serverTransportType, SockJsMessageCodec codec) { Assert.notNull(sockJsUrlInfo, "'sockJsUrlInfo' is required"); @@ -78,6 +81,7 @@ class DefaultTransportRequest implements TransportRequest { Assert.notNull(codec, "'codec' is required"); this.sockJsUrlInfo = sockJsUrlInfo; this.handshakeHeaders = (handshakeHeaders != null ? handshakeHeaders : new HttpHeaders()); + this.httpRequestHeaders = (httpRequestHeaders != null ? httpRequestHeaders : new HttpHeaders()); this.transport = transport; this.serverTransportType = serverTransportType; this.codec = codec; @@ -94,6 +98,11 @@ class DefaultTransportRequest implements TransportRequest { return this.handshakeHeaders; } + @Override + public HttpHeaders getHttpRequestHeaders() { + return this.httpRequestHeaders; + } + @Override public URI getTransportUrl() { return this.sockJsUrlInfo.getTransportUrl(this.serverTransportType); diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/client/InfoReceiver.java b/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/client/InfoReceiver.java index e921c97d72..b039c4dcca 100644 --- a/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/client/InfoReceiver.java +++ b/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/client/InfoReceiver.java @@ -17,6 +17,8 @@ package org.springframework.web.socket.sockjs.client; import java.net.URI; +import org.springframework.http.HttpHeaders; + /** * A component that can execute the SockJS "Info" request that needs to be * performed before the SockJS session starts in order to check server endpoint @@ -34,10 +36,11 @@ public interface InfoReceiver { /** * Perform an HTTP request to the SockJS "Info" URL. * and return the resulting JSON response content, or raise an exception. - * + *

Note that as of 4.2 this method accepts a {@code headers} parameter. * @param infoUrl the URL to obtain SockJS server information from + * @param headers the headers to use for the request * @return the body of the response */ - String executeInfoRequest(URI infoUrl); + String executeInfoRequest(URI infoUrl, HttpHeaders headers); } \ No newline at end of file diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/client/JettyXhrTransport.java b/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/client/JettyXhrTransport.java index 969c71e278..89088fbdbe 100644 --- a/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/client/JettyXhrTransport.java +++ b/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/client/JettyXhrTransport.java @@ -106,11 +106,12 @@ public class JettyXhrTransport extends AbstractXhrTransport implements XhrTransp @Override - protected void connectInternal(TransportRequest request, WebSocketHandler handler, + protected void connectInternal(TransportRequest transportRequest, WebSocketHandler handler, URI url, HttpHeaders handshakeHeaders, XhrClientSockJsSession session, SettableListenableFuture connectFuture) { - SockJsResponseListener listener = new SockJsResponseListener(url, getRequestHeaders(), session, connectFuture); + HttpHeaders httpHeaders = transportRequest.getHttpRequestHeaders(); + SockJsResponseListener listener = new SockJsResponseListener(url, httpHeaders, session, connectFuture); executeReceiveRequest(url, handshakeHeaders, listener); } @@ -124,8 +125,8 @@ public class JettyXhrTransport extends AbstractXhrTransport implements XhrTransp } @Override - protected ResponseEntity executeInfoRequestInternal(URI infoUrl) { - return executeRequest(infoUrl, HttpMethod.GET, getRequestHeaders(), null); + protected ResponseEntity executeInfoRequestInternal(URI infoUrl, HttpHeaders headers) { + return executeRequest(infoUrl, HttpMethod.GET, headers, null); } @Override diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/client/RestTemplateXhrTransport.java b/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/client/RestTemplateXhrTransport.java index 5d6d25cf3c..62f913cc14 100644 --- a/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/client/RestTemplateXhrTransport.java +++ b/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/client/RestTemplateXhrTransport.java @@ -94,15 +94,16 @@ public class RestTemplateXhrTransport extends AbstractXhrTransport implements Xh @Override - protected void connectInternal(final TransportRequest request, final WebSocketHandler handler, + protected void connectInternal(final TransportRequest transportRequest, final WebSocketHandler handler, final URI receiveUrl, final HttpHeaders handshakeHeaders, final XhrClientSockJsSession session, final SettableListenableFuture connectFuture) { getTaskExecutor().execute(new Runnable() { @Override public void run() { + HttpHeaders httpHeaders = transportRequest.getHttpRequestHeaders(); XhrRequestCallback requestCallback = new XhrRequestCallback(handshakeHeaders); - XhrRequestCallback requestCallbackAfterHandshake = new XhrRequestCallback(getRequestHeaders()); + XhrRequestCallback requestCallbackAfterHandshake = new XhrRequestCallback(httpHeaders); XhrReceiveExtractor responseExtractor = new XhrReceiveExtractor(session); while (true) { if (session.isDisconnected()) { @@ -132,8 +133,8 @@ public class RestTemplateXhrTransport extends AbstractXhrTransport implements Xh } @Override - public ResponseEntity executeInfoRequestInternal(URI infoUrl) { - RequestCallback requestCallback = new XhrRequestCallback(getRequestHeaders()); + protected ResponseEntity executeInfoRequestInternal(URI infoUrl, HttpHeaders headers) { + RequestCallback requestCallback = new XhrRequestCallback(headers); return this.restTemplate.execute(infoUrl, HttpMethod.GET, requestCallback, textResponseExtractor); } diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/client/SockJsClient.java b/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/client/SockJsClient.java index e73c912d41..2d34c44b48 100644 --- a/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/client/SockJsClient.java +++ b/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/client/SockJsClient.java @@ -78,6 +78,8 @@ public class SockJsClient implements WebSocketClient, Lifecycle { private final List transports; + private String[] httpHeaderNames; + private InfoReceiver infoReceiver; private SockJsMessageCodec messageCodec; @@ -116,6 +118,30 @@ public class SockJsClient implements WebSocketClient, Lifecycle { } + /** + * The names of HTTP headers that should be copied from the handshake headers + * of each call to {@link SockJsClient#doHandshake(WebSocketHandler, WebSocketHttpHeaders, URI)} + * and also used with other HTTP requests issued as part of that SockJS + * connection, e.g. the initial info request, XHR send or receive requests. + * + *

By default if this property is not set, all handshake headers are also + * used for other HTTP requests. Set it if you want only a subset of handshake + * headers (e.g. auth headers) to be used for other HTTP requests. + * + * @param httpHeaderNames HTTP header names + */ + public void setHttpHeaderNames(String... httpHeaderNames) { + this.httpHeaderNames = httpHeaderNames; + } + + /** + * The configured HTTP header names to be copied from the handshake + * headers and also included in other HTTP requests. + */ + public String[] getHttpHeaderNames() { + return this.httpHeaderNames; + } + /** * Configure the {@code InfoReceiver} to use to perform the SockJS "Info" * request before the SockJS session starts. @@ -225,7 +251,7 @@ public class SockJsClient implements WebSocketClient, Lifecycle { SettableListenableFuture connectFuture = new SettableListenableFuture(); try { SockJsUrlInfo sockJsUrlInfo = new SockJsUrlInfo(url); - ServerInfo serverInfo = getServerInfo(sockJsUrlInfo); + ServerInfo serverInfo = getServerInfo(sockJsUrlInfo, getHttpRequestHeaders(headers)); createRequest(sockJsUrlInfo, headers, serverInfo).connect(handler, connectFuture); } catch (Throwable exception) { @@ -237,12 +263,27 @@ public class SockJsClient implements WebSocketClient, Lifecycle { return connectFuture; } - private ServerInfo getServerInfo(SockJsUrlInfo sockJsUrlInfo) { + private HttpHeaders getHttpRequestHeaders(HttpHeaders webSocketHttpHeaders) { + if (getHttpHeaderNames() == null) { + return webSocketHttpHeaders; + } + else { + HttpHeaders httpHeaders = new HttpHeaders(); + for (String name : getHttpHeaderNames()) { + if (webSocketHttpHeaders.containsKey(name)) { + httpHeaders.put(name, webSocketHttpHeaders.get(name)); + } + } + return httpHeaders; + } + } + + private ServerInfo getServerInfo(SockJsUrlInfo sockJsUrlInfo, HttpHeaders headers) { URI infoUrl = sockJsUrlInfo.getInfoUrl(); ServerInfo info = this.serverInfoCache.get(infoUrl); if (info == null) { long start = System.currentTimeMillis(); - String response = this.infoReceiver.executeInfoRequest(infoUrl); + String response = this.infoReceiver.executeInfoRequest(infoUrl, headers); long infoRequestTime = System.currentTimeMillis() - start; info = new ServerInfo(response, infoRequestTime); this.serverInfoCache.put(infoUrl, info); @@ -255,7 +296,8 @@ public class SockJsClient implements WebSocketClient, Lifecycle { for (Transport transport : this.transports) { for (TransportType type : transport.getTransportTypes()) { if (serverInfo.isWebSocketEnabled() || !TransportType.WEBSOCKET.equals(type)) { - requests.add(new DefaultTransportRequest(urlInfo, headers, transport, type, getMessageCodec())); + requests.add(new DefaultTransportRequest(urlInfo, headers, getHttpRequestHeaders(headers), + transport, type, getMessageCodec())); } } } diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/client/TransportRequest.java b/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/client/TransportRequest.java index d0fc7df319..94bd3c65c6 100644 --- a/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/client/TransportRequest.java +++ b/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/client/TransportRequest.java @@ -47,6 +47,13 @@ public interface TransportRequest { */ HttpHeaders getHandshakeHeaders(); + /** + * Return the headers to add to all other HTTP requests besides the handshake + * request such XHR receive and send requests. + * @since 4.2 + */ + HttpHeaders getHttpRequestHeaders(); + /** * Return the transport URL for the given transport. * For an {@link XhrTransport} this is the URL for receiving messages. diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/client/UndertowXhrTransport.java b/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/client/UndertowXhrTransport.java index 68c064fa72..8f3757833c 100644 --- a/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/client/UndertowXhrTransport.java +++ b/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/client/UndertowXhrTransport.java @@ -134,11 +134,11 @@ public class UndertowXhrTransport extends AbstractXhrTransport implements XhrTra HttpHeaders handshakeHeaders, XhrClientSockJsSession session, SettableListenableFuture connectFuture) { - executeReceiveRequest(receiveUrl, handshakeHeaders, session, connectFuture); + executeReceiveRequest(request, receiveUrl, handshakeHeaders, session, connectFuture); } - private void executeReceiveRequest(final URI url, final HttpHeaders headers, - final XhrClientSockJsSession session, + private void executeReceiveRequest(final TransportRequest transportRequest, + final URI url, final HttpHeaders headers, final XhrClientSockJsSession session, final SettableListenableFuture connectFuture) { if (logger.isTraceEnabled()) { @@ -154,8 +154,9 @@ public class UndertowXhrTransport extends AbstractXhrTransport implements XhrTra HttpString headerName = HttpString.tryFromString(HttpHeaders.HOST); request.getRequestHeaders().add(headerName, url.getHost()); addHttpHeaders(request, headers); - connection.sendRequest(request, createReceiveCallback(url, - getRequestHeaders(), session, connectFuture)); + HttpHeaders httpHeaders = transportRequest.getHttpRequestHeaders(); + connection.sendRequest(request, createReceiveCallback(transportRequest, + url, httpHeaders, session, connectFuture)); } @Override @@ -175,8 +176,8 @@ public class UndertowXhrTransport extends AbstractXhrTransport implements XhrTra } } - private ClientCallback createReceiveCallback(final URI url, final HttpHeaders headers, - final XhrClientSockJsSession sockJsSession, + private ClientCallback createReceiveCallback(final TransportRequest transportRequest, + final URI url, final HttpHeaders headers, final XhrClientSockJsSession sockJsSession, final SettableListenableFuture connectFuture) { return new ClientCallback() { @@ -194,8 +195,9 @@ public class UndertowXhrTransport extends AbstractXhrTransport implements XhrTra onFailure(new HttpServerErrorException(status, "Unexpected XHR receive status")); } else { - SockJsResponseListener listener = new SockJsResponseListener(result.getConnection(), - url, headers, sockJsSession, connectFuture); + SockJsResponseListener listener = new SockJsResponseListener( + transportRequest, result.getConnection(), url, headers, + sockJsSession, connectFuture); listener.setup(result.getResponseChannel()); } if (logger.isTraceEnabled()) { @@ -254,8 +256,8 @@ public class UndertowXhrTransport extends AbstractXhrTransport implements XhrTra } @Override - protected ResponseEntity executeInfoRequestInternal(URI infoUrl) { - return executeRequest(infoUrl, Methods.GET, getRequestHeaders(), null); + protected ResponseEntity executeInfoRequestInternal(URI infoUrl, HttpHeaders headers) { + return executeRequest(infoUrl, Methods.GET, headers, null); } @Override @@ -360,6 +362,8 @@ public class UndertowXhrTransport extends AbstractXhrTransport implements XhrTra private class SockJsResponseListener implements ChannelListener { + private final TransportRequest request; + private final ClientConnection connection; private final URI url; @@ -372,10 +376,12 @@ public class UndertowXhrTransport extends AbstractXhrTransport implements XhrTra private final ByteArrayOutputStream outputStream = new ByteArrayOutputStream(); - public SockJsResponseListener(ClientConnection connection, URI url, + + public SockJsResponseListener(TransportRequest request, ClientConnection connection, URI url, HttpHeaders headers, XhrClientSockJsSession sockJsSession, SettableListenableFuture connectFuture) { + this.request = request; this.connection = connection; this.url = url; this.headers = headers; @@ -455,7 +461,7 @@ public class UndertowXhrTransport extends AbstractXhrTransport implements XhrTra logger.trace("XHR receive request completed."); } IoUtils.safeClose(this.connection); - executeReceiveRequest(this.url, this.headers, this.session, this.connectFuture); + executeReceiveRequest(this.request, this.url, this.headers, this.session, this.connectFuture); } public void onFailure(Throwable failure) { diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/client/XhrClientSockJsSession.java b/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/client/XhrClientSockJsSession.java index 1381f33ee4..59c7d25b47 100644 --- a/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/client/XhrClientSockJsSession.java +++ b/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/client/XhrClientSockJsSession.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2014 the original author or authors. + * Copyright 2002-2015 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -20,6 +20,8 @@ import java.net.InetSocketAddress; import java.net.URI; import java.util.List; +import org.springframework.http.HttpHeaders; +import org.springframework.http.MediaType; import org.springframework.util.Assert; import org.springframework.util.concurrent.SettableListenableFuture; import org.springframework.web.socket.CloseStatus; @@ -39,10 +41,14 @@ import org.springframework.web.socket.sockjs.transport.TransportType; */ public class XhrClientSockJsSession extends AbstractClientSockJsSession { - private final URI sendUrl; - private final XhrTransport transport; + private HttpHeaders headers; + + private HttpHeaders sendHeaders; + + private final URI sendUrl; + private int textMessageSizeLimit = -1; private int binaryMessageSizeLimit = -1; @@ -53,11 +59,21 @@ public class XhrClientSockJsSession extends AbstractClientSockJsSession { super(request, handler, connectFuture); Assert.notNull(transport, "'restTemplate' is required"); - this.sendUrl = request.getSockJsUrlInfo().getTransportUrl(TransportType.XHR_SEND); this.transport = transport; + this.headers = request.getHttpRequestHeaders(); + this.sendHeaders = new HttpHeaders(); + if (this.headers != null) { + this.sendHeaders.putAll(this.headers); + } + this.sendHeaders.setContentType(MediaType.APPLICATION_JSON); + this.sendUrl = request.getSockJsUrlInfo().getTransportUrl(TransportType.XHR_SEND); } + public HttpHeaders getHeaders() { + return this.headers; + } + @Override public InetSocketAddress getLocalAddress() { return null; @@ -100,7 +116,7 @@ public class XhrClientSockJsSession extends AbstractClientSockJsSession { @Override protected void sendInternal(TextMessage message) { - this.transport.executeSendRequest(this.sendUrl, message); + this.transport.executeSendRequest(this.sendUrl, this.sendHeaders, message); } @Override diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/client/XhrTransport.java b/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/client/XhrTransport.java index 6fcf7f1651..d5725ed54a 100644 --- a/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/client/XhrTransport.java +++ b/spring-websocket/src/main/java/org/springframework/web/socket/sockjs/client/XhrTransport.java @@ -17,14 +17,14 @@ package org.springframework.web.socket.sockjs.client; import java.net.URI; +import org.springframework.http.HttpHeaders; import org.springframework.web.socket.TextMessage; /** * A SockJS {@link Transport} that uses HTTP requests to simulate a WebSocket * interaction. The {@code connect} method of the base {@code Transport} interface * is used to receive messages from the server while the - * {@link #executeSendRequest(java.net.URI, org.springframework.web.socket.TextMessage) - * executeSendRequest(URI, TextMessage)} method here is used to send messages. + * {@link #executeSendRequest} method here is used to send messages. * * @author Rossen Stoyanchev * @since 4.1 @@ -35,7 +35,6 @@ public interface XhrTransport extends Transport, InfoReceiver { * An {@code XhrTransport} supports both the "xhr_streaming" and "xhr" SockJS * server transports. From a client perspective there is no implementation * difference. - * *

By default an {@code XhrTransport} will be used with "xhr_streaming" * first and then with "xhr", if the streaming fails to connect. In some * cases it may be useful to suppress streaming so that only "xhr" is used. @@ -44,9 +43,10 @@ public interface XhrTransport extends Transport, InfoReceiver { /** * Execute a request to send the message to the server. + *

Note that as of 4.2 this method accepts a {@code headers} parameter. * @param transportUrl the URL for sending messages. * @param message the message to send */ - void executeSendRequest(URI transportUrl, TextMessage message); + void executeSendRequest(URI transportUrl, HttpHeaders headers, TextMessage message); } diff --git a/spring-websocket/src/test/java/org/springframework/web/socket/sockjs/client/AbstractSockJsIntegrationTests.java b/spring-websocket/src/test/java/org/springframework/web/socket/sockjs/client/AbstractSockJsIntegrationTests.java index 0696bfe4be..4902fd5a7b 100644 --- a/spring-websocket/src/test/java/org/springframework/web/socket/sockjs/client/AbstractSockJsIntegrationTests.java +++ b/spring-websocket/src/test/java/org/springframework/web/socket/sockjs/client/AbstractSockJsIntegrationTests.java @@ -49,6 +49,8 @@ import org.junit.rules.TestName; import org.springframework.beans.factory.annotation.Autowired; import org.springframework.context.annotation.Bean; import org.springframework.context.annotation.Configuration; +import org.springframework.http.HttpHeaders; +import org.springframework.http.server.ServletServerHttpRequest; import org.springframework.scheduling.concurrent.ThreadPoolTaskScheduler; import org.springframework.tests.Assume; import org.springframework.tests.TestGroup; @@ -100,7 +102,7 @@ public abstract class AbstractSockJsIntegrationTests { @BeforeClass public static void performanceTestGroupAssumption() throws Exception { - Assume.group(TestGroup.PERFORMANCE); +// Assume.group(TestGroup.PERFORMANCE); } @@ -164,19 +166,36 @@ public abstract class AbstractSockJsIntegrationTests { @Test public void echoWebSocket() throws Exception { - testEcho(100, createWebSocketTransport()); + testEcho(100, createWebSocketTransport(), null); } @Test public void echoXhrStreaming() throws Exception { - testEcho(100, createXhrTransport()); + testEcho(100, createXhrTransport(), null); } @Test public void echoXhr() throws Exception { AbstractXhrTransport xhrTransport = createXhrTransport(); xhrTransport.setXhrStreamingDisabled(true); - testEcho(100, xhrTransport); + testEcho(100, xhrTransport, null); + } + + // SPR-13254 + + @Test + public void echoXhrWithHeaders() throws Exception { + AbstractXhrTransport xhrTransport = createXhrTransport(); + xhrTransport.setXhrStreamingDisabled(true); + + WebSocketHttpHeaders headers = new WebSocketHttpHeaders(); + headers.add("auth", "123"); + testEcho(10, xhrTransport, headers); + + for (Map.Entry entry : this.testFilter.requests.entrySet()) { + HttpHeaders httpHeaders = entry.getValue(); + assertEquals("No auth header for: " + entry.getKey(), "123", httpHeaders.getFirst("auth")); + } } @Test @@ -246,14 +265,15 @@ public abstract class AbstractSockJsIntegrationTests { } - private void testEcho(int messageCount, Transport transport) throws Exception { + private void testEcho(int messageCount, Transport transport, WebSocketHttpHeaders headers) throws Exception { List messages = new ArrayList<>(); for (int i = 0; i < messageCount; i++) { messages.add(new TextMessage("m" + i)); } TestClientHandler handler = new TestClientHandler(); initSockJsClient(transport); - WebSocketSession session = this.sockJsClient.doHandshake(handler, this.baseUrl + "/echo").get(); + URI url = new URI(this.baseUrl + "/echo"); + WebSocketSession session = this.sockJsClient.doHandshake(handler, headers, url).get(); for (TextMessage message : messages) { session.sendMessage(message); } @@ -386,7 +406,7 @@ public abstract class AbstractSockJsIntegrationTests { private static class TestFilter implements Filter { - private final List requests = new ArrayList<>(); + private final Map requests = new HashMap<>(); private final Map sleepDelayMap = new HashMap<>(); @@ -397,10 +417,13 @@ public abstract class AbstractSockJsIntegrationTests { public void doFilter(ServletRequest request, ServletResponse response, FilterChain chain) throws IOException, ServletException { - this.requests.add(request); + HttpServletRequest httpRequest = (HttpServletRequest) request; + String uri = httpRequest.getRequestURI(); + HttpHeaders headers = new ServletServerHttpRequest(httpRequest).getHeaders(); + this.requests.put(uri, headers); for (String suffix : this.sleepDelayMap.keySet()) { - if (((HttpServletRequest) request).getRequestURI().endsWith(suffix)) { + if ((httpRequest).getRequestURI().endsWith(suffix)) { try { Thread.sleep(this.sleepDelayMap.get(suffix)); break; @@ -411,7 +434,7 @@ public abstract class AbstractSockJsIntegrationTests { } } for (String suffix : this.sendErrorMap.keySet()) { - if (((HttpServletRequest) request).getRequestURI().endsWith(suffix)) { + if ((httpRequest).getRequestURI().endsWith(suffix)) { ((HttpServletResponse) response).sendError(this.sendErrorMap.get(suffix)); return; } diff --git a/spring-websocket/src/test/java/org/springframework/web/socket/sockjs/client/ClientSockJsSessionTests.java b/spring-websocket/src/test/java/org/springframework/web/socket/sockjs/client/ClientSockJsSessionTests.java index 8e153335bc..791aa387af 100644 --- a/spring-websocket/src/test/java/org/springframework/web/socket/sockjs/client/ClientSockJsSessionTests.java +++ b/spring-websocket/src/test/java/org/springframework/web/socket/sockjs/client/ClientSockJsSessionTests.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2014 the original author or authors. + * Copyright 2002-2015 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -64,7 +64,7 @@ public class ClientSockJsSessionTests { public void setup() throws Exception { SockJsUrlInfo urlInfo = new SockJsUrlInfo(new URI("http://example.com")); Transport transport = mock(Transport.class); - TransportRequest request = new DefaultTransportRequest(urlInfo, null, transport, TransportType.XHR, CODEC); + TransportRequest request = new DefaultTransportRequest(urlInfo, null, null, transport, TransportType.XHR, CODEC); this.handler = mock(WebSocketHandler.class); this.connectFuture = new SettableListenableFuture<>(); this.session = new TestClientSockJsSession(request, this.handler, this.connectFuture); diff --git a/spring-websocket/src/test/java/org/springframework/web/socket/sockjs/client/DefaultTransportRequestTests.java b/spring-websocket/src/test/java/org/springframework/web/socket/sockjs/client/DefaultTransportRequestTests.java index d9849ca0ce..8875081a99 100644 --- a/spring-websocket/src/test/java/org/springframework/web/socket/sockjs/client/DefaultTransportRequestTests.java +++ b/spring-websocket/src/test/java/org/springframework/web/socket/sockjs/client/DefaultTransportRequestTests.java @@ -127,7 +127,7 @@ public class DefaultTransportRequestTests { protected DefaultTransportRequest createTransportRequest(Transport transport, TransportType type) throws Exception { SockJsUrlInfo urlInfo = new SockJsUrlInfo(new URI("http://example.com")); - return new DefaultTransportRequest(urlInfo, new HttpHeaders(), transport, type, CODEC); + return new DefaultTransportRequest(urlInfo, new HttpHeaders(), new HttpHeaders(), transport, type, CODEC); } } diff --git a/spring-websocket/src/test/java/org/springframework/web/socket/sockjs/client/RestTemplateXhrTransportTests.java b/spring-websocket/src/test/java/org/springframework/web/socket/sockjs/client/RestTemplateXhrTransportTests.java index 74fa763d5c..ca8f2c83f6 100644 --- a/spring-websocket/src/test/java/org/springframework/web/socket/sockjs/client/RestTemplateXhrTransportTests.java +++ b/spring-websocket/src/test/java/org/springframework/web/socket/sockjs/client/RestTemplateXhrTransportTests.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2014 the original author or authors. + * Copyright 2002-2015 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -182,7 +182,8 @@ public class RestTemplateXhrTransportTests { SockJsUrlInfo urlInfo = new SockJsUrlInfo(new URI("http://example.com")); HttpHeaders headers = new HttpHeaders(); headers.add("h-foo", "h-bar"); - TransportRequest request = new DefaultTransportRequest(urlInfo, headers, transport, TransportType.XHR, CODEC); + TransportRequest request = new DefaultTransportRequest(urlInfo, headers, headers, + transport, TransportType.XHR, CODEC); return transport.connect(request, this.webSocketHandler); } diff --git a/spring-websocket/src/test/java/org/springframework/web/socket/sockjs/client/SockJsClientTests.java b/spring-websocket/src/test/java/org/springframework/web/socket/sockjs/client/SockJsClientTests.java index 918edddbbd..7db25a4ea9 100644 --- a/spring-websocket/src/test/java/org/springframework/web/socket/sockjs/client/SockJsClientTests.java +++ b/spring-websocket/src/test/java/org/springframework/web/socket/sockjs/client/SockJsClientTests.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2014 the original author or authors. + * Copyright 2002-2015 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -16,22 +16,34 @@ package org.springframework.web.socket.sockjs.client; +import java.net.URI; import java.net.URISyntaxException; import java.util.ArrayList; import java.util.List; import org.junit.Before; import org.junit.Test; +import org.mockito.ArgumentCaptor; +import org.springframework.http.HttpHeaders; import org.springframework.http.HttpStatus; import org.springframework.util.concurrent.ListenableFutureCallback; import org.springframework.web.client.HttpServerErrorException; import org.springframework.web.socket.WebSocketHandler; +import org.springframework.web.socket.WebSocketHttpHeaders; import org.springframework.web.socket.WebSocketSession; import org.springframework.web.socket.sockjs.client.TestTransport.XhrTestTransport; -import static org.junit.Assert.*; -import static org.mockito.BDDMockito.*; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertTrue; +import static org.mockito.BDDMockito.any; +import static org.mockito.BDDMockito.given; +import static org.mockito.BDDMockito.mock; +import static org.mockito.BDDMockito.times; +import static org.mockito.BDDMockito.verify; +import static org.mockito.BDDMockito.verifyNoMoreInteractions; +import static org.mockito.BDDMockito.when; /** * Unit tests for {@link org.springframework.web.socket.sockjs.client.SockJsClient}. @@ -102,11 +114,51 @@ public class SockJsClientTests { assertTrue(this.xhrTransport.getRequest().getTransportUrl().toString().endsWith("xhr")); } + // SPR-13254 + + @Test + public void connectWithHandshakeHeaders() throws Exception { + ArgumentCaptor headersCaptor = setupInfoRequest(false); + this.xhrTransport.setStreamingDisabled(true); + + WebSocketHttpHeaders headers = new WebSocketHttpHeaders(); + headers.set("foo", "bar"); + headers.set("auth", "123"); + this.sockJsClient.doHandshake(handler, headers, new URI(URL)).addCallback(this.connectCallback); + + HttpHeaders httpHeaders = headersCaptor.getValue(); + assertEquals(2, httpHeaders.size()); + assertEquals("bar", httpHeaders.getFirst("foo")); + assertEquals("123", httpHeaders.getFirst("auth")); + + httpHeaders = this.xhrTransport.getRequest().getHttpRequestHeaders(); + assertEquals(2, httpHeaders.size()); + assertEquals("bar", httpHeaders.getFirst("foo")); + assertEquals("123", httpHeaders.getFirst("auth")); + } + + @Test + public void connectAndUseSubsetOfHandshakeHeadersForHttpRequests() throws Exception { + ArgumentCaptor headersCaptor = setupInfoRequest(false); + this.xhrTransport.setStreamingDisabled(true); + + WebSocketHttpHeaders headers = new WebSocketHttpHeaders(); + headers.set("foo", "bar"); + headers.set("auth", "123"); + this.sockJsClient.setHttpHeaderNames("auth"); + this.sockJsClient.doHandshake(handler, headers, new URI(URL)).addCallback(this.connectCallback); + + assertEquals(1, headersCaptor.getValue().size()); + assertEquals("123", headersCaptor.getValue().getFirst("auth")); + assertEquals(1, this.xhrTransport.getRequest().getHttpRequestHeaders().size()); + assertEquals("123", this.xhrTransport.getRequest().getHttpRequestHeaders().getFirst("auth")); + } + @Test public void connectSockJsInfo() throws Exception { setupInfoRequest(true); this.sockJsClient.doHandshake(handler, URL); - verify(this.infoReceiver, times(1)).executeInfoRequest(any()); + verify(this.infoReceiver, times(1)).executeInfoRequest(any(), any()); } @Test @@ -115,22 +167,27 @@ public class SockJsClientTests { this.sockJsClient.doHandshake(handler, URL); this.sockJsClient.doHandshake(handler, URL); this.sockJsClient.doHandshake(handler, URL); - verify(this.infoReceiver, times(1)).executeInfoRequest(any()); + verify(this.infoReceiver, times(1)).executeInfoRequest(any(), any()); } @Test public void connectInfoRequestFailure() throws URISyntaxException { HttpServerErrorException exception = new HttpServerErrorException(HttpStatus.SERVICE_UNAVAILABLE); - given(this.infoReceiver.executeInfoRequest(any())).willThrow(exception); + given(this.infoReceiver.executeInfoRequest(any(), any())).willThrow(exception); this.sockJsClient.doHandshake(handler, URL).addCallback(this.connectCallback); verify(this.connectCallback).onFailure(exception); assertFalse(this.webSocketTransport.invoked()); assertFalse(this.xhrTransport.invoked()); } - private void setupInfoRequest(boolean webSocketEnabled) { - given(this.infoReceiver.executeInfoRequest(any())).willReturn("{\"entropy\":123," + - "\"origins\":[\"*:*\"],\"cookie_needed\":true,\"websocket\":" + webSocketEnabled + "}"); + private ArgumentCaptor setupInfoRequest(boolean webSocketEnabled) { + ArgumentCaptor headersCaptor = ArgumentCaptor.forClass(HttpHeaders.class); + when(this.infoReceiver.executeInfoRequest(any(), headersCaptor.capture())).thenReturn( + "{\"entropy\":123," + + "\"origins\":[\"*:*\"]," + + "\"cookie_needed\":true," + + "\"websocket\":" + webSocketEnabled + "}"); + return headersCaptor; } } diff --git a/spring-websocket/src/test/java/org/springframework/web/socket/sockjs/client/TestTransport.java b/spring-websocket/src/test/java/org/springframework/web/socket/sockjs/client/TestTransport.java index 449f429e1a..329b72c699 100644 --- a/spring-websocket/src/test/java/org/springframework/web/socket/sockjs/client/TestTransport.java +++ b/spring-websocket/src/test/java/org/springframework/web/socket/sockjs/client/TestTransport.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2014 the original author or authors. + * Copyright 2002-2015 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -18,9 +18,12 @@ package org.springframework.web.socket.sockjs.client; import java.net.URI; import java.util.Arrays; +import java.util.Collections; import java.util.List; import org.mockito.ArgumentCaptor; + +import org.springframework.http.HttpHeaders; import org.springframework.util.concurrent.ListenableFuture; import org.springframework.util.concurrent.ListenableFutureCallback; import org.springframework.web.socket.TextMessage; @@ -28,7 +31,8 @@ import org.springframework.web.socket.WebSocketHandler; import org.springframework.web.socket.WebSocketSession; import org.springframework.web.socket.sockjs.transport.TransportType; -import static org.mockito.Mockito.*; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.verify; /** * Test SockJS Transport. @@ -51,7 +55,7 @@ class TestTransport implements Transport { @Override public List getTransportTypes() { - return Arrays.asList(TransportType.WEBSOCKET); + return Collections.singletonList(TransportType.WEBSOCKET); } public TransportRequest getRequest() { @@ -95,7 +99,7 @@ class TestTransport implements Transport { @Override public List getTransportTypes() { return (isXhrStreamingDisabled() ? - Arrays.asList(TransportType.XHR) : + Collections.singletonList(TransportType.XHR) : Arrays.asList(TransportType.XHR_STREAMING, TransportType.XHR)); } @@ -109,11 +113,11 @@ class TestTransport implements Transport { } @Override - public void executeSendRequest(URI transportUrl, TextMessage message) { + public void executeSendRequest(URI transportUrl, HttpHeaders headers, TextMessage message) { } @Override - public String executeInfoRequest(URI infoUrl) { + public String executeInfoRequest(URI infoUrl, HttpHeaders headers) { return null; } } diff --git a/spring-websocket/src/test/java/org/springframework/web/socket/sockjs/client/XhrTransportTests.java b/spring-websocket/src/test/java/org/springframework/web/socket/sockjs/client/XhrTransportTests.java index b54c93079a..778831a1ff 100644 --- a/spring-websocket/src/test/java/org/springframework/web/socket/sockjs/client/XhrTransportTests.java +++ b/spring-websocket/src/test/java/org/springframework/web/socket/sockjs/client/XhrTransportTests.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2014 the original author or authors. + * Copyright 2002-2015 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -46,25 +46,25 @@ public class XhrTransportTests { public void infoResponse() throws Exception { TestXhrTransport transport = new TestXhrTransport(); transport.infoResponseToReturn = new ResponseEntity<>("body", HttpStatus.OK); - assertEquals("body", transport.executeInfoRequest(new URI("http://example.com/info"))); + assertEquals("body", transport.executeInfoRequest(new URI("http://example.com/info"), null)); } @Test(expected = HttpServerErrorException.class) public void infoResponseError() throws Exception { TestXhrTransport transport = new TestXhrTransport(); transport.infoResponseToReturn = new ResponseEntity<>("body", HttpStatus.BAD_REQUEST); - assertEquals("body", transport.executeInfoRequest(new URI("http://example.com/info"))); + assertEquals("body", transport.executeInfoRequest(new URI("http://example.com/info"), null)); } @Test public void sendMessage() throws Exception { HttpHeaders requestHeaders = new HttpHeaders(); requestHeaders.set("foo", "bar"); + requestHeaders.setContentType(MediaType.APPLICATION_JSON); TestXhrTransport transport = new TestXhrTransport(); - transport.setRequestHeaders(requestHeaders); transport.sendMessageResponseToReturn = new ResponseEntity<>(HttpStatus.NO_CONTENT); URI url = new URI("http://example.com"); - transport.executeSendRequest(url, new TextMessage("payload")); + transport.executeSendRequest(url, requestHeaders, new TextMessage("payload")); assertEquals(2, transport.actualSendRequestHeaders.size()); assertEquals("bar", transport.actualSendRequestHeaders.getFirst("foo")); assertEquals(MediaType.APPLICATION_JSON, transport.actualSendRequestHeaders.getContentType()); @@ -75,9 +75,10 @@ public class XhrTransportTests { TestXhrTransport transport = new TestXhrTransport(); transport.sendMessageResponseToReturn = new ResponseEntity<>(HttpStatus.BAD_REQUEST); URI url = new URI("http://example.com"); - transport.executeSendRequest(url, new TextMessage("payload")); + transport.executeSendRequest(url, null, new TextMessage("payload")); } + @SuppressWarnings("deprecation") @Test public void connect() throws Exception { HttpHeaders handshakeHeaders = new HttpHeaders(); @@ -101,6 +102,7 @@ public class XhrTransportTests { verify(request).addTimeoutTask(captor.capture()); verify(request).getTransportUrl(); verify(request).getHandshakeHeaders(); + verify(request).getHttpRequestHeaders(); verifyNoMoreInteractions(request); assertEquals(2, transport.actualHandshakeHeaders.size()); @@ -127,7 +129,7 @@ public class XhrTransportTests { @Override - protected ResponseEntity executeInfoRequestInternal(URI infoUrl) { + protected ResponseEntity executeInfoRequestInternal(URI infoUrl, HttpHeaders headers) { return this.infoResponseToReturn; }