Commit b4810b8b authored by cbono's avatar cbono Committed by Brian Clozel

Add SSL support to RSocketServer

See gh-19399
parent dd024048
...@@ -19,12 +19,15 @@ package org.springframework.boot.autoconfigure.rsocket; ...@@ -19,12 +19,15 @@ package org.springframework.boot.autoconfigure.rsocket;
import java.net.InetAddress; import java.net.InetAddress;
import org.springframework.boot.context.properties.ConfigurationProperties; import org.springframework.boot.context.properties.ConfigurationProperties;
import org.springframework.boot.context.properties.NestedConfigurationProperty;
import org.springframework.boot.rsocket.server.RSocketServer; import org.springframework.boot.rsocket.server.RSocketServer;
import org.springframework.boot.web.server.Ssl;
/** /**
* {@link ConfigurationProperties properties} for RSocket support. * {@link ConfigurationProperties properties} for RSocket support.
* *
* @author Brian Clozel * @author Brian Clozel
* @author Chris Bono
* @since 2.2.0 * @since 2.2.0
*/ */
@ConfigurationProperties("spring.rsocket") @ConfigurationProperties("spring.rsocket")
...@@ -59,6 +62,9 @@ public class RSocketProperties { ...@@ -59,6 +62,9 @@ public class RSocketProperties {
*/ */
private String mappingPath; private String mappingPath;
@NestedConfigurationProperty
private Ssl ssl;
public Integer getPort() { public Integer getPort() {
return this.port; return this.port;
} }
...@@ -91,6 +97,14 @@ public class RSocketProperties { ...@@ -91,6 +97,14 @@ public class RSocketProperties {
this.mappingPath = mappingPath; this.mappingPath = mappingPath;
} }
public Ssl getSsl() {
return this.ssl;
}
public void setSsl(Ssl ssl) {
this.ssl = ssl;
}
} }
} }
...@@ -97,6 +97,7 @@ public class RSocketServerAutoConfiguration { ...@@ -97,6 +97,7 @@ public class RSocketServerAutoConfiguration {
PropertyMapper map = PropertyMapper.get().alwaysApplyingWhenNonNull(); PropertyMapper map = PropertyMapper.get().alwaysApplyingWhenNonNull();
map.from(properties.getServer().getAddress()).to(factory::setAddress); map.from(properties.getServer().getAddress()).to(factory::setAddress);
map.from(properties.getServer().getPort()).to(factory::setPort); map.from(properties.getServer().getPort()).to(factory::setPort);
map.from(properties.getServer().getSsl()).to(factory::setSsl);
factory.setRSocketServerCustomizers(customizers.orderedStream().collect(Collectors.toList())); factory.setRSocketServerCustomizers(customizers.orderedStream().collect(Collectors.toList()));
return factory; return factory;
} }
......
...@@ -91,6 +91,18 @@ class RSocketServerAutoConfigurationTests { ...@@ -91,6 +91,18 @@ class RSocketServerAutoConfigurationTests {
}); });
} }
@Test
void shouldUseSslWhenRocketServerSslIsConfigured() {
reactiveWebContextRunner()
.withPropertyValues("spring.rsocket.server.ssl.keyStore=classpath:rsocket/test.jks",
"spring.rsocket.server.ssl.keyPassword=password", "spring.rsocket.server.port=0")
.run((context) -> assertThat(context).hasSingleBean(RSocketServerFactory.class)
.hasSingleBean(RSocketServerBootstrap.class).hasSingleBean(RSocketServerCustomizer.class)
.getBean(RSocketServerFactory.class)
.hasFieldOrPropertyWithValue("ssl.keyStore", "classpath:rsocket/test.jks")
.hasFieldOrPropertyWithValue("ssl.keyPassword", "password"));
}
@Test @Test
void shouldUseCustomServerBootstrap() { void shouldUseCustomServerBootstrap() {
contextRunner().withUserConfiguration(CustomServerBootstrapConfig.class).run((context) -> assertThat(context) contextRunner().withUserConfiguration(CustomServerBootstrapConfig.class).run((context) -> assertThat(context)
......
...@@ -37,6 +37,9 @@ import org.springframework.boot.rsocket.server.ConfigurableRSocketServerFactory; ...@@ -37,6 +37,9 @@ import org.springframework.boot.rsocket.server.ConfigurableRSocketServerFactory;
import org.springframework.boot.rsocket.server.RSocketServer; import org.springframework.boot.rsocket.server.RSocketServer;
import org.springframework.boot.rsocket.server.RSocketServerCustomizer; import org.springframework.boot.rsocket.server.RSocketServerCustomizer;
import org.springframework.boot.rsocket.server.RSocketServerFactory; import org.springframework.boot.rsocket.server.RSocketServerFactory;
import org.springframework.boot.web.embedded.netty.SslServerCustomizer;
import org.springframework.boot.web.server.Ssl;
import org.springframework.boot.web.server.SslStoreProvider;
import org.springframework.http.client.reactive.ReactorResourceFactory; import org.springframework.http.client.reactive.ReactorResourceFactory;
import org.springframework.util.Assert; import org.springframework.util.Assert;
...@@ -45,6 +48,7 @@ import org.springframework.util.Assert; ...@@ -45,6 +48,7 @@ import org.springframework.util.Assert;
* by Netty. * by Netty.
* *
* @author Brian Clozel * @author Brian Clozel
* @author Chris Bono
* @since 2.2.0 * @since 2.2.0
*/ */
public class NettyRSocketServerFactory implements RSocketServerFactory, ConfigurableRSocketServerFactory { public class NettyRSocketServerFactory implements RSocketServerFactory, ConfigurableRSocketServerFactory {
...@@ -61,6 +65,10 @@ public class NettyRSocketServerFactory implements RSocketServerFactory, Configur ...@@ -61,6 +65,10 @@ public class NettyRSocketServerFactory implements RSocketServerFactory, Configur
private List<RSocketServerCustomizer> rSocketServerCustomizers = new ArrayList<>(); private List<RSocketServerCustomizer> rSocketServerCustomizers = new ArrayList<>();
private Ssl ssl;
private SslStoreProvider sslStoreProvider;
@Override @Override
public void setPort(int port) { public void setPort(int port) {
this.port = port; this.port = port;
...@@ -76,6 +84,16 @@ public class NettyRSocketServerFactory implements RSocketServerFactory, Configur ...@@ -76,6 +84,16 @@ public class NettyRSocketServerFactory implements RSocketServerFactory, Configur
this.transport = transport; this.transport = transport;
} }
@Override
public void setSsl(Ssl ssl) {
this.ssl = ssl;
}
@Override
public void setSslStoreProvider(SslStoreProvider sslStoreProvider) {
this.sslStoreProvider = sslStoreProvider;
}
/** /**
* Set the {@link ReactorResourceFactory} to get the shared resources from. * Set the {@link ReactorResourceFactory} to get the shared resources from.
* @param resourceFactory the server resources * @param resourceFactory the server resources
...@@ -133,21 +151,41 @@ public class NettyRSocketServerFactory implements RSocketServerFactory, Configur ...@@ -133,21 +151,41 @@ public class NettyRSocketServerFactory implements RSocketServerFactory, Configur
} }
private ServerTransport<CloseableChannel> createWebSocketTransport() { private ServerTransport<CloseableChannel> createWebSocketTransport() {
HttpServer httpServer;
if (this.resourceFactory != null) { if (this.resourceFactory != null) {
HttpServer httpServer = HttpServer.create().runOn(this.resourceFactory.getLoopResources()) httpServer = HttpServer.create().runOn(this.resourceFactory.getLoopResources())
.bindAddress(this::getListenAddress); .bindAddress(this::getListenAddress);
return WebsocketServerTransport.create(httpServer);
} }
return WebsocketServerTransport.create(getListenAddress()); else {
InetSocketAddress listenAddress = this.getListenAddress();
httpServer = HttpServer.create().host(listenAddress.getHostName()).port(listenAddress.getPort());
}
if (this.ssl != null && this.ssl.isEnabled()) {
SslServerCustomizer sslServerCustomizer = new SslServerCustomizer(this.ssl, null, this.sslStoreProvider);
httpServer = sslServerCustomizer.apply(httpServer);
}
return WebsocketServerTransport.create(httpServer);
} }
private ServerTransport<CloseableChannel> createTcpTransport() { private ServerTransport<CloseableChannel> createTcpTransport() {
TcpServer tcpServer;
if (this.resourceFactory != null) { if (this.resourceFactory != null) {
TcpServer tcpServer = TcpServer.create().runOn(this.resourceFactory.getLoopResources()) tcpServer = TcpServer.create().runOn(this.resourceFactory.getLoopResources())
.bindAddress(this::getListenAddress); .bindAddress(this::getListenAddress);
return TcpServerTransport.create(tcpServer);
} }
return TcpServerTransport.create(getListenAddress()); else {
InetSocketAddress listenAddress = this.getListenAddress();
tcpServer = TcpServer.create().host(listenAddress.getHostName()).port(listenAddress.getPort());
}
if (this.ssl != null && this.ssl.isEnabled()) {
TcpSslServerCustomizer sslServerCustomizer = new TcpSslServerCustomizer(this.ssl, this.sslStoreProvider);
tcpServer = sslServerCustomizer.apply(tcpServer);
}
return TcpServerTransport.create(tcpServer);
} }
private InetSocketAddress getListenAddress() { private InetSocketAddress getListenAddress() {
...@@ -157,4 +195,24 @@ public class NettyRSocketServerFactory implements RSocketServerFactory, Configur ...@@ -157,4 +195,24 @@ public class NettyRSocketServerFactory implements RSocketServerFactory, Configur
return new InetSocketAddress(this.port); return new InetSocketAddress(this.port);
} }
private static final class TcpSslServerCustomizer extends SslServerCustomizer {
private TcpSslServerCustomizer(Ssl ssl, SslStoreProvider sslStoreProvider) {
super(ssl, null, sslStoreProvider);
}
// This does not override the apply in parent - currently just leveraging the
// parent for its "getContextBuilder()" method. This should be refactored when
// we add the concept of http/tcp customizers for RSocket.
private TcpServer apply(TcpServer server) {
try {
return server.secure((contextSpec) -> contextSpec.sslContext(getContextBuilder()));
}
catch (Exception ex) {
throw new IllegalStateException(ex);
}
}
}
} }
...@@ -18,6 +18,9 @@ package org.springframework.boot.rsocket.server; ...@@ -18,6 +18,9 @@ package org.springframework.boot.rsocket.server;
import java.net.InetAddress; import java.net.InetAddress;
import org.springframework.boot.web.server.Ssl;
import org.springframework.boot.web.server.SslStoreProvider;
/** /**
* A configurable {@link RSocketServerFactory}. * A configurable {@link RSocketServerFactory}.
* *
...@@ -45,4 +48,16 @@ public interface ConfigurableRSocketServerFactory { ...@@ -45,4 +48,16 @@ public interface ConfigurableRSocketServerFactory {
*/ */
void setTransport(RSocketServer.Transport transport); void setTransport(RSocketServer.Transport transport);
/**
* Sets the SSL configuration that will be applied to the server's default connector.
* @param ssl the SSL configuration
*/
void setSsl(Ssl ssl);
/**
* Sets a provider that will be used to obtain SSL stores.
* @param sslStoreProvider the SSL store provider
*/
void setSslStoreProvider(SslStoreProvider sslStoreProvider);
} }
...@@ -17,15 +17,20 @@ ...@@ -17,15 +17,20 @@
package org.springframework.boot.rsocket.netty; package org.springframework.boot.rsocket.netty;
import java.net.InetSocketAddress; import java.net.InetSocketAddress;
import java.nio.channels.ClosedChannelException;
import java.time.Duration; import java.time.Duration;
import java.util.Arrays; import java.util.Arrays;
import java.util.concurrent.Callable; import java.util.concurrent.Callable;
import io.netty.buffer.PooledByteBufAllocator; import io.netty.buffer.PooledByteBufAllocator;
import io.netty.handler.ssl.SslContextBuilder;
import io.netty.handler.ssl.SslProvider;
import io.netty.handler.ssl.util.InsecureTrustManagerFactory;
import io.rsocket.ConnectionSetupPayload; import io.rsocket.ConnectionSetupPayload;
import io.rsocket.Payload; import io.rsocket.Payload;
import io.rsocket.RSocket; import io.rsocket.RSocket;
import io.rsocket.SocketAcceptor; import io.rsocket.SocketAcceptor;
import io.rsocket.transport.netty.client.TcpClientTransport;
import io.rsocket.transport.netty.client.WebsocketClientTransport; import io.rsocket.transport.netty.client.WebsocketClientTransport;
import io.rsocket.util.DefaultPayload; import io.rsocket.util.DefaultPayload;
import org.assertj.core.api.Assertions; import org.assertj.core.api.Assertions;
...@@ -33,9 +38,13 @@ import org.junit.jupiter.api.AfterEach; ...@@ -33,9 +38,13 @@ import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import org.mockito.InOrder; import org.mockito.InOrder;
import reactor.core.publisher.Mono; import reactor.core.publisher.Mono;
import reactor.netty.tcp.TcpClient;
import reactor.test.StepVerifier;
import org.springframework.boot.rsocket.server.RSocketServer; import org.springframework.boot.rsocket.server.RSocketServer;
import org.springframework.boot.rsocket.server.RSocketServerCustomizer; import org.springframework.boot.rsocket.server.RSocketServerCustomizer;
import org.springframework.boot.rsocket.server.RSocketServer.Transport;
import org.springframework.boot.web.server.Ssl;
import org.springframework.core.codec.CharSequenceEncoder; import org.springframework.core.codec.CharSequenceEncoder;
import org.springframework.core.codec.StringDecoder; import org.springframework.core.codec.StringDecoder;
import org.springframework.core.io.buffer.NettyDataBufferFactory; import org.springframework.core.io.buffer.NettyDataBufferFactory;
...@@ -45,6 +54,8 @@ import org.springframework.messaging.rsocket.RSocketStrategies; ...@@ -45,6 +54,8 @@ import org.springframework.messaging.rsocket.RSocketStrategies;
import org.springframework.util.SocketUtils; import org.springframework.util.SocketUtils;
import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatExceptionOfType;
import static org.assertj.core.api.Assertions.assertThatThrownBy;
import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.any;
import static org.mockito.BDDMockito.will; import static org.mockito.BDDMockito.will;
import static org.mockito.Mockito.inOrder; import static org.mockito.Mockito.inOrder;
...@@ -55,6 +66,7 @@ import static org.mockito.Mockito.mock; ...@@ -55,6 +66,7 @@ import static org.mockito.Mockito.mock;
* *
* @author Brian Clozel * @author Brian Clozel
* @author Leo Li * @author Leo Li
* @author Chris Bono
*/ */
class NettyRSocketServerFactoryTests { class NettyRSocketServerFactoryTests {
...@@ -93,7 +105,7 @@ class NettyRSocketServerFactoryTests { ...@@ -93,7 +105,7 @@ class NettyRSocketServerFactoryTests {
this.server.start(); this.server.start();
return port; return port;
}); });
this.requester = createRSocketTcpClient(); this.requester = createRSocketTcpClient(false);
String payload = "test payload"; String payload = "test payload";
String response = this.requester.route("test").data(payload).retrieveMono(String.class).block(TIMEOUT); String response = this.requester.route("test").data(payload).retrieveMono(String.class).block(TIMEOUT);
assertThat(this.server.address().getPort()).isEqualTo(specificPort); assertThat(this.server.address().getPort()).isEqualTo(specificPort);
...@@ -106,7 +118,7 @@ class NettyRSocketServerFactoryTests { ...@@ -106,7 +118,7 @@ class NettyRSocketServerFactoryTests {
factory.setTransport(RSocketServer.Transport.WEBSOCKET); factory.setTransport(RSocketServer.Transport.WEBSOCKET);
this.server = factory.create(new EchoRequestResponseAcceptor()); this.server = factory.create(new EchoRequestResponseAcceptor());
this.server.start(); this.server.start();
this.requester = createRSocketWebSocketClient(); this.requester = createRSocketWebSocketClient(false);
String payload = "test payload"; String payload = "test payload";
String response = this.requester.route("test").data(payload).retrieveMono(String.class).block(TIMEOUT); String response = this.requester.route("test").data(payload).retrieveMono(String.class).block(TIMEOUT);
assertThat(response).isEqualTo(payload); assertThat(response).isEqualTo(payload);
...@@ -121,7 +133,7 @@ class NettyRSocketServerFactoryTests { ...@@ -121,7 +133,7 @@ class NettyRSocketServerFactoryTests {
factory.setResourceFactory(resourceFactory); factory.setResourceFactory(resourceFactory);
this.server = factory.create(new EchoRequestResponseAcceptor()); this.server = factory.create(new EchoRequestResponseAcceptor());
this.server.start(); this.server.start();
this.requester = createRSocketWebSocketClient(); this.requester = createRSocketWebSocketClient(false);
String payload = "test payload"; String payload = "test payload";
String response = this.requester.route("test").data(payload).retrieveMono(String.class).block(TIMEOUT); String response = this.requester.route("test").data(payload).retrieveMono(String.class).block(TIMEOUT);
assertThat(response).isEqualTo(payload); assertThat(response).isEqualTo(payload);
...@@ -144,16 +156,94 @@ class NettyRSocketServerFactoryTests { ...@@ -144,16 +156,94 @@ class NettyRSocketServerFactoryTests {
} }
} }
private RSocketRequester createRSocketTcpClient() { @Test
Assertions.assertThat(this.server).isNotNull(); void tcpTransportBasicSslFromClassPath() {
InetSocketAddress address = this.server.address(); testBasicSslWithKeyStore("classpath:test.jks", "password", Transport.TCP);
return createRSocketRequesterBuilder().tcp(address.getHostString(), address.getPort()); }
@Test
void tcpTransportBasicSslFromFileSystem() {
testBasicSslWithKeyStore("src/test/resources/test.jks", "password", Transport.TCP);
}
@Test
void websocketTransportBasicSslFromClassPath() {
testBasicSslWithKeyStore("classpath:test.jks", "password", Transport.WEBSOCKET);
}
@Test
void websocketTransportBasicSslFromFileSystem() {
testBasicSslWithKeyStore("src/test/resources/test.jks", "password", Transport.WEBSOCKET);
}
private void testBasicSslWithKeyStore(String keyStore, String keyPassword, Transport transport) {
NettyRSocketServerFactory factory = getFactory();
factory.setTransport(transport);
Ssl ssl = new Ssl();
ssl.setKeyStore(keyStore);
ssl.setKeyPassword(keyPassword);
factory.setSsl(ssl);
this.server = factory.create(new EchoRequestResponseAcceptor());
this.server.start();
this.requester = (transport == Transport.TCP) ? createRSocketTcpClient(true)
: createRSocketWebSocketClient(true);
String payload = "test payload";
Mono<String> responseMono = this.requester.route("test").data(payload).retrieveMono(String.class);
StepVerifier.create(responseMono).expectNext(payload).verifyComplete();
} }
private RSocketRequester createRSocketWebSocketClient() { @Test
void tcpTransportSslRejectsInsecureClient() {
NettyRSocketServerFactory factory = getFactory();
factory.setTransport(Transport.TCP);
Ssl ssl = new Ssl();
ssl.setKeyStore("classpath:test.jks");
ssl.setKeyPassword("password");
factory.setSsl(ssl);
this.server = factory.create(new EchoRequestResponseAcceptor());
this.server.start();
this.requester = createRSocketTcpClient(false);
String payload = "test payload";
Mono<String> responseMono = this.requester.route("test").data(payload).retrieveMono(String.class);
StepVerifier.create(responseMono)
.verifyErrorSatisfies((ex) -> assertThatExceptionOfType(ClosedChannelException.class));
}
@Test
void websocketTransportSslRejectsInsecureClient() {
NettyRSocketServerFactory factory = getFactory();
factory.setTransport(Transport.WEBSOCKET);
Ssl ssl = new Ssl();
ssl.setKeyStore("classpath:test.jks");
ssl.setKeyPassword("password");
factory.setSsl(ssl);
this.server = factory.create(new EchoRequestResponseAcceptor());
this.server.start();
// For WebSocket, the SSL failure results in a hang on the initial connect call
assertThatThrownBy(() -> createRSocketWebSocketClient(false)).isInstanceOf(IllegalStateException.class)
.hasStackTraceContaining("Timeout on blocking read");
}
private RSocketRequester createRSocketTcpClient(boolean ssl) {
TcpClient tcpClient = createTcpClient(ssl);
return createRSocketRequesterBuilder().connect(TcpClientTransport.create(tcpClient)).block(TIMEOUT);
}
private RSocketRequester createRSocketWebSocketClient(boolean ssl) {
TcpClient tcpClient = createTcpClient(ssl);
return createRSocketRequesterBuilder().connect(WebsocketClientTransport.create(tcpClient)).block(TIMEOUT);
}
private TcpClient createTcpClient(boolean ssl) {
Assertions.assertThat(this.server).isNotNull(); Assertions.assertThat(this.server).isNotNull();
InetSocketAddress address = this.server.address(); InetSocketAddress address = this.server.address();
return createRSocketRequesterBuilder().transport(WebsocketClientTransport.create(address)); TcpClient tcpClient = TcpClient.create().host(address.getHostName()).port(address.getPort());
if (ssl) {
SslContextBuilder builder = SslContextBuilder.forClient().sslProvider(SslProvider.JDK)
.trustManager(InsecureTrustManagerFactory.INSTANCE);
tcpClient = tcpClient.secure((spec) -> spec.sslContext(builder));
}
return tcpClient;
} }
private RSocketRequester.Builder createRSocketRequesterBuilder() { private RSocketRequester.Builder createRSocketRequesterBuilder() {
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment