Add SockJS client

This change adds a new implementation of WebSocketClient that can
connect to a SockJS server using one of the SockJS transports
"websocket", "xhr_streaming", or "xhr". From a client perspective
there is no implementation difference between "xhr_streaming" and
"xhr". Just keep receiving and when the response is complete,
start over. Other SockJS transports are browser specific
and therefore not relevant in Java ("eventsource", "htmlfile" or
iframe based variations).

The client loosely mimics the behavior of the JavaScript SockJS client.
First it sends an info request to find the server capabilities,
then it tries to connect with each configured transport, falling
back, or forcing a timeout and then falling back, until one of the
configured transports succeeds.

The WebSocketTransport can be configured with any Spring Framework
WebSocketClient implementation (currently JSR-356 or Jetty 9).

The XhrTransport currently has a RestTemplate-based and a Jetty
HttpClient-based implementations. To use those to simulate a large
number of users be sure to configure Jetty's HttpClient executor
and maxConnectionsPerDestination to high numbers. The same is true
for whichever underlying HTTP library is used with the RestTemplate
(e.g. maxConnPerRoute and maxConnTotal in Apache HttpComponents).

Issue: SPR-10797
This commit is contained in:
Rossen Stoyanchev
2014-04-28 22:01:14 -04:00
parent dc1d85d045
commit e82df99a22
32 changed files with 3835 additions and 13 deletions

View File

