Commit dfb97eb0 authored by Madhura Bhave's avatar Madhura Bhave

Convert environment if webApplicationType changes

If the web application type is set via properties,
it is available only after binding. The environment needs
to be converted to the appropriate type if it does not match.
If a custom environment is set, it is not converted.

Fixes gh-13977
parent 6e5ff77b
...@@ -26,7 +26,6 @@ import org.springframework.core.env.MutablePropertySources; ...@@ -26,7 +26,6 @@ import org.springframework.core.env.MutablePropertySources;
import org.springframework.core.env.PropertySource; import org.springframework.core.env.PropertySource;
import org.springframework.core.env.StandardEnvironment; import org.springframework.core.env.StandardEnvironment;
import org.springframework.util.ClassUtils; import org.springframework.util.ClassUtils;
import org.springframework.web.context.ConfigurableWebEnvironment;
import org.springframework.web.context.support.StandardServletEnvironment; import org.springframework.web.context.support.StandardServletEnvironment;
/** /**
...@@ -34,6 +33,7 @@ import org.springframework.web.context.support.StandardServletEnvironment; ...@@ -34,6 +33,7 @@ import org.springframework.web.context.support.StandardServletEnvironment;
* *
* @author Ethan Rubinson * @author Ethan Rubinson
* @author Andy Wilkinson * @author Andy Wilkinson
* @author Madhura Bhave
*/ */
final class EnvironmentConverter { final class EnvironmentConverter {
...@@ -61,46 +61,44 @@ final class EnvironmentConverter { ...@@ -61,46 +61,44 @@ final class EnvironmentConverter {
} }
/** /**
* Converts the given {@code environment} to a {@link StandardEnvironment}. If the * Converts the given {@code environment} to the given {@link StandardEnvironment}
* environment is already a {@code StandardEnvironment} and is not a * type. If the environment is already of the same type, no conversion is performed
* {@link ConfigurableWebEnvironment} no conversion is performed and it is returned * and it is returned unchanged.
* unchanged.
* @param environment the Environment to convert * @param environment the Environment to convert
* @param conversionType the type to convert the Environment to
* @return the converted Environment * @return the converted Environment
*/ */
StandardEnvironment convertToStandardEnvironmentIfNecessary( StandardEnvironment convertEnvironmentIfNecessary(ConfigurableEnvironment environment,
ConfigurableEnvironment environment) { Class<? extends StandardEnvironment> conversionType) {
if (environment instanceof StandardEnvironment if (conversionType.equals(environment.getClass())) {
&& !isWebEnvironment(environment, this.classLoader)) {
return (StandardEnvironment) environment; return (StandardEnvironment) environment;
} }
return convertToStandardEnvironment(environment); return convertEnvironment(environment, conversionType);
} }
private boolean isWebEnvironment(ConfigurableEnvironment environment, private StandardEnvironment convertEnvironment(ConfigurableEnvironment environment,
ClassLoader classLoader) { Class<? extends StandardEnvironment> conversionType) {
try { StandardEnvironment result = createEnvironment(conversionType);
Class<?> webEnvironmentClass = ClassUtils
.forName(CONFIGURABLE_WEB_ENVIRONMENT_CLASS, classLoader);
return (webEnvironmentClass.isInstance(environment));
}
catch (Throwable ex) {
return false;
}
}
private StandardEnvironment convertToStandardEnvironment(
ConfigurableEnvironment environment) {
StandardEnvironment result = new StandardEnvironment();
result.setActiveProfiles(environment.getActiveProfiles()); result.setActiveProfiles(environment.getActiveProfiles());
result.setConversionService(environment.getConversionService()); result.setConversionService(environment.getConversionService());
copyNonServletPropertySources(environment, result); copyPropertySources(environment, result);
return result; return result;
} }
private void copyNonServletPropertySources(ConfigurableEnvironment source, private StandardEnvironment createEnvironment(
Class<? extends StandardEnvironment> conversionType) {
try {
return conversionType.newInstance();
}
catch (Exception ex) {
return new StandardEnvironment();
}
}
private void copyPropertySources(ConfigurableEnvironment source,
StandardEnvironment target) { StandardEnvironment target) {
removeAllPropertySources(target.getPropertySources()); removePropertySources(target.getPropertySources(),
isServletEnvironment(target.getClass(), this.classLoader));
for (PropertySource<?> propertySource : source.getPropertySources()) { for (PropertySource<?> propertySource : source.getPropertySources()) {
if (!SERVLET_ENVIRONMENT_SOURCE_NAMES.contains(propertySource.getName())) { if (!SERVLET_ENVIRONMENT_SOURCE_NAMES.contains(propertySource.getName())) {
target.getPropertySources().addLast(propertySource); target.getPropertySources().addLast(propertySource);
...@@ -108,13 +106,31 @@ final class EnvironmentConverter { ...@@ -108,13 +106,31 @@ final class EnvironmentConverter {
} }
} }
private void removeAllPropertySources(MutablePropertySources propertySources) { private boolean isServletEnvironment(Class<?> conversionType,
ClassLoader classLoader) {
try {
Class<?> webEnvironmentClass = ClassUtils
.forName(CONFIGURABLE_WEB_ENVIRONMENT_CLASS, classLoader);
return webEnvironmentClass.isAssignableFrom(conversionType);
}
catch (Throwable ex) {
return false;
}
}
private void removePropertySources(MutablePropertySources propertySources,
boolean isServletEnvironment) {
Set<String> names = new HashSet<>(); Set<String> names = new HashSet<>();
for (PropertySource<?> propertySource : propertySources) { for (PropertySource<?> propertySource : propertySources) {
names.add(propertySource.getName()); names.add(propertySource.getName());
} }
for (String name : names) { for (String name : names) {
propertySources.remove(name); if (!isServletEnvironment) {
propertySources.remove(name);
}
else if (!SERVLET_ENVIRONMENT_SOURCE_NAMES.contains(name)) {
propertySources.remove(name);
}
} }
} }
......
...@@ -43,6 +43,7 @@ import org.springframework.boot.Banner.Mode; ...@@ -43,6 +43,7 @@ import org.springframework.boot.Banner.Mode;
import org.springframework.boot.context.properties.bind.Bindable; import org.springframework.boot.context.properties.bind.Bindable;
import org.springframework.boot.context.properties.bind.Binder; import org.springframework.boot.context.properties.bind.Binder;
import org.springframework.boot.context.properties.source.ConfigurationPropertySources; import org.springframework.boot.context.properties.source.ConfigurationPropertySources;
import org.springframework.boot.web.reactive.context.StandardReactiveWebEnvironment;
import org.springframework.context.ApplicationContext; import org.springframework.context.ApplicationContext;
import org.springframework.context.ApplicationContextInitializer; import org.springframework.context.ApplicationContextInitializer;
import org.springframework.context.ApplicationListener; import org.springframework.context.ApplicationListener;
...@@ -235,6 +236,8 @@ public class SpringApplication { ...@@ -235,6 +236,8 @@ public class SpringApplication {
private Set<String> additionalProfiles = new HashSet<>(); private Set<String> additionalProfiles = new HashSet<>();
private boolean isCustomEnvironment = false;
/** /**
* Create a new {@link SpringApplication} instance. The application context will load * Create a new {@link SpringApplication} instance. The application context will load
* beans from the specified primary sources (see {@link SpringApplication class-level} * beans from the specified primary sources (see {@link SpringApplication class-level}
...@@ -360,14 +363,24 @@ public class SpringApplication { ...@@ -360,14 +363,24 @@ public class SpringApplication {
configureEnvironment(environment, applicationArguments.getSourceArgs()); configureEnvironment(environment, applicationArguments.getSourceArgs());
listeners.environmentPrepared(environment); listeners.environmentPrepared(environment);
bindToSpringApplication(environment); bindToSpringApplication(environment);
if (this.webApplicationType == WebApplicationType.NONE) { if (!this.isCustomEnvironment) {
environment = new EnvironmentConverter(getClassLoader()) environment = new EnvironmentConverter(getClassLoader())
.convertToStandardEnvironmentIfNecessary(environment); .convertEnvironmentIfNecessary(environment, deduceEnvironmentClass());
} }
ConfigurationPropertySources.attach(environment); ConfigurationPropertySources.attach(environment);
return environment; return environment;
} }
private Class<? extends StandardEnvironment> deduceEnvironmentClass() {
if (this.webApplicationType == WebApplicationType.SERVLET) {
return StandardServletEnvironment.class;
}
if (this.webApplicationType == WebApplicationType.REACTIVE) {
return StandardReactiveWebEnvironment.class;
}
return StandardEnvironment.class;
}
private void prepareContext(ConfigurableApplicationContext context, private void prepareContext(ConfigurableApplicationContext context,
ConfigurableEnvironment environment, SpringApplicationRunListeners listeners, ConfigurableEnvironment environment, SpringApplicationRunListeners listeners,
ApplicationArguments applicationArguments, Banner printedBanner) { ApplicationArguments applicationArguments, Banner printedBanner) {
...@@ -462,6 +475,9 @@ public class SpringApplication { ...@@ -462,6 +475,9 @@ public class SpringApplication {
if (this.webApplicationType == WebApplicationType.SERVLET) { if (this.webApplicationType == WebApplicationType.SERVLET) {
return new StandardServletEnvironment(); return new StandardServletEnvironment();
} }
if (this.webApplicationType == WebApplicationType.REACTIVE) {
return new StandardReactiveWebEnvironment();
}
return new StandardEnvironment(); return new StandardEnvironment();
} }
...@@ -1077,6 +1093,7 @@ public class SpringApplication { ...@@ -1077,6 +1093,7 @@ public class SpringApplication {
* @param environment the environment * @param environment the environment
*/ */
public void setEnvironment(ConfigurableEnvironment environment) { public void setEnvironment(ConfigurableEnvironment environment) {
this.isCustomEnvironment = true;
this.environment = environment; this.environment = environment;
} }
......
...@@ -16,10 +16,14 @@ ...@@ -16,10 +16,14 @@
package org.springframework.boot; package org.springframework.boot;
import java.util.HashSet;
import java.util.Set;
import org.junit.Test; import org.junit.Test;
import org.springframework.core.convert.support.ConfigurableConversionService; import org.springframework.core.convert.support.ConfigurableConversionService;
import org.springframework.core.env.AbstractEnvironment; import org.springframework.core.env.AbstractEnvironment;
import org.springframework.core.env.PropertySource;
import org.springframework.core.env.StandardEnvironment; import org.springframework.core.env.StandardEnvironment;
import org.springframework.mock.env.MockEnvironment; import org.springframework.mock.env.MockEnvironment;
import org.springframework.web.context.support.StandardServletEnvironment; import org.springframework.web.context.support.StandardServletEnvironment;
...@@ -32,6 +36,7 @@ import static org.mockito.Mockito.mock; ...@@ -32,6 +36,7 @@ import static org.mockito.Mockito.mock;
* *
* @author Ethan Rubinson * @author Ethan Rubinson
* @author Andy Wilkinson * @author Andy Wilkinson
* @author Madhura Bhave
*/ */
public class EnvironmentConverterTests { public class EnvironmentConverterTests {
...@@ -43,7 +48,8 @@ public class EnvironmentConverterTests { ...@@ -43,7 +48,8 @@ public class EnvironmentConverterTests {
AbstractEnvironment originalEnvironment = new MockEnvironment(); AbstractEnvironment originalEnvironment = new MockEnvironment();
originalEnvironment.setActiveProfiles("activeProfile1", "activeProfile2"); originalEnvironment.setActiveProfiles("activeProfile1", "activeProfile2");
StandardEnvironment convertedEnvironment = this.environmentConverter StandardEnvironment convertedEnvironment = this.environmentConverter
.convertToStandardEnvironmentIfNecessary(originalEnvironment); .convertEnvironmentIfNecessary(originalEnvironment,
StandardEnvironment.class);
assertThat(convertedEnvironment.getActiveProfiles()) assertThat(convertedEnvironment.getActiveProfiles())
.containsExactly("activeProfile1", "activeProfile2"); .containsExactly("activeProfile1", "activeProfile2");
} }
...@@ -55,16 +61,18 @@ public class EnvironmentConverterTests { ...@@ -55,16 +61,18 @@ public class EnvironmentConverterTests {
ConfigurableConversionService.class); ConfigurableConversionService.class);
originalEnvironment.setConversionService(conversionService); originalEnvironment.setConversionService(conversionService);
StandardEnvironment convertedEnvironment = this.environmentConverter StandardEnvironment convertedEnvironment = this.environmentConverter
.convertToStandardEnvironmentIfNecessary(originalEnvironment); .convertEnvironmentIfNecessary(originalEnvironment,
StandardEnvironment.class);
assertThat(convertedEnvironment.getConversionService()) assertThat(convertedEnvironment.getConversionService())
.isEqualTo(conversionService); .isEqualTo(conversionService);
} }
@Test @Test
public void standardEnvironmentIsReturnedUnconverted() { public void envClassSameShouldReturnEnvironmentUnconverted() {
StandardEnvironment standardEnvironment = new StandardEnvironment(); StandardEnvironment standardEnvironment = new StandardEnvironment();
StandardEnvironment convertedEnvironment = this.environmentConverter StandardEnvironment convertedEnvironment = this.environmentConverter
.convertToStandardEnvironmentIfNecessary(standardEnvironment); .convertEnvironmentIfNecessary(standardEnvironment,
StandardEnvironment.class);
assertThat(convertedEnvironment).isSameAs(standardEnvironment); assertThat(convertedEnvironment).isSameAs(standardEnvironment);
} }
...@@ -72,8 +80,53 @@ public class EnvironmentConverterTests { ...@@ -72,8 +80,53 @@ public class EnvironmentConverterTests {
public void standardServletEnvironmentIsConverted() { public void standardServletEnvironmentIsConverted() {
StandardServletEnvironment standardServletEnvironment = new StandardServletEnvironment(); StandardServletEnvironment standardServletEnvironment = new StandardServletEnvironment();
StandardEnvironment convertedEnvironment = this.environmentConverter StandardEnvironment convertedEnvironment = this.environmentConverter
.convertToStandardEnvironmentIfNecessary(standardServletEnvironment); .convertEnvironmentIfNecessary(standardServletEnvironment,
StandardEnvironment.class);
assertThat(convertedEnvironment).isNotSameAs(standardServletEnvironment); assertThat(convertedEnvironment).isNotSameAs(standardServletEnvironment);
} }
@Test
public void servletPropertySourcesAreNotCopiedOverIfNotWebEnvironment() {
StandardServletEnvironment standardServletEnvironment = new StandardServletEnvironment();
StandardEnvironment convertedEnvironment = this.environmentConverter
.convertEnvironmentIfNecessary(standardServletEnvironment,
StandardEnvironment.class);
assertThat(convertedEnvironment).isNotSameAs(standardServletEnvironment);
Set<String> names = new HashSet<>();
for (PropertySource<?> propertySource : convertedEnvironment
.getPropertySources()) {
names.add(propertySource.getName());
}
assertThat(names).doesNotContain(
StandardServletEnvironment.SERVLET_CONTEXT_PROPERTY_SOURCE_NAME,
StandardServletEnvironment.SERVLET_CONFIG_PROPERTY_SOURCE_NAME,
StandardServletEnvironment.JNDI_PROPERTY_SOURCE_NAME);
}
@Test
public void envClassSameShouldReturnEnvironmentUnconvertedEvenForWeb() {
StandardServletEnvironment standardServletEnvironment = new StandardServletEnvironment();
StandardEnvironment convertedEnvironment = this.environmentConverter
.convertEnvironmentIfNecessary(standardServletEnvironment,
StandardServletEnvironment.class);
assertThat(convertedEnvironment).isSameAs(standardServletEnvironment);
}
@Test
public void servletPropertySourcesArePresentWhenTypeToConvertIsWeb() {
StandardEnvironment standardEnvironment = new StandardEnvironment();
StandardEnvironment convertedEnvironment = this.environmentConverter
.convertEnvironmentIfNecessary(standardEnvironment,
StandardServletEnvironment.class);
assertThat(convertedEnvironment).isNotSameAs(standardEnvironment);
Set<String> names = new HashSet<>();
for (PropertySource<?> propertySource : convertedEnvironment
.getPropertySources()) {
names.add(propertySource.getName());
}
assertThat(names).contains(
StandardServletEnvironment.SERVLET_CONTEXT_PROPERTY_SOURCE_NAME,
StandardServletEnvironment.SERVLET_CONFIG_PROPERTY_SOURCE_NAME);
}
} }
...@@ -60,6 +60,7 @@ import org.springframework.boot.web.embedded.netty.NettyReactiveWebServerFactory ...@@ -60,6 +60,7 @@ import org.springframework.boot.web.embedded.netty.NettyReactiveWebServerFactory
import org.springframework.boot.web.embedded.tomcat.TomcatServletWebServerFactory; import org.springframework.boot.web.embedded.tomcat.TomcatServletWebServerFactory;
import org.springframework.boot.web.reactive.context.AnnotationConfigReactiveWebServerApplicationContext; import org.springframework.boot.web.reactive.context.AnnotationConfigReactiveWebServerApplicationContext;
import org.springframework.boot.web.reactive.context.ReactiveWebApplicationContext; import org.springframework.boot.web.reactive.context.ReactiveWebApplicationContext;
import org.springframework.boot.web.reactive.context.StandardReactiveWebEnvironment;
import org.springframework.boot.web.servlet.context.AnnotationConfigServletWebServerApplicationContext; import org.springframework.boot.web.servlet.context.AnnotationConfigServletWebServerApplicationContext;
import org.springframework.context.ApplicationContext; import org.springframework.context.ApplicationContext;
import org.springframework.context.ApplicationContextAware; import org.springframework.context.ApplicationContextAware;
...@@ -413,6 +414,25 @@ public class SpringApplicationTests { ...@@ -413,6 +414,25 @@ public class SpringApplicationTests {
.isInstanceOf(AnnotationConfigReactiveWebServerApplicationContext.class); .isInstanceOf(AnnotationConfigReactiveWebServerApplicationContext.class);
} }
@Test
public void environmentForWeb() {
SpringApplication application = new SpringApplication(ExampleWebConfig.class);
application.setWebApplicationType(WebApplicationType.SERVLET);
this.context = application.run();
assertThat(this.context.getEnvironment())
.isInstanceOf(StandardServletEnvironment.class);
}
@Test
public void environmentForReactiveWeb() {
SpringApplication application = new SpringApplication(
ExampleReactiveWebConfig.class);
application.setWebApplicationType(WebApplicationType.REACTIVE);
this.context = application.run();
assertThat(this.context.getEnvironment())
.isInstanceOf(StandardReactiveWebEnvironment.class);
}
@Test @Test
public void customEnvironment() { public void customEnvironment() {
TestSpringApplication application = new TestSpringApplication( TestSpringApplication application = new TestSpringApplication(
...@@ -1098,6 +1118,35 @@ public class SpringApplicationTests { ...@@ -1098,6 +1118,35 @@ public class SpringApplicationTests {
.isNotInstanceOfAny(ConfigurableWebEnvironment.class); .isNotInstanceOfAny(ConfigurableWebEnvironment.class);
} }
@Test
public void webApplicationConfiguredViaAPropertyHasTheCorrectTypeOfContextAndEnvironment() {
ConfigurableApplicationContext context = new SpringApplication(
ExampleWebConfig.class).run("--spring.main.web-application-type=servlet");
assertThat(context).isInstanceOfAny(WebApplicationContext.class);
assertThat(context.getEnvironment())
.isInstanceOfAny(StandardServletEnvironment.class);
}
@Test
public void reactiveApplicationConfiguredViaAPropertyHasTheCorrectTypeOfContextAndEnvironment() {
ConfigurableApplicationContext context = new SpringApplication(
ExampleReactiveWebConfig.class)
.run("--spring.main.web-application-type=reactive");
assertThat(context).isInstanceOfAny(ReactiveWebApplicationContext.class);
assertThat(context.getEnvironment())
.isInstanceOfAny(StandardReactiveWebEnvironment.class);
}
@Test
public void environmentIsConvertedIfTypeDoesNotMatch() {
ConfigurableApplicationContext context = new SpringApplication(
ExampleReactiveWebConfig.class)
.run("--spring.profiles.active=withwebapplicationtype");
assertThat(context).isInstanceOfAny(ReactiveWebApplicationContext.class);
assertThat(context.getEnvironment())
.isInstanceOfAny(StandardReactiveWebEnvironment.class);
}
@Test @Test
public void failureResultsInSingleStackTrace() throws Exception { public void failureResultsInSingleStackTrace() throws Exception {
ThreadGroup group = new ThreadGroup("main"); ThreadGroup group = new ThreadGroup("main");
......
spring.main.web-application-type: reactive
\ No newline at end of file
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment