From a3fa9c979777e554efad0df429041767f05dfdb8 Mon Sep 17 00:00:00 2001 From: Rossen Stoyanchev Date: Thu, 26 Jun 2014 09:38:46 -0400 Subject: [PATCH] Add check for unused WebSocket sessions Sessions connected to a STOMP endpoint are expected to receive some client messages. Having received none after successfully connecting could be an indication of proxy or network issue. This change adds periodic checks to see if we have not received any messages on a session which is an indication the session isn't going anywhere most likely due to a proxy issue (or unreliable network) and close those sessions. Issue: SPR-11884 --- .../SubProtocolWebSocketHandler.java | 136 +++++++++++++++--- .../SubProtocolWebSocketHandlerTests.java | 38 ++++- 2 files changed, 153 insertions(+), 21 deletions(-) diff --git a/spring-websocket/src/main/java/org/springframework/web/socket/messaging/SubProtocolWebSocketHandler.java b/spring-websocket/src/main/java/org/springframework/web/socket/messaging/SubProtocolWebSocketHandler.java index 572f5ffc29..a5ae409c47 100644 --- a/spring-websocket/src/main/java/org/springframework/web/socket/messaging/SubProtocolWebSocketHandler.java +++ b/spring-websocket/src/main/java/org/springframework/web/socket/messaging/SubProtocolWebSocketHandler.java @@ -16,6 +16,7 @@ package org.springframework.web.socket.messaging; +import java.io.IOException; import java.util.ArrayList; import java.util.Arrays; import java.util.HashSet; @@ -24,6 +25,7 @@ import java.util.Map; import java.util.Set; import java.util.TreeMap; import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.locks.ReentrantLock; import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; @@ -64,8 +66,18 @@ import org.springframework.web.socket.handler.SessionLimitExceededException; public class SubProtocolWebSocketHandler implements WebSocketHandler, SubProtocolCapable, MessageHandler, SmartLifecycle { + /** + * Sessions connected to this handler use a sub-protocol. Hence we expect to + * receive some client messages. If we don't receive any within a minute, the + * connection isn't doing well (proxy issue, slow network?) and can be closed. + * @see #checkSessions() + */ + private final int TIME_TO_FIRST_MESSAGE = 60 * 1000; + + private final Log logger = LogFactory.getLog(SubProtocolWebSocketHandler.class); + private final MessageChannel clientInboundChannel; private final SubscribableChannel clientOutboundChannel; @@ -75,12 +87,16 @@ public class SubProtocolWebSocketHandler implements WebSocketHandler, private SubProtocolHandler defaultProtocolHandler; - private final Map sessions = new ConcurrentHashMap(); + private final Map sessions = new ConcurrentHashMap(); private int sendTimeLimit = 10 * 1000; private int sendBufferSizeLimit = 512 * 1024; + private volatile long lastSessionCheckTime = System.currentTimeMillis(); + + private final ReentrantLock sessionCheckLock = new ReentrantLock(); + private final Object lifecycleMonitor = new Object(); private volatile boolean running = false; @@ -214,12 +230,12 @@ public class SubProtocolWebSocketHandler implements WebSocketHandler, this.clientOutboundChannel.unsubscribe(this); // Notify sessions to stop flushing messages - for (WebSocketSession session : this.sessions.values()) { + for (WebSocketSessionHolder holder : this.sessions.values()) { try { - session.close(CloseStatus.GOING_AWAY); + holder.getSession().close(CloseStatus.GOING_AWAY); } catch (Throwable t) { - logger.error("Failed to close session id '" + session.getId() + "': " + t.getMessage()); + logger.error("Failed to close '" + holder.getSession() + "': " + t.getMessage()); } } } @@ -235,15 +251,11 @@ public class SubProtocolWebSocketHandler implements WebSocketHandler, @Override public void afterConnectionEstablished(WebSocketSession session) throws Exception { - session = new ConcurrentWebSocketSessionDecorator(session, getSendTimeLimit(), getSendBufferSizeLimit()); - - this.sessions.put(session.getId(), session); + this.sessions.put(session.getId(), new WebSocketSessionHolder(session)); if (logger.isDebugEnabled()) { - logger.debug("Started WebSocket session=" + session.getId() + - ", number of sessions=" + this.sessions.size()); + logger.debug("Started session " + session.getId() + ", number of sessions=" + this.sessions.size()); } - findProtocolHandler(session).afterSessionStarted(session, this.clientInboundChannel); } @@ -283,41 +295,49 @@ public class SubProtocolWebSocketHandler implements WebSocketHandler, @Override public void handleMessage(WebSocketSession session, WebSocketMessage message) throws Exception { - findProtocolHandler(session).handleMessageFromClient(session, message, this.clientInboundChannel); + SubProtocolHandler protocolHandler = findProtocolHandler(session); + protocolHandler.handleMessageFromClient(session, message, this.clientInboundChannel); + WebSocketSessionHolder holder = this.sessions.get(session.getId()); + if (holder != null) { + holder.setHasHandledMessages(); + } + else { + // Should never happen + throw new IllegalStateException("Session not found: " + session); + } + checkSessions(); } @Override public void handleMessage(Message message) throws MessagingException { - String sessionId = resolveSessionId(message); if (sessionId == null) { logger.error("sessionId not found in message " + message); return; } - - WebSocketSession session = this.sessions.get(sessionId); - if (session == null) { + WebSocketSessionHolder holder = this.sessions.get(sessionId); + if (holder == null) { logger.error("Session not found for session with id '" + sessionId + "', ignoring message " + message); return; } - + WebSocketSession session = holder.getSession(); try { findProtocolHandler(session).handleMessageToClient(session, message); } catch (SessionLimitExceededException ex) { try { - logger.error("Terminating session id '" + sessionId + "'", ex); + logger.error("Terminating '" + session + "'", ex); // Session may be unresponsive so clear first clearSession(session, ex.getStatus()); session.close(ex.getStatus()); } catch (Exception secondException) { - logger.error("Exception terminating session id '" + sessionId + "'", secondException); + logger.error("Exception terminating '" + sessionId + "'", secondException); } } catch (Exception e) { - logger.error("Failed to send message to client " + message, e); + logger.error("Failed to send message to client " + message + " in " + session, e); } } @@ -337,6 +357,43 @@ public class SubProtocolWebSocketHandler implements WebSocketHandler, return null; } + /** + * Periodically check sessions to ensure they have received at least one + * message or otherwise close them. + */ + private void checkSessions() throws IOException { + long currentTime = System.currentTimeMillis(); + if (!isRunning() && currentTime - this.lastSessionCheckTime < TIME_TO_FIRST_MESSAGE) { + return; + } + try { + if (this.sessionCheckLock.tryLock()) { + for (WebSocketSessionHolder holder : this.sessions.values()) { + if (holder.hasHandledMessages()) { + continue; + } + long timeSinceCreated = currentTime - holder.getCreateTime(); + if (holder.hasHandledMessages() || timeSinceCreated < TIME_TO_FIRST_MESSAGE) { + continue; + } + WebSocketSession session = holder.getSession(); + if (logger.isErrorEnabled()) { + logger.error("No messages received after " + timeSinceCreated + " ms. Closing " + holder); + } + try { + session.close(CloseStatus.PROTOCOL_ERROR); + } + catch (Throwable t) { + logger.error("Failed to close " + session, t); + } + } + } + } + finally { + this.sessionCheckLock.unlock(); + } + } + @Override public void handleTransportError(WebSocketSession session, Throwable exception) throws Exception { } @@ -356,4 +413,45 @@ public class SubProtocolWebSocketHandler implements WebSocketHandler, return false; } + + private static class WebSocketSessionHolder { + + private final WebSocketSession session; + + private final long createTime = System.currentTimeMillis(); + + private volatile boolean handledMessages; + + + private WebSocketSessionHolder(WebSocketSession session) { + this.session = session; + } + + public WebSocketSession getSession() { + return this.session; + } + + public long getCreateTime() { + return this.createTime; + } + + public void setHasHandledMessages() { + this.handledMessages = true; + } + + public boolean hasHandledMessages() { + return this.handledMessages; + } + + @Override + public String toString() { + if (this.session instanceof ConcurrentWebSocketSessionDecorator) { + return ((ConcurrentWebSocketSessionDecorator) this.session).getLastSession().toString(); + } + else { + return this.session.toString(); + } + } + } + } diff --git a/spring-websocket/src/test/java/org/springframework/web/socket/messaging/SubProtocolWebSocketHandlerTests.java b/spring-websocket/src/test/java/org/springframework/web/socket/messaging/SubProtocolWebSocketHandlerTests.java index 31cf82ac76..3ef5d87f0b 100644 --- a/spring-websocket/src/test/java/org/springframework/web/socket/messaging/SubProtocolWebSocketHandlerTests.java +++ b/spring-websocket/src/test/java/org/springframework/web/socket/messaging/SubProtocolWebSocketHandlerTests.java @@ -17,16 +17,24 @@ package org.springframework.web.socket.messaging; import java.util.Arrays; +import java.util.Map; import org.junit.Before; import org.junit.Test; import org.mockito.Mock; import org.mockito.MockitoAnnotations; +import org.springframework.beans.DirectFieldAccessor; import org.springframework.messaging.MessageChannel; import org.springframework.messaging.SubscribableChannel; +import org.springframework.web.socket.CloseStatus; +import org.springframework.web.socket.TextMessage; import org.springframework.web.socket.handler.ConcurrentWebSocketSessionDecorator; import org.springframework.web.socket.handler.TestWebSocketSession; +import static org.junit.Assert.assertEquals; +import static org.junit.Assert.assertFalse; +import static org.junit.Assert.assertNull; +import static org.junit.Assert.assertTrue; import static org.mockito.Mockito.*; /** @@ -56,11 +64,9 @@ public class SubProtocolWebSocketHandlerTests { @Before public void setup() { MockitoAnnotations.initMocks(this); - this.webSocketHandler = new SubProtocolWebSocketHandler(this.inClientChannel, this.outClientChannel); when(stompHandler.getSupportedProtocols()).thenReturn(Arrays.asList("v10.stomp", "v11.stomp", "v12.stomp")); when(mqttHandler.getSupportedProtocols()).thenReturn(Arrays.asList("MQTT")); - this.session = new TestWebSocketSession(); this.session.setId("1"); } @@ -140,4 +146,32 @@ public class SubProtocolWebSocketHandlerTests { this.webSocketHandler.afterConnectionEstablished(session); } + @Test + public void checkSession() throws Exception { + TestWebSocketSession session1 = new TestWebSocketSession("id1"); + TestWebSocketSession session2 = new TestWebSocketSession("id2"); + session1.setAcceptedProtocol("v12.stomp"); + session2.setAcceptedProtocol("v12.stomp"); + + this.webSocketHandler.setProtocolHandlers(Arrays.asList(this.stompHandler)); + this.webSocketHandler.afterConnectionEstablished(session1); + this.webSocketHandler.afterConnectionEstablished(session2); + session1.setOpen(true); + session2.setOpen(true); + + long sixtyOneSecondsAgo = System.currentTimeMillis() - 61 * 1000; + new DirectFieldAccessor(this.webSocketHandler).setPropertyValue("lastSessionCheckTime", sixtyOneSecondsAgo); + Map sessions = (Map) new DirectFieldAccessor(this.webSocketHandler).getPropertyValue("sessions"); + new DirectFieldAccessor(sessions.get("id1")).setPropertyValue("createTime", sixtyOneSecondsAgo); + new DirectFieldAccessor(sessions.get("id2")).setPropertyValue("createTime", sixtyOneSecondsAgo); + + this.webSocketHandler.handleMessage(session1, new TextMessage("foo")); + + assertTrue(session1.isOpen()); + assertFalse(session2.isOpen()); + assertNull(session1.getCloseStatus()); + assertEquals(CloseStatus.PROTOCOL_ERROR, session2.getCloseStatus()); + } + + }