@@ -78,6 +78,7 @@ public class JettyWebSocketTestServer implements WebSocketTestServer {
@Override
public void stop() throws Exception {
if (this.jettyServer.isRunning()) {
this.jettyServer.setStopTimeout(0);
this.jettyServer.stop();
}
}

View File

@@ -0,0 +1,394 @@
/*
* Copyright 2002-2014 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.web.socket.sockjs.client;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.hamcrest.Matchers;
import org.junit.After;
import org.junit.Before;
import org.junit.Test;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import org.springframework.scheduling.concurrent.ThreadPoolTaskScheduler;
import org.springframework.util.concurrent.ListenableFutureCallback;
import org.springframework.web.context.support.AnnotationConfigWebApplicationContext;
import org.springframework.web.socket.CloseStatus;
import org.springframework.web.socket.TextMessage;
import org.springframework.web.socket.WebSocketSession;
import org.springframework.web.socket.WebSocketTestServer;
import org.springframework.web.socket.config.annotation.EnableWebSocket;
import org.springframework.web.socket.config.annotation.WebSocketConfigurer;
import org.springframework.web.socket.config.annotation.WebSocketHandlerRegistry;
import org.springframework.web.socket.handler.TextWebSocketHandler;
import org.springframework.web.socket.server.HandshakeHandler;
import org.springframework.web.socket.server.RequestUpgradeStrategy;
import org.springframework.web.socket.server.support.DefaultHandshakeHandler;
import javax.servlet.Filter;
import javax.servlet.FilterChain;
import javax.servlet.FilterConfig;
import javax.servlet.ServletException;
import javax.servlet.ServletRequest;
import javax.servlet.ServletResponse;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import java.io.IOException;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.concurrent.BlockingQueue;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.LinkedBlockingQueue;
import java.util.concurrent.TimeUnit;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertNotNull;
import static org.junit.Assert.assertThat;
import static org.junit.Assert.assertTrue;
import static org.hamcrest.Matchers.*;
/**
* Integration tests using the
* {@link org.springframework.web.socket.sockjs.client.SockJsClient}.
* against actual SockJS server endpoints.
*
* @author Rossen Stoyanchev
*/
public abstract class AbstractSockJsIntegrationTests {
protected Log logger = LogFactory.getLog(getClass());
private WebSocketTestServer server;
private AnnotationConfigWebApplicationContext wac;
private ErrorFilter errorFilter;
private String baseUrl;
@Before
public void setup() throws Exception {
this.errorFilter = new ErrorFilter();
this.wac = new AnnotationConfigWebApplicationContext();
this.wac.register(TestConfig.class, upgradeStrategyConfigClass());
this.wac.refresh();
this.server = createWebSocketTestServer();
this.server.deployConfig(this.wac, this.errorFilter);
this.server.start();
this.baseUrl = "http://localhost:" + this.server.getPort();
}
@After
public void teardown() throws Exception {
try {
this.server.undeployConfig();
}
catch (Throwable t) {
logger.error("Failed to undeploy application config", t);
}
try {
this.server.stop();
}
catch (Throwable t) {
logger.error("Failed to stop server", t);
}
}
protected abstract WebSocketTestServer createWebSocketTestServer();
protected abstract Class<?> upgradeStrategyConfigClass();
protected abstract Transport getWebSocketTransport();
protected abstract AbstractXhrTransport getXhrTransport();
protected SockJsClient createSockJsClient(Transport... transports) {
return new SockJsClient(Arrays.<Transport>asList(transports));
}
@Test
public void echoWebSocket() throws Exception {
testEcho(100, getWebSocketTransport());
}
@Test
public void echoXhrStreaming() throws Exception {
testEcho(100, getXhrTransport());
}
@Test
public void echoXhr() throws Exception {
AbstractXhrTransport xhrTransport = getXhrTransport();
xhrTransport.setXhrStreamingDisabled(true);
testEcho(100, xhrTransport);
}
@Test
public void closeAfterOneMessageWebSocket() throws Exception {
testCloseAfterOneMessage(getWebSocketTransport());
}
@Test
public void closeAfterOneMessageXhrStreaming() throws Exception {
testCloseAfterOneMessage(getXhrTransport());
}
@Test
public void closeAfterOneMessageXhr() throws Exception {
AbstractXhrTransport xhrTransport = getXhrTransport();
xhrTransport.setXhrStreamingDisabled(true);
testCloseAfterOneMessage(xhrTransport);
}
@Test
public void infoRequestFailure() throws Exception {
TestClientHandler handler = new TestClientHandler();
this.errorFilter.responseStatusMap.put("/info", 500);
CountDownLatch latch = new CountDownLatch(1);
createSockJsClient(getWebSocketTransport()).doHandshake(handler, this.baseUrl + "/echo").addCallback(
new ListenableFutureCallback<WebSocketSession>() {
@Override
public void onSuccess(WebSocketSession result) {
}
@Override
public void onFailure(Throwable t) {
latch.countDown();
}
}
);
assertTrue(latch.await(5000, TimeUnit.MILLISECONDS));
}
@Test
public void fallbackAfterTransportFailure() throws Exception {
this.errorFilter.responseStatusMap.put("/websocket", 200);
this.errorFilter.responseStatusMap.put("/xhr_streaming", 500);
TestClientHandler handler = new TestClientHandler();
Transport[] transports = { getWebSocketTransport(), getXhrTransport() };
WebSocketSession session = createSockJsClient(transports).doHandshake(handler, this.baseUrl + "/echo").get();
assertEquals("Fallback didn't occur", XhrClientSockJsSession.class, session.getClass());
TextMessage message = new TextMessage("message1");
session.sendMessage(message);
handler.awaitMessage(message, 5000);
}
@Test(timeout = 5000)
public void fallbackAfterConnectTimeout() throws Exception {
TestClientHandler clientHandler = new TestClientHandler();
this.errorFilter.sleepDelayMap.put("/xhr_streaming", 10000L);
this.errorFilter.responseStatusMap.put("/xhr_streaming", 503);
SockJsClient sockJsClient = createSockJsClient(getXhrTransport());
sockJsClient.setTaskScheduler(this.wac.getBean(ThreadPoolTaskScheduler.class));
WebSocketSession clientSession = sockJsClient.doHandshake(clientHandler, this.baseUrl + "/echo").get();
assertEquals("Fallback didn't occur", XhrClientSockJsSession.class, clientSession.getClass());
TextMessage message = new TextMessage("message1");
clientSession.sendMessage(message);
clientHandler.awaitMessage(message, 5000);
clientSession.close();
}
private void testEcho(int messageCount, Transport transport) throws Exception {
List<TextMessage> messages = new ArrayList<>();
for (int i = 0; i < messageCount; i++) {
messages.add(new TextMessage("m" + i));
}
TestClientHandler handler = new TestClientHandler();
WebSocketSession session = createSockJsClient(transport).doHandshake(handler, this.baseUrl + "/echo").get();
for (TextMessage message : messages) {
session.sendMessage(message);
}
handler.awaitMessageCount(messageCount, 5000);
for (TextMessage message : messages) {
assertTrue("Message not received: " + message, handler.receivedMessages.remove(message));
}
assertEquals("Remaining messages: " + handler.receivedMessages, 0, handler.receivedMessages.size());
session.close();
}
private void testCloseAfterOneMessage(Transport transport) throws Exception {
TestClientHandler clientHandler = new TestClientHandler();
createSockJsClient(transport).doHandshake(clientHandler, this.baseUrl + "/test").get();
TestServerHandler serverHandler = this.wac.getBean(TestServerHandler.class);
assertNotNull("afterConnectionEstablished should have been called", clientHandler.session);
serverHandler.awaitSession(5000);
TextMessage message = new TextMessage("message1");
serverHandler.session.sendMessage(message);
clientHandler.awaitMessage(message, 5000);
CloseStatus expected = new CloseStatus(3500, "Oops");
serverHandler.session.close(expected);
CloseStatus actual = clientHandler.awaitCloseStatus(5000);
if (transport instanceof XhrTransport) {
assertThat(actual, Matchers.anyOf(equalTo(expected), equalTo(new CloseStatus(3000, "Go away!"))));
}
else {
assertEquals(expected, actual);
}
}
@Configuration
@EnableWebSocket
static class TestConfig implements WebSocketConfigurer {
@Autowired
private RequestUpgradeStrategy upgradeStrategy;
@Override
public void registerWebSocketHandlers(WebSocketHandlerRegistry registry) {
HandshakeHandler handshakeHandler = new DefaultHandshakeHandler(this.upgradeStrategy);
registry.addHandler(new EchoHandler(), "/echo").setHandshakeHandler(handshakeHandler).withSockJS();
registry.addHandler(testServerHandler(), "/test").setHandshakeHandler(handshakeHandler).withSockJS();
}
@Bean
public TestServerHandler testServerHandler() {
return new TestServerHandler();
}
}
private static interface Condition {
boolean match();
}
private static void awaitEvent(Condition condition, long timeToWait, String description) {
long timeToSleep = 200;
for (int i = 0 ; i < Math.floor(timeToWait / timeToSleep); i++) {
if (condition.match()) {
return;
}
try {
Thread.sleep(timeToSleep);
}
catch (InterruptedException e) {
throw new IllegalStateException("Interrupted while waiting for " + description, e);
}
}
throw new IllegalStateException("Timed out waiting for " + description);
}
private static class TestClientHandler extends TextWebSocketHandler {
private final BlockingQueue<TextMessage> receivedMessages = new LinkedBlockingQueue<>();
private volatile WebSocketSession session;
private volatile CloseStatus closeStatus;
@Override
public void afterConnectionEstablished(WebSocketSession session) throws Exception {
this.session = session;
}
@Override
protected void handleTextMessage(WebSocketSession session, TextMessage message) throws Exception {
this.receivedMessages.add(message);
}
@Override
public void afterConnectionClosed(WebSocketSession session, CloseStatus status) throws Exception {
this.closeStatus = status;
}
public void awaitMessageCount(final int count, long timeToWait) throws Exception {
awaitEvent(() -> receivedMessages.size() >= count, timeToWait,
count + " number of messages. Received so far: " + this.receivedMessages);
}
public void awaitMessage(TextMessage expected, long timeToWait) throws InterruptedException {
TextMessage actual = this.receivedMessages.poll(timeToWait, TimeUnit.MILLISECONDS);
assertNotNull("Timed out waiting for [" + expected + "]", actual);
assertEquals(expected, actual);
}
public CloseStatus awaitCloseStatus(long timeToWait) throws InterruptedException {
awaitEvent(() -> this.closeStatus != null, timeToWait, " CloseStatus");
return this.closeStatus;
}
}
private static class TestServerHandler extends TextWebSocketHandler {
private WebSocketSession session;
@Override
public void afterConnectionEstablished(WebSocketSession session) throws Exception {
this.session = session;
}
public WebSocketSession awaitSession(long timeToWait) throws InterruptedException {
awaitEvent(() -> this.session != null, timeToWait, " session");
return this.session;
}
}
private static class EchoHandler extends TextWebSocketHandler {
@Override
protected void handleTextMessage(WebSocketSession session, TextMessage message) throws Exception {
session.sendMessage(message);
}
}
private static class ErrorFilter implements Filter {
private final Map<String, Integer> responseStatusMap = new HashMap<>();
private final Map<String, Long> sleepDelayMap = new HashMap<>();
@Override
public void doFilter(ServletRequest req, ServletResponse resp, FilterChain chain) throws IOException, ServletException {
for (String suffix : this.sleepDelayMap.keySet()) {
if (((HttpServletRequest) req).getRequestURI().endsWith(suffix)) {
try {
Thread.sleep(this.sleepDelayMap.get(suffix));
break;
}
catch (InterruptedException e) {
e.printStackTrace();
}
}
}
for (String suffix : this.responseStatusMap.keySet()) {
if (((HttpServletRequest) req).getRequestURI().endsWith(suffix)) {
((HttpServletResponse) resp).sendError(this.responseStatusMap.get(suffix));
return;
}
}
chain.doFilter(req, resp);
}
@Override
public void init(FilterConfig filterConfig) throws ServletException {
}
@Override
public void destroy() {
}
}
}

View File

@@ -0,0 +1,280 @@
/*
* Copyright 2002-2014 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.web.socket.sockjs.client;
import org.junit.Before;
import org.junit.Rule;
import org.junit.Test;
import org.junit.rules.ExpectedException;
import org.springframework.util.concurrent.SettableListenableFuture;
import org.springframework.web.socket.CloseStatus;
import org.springframework.web.socket.TextMessage;
import org.springframework.web.socket.WebSocketExtension;
import org.springframework.web.socket.WebSocketHandler;
import org.springframework.web.socket.WebSocketSession;
import org.springframework.web.socket.sockjs.frame.Jackson2SockJsMessageCodec;
import org.springframework.web.socket.sockjs.frame.SockJsFrame;
import org.springframework.web.socket.sockjs.transport.TransportType;
import java.io.IOException;
import java.net.InetSocketAddress;
import java.net.URI;
import java.util.List;
import static org.junit.Assert.assertThat;
import static org.hamcrest.CoreMatchers.*;
import static org.junit.Assert.assertTrue;
import static org.mockito.Mockito.*;
/**
* Unit tests for
* {@link org.springframework.web.socket.sockjs.client.AbstractClientSockJsSession}.
*
* @author Rossen Stoyanchev
*/
public class ClientSockJsSessionTests {
private static final Jackson2SockJsMessageCodec CODEC = new Jackson2SockJsMessageCodec();
private TestClientSockJsSession session;
private WebSocketHandler handler;
private SettableListenableFuture<WebSocketSession> connectFuture;
@Rule
public final ExpectedException thrown = ExpectedException.none();
@Before
public void setup() throws Exception {
SockJsUrlInfo urlInfo = new SockJsUrlInfo(new URI("http://example.com"));
Transport transport = mock(Transport.class);
TransportRequest request = new DefaultTransportRequest(urlInfo, null, transport, TransportType.XHR, CODEC);
this.handler = mock(WebSocketHandler.class);
this.connectFuture = new SettableListenableFuture<>();
this.session = new TestClientSockJsSession(request, this.handler, this.connectFuture);
}
@Test
public void handleFrameOpen() throws Exception {
assertThat(this.session.isOpen(), is(false));
this.session.handleFrame(SockJsFrame.openFrame().getContent());
assertThat(this.session.isOpen(), is(true));
assertTrue(this.connectFuture.isDone());
assertThat(this.connectFuture.get(), sameInstance(this.session));
verify(this.handler).afterConnectionEstablished(this.session);
verifyNoMoreInteractions(this.handler);
}
@Test
public void handleFrameOpenWhenStatusNotNew() throws Exception {
this.session.handleFrame(SockJsFrame.openFrame().getContent());
assertThat(this.session.isOpen(), is(true));
this.session.handleFrame(SockJsFrame.openFrame().getContent());
assertThat(this.session.disconnectStatus, equalTo(new CloseStatus(1006, "Server lost session")));
}
@Test
public void handleFrameOpenWithWebSocketHandlerException() throws Exception {
doThrow(new IllegalStateException("Fake error")).when(this.handler).afterConnectionEstablished(this.session);
this.session.handleFrame(SockJsFrame.openFrame().getContent());
assertThat(this.session.isOpen(), is(true));
}
@Test
public void handleFrameMessage() throws Exception {
this.session.handleFrame(SockJsFrame.openFrame().getContent());
this.session.handleFrame(SockJsFrame.messageFrame(CODEC, "foo", "bar").getContent());
verify(this.handler).afterConnectionEstablished(this.session);
verify(this.handler).handleMessage(this.session, new TextMessage("foo"));
verify(this.handler).handleMessage(this.session, new TextMessage("bar"));
verifyNoMoreInteractions(this.handler);
}
@Test
public void handleFrameMessageWhenNotOpen() throws Exception {
this.session.handleFrame(SockJsFrame.openFrame().getContent());
this.session.close();
reset(this.handler);
this.session.handleFrame(SockJsFrame.messageFrame(CODEC, "foo", "bar").getContent());
verifyNoMoreInteractions(this.handler);
}
@Test
public void handleFrameMessageWithBadData() throws Exception {
this.session.handleFrame(SockJsFrame.openFrame().getContent());
this.session.handleFrame("a['bad data");
assertThat(this.session.isOpen(), equalTo(false));
assertThat(this.session.disconnectStatus, equalTo(CloseStatus.BAD_DATA));
verify(this.handler).afterConnectionEstablished(this.session);
verifyNoMoreInteractions(this.handler);
}
@Test
public void handleFrameMessageWithWebSocketHandlerException() throws Exception {
this.session.handleFrame(SockJsFrame.openFrame().getContent());
doThrow(new IllegalStateException("Fake error")).when(this.handler).handleMessage(this.session, new TextMessage("foo"));
doThrow(new IllegalStateException("Fake error")).when(this.handler).handleMessage(this.session, new TextMessage("bar"));
this.session.handleFrame(SockJsFrame.messageFrame(CODEC, "foo", "bar").getContent());
assertThat(this.session.isOpen(), equalTo(true));
verify(this.handler).afterConnectionEstablished(this.session);
verify(this.handler).handleMessage(this.session, new TextMessage("foo"));
verify(this.handler).handleMessage(this.session, new TextMessage("bar"));
verifyNoMoreInteractions(this.handler);
}
@Test
public void handleFrameClose() throws Exception {
this.session.handleFrame(SockJsFrame.openFrame().getContent());
this.session.handleFrame(SockJsFrame.closeFrame(1007, "").getContent());
assertThat(this.session.isOpen(), equalTo(false));
assertThat(this.session.disconnectStatus, equalTo(new CloseStatus(1007, "")));
verify(this.handler).afterConnectionEstablished(this.session);
verifyNoMoreInteractions(this.handler);
}
@Test
public void handleTransportError() throws Exception {
final IllegalStateException ex = new IllegalStateException("Fake error");
this.session.handleTransportError(ex);
verify(this.handler).handleTransportError(this.session, ex);
verifyNoMoreInteractions(this.handler);
}
@Test
public void afterTransportClosed() throws Exception {
this.session.handleFrame(SockJsFrame.openFrame().getContent());
this.session.afterTransportClosed(CloseStatus.SERVER_ERROR);
assertThat(this.session.isOpen(), equalTo(false));
verify(this.handler).afterConnectionEstablished(this.session);
verify(this.handler).afterConnectionClosed(this.session, CloseStatus.SERVER_ERROR);
verifyNoMoreInteractions(this.handler);
}
@Test
public void close() throws Exception {
this.session.handleFrame(SockJsFrame.openFrame().getContent());
this.session.close();
assertThat(this.session.isOpen(), equalTo(false));
assertThat(this.session.disconnectStatus, equalTo(CloseStatus.NORMAL));
verify(this.handler).afterConnectionEstablished(this.session);
verifyNoMoreInteractions(this.handler);
}
@Test
public void closeWithStatus() throws Exception {
this.session.handleFrame(SockJsFrame.openFrame().getContent());
this.session.close(new CloseStatus(3000, "reason"));
assertThat(this.session.disconnectStatus, equalTo(new CloseStatus(3000, "reason")));
}
@Test
public void closeWithNullStatus() throws Exception {
this.session.handleFrame(SockJsFrame.openFrame().getContent());
this.thrown.expect(IllegalArgumentException.class);
this.thrown.expectMessage("Invalid close status");
this.session.close(null);
}
@Test
public void closeWithStatusOutOfRange() throws Exception {
this.session.handleFrame(SockJsFrame.openFrame().getContent());
this.thrown.expect(IllegalArgumentException.class);
this.thrown.expectMessage("Invalid close status");
this.session.close(new CloseStatus(2999, "reason"));
}
@Test
public void timeoutTask() {
this.session.getTimeoutTask().run();
assertThat(this.session.disconnectStatus, equalTo(new CloseStatus(2007, "Transport timed out")));
}
@Test
public void send() throws Exception {
this.session.handleFrame(SockJsFrame.openFrame().getContent());
this.session.sendMessage(new TextMessage("foo"));
assertThat(this.session.sentMessage, equalTo(new TextMessage("[\"foo\"]")));
}
private static class TestClientSockJsSession extends AbstractClientSockJsSession {
private TextMessage sentMessage;
private CloseStatus disconnectStatus;
protected TestClientSockJsSession(TransportRequest request, WebSocketHandler handler,
SettableListenableFuture<WebSocketSession> connectFuture) {
super(request, handler, connectFuture);
}
@Override
protected void sendInternal(TextMessage textMessage) throws IOException {
this.sentMessage = textMessage;
}
@Override
protected void disconnect(CloseStatus status) throws IOException {
this.disconnectStatus = status;
}
@Override
public InetSocketAddress getLocalAddress() {
return null;
}
@Override
public InetSocketAddress getRemoteAddress() {
return null;
}
@Override
public String getAcceptedProtocol() {
return null;
}
@Override
public void setTextMessageSizeLimit(int messageSizeLimit) {
}
@Override
public int getTextMessageSizeLimit() {
return 0;
}
@Override
public void setBinaryMessageSizeLimit(int messageSizeLimit) {
}
@Override
public int getBinaryMessageSizeLimit() {
return 0;
}
@Override
public List<WebSocketExtension> getExtensions() {
return null;
}
}
}

View File

@@ -0,0 +1,139 @@
/*
* Copyright 2002-2014 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.web.socket.sockjs.client;
import org.junit.Before;
import org.junit.Rule;
import org.junit.Test;
import org.junit.rules.ExpectedException;
import org.mockito.ArgumentCaptor;
import org.springframework.http.HttpHeaders;
import org.springframework.scheduling.TaskScheduler;
import org.springframework.util.concurrent.ListenableFutureCallback;
import org.springframework.util.concurrent.SettableListenableFuture;
import org.springframework.web.socket.WebSocketSession;
import org.springframework.web.socket.sockjs.frame.Jackson2SockJsMessageCodec;
import org.springframework.web.socket.sockjs.transport.TransportType;
import java.io.IOException;
import java.net.URI;
import java.util.Date;
import java.util.concurrent.ExecutionException;
import static org.junit.Assert.assertFalse;
import static org.junit.Assert.assertSame;
import static org.junit.Assert.assertTrue;
import static org.mockito.Matchers.any;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.verifyNoMoreInteractions;
/**
* Unit tests for {@link DefaultTransportRequest}.
*
* @author Rossen Stoyanchev
*/
public class DefaultTransportRequestTests {
private static final Jackson2SockJsMessageCodec CODEC = new Jackson2SockJsMessageCodec();
private SettableListenableFuture<WebSocketSession> connectFuture;
private ListenableFutureCallback<WebSocketSession> connectCallback;
private TestTransport webSocketTransport;
private TestTransport xhrTransport;
@Rule
public final ExpectedException thrown = ExpectedException.none();
@SuppressWarnings("unchecked")
@Before
public void setup() throws Exception {
this.connectCallback = mock(ListenableFutureCallback.class);
this.connectFuture = new SettableListenableFuture<>();
this.connectFuture.addCallback(this.connectCallback);
this.webSocketTransport = new TestTransport("WebSocketTestTransport");
this.xhrTransport = new TestTransport("XhrTestTransport");
}
@Test
@SuppressWarnings("unchecked")
public void connect() throws Exception {
DefaultTransportRequest request = createTransportRequest(this.webSocketTransport, TransportType.WEBSOCKET);
request.connect(null, this.connectFuture);
WebSocketSession session = mock(WebSocketSession.class);
this.webSocketTransport.getConnectCallback().onSuccess(session);
assertSame(session, this.connectFuture.get());
}
@Test
public void fallbackAfterTransportError() throws Exception {
DefaultTransportRequest request1 = createTransportRequest(this.webSocketTransport, TransportType.WEBSOCKET);
DefaultTransportRequest request2 = createTransportRequest(this.xhrTransport, TransportType.XHR_STREAMING);
request1.setFallbackRequest(request2);
request1.connect(null, this.connectFuture);
// Transport error => fallback
this.webSocketTransport.getConnectCallback().onFailure(new IOException("Fake exception 1"));
assertFalse(this.connectFuture.isDone());
assertTrue(this.xhrTransport.invoked());
// Transport error => no more fallback
this.xhrTransport.getConnectCallback().onFailure(new IOException("Fake exception 2"));
assertTrue(this.connectFuture.isDone());
this.thrown.expect(ExecutionException.class);
this.thrown.expectMessage("Fake exception 2");
this.connectFuture.get();
}
@Test
public void fallbackAfterTimeout() throws Exception {
TaskScheduler scheduler = mock(TaskScheduler.class);
Runnable sessionCleanupTask = mock(Runnable.class);
DefaultTransportRequest request1 = createTransportRequest(this.webSocketTransport, TransportType.WEBSOCKET);
DefaultTransportRequest request2 = createTransportRequest(this.xhrTransport, TransportType.XHR_STREAMING);
request1.setFallbackRequest(request2);
request1.setTimeoutScheduler(scheduler);
request1.addTimeoutTask(sessionCleanupTask);
request1.connect(null, this.connectFuture);
assertTrue(this.webSocketTransport.invoked());
assertFalse(this.xhrTransport.invoked());
// Get and invoke the scheduled timeout task
ArgumentCaptor<Runnable> taskCaptor = ArgumentCaptor.forClass(Runnable.class);
verify(scheduler).schedule(taskCaptor.capture(), any(Date.class));
verifyNoMoreInteractions(scheduler);
taskCaptor.getValue().run();
assertTrue(this.xhrTransport.invoked());
verify(sessionCleanupTask).run();
}
protected DefaultTransportRequest createTransportRequest(Transport transport, TransportType type) throws Exception {
SockJsUrlInfo urlInfo = new SockJsUrlInfo(new URI("http://example.com"));
return new DefaultTransportRequest(urlInfo, new HttpHeaders(), transport, type, CODEC);
}
}

View File

@@ -0,0 +1,101 @@
/*
* Copyright 2002-2014 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.web.socket.sockjs.client;
import org.eclipse.jetty.client.HttpClient;
import org.eclipse.jetty.websocket.client.WebSocketClient;
import org.junit.After;
import org.junit.Before;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import org.springframework.web.socket.JettyWebSocketTestServer;
import org.springframework.web.socket.client.jetty.JettyWebSocketClient;
import org.springframework.web.socket.server.RequestUpgradeStrategy;
import org.springframework.web.socket.server.jetty.JettyRequestUpgradeStrategy;
import java.util.ArrayList;
import java.util.List;
/**
* SockJS integration tests using Jetty for client and server.
*
* @author Rossen Stoyanchev
*/
public class JettySockJsIntegrationTests extends AbstractSockJsIntegrationTests {
private WebSocketClient webSocketClient;
private HttpClient httpClient;
@Before
public void setup() throws Exception {
super.setup();
this.webSocketClient = new WebSocketClient();
this.webSocketClient.start();
this.httpClient = new HttpClient();
this.httpClient.start();
}
@After
public void teardown() throws Exception {
super.teardown();
try {
this.webSocketClient.stop();
}
catch (Throwable ex) {
logger.error("Failed to stop Jetty WebSocketClient", ex);
}
try {
this.httpClient.stop();
}
catch (Throwable ex) {
logger.error("Failed to stop Jetty HttpClient", ex);
}
}
@Override
protected JettyWebSocketTestServer createWebSocketTestServer() {
return new JettyWebSocketTestServer();
}
@Override
protected Class<?> upgradeStrategyConfigClass() {
return JettyTestConfig.class;
}
@Override
protected Transport getWebSocketTransport() {
return new WebSocketTransport(new JettyWebSocketClient(this.webSocketClient));
}
@Override
protected AbstractXhrTransport getXhrTransport() {
return new JettyXhrTransport(this.httpClient);
}
@Configuration
static class JettyTestConfig {
@Bean
public RequestUpgradeStrategy upgradeStrategy() {
return new JettyRequestUpgradeStrategy();
}
}
}

View File

@@ -0,0 +1,228 @@
/*
* Copyright 2002-2014 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.web.socket.sockjs.client;
import org.junit.Before;
import org.junit.Test;
import org.springframework.core.task.SyncTaskExecutor;
import org.springframework.http.HttpHeaders;
import org.springframework.http.HttpMethod;
import org.springframework.http.HttpStatus;
import org.springframework.http.client.ClientHttpResponse;
import org.springframework.messaging.Message;
import org.springframework.messaging.MessageHeaders;
import org.springframework.messaging.simp.stomp.StompCommand;
import org.springframework.messaging.simp.stomp.StompEncoder;
import org.springframework.messaging.simp.stomp.StompHeaderAccessor;
import org.springframework.messaging.support.MessageBuilder;
import org.springframework.util.concurrent.ListenableFuture;
import org.springframework.util.concurrent.ListenableFutureCallback;
import org.springframework.web.client.HttpServerErrorException;
import org.springframework.web.client.RequestCallback;
import org.springframework.web.client.ResponseExtractor;
import org.springframework.web.client.RestClientException;
import org.springframework.web.client.RestOperations;
import org.springframework.web.client.RestTemplate;
import org.springframework.web.socket.CloseStatus;
import org.springframework.web.socket.TextMessage;
import org.springframework.web.socket.WebSocketHandler;
import org.springframework.web.socket.WebSocketSession;
import org.springframework.web.socket.sockjs.frame.Jackson2SockJsMessageCodec;
import org.springframework.web.socket.sockjs.frame.SockJsFrame;
import org.springframework.web.socket.sockjs.transport.TransportType;
import java.io.ByteArrayInputStream;
import java.io.IOException;
import java.io.InputStream;
import java.net.URI;
import java.nio.charset.Charset;
import java.util.Arrays;
import java.util.Queue;
import java.util.concurrent.CountDownLatch;
import java.util.concurrent.LinkedBlockingDeque;
import static org.mockito.Matchers.eq;
import static org.mockito.Mockito.*;
import static org.mockito.Mockito.verifyNoMoreInteractions;
/**
* Unit tests for {@link RestTemplateXhrTransport}.
*
* @author Rossen Stoyanchev
*/
public class RestTemplateXhrTransportTests {
private static final Jackson2SockJsMessageCodec CODEC = new Jackson2SockJsMessageCodec();
private WebSocketHandler webSocketHandler;
@Before
public void setup() throws Exception {
this.webSocketHandler = mock(WebSocketHandler.class);
}
@Test
public void connectReceiveAndClose() throws Exception {
String body = "o\n" + "a[\"foo\"]\n" + "c[3000,\"Go away!\"]";
ClientHttpResponse response = response(HttpStatus.OK, body);
connect(response);
verify(this.webSocketHandler).afterConnectionEstablished(any());
verify(this.webSocketHandler).handleMessage(any(), eq(new TextMessage("foo")));
verify(this.webSocketHandler).afterConnectionClosed(any(), eq(new CloseStatus(3000, "Go away!")));
verifyNoMoreInteractions(this.webSocketHandler);
}
@Test
public void connectReceiveAndCloseWithPrelude() throws Exception {
StringBuilder sb = new StringBuilder(2048);
for (int i=0; i < 2048; i++) {
sb.append('h');
}
String body = sb.toString() + "\n" + "o\n" + "a[\"foo\"]\n" + "c[3000,\"Go away!\"]";
ClientHttpResponse response = response(HttpStatus.OK, body);
connect(response);
verify(this.webSocketHandler).afterConnectionEstablished(any());
verify(this.webSocketHandler).handleMessage(any(), eq(new TextMessage("foo")));
verify(this.webSocketHandler).afterConnectionClosed(any(), eq(new CloseStatus(3000, "Go away!")));
verifyNoMoreInteractions(this.webSocketHandler);
}
@Test
public void connectReceiveAndCloseWithStompFrame() throws Exception {
StompHeaderAccessor accessor = StompHeaderAccessor.create(StompCommand.SEND);
accessor.setDestination("/destination");
MessageHeaders headers = accessor.getMessageHeaders();
Message<byte[]> message = MessageBuilder.createMessage("body".getBytes(Charset.forName("UTF-8")), headers);
byte[] bytes = new StompEncoder().encode(message);
TextMessage textMessage = new TextMessage(bytes);
SockJsFrame frame = SockJsFrame.messageFrame(new Jackson2SockJsMessageCodec(), textMessage.getPayload());
String body = "o\n" + frame.getContent() + "\n" + "c[3000,\"Go away!\"]";
ClientHttpResponse response = response(HttpStatus.OK, body);
connect(response);
verify(this.webSocketHandler).afterConnectionEstablished(any());
verify(this.webSocketHandler).handleMessage(any(), eq(textMessage));
verify(this.webSocketHandler).afterConnectionClosed(any(), eq(new CloseStatus(3000, "Go away!")));
verifyNoMoreInteractions(this.webSocketHandler);
}
@Test
public void connectFailure() throws Exception {
final HttpServerErrorException expected = new HttpServerErrorException(HttpStatus.INTERNAL_SERVER_ERROR);
RestOperations restTemplate = mock(RestOperations.class);
when(restTemplate.execute(any(), eq(HttpMethod.POST), any(), any())).thenThrow(expected);
final CountDownLatch latch = new CountDownLatch(1);
connect(restTemplate).addCallback(
new ListenableFutureCallback<WebSocketSession>() {
@Override
public void onSuccess(WebSocketSession result) {
}
@Override
public void onFailure(Throwable actual) {
if (actual == expected) {
latch.countDown();
}
}
}
);
verifyNoMoreInteractions(this.webSocketHandler);
}
@Test
public void errorResponseStatus() throws Exception {
connect(response(HttpStatus.OK, "o\n"), response(HttpStatus.INTERNAL_SERVER_ERROR, "Oops"));
verify(this.webSocketHandler).afterConnectionEstablished(any());
verify(this.webSocketHandler).handleTransportError(any(), any());
verify(this.webSocketHandler).afterConnectionClosed(any(), any());
verifyNoMoreInteractions(this.webSocketHandler);
}
@Test
public void responseClosedAfterDisconnected() throws Exception {
String body = "o\n" + "c[3000,\"Go away!\"]\n" + "a[\"foo\"]\n";
ClientHttpResponse response = response(HttpStatus.OK, body);
connect(response);
verify(this.webSocketHandler).afterConnectionEstablished(any());
verify(this.webSocketHandler).afterConnectionClosed(any(), any());
verifyNoMoreInteractions(this.webSocketHandler);
verify(response).close();
}
private ListenableFuture<WebSocketSession> connect(ClientHttpResponse... responses) throws Exception {
return connect(new TestRestTemplate(responses));
}
private ListenableFuture<WebSocketSession> connect(RestOperations restTemplate, ClientHttpResponse... responses)
throws Exception {
RestTemplateXhrTransport transport = new RestTemplateXhrTransport(restTemplate);
transport.setTaskExecutor(new SyncTaskExecutor());
SockJsUrlInfo urlInfo = new SockJsUrlInfo(new URI("http://example.com"));
HttpHeaders headers = new HttpHeaders();
headers.add("h-foo", "h-bar");
TransportRequest request = new DefaultTransportRequest(urlInfo, headers, transport, TransportType.XHR, CODEC);
return transport.connect(request, this.webSocketHandler);
}
private ClientHttpResponse response(HttpStatus status, String body) throws IOException {
ClientHttpResponse response = mock(ClientHttpResponse.class);
InputStream inputStream = getInputStream(body);
when(response.getStatusCode()).thenReturn(status);
when(response.getBody()).thenReturn(inputStream);
return response;
}
private InputStream getInputStream(String content) {
byte[] bytes = content.getBytes(Charset.forName("UTF-8"));
return new ByteArrayInputStream(bytes);
}
private static class TestRestTemplate extends RestTemplate {
private Queue<ClientHttpResponse> responses = new LinkedBlockingDeque<>();
private TestRestTemplate(ClientHttpResponse... responses) {
this.responses.addAll(Arrays.asList(responses));
}
@Override
public <T> T execute(URI url, HttpMethod method, RequestCallback callback, ResponseExtractor<T> extractor) throws RestClientException {
try {
extractor.extractData(this.responses.remove());
}
catch (Throwable t) {
throw new RestClientException("Failed to invoke extractor", t);
}
return null;
}
}
}

View File

@@ -0,0 +1,137 @@
/*
* Copyright 2002-2014 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.web.socket.sockjs.client;
import org.junit.Before;
import org.junit.Test;
import org.springframework.http.HttpStatus;
import org.springframework.util.concurrent.ListenableFutureCallback;
import org.springframework.web.client.HttpServerErrorException;
import org.springframework.web.socket.WebSocketHandler;
import org.springframework.web.socket.WebSocketSession;
import org.springframework.web.socket.sockjs.client.TestTransport.XhrTestTransport;
import java.net.URISyntaxException;
import java.util.ArrayList;
import java.util.List;
import static org.junit.Assert.assertFalse;
import static org.junit.Assert.assertTrue;
import static org.mockito.Mockito.*;
/**
* Unit tests for {@link org.springframework.web.socket.sockjs.client.SockJsClient}.
*
* @author Rossen Stoyanchev
*/
public class SockJsClientTests {
private static final String URL = "http://example.com";
private static final WebSocketHandler handler = mock(WebSocketHandler.class);
private SockJsClient sockJsClient;
private InfoReceiver infoReceiver;
private TestTransport webSocketTransport;
private XhrTestTransport xhrTransport;
private ListenableFutureCallback<WebSocketSession> connectCallback;
@Before
@SuppressWarnings("unchecked")
public void setup() {
this.infoReceiver = mock(InfoReceiver.class);
this.webSocketTransport = new TestTransport("WebSocketTestTransport");
this.xhrTransport = new XhrTestTransport("XhrTestTransport");
List<Transport> transports = new ArrayList<>();
transports.add(this.webSocketTransport);
transports.add(this.xhrTransport);
this.sockJsClient = new SockJsClient(transports);
this.sockJsClient.setInfoReceiver(this.infoReceiver);
this.connectCallback = mock(ListenableFutureCallback.class);
}
@Test
public void connectWebSocket() throws Exception {
setupInfoRequest(true);
this.sockJsClient.doHandshake(handler, URL).addCallback(this.connectCallback);
assertTrue(this.webSocketTransport.invoked());
WebSocketSession session = mock(WebSocketSession.class);
this.webSocketTransport.getConnectCallback().onSuccess(session);
verify(this.connectCallback).onSuccess(session);
verifyNoMoreInteractions(this.connectCallback);
}
@Test
public void connectWebSocketDisabled() throws URISyntaxException {
setupInfoRequest(false);
this.sockJsClient.doHandshake(handler, URL);
assertFalse(this.webSocketTransport.invoked());
assertTrue(this.xhrTransport.invoked());
assertTrue(this.xhrTransport.getRequest().getTransportUrl().toString().endsWith("xhr_streaming"));
}
@Test
public void connectXhrStreamingDisabled() throws Exception {
setupInfoRequest(false);
this.xhrTransport.setStreamingDisabled(true);
this.sockJsClient.doHandshake(handler, URL).addCallback(this.connectCallback);
assertFalse(this.webSocketTransport.invoked());
assertTrue(this.xhrTransport.invoked());
assertTrue(this.xhrTransport.getRequest().getTransportUrl().toString().endsWith("xhr"));
}
@Test
public void connectSockJsInfo() throws Exception {
setupInfoRequest(true);
this.sockJsClient.doHandshake(handler, URL);
verify(this.infoReceiver, times(1)).executeInfoRequest(any());
}
@Test
public void connectSockJsInfoCached() throws Exception {
setupInfoRequest(true);
this.sockJsClient.doHandshake(handler, URL);
this.sockJsClient.doHandshake(handler, URL);
this.sockJsClient.doHandshake(handler, URL);
verify(this.infoReceiver, times(1)).executeInfoRequest(any());
}
@Test
@SuppressWarnings("unchecked")
public void connectInfoRequestFailure() throws URISyntaxException {
HttpServerErrorException exception = new HttpServerErrorException(HttpStatus.SERVICE_UNAVAILABLE);
when(this.infoReceiver.executeInfoRequest(any())).thenThrow(exception);
this.sockJsClient.doHandshake(handler, URL).addCallback(this.connectCallback);
verify(this.connectCallback).onFailure(exception);
assertFalse(this.webSocketTransport.invoked());
assertFalse(this.xhrTransport.invoked());
}
private void setupInfoRequest(boolean webSocketEnabled) {
when(this.infoReceiver.executeInfoRequest(any())).thenReturn("{\"entropy\":123," +
"\"origins\":[\"*:*\"],\"cookie_needed\":true,\"websocket\":" + webSocketEnabled + "}");
}
}

View File

@@ -0,0 +1,90 @@
/*
* Copyright 2002-2014 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.web.socket.sockjs.client;
import org.junit.Assert;
import org.junit.Test;
import org.springframework.web.socket.sockjs.frame.Jackson2SockJsMessageCodec;
import org.springframework.web.socket.sockjs.transport.TransportType;
import java.net.URI;
import static org.hamcrest.CoreMatchers.equalTo;
import static org.hamcrest.CoreMatchers.is;
import static org.hamcrest.MatcherAssert.assertThat;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertTrue;
/**
* Unit tests for {@code SockJsUrlInfo}.
* @author Rossen Stoyanchev
*/
public class SockJsUrlInfoTests {
@Test
public void serverId() throws Exception {
SockJsUrlInfo info = new SockJsUrlInfo(new URI("http://example.com"));
int serverId = Integer.valueOf(info.getServerId());
assertTrue("Invalid serverId: " + serverId, serverId > 0 && serverId < 1000);
}
@Test
public void sessionId() throws Exception {
SockJsUrlInfo info = new SockJsUrlInfo(new URI("http://example.com"));
assertEquals("Invalid sessionId: " + info.getSessionId(), 32, info.getSessionId().length());
}
@Test
public void infoUrl() throws Exception {
testInfoUrl("http", "http");
testInfoUrl("http", "http");
testInfoUrl("https", "https");
testInfoUrl("https", "https");
testInfoUrl("ws", "http");
testInfoUrl("ws", "http");
testInfoUrl("wss", "https");
testInfoUrl("wss", "https");
}
private void testInfoUrl(String scheme, String expectedScheme) throws Exception {
SockJsUrlInfo info = new SockJsUrlInfo(new URI(scheme + "://example.com"));
Assert.assertThat(info.getInfoUrl(), is(equalTo(new URI(expectedScheme + "://example.com/info"))));
}
@Test
public void transportUrl() throws Exception {
testTransportUrl("http", "http", TransportType.XHR_STREAMING);
testTransportUrl("http", "ws", TransportType.WEBSOCKET);
testTransportUrl("https", "https", TransportType.XHR_STREAMING);
testTransportUrl("https", "wss", TransportType.WEBSOCKET);
testTransportUrl("ws", "http", TransportType.XHR_STREAMING);
testTransportUrl("ws", "ws", TransportType.WEBSOCKET);
testTransportUrl("wss", "https", TransportType.XHR_STREAMING);
testTransportUrl("wss", "wss", TransportType.WEBSOCKET);
}
private void testTransportUrl(String scheme, String expectedScheme, TransportType transportType) throws Exception {
SockJsUrlInfo info = new SockJsUrlInfo(new URI(scheme + "://example.com"));
String serverId = info.getServerId();
String sessionId = info.getSessionId();
String transport = transportType.toString().toLowerCase();
URI expected = new URI(expectedScheme + "://example.com/" + serverId + "/" + sessionId + "/" + transport);
assertThat(info.getTransportUrl(transportType), equalTo(expected));
}
}

View File

@@ -0,0 +1,106 @@
/*
* Copyright 2002-2014 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.web.socket.sockjs.client;
import org.mockito.ArgumentCaptor;
import org.springframework.util.concurrent.ListenableFuture;
import org.springframework.util.concurrent.ListenableFutureCallback;
import org.springframework.web.socket.TextMessage;
import org.springframework.web.socket.WebSocketHandler;
import org.springframework.web.socket.WebSocketSession;
import java.net.URI;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.verify;
/**
* Test SockJS Transport.
*
* @author Rossen Stoyanchev
*/
class TestTransport implements Transport {
private final String name;
private TransportRequest request;
private ListenableFuture future;
public TestTransport(String name) {
this.name = name;
}
public TransportRequest getRequest() {
return this.request;
}
public boolean invoked() {
return this.future != null;
}
@SuppressWarnings("unchecked")
public ListenableFutureCallback<WebSocketSession> getConnectCallback() {
ArgumentCaptor<ListenableFutureCallback> captor = ArgumentCaptor.forClass(ListenableFutureCallback.class);
verify(this.future).addCallback(captor.capture());
return captor.getValue();
}
@SuppressWarnings("unchecked")
@Override
public ListenableFuture<WebSocketSession> connect(TransportRequest request, WebSocketHandler handler) {
this.request = request;
this.future = mock(ListenableFuture.class);
return this.future;
}
@Override
public String toString() {
return "TestTransport[" + name + "]";
}
static class XhrTestTransport extends TestTransport implements XhrTransport {
private boolean streamingDisabled;
XhrTestTransport(String name) {
super(name);
}
public void setStreamingDisabled(boolean streamingDisabled) {
this.streamingDisabled = streamingDisabled;
}
@Override
public boolean isXhrStreamingDisabled() {
return this.streamingDisabled;
}
@Override
public void executeSendRequest(URI transportUrl, TextMessage message) {
}
@Override
public String executeInfoRequest(URI infoUrl) {
return null;
}
}
}

View File

@@ -0,0 +1,155 @@
/*
* Copyright 2002-2014 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.web.socket.sockjs.client;
import org.junit.Test;
import org.mockito.ArgumentCaptor;
import org.springframework.http.HttpHeaders;
import org.springframework.http.HttpStatus;
import org.springframework.http.MediaType;
import org.springframework.http.ResponseEntity;
import org.springframework.util.concurrent.SettableListenableFuture;
import org.springframework.web.client.HttpServerErrorException;
import org.springframework.web.socket.TextMessage;
import org.springframework.web.socket.WebSocketHandler;
import org.springframework.web.socket.WebSocketSession;
import java.net.URI;
import static org.junit.Assert.assertEquals;
import static org.junit.Assert.assertFalse;
import static org.junit.Assert.assertTrue;
import static org.mockito.Matchers.notNull;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.verify;
import static org.mockito.Mockito.verifyNoMoreInteractions;
import static org.mockito.Mockito.when;
/**
* Unit tests for
* {@link org.springframework.web.socket.sockjs.client.AbstractXhrTransport}.
*
* @author Rossen Stoyanchev
*/
public class XhrTransportTests {
@Test
public void infoResponse() throws Exception {
TestXhrTransport transport = new TestXhrTransport();
transport.infoResponseToReturn = new ResponseEntity<>("body", HttpStatus.OK);
assertEquals("body", transport.executeInfoRequest(new URI("http://example.com/info")));
}
@Test(expected = HttpServerErrorException.class)
public void infoResponseError() throws Exception {
TestXhrTransport transport = new TestXhrTransport();
transport.infoResponseToReturn = new ResponseEntity<>("body", HttpStatus.BAD_REQUEST);
assertEquals("body", transport.executeInfoRequest(new URI("http://example.com/info")));
}
@Test
public void sendMessage() throws Exception {
HttpHeaders requestHeaders = new HttpHeaders();
requestHeaders.set("foo", "bar");
TestXhrTransport transport = new TestXhrTransport();
transport.setRequestHeaders(requestHeaders);
transport.sendMessageResponseToReturn = new ResponseEntity<>(HttpStatus.NO_CONTENT);
URI url = new URI("http://example.com");
transport.executeSendRequest(url, new TextMessage("payload"));
assertEquals(2, transport.actualSendRequestHeaders.size());
assertEquals("bar", transport.actualSendRequestHeaders.getFirst("foo"));
assertEquals(MediaType.APPLICATION_JSON, transport.actualSendRequestHeaders.getContentType());
}
@Test(expected = HttpServerErrorException.class)
public void sendMessageError() throws Exception {
TestXhrTransport transport = new TestXhrTransport();
transport.sendMessageResponseToReturn = new ResponseEntity<>(HttpStatus.BAD_REQUEST);
URI url = new URI("http://example.com");
transport.executeSendRequest(url, new TextMessage("payload"));
}
@Test
public void connect() throws Exception {
HttpHeaders handshakeHeaders = new HttpHeaders();
handshakeHeaders.setOrigin("foo");
TransportRequest request = mock(TransportRequest.class);
when(request.getSockJsUrlInfo()).thenReturn(new SockJsUrlInfo(new URI("http://example.com")));
when(request.getHandshakeHeaders()).thenReturn(handshakeHeaders);
HttpHeaders requestHeaders = new HttpHeaders();
requestHeaders.set("foo", "bar");
TestXhrTransport transport = new TestXhrTransport();
transport.setRequestHeaders(requestHeaders);
WebSocketHandler handler = mock(WebSocketHandler.class);
transport.connect(request, handler);
ArgumentCaptor<Runnable> captor = ArgumentCaptor.forClass(Runnable.class);
verify(request).getSockJsUrlInfo();
verify(request).addTimeoutTask(captor.capture());
verify(request).getTransportUrl();
verify(request).getHandshakeHeaders();
verifyNoMoreInteractions(request);
assertEquals(2, transport.actualHandshakeHeaders.size());
assertEquals("foo", transport.actualHandshakeHeaders.getOrigin());
assertEquals("bar", transport.actualHandshakeHeaders.getFirst("foo"));
assertFalse(transport.actualSession.isDisconnected());
captor.getValue().run();
assertTrue(transport.actualSession.isDisconnected());
}
private static class TestXhrTransport extends AbstractXhrTransport {
private ResponseEntity<String> infoResponseToReturn;
private ResponseEntity<String> sendMessageResponseToReturn;
private HttpHeaders actualSendRequestHeaders;
private HttpHeaders actualHandshakeHeaders;
private XhrClientSockJsSession actualSession;
@Override
protected ResponseEntity<String> executeInfoRequestInternal(URI infoUrl) {
return this.infoResponseToReturn;
}
@Override
protected ResponseEntity<String> executeSendRequestInternal(URI url, HttpHeaders headers, TextMessage message) {
this.actualSendRequestHeaders = headers;
return this.sendMessageResponseToReturn;
}
@Override
protected void connectInternal(TransportRequest request, WebSocketHandler handler, URI receiveUrl,
HttpHeaders handshakeHeaders, XhrClientSockJsSession session,
SettableListenableFuture<WebSocketSession> connectFuture) {
this.actualHandshakeHeaders = handshakeHeaders;
this.actualSession = session;
}
}
}