Commit 615c6d4e authored by Phillip Webb's avatar Phillip Webb

Restructure RSocket packages and polish

Polish code and relocate `RSocketServerBootstrap` from `server` to
`context` since it's really an `ApplicationContext` concern.

Closes gh-18391
parent de393aba
......@@ -51,7 +51,7 @@ public class RSocketProperties {
/**
* RSocket transport protocol.
*/
private RSocketServer.TRANSPORT transport = RSocketServer.TRANSPORT.TCP;
private RSocketServer.Transport transport = RSocketServer.Transport.TCP;
/**
* Path under which RSocket handles requests (only works with websocket
......@@ -75,11 +75,11 @@ public class RSocketProperties {
this.address = address;
}
public RSocketServer.TRANSPORT getTransport() {
public RSocketServer.Transport getTransport() {
return this.transport;
}
public void setTransport(RSocketServer.TRANSPORT transport) {
public void setTransport(RSocketServer.Transport transport) {
this.transport = transport;
}
......
......@@ -34,8 +34,8 @@ import org.springframework.boot.autoconfigure.condition.ConditionalOnProperty;
import org.springframework.boot.autoconfigure.condition.ConditionalOnWebApplication;
import org.springframework.boot.context.properties.EnableConfigurationProperties;
import org.springframework.boot.context.properties.PropertyMapper;
import org.springframework.boot.rsocket.context.RSocketServerBootstrap;
import org.springframework.boot.rsocket.netty.NettyRSocketServerFactory;
import org.springframework.boot.rsocket.server.RSocketServerBootstrap;
import org.springframework.boot.rsocket.server.RSocketServerFactory;
import org.springframework.boot.rsocket.server.ServerRSocketFactoryCustomizer;
import org.springframework.context.annotation.Bean;
......
......@@ -20,7 +20,7 @@ import org.junit.jupiter.api.Test;
import org.springframework.boot.autoconfigure.AutoConfigurations;
import org.springframework.boot.rsocket.context.RSocketPortInfoApplicationContextInitializer;
import org.springframework.boot.rsocket.server.RSocketServerBootstrap;
import org.springframework.boot.rsocket.context.RSocketServerBootstrap;
import org.springframework.boot.rsocket.server.RSocketServerFactory;
import org.springframework.boot.rsocket.server.ServerRSocketFactoryCustomizer;
import org.springframework.boot.test.context.runner.ApplicationContextRunner;
......
......@@ -63,7 +63,7 @@ public class RSocketPortInfoApplicationContextInitializer
@Override
public void onApplicationEvent(RSocketServerInitializedEvent event) {
setPortProperty(this.applicationContext, event.getrSocketServer().address().getPort());
setPortProperty(this.applicationContext, event.getServer().address().getPort());
}
private void setPortProperty(ApplicationContext context, int port) {
......
......@@ -14,14 +14,16 @@
* limitations under the License.
*/
package org.springframework.boot.rsocket.server;
package org.springframework.boot.rsocket.context;
import io.rsocket.SocketAcceptor;
import org.springframework.boot.rsocket.context.RSocketServerInitializedEvent;
import org.springframework.boot.rsocket.server.RSocketServer;
import org.springframework.boot.rsocket.server.RSocketServerFactory;
import org.springframework.context.ApplicationEventPublisher;
import org.springframework.context.ApplicationEventPublisherAware;
import org.springframework.context.SmartLifecycle;
import org.springframework.util.Assert;
/**
* Bootstrap an {@link RSocketServer} and start it with the application context.
......@@ -31,33 +33,34 @@ import org.springframework.context.SmartLifecycle;
*/
public class RSocketServerBootstrap implements ApplicationEventPublisherAware, SmartLifecycle {
private final RSocketServer rSocketServer;
private final RSocketServer server;
private ApplicationEventPublisher applicationEventPublisher;
private ApplicationEventPublisher eventPublisher;
public RSocketServerBootstrap(RSocketServerFactory serverFactoryProvider, SocketAcceptor socketAcceptor) {
this.rSocketServer = serverFactoryProvider.create(socketAcceptor);
public RSocketServerBootstrap(RSocketServerFactory serverFactory, SocketAcceptor socketAcceptor) {
Assert.notNull(serverFactory, "ServerFactory must not be null");
this.server = serverFactory.create(socketAcceptor);
}
@Override
public void setApplicationEventPublisher(ApplicationEventPublisher applicationEventPublisher) {
this.applicationEventPublisher = applicationEventPublisher;
this.eventPublisher = applicationEventPublisher;
}
@Override
public void start() {
this.rSocketServer.start();
this.applicationEventPublisher.publishEvent(new RSocketServerInitializedEvent(this.rSocketServer));
this.server.start();
this.eventPublisher.publishEvent(new RSocketServerInitializedEvent(this.server));
}
@Override
public void stop() {
this.rSocketServer.stop();
this.server.stop();
}
@Override
public boolean isRunning() {
RSocketServer server = this.rSocketServer;
RSocketServer server = this.server;
if (server != null) {
return server.address() != null;
}
......
......@@ -27,18 +27,17 @@ import org.springframework.context.ApplicationEvent;
* @author Brian Clozel
* @since 2.2.0
*/
@SuppressWarnings("serial")
public class RSocketServerInitializedEvent extends ApplicationEvent {
public RSocketServerInitializedEvent(RSocketServer rSocketServer) {
super(rSocketServer);
public RSocketServerInitializedEvent(RSocketServer server) {
super(server);
}
/**
* Access the {@link RSocketServer}.
* @return the embedded RSocket server
*/
public RSocketServer getrSocketServer() {
public RSocketServer getServer() {
return getSource();
}
......
......@@ -61,25 +61,13 @@ public class NettyRSocketServer implements RSocketServer {
@Override
public void start() throws RSocketServerException {
if (this.lifecycleTimeout != null) {
this.channel = this.starter.block(this.lifecycleTimeout);
}
else {
this.channel = this.starter.block();
}
this.channel = block(this.starter, this.lifecycleTimeout);
logger.info("Netty RSocket started on port(s): " + address().getPort());
startDaemonAwaitThread(this.channel);
}
private void startDaemonAwaitThread(CloseableChannel channel) {
Thread awaitThread = new Thread("rsocket") {
@Override
public void run() {
channel.onClose().block();
}
};
Thread awaitThread = new Thread(() -> channel.onClose().block(), "rsocket");
awaitThread.setContextClassLoader(getClass().getClassLoader());
awaitThread.setDaemon(false);
awaitThread.start();
......@@ -93,4 +81,8 @@ public class NettyRSocketServer implements RSocketServer {
}
}
private <T> T block(Mono<T> mono, Duration timeout) {
return (timeout != null) ? mono.block(timeout) : mono.block();
}
}
......@@ -54,7 +54,7 @@ public class NettyRSocketServerFactory implements RSocketServerFactory, Configur
private InetAddress address;
private RSocketServer.TRANSPORT transport = RSocketServer.TRANSPORT.TCP;
private RSocketServer.Transport transport = RSocketServer.Transport.TCP;
private ReactorResourceFactory resourceFactory;
......@@ -73,7 +73,7 @@ public class NettyRSocketServerFactory implements RSocketServerFactory, Configur
}
@Override
public void setTransport(RSocketServer.TRANSPORT transport) {
public void setTransport(RSocketServer.Transport transport) {
this.transport = transport;
}
......@@ -126,26 +126,28 @@ public class NettyRSocketServerFactory implements RSocketServerFactory, Configur
}
private ServerTransport<CloseableChannel> createTransport() {
if (this.transport == RSocketServer.TRANSPORT.WEBSOCKET) {
if (this.resourceFactory != null) {
HttpServer httpServer = HttpServer.create().tcpConfiguration((tcpServer) -> tcpServer
.runOn(this.resourceFactory.getLoopResources()).addressSupplier(this::getListenAddress));
return WebsocketServerTransport.create(httpServer);
}
else {
return WebsocketServerTransport.create(getListenAddress());
}
if (this.transport == RSocketServer.Transport.WEBSOCKET) {
return createWebSocketTransport();
}
else {
if (this.resourceFactory != null) {
TcpServer tcpServer = TcpServer.create().runOn(this.resourceFactory.getLoopResources())
.addressSupplier(this::getListenAddress);
return TcpServerTransport.create(tcpServer);
}
else {
return TcpServerTransport.create(getListenAddress());
}
return createTcpTransport();
}
private ServerTransport<CloseableChannel> createWebSocketTransport() {
if (this.resourceFactory != null) {
HttpServer httpServer = HttpServer.create().tcpConfiguration((tcpServer) -> tcpServer
.runOn(this.resourceFactory.getLoopResources()).addressSupplier(this::getListenAddress));
return WebsocketServerTransport.create(httpServer);
}
return WebsocketServerTransport.create(getListenAddress());
}
private ServerTransport<CloseableChannel> createTcpTransport() {
if (this.resourceFactory != null) {
TcpServer tcpServer = TcpServer.create().runOn(this.resourceFactory.getLoopResources())
.addressSupplier(this::getListenAddress);
return TcpServerTransport.create(tcpServer);
}
return TcpServerTransport.create(getListenAddress());
}
private InetSocketAddress getListenAddress() {
......
......@@ -43,6 +43,6 @@ public interface ConfigurableRSocketServerFactory {
* Set the transport that the RSocket server should use.
* @param transport the transport protocol to use
*/
void setTransport(RSocketServer.TRANSPORT transport);
void setTransport(RSocketServer.Transport transport);
}
......@@ -50,9 +50,17 @@ public interface RSocketServer {
/**
* Choice of transport protocol for the RSocket server.
*/
enum TRANSPORT {
enum Transport {
TCP, WEBSOCKET
/**
* TCP transport protocol.
*/
TCP,
/**
* WebSocket transport protocol.
*/
WEBSOCKET
}
......
......@@ -22,7 +22,6 @@ package org.springframework.boot.rsocket.server;
* @author Brian Clozel
* @since 2.2.0
*/
@SuppressWarnings("serial")
public class RSocketServerException extends RuntimeException {
public RSocketServerException(String message, Throwable cause) {
......
......@@ -59,7 +59,7 @@ import static org.mockito.Mockito.mock;
*/
class NettyRSocketServerFactoryTests {
private NettyRSocketServer rSocketServer;
private NettyRSocketServer server;
private RSocketRequester requester;
......@@ -67,9 +67,9 @@ class NettyRSocketServerFactoryTests {
@AfterEach
void tearDown() {
if (this.rSocketServer != null) {
if (this.server != null) {
try {
this.rSocketServer.stop();
this.server.stop();
}
catch (Exception ex) {
// Ignore
......@@ -89,47 +89,44 @@ class NettyRSocketServerFactoryTests {
NettyRSocketServerFactory factory = getFactory();
int specificPort = SocketUtils.findAvailableTcpPort(41000);
factory.setPort(specificPort);
this.rSocketServer = factory.create(new EchoRequestResponseAcceptor());
this.rSocketServer.start();
this.server = factory.create(new EchoRequestResponseAcceptor());
this.server.start();
this.requester = createRSocketTcpClient();
String payload = "test payload";
String response = this.requester.route("test").data(payload).retrieveMono(String.class).block(TIMEOUT);
assertThat(this.rSocketServer.address().getPort()).isEqualTo(specificPort);
assertThat(this.server.address().getPort()).isEqualTo(specificPort);
assertThat(response).isEqualTo(payload);
assertThat(this.rSocketServer.address().getPort()).isEqualTo(specificPort);
assertThat(this.server.address().getPort()).isEqualTo(specificPort);
}
@Test
void websocketTransport() {
NettyRSocketServerFactory factory = getFactory();
factory.setTransport(RSocketServer.TRANSPORT.WEBSOCKET);
this.rSocketServer = factory.create(new EchoRequestResponseAcceptor());
this.rSocketServer.start();
factory.setTransport(RSocketServer.Transport.WEBSOCKET);
this.server = factory.create(new EchoRequestResponseAcceptor());
this.server.start();
this.requester = createRSocketWebSocketClient();
String payload = "test payload";
String response = this.requester.route("test").data(payload).retrieveMono(String.class).block(TIMEOUT);
assertThat(response).isEqualTo(payload);
}
@Test
void websocketTransportWithReactorResource() {
NettyRSocketServerFactory factory = getFactory();
factory.setTransport(RSocketServer.TRANSPORT.WEBSOCKET);
factory.setTransport(RSocketServer.Transport.WEBSOCKET);
ReactorResourceFactory resourceFactory = new ReactorResourceFactory();
resourceFactory.afterPropertiesSet();
factory.setResourceFactory(resourceFactory);
int specificPort = SocketUtils.findAvailableTcpPort(41000);
factory.setPort(specificPort);
this.rSocketServer = factory.create(new EchoRequestResponseAcceptor());
this.rSocketServer.start();
this.server = factory.create(new EchoRequestResponseAcceptor());
this.server.start();
this.requester = createRSocketWebSocketClient();
String payload = "test payload";
String response = this.requester.route("test").data(payload).retrieveMono(String.class).block(TIMEOUT);
assertThat(response).isEqualTo(payload);
assertThat(this.rSocketServer.address().getPort()).isEqualTo(specificPort);
assertThat(this.server.address().getPort()).isEqualTo(specificPort);
}
@Test
......@@ -142,7 +139,7 @@ class NettyRSocketServerFactoryTests {
.will((invocation) -> invocation.getArgument(0));
}
factory.setServerCustomizers(Arrays.asList(customizers[0], customizers[1]));
this.rSocketServer = factory.create(new EchoRequestResponseAcceptor());
this.server = factory.create(new EchoRequestResponseAcceptor());
InOrder ordered = inOrder((Object[]) customizers);
for (ServerRSocketFactoryCustomizer customizer : customizers) {
ordered.verify(customizer).apply(any(RSocketFactory.ServerRSocketFactory.class));
......@@ -150,14 +147,14 @@ class NettyRSocketServerFactoryTests {
}
private RSocketRequester createRSocketTcpClient() {
Assertions.assertThat(this.rSocketServer).isNotNull();
InetSocketAddress address = this.rSocketServer.address();
Assertions.assertThat(this.server).isNotNull();
InetSocketAddress address = this.server.address();
return createRSocketRequesterBuilder().connectTcp(address.getHostString(), address.getPort()).block();
}
private RSocketRequester createRSocketWebSocketClient() {
Assertions.assertThat(this.rSocketServer).isNotNull();
InetSocketAddress address = this.rSocketServer.address();
Assertions.assertThat(this.server).isNotNull();
InetSocketAddress address = this.server.address();
return createRSocketRequesterBuilder().connect(WebsocketClientTransport.create(address)).block();
}
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment