Commit 69fbd8f2 authored by Phillip Webb's avatar Phillip Webb

Merge branch '2.1.x'

Closes gh-17227
Closes gh-17228
parents bd140508 5e3438f0
......@@ -24,7 +24,6 @@ import org.springframework.boot.autoconfigure.web.ServerProperties;
import org.springframework.boot.cloud.CloudPlatform;
import org.springframework.boot.context.properties.PropertyMapper;
import org.springframework.boot.web.embedded.netty.NettyReactiveWebServerFactory;
import org.springframework.boot.web.embedded.netty.NettyServerCustomizer;
import org.springframework.boot.web.server.WebServerFactoryCustomizer;
import org.springframework.core.Ordered;
import org.springframework.core.env.Environment;
......@@ -58,11 +57,11 @@ public class NettyWebServerFactoryCustomizer
@Override
public void customize(NettyReactiveWebServerFactory factory) {
factory.setUseForwardHeaders(getOrDeduceUseForwardHeaders());
PropertyMapper propertyMapper = PropertyMapper.get();
propertyMapper.from(this.serverProperties::getMaxHttpHeaderSize).whenNonNull().asInt(DataSize::toBytes)
PropertyMapper propertyMapper = PropertyMapper.get().alwaysApplyingWhenNonNull();
propertyMapper.from(this.serverProperties::getMaxHttpHeaderSize)
.to((maxHttpRequestHeaderSize) -> customizeMaxHttpHeaderSize(factory, maxHttpRequestHeaderSize));
propertyMapper.from(this.serverProperties::getConnectionTimeout).whenNonNull().asInt(Duration::toMillis)
.to((duration) -> factory.addServerCustomizers(getConnectionTimeOutCustomizer(duration)));
propertyMapper.from(this.serverProperties::getConnectionTimeout)
.to((connectionTimeout) -> customizeConnectionTimeout(factory, connectionTimeout));
}
private boolean getOrDeduceUseForwardHeaders() {
......@@ -73,14 +72,17 @@ public class NettyWebServerFactoryCustomizer
return this.serverProperties.getForwardHeadersStrategy().equals(ServerProperties.ForwardHeadersStrategy.NATIVE);
}
private void customizeMaxHttpHeaderSize(NettyReactiveWebServerFactory factory, Integer maxHttpHeaderSize) {
factory.addServerCustomizers((NettyServerCustomizer) (httpServer) -> httpServer.httpRequestDecoder(
(httpRequestDecoderSpec) -> httpRequestDecoderSpec.maxHeaderSize(maxHttpHeaderSize)));
private void customizeMaxHttpHeaderSize(NettyReactiveWebServerFactory factory, DataSize maxHttpHeaderSize) {
factory.addServerCustomizers((httpServer) -> httpServer.httpRequestDecoder(
(httpRequestDecoderSpec) -> httpRequestDecoderSpec.maxHeaderSize((int) maxHttpHeaderSize.toBytes())));
}
private NettyServerCustomizer getConnectionTimeOutCustomizer(int duration) {
return (httpServer) -> httpServer.tcpConfiguration(
(tcpServer) -> tcpServer.selectorOption(ChannelOption.CONNECT_TIMEOUT_MILLIS, duration));
private void customizeConnectionTimeout(NettyReactiveWebServerFactory factory, Duration connectionTimeout) {
if (!connectionTimeout.isZero()) {
long timeoutMillis = connectionTimeout.isNegative() ? 0 : connectionTimeout.toMillis();
factory.addServerCustomizers((httpServer) -> httpServer.tcpConfiguration((tcpServer) -> tcpServer
.selectorOption(ChannelOption.CONNECT_TIMEOUT_MILLIS, (int) timeoutMillis)));
}
}
}
......@@ -16,21 +16,38 @@
package org.springframework.boot.autoconfigure.web.embedded;
import java.time.Duration;
import java.util.Map;
import io.netty.bootstrap.ServerBootstrap;
import io.netty.channel.ChannelOption;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.mockito.ArgumentCaptor;
import org.mockito.Captor;
import org.mockito.MockitoAnnotations;
import reactor.netty.http.server.HttpServer;
import reactor.netty.tcp.TcpServer;
import org.springframework.boot.autoconfigure.web.ServerProperties;
import org.springframework.boot.context.properties.source.ConfigurationPropertySources;
import org.springframework.boot.web.embedded.netty.NettyReactiveWebServerFactory;
import org.springframework.boot.web.embedded.netty.NettyServerCustomizer;
import org.springframework.mock.env.MockEnvironment;
import org.springframework.test.util.ReflectionTestUtils;
import static org.assertj.core.api.Assertions.assertThat;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.never;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
/**
* Tests for {@link NettyWebServerFactoryCustomizer}.
*
* @author Brian Clozel
* @author Artsiom Yudovin
*/
class NettyWebServerFactoryCustomizerTests {
......@@ -40,8 +57,12 @@ class NettyWebServerFactoryCustomizerTests {
private NettyWebServerFactoryCustomizer customizer;
@Captor
private ArgumentCaptor<NettyServerCustomizer> customizerCaptor;
@BeforeEach
public void setup() {
MockitoAnnotations.initMocks(this);
this.environment = new MockEnvironment();
this.serverProperties = new ServerProperties();
ConfigurationPropertySources.attach(this.environment);
......@@ -71,4 +92,49 @@ class NettyWebServerFactoryCustomizerTests {
verify(factory).setUseForwardHeaders(true);
}
@Test
void setConnectionTimeoutAsZero() {
setupConnectionTimeout(Duration.ZERO);
NettyReactiveWebServerFactory factory = mock(NettyReactiveWebServerFactory.class);
this.customizer.customize(factory);
verifyConnectionTimeout(factory, null);
}
@Test
void setConnectionTimeoutAsMinusOne() {
setupConnectionTimeout(Duration.ofNanos(-1));
NettyReactiveWebServerFactory factory = mock(NettyReactiveWebServerFactory.class);
this.customizer.customize(factory);
verifyConnectionTimeout(factory, 0);
}
@Test
void setConnectionTimeout() {
setupConnectionTimeout(Duration.ofSeconds(1));
NettyReactiveWebServerFactory factory = mock(NettyReactiveWebServerFactory.class);
this.customizer.customize(factory);
verifyConnectionTimeout(factory, 1000);
}
@SuppressWarnings("unchecked")
private void verifyConnectionTimeout(NettyReactiveWebServerFactory factory, Integer expected) {
if (expected == null) {
verify(factory, never()).addServerCustomizers(any(NettyServerCustomizer.class));
return;
}
verify(factory, times(1)).addServerCustomizers(this.customizerCaptor.capture());
NettyServerCustomizer serverCustomizer = this.customizerCaptor.getValue();
HttpServer httpServer = serverCustomizer.apply(HttpServer.create());
TcpServer tcpConfiguration = ReflectionTestUtils.invokeMethod(httpServer, "tcpConfiguration");
ServerBootstrap bootstrap = tcpConfiguration.configure();
Map<Object, Object> options = (Map<Object, Object>) ReflectionTestUtils.getField(bootstrap, "options");
assertThat(options).containsEntry(ChannelOption.CONNECT_TIMEOUT_MILLIS, expected);
}
private void setupConnectionTimeout(Duration connectionTimeout) {
this.serverProperties.setUseForwardHeaders(null);
this.serverProperties.setMaxHttpHeaderSize(null);
this.serverProperties.setConnectionTimeout(connectionTimeout);
}
}
......@@ -50,6 +50,7 @@ import org.springframework.util.StringUtils;
* {@link Source#toInstance(Function) new instance}.
*
* @author Phillip Webb
* @author Artsiom Yudovin
* @since 2.0.0
*/
public final class PropertyMapper {
......@@ -288,7 +289,7 @@ public final class PropertyMapper {
*/
public Source<T> whenNot(Predicate<T> predicate) {
Assert.notNull(predicate, "Predicate must not be null");
return new Source<>(this.supplier, predicate.negate());
return when(predicate.negate());
}
/**
......@@ -299,7 +300,7 @@ public final class PropertyMapper {
*/
public Source<T> when(Predicate<T> predicate) {
Assert.notNull(predicate, "Predicate must not be null");
return new Source<>(this.supplier, predicate);
return new Source<>(this.supplier, (this.predicate != null) ? this.predicate.and(predicate) : predicate);
}
/**
......
......@@ -28,6 +28,7 @@ import static org.assertj.core.api.Assertions.assertThatIllegalArgumentException
* Tests for {@link PropertyMapper}.
*
* @author Phillip Webb
* @author Artsiom Yudovin
*/
class PropertyMapperTests {
......@@ -195,6 +196,17 @@ class PropertyMapperTests {
this.map.alwaysApplyingWhenNonNull().from(() -> null).toCall(Assertions::fail);
}
@Test
public void whenWhenValueNotMatchesShouldSupportChainedCalls() {
this.map.from("123").when("456"::equals).when("123"::equals).toCall(Assertions::fail);
}
@Test
public void whenWhenValueMatchesShouldSupportChainedCalls() {
String result = this.map.from("123").when((s) -> s.contains("2")).when("123"::equals).toInstance(String::new);
assertThat(result).isEqualTo("123");
}
private static class Count<T> implements Supplier<T> {
private final Supplier<T> source;
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment