Commit b4810b8b authored by cbono's avatar cbono Committed by Brian Clozel

Add SSL support to RSocketServer

See gh-19399
parent dd024048
......@@ -19,12 +19,15 @@ package org.springframework.boot.autoconfigure.rsocket;
import java.net.InetAddress;
import org.springframework.boot.context.properties.ConfigurationProperties;
import org.springframework.boot.context.properties.NestedConfigurationProperty;
import org.springframework.boot.rsocket.server.RSocketServer;
import org.springframework.boot.web.server.Ssl;
/**
* {@link ConfigurationProperties properties} for RSocket support.
*
* @author Brian Clozel
* @author Chris Bono
* @since 2.2.0
*/
@ConfigurationProperties("spring.rsocket")
......@@ -59,6 +62,9 @@ public class RSocketProperties {
*/
private String mappingPath;
@NestedConfigurationProperty
private Ssl ssl;
public Integer getPort() {
return this.port;
}
......@@ -91,6 +97,14 @@ public class RSocketProperties {
this.mappingPath = mappingPath;
}
public Ssl getSsl() {
return this.ssl;
}
public void setSsl(Ssl ssl) {
this.ssl = ssl;
}
}
}
......@@ -97,6 +97,7 @@ public class RSocketServerAutoConfiguration {
PropertyMapper map = PropertyMapper.get().alwaysApplyingWhenNonNull();
map.from(properties.getServer().getAddress()).to(factory::setAddress);
map.from(properties.getServer().getPort()).to(factory::setPort);
map.from(properties.getServer().getSsl()).to(factory::setSsl);
factory.setRSocketServerCustomizers(customizers.orderedStream().collect(Collectors.toList()));
return factory;
}
......
......@@ -91,6 +91,18 @@ class RSocketServerAutoConfigurationTests {
});
}
@Test
void shouldUseSslWhenRocketServerSslIsConfigured() {
reactiveWebContextRunner()
.withPropertyValues("spring.rsocket.server.ssl.keyStore=classpath:rsocket/test.jks",
"spring.rsocket.server.ssl.keyPassword=password", "spring.rsocket.server.port=0")
.run((context) -> assertThat(context).hasSingleBean(RSocketServerFactory.class)
.hasSingleBean(RSocketServerBootstrap.class).hasSingleBean(RSocketServerCustomizer.class)
.getBean(RSocketServerFactory.class)
.hasFieldOrPropertyWithValue("ssl.keyStore", "classpath:rsocket/test.jks")
.hasFieldOrPropertyWithValue("ssl.keyPassword", "password"));
}
@Test
void shouldUseCustomServerBootstrap() {
contextRunner().withUserConfiguration(CustomServerBootstrapConfig.class).run((context) -> assertThat(context)
......
......@@ -37,6 +37,9 @@ import org.springframework.boot.rsocket.server.ConfigurableRSocketServerFactory;
import org.springframework.boot.rsocket.server.RSocketServer;
import org.springframework.boot.rsocket.server.RSocketServerCustomizer;
import org.springframework.boot.rsocket.server.RSocketServerFactory;
import org.springframework.boot.web.embedded.netty.SslServerCustomizer;
import org.springframework.boot.web.server.Ssl;
import org.springframework.boot.web.server.SslStoreProvider;
import org.springframework.http.client.reactive.ReactorResourceFactory;
import org.springframework.util.Assert;
......@@ -45,6 +48,7 @@ import org.springframework.util.Assert;
* by Netty.
*
* @author Brian Clozel
* @author Chris Bono
* @since 2.2.0
*/
public class NettyRSocketServerFactory implements RSocketServerFactory, ConfigurableRSocketServerFactory {
......@@ -61,6 +65,10 @@ public class NettyRSocketServerFactory implements RSocketServerFactory, Configur
private List<RSocketServerCustomizer> rSocketServerCustomizers = new ArrayList<>();
private Ssl ssl;
private SslStoreProvider sslStoreProvider;
@Override
public void setPort(int port) {
this.port = port;
......@@ -76,6 +84,16 @@ public class NettyRSocketServerFactory implements RSocketServerFactory, Configur
this.transport = transport;
}
@Override
public void setSsl(Ssl ssl) {
this.ssl = ssl;
}
@Override
public void setSslStoreProvider(SslStoreProvider sslStoreProvider) {
this.sslStoreProvider = sslStoreProvider;
}
/**
* Set the {@link ReactorResourceFactory} to get the shared resources from.
* @param resourceFactory the server resources
......@@ -133,21 +151,41 @@ public class NettyRSocketServerFactory implements RSocketServerFactory, Configur
}
private ServerTransport<CloseableChannel> createWebSocketTransport() {
HttpServer httpServer;
if (this.resourceFactory != null) {
HttpServer httpServer = HttpServer.create().runOn(this.resourceFactory.getLoopResources())
httpServer = HttpServer.create().runOn(this.resourceFactory.getLoopResources())
.bindAddress(this::getListenAddress);
return WebsocketServerTransport.create(httpServer);
}
return WebsocketServerTransport.create(getListenAddress());
else {
InetSocketAddress listenAddress = this.getListenAddress();
httpServer = HttpServer.create().host(listenAddress.getHostName()).port(listenAddress.getPort());
}
if (this.ssl != null && this.ssl.isEnabled()) {
SslServerCustomizer sslServerCustomizer = new SslServerCustomizer(this.ssl, null, this.sslStoreProvider);
httpServer = sslServerCustomizer.apply(httpServer);
}
return WebsocketServerTransport.create(httpServer);
}
private ServerTransport<CloseableChannel> createTcpTransport() {
TcpServer tcpServer;
if (this.resourceFactory != null) {
TcpServer tcpServer = TcpServer.create().runOn(this.resourceFactory.getLoopResources())
tcpServer = TcpServer.create().runOn(this.resourceFactory.getLoopResources())
.bindAddress(this::getListenAddress);
return TcpServerTransport.create(tcpServer);
}
return TcpServerTransport.create(getListenAddress());
else {
InetSocketAddress listenAddress = this.getListenAddress();
tcpServer = TcpServer.create().host(listenAddress.getHostName()).port(listenAddress.getPort());
}
if (this.ssl != null && this.ssl.isEnabled()) {
TcpSslServerCustomizer sslServerCustomizer = new TcpSslServerCustomizer(this.ssl, this.sslStoreProvider);
tcpServer = sslServerCustomizer.apply(tcpServer);
}
return TcpServerTransport.create(tcpServer);
}
private InetSocketAddress getListenAddress() {
......@@ -157,4 +195,24 @@ public class NettyRSocketServerFactory implements RSocketServerFactory, Configur
return new InetSocketAddress(this.port);
}
private static final class TcpSslServerCustomizer extends SslServerCustomizer {
private TcpSslServerCustomizer(Ssl ssl, SslStoreProvider sslStoreProvider) {
super(ssl, null, sslStoreProvider);
}
// This does not override the apply in parent - currently just leveraging the
// parent for its "getContextBuilder()" method. This should be refactored when
// we add the concept of http/tcp customizers for RSocket.
private TcpServer apply(TcpServer server) {
try {
return server.secure((contextSpec) -> contextSpec.sslContext(getContextBuilder()));
}
catch (Exception ex) {
throw new IllegalStateException(ex);
}
}
}
}
......@@ -18,6 +18,9 @@ package org.springframework.boot.rsocket.server;
import java.net.InetAddress;
import org.springframework.boot.web.server.Ssl;
import org.springframework.boot.web.server.SslStoreProvider;
/**
* A configurable {@link RSocketServerFactory}.
*
......@@ -45,4 +48,16 @@ public interface ConfigurableRSocketServerFactory {
*/
void setTransport(RSocketServer.Transport transport);
/**
* Sets the SSL configuration that will be applied to the server's default connector.
* @param ssl the SSL configuration
*/
void setSsl(Ssl ssl);
/**
* Sets a provider that will be used to obtain SSL stores.
* @param sslStoreProvider the SSL store provider
*/
void setSslStoreProvider(SslStoreProvider sslStoreProvider);
}
......@@ -17,15 +17,20 @@
package org.springframework.boot.rsocket.netty;
import java.net.InetSocketAddress;
import java.nio.channels.ClosedChannelException;
import java.time.Duration;
import java.util.Arrays;
import java.util.concurrent.Callable;
import io.netty.buffer.PooledByteBufAllocator;
import io.netty.handler.ssl.SslContextBuilder;
import io.netty.handler.ssl.SslProvider;
import io.netty.handler.ssl.util.InsecureTrustManagerFactory;
import io.rsocket.ConnectionSetupPayload;
import io.rsocket.Payload;
import io.rsocket.RSocket;
import io.rsocket.SocketAcceptor;
import io.rsocket.transport.netty.client.TcpClientTransport;
import io.rsocket.transport.netty.client.WebsocketClientTransport;
import io.rsocket.util.DefaultPayload;
import org.assertj.core.api.Assertions;
......@@ -33,9 +38,13 @@ import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.Test;
import org.mockito.InOrder;
import reactor.core.publisher.Mono;
import reactor.netty.tcp.TcpClient;
import reactor.test.StepVerifier;
import org.springframework.boot.rsocket.server.RSocketServer;
import org.springframework.boot.rsocket.server.RSocketServerCustomizer;
import org.springframework.boot.rsocket.server.RSocketServer.Transport;
import org.springframework.boot.web.server.Ssl;
import org.springframework.core.codec.CharSequenceEncoder;
import org.springframework.core.codec.StringDecoder;
import org.springframework.core.io.buffer.NettyDataBufferFactory;
......@@ -45,6 +54,8 @@ import org.springframework.messaging.rsocket.RSocketStrategies;
import org.springframework.util.SocketUtils;
import static org.assertj.core.api.Assertions.assertThat;
import static org.assertj.core.api.Assertions.assertThatExceptionOfType;
import static org.assertj.core.api.Assertions.assertThatThrownBy;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.BDDMockito.will;
import static org.mockito.Mockito.inOrder;
......@@ -55,6 +66,7 @@ import static org.mockito.Mockito.mock;
*
* @author Brian Clozel
* @author Leo Li
* @author Chris Bono
*/
class NettyRSocketServerFactoryTests {
......@@ -93,7 +105,7 @@ class NettyRSocketServerFactoryTests {
this.server.start();
return port;
});
this.requester = createRSocketTcpClient();
this.requester = createRSocketTcpClient(false);
String payload = "test payload";
String response = this.requester.route("test").data(payload).retrieveMono(String.class).block(TIMEOUT);
assertThat(this.server.address().getPort()).isEqualTo(specificPort);
......@@ -106,7 +118,7 @@ class NettyRSocketServerFactoryTests {
factory.setTransport(RSocketServer.Transport.WEBSOCKET);
this.server = factory.create(new EchoRequestResponseAcceptor());
this.server.start();
this.requester = createRSocketWebSocketClient();
this.requester = createRSocketWebSocketClient(false);
String payload = "test payload";
String response = this.requester.route("test").data(payload).retrieveMono(String.class).block(TIMEOUT);
assertThat(response).isEqualTo(payload);
......@@ -121,7 +133,7 @@ class NettyRSocketServerFactoryTests {
factory.setResourceFactory(resourceFactory);
this.server = factory.create(new EchoRequestResponseAcceptor());
this.server.start();
this.requester = createRSocketWebSocketClient();
this.requester = createRSocketWebSocketClient(false);
String payload = "test payload";
String response = this.requester.route("test").data(payload).retrieveMono(String.class).block(TIMEOUT);
assertThat(response).isEqualTo(payload);
......@@ -144,16 +156,94 @@ class NettyRSocketServerFactoryTests {
}
}
private RSocketRequester createRSocketTcpClient() {
Assertions.assertThat(this.server).isNotNull();
InetSocketAddress address = this.server.address();
return createRSocketRequesterBuilder().tcp(address.getHostString(), address.getPort());
@Test
void tcpTransportBasicSslFromClassPath() {
testBasicSslWithKeyStore("classpath:test.jks", "password", Transport.TCP);
}
@Test
void tcpTransportBasicSslFromFileSystem() {
testBasicSslWithKeyStore("src/test/resources/test.jks", "password", Transport.TCP);
}
@Test
void websocketTransportBasicSslFromClassPath() {
testBasicSslWithKeyStore("classpath:test.jks", "password", Transport.WEBSOCKET);
}
@Test
void websocketTransportBasicSslFromFileSystem() {
testBasicSslWithKeyStore("src/test/resources/test.jks", "password", Transport.WEBSOCKET);
}
private void testBasicSslWithKeyStore(String keyStore, String keyPassword, Transport transport) {
NettyRSocketServerFactory factory = getFactory();
factory.setTransport(transport);
Ssl ssl = new Ssl();
ssl.setKeyStore(keyStore);
ssl.setKeyPassword(keyPassword);
factory.setSsl(ssl);
this.server = factory.create(new EchoRequestResponseAcceptor());
this.server.start();
this.requester = (transport == Transport.TCP) ? createRSocketTcpClient(true)
: createRSocketWebSocketClient(true);
String payload = "test payload";
Mono<String> responseMono = this.requester.route("test").data(payload).retrieveMono(String.class);
StepVerifier.create(responseMono).expectNext(payload).verifyComplete();
}
private RSocketRequester createRSocketWebSocketClient() {
@Test
void tcpTransportSslRejectsInsecureClient() {
NettyRSocketServerFactory factory = getFactory();
factory.setTransport(Transport.TCP);
Ssl ssl = new Ssl();
ssl.setKeyStore("classpath:test.jks");
ssl.setKeyPassword("password");
factory.setSsl(ssl);
this.server = factory.create(new EchoRequestResponseAcceptor());
this.server.start();
this.requester = createRSocketTcpClient(false);
String payload = "test payload";
Mono<String> responseMono = this.requester.route("test").data(payload).retrieveMono(String.class);
StepVerifier.create(responseMono)
.verifyErrorSatisfies((ex) -> assertThatExceptionOfType(ClosedChannelException.class));
}
@Test
void websocketTransportSslRejectsInsecureClient() {
NettyRSocketServerFactory factory = getFactory();
factory.setTransport(Transport.WEBSOCKET);
Ssl ssl = new Ssl();
ssl.setKeyStore("classpath:test.jks");
ssl.setKeyPassword("password");
factory.setSsl(ssl);
this.server = factory.create(new EchoRequestResponseAcceptor());
this.server.start();
// For WebSocket, the SSL failure results in a hang on the initial connect call
assertThatThrownBy(() -> createRSocketWebSocketClient(false)).isInstanceOf(IllegalStateException.class)
.hasStackTraceContaining("Timeout on blocking read");
}
private RSocketRequester createRSocketTcpClient(boolean ssl) {
TcpClient tcpClient = createTcpClient(ssl);
return createRSocketRequesterBuilder().connect(TcpClientTransport.create(tcpClient)).block(TIMEOUT);
}
private RSocketRequester createRSocketWebSocketClient(boolean ssl) {
TcpClient tcpClient = createTcpClient(ssl);
return createRSocketRequesterBuilder().connect(WebsocketClientTransport.create(tcpClient)).block(TIMEOUT);
}
private TcpClient createTcpClient(boolean ssl) {
Assertions.assertThat(this.server).isNotNull();
InetSocketAddress address = this.server.address();
return createRSocketRequesterBuilder().transport(WebsocketClientTransport.create(address));
TcpClient tcpClient = TcpClient.create().host(address.getHostName()).port(address.getPort());
if (ssl) {
SslContextBuilder builder = SslContextBuilder.forClient().sslProvider(SslProvider.JDK)
.trustManager(InsecureTrustManagerFactory.INSTANCE);
tcpClient = tcpClient.secure((spec) -> spec.sslContext(builder));
}
return tcpClient;
}
private RSocketRequester.Builder createRSocketRequesterBuilder() {
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment