Commit 20a1ce49 authored by Stephane Nicoll's avatar Stephane Nicoll

Merge pull request #15482 from dreis2211

* pr/15482:
  Avoid unnecessary usage of ReflectionTestUtils
parents d004ee9a 948902f0
......@@ -23,12 +23,10 @@ import org.springframework.boot.test.context.runner.ApplicationContextRunner;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import org.springframework.context.annotation.Primary;
import org.springframework.ldap.core.ContextSource;
import org.springframework.ldap.core.LdapTemplate;
import org.springframework.ldap.core.support.LdapContextSource;
import org.springframework.ldap.pool2.factory.PoolConfig;
import org.springframework.ldap.pool2.factory.PooledContextSource;
import org.springframework.test.util.ReflectionTestUtils;
import static org.assertj.core.api.Assertions.assertThat;
......@@ -58,7 +56,7 @@ public class LdapAutoConfigurationTests {
public void contextSourceWithSingleUrl() {
this.contextRunner.withPropertyValues("spring.ldap.urls:ldap://localhost:123")
.run((context) -> {
ContextSource contextSource = context
LdapContextSource contextSource = context
.getBean(LdapContextSource.class);
String[] urls = getUrls(contextSource);
assertThat(urls).containsExactly("ldap://localhost:123");
......@@ -71,7 +69,7 @@ public class LdapAutoConfigurationTests {
.withPropertyValues(
"spring.ldap.urls:ldap://localhost:123,ldap://mycompany:123")
.run((context) -> {
ContextSource contextSource = context
LdapContextSource contextSource = context
.getBean(LdapContextSource.class);
LdapProperties ldapProperties = context.getBean(LdapProperties.class);
String[] urls = getUrls(contextSource);
......@@ -120,8 +118,8 @@ public class LdapAutoConfigurationTests {
});
}
private String[] getUrls(ContextSource contextSource) {
return (String[]) ReflectionTestUtils.getField(contextSource, "urls");
private String[] getUrls(LdapContextSource contextSource) {
return contextSource.getUrls();
}
@Configuration
......
......@@ -21,15 +21,12 @@ import java.util.List;
import com.mongodb.MongoClient;
import com.mongodb.MongoCredential;
import com.mongodb.ServerAddress;
import com.mongodb.connection.Cluster;
import com.mongodb.connection.ClusterSettings;
import org.junit.Test;
import org.springframework.boot.context.properties.EnableConfigurationProperties;
import org.springframework.context.annotation.Configuration;
import org.springframework.core.env.Environment;
import org.springframework.mock.env.MockEnvironment;
import org.springframework.test.util.ReflectionTestUtils;
import static org.assertj.core.api.Assertions.assertThat;
......@@ -50,7 +47,7 @@ public class MongoClientFactoryTests {
MongoProperties properties = new MongoProperties();
properties.setPort(12345);
MongoClient client = createMongoClient(properties);
List<ServerAddress> allAddresses = extractServerAddresses(client);
List<ServerAddress> allAddresses = client.getAllAddress();
assertThat(allAddresses).hasSize(1);
assertServerAddress(allAddresses.get(0), "localhost", 12345);
}
......@@ -60,7 +57,7 @@ public class MongoClientFactoryTests {
MongoProperties properties = new MongoProperties();
properties.setHost("mongo.example.com");
MongoClient client = createMongoClient(properties);
List<ServerAddress> allAddresses = extractServerAddresses(client);
List<ServerAddress> allAddresses = client.getAllAddress();
assertThat(allAddresses).hasSize(1);
assertServerAddress(allAddresses.get(0), "mongo.example.com", 27017);
}
......@@ -103,7 +100,7 @@ public class MongoClientFactoryTests {
properties.setUri("mongodb://user:secret@mongo1.example.com:12345,"
+ "mongo2.example.com:23456/test");
MongoClient client = createMongoClient(properties);
List<ServerAddress> allAddresses = extractServerAddresses(client);
List<ServerAddress> allAddresses = client.getAllAddress();
assertThat(allAddresses).hasSize(2);
assertServerAddress(allAddresses.get(0), "mongo1.example.com", 12345);
assertServerAddress(allAddresses.get(1), "mongo2.example.com", 23456);
......@@ -118,7 +115,7 @@ public class MongoClientFactoryTests {
properties.setUri("mongodb://mongo.example.com:1234/mydb");
this.environment.setProperty("local.mongo.port", "4000");
MongoClient client = createMongoClient(properties, this.environment);
List<ServerAddress> allAddresses = extractServerAddresses(client);
List<ServerAddress> allAddresses = client.getAllAddress();
assertThat(allAddresses).hasSize(1);
assertServerAddress(allAddresses.get(0), "localhost", 4000);
}
......@@ -132,14 +129,6 @@ public class MongoClientFactoryTests {
return new MongoClientFactory(properties, environment).createMongoClient(null);
}
private List<ServerAddress> extractServerAddresses(MongoClient client) {
Cluster cluster = (Cluster) ReflectionTestUtils
.getField(ReflectionTestUtils.getField(client, "delegate"), "cluster");
ClusterSettings clusterSettings = (ClusterSettings) ReflectionTestUtils
.getField(cluster, "settings");
return clusterSettings.getHosts();
}
private void assertServerAddress(ServerAddress serverAddress, String expectedHost,
int expectedPort) {
assertThat(serverAddress.getHost()).isEqualTo(expectedHost);
......
......@@ -21,15 +21,12 @@ import java.util.List;
import com.mongodb.MongoClient;
import com.mongodb.MongoClientOptions;
import com.mongodb.ServerAddress;
import com.mongodb.connection.Cluster;
import com.mongodb.connection.ClusterSettings;
import org.junit.Test;
import org.springframework.boot.context.properties.EnableConfigurationProperties;
import org.springframework.boot.test.util.TestPropertyValues;
import org.springframework.context.annotation.AnnotationConfigApplicationContext;
import org.springframework.context.annotation.Configuration;
import org.springframework.test.util.ReflectionTestUtils;
import static org.assertj.core.api.Assertions.assertThat;
......@@ -120,7 +117,7 @@ public class MongoPropertiesTests {
properties.setUri("mongodb://mongo1.example.com:12345");
MongoClient client = new MongoClientFactory(properties, null)
.createMongoClient(null);
List<ServerAddress> allAddresses = extractServerAddresses(client);
List<ServerAddress> allAddresses = client.getAllAddress();
assertThat(allAddresses).hasSize(1);
assertServerAddress(allAddresses.get(0), "mongo1.example.com", 12345);
}
......@@ -132,7 +129,7 @@ public class MongoPropertiesTests {
properties.setPort(27017);
MongoClient client = new MongoClientFactory(properties, null)
.createMongoClient(null);
List<ServerAddress> allAddresses = extractServerAddresses(client);
List<ServerAddress> allAddresses = client.getAllAddress();
assertThat(allAddresses).hasSize(1);
assertServerAddress(allAddresses.get(0), "localhost", 27017);
}
......@@ -143,7 +140,7 @@ public class MongoPropertiesTests {
properties.setUri("mongodb://mongo1.example.com:12345");
MongoClient client = new MongoClientFactory(properties, null)
.createMongoClient(null);
List<ServerAddress> allAddresses = extractServerAddresses(client);
List<ServerAddress> allAddresses = client.getAllAddress();
assertThat(allAddresses).hasSize(1);
assertServerAddress(allAddresses.get(0), "mongo1.example.com", 12345);
}
......@@ -153,19 +150,11 @@ public class MongoPropertiesTests {
MongoProperties properties = new MongoProperties();
MongoClient client = new MongoClientFactory(properties, null)
.createMongoClient(null);
List<ServerAddress> allAddresses = extractServerAddresses(client);
List<ServerAddress> allAddresses = client.getAllAddress();
assertThat(allAddresses).hasSize(1);
assertServerAddress(allAddresses.get(0), "localhost", 27017);
}
private List<ServerAddress> extractServerAddresses(MongoClient client) {
Cluster cluster = (Cluster) ReflectionTestUtils
.getField(ReflectionTestUtils.getField(client, "delegate"), "cluster");
ClusterSettings clusterSettings = (ClusterSettings) ReflectionTestUtils
.getField(cluster, "settings");
return clusterSettings.getHosts();
}
private void assertServerAddress(ServerAddress serverAddress, String expectedHost,
int expectedPort) {
assertThat(serverAddress.getHost()).isEqualTo(expectedHost);
......
......@@ -165,14 +165,12 @@ public class OAuth2WebSecurityConfigurationTests {
});
}
@SuppressWarnings("unchecked")
private List<Filter> getFilters(AssertableApplicationContext context,
Class<? extends Filter> filter) {
FilterChainProxy filterChain = (FilterChainProxy) context
.getBean(BeanIds.SPRING_SECURITY_FILTER_CHAIN);
List<SecurityFilterChain> filterChains = filterChain.getFilterChains();
List<Filter> filters = (List<Filter>) ReflectionTestUtils
.getField(filterChains.get(0), "filters");
List<Filter> filters = filterChains.get(0).getFilters();
return filters.stream().filter(filter::isInstance).collect(Collectors.toList());
}
......
......@@ -18,8 +18,8 @@ package org.springframework.boot.autoconfigure.security.oauth2.resource.reactive
import java.io.IOException;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.stream.Stream;
import okhttp3.mockwebserver.MockResponse;
import okhttp3.mockwebserver.MockWebServer;
......@@ -175,14 +175,12 @@ public class ReactiveOAuth2ResourceServerAutoConfigurationTests {
});
}
@SuppressWarnings("unchecked")
private void assertFilterConfiguredWithJwtAuthenticationManager(
AssertableReactiveWebApplicationContext context) {
MatcherSecurityWebFilterChain filterChain = (MatcherSecurityWebFilterChain) context
.getBean(BeanIds.SPRING_SECURITY_FILTER_CHAIN);
List<WebFilter> filters = (List<WebFilter>) ReflectionTestUtils
.getField(filterChain, "filters");
AuthenticationWebFilter webFilter = (AuthenticationWebFilter) filters.stream()
Stream<WebFilter> filters = filterChain.getWebFilters().toStream();
AuthenticationWebFilter webFilter = (AuthenticationWebFilter) filters
.filter((f) -> f instanceof AuthenticationWebFilter).findFirst()
.orElse(null);
ReactiveAuthenticationManager authenticationManager = (ReactiveAuthenticationManager) ReflectionTestUtils
......
......@@ -46,7 +46,6 @@ import org.springframework.security.oauth2.server.resource.authentication.JwtAut
import org.springframework.security.oauth2.server.resource.web.BearerTokenAuthenticationFilter;
import org.springframework.security.web.FilterChainProxy;
import org.springframework.security.web.SecurityFilterChain;
import org.springframework.test.util.ReflectionTestUtils;
import static org.assertj.core.api.Assertions.assertThat;
import static org.mockito.Mockito.mock;
......@@ -158,13 +157,11 @@ public class OAuth2ResourceServerAutoConfigurationTests {
.run((context) -> assertThat(getBearerTokenFilter(context)).isNull());
}
@SuppressWarnings("unchecked")
private Filter getBearerTokenFilter(AssertableWebApplicationContext context) {
FilterChainProxy filterChain = (FilterChainProxy) context
.getBean(BeanIds.SPRING_SECURITY_FILTER_CHAIN);
List<SecurityFilterChain> filterChains = filterChain.getFilterChains();
List<Filter> filters = (List<Filter>) ReflectionTestUtils
.getField(filterChains.get(0), "filters");
List<Filter> filters = filterChains.get(0).getFilters();
return filters.stream()
.filter((f) -> f instanceof BearerTokenAuthenticationFilter).findFirst()
.orElse(null);
......
......@@ -40,7 +40,6 @@ import org.springframework.boot.web.embedded.jetty.JettyServletWebServerFactory;
import org.springframework.boot.web.embedded.jetty.JettyWebServer;
import org.springframework.mock.env.MockEnvironment;
import org.springframework.test.context.support.TestPropertySourceUtils;
import org.springframework.test.util.ReflectionTestUtils;
import static org.assertj.core.api.Assertions.assertThat;
import static org.mockito.Mockito.mock;
......@@ -172,8 +171,10 @@ public class JettyWebServerFactoryCustomizerTests {
private List<Integer> getRequestHeaderSizes(JettyWebServer server) {
List<Integer> requestHeaderSizes = new ArrayList<>();
Connector[] connectors = (Connector[]) ReflectionTestUtils.getField(server,
"connectors");
// Start (and directly stop) server to have connectors available
server.start();
server.stop();
Connector[] connectors = server.getServer().getConnectors();
for (Connector connector : connectors) {
connector.getConnectionFactories().stream()
.filter((factory) -> factory instanceof ConnectionFactory)
......
......@@ -16,17 +16,14 @@
package org.springframework.boot.autoconfigure.web.embedded;
import java.util.Map;
import java.util.function.Consumer;
import org.apache.catalina.Context;
import org.apache.catalina.Valve;
import org.apache.catalina.mapper.Mapper;
import org.apache.catalina.startup.Tomcat;
import org.apache.catalina.valves.AccessLogValve;
import org.apache.catalina.valves.ErrorReportValve;
import org.apache.catalina.valves.RemoteIpValve;
import org.apache.catalina.webresources.StandardRoot;
import org.apache.coyote.AbstractProtocol;
import org.apache.coyote.http11.AbstractHttp11Protocol;
import org.junit.Before;
......@@ -40,7 +37,6 @@ import org.springframework.boot.web.embedded.tomcat.TomcatServletWebServerFactor
import org.springframework.boot.web.embedded.tomcat.TomcatWebServer;
import org.springframework.mock.env.MockEnvironment;
import org.springframework.test.context.support.TestPropertySourceUtils;
import org.springframework.test.util.ReflectionTestUtils;
import org.springframework.util.unit.DataSize;
import static org.assertj.core.api.Assertions.assertThat;
......@@ -180,18 +176,13 @@ public class TomcatWebServerFactoryCustomizerTests {
assertThat(remoteIpValve.getInternalProxies()).isEqualTo("192.168.0.1");
}
@SuppressWarnings("unchecked")
@Test
public void customStaticResourceAllowCaching() {
bind("server.tomcat.resource.allow-caching=false");
customizeAndRunServer((server) -> {
Mapper mapper = server.getTomcat().getService().getMapper();
Object contextObjectToContextVersionMap = ReflectionTestUtils.getField(mapper,
"contextObjectToContextVersionMap");
Object tomcatEmbeddedContext = ((Map<Context, Object>) contextObjectToContextVersionMap)
.values().toArray()[0];
assertThat(((StandardRoot) ReflectionTestUtils.getField(tomcatEmbeddedContext,
"resources")).isCachingAllowed()).isFalse();
Tomcat tomcat = server.getTomcat();
Context context = (Context) tomcat.getHost().findChildren()[0];
assertThat(context.getResources().isCachingAllowed()).isFalse();
});
}
......
......@@ -29,7 +29,6 @@ import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.boot.test.autoconfigure.web.servlet.WebMvcTest;
import org.springframework.security.test.context.support.WithMockUser;
import org.springframework.test.context.junit4.SpringRunner;
import org.springframework.test.util.ReflectionTestUtils;
import static org.assertj.core.api.Assertions.assertThat;
import static org.junit.Assert.fail;
......@@ -64,7 +63,7 @@ public class WebMvcTestWebDriverIntegrationTests {
WebElement element = this.webDriver.findElement(By.tagName("body"));
assertThat(element.getText()).isEqualTo("Hello");
try {
ReflectionTestUtils.invokeMethod(previousWebDriver, "getCurrentWindow");
previousWebDriver.getWindowHandle();
fail("Did not call quit()");
}
catch (NoSuchWindowException ex) {
......
......@@ -23,9 +23,9 @@ import java.nio.charset.Charset;
import java.nio.charset.StandardCharsets;
import java.time.Duration;
import java.util.Arrays;
import java.util.HashMap;
import java.util.Locale;
import java.util.Map;
import java.util.Set;
import javax.naming.InitialContext;
import javax.naming.NamingException;
......@@ -49,6 +49,7 @@ import org.apache.catalina.valves.RemoteIpValve;
import org.apache.catalina.webresources.TomcatURLStreamHandlerFactory;
import org.apache.jasper.servlet.JspServlet;
import org.apache.tomcat.JarScanFilter;
import org.apache.tomcat.JarScanType;
import org.junit.After;
import org.junit.Rule;
import org.junit.Test;
......@@ -399,7 +400,6 @@ public class TomcatServletWebServerFactoryTests
}
@Test
@SuppressWarnings("unchecked")
public void tldSkipPatternsShouldBeAppliedToContextJarScanner() {
TomcatServletWebServerFactory factory = getFactory();
factory.addTldSkipPatterns("foo.jar", "bar.jar");
......@@ -408,9 +408,9 @@ public class TomcatServletWebServerFactoryTests
Tomcat tomcat = ((TomcatWebServer) this.webServer).getTomcat();
Context context = (Context) tomcat.getHost().findChildren()[0];
JarScanFilter jarScanFilter = context.getJarScanner().getJarScanFilter();
Set<String> tldSkipSet = (Set<String>) ReflectionTestUtils.getField(jarScanFilter,
"tldSkipSet");
assertThat(tldSkipSet).contains("foo.jar", "bar.jar");
assertThat(jarScanFilter.check(JarScanType.TLD, "foo.jar")).isFalse();
assertThat(jarScanFilter.check(JarScanType.TLD, "bar.jar")).isFalse();
assertThat(jarScanFilter.check(JarScanType.TLD, "test.jar")).isTrue();
}
@Test
......@@ -463,13 +463,15 @@ public class TomcatServletWebServerFactoryTests
return (JspServlet) standardWrapper.getServlet();
}
@SuppressWarnings("unchecked")
@Override
protected Map<String, String> getActualMimeMappings() {
Context context = (Context) ((TomcatWebServer) this.webServer).getTomcat()
.getHost().findChildren()[0];
return (Map<String, String>) ReflectionTestUtils.getField(context,
"mimeMappings");
Map<String, String> mimeMappings = new HashMap<>();
for (String extension : context.findMimeMappings()) {
mimeMappings.put(extension, context.findMimeMapping(extension));
}
return mimeMappings;
}
@Override
......
......@@ -32,9 +32,9 @@ import java.util.concurrent.atomic.AtomicReference;
import javax.net.ssl.SSLException;
import javax.net.ssl.SSLHandshakeException;
import io.undertow.Undertow;
import io.undertow.Undertow.Builder;
import io.undertow.servlet.api.DeploymentInfo;
import io.undertow.servlet.api.DeploymentManager;
import io.undertow.servlet.api.ServletContainer;
import org.apache.jasper.servlet.JspServlet;
import org.junit.Test;
......@@ -263,8 +263,7 @@ public class UndertowServletWebServerFactoryTests
UndertowServletWebServer container = (UndertowServletWebServer) getFactory()
.getWebServer();
try {
return ((DeploymentManager) ReflectionTestUtils.getField(container,
"manager")).getDeployment().getServletContainer();
return container.getDeploymentManager().getDeployment().getServletContainer();
}
finally {
container.stop();
......@@ -273,8 +272,8 @@ public class UndertowServletWebServerFactoryTests
@Override
protected Map<String, String> getActualMimeMappings() {
return ((DeploymentManager) ReflectionTestUtils.getField(this.webServer,
"manager")).getDeployment().getMimeExtensionMappings();
return ((UndertowServletWebServer) this.webServer).getDeploymentManager()
.getDeployment().getMimeExtensionMappings();
}
@Override
......@@ -288,8 +287,8 @@ public class UndertowServletWebServerFactoryTests
@Override
protected Charset getCharset(Locale locale) {
DeploymentInfo info = ((DeploymentManager) ReflectionTestUtils
.getField(this.webServer, "manager")).getDeployment().getDeploymentInfo();
DeploymentInfo info = ((UndertowServletWebServer) this.webServer)
.getDeploymentManager().getDeployment().getDeploymentInfo();
String charsetName = info.getLocaleCharsetMapping().get(locale.toString());
return (charsetName != null) ? Charset.forName(charsetName) : null;
}
......@@ -299,9 +298,9 @@ public class UndertowServletWebServerFactoryTests
int blockedPort) {
assertThat(ex).isInstanceOf(PortInUseException.class);
assertThat(((PortInUseException) ex).getPort()).isEqualTo(blockedPort);
Object undertow = ReflectionTestUtils.getField(this.webServer, "undertow");
Object worker = ReflectionTestUtils.getField(undertow, "worker");
assertThat(worker).isNull();
Undertow undertow = (Undertow) ReflectionTestUtils.getField(this.webServer,
"undertow");
assertThat(undertow.getWorker()).isNull();
}
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment