Commit 2123b267 authored by Phillip Webb's avatar Phillip Webb

Add HTTP tunnel support

Add server and client components to support tunneling of binary TCP
protocols over HTTP. Primarily designed to support Java's remote
debug protocol (JDWP).

See gh-3087
parent c27b63b3
/*
* Copyright 2012-2015 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.boot.developertools.tunnel.client;
import java.io.Closeable;
import java.io.IOException;
import java.net.MalformedURLException;
import java.net.URI;
import java.net.URISyntaxException;
import java.net.URL;
import java.nio.ByteBuffer;
import java.nio.channels.WritableByteChannel;
import java.util.concurrent.Executor;
import java.util.concurrent.Executors;
import java.util.concurrent.ThreadFactory;
import java.util.concurrent.atomic.AtomicLong;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.springframework.boot.developertools.tunnel.payload.HttpTunnelPayload;
import org.springframework.boot.developertools.tunnel.payload.HttpTunnelPayloadForwarder;
import org.springframework.http.HttpMethod;
import org.springframework.http.HttpStatus;
import org.springframework.http.client.ClientHttpRequest;
import org.springframework.http.client.ClientHttpRequestFactory;
import org.springframework.http.client.ClientHttpResponse;
import org.springframework.util.Assert;
/**
* {@link TunnelConnection} implementation that uses HTTP to transfer data.
*
* @author Phillip Webb
* @author Rob Winch
* @since 1.3.0
* @see TunnelClient
* @see org.springframework.boot.developertools.tunnel.server.HttpTunnelServer
*/
public class HttpTunnelConnection implements TunnelConnection {
private static Log logger = LogFactory.getLog(HttpTunnelConnection.class);
private final URI uri;
private final ClientHttpRequestFactory requestFactory;
private final Executor executor;
/**
* Create a new {@link HttpTunnelConnection} instance.
* @param url the URL to connect to
* @param requestFactory the HTTP request factory
*/
public HttpTunnelConnection(String url, ClientHttpRequestFactory requestFactory) {
this(url, requestFactory, null);
}
/**
* Create a new {@link HttpTunnelConnection} instance.
* @param url the URL to connect to
* @param requestFactory the HTTP request factory
* @param executor the executor used to handle connections
*/
protected HttpTunnelConnection(String url, ClientHttpRequestFactory requestFactory,
Executor executor) {
Assert.hasLength(url, "URL must not be empty");
Assert.notNull(requestFactory, "RequestFactory must not be null");
try {
this.uri = new URL(url).toURI();
}
catch (URISyntaxException ex) {
throw new IllegalArgumentException("Malformed URL '" + url + "'");
}
catch (MalformedURLException ex) {
throw new IllegalArgumentException("Malformed URL '" + url + "'");
}
this.requestFactory = requestFactory;
this.executor = (executor == null ? Executors
.newCachedThreadPool(new TunnelThreadFactory()) : executor);
}
@Override
public TunnelChannel open(WritableByteChannel incomingChannel, Closeable closeable)
throws Exception {
logger.trace("Opening HTTP tunnel to " + this.uri);
return new TunnelChannel(incomingChannel, closeable);
}
protected final ClientHttpRequest createRequest(boolean hasPayload)
throws IOException {
HttpMethod method = (hasPayload ? HttpMethod.POST : HttpMethod.GET);
return this.requestFactory.createRequest(this.uri, method);
}
/**
* A {@link WritableByteChannel} used to transfer traffic.
*/
protected class TunnelChannel implements WritableByteChannel {
private final HttpTunnelPayloadForwarder forwarder;
private final Closeable closeable;
private boolean open = true;
private AtomicLong requestSeq = new AtomicLong();
public TunnelChannel(WritableByteChannel incomingChannel, Closeable closeable) {
this.forwarder = new HttpTunnelPayloadForwarder(incomingChannel);
this.closeable = closeable;
openNewConnection(null);
}
@Override
public boolean isOpen() {
return this.open;
}
@Override
public void close() throws IOException {
if (this.open) {
this.open = false;
this.closeable.close();
}
}
@Override
public int write(ByteBuffer src) throws IOException {
int size = src.remaining();
if (size > 0) {
openNewConnection(new HttpTunnelPayload(
this.requestSeq.incrementAndGet(), src));
}
return size;
}
private synchronized void openNewConnection(final HttpTunnelPayload payload) {
HttpTunnelConnection.this.executor.execute(new Runnable() {
@Override
public void run() {
try {
sendAndReceive(payload);
}
catch (IOException ex) {
logger.trace("Unexpected connection error", ex);
closeQuitely();
}
}
private void closeQuitely() {
try {
close();
}
catch (IOException ex) {
}
}
});
}
private void sendAndReceive(HttpTunnelPayload payload) throws IOException {
ClientHttpRequest request = createRequest(payload != null);
if (payload != null) {
payload.logIncoming();
payload.assignTo(request);
}
handleResponse(request.execute());
}
private void handleResponse(ClientHttpResponse response) throws IOException {
if (response.getStatusCode() == HttpStatus.GONE) {
close();
return;
}
if (response.getStatusCode() == HttpStatus.OK) {
HttpTunnelPayload payload = HttpTunnelPayload.get(response);
if (payload != null) {
this.forwarder.forward(payload);
}
}
if (response.getStatusCode() != HttpStatus.TOO_MANY_REQUESTS) {
openNewConnection(null);
}
}
}
/**
* {@link ThreadFactory} used to create the tunnel thread.
*/
private static class TunnelThreadFactory implements ThreadFactory {
@Override
public Thread newThread(Runnable runnable) {
Thread thread = new Thread(runnable, "HTTP Tunnel Connection");
thread.setDaemon(true);
return thread;
}
}
}
/*
* Copyright 2012-2015 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.boot.developertools.tunnel.client;
import java.io.Closeable;
import java.io.IOException;
import java.net.InetSocketAddress;
import java.net.ServerSocket;
import java.nio.ByteBuffer;
import java.nio.channels.ServerSocketChannel;
import java.nio.channels.SocketChannel;
import java.nio.channels.WritableByteChannel;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.springframework.beans.factory.SmartInitializingSingleton;
import org.springframework.util.Assert;
/**
* The client side component of a socket tunnel. Starts a {@link ServerSocket} of the
* specified port for local clients to connect to.
*
* @author Phillip Webb
* @since 1.3.0
*/
public class TunnelClient implements SmartInitializingSingleton {
private static final int BUFFER_SIZE = 1024 * 100;
private static final Log logger = LogFactory.getLog(TunnelClient.class);
private final int listenPort;
private final TunnelConnection tunnelConnection;
private TunnelClientListeners listeners = new TunnelClientListeners();
private ServerThread serverThread;
public TunnelClient(int listenPort, TunnelConnection tunnelConnection) {
Assert.isTrue(listenPort > 0, "ListenPort must be positive");
Assert.notNull(tunnelConnection, "TunnelConnection must not be null");
this.listenPort = listenPort;
this.tunnelConnection = tunnelConnection;
}
@Override
public void afterSingletonsInstantiated() {
if (this.serverThread == null) {
try {
start();
}
catch (IOException ex) {
throw new IllegalStateException(ex);
}
}
}
/**
* Start the client and accept incoming connections on the port.
* @throws IOException
*/
public synchronized void start() throws IOException {
Assert.state(this.serverThread == null, "Server already started");
ServerSocketChannel serverSocketChannel = ServerSocketChannel.open();
serverSocketChannel.socket().bind(new InetSocketAddress(this.listenPort));
logger.trace("Listening for TCP traffic to tunnel on port " + this.listenPort);
this.serverThread = new ServerThread(serverSocketChannel);
this.serverThread.start();
}
/**
* Stop the client, disconnecting any servers.
* @throws IOException
*/
public synchronized void stop() throws IOException {
if (this.serverThread != null) {
logger.trace("Closing tunnel client on port " + this.listenPort);
this.serverThread.close();
try {
this.serverThread.join(2000);
}
catch (InterruptedException ex) {
}
this.serverThread = null;
}
}
protected final ServerThread getServerThread() {
return this.serverThread;
}
public void addListener(TunnelClientListener listener) {
this.listeners.addListener(listener);
}
public void removeListener(TunnelClientListener listener) {
this.listeners.removeListener(listener);
}
/**
* The main server thread.
*/
protected class ServerThread extends Thread {
private final ServerSocketChannel serverSocketChannel;
private boolean acceptConnections = true;
public ServerThread(ServerSocketChannel serverSocketChannel) {
this.serverSocketChannel = serverSocketChannel;
setName("Tunnel Server");
setDaemon(true);
}
public void close() throws IOException {
this.serverSocketChannel.close();
this.acceptConnections = false;
interrupt();
}
@Override
public void run() {
try {
while (this.acceptConnections) {
SocketChannel socket = this.serverSocketChannel.accept();
try {
handleConnection(socket);
}
finally {
socket.close();
}
}
}
catch (Exception ex) {
logger.trace("Unexpected exception from tunnel client", ex);
}
}
private void handleConnection(SocketChannel socketChannel) throws Exception {
Closeable closeable = new SocketCloseable(socketChannel);
WritableByteChannel outputChannel = TunnelClient.this.tunnelConnection.open(
socketChannel, closeable);
TunnelClient.this.listeners.fireOpenEvent(socketChannel);
try {
logger.trace("Accepted connection to tunnel client from "
+ socketChannel.socket().getRemoteSocketAddress());
while (true) {
ByteBuffer buffer = ByteBuffer.allocate(BUFFER_SIZE);
int amountRead = socketChannel.read(buffer);
if (amountRead == -1) {
outputChannel.close();
return;
}
if (amountRead > 0) {
buffer.flip();
outputChannel.write(buffer);
}
}
}
finally {
outputChannel.close();
}
}
protected void stopAcceptingConnections() {
this.acceptConnections = false;
}
}
/**
* {@link Closeable} used to close a {@link SocketChannel} and fire an event.
*/
private class SocketCloseable implements Closeable {
private final SocketChannel socketChannel;
private boolean closed = false;
public SocketCloseable(SocketChannel socketChannel) {
this.socketChannel = socketChannel;
}
@Override
public void close() throws IOException {
if (!this.closed) {
this.socketChannel.close();
TunnelClient.this.listeners.fireCloseEvent(this.socketChannel);
this.closed = true;
}
}
}
}
/*
* Copyright 2012-2015 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.boot.developertools.tunnel.client;
import java.nio.channels.SocketChannel;
/**
* Listener that can be used to receive {@link TunnelClient} events.
*
* @author Phillip Webb
* @since 1.3.0
*/
public interface TunnelClientListener {
/**
* Called when a socket channel is opened.
* @param socket the socket channel
*/
void onOpen(SocketChannel socket);
/**
* Called when a socket channel is closed.
* @param socket the socket channel
*/
void onClose(SocketChannel socket);
}
/*
* Copyright 2012-2015 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.boot.developertools.tunnel.client;
import java.nio.channels.SocketChannel;
import java.util.ArrayList;
import java.util.List;
import org.springframework.util.Assert;
/**
* A collection of {@link TunnelClientListener}.
*
* @author Phillip Webb
*/
class TunnelClientListeners {
private final List<TunnelClientListener> listeners = new ArrayList<TunnelClientListener>();
public void addListener(TunnelClientListener listener) {
Assert.notNull(listener, "Listener must not be null");
this.listeners.add(listener);
}
public void removeListener(TunnelClientListener listener) {
Assert.notNull(listener, "Listener must not be null");
this.listeners.remove(listener);
}
public void fireOpenEvent(SocketChannel socket) {
for (TunnelClientListener listener : this.listeners) {
listener.onOpen(socket);
}
}
public void fireCloseEvent(SocketChannel socket) {
for (TunnelClientListener listener : this.listeners) {
listener.onClose(socket);
}
}
}
/*
* Copyright 2012-2015 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.boot.developertools.tunnel.client;
import java.io.Closeable;
import java.nio.channels.WritableByteChannel;
/**
* Interface used to manage socket tunnel connections.
*
* @author Phillip Webb
* @since 1.3.0
*/
public interface TunnelConnection {
/**
* Open the tunnel connection.
* @param incomingChannel A {@link WritableByteChannel} that should be used to write
* any incoming data received from the remote server.
* @param closeable
* @return A {@link WritableByteChannel} that should be used to send any outgoing data
* destined for the remote server
* @throws Exception
*/
WritableByteChannel open(WritableByteChannel incomingChannel, Closeable closeable)
throws Exception;
}
/*
* Copyright 2012-2015 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/**
* Client side TCP tunnel support.
*/
package org.springframework.boot.developertools.tunnel.client;
/*
* Copyright 2012-2015 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/**
* Provides support for tunneling TCP traffic over HTTP. Tunneling is primarily designed
* for the Java Debug Wire Protocol (JDWP) and as such only expects a single connection
* and isn't particularly worried about resource usage.
*/
package org.springframework.boot.developertools.tunnel;
/*
* Copyright 2012-2015 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.boot.developertools.tunnel.payload;
import java.io.IOException;
import java.io.InterruptedIOException;
import java.nio.ByteBuffer;
import java.nio.channels.Channels;
import java.nio.channels.ReadableByteChannel;
import java.nio.channels.WritableByteChannel;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.springframework.http.HttpHeaders;
import org.springframework.http.HttpInputMessage;
import org.springframework.http.HttpOutputMessage;
import org.springframework.http.MediaType;
import org.springframework.util.Assert;
import org.springframework.util.StringUtils;
/**
* Encapsulates a payload data sent via a HTTP tunnel.
*
* @author Phillip Webb
* @since 1.3.0
*/
public class HttpTunnelPayload {
private static final String SEQ_HEADER = "x-seq";
private static final int BUFFER_SIZE = 1024 * 100;
final protected static char[] HEX_CHARS = "0123456789ABCDEF".toCharArray();
private static final Log logger = LogFactory.getLog(HttpTunnelPayload.class);
private final long sequence;
private final ByteBuffer data;
/**
* Create a new {@link HttpTunnelPayload} instance.
* @param sequence the sequence number of the payload
* @param data the payload data
*/
public HttpTunnelPayload(long sequence, ByteBuffer data) {
Assert.isTrue(sequence > 0, "Sequence must be positive");
Assert.notNull(data, "Data must not be null");
this.sequence = sequence;
this.data = data;
}
/**
* Return the sequence number of the payload.
* @return the sequence
*/
public long getSequence() {
return this.sequence;
}
/**
* Assign this payload to the given {@link HttpOutputMessage}.
* @param message the message to assign this payload to
* @throws IOException
*/
public void assignTo(HttpOutputMessage message) throws IOException {
Assert.notNull(message, "Message must not be null");
HttpHeaders headers = message.getHeaders();
headers.setContentLength(this.data.remaining());
headers.add(SEQ_HEADER, Long.toString(getSequence()));
headers.setContentType(MediaType.APPLICATION_OCTET_STREAM);
WritableByteChannel body = Channels.newChannel(message.getBody());
while (this.data.hasRemaining()) {
body.write(this.data);
}
body.close();
}
/**
* Write the content of this payload to the given target channel.
* @param channel the channel to write to
* @throws IOException
*/
public void writeTo(WritableByteChannel channel) throws IOException {
Assert.notNull(channel, "Channel must not be null");
while (this.data.hasRemaining()) {
channel.write(this.data);
}
}
/**
* Return the {@link HttpTunnelPayload} for the given message or {@code null} if there
* is no payload.
* @param message the HTTP message
* @return the payload or {@code null}
* @throws IOException
*/
public static HttpTunnelPayload get(HttpInputMessage message) throws IOException {
long length = message.getHeaders().getContentLength();
if (length <= 0) {
return null;
}
String seqHeader = message.getHeaders().getFirst(SEQ_HEADER);
Assert.state(StringUtils.hasLength(seqHeader), "Missing sequence header");
ReadableByteChannel body = Channels.newChannel(message.getBody());
ByteBuffer payload = ByteBuffer.allocate((int) length);
while (payload.hasRemaining()) {
body.read(payload);
}
body.close();
payload.flip();
return new HttpTunnelPayload(Long.valueOf(seqHeader), payload);
}
/**
* Return the payload data for the given source {@link ReadableByteChannel} or null if
* the channel timed out whilst reading.
* @param channel the source channel
* @return payload data or {@code null}
* @throws IOException
*/
public static ByteBuffer getPayloadData(ReadableByteChannel channel)
throws IOException {
ByteBuffer buffer = ByteBuffer.allocate(BUFFER_SIZE);
try {
int amountRead = channel.read(buffer);
Assert.state(amountRead != -1, "Target server connection closed");
buffer.flip();
return buffer;
}
catch (InterruptedIOException ex) {
return null;
}
}
/**
* Log incoming payload information at trace level to aid diagnostics.
*/
public void logIncoming() {
log("< ");
}
/**
* Log incoming payload information at trace level to aid diagnostics.
*/
public void logOutgoing() {
log("> ");
}
private void log(String prefix) {
if (logger.isTraceEnabled()) {
logger.trace(prefix + toHexString());
}
}
/**
* Return the payload as a hexadecimal string.
* @return the payload as a hex string
*/
public String toHexString() {
byte[] bytes = this.data.array();
char[] hex = new char[this.data.remaining() * 2];
for (int i = this.data.position(); i < this.data.remaining(); i++) {
int b = bytes[i] & 0xFF;
hex[i * 2] = HEX_CHARS[b >>> 4];
hex[i * 2 + 1] = HEX_CHARS[b & 0x0F];
}
return new String(hex);
}
}
/*
* Copyright 2012-2015 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.boot.developertools.tunnel.payload;
import java.io.IOException;
import java.nio.channels.WritableByteChannel;
import java.util.HashMap;
import java.util.Map;
import org.springframework.util.Assert;
/**
* Utility class that forwards {@link HttpTunnelPayload} instances to a destination
* channel, respecting sequence order.
*
* @author Phillip Webb
* @since 1.3.0
*/
public class HttpTunnelPayloadForwarder {
private static final int MAXIMUM_QUEUE_SIZE = 100;
private final WritableByteChannel targetChannel;
private long lastRequestSeq = 0;
private final Map<Long, HttpTunnelPayload> queue = new HashMap<Long, HttpTunnelPayload>();
/**
* Create a new {@link HttpTunnelPayloadForwarder} instance.
* @param targetChannel the target channel
*/
public HttpTunnelPayloadForwarder(WritableByteChannel targetChannel) {
Assert.notNull(targetChannel, "TargetChannel must not be null");
this.targetChannel = targetChannel;
}
public synchronized void forward(HttpTunnelPayload payload) throws IOException {
long seq = payload.getSequence();
if (this.lastRequestSeq != seq - 1) {
Assert.state(this.queue.size() < MAXIMUM_QUEUE_SIZE,
"Too many messages queued");
this.queue.put(seq, payload);
return;
}
payload.logOutgoing();
payload.writeTo(this.targetChannel);
this.lastRequestSeq = seq;
HttpTunnelPayload queuedItem = this.queue.get(seq + 1);
if (queuedItem != null) {
forward(queuedItem);
}
}
}
/*
* Copyright 2012-2015 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/**
* Classes to deal with payloads sent over a HTTP tunnel.
*/
package org.springframework.boot.developertools.tunnel.payload;
/*
* Copyright 2012-2015 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.boot.developertools.tunnel.server;
import java.io.IOException;
import org.springframework.boot.developertools.remote.server.Handler;
import org.springframework.http.server.ServerHttpRequest;
import org.springframework.http.server.ServerHttpResponse;
import org.springframework.util.Assert;
/**
* Adapts a {@link HttpTunnelServer} to a {@link Handler}.
*
* @author Phillip Webb
* @since 1.3.0
*/
public class HttpTunnelServerHandler implements Handler {
private HttpTunnelServer server;
/**
* Create a new {@link HttpTunnelServerHandler} instance.
* @param server the server to adapt
*/
public HttpTunnelServerHandler(HttpTunnelServer server) {
Assert.notNull(server, "Server must not be null");
this.server = server;
}
@Override
public void handle(ServerHttpRequest request, ServerHttpResponse response)
throws IOException {
this.server.handle(request, response);
}
}
/*
* Copyright 2012-2015 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.boot.developertools.tunnel.server;
/**
* Strategy interface to provide access to a port (which may change if an existing
* connection is closed).
*
* @author Phillip Webb
* @since 1.3.0
*/
public interface PortProvider {
/**
* Return the port number
* @return the port number
*/
int getPort();
}
/*
* Copyright 2012-2015 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.boot.developertools.tunnel.server;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.springframework.boot.lang.UsesUnsafeJava;
import org.springframework.util.Assert;
/**
* {@link PortProvider} that provides the port being used by the Java remote debugging.
*
* @author Phillip Webb
*/
public class RemoteDebugPortProvider implements PortProvider {
private static final String JDWP_ADDRESS_PROPERTY = "sun.jdwp.listenerAddress";
private static final Log logger = LogFactory.getLog(RemoteDebugPortProvider.class);
@Override
public int getPort() {
Assert.state(isRemoteDebugRunning(), "Remote debug is not running");
return getRemoteDebugPort();
}
public static boolean isRemoteDebugRunning() {
return getRemoteDebugPort() != -1;
}
@UsesUnsafeJava
@SuppressWarnings("restriction")
private static int getRemoteDebugPort() {
String property = sun.misc.VMSupport.getAgentProperties().getProperty(
JDWP_ADDRESS_PROPERTY);
try {
if (property != null && property.contains(":")) {
return Integer.valueOf(property.split(":")[1]);
}
}
catch (Exception ex) {
logger.trace("Unable to get JDWP port from property value '" + property + "'");
}
return -1;
}
}
/*
* Copyright 2012-2015 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.boot.developertools.tunnel.server;
import java.io.IOException;
import java.net.InetSocketAddress;
import java.net.SocketAddress;
import java.nio.ByteBuffer;
import java.nio.channels.ByteChannel;
import java.nio.channels.Channels;
import java.nio.channels.ReadableByteChannel;
import java.nio.channels.SocketChannel;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.springframework.util.Assert;
/**
* Socket based {@link TargetServerConnection}.
*
* @author Phillip Webb
* @since 1.3.0
*/
public class SocketTargetServerConnection implements TargetServerConnection {
private static final Log logger = LogFactory
.getLog(SocketTargetServerConnection.class);
private final PortProvider portProvider;
/**
* Create a new {@link SocketTargetServerConnection}.
* @param portProvider the port provider
*/
public SocketTargetServerConnection(PortProvider portProvider) {
Assert.notNull(portProvider, "PortProvider must not be null");
this.portProvider = portProvider;
}
@Override
public ByteChannel open(int socketTimeout) throws IOException {
SocketAddress address = new InetSocketAddress(this.portProvider.getPort());
logger.trace("Opening tunnel connection to target server on " + address);
SocketChannel channel = SocketChannel.open(address);
channel.socket().setSoTimeout(socketTimeout);
return new TimeoutAwareChannel(channel);
}
/**
* Wrapper to expose the {@link SocketChannel} in such a way that
* {@code SocketTimeoutExceptions} are still thrown from read methods.
*/
private static class TimeoutAwareChannel implements ByteChannel {
private final SocketChannel socketChannel;
private final ReadableByteChannel readChannel;
public TimeoutAwareChannel(SocketChannel socketChannel) throws IOException {
this.socketChannel = socketChannel;
this.readChannel = Channels.newChannel(socketChannel.socket()
.getInputStream());
}
@Override
public int read(ByteBuffer dst) throws IOException {
return this.readChannel.read(dst);
}
@Override
public int write(ByteBuffer src) throws IOException {
return this.socketChannel.write(src);
}
@Override
public boolean isOpen() {
return this.socketChannel.isOpen();
}
@Override
public void close() throws IOException {
this.socketChannel.close();
}
}
}
/*
* Copyright 2012-2015 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.boot.developertools.tunnel.server;
import org.springframework.util.Assert;
/**
* {@link PortProvider} for a static port that won't change.
*
* @author Phillip Webb
* @since 1.3.0
*/
public class StaticPortProvider implements PortProvider {
private final int port;
public StaticPortProvider(int port) {
Assert.isTrue(port > 0, "Port must be positive");
this.port = port;
}
@Override
public int getPort() {
return this.port;
}
}
/*
* Copyright 2012-2015 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.boot.developertools.tunnel.server;
import java.io.IOException;
import java.nio.channels.ByteChannel;
/**
* Manages the connection to the ultimate tunnel target server.
*
* @author Phillip Webb
* @since 1.3.0
*/
public interface TargetServerConnection {
/**
* Open a connection to the target server with the specified timeout.
* @param timeout the read timeout
* @return a {@link ByteChannel} providing read/write access to the server
* @throws IOException
*/
ByteChannel open(int timeout) throws IOException;
}
/*
* Copyright 2012-2015 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
/**
* Server side TCP tunnel support.
*/
package org.springframework.boot.developertools.tunnel.server;
/*
* Copyright 2012-2015 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.boot.developertools.tunnel.client;
import java.io.ByteArrayOutputStream;
import java.io.Closeable;
import java.io.IOException;
import java.nio.ByteBuffer;
import java.nio.channels.Channels;
import java.nio.channels.WritableByteChannel;
import java.util.concurrent.Executor;
import org.junit.Before;
import org.junit.Rule;
import org.junit.Test;
import org.junit.rules.ExpectedException;
import org.mockito.Mock;
import org.mockito.MockitoAnnotations;
import org.springframework.boot.developertools.test.MockClientHttpRequestFactory;
import org.springframework.boot.developertools.tunnel.client.HttpTunnelConnection.TunnelChannel;
import org.springframework.http.HttpStatus;
import org.springframework.util.SocketUtils;
import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.greaterThan;
import static org.junit.Assert.assertThat;
import static org.mockito.Mockito.never;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
/**
* Tests for {@link HttpTunnelConnection}.
*
* @author Phillip Webb
* @author Rob Winch
*/
public class HttpTunnelConnectionTests {
@Rule
public ExpectedException thrown = ExpectedException.none();
private int port = SocketUtils.findAvailableTcpPort();
private String url;
private ByteArrayOutputStream incommingData;
private WritableByteChannel incomingChannel;
@Mock
private Closeable closeable;
private MockClientHttpRequestFactory requestFactory = new MockClientHttpRequestFactory();
@Before
public void setup() {
MockitoAnnotations.initMocks(this);
this.url = "http://localhost:" + this.port;
this.incommingData = new ByteArrayOutputStream();
this.incomingChannel = Channels.newChannel(this.incommingData);
}
@Test
public void urlMustNotBeNull() throws Exception {
this.thrown.expect(IllegalArgumentException.class);
this.thrown.expectMessage("URL must not be empty");
new HttpTunnelConnection(null, this.requestFactory);
}
@Test
public void urlMustNotBeEmpty() throws Exception {
this.thrown.expect(IllegalArgumentException.class);
this.thrown.expectMessage("URL must not be empty");
new HttpTunnelConnection("", this.requestFactory);
}
@Test
public void urlMustNotBeMalformed() throws Exception {
this.thrown.expect(IllegalArgumentException.class);
this.thrown.expectMessage("Malformed URL 'htttttp:///ttest'");
new HttpTunnelConnection("htttttp:///ttest", this.requestFactory);
}
@Test
public void requestFactoryMustNotBeNull() {
this.thrown.expect(IllegalArgumentException.class);
this.thrown.expectMessage("RequestFactory must not be null");
new HttpTunnelConnection(this.url, null);
}
@Test
public void closeTunnelChangesIsOpen() throws Exception {
this.requestFactory.willRespondAfterDelay(1000, HttpStatus.GONE);
WritableByteChannel channel = openTunnel(false);
assertThat(channel.isOpen(), equalTo(true));
channel.close();
assertThat(channel.isOpen(), equalTo(false));
}
@Test
public void closeTunnelCallsCloseableOnce() throws Exception {
this.requestFactory.willRespondAfterDelay(1000, HttpStatus.GONE);
WritableByteChannel channel = openTunnel(false);
verify(this.closeable, never()).close();
channel.close();
channel.close();
verify(this.closeable, times(1)).close();
}
@Test
public void typicalTraffic() throws Exception {
this.requestFactory.willRespond("hi", "=2", "=3");
TunnelChannel channel = openTunnel(true);
write(channel, "hello");
write(channel, "1+1");
write(channel, "1+2");
assertThat(this.incommingData.toString(), equalTo("hi=2=3"));
}
@Test
public void trafficWithLongPollTimeouts() throws Exception {
for (int i = 0; i < 10; i++) {
this.requestFactory.willRespond(HttpStatus.NO_CONTENT);
}
this.requestFactory.willRespond("hi");
TunnelChannel channel = openTunnel(true);
write(channel, "hello");
assertThat(this.incommingData.toString(), equalTo("hi"));
assertThat(this.requestFactory.getExecutedRequests().size(), greaterThan(10));
}
private void write(TunnelChannel channel, String string) throws IOException {
channel.write(ByteBuffer.wrap(string.getBytes()));
}
private TunnelChannel openTunnel(boolean singleThreaded) throws Exception {
HttpTunnelConnection connection = new HttpTunnelConnection(this.url,
this.requestFactory,
(singleThreaded ? new CurrentThreadExecutor() : null));
return connection.open(this.incomingChannel, this.closeable);
}
private static class CurrentThreadExecutor implements Executor {
@Override
public void execute(Runnable command) {
command.run();
}
}
}
/*
* Copyright 2012-2015 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.boot.developertools.tunnel.client;
import java.io.ByteArrayOutputStream;
import java.io.Closeable;
import java.io.IOException;
import java.net.InetSocketAddress;
import java.nio.ByteBuffer;
import java.nio.channels.Channels;
import java.nio.channels.SocketChannel;
import java.nio.channels.WritableByteChannel;
import org.junit.Rule;
import org.junit.Test;
import org.junit.rules.ExpectedException;
import org.springframework.util.SocketUtils;
import static org.hamcrest.Matchers.equalTo;
import static org.junit.Assert.assertThat;
import static org.mockito.Matchers.any;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.verify;
/**
* Tests for {@link TunnelClient}.
*
* @author Phillip Webb
*/
public class TunnelClientTests {
@Rule
public ExpectedException thrown = ExpectedException.none();
private int listenPort = SocketUtils.findAvailableTcpPort();
private MockTunnelConnection tunnelConnection = new MockTunnelConnection();
@Test
public void listenPortMustBePositive() throws Exception {
this.thrown.expect(IllegalArgumentException.class);
this.thrown.expectMessage("ListenPort must be positive");
new TunnelClient(0, this.tunnelConnection);
}
@Test
public void tunnelConnectionMustNotBeNull() throws Exception {
this.thrown.expect(IllegalArgumentException.class);
this.thrown.expectMessage("TunnelConnection must not be null");
new TunnelClient(1, null);
}
@Test
public void typicalTraffic() throws Exception {
TunnelClient client = new TunnelClient(this.listenPort, this.tunnelConnection);
client.start();
SocketChannel channel = SocketChannel
.open(new InetSocketAddress(this.listenPort));
channel.write(ByteBuffer.wrap("hello".getBytes()));
ByteBuffer buffer = ByteBuffer.allocate(5);
channel.read(buffer);
channel.close();
this.tunnelConnection.verifyWritten("hello");
assertThat(new String(buffer.array()), equalTo("olleh"));
}
@Test
public void socketChannelClosedTriggersTunnelClose() throws Exception {
TunnelClient client = new TunnelClient(this.listenPort, this.tunnelConnection);
client.start();
SocketChannel channel = SocketChannel
.open(new InetSocketAddress(this.listenPort));
channel.close();
client.getServerThread().stopAcceptingConnections();
client.getServerThread().join(2000);
assertThat(this.tunnelConnection.getOpenedTimes(), equalTo(1));
assertThat(this.tunnelConnection.isOpen(), equalTo(false));
}
@Test
public void stopTriggersTunnelClose() throws Exception {
TunnelClient client = new TunnelClient(this.listenPort, this.tunnelConnection);
client.start();
SocketChannel channel = SocketChannel
.open(new InetSocketAddress(this.listenPort));
client.stop();
assertThat(this.tunnelConnection.getOpenedTimes(), equalTo(1));
assertThat(this.tunnelConnection.isOpen(), equalTo(false));
assertThat(channel.read(ByteBuffer.allocate(1)), equalTo(-1));
}
@Test
public void addListener() throws Exception {
TunnelClient client = new TunnelClient(this.listenPort, this.tunnelConnection);
TunnelClientListener listener = mock(TunnelClientListener.class);
client.addListener(listener);
client.start();
SocketChannel channel = SocketChannel
.open(new InetSocketAddress(this.listenPort));
channel.close();
client.getServerThread().stopAcceptingConnections();
client.getServerThread().join(2000);
verify(listener).onOpen(any(SocketChannel.class));
verify(listener).onClose(any(SocketChannel.class));
}
private static class MockTunnelConnection implements TunnelConnection {
private final ByteArrayOutputStream written = new ByteArrayOutputStream();
private boolean open;
private int openedTimes;
@Override
public WritableByteChannel open(WritableByteChannel incomingChannel,
Closeable closeable) throws Exception {
this.openedTimes++;
this.open = true;
return new TunnelChannel(incomingChannel, closeable);
}
public void verifyWritten(String expected) {
verifyWritten(expected.getBytes());
}
public void verifyWritten(byte[] expected) {
synchronized (this.written) {
assertThat(this.written.toByteArray(), equalTo(expected));
this.written.reset();
}
}
public boolean isOpen() {
return this.open;
}
public int getOpenedTimes() {
return this.openedTimes;
}
private class TunnelChannel implements WritableByteChannel {
private final WritableByteChannel incomingChannel;
private final Closeable closeable;
public TunnelChannel(WritableByteChannel incomingChannel, Closeable closeable) {
this.incomingChannel = incomingChannel;
this.closeable = closeable;
}
@Override
public boolean isOpen() {
return MockTunnelConnection.this.open;
}
@Override
public void close() throws IOException {
MockTunnelConnection.this.open = false;
this.closeable.close();
}
@Override
public int write(ByteBuffer src) throws IOException {
int remaining = src.remaining();
ByteArrayOutputStream stream = new ByteArrayOutputStream();
Channels.newChannel(stream).write(src);
byte[] bytes = stream.toByteArray();
synchronized (MockTunnelConnection.this.written) {
MockTunnelConnection.this.written.write(bytes);
}
byte[] reversed = new byte[bytes.length];
for (int i = 0; i < reversed.length; i++) {
reversed[i] = bytes[bytes.length - 1 - i];
}
this.incomingChannel.write(ByteBuffer.wrap(reversed));
return remaining;
}
}
}
}
/*
* Copyright 2012-2015 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.boot.developertools.tunnel.payload;
import java.io.ByteArrayOutputStream;
import java.nio.ByteBuffer;
import java.nio.channels.Channels;
import java.nio.channels.WritableByteChannel;
import org.junit.Rule;
import org.junit.Test;
import org.junit.rules.ExpectedException;
import static org.hamcrest.Matchers.equalTo;
import static org.junit.Assert.assertThat;
/**
* Tests for {@link HttpTunnelPayloadForwarder}.
*
* @author Phillip Webb
*/
public class HttpTunnelPayloadForwarderTests {
@Rule
public ExpectedException thrown = ExpectedException.none();
@Test
public void targetChannelMustNoBeNull() throws Exception {
this.thrown.expect(IllegalArgumentException.class);
this.thrown.expectMessage("TargetChannel must not be null");
new HttpTunnelPayloadForwarder(null);
}
@Test
public void forwardInSequence() throws Exception {
ByteArrayOutputStream out = new ByteArrayOutputStream();
WritableByteChannel channel = Channels.newChannel(out);
HttpTunnelPayloadForwarder forwarder = new HttpTunnelPayloadForwarder(channel);
forwarder.forward(payload(1, "he"));
forwarder.forward(payload(2, "ll"));
forwarder.forward(payload(3, "o"));
assertThat(out.toByteArray(), equalTo("hello".getBytes()));
}
@Test
public void forwardOutOfSequence() throws Exception {
ByteArrayOutputStream out = new ByteArrayOutputStream();
WritableByteChannel channel = Channels.newChannel(out);
HttpTunnelPayloadForwarder forwarder = new HttpTunnelPayloadForwarder(channel);
forwarder.forward(payload(3, "o"));
forwarder.forward(payload(2, "ll"));
forwarder.forward(payload(1, "he"));
assertThat(out.toByteArray(), equalTo("hello".getBytes()));
}
@Test
public void overflow() throws Exception {
WritableByteChannel channel = Channels.newChannel(new ByteArrayOutputStream());
HttpTunnelPayloadForwarder forwarder = new HttpTunnelPayloadForwarder(channel);
this.thrown.expect(IllegalStateException.class);
this.thrown.expectMessage("Too many messages queued");
for (int i = 2; i < 130; i++) {
forwarder.forward(payload(i, "data" + i));
}
}
private HttpTunnelPayload payload(long sequence, String data) {
return new HttpTunnelPayload(sequence, ByteBuffer.wrap(data.getBytes()));
}
}
/*
* Copyright 2012-2015 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.boot.developertools.tunnel.payload;
import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.net.SocketTimeoutException;
import java.nio.ByteBuffer;
import java.nio.channels.Channels;
import java.nio.channels.ReadableByteChannel;
import java.nio.channels.WritableByteChannel;
import org.junit.Rule;
import org.junit.Test;
import org.junit.rules.ExpectedException;
import org.springframework.http.HttpInputMessage;
import org.springframework.http.HttpOutputMessage;
import org.springframework.http.server.ServletServerHttpRequest;
import org.springframework.http.server.ServletServerHttpResponse;
import org.springframework.mock.web.MockHttpServletRequest;
import org.springframework.mock.web.MockHttpServletResponse;
import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.nullValue;
import static org.junit.Assert.assertThat;
import static org.mockito.BDDMockito.given;
import static org.mockito.Matchers.any;
import static org.mockito.Mockito.mock;
/**
* Tests for {@link HttpTunnelPayload}.
*
* @author Phillip Webb
*/
public class HttpTunnelPayloadTests {
@Rule
public ExpectedException thrown = ExpectedException.none();
@Test
public void sequenceMustBePositive() throws Exception {
this.thrown.expect(IllegalArgumentException.class);
this.thrown.expectMessage("Sequence must be positive");
new HttpTunnelPayload(0, ByteBuffer.allocate(1));
}
@Test
public void dataMustNotBeNull() throws Exception {
this.thrown.expect(IllegalArgumentException.class);
this.thrown.expectMessage("Data must not be null");
new HttpTunnelPayload(1, null);
}
@Test
public void getSequence() throws Exception {
HttpTunnelPayload payload = new HttpTunnelPayload(1, ByteBuffer.allocate(1));
assertThat(payload.getSequence(), equalTo(1L));
}
@Test
public void getData() throws Exception {
ByteBuffer data = ByteBuffer.wrap("hello".getBytes());
HttpTunnelPayload payload = new HttpTunnelPayload(1, data);
assertThat(getData(payload), equalTo(data.array()));
}
@Test
public void assignTo() throws Exception {
ByteBuffer data = ByteBuffer.wrap("hello".getBytes());
HttpTunnelPayload payload = new HttpTunnelPayload(2, data);
MockHttpServletResponse servletResponse = new MockHttpServletResponse();
HttpOutputMessage response = new ServletServerHttpResponse(servletResponse);
payload.assignTo(response);
assertThat(servletResponse.getHeader("x-seq"), equalTo("2"));
assertThat(servletResponse.getContentAsString(), equalTo("hello"));
}
@Test
public void getNoData() throws Exception {
MockHttpServletRequest servletRequest = new MockHttpServletRequest();
HttpInputMessage request = new ServletServerHttpRequest(servletRequest);
HttpTunnelPayload payload = HttpTunnelPayload.get(request);
assertThat(payload, nullValue());
}
@Test
public void getWithMissingHeader() throws Exception {
MockHttpServletRequest servletRequest = new MockHttpServletRequest();
servletRequest.setContent("hello".getBytes());
HttpInputMessage request = new ServletServerHttpRequest(servletRequest);
this.thrown.expect(IllegalStateException.class);
this.thrown.expectMessage("Missing sequence header");
HttpTunnelPayload.get(request);
}
@Test
public void getWithData() throws Exception {
MockHttpServletRequest servletRequest = new MockHttpServletRequest();
servletRequest.setContent("hello".getBytes());
servletRequest.addHeader("x-seq", 123);
HttpInputMessage request = new ServletServerHttpRequest(servletRequest);
HttpTunnelPayload payload = HttpTunnelPayload.get(request);
assertThat(payload.getSequence(), equalTo(123L));
assertThat(getData(payload), equalTo("hello".getBytes()));
}
@Test
public void getPayloadData() throws Exception {
ReadableByteChannel channel = Channels.newChannel(new ByteArrayInputStream(
"hello".getBytes()));
ByteBuffer payloadData = HttpTunnelPayload.getPayloadData(channel);
ByteArrayOutputStream out = new ByteArrayOutputStream();
WritableByteChannel writeChannel = Channels.newChannel(out);
while (payloadData.hasRemaining()) {
writeChannel.write(payloadData);
}
assertThat(out.toByteArray(), equalTo("hello".getBytes()));
}
@Test
public void getPayloadDataWithTimeout() throws Exception {
ReadableByteChannel channel = mock(ReadableByteChannel.class);
given(channel.read(any(ByteBuffer.class)))
.willThrow(new SocketTimeoutException());
ByteBuffer payload = HttpTunnelPayload.getPayloadData(channel);
assertThat(payload, nullValue());
}
private byte[] getData(HttpTunnelPayload payload) throws IOException {
ByteArrayOutputStream out = new ByteArrayOutputStream();
WritableByteChannel channel = Channels.newChannel(out);
payload.writeTo(channel);
return out.toByteArray();
}
}
/*
* Copyright 2012-2015 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.boot.developertools.tunnel.server;
import org.junit.Rule;
import org.junit.Test;
import org.junit.rules.ExpectedException;
import org.springframework.http.server.ServerHttpRequest;
import org.springframework.http.server.ServerHttpResponse;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.verify;
/**
* Tests for {@link HttpTunnelServerHandler}.
*
* @author Phillip Webb
*/
public class HttpTunnelServerHandlerTests {
@Rule
public ExpectedException thrown = ExpectedException.none();
@Test
public void serverMustNotBeNull() throws Exception {
this.thrown.expect(IllegalArgumentException.class);
this.thrown.expectMessage("Server must not be null");
new HttpTunnelServerHandler(null);
}
@Test
public void handleDelegatesToServer() throws Exception {
HttpTunnelServer server = mock(HttpTunnelServer.class);
HttpTunnelServerHandler handler = new HttpTunnelServerHandler(server);
ServerHttpRequest request = mock(ServerHttpRequest.class);
ServerHttpResponse response = mock(ServerHttpResponse.class);
handler.handle(request, response);
verify(server).handle(request, response);
}
}
/*
* Copyright 2012-2015 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.boot.developertools.tunnel.server;
import java.io.IOException;
import java.net.InetSocketAddress;
import java.net.SocketTimeoutException;
import java.nio.ByteBuffer;
import java.nio.channels.ByteChannel;
import java.nio.channels.ServerSocketChannel;
import java.nio.channels.SocketChannel;
import org.junit.Before;
import org.junit.Test;
import org.springframework.util.SocketUtils;
import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.greaterThanOrEqualTo;
import static org.hamcrest.Matchers.lessThan;
import static org.junit.Assert.assertThat;
import static org.junit.Assert.fail;
/**
* Tests for {@link SocketTargetServerConnection}.
*
* @author Phillip Webb
*/
public class SocketTargetServerConnectionTests {
private static final int DEFAULT_TIMEOUT = 1000;
private int port;
private MockServer server;
private SocketTargetServerConnection connection;
@Before
public void setup() throws IOException {
this.port = SocketUtils.findAvailableTcpPort();
this.server = new MockServer(this.port);
StaticPortProvider portProvider = new StaticPortProvider(this.port);
this.connection = new SocketTargetServerConnection(portProvider);
}
@Test
public void readData() throws Exception {
this.server.willSend("hello".getBytes());
this.server.start();
ByteChannel channel = this.connection.open(DEFAULT_TIMEOUT);
ByteBuffer buffer = ByteBuffer.allocate(5);
channel.read(buffer);
assertThat(buffer.array(), equalTo("hello".getBytes()));
}
@Test
public void writeData() throws Exception {
this.server.expect("hello".getBytes());
this.server.start();
ByteChannel channel = this.connection.open(DEFAULT_TIMEOUT);
ByteBuffer buffer = ByteBuffer.wrap("hello".getBytes());
channel.write(buffer);
this.server.closeAndVerify();
}
@Test
public void timeout() throws Exception {
this.server.delay(1000);
this.server.start();
ByteChannel channel = this.connection.open(10);
long startTime = System.currentTimeMillis();
try {
channel.read(ByteBuffer.allocate(5));
fail("No socket timeout thrown");
}
catch (SocketTimeoutException ex) {
// Expected
long runTime = System.currentTimeMillis() - startTime;
assertThat(runTime, greaterThanOrEqualTo(10L));
assertThat(runTime, lessThan(10000L));
}
}
private static class MockServer {
private ServerSocketChannel serverSocket;
private byte[] send;
private byte[] expect;
private int delay;
private ByteBuffer actualRead;
private ServerThread thread;
public MockServer(int port) throws IOException {
this.serverSocket = ServerSocketChannel.open();
this.serverSocket.bind(new InetSocketAddress(port));
}
public void delay(int delay) {
this.delay = delay;
}
public void willSend(byte[] send) {
this.send = send;
}
public void expect(byte[] expect) {
this.expect = expect;
}
public void start() {
this.thread = new ServerThread();
this.thread.start();
}
public void closeAndVerify() throws InterruptedException {
close();
assertThat(this.actualRead.array(), equalTo(this.expect));
}
public void close() throws InterruptedException {
while (this.thread.isAlive()) {
Thread.sleep(10);
}
}
private class ServerThread extends Thread {
@Override
public void run() {
try {
SocketChannel channel = MockServer.this.serverSocket.accept();
Thread.sleep(MockServer.this.delay);
if (MockServer.this.send != null) {
ByteBuffer buffer = ByteBuffer.wrap(MockServer.this.send);
while (buffer.hasRemaining()) {
channel.write(buffer);
}
}
if (MockServer.this.expect != null) {
ByteBuffer buffer = ByteBuffer
.allocate(MockServer.this.expect.length);
while (buffer.hasRemaining()) {
channel.read(buffer);
}
MockServer.this.actualRead = buffer;
}
channel.close();
}
catch (Exception ex) {
ex.printStackTrace();
fail();
}
}
}
}
}
/*
* Copyright 2012-2015 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.boot.developertools.tunnel.server;
import org.junit.Rule;
import org.junit.Test;
import org.junit.rules.ExpectedException;
import static org.hamcrest.Matchers.equalTo;
import static org.junit.Assert.assertThat;
/**
* Tests for {@link StaticPortProvider}.
*
* @author Phillip Webb
*/
public class StaticPortProviderTests {
@Rule
public ExpectedException thrown = ExpectedException.none();
@Test
public void portMustBePostive() throws Exception {
this.thrown.expect(IllegalArgumentException.class);
this.thrown.expectMessage("Port must be positive");
new StaticPortProvider(0);
}
@Test
public void getPort() throws Exception {
StaticPortProvider provider = new StaticPortProvider(123);
assertThat(provider.getPort(), equalTo(123));
}
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment