Commit 615c6d4e authored by Phillip Webb's avatar Phillip Webb

Restructure RSocket packages and polish

Polish code and relocate `RSocketServerBootstrap` from `server` to
`context` since it's really an `ApplicationContext` concern.

Closes gh-18391
parent de393aba
...@@ -51,7 +51,7 @@ public class RSocketProperties { ...@@ -51,7 +51,7 @@ public class RSocketProperties {
/** /**
* RSocket transport protocol. * RSocket transport protocol.
*/ */
private RSocketServer.TRANSPORT transport = RSocketServer.TRANSPORT.TCP; private RSocketServer.Transport transport = RSocketServer.Transport.TCP;
/** /**
* Path under which RSocket handles requests (only works with websocket * Path under which RSocket handles requests (only works with websocket
...@@ -75,11 +75,11 @@ public class RSocketProperties { ...@@ -75,11 +75,11 @@ public class RSocketProperties {
this.address = address; this.address = address;
} }
public RSocketServer.TRANSPORT getTransport() { public RSocketServer.Transport getTransport() {
return this.transport; return this.transport;
} }
public void setTransport(RSocketServer.TRANSPORT transport) { public void setTransport(RSocketServer.Transport transport) {
this.transport = transport; this.transport = transport;
} }
......
...@@ -34,8 +34,8 @@ import org.springframework.boot.autoconfigure.condition.ConditionalOnProperty; ...@@ -34,8 +34,8 @@ import org.springframework.boot.autoconfigure.condition.ConditionalOnProperty;
import org.springframework.boot.autoconfigure.condition.ConditionalOnWebApplication; import org.springframework.boot.autoconfigure.condition.ConditionalOnWebApplication;
import org.springframework.boot.context.properties.EnableConfigurationProperties; import org.springframework.boot.context.properties.EnableConfigurationProperties;
import org.springframework.boot.context.properties.PropertyMapper; import org.springframework.boot.context.properties.PropertyMapper;
import org.springframework.boot.rsocket.context.RSocketServerBootstrap;
import org.springframework.boot.rsocket.netty.NettyRSocketServerFactory; import org.springframework.boot.rsocket.netty.NettyRSocketServerFactory;
import org.springframework.boot.rsocket.server.RSocketServerBootstrap;
import org.springframework.boot.rsocket.server.RSocketServerFactory; import org.springframework.boot.rsocket.server.RSocketServerFactory;
import org.springframework.boot.rsocket.server.ServerRSocketFactoryCustomizer; import org.springframework.boot.rsocket.server.ServerRSocketFactoryCustomizer;
import org.springframework.context.annotation.Bean; import org.springframework.context.annotation.Bean;
......
...@@ -20,7 +20,7 @@ import org.junit.jupiter.api.Test; ...@@ -20,7 +20,7 @@ import org.junit.jupiter.api.Test;
import org.springframework.boot.autoconfigure.AutoConfigurations; import org.springframework.boot.autoconfigure.AutoConfigurations;
import org.springframework.boot.rsocket.context.RSocketPortInfoApplicationContextInitializer; import org.springframework.boot.rsocket.context.RSocketPortInfoApplicationContextInitializer;
import org.springframework.boot.rsocket.server.RSocketServerBootstrap; import org.springframework.boot.rsocket.context.RSocketServerBootstrap;
import org.springframework.boot.rsocket.server.RSocketServerFactory; import org.springframework.boot.rsocket.server.RSocketServerFactory;
import org.springframework.boot.rsocket.server.ServerRSocketFactoryCustomizer; import org.springframework.boot.rsocket.server.ServerRSocketFactoryCustomizer;
import org.springframework.boot.test.context.runner.ApplicationContextRunner; import org.springframework.boot.test.context.runner.ApplicationContextRunner;
......
...@@ -63,7 +63,7 @@ public class RSocketPortInfoApplicationContextInitializer ...@@ -63,7 +63,7 @@ public class RSocketPortInfoApplicationContextInitializer
@Override @Override
public void onApplicationEvent(RSocketServerInitializedEvent event) { public void onApplicationEvent(RSocketServerInitializedEvent event) {
setPortProperty(this.applicationContext, event.getrSocketServer().address().getPort()); setPortProperty(this.applicationContext, event.getServer().address().getPort());
} }
private void setPortProperty(ApplicationContext context, int port) { private void setPortProperty(ApplicationContext context, int port) {
......
...@@ -14,14 +14,16 @@ ...@@ -14,14 +14,16 @@
* limitations under the License. * limitations under the License.
*/ */
package org.springframework.boot.rsocket.server; package org.springframework.boot.rsocket.context;
import io.rsocket.SocketAcceptor; import io.rsocket.SocketAcceptor;
import org.springframework.boot.rsocket.context.RSocketServerInitializedEvent; import org.springframework.boot.rsocket.server.RSocketServer;
import org.springframework.boot.rsocket.server.RSocketServerFactory;
import org.springframework.context.ApplicationEventPublisher; import org.springframework.context.ApplicationEventPublisher;
import org.springframework.context.ApplicationEventPublisherAware; import org.springframework.context.ApplicationEventPublisherAware;
import org.springframework.context.SmartLifecycle; import org.springframework.context.SmartLifecycle;
import org.springframework.util.Assert;
/** /**
* Bootstrap an {@link RSocketServer} and start it with the application context. * Bootstrap an {@link RSocketServer} and start it with the application context.
...@@ -31,33 +33,34 @@ import org.springframework.context.SmartLifecycle; ...@@ -31,33 +33,34 @@ import org.springframework.context.SmartLifecycle;
*/ */
public class RSocketServerBootstrap implements ApplicationEventPublisherAware, SmartLifecycle { public class RSocketServerBootstrap implements ApplicationEventPublisherAware, SmartLifecycle {
private final RSocketServer rSocketServer; private final RSocketServer server;
private ApplicationEventPublisher applicationEventPublisher; private ApplicationEventPublisher eventPublisher;
public RSocketServerBootstrap(RSocketServerFactory serverFactoryProvider, SocketAcceptor socketAcceptor) { public RSocketServerBootstrap(RSocketServerFactory serverFactory, SocketAcceptor socketAcceptor) {
this.rSocketServer = serverFactoryProvider.create(socketAcceptor); Assert.notNull(serverFactory, "ServerFactory must not be null");
this.server = serverFactory.create(socketAcceptor);
} }
@Override @Override
public void setApplicationEventPublisher(ApplicationEventPublisher applicationEventPublisher) { public void setApplicationEventPublisher(ApplicationEventPublisher applicationEventPublisher) {
this.applicationEventPublisher = applicationEventPublisher; this.eventPublisher = applicationEventPublisher;
} }
@Override @Override
public void start() { public void start() {
this.rSocketServer.start(); this.server.start();
this.applicationEventPublisher.publishEvent(new RSocketServerInitializedEvent(this.rSocketServer)); this.eventPublisher.publishEvent(new RSocketServerInitializedEvent(this.server));
} }
@Override @Override
public void stop() { public void stop() {
this.rSocketServer.stop(); this.server.stop();
} }
@Override @Override
public boolean isRunning() { public boolean isRunning() {
RSocketServer server = this.rSocketServer; RSocketServer server = this.server;
if (server != null) { if (server != null) {
return server.address() != null; return server.address() != null;
} }
......
...@@ -27,18 +27,17 @@ import org.springframework.context.ApplicationEvent; ...@@ -27,18 +27,17 @@ import org.springframework.context.ApplicationEvent;
* @author Brian Clozel * @author Brian Clozel
* @since 2.2.0 * @since 2.2.0
*/ */
@SuppressWarnings("serial")
public class RSocketServerInitializedEvent extends ApplicationEvent { public class RSocketServerInitializedEvent extends ApplicationEvent {
public RSocketServerInitializedEvent(RSocketServer rSocketServer) { public RSocketServerInitializedEvent(RSocketServer server) {
super(rSocketServer); super(server);
} }
/** /**
* Access the {@link RSocketServer}. * Access the {@link RSocketServer}.
* @return the embedded RSocket server * @return the embedded RSocket server
*/ */
public RSocketServer getrSocketServer() { public RSocketServer getServer() {
return getSource(); return getSource();
} }
......
...@@ -61,25 +61,13 @@ public class NettyRSocketServer implements RSocketServer { ...@@ -61,25 +61,13 @@ public class NettyRSocketServer implements RSocketServer {
@Override @Override
public void start() throws RSocketServerException { public void start() throws RSocketServerException {
if (this.lifecycleTimeout != null) { this.channel = block(this.starter, this.lifecycleTimeout);
this.channel = this.starter.block(this.lifecycleTimeout);
}
else {
this.channel = this.starter.block();
}
logger.info("Netty RSocket started on port(s): " + address().getPort()); logger.info("Netty RSocket started on port(s): " + address().getPort());
startDaemonAwaitThread(this.channel); startDaemonAwaitThread(this.channel);
} }
private void startDaemonAwaitThread(CloseableChannel channel) { private void startDaemonAwaitThread(CloseableChannel channel) {
Thread awaitThread = new Thread("rsocket") { Thread awaitThread = new Thread(() -> channel.onClose().block(), "rsocket");
@Override
public void run() {
channel.onClose().block();
}
};
awaitThread.setContextClassLoader(getClass().getClassLoader()); awaitThread.setContextClassLoader(getClass().getClassLoader());
awaitThread.setDaemon(false); awaitThread.setDaemon(false);
awaitThread.start(); awaitThread.start();
...@@ -93,4 +81,8 @@ public class NettyRSocketServer implements RSocketServer { ...@@ -93,4 +81,8 @@ public class NettyRSocketServer implements RSocketServer {
} }
} }
private <T> T block(Mono<T> mono, Duration timeout) {
return (timeout != null) ? mono.block(timeout) : mono.block();
}
} }
...@@ -54,7 +54,7 @@ public class NettyRSocketServerFactory implements RSocketServerFactory, Configur ...@@ -54,7 +54,7 @@ public class NettyRSocketServerFactory implements RSocketServerFactory, Configur
private InetAddress address; private InetAddress address;
private RSocketServer.TRANSPORT transport = RSocketServer.TRANSPORT.TCP; private RSocketServer.Transport transport = RSocketServer.Transport.TCP;
private ReactorResourceFactory resourceFactory; private ReactorResourceFactory resourceFactory;
...@@ -73,7 +73,7 @@ public class NettyRSocketServerFactory implements RSocketServerFactory, Configur ...@@ -73,7 +73,7 @@ public class NettyRSocketServerFactory implements RSocketServerFactory, Configur
} }
@Override @Override
public void setTransport(RSocketServer.TRANSPORT transport) { public void setTransport(RSocketServer.Transport transport) {
this.transport = transport; this.transport = transport;
} }
...@@ -126,26 +126,28 @@ public class NettyRSocketServerFactory implements RSocketServerFactory, Configur ...@@ -126,26 +126,28 @@ public class NettyRSocketServerFactory implements RSocketServerFactory, Configur
} }
private ServerTransport<CloseableChannel> createTransport() { private ServerTransport<CloseableChannel> createTransport() {
if (this.transport == RSocketServer.TRANSPORT.WEBSOCKET) { if (this.transport == RSocketServer.Transport.WEBSOCKET) {
if (this.resourceFactory != null) { return createWebSocketTransport();
HttpServer httpServer = HttpServer.create().tcpConfiguration((tcpServer) -> tcpServer
.runOn(this.resourceFactory.getLoopResources()).addressSupplier(this::getListenAddress));
return WebsocketServerTransport.create(httpServer);
}
else {
return WebsocketServerTransport.create(getListenAddress());
}
} }
else { return createTcpTransport();
if (this.resourceFactory != null) { }
TcpServer tcpServer = TcpServer.create().runOn(this.resourceFactory.getLoopResources())
.addressSupplier(this::getListenAddress); private ServerTransport<CloseableChannel> createWebSocketTransport() {
return TcpServerTransport.create(tcpServer); if (this.resourceFactory != null) {
} HttpServer httpServer = HttpServer.create().tcpConfiguration((tcpServer) -> tcpServer
else { .runOn(this.resourceFactory.getLoopResources()).addressSupplier(this::getListenAddress));
return TcpServerTransport.create(getListenAddress()); return WebsocketServerTransport.create(httpServer);
} }
return WebsocketServerTransport.create(getListenAddress());
}
private ServerTransport<CloseableChannel> createTcpTransport() {
if (this.resourceFactory != null) {
TcpServer tcpServer = TcpServer.create().runOn(this.resourceFactory.getLoopResources())
.addressSupplier(this::getListenAddress);
return TcpServerTransport.create(tcpServer);
} }
return TcpServerTransport.create(getListenAddress());
} }
private InetSocketAddress getListenAddress() { private InetSocketAddress getListenAddress() {
......
...@@ -43,6 +43,6 @@ public interface ConfigurableRSocketServerFactory { ...@@ -43,6 +43,6 @@ public interface ConfigurableRSocketServerFactory {
* Set the transport that the RSocket server should use. * Set the transport that the RSocket server should use.
* @param transport the transport protocol to use * @param transport the transport protocol to use
*/ */
void setTransport(RSocketServer.TRANSPORT transport); void setTransport(RSocketServer.Transport transport);
} }
...@@ -50,9 +50,17 @@ public interface RSocketServer { ...@@ -50,9 +50,17 @@ public interface RSocketServer {
/** /**
* Choice of transport protocol for the RSocket server. * Choice of transport protocol for the RSocket server.
*/ */
enum TRANSPORT { enum Transport {
TCP, WEBSOCKET /**
* TCP transport protocol.
*/
TCP,
/**
* WebSocket transport protocol.
*/
WEBSOCKET
} }
......
...@@ -22,7 +22,6 @@ package org.springframework.boot.rsocket.server; ...@@ -22,7 +22,6 @@ package org.springframework.boot.rsocket.server;
* @author Brian Clozel * @author Brian Clozel
* @since 2.2.0 * @since 2.2.0
*/ */
@SuppressWarnings("serial")
public class RSocketServerException extends RuntimeException { public class RSocketServerException extends RuntimeException {
public RSocketServerException(String message, Throwable cause) { public RSocketServerException(String message, Throwable cause) {
......
...@@ -59,7 +59,7 @@ import static org.mockito.Mockito.mock; ...@@ -59,7 +59,7 @@ import static org.mockito.Mockito.mock;
*/ */
class NettyRSocketServerFactoryTests { class NettyRSocketServerFactoryTests {
private NettyRSocketServer rSocketServer; private NettyRSocketServer server;
private RSocketRequester requester; private RSocketRequester requester;
...@@ -67,9 +67,9 @@ class NettyRSocketServerFactoryTests { ...@@ -67,9 +67,9 @@ class NettyRSocketServerFactoryTests {
@AfterEach @AfterEach
void tearDown() { void tearDown() {
if (this.rSocketServer != null) { if (this.server != null) {
try { try {
this.rSocketServer.stop(); this.server.stop();
} }
catch (Exception ex) { catch (Exception ex) {
// Ignore // Ignore
...@@ -89,47 +89,44 @@ class NettyRSocketServerFactoryTests { ...@@ -89,47 +89,44 @@ class NettyRSocketServerFactoryTests {
NettyRSocketServerFactory factory = getFactory(); NettyRSocketServerFactory factory = getFactory();
int specificPort = SocketUtils.findAvailableTcpPort(41000); int specificPort = SocketUtils.findAvailableTcpPort(41000);
factory.setPort(specificPort); factory.setPort(specificPort);
this.rSocketServer = factory.create(new EchoRequestResponseAcceptor()); this.server = factory.create(new EchoRequestResponseAcceptor());
this.rSocketServer.start(); this.server.start();
this.requester = createRSocketTcpClient(); this.requester = createRSocketTcpClient();
String payload = "test payload"; String payload = "test payload";
String response = this.requester.route("test").data(payload).retrieveMono(String.class).block(TIMEOUT); String response = this.requester.route("test").data(payload).retrieveMono(String.class).block(TIMEOUT);
assertThat(this.server.address().getPort()).isEqualTo(specificPort);
assertThat(this.rSocketServer.address().getPort()).isEqualTo(specificPort);
assertThat(response).isEqualTo(payload); assertThat(response).isEqualTo(payload);
assertThat(this.rSocketServer.address().getPort()).isEqualTo(specificPort); assertThat(this.server.address().getPort()).isEqualTo(specificPort);
} }
@Test @Test
void websocketTransport() { void websocketTransport() {
NettyRSocketServerFactory factory = getFactory(); NettyRSocketServerFactory factory = getFactory();
factory.setTransport(RSocketServer.TRANSPORT.WEBSOCKET); factory.setTransport(RSocketServer.Transport.WEBSOCKET);
this.rSocketServer = factory.create(new EchoRequestResponseAcceptor()); this.server = factory.create(new EchoRequestResponseAcceptor());
this.rSocketServer.start(); this.server.start();
this.requester = createRSocketWebSocketClient(); this.requester = createRSocketWebSocketClient();
String payload = "test payload"; String payload = "test payload";
String response = this.requester.route("test").data(payload).retrieveMono(String.class).block(TIMEOUT); String response = this.requester.route("test").data(payload).retrieveMono(String.class).block(TIMEOUT);
assertThat(response).isEqualTo(payload); assertThat(response).isEqualTo(payload);
} }
@Test @Test
void websocketTransportWithReactorResource() { void websocketTransportWithReactorResource() {
NettyRSocketServerFactory factory = getFactory(); NettyRSocketServerFactory factory = getFactory();
factory.setTransport(RSocketServer.TRANSPORT.WEBSOCKET); factory.setTransport(RSocketServer.Transport.WEBSOCKET);
ReactorResourceFactory resourceFactory = new ReactorResourceFactory(); ReactorResourceFactory resourceFactory = new ReactorResourceFactory();
resourceFactory.afterPropertiesSet(); resourceFactory.afterPropertiesSet();
factory.setResourceFactory(resourceFactory); factory.setResourceFactory(resourceFactory);
int specificPort = SocketUtils.findAvailableTcpPort(41000); int specificPort = SocketUtils.findAvailableTcpPort(41000);
factory.setPort(specificPort); factory.setPort(specificPort);
this.rSocketServer = factory.create(new EchoRequestResponseAcceptor()); this.server = factory.create(new EchoRequestResponseAcceptor());
this.rSocketServer.start(); this.server.start();
this.requester = createRSocketWebSocketClient(); this.requester = createRSocketWebSocketClient();
String payload = "test payload"; String payload = "test payload";
String response = this.requester.route("test").data(payload).retrieveMono(String.class).block(TIMEOUT); String response = this.requester.route("test").data(payload).retrieveMono(String.class).block(TIMEOUT);
assertThat(response).isEqualTo(payload); assertThat(response).isEqualTo(payload);
assertThat(this.rSocketServer.address().getPort()).isEqualTo(specificPort); assertThat(this.server.address().getPort()).isEqualTo(specificPort);
} }
@Test @Test
...@@ -142,7 +139,7 @@ class NettyRSocketServerFactoryTests { ...@@ -142,7 +139,7 @@ class NettyRSocketServerFactoryTests {
.will((invocation) -> invocation.getArgument(0)); .will((invocation) -> invocation.getArgument(0));
} }
factory.setServerCustomizers(Arrays.asList(customizers[0], customizers[1])); factory.setServerCustomizers(Arrays.asList(customizers[0], customizers[1]));
this.rSocketServer = factory.create(new EchoRequestResponseAcceptor()); this.server = factory.create(new EchoRequestResponseAcceptor());
InOrder ordered = inOrder((Object[]) customizers); InOrder ordered = inOrder((Object[]) customizers);
for (ServerRSocketFactoryCustomizer customizer : customizers) { for (ServerRSocketFactoryCustomizer customizer : customizers) {
ordered.verify(customizer).apply(any(RSocketFactory.ServerRSocketFactory.class)); ordered.verify(customizer).apply(any(RSocketFactory.ServerRSocketFactory.class));
...@@ -150,14 +147,14 @@ class NettyRSocketServerFactoryTests { ...@@ -150,14 +147,14 @@ class NettyRSocketServerFactoryTests {
} }
private RSocketRequester createRSocketTcpClient() { private RSocketRequester createRSocketTcpClient() {
Assertions.assertThat(this.rSocketServer).isNotNull(); Assertions.assertThat(this.server).isNotNull();
InetSocketAddress address = this.rSocketServer.address(); InetSocketAddress address = this.server.address();
return createRSocketRequesterBuilder().connectTcp(address.getHostString(), address.getPort()).block(); return createRSocketRequesterBuilder().connectTcp(address.getHostString(), address.getPort()).block();
} }
private RSocketRequester createRSocketWebSocketClient() { private RSocketRequester createRSocketWebSocketClient() {
Assertions.assertThat(this.rSocketServer).isNotNull(); Assertions.assertThat(this.server).isNotNull();
InetSocketAddress address = this.rSocketServer.address(); InetSocketAddress address = this.server.address();
return createRSocketRequesterBuilder().connect(WebsocketClientTransport.create(address)).block(); return createRSocketRequesterBuilder().connect(WebsocketClientTransport.create(address)).block();
} }
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment