diff --git a/spring-integration-ip-extensions/src/main/java/org/springframework/integration/x/ip/websocket/WebSocketSerializer.java b/spring-integration-ip-extensions/src/main/java/org/springframework/integration/x/ip/websocket/WebSocketSerializer.java index 2acf77c..a3cfecd 100644 --- a/spring-integration-ip-extensions/src/main/java/org/springframework/integration/x/ip/websocket/WebSocketSerializer.java +++ b/spring-integration-ip-extensions/src/main/java/org/springframework/integration/x/ip/websocket/WebSocketSerializer.java @@ -95,6 +95,7 @@ public class WebSocketSerializer extends AbstractHttpSwitchingDeserializer imple int lenBytes; int payloadLen = this.server ? 0 : 0x80; //masked boolean close = theFrame.getType() == WebSocketFrame.TYPE_CLOSE; + boolean ping = theFrame.getType() == WebSocketFrame.TYPE_PING; boolean pong = theFrame.getType() == WebSocketFrame.TYPE_PONG; byte[] bytes = theFrame.getBinary() != null ? theFrame.getBinary() : data.getBytes("UTF-8"); @@ -116,7 +117,10 @@ public class WebSocketSerializer extends AbstractHttpSwitchingDeserializer imple } int mask = (int) System.currentTimeMillis(); ByteBuffer buffer = ByteBuffer.allocate(length + 6 + lenBytes); - if (pong) { + if (ping) { + buffer.put((byte) 0x89); + } + else if (pong) { buffer.put((byte) 0x8a); } else if (close) { diff --git a/spring-integration-ip-extensions/src/main/java/org/springframework/integration/x/ip/websocket/WebSocketTcpConnectionInterceptorFactory.java b/spring-integration-ip-extensions/src/main/java/org/springframework/integration/x/ip/websocket/WebSocketTcpConnectionInterceptorFactory.java index fac6b63..93ef3db 100644 --- a/spring-integration-ip-extensions/src/main/java/org/springframework/integration/x/ip/websocket/WebSocketTcpConnectionInterceptorFactory.java +++ b/spring-integration-ip-extensions/src/main/java/org/springframework/integration/x/ip/websocket/WebSocketTcpConnectionInterceptorFactory.java @@ -15,8 +15,10 @@ */ package org.springframework.integration.x.ip.websocket; +import java.util.Date; import java.util.HashMap; import java.util.Map; +import java.util.Map.Entry; import java.util.concurrent.ConcurrentHashMap; import org.apache.commons.logging.Log; @@ -29,6 +31,7 @@ import org.springframework.integration.MessagingException; import org.springframework.integration.aggregator.ResequencingMessageGroupProcessor; import org.springframework.integration.aggregator.ResequencingMessageHandler; import org.springframework.integration.channel.DirectChannel; +import org.springframework.integration.context.IntegrationObjectSupport; import org.springframework.integration.core.MessageHandler; import org.springframework.integration.endpoint.EventDrivenConsumer; import org.springframework.integration.ip.tcp.connection.TcpConnection; @@ -39,6 +42,7 @@ import org.springframework.integration.ip.tcp.connection.TcpNioConnection; import org.springframework.integration.support.MessageBuilder; import org.springframework.integration.x.ip.websocket.WebSocketEvent.WebSocketEventType; import org.springframework.integration.x.ip.websocket.WebSocketSerializer.WebSocketState; +import org.springframework.scheduling.TaskScheduler; import org.springframework.util.Assert; /** @@ -46,13 +50,73 @@ import org.springframework.util.Assert; * @since 3.0 * */ -public class WebSocketTcpConnectionInterceptorFactory implements TcpConnectionInterceptorFactory { +public class WebSocketTcpConnectionInterceptorFactory extends IntegrationObjectSupport + implements TcpConnectionInterceptorFactory { + + private static final Message PING = MessageBuilder.withPayload( + new WebSocketFrame(WebSocketFrame.TYPE_PING, "Ping from SI")).build(); private static final Log logger = LogFactory.getLog(WebSocketTcpConnectionInterceptor.class); private final Map connections = new ConcurrentHashMap(); + private volatile TaskScheduler taskScheduler; + + private volatile long pingInterval = 25000; + + private final Runnable pinger = new Runnable() { + + @Override + public void run() { + long pingFilter = System.currentTimeMillis() - pingInterval; + for (Entry entry : connections.entrySet()) { + TcpConnection connection = entry.getKey(); + String connectionId = connection.getConnectionId(); + if (entry.getValue().getLastSend() <= pingFilter) { + try { + if (logger.isDebugEnabled()) { + logger.debug("Sending Ping to " + connectionId); + } + connection.send(PING); + } + catch (Exception e) { + logger.error("Failed to send Ping to " + connectionId, e); + connection.close(); + } + } + else { + if (logger.isTraceEnabled()) { + logger.trace("Skipping PING for " + connectionId + " due to recent send"); + } + } + } + if (pingInterval > 0) { + taskScheduler.schedule(pinger, new Date(System.currentTimeMillis() + pingInterval)); + } + } + }; + + @Override + public void setTaskScheduler(TaskScheduler taskScheduler) { + this.taskScheduler = taskScheduler; + } + + public void setPingInterval(long pingInterval) { + this.pingInterval = pingInterval; + } + + @Override + protected void onInit() throws Exception { + super.onInit(); + if (this.pingInterval > 0) { + if (this.taskScheduler == null) { + this.taskScheduler = this.getTaskScheduler(); + } + this.taskScheduler.schedule(this.pinger, new Date(System.currentTimeMillis() + this.pingInterval)); + } + } + @Override public TcpConnectionInterceptorSupport getInterceptor() { return new WebSocketTcpConnectionInterceptor(); @@ -62,6 +126,7 @@ public class WebSocketTcpConnectionInterceptorFactory implements TcpConnectionIn return this.connections.get(connection); } + public class WebSocketTcpConnectionInterceptor extends TcpConnectionInterceptorSupport { private volatile boolean shook; @@ -70,6 +135,8 @@ public class WebSocketTcpConnectionInterceptorFactory implements TcpConnectionIn private final EventDrivenConsumer resequencer; + private long lastSend; + public WebSocketTcpConnectionInterceptor() { super(); ResequencingMessageHandler handler = new ResequencingMessageHandler(new ResequencingMessageGroupProcessor()); @@ -89,6 +156,10 @@ public class WebSocketTcpConnectionInterceptorFactory implements TcpConnectionIn this.resequencer.start(); } + public long getLastSend() { + return lastSend; + } + /** * When using NIO, we have to resequence the messages because frames may * arrive out of order. This is particularly an issue for some of the @@ -160,7 +231,8 @@ public class WebSocketTcpConnectionInterceptorFactory implements TcpConnectionIn else if (payload.getType() == WebSocketFrame.TYPE_PING) { try { if (logger.isDebugEnabled()) { - logger.debug("Ping:" + new String(payload.getBinary(), "UTF-8")); + logger.debug("Ping received on " + this.getConnectionId() + ":" + + new String(payload.getBinary(), "UTF-8")); } if (payload.getBinary().length > 125) { this.protocolViolation(message); @@ -178,7 +250,7 @@ public class WebSocketTcpConnectionInterceptorFactory implements TcpConnectionIn } else if (payload.getType() == WebSocketFrame.TYPE_PONG) { if (logger.isDebugEnabled()) { - logger.debug("Pong"); + logger.debug("Pong received on " + this.getConnectionId()); } } else if (this.shook) { @@ -243,6 +315,13 @@ public class WebSocketTcpConnectionInterceptorFactory implements TcpConnectionIn super.close(); } + + @Override + public void send(Message message) throws Exception { + super.send(message); + this.lastSend = System.currentTimeMillis(); + } + private void doHandshake(WebSocketFrame frame, MessageHeaders messageHeaders) throws Exception { try { WebSocketFrame handshake = this.getRequiredDeserializer().generateHandshake(frame);