Commit 69fbd8f2 authored by Phillip Webb's avatar Phillip Webb

Merge branch '2.1.x'

Closes gh-17227
Closes gh-17228
parents bd140508 5e3438f0
...@@ -24,7 +24,6 @@ import org.springframework.boot.autoconfigure.web.ServerProperties; ...@@ -24,7 +24,6 @@ import org.springframework.boot.autoconfigure.web.ServerProperties;
import org.springframework.boot.cloud.CloudPlatform; import org.springframework.boot.cloud.CloudPlatform;
import org.springframework.boot.context.properties.PropertyMapper; import org.springframework.boot.context.properties.PropertyMapper;
import org.springframework.boot.web.embedded.netty.NettyReactiveWebServerFactory; import org.springframework.boot.web.embedded.netty.NettyReactiveWebServerFactory;
import org.springframework.boot.web.embedded.netty.NettyServerCustomizer;
import org.springframework.boot.web.server.WebServerFactoryCustomizer; import org.springframework.boot.web.server.WebServerFactoryCustomizer;
import org.springframework.core.Ordered; import org.springframework.core.Ordered;
import org.springframework.core.env.Environment; import org.springframework.core.env.Environment;
...@@ -58,11 +57,11 @@ public class NettyWebServerFactoryCustomizer ...@@ -58,11 +57,11 @@ public class NettyWebServerFactoryCustomizer
@Override @Override
public void customize(NettyReactiveWebServerFactory factory) { public void customize(NettyReactiveWebServerFactory factory) {
factory.setUseForwardHeaders(getOrDeduceUseForwardHeaders()); factory.setUseForwardHeaders(getOrDeduceUseForwardHeaders());
PropertyMapper propertyMapper = PropertyMapper.get(); PropertyMapper propertyMapper = PropertyMapper.get().alwaysApplyingWhenNonNull();
propertyMapper.from(this.serverProperties::getMaxHttpHeaderSize).whenNonNull().asInt(DataSize::toBytes) propertyMapper.from(this.serverProperties::getMaxHttpHeaderSize)
.to((maxHttpRequestHeaderSize) -> customizeMaxHttpHeaderSize(factory, maxHttpRequestHeaderSize)); .to((maxHttpRequestHeaderSize) -> customizeMaxHttpHeaderSize(factory, maxHttpRequestHeaderSize));
propertyMapper.from(this.serverProperties::getConnectionTimeout).whenNonNull().asInt(Duration::toMillis) propertyMapper.from(this.serverProperties::getConnectionTimeout)
.to((duration) -> factory.addServerCustomizers(getConnectionTimeOutCustomizer(duration))); .to((connectionTimeout) -> customizeConnectionTimeout(factory, connectionTimeout));
} }
private boolean getOrDeduceUseForwardHeaders() { private boolean getOrDeduceUseForwardHeaders() {
...@@ -73,14 +72,17 @@ public class NettyWebServerFactoryCustomizer ...@@ -73,14 +72,17 @@ public class NettyWebServerFactoryCustomizer
return this.serverProperties.getForwardHeadersStrategy().equals(ServerProperties.ForwardHeadersStrategy.NATIVE); return this.serverProperties.getForwardHeadersStrategy().equals(ServerProperties.ForwardHeadersStrategy.NATIVE);
} }
private void customizeMaxHttpHeaderSize(NettyReactiveWebServerFactory factory, Integer maxHttpHeaderSize) { private void customizeMaxHttpHeaderSize(NettyReactiveWebServerFactory factory, DataSize maxHttpHeaderSize) {
factory.addServerCustomizers((NettyServerCustomizer) (httpServer) -> httpServer.httpRequestDecoder( factory.addServerCustomizers((httpServer) -> httpServer.httpRequestDecoder(
(httpRequestDecoderSpec) -> httpRequestDecoderSpec.maxHeaderSize(maxHttpHeaderSize))); (httpRequestDecoderSpec) -> httpRequestDecoderSpec.maxHeaderSize((int) maxHttpHeaderSize.toBytes())));
} }
private NettyServerCustomizer getConnectionTimeOutCustomizer(int duration) { private void customizeConnectionTimeout(NettyReactiveWebServerFactory factory, Duration connectionTimeout) {
return (httpServer) -> httpServer.tcpConfiguration( if (!connectionTimeout.isZero()) {
(tcpServer) -> tcpServer.selectorOption(ChannelOption.CONNECT_TIMEOUT_MILLIS, duration)); long timeoutMillis = connectionTimeout.isNegative() ? 0 : connectionTimeout.toMillis();
factory.addServerCustomizers((httpServer) -> httpServer.tcpConfiguration((tcpServer) -> tcpServer
.selectorOption(ChannelOption.CONNECT_TIMEOUT_MILLIS, (int) timeoutMillis)));
}
} }
} }
...@@ -16,21 +16,38 @@ ...@@ -16,21 +16,38 @@
package org.springframework.boot.autoconfigure.web.embedded; package org.springframework.boot.autoconfigure.web.embedded;
import java.time.Duration;
import java.util.Map;
import io.netty.bootstrap.ServerBootstrap;
import io.netty.channel.ChannelOption;
import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test; import org.junit.jupiter.api.Test;
import org.mockito.ArgumentCaptor;
import org.mockito.Captor;
import org.mockito.MockitoAnnotations;
import reactor.netty.http.server.HttpServer;
import reactor.netty.tcp.TcpServer;
import org.springframework.boot.autoconfigure.web.ServerProperties; import org.springframework.boot.autoconfigure.web.ServerProperties;
import org.springframework.boot.context.properties.source.ConfigurationPropertySources; import org.springframework.boot.context.properties.source.ConfigurationPropertySources;
import org.springframework.boot.web.embedded.netty.NettyReactiveWebServerFactory; import org.springframework.boot.web.embedded.netty.NettyReactiveWebServerFactory;
import org.springframework.boot.web.embedded.netty.NettyServerCustomizer;
import org.springframework.mock.env.MockEnvironment; import org.springframework.mock.env.MockEnvironment;
import org.springframework.test.util.ReflectionTestUtils;
import static org.assertj.core.api.Assertions.assertThat;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.mock; import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.never;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify; import static org.mockito.Mockito.verify;
/** /**
* Tests for {@link NettyWebServerFactoryCustomizer}. * Tests for {@link NettyWebServerFactoryCustomizer}.
* *
* @author Brian Clozel * @author Brian Clozel
* @author Artsiom Yudovin
*/ */
class NettyWebServerFactoryCustomizerTests { class NettyWebServerFactoryCustomizerTests {
...@@ -40,8 +57,12 @@ class NettyWebServerFactoryCustomizerTests { ...@@ -40,8 +57,12 @@ class NettyWebServerFactoryCustomizerTests {
private NettyWebServerFactoryCustomizer customizer; private NettyWebServerFactoryCustomizer customizer;
@Captor
private ArgumentCaptor<NettyServerCustomizer> customizerCaptor;
@BeforeEach @BeforeEach
public void setup() { public void setup() {
MockitoAnnotations.initMocks(this);
this.environment = new MockEnvironment(); this.environment = new MockEnvironment();
this.serverProperties = new ServerProperties(); this.serverProperties = new ServerProperties();
ConfigurationPropertySources.attach(this.environment); ConfigurationPropertySources.attach(this.environment);
...@@ -71,4 +92,49 @@ class NettyWebServerFactoryCustomizerTests { ...@@ -71,4 +92,49 @@ class NettyWebServerFactoryCustomizerTests {
verify(factory).setUseForwardHeaders(true); verify(factory).setUseForwardHeaders(true);
} }
@Test
void setConnectionTimeoutAsZero() {
setupConnectionTimeout(Duration.ZERO);
NettyReactiveWebServerFactory factory = mock(NettyReactiveWebServerFactory.class);
this.customizer.customize(factory);
verifyConnectionTimeout(factory, null);
}
@Test
void setConnectionTimeoutAsMinusOne() {
setupConnectionTimeout(Duration.ofNanos(-1));
NettyReactiveWebServerFactory factory = mock(NettyReactiveWebServerFactory.class);
this.customizer.customize(factory);
verifyConnectionTimeout(factory, 0);
}
@Test
void setConnectionTimeout() {
setupConnectionTimeout(Duration.ofSeconds(1));
NettyReactiveWebServerFactory factory = mock(NettyReactiveWebServerFactory.class);
this.customizer.customize(factory);
verifyConnectionTimeout(factory, 1000);
}
@SuppressWarnings("unchecked")
private void verifyConnectionTimeout(NettyReactiveWebServerFactory factory, Integer expected) {
if (expected == null) {
verify(factory, never()).addServerCustomizers(any(NettyServerCustomizer.class));
return;
}
verify(factory, times(1)).addServerCustomizers(this.customizerCaptor.capture());
NettyServerCustomizer serverCustomizer = this.customizerCaptor.getValue();
HttpServer httpServer = serverCustomizer.apply(HttpServer.create());
TcpServer tcpConfiguration = ReflectionTestUtils.invokeMethod(httpServer, "tcpConfiguration");
ServerBootstrap bootstrap = tcpConfiguration.configure();
Map<Object, Object> options = (Map<Object, Object>) ReflectionTestUtils.getField(bootstrap, "options");
assertThat(options).containsEntry(ChannelOption.CONNECT_TIMEOUT_MILLIS, expected);
}
private void setupConnectionTimeout(Duration connectionTimeout) {
this.serverProperties.setUseForwardHeaders(null);
this.serverProperties.setMaxHttpHeaderSize(null);
this.serverProperties.setConnectionTimeout(connectionTimeout);
}
} }
...@@ -50,6 +50,7 @@ import org.springframework.util.StringUtils; ...@@ -50,6 +50,7 @@ import org.springframework.util.StringUtils;
* {@link Source#toInstance(Function) new instance}. * {@link Source#toInstance(Function) new instance}.
* *
* @author Phillip Webb * @author Phillip Webb
* @author Artsiom Yudovin
* @since 2.0.0 * @since 2.0.0
*/ */
public final class PropertyMapper { public final class PropertyMapper {
...@@ -288,7 +289,7 @@ public final class PropertyMapper { ...@@ -288,7 +289,7 @@ public final class PropertyMapper {
*/ */
public Source<T> whenNot(Predicate<T> predicate) { public Source<T> whenNot(Predicate<T> predicate) {
Assert.notNull(predicate, "Predicate must not be null"); Assert.notNull(predicate, "Predicate must not be null");
return new Source<>(this.supplier, predicate.negate()); return when(predicate.negate());
} }
/** /**
...@@ -299,7 +300,7 @@ public final class PropertyMapper { ...@@ -299,7 +300,7 @@ public final class PropertyMapper {
*/ */
public Source<T> when(Predicate<T> predicate) { public Source<T> when(Predicate<T> predicate) {
Assert.notNull(predicate, "Predicate must not be null"); Assert.notNull(predicate, "Predicate must not be null");
return new Source<>(this.supplier, predicate); return new Source<>(this.supplier, (this.predicate != null) ? this.predicate.and(predicate) : predicate);
} }
/** /**
......
...@@ -28,6 +28,7 @@ import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException ...@@ -28,6 +28,7 @@ import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException
* Tests for {@link PropertyMapper}. * Tests for {@link PropertyMapper}.
* *
* @author Phillip Webb * @author Phillip Webb
* @author Artsiom Yudovin
*/ */
class PropertyMapperTests { class PropertyMapperTests {
...@@ -195,6 +196,17 @@ class PropertyMapperTests { ...@@ -195,6 +196,17 @@ class PropertyMapperTests {
this.map.alwaysApplyingWhenNonNull().from(() -> null).toCall(Assertions::fail); this.map.alwaysApplyingWhenNonNull().from(() -> null).toCall(Assertions::fail);
} }
@Test
public void whenWhenValueNotMatchesShouldSupportChainedCalls() {
this.map.from("123").when("456"::equals).when("123"::equals).toCall(Assertions::fail);
}
@Test
public void whenWhenValueMatchesShouldSupportChainedCalls() {
String result = this.map.from("123").when((s) -> s.contains("2")).when("123"::equals).toInstance(String::new);
assertThat(result).isEqualTo("123");
}
private static class Count<T> implements Supplier<T> { private static class Count<T> implements Supplier<T> {
private final Supplier<T> source; private final Supplier<T> source;
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment