Commit 3904f49c authored by Andy Wilkinson's avatar Andy Wilkinson

Configure ServletContext before initializing S…C…Initializer beans

Previously, the ServletContext was configured after any
ServletContextInitializer beans had been initialized. This meant that
any configuration class that provided such a bean would be initialized
before the ServletContext was configured. If the configuration class
used the ServletContext in its initializtaion that it would see it in
its default, unconfigured state.

This commit reworks the configuration of the ServletContext so that
it happens before any ServletContextInitializer beans are initialized.

Closes gh-10699
parent d8b3c7cc
/*
* Copyright 2012-2017 the original author or authors.
* Copyright 2012-2018 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
......@@ -23,9 +23,9 @@ import java.util.Set;
import org.springframework.beans.factory.ObjectProvider;
import org.springframework.boot.autoconfigure.web.ServerProperties;
import org.springframework.boot.autoconfigure.web.ServerProperties.Servlet.Session;
import org.springframework.boot.context.properties.ConfigurationProperties;
import org.springframework.boot.web.servlet.DispatcherType;
import org.springframework.boot.web.servlet.server.Session;
import org.springframework.session.web.http.SessionRepositoryFilter;
/**
......
......@@ -27,7 +27,6 @@ import java.util.HashMap;
import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.Set;
import java.util.TimeZone;
import org.springframework.boot.context.properties.ConfigurationProperties;
......@@ -37,6 +36,7 @@ import org.springframework.boot.web.server.Compression;
import org.springframework.boot.web.server.Http2;
import org.springframework.boot.web.server.Ssl;
import org.springframework.boot.web.servlet.server.Jsp;
import org.springframework.boot.web.servlet.server.Session;
import org.springframework.util.Assert;
import org.springframework.util.StringUtils;
......@@ -319,196 +319,8 @@ public class ServerProperties {
return result;
}
/**
* Session properties.
*/
public static class Session {
/**
* Session timeout. If a duration suffix is not specified, seconds will be used.
*/
@DefaultDurationUnit(ChronoUnit.SECONDS)
private Duration timeout;
/**
* Session tracking modes (one or more of the following: "cookie", "url", "ssl").
*/
private Set<SessionTrackingMode> trackingModes;
/**
* Whether to persist session data between restarts.
*/
private boolean persistent;
/**
* Directory used to store session data.
*/
private File storeDir;
private final Cookie cookie = new Cookie();
public Cookie getCookie() {
return this.cookie;
}
public Duration getTimeout() {
return this.timeout;
}
public void setTimeout(Duration timeout) {
this.timeout = timeout;
}
public Set<SessionTrackingMode> getTrackingModes() {
return this.trackingModes;
}
public void setTrackingModes(Set<SessionTrackingMode> trackingModes) {
this.trackingModes = trackingModes;
}
public boolean isPersistent() {
return this.persistent;
}
public void setPersistent(boolean persistent) {
this.persistent = persistent;
}
public File getStoreDir() {
return this.storeDir;
}
public void setStoreDir(File storeDir) {
this.storeDir = storeDir;
}
/**
* Cookie properties.
*/
public static class Cookie {
/**
* Session cookie name.
*/
private String name;
/**
* Domain for the session cookie.
*/
private String domain;
/**
* Path of the session cookie.
*/
private String path;
/**
* Comment for the session cookie.
*/
private String comment;
/**
* "HttpOnly" flag for the session cookie.
*/
private Boolean httpOnly;
/**
* "Secure" flag for the session cookie.
*/
private Boolean secure;
/**
* Maximum age of the session cookie.
*/
@DefaultDurationUnit(ChronoUnit.SECONDS)
private Duration maxAge;
public String getName() {
return this.name;
}
public void setName(String name) {
this.name = name;
}
public String getDomain() {
return this.domain;
}
public void setDomain(String domain) {
this.domain = domain;
}
public String getPath() {
return this.path;
}
public void setPath(String path) {
this.path = path;
}
public String getComment() {
return this.comment;
}
public void setComment(String comment) {
this.comment = comment;
}
public Boolean getHttpOnly() {
return this.httpOnly;
}
public void setHttpOnly(Boolean httpOnly) {
this.httpOnly = httpOnly;
}
public Boolean getSecure() {
return this.secure;
}
public void setSecure(Boolean secure) {
this.secure = secure;
}
public Duration getMaxAge() {
return this.maxAge;
}
public void setMaxAge(Duration maxAge) {
this.maxAge = maxAge;
}
}
/**
* Available session tracking modes (mirrors
* {@link javax.servlet.SessionTrackingMode}.
*/
public enum SessionTrackingMode {
/**
* Send a cookie in response to the client's first request.
*/
COOKIE,
/**
* Rewrite the URL to append a session ID.
*/
URL,
/**
* Use SSL build-in mechanism to track the session.
*/
SSL
}
}
}
/**
* Tomcat properties.
*/
......
......@@ -16,15 +16,7 @@
package org.springframework.boot.autoconfigure.web.servlet;
import java.util.LinkedHashSet;
import java.util.Set;
import javax.servlet.ServletContext;
import javax.servlet.ServletException;
import javax.servlet.SessionCookieConfig;
import org.springframework.boot.autoconfigure.web.ServerProperties;
import org.springframework.boot.autoconfigure.web.ServerProperties.Servlet.Session;
import org.springframework.boot.autoconfigure.web.embedded.jetty.JettyCustomizer;
import org.springframework.boot.autoconfigure.web.embedded.tomcat.TomcatCustomizer;
import org.springframework.boot.autoconfigure.web.embedded.undertow.UndertowCustomizer;
......@@ -33,9 +25,7 @@ import org.springframework.boot.web.embedded.tomcat.ConfigurableTomcatWebServerF
import org.springframework.boot.web.embedded.tomcat.TomcatServletWebServerFactory;
import org.springframework.boot.web.embedded.undertow.UndertowServletWebServerFactory;
import org.springframework.boot.web.server.WebServerFactoryCustomizer;
import org.springframework.boot.web.servlet.ServletContextInitializer;
import org.springframework.boot.web.servlet.server.ConfigurableServletWebServerFactory;
import org.springframework.boot.web.servlet.server.InitParameterConfiguringServletContextInitializer;
import org.springframework.context.EnvironmentAware;
import org.springframework.core.Ordered;
import org.springframework.core.env.Environment;
......@@ -89,11 +79,7 @@ public class DefaultServletWebServerFactoryCustomizer
if (this.serverProperties.getDisplayName() != null) {
factory.setDisplayName(this.serverProperties.getDisplayName());
}
if (this.serverProperties.getServlet().getSession().getTimeout() != null) {
factory.setSessionTimeout(this.serverProperties.getServlet().getSession().getTimeout());
}
factory.setPersistSession(this.serverProperties.getServlet().getSession().isPersistent());
factory.setSessionStoreDir(this.serverProperties.getServlet().getSession().getStoreDir());
factory.setSession(this.serverProperties.getServlet().getSession());
if (this.serverProperties.getSsl() != null) {
factory.setSsl(this.serverProperties.getSsl());
}
......@@ -109,8 +95,10 @@ public class DefaultServletWebServerFactoryCustomizer
factory.setServerHeader(this.serverProperties.getServerHeader());
if (factory instanceof TomcatServletWebServerFactory) {
TomcatServletWebServerFactory tomcatFactory = (TomcatServletWebServerFactory) factory;
TomcatCustomizer.customizeTomcat(this.serverProperties, this.environment, tomcatFactory);
TomcatServletCustomizer.customizeTomcat(this.serverProperties, this.environment, tomcatFactory);
TomcatCustomizer.customizeTomcat(this.serverProperties, this.environment,
tomcatFactory);
TomcatServletCustomizer.customizeTomcat(this.serverProperties,
this.environment, tomcatFactory);
}
if (factory instanceof JettyServletWebServerFactory) {
JettyCustomizer.customizeJetty(this.serverProperties, this.environment,
......@@ -120,71 +108,8 @@ public class DefaultServletWebServerFactoryCustomizer
UndertowCustomizer.customizeUndertow(this.serverProperties, this.environment,
(UndertowServletWebServerFactory) factory);
}
factory.addInitializers(
new SessionConfiguringInitializer(this.serverProperties.getServlet().getSession()));
factory.addInitializers(new InitParameterConfiguringServletContextInitializer(
this.serverProperties.getServlet().getContextParameters()));
}
/**
* {@link ServletContextInitializer} to apply appropriate parts of the {@link Session}
* configuration.
*/
private static class SessionConfiguringInitializer
implements ServletContextInitializer {
private final Session session;
SessionConfiguringInitializer(Session session) {
this.session = session;
}
@Override
public void onStartup(ServletContext servletContext) throws ServletException {
if (this.session.getTrackingModes() != null) {
servletContext
.setSessionTrackingModes(unwrap(this.session.getTrackingModes()));
}
configureSessionCookie(servletContext.getSessionCookieConfig());
}
private void configureSessionCookie(SessionCookieConfig config) {
Session.Cookie cookie = this.session.getCookie();
if (cookie.getName() != null) {
config.setName(cookie.getName());
}
if (cookie.getDomain() != null) {
config.setDomain(cookie.getDomain());
}
if (cookie.getPath() != null) {
config.setPath(cookie.getPath());
}
if (cookie.getComment() != null) {
config.setComment(cookie.getComment());
}
if (cookie.getHttpOnly() != null) {
config.setHttpOnly(cookie.getHttpOnly());
}
if (cookie.getSecure() != null) {
config.setSecure(cookie.getSecure());
}
if (cookie.getMaxAge() != null) {
config.setMaxAge((int) cookie.getMaxAge().getSeconds());
}
}
private Set<javax.servlet.SessionTrackingMode> unwrap(
Set<Session.SessionTrackingMode> modes) {
if (modes == null) {
return null;
}
Set<javax.servlet.SessionTrackingMode> result = new LinkedHashSet<>();
for (Session.SessionTrackingMode mode : modes) {
result.add(javax.servlet.SessionTrackingMode.valueOf(mode.name()));
}
return result;
}
factory.setInitParameters(
this.serverProperties.getServlet().getContextParameters());
}
private static class TomcatServletCustomizer {
......@@ -213,7 +138,8 @@ public class DefaultServletWebServerFactoryCustomizer
}
private static void customizeUseRelativeRedirects(
ConfigurableTomcatWebServerFactory factory, boolean useRelativeRedirects) {
ConfigurableTomcatWebServerFactory factory,
boolean useRelativeRedirects) {
factory.addContextCustomizers(
(context) -> context.setUseRelativeRedirects(useRelativeRedirects));
}
......
/*
* Copyright 2012-2017 the original author or authors.
* Copyright 2012-2018 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
......@@ -25,13 +25,17 @@ import org.junit.Test;
import org.springframework.beans.factory.BeanCreationException;
import org.springframework.boot.autoconfigure.AutoConfigurations;
import org.springframework.boot.autoconfigure.web.servlet.ServletWebServerFactoryAutoConfiguration;
import org.springframework.boot.test.context.runner.WebApplicationContextRunner;
import org.springframework.boot.web.servlet.FilterRegistrationBean;
import org.springframework.boot.web.servlet.context.AnnotationConfigServletWebServerApplicationContext;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import org.springframework.session.MapSessionRepository;
import org.springframework.session.SessionRepository;
import org.springframework.session.config.annotation.web.http.EnableSpringHttpSession;
import org.springframework.session.web.http.CookieHttpSessionIdResolver;
import org.springframework.session.web.http.DefaultCookieSerializer;
import org.springframework.session.web.http.SessionRepositoryFilter;
import org.springframework.test.util.ReflectionTestUtils;
......@@ -147,6 +151,28 @@ public class SessionAutoConfigurationTests extends AbstractSessionAutoConfigurat
});
}
@Test
public void sessionCookieConfigurationIsPickedUp() {
new WebApplicationContextRunner(
AnnotationConfigServletWebServerApplicationContext::new)
.withConfiguration(AutoConfigurations
.of(ServletWebServerFactoryAutoConfiguration.class))
.withUserConfiguration(SessionRepositoryConfiguration.class)
.withPropertyValues("server.port=0",
"server.servlet.session.cookie.name=testname")
.run((context) -> {
SessionRepositoryFilter<?> filter = context
.getBean(SessionRepositoryFilter.class);
CookieHttpSessionIdResolver sessionIdResolver = (CookieHttpSessionIdResolver) ReflectionTestUtils
.getField(filter, "httpSessionIdResolver");
DefaultCookieSerializer cookieSerializer = (DefaultCookieSerializer) ReflectionTestUtils
.getField(sessionIdResolver, "cookieSerializer");
String cookieName = (String) ReflectionTestUtils
.getField(cookieSerializer, "cookieName");
assertThat(cookieName).isEqualTo("testname");
});
}
@Configuration
@EnableSpringHttpSession
static class SessionRepositoryConfiguration {
......
/*
* Copyright 2012-2017 the original author or authors.
* Copyright 2012-2018 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
......@@ -19,17 +19,11 @@ package org.springframework.boot.autoconfigure.web.servlet;
import java.io.File;
import java.io.IOException;
import java.time.Duration;
import java.util.EnumSet;
import java.util.HashMap;
import java.util.Locale;
import java.util.Map;
import java.util.TimeZone;
import javax.servlet.ServletContext;
import javax.servlet.ServletException;
import javax.servlet.SessionCookieConfig;
import javax.servlet.SessionTrackingMode;
import org.apache.catalina.Context;
import org.apache.catalina.Valve;
import org.apache.catalina.startup.Tomcat;
......@@ -57,12 +51,12 @@ import org.springframework.boot.web.embedded.tomcat.TomcatWebServer;
import org.springframework.boot.web.embedded.undertow.UndertowServletWebServerFactory;
import org.springframework.boot.web.servlet.ServletContextInitializer;
import org.springframework.boot.web.servlet.server.ConfigurableServletWebServerFactory;
import org.springframework.boot.web.servlet.server.Session;
import org.springframework.boot.web.servlet.server.Session.Cookie;
import org.springframework.mock.env.MockEnvironment;
import static org.assertj.core.api.Assertions.assertThat;
import static org.mockito.ArgumentMatchers.anyBoolean;
import static org.mockito.BDDMockito.given;
import static org.mockito.Mockito.atLeastOnce;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.never;
import static org.mockito.Mockito.spy;
......@@ -224,21 +218,19 @@ public class DefaultServletWebServerFactoryCustomizerTests {
bindProperties(map);
ConfigurableServletWebServerFactory factory = mock(
ConfigurableServletWebServerFactory.class);
ServletContext servletContext = mock(ServletContext.class);
SessionCookieConfig sessionCookieConfig = mock(SessionCookieConfig.class);
given(servletContext.getSessionCookieConfig()).willReturn(sessionCookieConfig);
this.customizer.customize(factory);
triggerInitializers(factory, servletContext);
verify(factory).setSessionTimeout(Duration.ofSeconds(123));
verify(servletContext).setSessionTrackingModes(
EnumSet.of(SessionTrackingMode.COOKIE, SessionTrackingMode.URL));
verify(sessionCookieConfig).setName("testname");
verify(sessionCookieConfig).setDomain("testdomain");
verify(sessionCookieConfig).setPath("/testpath");
verify(sessionCookieConfig).setComment("testcomment");
verify(sessionCookieConfig).setHttpOnly(true);
verify(sessionCookieConfig).setSecure(true);
verify(sessionCookieConfig).setMaxAge(60);
ArgumentCaptor<Session> sessionCaptor = ArgumentCaptor.forClass(Session.class);
verify(factory).setSession(sessionCaptor.capture());
assertThat(sessionCaptor.getValue().getTimeout())
.isEqualTo(Duration.ofSeconds(123));
Cookie cookie = sessionCaptor.getValue().getCookie();
assertThat(cookie.getName()).isEqualTo("testname");
assertThat(cookie.getDomain()).isEqualTo("testdomain");
assertThat(cookie.getPath()).isEqualTo("/testpath");
assertThat(cookie.getComment()).isEqualTo("testcomment");
assertThat(cookie.getHttpOnly()).isTrue();
assertThat(cookie.getMaxAge()).isEqualTo(Duration.ofSeconds(60));
}
@Test
......@@ -540,7 +532,10 @@ public class DefaultServletWebServerFactoryCustomizerTests {
bindProperties(map);
JettyServletWebServerFactory factory = spy(new JettyServletWebServerFactory());
this.customizer.customize(factory);
verify(factory).setSessionStoreDir(new File("myfolder"));
ArgumentCaptor<Session> sessionCaptor = ArgumentCaptor.forClass(Session.class);
verify(factory).setSession(sessionCaptor.capture());
assertThat(sessionCaptor.getValue().getStoreDir())
.isEqualTo(new File("myfolder"));
}
@Test
......@@ -638,21 +633,6 @@ public class DefaultServletWebServerFactoryCustomizerTests {
}
}
private void triggerInitializers(ConfigurableServletWebServerFactory factory,
ServletContext servletContext) throws ServletException {
verify(factory, atLeastOnce()).addInitializers(this.initializersCaptor.capture());
for (Object initializers : this.initializersCaptor.getAllValues()) {
if (initializers instanceof ServletContextInitializer) {
((ServletContextInitializer) initializers).onStartup(servletContext);
}
else {
for (ServletContextInitializer initializer : (ServletContextInitializer[]) initializers) {
initializer.onStartup(servletContext);
}
}
}
}
private void bindProperties(Map<String, String> map) {
ConfigurationPropertySource source = new MapConfigurationPropertySource(map);
new Binder(source).bind("server", Bindable.ofInstance(this.properties));
......
......@@ -24,6 +24,7 @@ import java.net.MalformedURLException;
import java.net.URL;
import java.nio.channels.ReadableByteChannel;
import java.nio.charset.Charset;
import java.time.Duration;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
......@@ -231,10 +232,11 @@ public class JettyServletWebServerFactory extends AbstractServletWebServerFactor
private void configureSession(WebAppContext context) {
SessionHandler handler = context.getSessionHandler();
Duration sessionTimeout = getSession().getTimeout();
handler.setMaxInactiveInterval(
(getSessionTimeout() == null || getSessionTimeout().isNegative()) ? -1
: (int) getSessionTimeout().getSeconds());
if (isPersistSession()) {
(sessionTimeout == null || sessionTimeout.isNegative()) ? -1
: (int) sessionTimeout.getSeconds());
if (getSession().isPersistent()) {
DefaultSessionCache cache = new DefaultSessionCache(handler);
FileSessionDataStore store = new FileSessionDataStore();
store.setStoreDir(getValidSessionStoreDir());
......
......@@ -362,7 +362,7 @@ public class TomcatServletWebServerFactory extends AbstractServletWebServerFacto
private void configureSession(Context context) {
long sessionTimeout = getSessionTimeoutInMinutes();
context.setSessionTimeout((int) sessionTimeout);
if (isPersistSession()) {
if (getSession().isPersistent()) {
Manager manager = context.getManager();
if (manager == null) {
manager = new StandardManager();
......@@ -385,7 +385,7 @@ public class TomcatServletWebServerFactory extends AbstractServletWebServerFacto
}
private long getSessionTimeoutInMinutes() {
Duration sessionTimeout = getSessionTimeout();
Duration sessionTimeout = getSession().getTimeout();
if (sessionTimeout == null || sessionTimeout.isNegative()
|| sessionTimeout.isZero()) {
return 0;
......@@ -516,8 +516,8 @@ public class TomcatServletWebServerFactory extends AbstractServletWebServerFacto
}
/**
* Set {@link LifecycleListener}s that should be applied to the Tomcat {@link Context}.
* Calling this method will replace any existing listeners.
* Set {@link LifecycleListener}s that should be applied to the Tomcat
* {@link Context}. Calling this method will replace any existing listeners.
* @param contextLifecycleListeners the listeners to set
*/
public void setContextLifecycleListeners(
......
......@@ -21,6 +21,7 @@ import java.io.IOException;
import java.net.MalformedURLException;
import java.net.URL;
import java.nio.charset.Charset;
import java.time.Duration;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collection;
......@@ -239,8 +240,7 @@ public class UndertowServletWebServerFactory extends AbstractServletWebServerFac
new SslBuilderCustomizer(getPort(), getAddress(), getSsl(), getSslStoreProvider())
.customize(builder);
if (getHttp2() != null) {
builder.setServerOption(UndertowOptions.ENABLE_HTTP2,
getHttp2().isEnabled());
builder.setServerOption(UndertowOptions.ENABLE_HTTP2, getHttp2().isEnabled());
}
}
......@@ -274,7 +274,7 @@ public class UndertowServletWebServerFactory extends AbstractServletWebServerFac
if (isAccessLogEnabled()) {
configureAccessLog(deployment);
}
if (isPersistSession()) {
if (getSession().isPersistent()) {
File dir = getValidSessionStoreDir();
deployment.setSessionPersistenceManager(new FileSessionPersistence(dir));
}
......@@ -282,9 +282,10 @@ public class UndertowServletWebServerFactory extends AbstractServletWebServerFac
DeploymentManager manager = Servlets.newContainer().addDeployment(deployment);
manager.deploy();
SessionManager sessionManager = manager.getDeployment().getSessionManager();
int sessionTimeout = (getSessionTimeout() == null || getSessionTimeout().isZero()
|| getSessionTimeout().isNegative() ? -1
: (int) getSessionTimeout().getSeconds());
Duration timeoutDuration = getSession().getTimeout();
int sessionTimeout = (timeoutDuration == null || timeoutDuration.isZero()
|| timeoutDuration.isNegative() ? -1
: (int) timeoutDuration.getSeconds());
sessionManager.setDefaultSessionTimeout(sessionTimeout);
return manager;
}
......
/*
* Copyright 2012-2017 the original author or authors.
* Copyright 2012-2018 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
......@@ -19,13 +19,19 @@ package org.springframework.boot.web.servlet.server;
import java.io.File;
import java.net.URL;
import java.nio.charset.Charset;
import java.time.Duration;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap;
import java.util.LinkedHashSet;
import java.util.List;
import java.util.Locale;
import java.util.Map;
import java.util.Set;
import javax.servlet.ServletContext;
import javax.servlet.ServletException;
import javax.servlet.SessionCookieConfig;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
......@@ -58,9 +64,7 @@ public abstract class AbstractServletWebServerFactory
private String displayName;
private Duration sessionTimeout = Duration.ofMinutes(30);
private boolean persistSession;
private Session session = new Session();
private boolean registerDefaultServlet = true;
......@@ -72,7 +76,7 @@ public abstract class AbstractServletWebServerFactory
private Map<Locale, Charset> localeCharsetMappings = new HashMap<>();
private final SessionStoreDirectory sessionStoreDir = new SessionStoreDirectory();
private Map<String, String> initParameters = Collections.emptyMap();
private final DocumentRoot documentRoot = new DocumentRoot(this.logger);
......@@ -143,37 +147,6 @@ public abstract class AbstractServletWebServerFactory
this.displayName = displayName;
}
/**
* Return the session timeout or {@code null}.
* @return the session timeout
*/
public Duration getSessionTimeout() {
return this.sessionTimeout;
}
@Override
public void setSessionTimeout(Duration sessionTimeout) {
this.sessionTimeout = sessionTimeout;
}
public boolean isPersistSession() {
return this.persistSession;
}
@Override
public void setPersistSession(boolean persistSession) {
this.persistSession = persistSession;
}
public File getSessionStoreDir() {
return this.sessionStoreDir.getDirectory();
}
@Override
public void setSessionStoreDir(File sessionStoreDir) {
this.sessionStoreDir.setDirectory(sessionStoreDir);
}
/**
* Flag to indicate that the default servlet should be registered.
* @return true if the default servlet is to be registered
......@@ -235,6 +208,15 @@ public abstract class AbstractServletWebServerFactory
this.jsp = jsp;
}
public Session getSession() {
return this.session;
}
@Override
public void setSession(Session session) {
this.session = session;
}
/**
* Return the Locale to Charset mappings.
* @return the charset mappings
......@@ -249,6 +231,15 @@ public abstract class AbstractServletWebServerFactory
this.localeCharsetMappings = localeCharsetMappings;
}
@Override
public void setInitParameters(Map<String, String> initParameters) {
this.initParameters = initParameters;
}
public Map<String, String> getInitParameters() {
return this.initParameters;
}
/**
* Utility method that can be used by subclasses wishing to combine the specified
* {@link ServletContextInitializer} parameters with those defined in this instance.
......@@ -259,6 +250,9 @@ public abstract class AbstractServletWebServerFactory
protected final ServletContextInitializer[] mergeInitializers(
ServletContextInitializer... initializers) {
List<ServletContextInitializer> mergedInitializers = new ArrayList<>();
mergedInitializers.add((servletContext) -> this.initParameters
.forEach(servletContext::setInitParameter));
mergedInitializers.add(new SessionConfiguringInitializer(this.session));
mergedInitializers.addAll(Arrays.asList(initializers));
mergedInitializers.addAll(this.initializers);
return mergedInitializers
......@@ -288,11 +282,72 @@ public abstract class AbstractServletWebServerFactory
}
protected final File getValidSessionStoreDir() {
return this.sessionStoreDir.getValidDirectory(true);
return getValidSessionStoreDir(true);
}
protected final File getValidSessionStoreDir(boolean mkdirs) {
return this.sessionStoreDir.getValidDirectory(mkdirs);
return this.session.getSessionStoreDirectory().getValidDirectory(mkdirs);
}
/**
* {@link ServletContextInitializer} to apply appropriate parts of the {@link Session}
* configuration.
*/
private static class SessionConfiguringInitializer
implements ServletContextInitializer {
private final Session session;
SessionConfiguringInitializer(Session session) {
this.session = session;
}
@Override
public void onStartup(ServletContext servletContext) throws ServletException {
if (this.session.getTrackingModes() != null) {
servletContext
.setSessionTrackingModes(unwrap(this.session.getTrackingModes()));
}
configureSessionCookie(servletContext.getSessionCookieConfig());
}
private void configureSessionCookie(SessionCookieConfig config) {
Session.Cookie cookie = this.session.getCookie();
if (cookie.getName() != null) {
config.setName(cookie.getName());
}
if (cookie.getDomain() != null) {
config.setDomain(cookie.getDomain());
}
if (cookie.getPath() != null) {
config.setPath(cookie.getPath());
}
if (cookie.getComment() != null) {
config.setComment(cookie.getComment());
}
if (cookie.getHttpOnly() != null) {
config.setHttpOnly(cookie.getHttpOnly());
}
if (cookie.getSecure() != null) {
config.setSecure(cookie.getSecure());
}
if (cookie.getMaxAge() != null) {
config.setMaxAge((int) cookie.getMaxAge().getSeconds());
}
}
private Set<javax.servlet.SessionTrackingMode> unwrap(
Set<Session.SessionTrackingMode> modes) {
if (modes == null) {
return null;
}
Set<javax.servlet.SessionTrackingMode> result = new LinkedHashSet<>();
for (Session.SessionTrackingMode mode : modes) {
result.add(javax.servlet.SessionTrackingMode.valueOf(mode.name()));
}
return result;
}
}
}
/*
* Copyright 2012-2017 the original author or authors.
* Copyright 2012-2018 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
......@@ -18,11 +18,12 @@ package org.springframework.boot.web.servlet.server;
import java.io.File;
import java.nio.charset.Charset;
import java.time.Duration;
import java.util.List;
import java.util.Locale;
import java.util.Map;
import javax.servlet.ServletContext;
import org.springframework.boot.web.server.ConfigurableWebServerFactory;
import org.springframework.boot.web.server.MimeMappings;
import org.springframework.boot.web.server.WebServerFactoryCustomizer;
......@@ -59,23 +60,32 @@ public interface ConfigurableServletWebServerFactory
void setDisplayName(String displayName);
/**
* The session timeout in seconds (default 30 minutes). If {@code null} then sessions
* never expire.
* @param sessionTimeout the session timeout
*/
void setSessionTimeout(Duration sessionTimeout);
/**
* Sets if session data should be persisted between restarts.
* @param persistSession {@code true} if session data should be persisted
*/
void setPersistSession(boolean persistSession);
/**
* Set the directory used to store serialized session data.
* @param sessionStoreDir the directory or {@code null} to use a default location.
* Sets the configuration that will be applied to the container's HTTP session
* support.
*
* @param session the session configuration
*/
void setSessionStoreDir(File sessionStoreDir);
void setSession(Session session);
// /**
// * The session timeout in seconds (default 30 minutes). If {@code null} then
// sessions
// * never expire.
// * @param sessionTimeout the session timeout
// */
// void setSessionTimeout(Duration sessionTimeout);
//
// /**
// * Sets if session data should be persisted between restarts.
// * @param persistSession {@code true} if session data should be persisted
// */
// void setPersistSession(boolean persistSession);
//
// /**
// * Set the directory used to store serialized session data.
// * @param sessionStoreDir the directory or {@code null} to use a default location.
// */
// void setSessionStoreDir(File sessionStoreDir);
/**
* Set if the DefaultServlet should be registered. Defaults to {@code true} so that
......@@ -128,4 +138,12 @@ public interface ConfigurableServletWebServerFactory
*/
void setLocaleCharsetMappings(Map<Locale, Charset> localeCharsetMappings);
/**
* Sets the init parameters that are applied to the container's
* {@link ServletContext}.
*
* @param initParameters the init parameters
*/
void setInitParameters(Map<String, String> initParameters);
}
/*
* Copyright 2012-2017 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.boot.web.servlet.server;
import java.util.Map;
import java.util.Map.Entry;
import javax.servlet.ServletContext;
import javax.servlet.ServletException;
import org.springframework.boot.web.servlet.ServletContextInitializer;
/**
* A {@code ServletContextInitializer} that configures init parameters on the
* {@code ServletContext}.
*
* @author Andy Wilkinson
* @since 2.0.0
* @see ServletContext#setInitParameter(String, String)
*/
public class InitParameterConfiguringServletContextInitializer
implements ServletContextInitializer {
private final Map<String, String> parameters;
public InitParameterConfiguringServletContextInitializer(
Map<String, String> parameters) {
this.parameters = parameters;
}
@Override
public void onStartup(ServletContext servletContext) throws ServletException {
for (Entry<String, String> entry : this.parameters.entrySet()) {
servletContext.setInitParameter(entry.getKey(), entry.getValue());
}
}
}
/*
* Copyright 2012-2017 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.boot.web.servlet.server;
import java.io.File;
import java.time.Duration;
import java.time.temporal.ChronoUnit;
import java.util.Set;
import org.springframework.boot.context.properties.bind.convert.DefaultDurationUnit;
/**
* Session properties.
*
* @author Andy Wilkinson
* @since 2.0.0
*/
public class Session {
/**
* Session timeout. If a duration suffix is not specified, seconds will be used.
*/
@DefaultDurationUnit(ChronoUnit.SECONDS)
private Duration timeout = Duration.ofMinutes(30);
/**
* Session tracking modes (one or more of the following: "cookie", "url", "ssl").
*/
private Set<Session.SessionTrackingMode> trackingModes;
/**
* Whether to persist session data between restarts.
*/
private boolean persistent;
/**
* Directory used to store session data.
*/
private File storeDir;
private final Cookie cookie = new Cookie();
private final SessionStoreDirectory sessionStoreDirectory = new SessionStoreDirectory();
public Cookie getCookie() {
return this.cookie;
}
public Duration getTimeout() {
return this.timeout;
}
public void setTimeout(Duration timeout) {
this.timeout = timeout;
}
public Set<Session.SessionTrackingMode> getTrackingModes() {
return this.trackingModes;
}
public void setTrackingModes(Set<Session.SessionTrackingMode> trackingModes) {
this.trackingModes = trackingModes;
}
public boolean isPersistent() {
return this.persistent;
}
public void setPersistent(boolean persistent) {
this.persistent = persistent;
}
public File getStoreDir() {
return this.storeDir;
}
public void setStoreDir(File storeDir) {
this.sessionStoreDirectory.setDirectory(storeDir);
this.storeDir = storeDir;
}
SessionStoreDirectory getSessionStoreDirectory() {
return this.sessionStoreDirectory;
}
/**
* Cookie properties.
*/
public static class Cookie {
/**
* Session cookie name.
*/
private String name;
/**
* Domain for the session cookie.
*/
private String domain;
/**
* Path of the session cookie.
*/
private String path;
/**
* Comment for the session cookie.
*/
private String comment;
/**
* "HttpOnly" flag for the session cookie.
*/
private Boolean httpOnly;
/**
* "Secure" flag for the session cookie.
*/
private Boolean secure;
/**
* Maximum age of the session cookie.
*/
@DefaultDurationUnit(ChronoUnit.SECONDS)
private Duration maxAge;
public String getName() {
return this.name;
}
public void setName(String name) {
this.name = name;
}
public String getDomain() {
return this.domain;
}
public void setDomain(String domain) {
this.domain = domain;
}
public String getPath() {
return this.path;
}
public void setPath(String path) {
this.path = path;
}
public String getComment() {
return this.comment;
}
public void setComment(String comment) {
this.comment = comment;
}
public Boolean getHttpOnly() {
return this.httpOnly;
}
public void setHttpOnly(Boolean httpOnly) {
this.httpOnly = httpOnly;
}
public Boolean getSecure() {
return this.secure;
}
public void setSecure(Boolean secure) {
this.secure = secure;
}
public Duration getMaxAge() {
return this.maxAge;
}
public void setMaxAge(Duration maxAge) {
this.maxAge = maxAge;
}
}
/**
* Available session tracking modes (mirrors
* {@link javax.servlet.SessionTrackingMode}.
*/
public enum SessionTrackingMode {
/**
* Send a cookie in response to the client's first request.
*/
COOKIE,
/**
* Rewrite the URL to append a session ID.
*/
URL,
/**
* Use SSL build-in mechanism to track the session.
*/
SSL
}
}
......@@ -99,14 +99,14 @@ public class JettyServletWebServerFactoryTests
@Test
public void sessionTimeout() {
JettyServletWebServerFactory factory = getFactory();
factory.setSessionTimeout(Duration.ofSeconds(10));
factory.getSession().setTimeout(Duration.ofSeconds(10));
assertTimeout(factory, 10);
}
@Test
public void sessionTimeoutInMins() {
JettyServletWebServerFactory factory = getFactory();
factory.setSessionTimeout(Duration.ofMinutes(1));
factory.getSession().setTimeout(Duration.ofMinutes(1));
assertTimeout(factory, 60);
}
......
/*
* Copyright 2012-2017 the original author or authors.
* Copyright 2012-2018 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
......@@ -187,21 +187,21 @@ public class TomcatServletWebServerFactoryTests
@Test
public void sessionTimeout() {
TomcatServletWebServerFactory factory = getFactory();
factory.setSessionTimeout(Duration.ofSeconds(10));
factory.getSession().setTimeout(Duration.ofSeconds(10));
assertTimeout(factory, 1);
}
@Test
public void sessionTimeoutInMins() {
TomcatServletWebServerFactory factory = getFactory();
factory.setSessionTimeout(Duration.ofMinutes(1));
factory.getSession().setTimeout(Duration.ofMinutes(1));
assertTimeout(factory, 1);
}
@Test
public void noSessionTimeout() {
TomcatServletWebServerFactory factory = getFactory();
factory.setSessionTimeout(null);
factory.getSession().setTimeout(null);
assertTimeout(factory, -1);
}
......
/*
* Copyright 2012-2017 the original author or authors.
* Copyright 2012-2018 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
......@@ -40,6 +40,7 @@ import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.Date;
import java.util.EnumSet;
import java.util.HashMap;
import java.util.Locale;
import java.util.Map;
......@@ -59,6 +60,7 @@ import javax.servlet.ServletContext;
import javax.servlet.ServletException;
import javax.servlet.ServletRequest;
import javax.servlet.ServletResponse;
import javax.servlet.SessionCookieConfig;
import javax.servlet.http.HttpServlet;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
......@@ -100,6 +102,7 @@ import org.springframework.boot.web.server.WebServerException;
import org.springframework.boot.web.servlet.FilterRegistrationBean;
import org.springframework.boot.web.servlet.ServletContextInitializer;
import org.springframework.boot.web.servlet.ServletRegistrationBean;
import org.springframework.boot.web.servlet.server.Session.SessionTrackingMode;
import org.springframework.core.io.ClassPathResource;
import org.springframework.core.io.Resource;
import org.springframework.http.HttpMethod;
......@@ -713,13 +716,14 @@ public abstract class AbstractServletWebServerFactoryTests {
@Test
public void defaultSessionTimeout() {
assertThat(getFactory().getSessionTimeout()).isEqualTo(Duration.ofMinutes(30));
assertThat(getFactory().getSession().getTimeout())
.isEqualTo(Duration.ofMinutes(30));
}
@Test
public void persistSession() throws Exception {
AbstractServletWebServerFactory factory = getFactory();
factory.setPersistSession(true);
factory.getSession().setPersistent(true);
this.webServer = factory.getWebServer(sessionServletRegistration());
this.webServer.start();
String s1 = getResponse(getLocalUrl("/session"));
......@@ -737,8 +741,8 @@ public abstract class AbstractServletWebServerFactoryTests {
public void persistSessionInSpecificSessionStoreDir() throws Exception {
AbstractServletWebServerFactory factory = getFactory();
File sessionStoreDir = this.temporaryFolder.newFolder();
factory.setPersistSession(true);
factory.setSessionStoreDir(sessionStoreDir);
factory.getSession().setPersistent(true);
factory.getSession().setStoreDir(sessionStoreDir);
this.webServer = factory.getWebServer(sessionServletRegistration());
this.webServer.start();
getResponse(getLocalUrl("/session"));
......@@ -759,7 +763,7 @@ public abstract class AbstractServletWebServerFactoryTests {
@Test
public void getValidSessionStoreWhenSessionStoreIsRelative() {
AbstractServletWebServerFactory factory = getFactory();
factory.setSessionStoreDir(new File("sessions"));
factory.getSession().setStoreDir(new File("sessions"));
File dir = factory.getValidSessionStoreDir(false);
assertThat(dir.getName()).isEqualTo("sessions");
assertThat(dir.getParentFile()).isEqualTo(new ApplicationHome().getDir());
......@@ -768,12 +772,35 @@ public abstract class AbstractServletWebServerFactoryTests {
@Test
public void getValidSessionStoreWhenSessionStoreReferencesFile() throws Exception {
AbstractServletWebServerFactory factory = getFactory();
factory.setSessionStoreDir(this.temporaryFolder.newFile());
factory.getSession().setStoreDir(this.temporaryFolder.newFile());
this.thrown.expect(IllegalStateException.class);
this.thrown.expectMessage("points to a file");
factory.getValidSessionStoreDir(false);
}
@Test
public void sessionCookieConfiguration() {
AbstractServletWebServerFactory factory = getFactory();
factory.getSession().getCookie().setName("testname");
factory.getSession().getCookie().setDomain("testdomain");
factory.getSession().getCookie().setPath("/testpath");
factory.getSession().getCookie().setComment("testcomment");
factory.getSession().getCookie().setHttpOnly(true);
factory.getSession().getCookie().setSecure(true);
factory.getSession().getCookie().setMaxAge(Duration.ofSeconds(60));
final AtomicReference<SessionCookieConfig> configReference = new AtomicReference<>();
this.webServer = factory.getWebServer(
(context) -> configReference.set(context.getSessionCookieConfig()));
SessionCookieConfig sessionCookieConfig = configReference.get();
assertThat(sessionCookieConfig.getName()).isEqualTo("testname");
assertThat(sessionCookieConfig.getDomain()).isEqualTo("testdomain");
assertThat(sessionCookieConfig.getPath()).isEqualTo("/testpath");
assertThat(sessionCookieConfig.getComment()).isEqualTo("testcomment");
assertThat(sessionCookieConfig.isHttpOnly()).isTrue();
assertThat(sessionCookieConfig.isSecure()).isTrue();
assertThat(sessionCookieConfig.getMaxAge()).isEqualTo(60);
}
@Test
public void compressionOfResponseToGetRequest() throws Exception {
assertThat(doTestCompression(10000, null, null)).isTrue();
......@@ -969,6 +996,48 @@ public abstract class AbstractServletWebServerFactoryTests {
factory.getWebServer().start();
}
@Test
public void sessionConfiguration() {
AbstractServletWebServerFactory factory = getFactory();
// map.put("server.servlet.session.timeout", "123");
// map.put("server.servlet.session.tracking-modes", "cookie,url");
// map.put("server.servlet.session.cookie.name", "testname");
// map.put("server.servlet.session.cookie.domain", "testdomain");
// map.put("server.servlet.session.cookie.path", "/testpath");
// map.put("server.servlet.session.cookie.comment", "testcomment");
// map.put("server.servlet.session.cookie.http-only", "true");
// map.put("server.servlet.session.cookie.secure", "true");
// map.put("server.servlet.session.cookie.max-age", "60");
factory.getSession().setTimeout(Duration.ofSeconds(123));
factory.getSession().setTrackingModes(
EnumSet.of(SessionTrackingMode.COOKIE, SessionTrackingMode.URL));
factory.getSession().getCookie().setName("testname");
factory.getSession().getCookie().setDomain("testdomain");
factory.getSession().getCookie().setPath("/testpath");
factory.getSession().getCookie().setComment("testcomment");
factory.getSession().getCookie().setHttpOnly(true);
factory.getSession().getCookie().setSecure(true);
factory.getSession().getCookie().setMaxAge(Duration.ofMinutes(1));
AtomicReference<ServletContext> contextReference = new AtomicReference<ServletContext>();
factory.getWebServer(contextReference::set).start();
ServletContext servletContext = contextReference.get();
assertThat(servletContext.getEffectiveSessionTrackingModes())
.isEqualTo(EnumSet.of(javax.servlet.SessionTrackingMode.COOKIE,
javax.servlet.SessionTrackingMode.URL));
assertThat(servletContext.getSessionCookieConfig().getName())
.isEqualTo("testname");
assertThat(servletContext.getSessionCookieConfig().getDomain())
.isEqualTo("testdomain");
assertThat(servletContext.getSessionCookieConfig().getPath())
.isEqualTo("/testpath");
assertThat(servletContext.getSessionCookieConfig().getComment())
.isEqualTo("testcomment");
assertThat(servletContext.getSessionCookieConfig().isHttpOnly()).isTrue();
assertThat(servletContext.getSessionCookieConfig().isSecure()).isTrue();
assertThat(servletContext.getSessionCookieConfig().getMaxAge()).isEqualTo(60);
}
protected abstract void addConnector(int port,
AbstractServletWebServerFactory factory);
......@@ -1016,8 +1085,7 @@ public abstract class AbstractServletWebServerFactoryTests {
@Override
protected void service(HttpServletRequest req,
HttpServletResponse resp)
throws IOException {
HttpServletResponse resp) throws IOException {
resp.setContentType("text/plain");
resp.setContentLength(testContent.length());
resp.getWriter().write(testContent);
......@@ -1140,7 +1208,8 @@ public abstract class AbstractServletWebServerFactoryTests {
new ExampleServlet() {
@Override
public void service(ServletRequest request, ServletResponse response) {
public void service(ServletRequest request,
ServletResponse response) {
throw new RuntimeException("Planned");
}
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment