Commit 3904f49c authored by Andy Wilkinson's avatar Andy Wilkinson

Configure ServletContext before initializing S…C…Initializer beans

Previously, the ServletContext was configured after any
ServletContextInitializer beans had been initialized. This meant that
any configuration class that provided such a bean would be initialized
before the ServletContext was configured. If the configuration class
used the ServletContext in its initializtaion that it would see it in
its default, unconfigured state.

This commit reworks the configuration of the ServletContext so that
it happens before any ServletContextInitializer beans are initialized.

Closes gh-10699
parent d8b3c7cc
/* /*
* Copyright 2012-2017 the original author or authors. * Copyright 2012-2018 the original author or authors.
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
...@@ -23,9 +23,9 @@ import java.util.Set; ...@@ -23,9 +23,9 @@ import java.util.Set;
import org.springframework.beans.factory.ObjectProvider; import org.springframework.beans.factory.ObjectProvider;
import org.springframework.boot.autoconfigure.web.ServerProperties; import org.springframework.boot.autoconfigure.web.ServerProperties;
import org.springframework.boot.autoconfigure.web.ServerProperties.Servlet.Session;
import org.springframework.boot.context.properties.ConfigurationProperties; import org.springframework.boot.context.properties.ConfigurationProperties;
import org.springframework.boot.web.servlet.DispatcherType; import org.springframework.boot.web.servlet.DispatcherType;
import org.springframework.boot.web.servlet.server.Session;
import org.springframework.session.web.http.SessionRepositoryFilter; import org.springframework.session.web.http.SessionRepositoryFilter;
/** /**
......
...@@ -27,7 +27,6 @@ import java.util.HashMap; ...@@ -27,7 +27,6 @@ import java.util.HashMap;
import java.util.List; import java.util.List;
import java.util.Locale; import java.util.Locale;
import java.util.Map; import java.util.Map;
import java.util.Set;
import java.util.TimeZone; import java.util.TimeZone;
import org.springframework.boot.context.properties.ConfigurationProperties; import org.springframework.boot.context.properties.ConfigurationProperties;
...@@ -37,6 +36,7 @@ import org.springframework.boot.web.server.Compression; ...@@ -37,6 +36,7 @@ import org.springframework.boot.web.server.Compression;
import org.springframework.boot.web.server.Http2; import org.springframework.boot.web.server.Http2;
import org.springframework.boot.web.server.Ssl; import org.springframework.boot.web.server.Ssl;
import org.springframework.boot.web.servlet.server.Jsp; import org.springframework.boot.web.servlet.server.Jsp;
import org.springframework.boot.web.servlet.server.Session;
import org.springframework.util.Assert; import org.springframework.util.Assert;
import org.springframework.util.StringUtils; import org.springframework.util.StringUtils;
...@@ -319,196 +319,8 @@ public class ServerProperties { ...@@ -319,196 +319,8 @@ public class ServerProperties {
return result; return result;
} }
/**
* Session properties.
*/
public static class Session {
/**
* Session timeout. If a duration suffix is not specified, seconds will be used.
*/
@DefaultDurationUnit(ChronoUnit.SECONDS)
private Duration timeout;
/**
* Session tracking modes (one or more of the following: "cookie", "url", "ssl").
*/
private Set<SessionTrackingMode> trackingModes;
/**
* Whether to persist session data between restarts.
*/
private boolean persistent;
/**
* Directory used to store session data.
*/
private File storeDir;
private final Cookie cookie = new Cookie();
public Cookie getCookie() {
return this.cookie;
}
public Duration getTimeout() {
return this.timeout;
}
public void setTimeout(Duration timeout) {
this.timeout = timeout;
}
public Set<SessionTrackingMode> getTrackingModes() {
return this.trackingModes;
}
public void setTrackingModes(Set<SessionTrackingMode> trackingModes) {
this.trackingModes = trackingModes;
}
public boolean isPersistent() {
return this.persistent;
}
public void setPersistent(boolean persistent) {
this.persistent = persistent;
}
public File getStoreDir() {
return this.storeDir;
}
public void setStoreDir(File storeDir) {
this.storeDir = storeDir;
}
/**
* Cookie properties.
*/
public static class Cookie {
/**
* Session cookie name.
*/
private String name;
/**
* Domain for the session cookie.
*/
private String domain;
/**
* Path of the session cookie.
*/
private String path;
/**
* Comment for the session cookie.
*/
private String comment;
/**
* "HttpOnly" flag for the session cookie.
*/
private Boolean httpOnly;
/**
* "Secure" flag for the session cookie.
*/
private Boolean secure;
/**
* Maximum age of the session cookie.
*/
@DefaultDurationUnit(ChronoUnit.SECONDS)
private Duration maxAge;
public String getName() {
return this.name;
}
public void setName(String name) {
this.name = name;
}
public String getDomain() {
return this.domain;
}
public void setDomain(String domain) {
this.domain = domain;
}
public String getPath() {
return this.path;
}
public void setPath(String path) {
this.path = path;
}
public String getComment() {
return this.comment;
}
public void setComment(String comment) {
this.comment = comment;
}
public Boolean getHttpOnly() {
return this.httpOnly;
}
public void setHttpOnly(Boolean httpOnly) {
this.httpOnly = httpOnly;
}
public Boolean getSecure() {
return this.secure;
}
public void setSecure(Boolean secure) {
this.secure = secure;
}
public Duration getMaxAge() {
return this.maxAge;
}
public void setMaxAge(Duration maxAge) {
this.maxAge = maxAge;
}
}
/**
* Available session tracking modes (mirrors
* {@link javax.servlet.SessionTrackingMode}.
*/
public enum SessionTrackingMode {
/**
* Send a cookie in response to the client's first request.
*/
COOKIE,
/**
* Rewrite the URL to append a session ID.
*/
URL,
/**
* Use SSL build-in mechanism to track the session.
*/
SSL
}
}
} }
/** /**
* Tomcat properties. * Tomcat properties.
*/ */
......
...@@ -16,15 +16,7 @@ ...@@ -16,15 +16,7 @@
package org.springframework.boot.autoconfigure.web.servlet; package org.springframework.boot.autoconfigure.web.servlet;
import java.util.LinkedHashSet;
import java.util.Set;
import javax.servlet.ServletContext;
import javax.servlet.ServletException;
import javax.servlet.SessionCookieConfig;
import org.springframework.boot.autoconfigure.web.ServerProperties; import org.springframework.boot.autoconfigure.web.ServerProperties;
import org.springframework.boot.autoconfigure.web.ServerProperties.Servlet.Session;
import org.springframework.boot.autoconfigure.web.embedded.jetty.JettyCustomizer; import org.springframework.boot.autoconfigure.web.embedded.jetty.JettyCustomizer;
import org.springframework.boot.autoconfigure.web.embedded.tomcat.TomcatCustomizer; import org.springframework.boot.autoconfigure.web.embedded.tomcat.TomcatCustomizer;
import org.springframework.boot.autoconfigure.web.embedded.undertow.UndertowCustomizer; import org.springframework.boot.autoconfigure.web.embedded.undertow.UndertowCustomizer;
...@@ -33,9 +25,7 @@ import org.springframework.boot.web.embedded.tomcat.ConfigurableTomcatWebServerF ...@@ -33,9 +25,7 @@ import org.springframework.boot.web.embedded.tomcat.ConfigurableTomcatWebServerF
import org.springframework.boot.web.embedded.tomcat.TomcatServletWebServerFactory; import org.springframework.boot.web.embedded.tomcat.TomcatServletWebServerFactory;
import org.springframework.boot.web.embedded.undertow.UndertowServletWebServerFactory; import org.springframework.boot.web.embedded.undertow.UndertowServletWebServerFactory;
import org.springframework.boot.web.server.WebServerFactoryCustomizer; import org.springframework.boot.web.server.WebServerFactoryCustomizer;
import org.springframework.boot.web.servlet.ServletContextInitializer;
import org.springframework.boot.web.servlet.server.ConfigurableServletWebServerFactory; import org.springframework.boot.web.servlet.server.ConfigurableServletWebServerFactory;
import org.springframework.boot.web.servlet.server.InitParameterConfiguringServletContextInitializer;
import org.springframework.context.EnvironmentAware; import org.springframework.context.EnvironmentAware;
import org.springframework.core.Ordered; import org.springframework.core.Ordered;
import org.springframework.core.env.Environment; import org.springframework.core.env.Environment;
...@@ -89,11 +79,7 @@ public class DefaultServletWebServerFactoryCustomizer ...@@ -89,11 +79,7 @@ public class DefaultServletWebServerFactoryCustomizer
if (this.serverProperties.getDisplayName() != null) { if (this.serverProperties.getDisplayName() != null) {
factory.setDisplayName(this.serverProperties.getDisplayName()); factory.setDisplayName(this.serverProperties.getDisplayName());
} }
if (this.serverProperties.getServlet().getSession().getTimeout() != null) { factory.setSession(this.serverProperties.getServlet().getSession());
factory.setSessionTimeout(this.serverProperties.getServlet().getSession().getTimeout());
}
factory.setPersistSession(this.serverProperties.getServlet().getSession().isPersistent());
factory.setSessionStoreDir(this.serverProperties.getServlet().getSession().getStoreDir());
if (this.serverProperties.getSsl() != null) { if (this.serverProperties.getSsl() != null) {
factory.setSsl(this.serverProperties.getSsl()); factory.setSsl(this.serverProperties.getSsl());
} }
...@@ -109,8 +95,10 @@ public class DefaultServletWebServerFactoryCustomizer ...@@ -109,8 +95,10 @@ public class DefaultServletWebServerFactoryCustomizer
factory.setServerHeader(this.serverProperties.getServerHeader()); factory.setServerHeader(this.serverProperties.getServerHeader());
if (factory instanceof TomcatServletWebServerFactory) { if (factory instanceof TomcatServletWebServerFactory) {
TomcatServletWebServerFactory tomcatFactory = (TomcatServletWebServerFactory) factory; TomcatServletWebServerFactory tomcatFactory = (TomcatServletWebServerFactory) factory;
TomcatCustomizer.customizeTomcat(this.serverProperties, this.environment, tomcatFactory); TomcatCustomizer.customizeTomcat(this.serverProperties, this.environment,
TomcatServletCustomizer.customizeTomcat(this.serverProperties, this.environment, tomcatFactory); tomcatFactory);
TomcatServletCustomizer.customizeTomcat(this.serverProperties,
this.environment, tomcatFactory);
} }
if (factory instanceof JettyServletWebServerFactory) { if (factory instanceof JettyServletWebServerFactory) {
JettyCustomizer.customizeJetty(this.serverProperties, this.environment, JettyCustomizer.customizeJetty(this.serverProperties, this.environment,
...@@ -120,71 +108,8 @@ public class DefaultServletWebServerFactoryCustomizer ...@@ -120,71 +108,8 @@ public class DefaultServletWebServerFactoryCustomizer
UndertowCustomizer.customizeUndertow(this.serverProperties, this.environment, UndertowCustomizer.customizeUndertow(this.serverProperties, this.environment,
(UndertowServletWebServerFactory) factory); (UndertowServletWebServerFactory) factory);
} }
factory.addInitializers( factory.setInitParameters(
new SessionConfiguringInitializer(this.serverProperties.getServlet().getSession())); this.serverProperties.getServlet().getContextParameters());
factory.addInitializers(new InitParameterConfiguringServletContextInitializer(
this.serverProperties.getServlet().getContextParameters()));
}
/**
* {@link ServletContextInitializer} to apply appropriate parts of the {@link Session}
* configuration.
*/
private static class SessionConfiguringInitializer
implements ServletContextInitializer {
private final Session session;
SessionConfiguringInitializer(Session session) {
this.session = session;
}
@Override
public void onStartup(ServletContext servletContext) throws ServletException {
if (this.session.getTrackingModes() != null) {
servletContext
.setSessionTrackingModes(unwrap(this.session.getTrackingModes()));
}
configureSessionCookie(servletContext.getSessionCookieConfig());
}
private void configureSessionCookie(SessionCookieConfig config) {
Session.Cookie cookie = this.session.getCookie();
if (cookie.getName() != null) {
config.setName(cookie.getName());
}
if (cookie.getDomain() != null) {
config.setDomain(cookie.getDomain());
}
if (cookie.getPath() != null) {
config.setPath(cookie.getPath());
}
if (cookie.getComment() != null) {
config.setComment(cookie.getComment());
}
if (cookie.getHttpOnly() != null) {
config.setHttpOnly(cookie.getHttpOnly());
}
if (cookie.getSecure() != null) {
config.setSecure(cookie.getSecure());
}
if (cookie.getMaxAge() != null) {
config.setMaxAge((int) cookie.getMaxAge().getSeconds());
}
}
private Set<javax.servlet.SessionTrackingMode> unwrap(
Set<Session.SessionTrackingMode> modes) {
if (modes == null) {
return null;
}
Set<javax.servlet.SessionTrackingMode> result = new LinkedHashSet<>();
for (Session.SessionTrackingMode mode : modes) {
result.add(javax.servlet.SessionTrackingMode.valueOf(mode.name()));
}
return result;
}
} }
private static class TomcatServletCustomizer { private static class TomcatServletCustomizer {
...@@ -213,7 +138,8 @@ public class DefaultServletWebServerFactoryCustomizer ...@@ -213,7 +138,8 @@ public class DefaultServletWebServerFactoryCustomizer
} }
private static void customizeUseRelativeRedirects( private static void customizeUseRelativeRedirects(
ConfigurableTomcatWebServerFactory factory, boolean useRelativeRedirects) { ConfigurableTomcatWebServerFactory factory,
boolean useRelativeRedirects) {
factory.addContextCustomizers( factory.addContextCustomizers(
(context) -> context.setUseRelativeRedirects(useRelativeRedirects)); (context) -> context.setUseRelativeRedirects(useRelativeRedirects));
} }
......
/* /*
* Copyright 2012-2017 the original author or authors. * Copyright 2012-2018 the original author or authors.
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
...@@ -25,13 +25,17 @@ import org.junit.Test; ...@@ -25,13 +25,17 @@ import org.junit.Test;
import org.springframework.beans.factory.BeanCreationException; import org.springframework.beans.factory.BeanCreationException;
import org.springframework.boot.autoconfigure.AutoConfigurations; import org.springframework.boot.autoconfigure.AutoConfigurations;
import org.springframework.boot.autoconfigure.web.servlet.ServletWebServerFactoryAutoConfiguration;
import org.springframework.boot.test.context.runner.WebApplicationContextRunner; import org.springframework.boot.test.context.runner.WebApplicationContextRunner;
import org.springframework.boot.web.servlet.FilterRegistrationBean; import org.springframework.boot.web.servlet.FilterRegistrationBean;
import org.springframework.boot.web.servlet.context.AnnotationConfigServletWebServerApplicationContext;
import org.springframework.context.annotation.Bean; import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration; import org.springframework.context.annotation.Configuration;
import org.springframework.session.MapSessionRepository; import org.springframework.session.MapSessionRepository;
import org.springframework.session.SessionRepository; import org.springframework.session.SessionRepository;
import org.springframework.session.config.annotation.web.http.EnableSpringHttpSession; import org.springframework.session.config.annotation.web.http.EnableSpringHttpSession;
import org.springframework.session.web.http.CookieHttpSessionIdResolver;
import org.springframework.session.web.http.DefaultCookieSerializer;
import org.springframework.session.web.http.SessionRepositoryFilter; import org.springframework.session.web.http.SessionRepositoryFilter;
import org.springframework.test.util.ReflectionTestUtils; import org.springframework.test.util.ReflectionTestUtils;
...@@ -147,6 +151,28 @@ public class SessionAutoConfigurationTests extends AbstractSessionAutoConfigurat ...@@ -147,6 +151,28 @@ public class SessionAutoConfigurationTests extends AbstractSessionAutoConfigurat
}); });
} }
@Test
public void sessionCookieConfigurationIsPickedUp() {
new WebApplicationContextRunner(
AnnotationConfigServletWebServerApplicationContext::new)
.withConfiguration(AutoConfigurations
.of(ServletWebServerFactoryAutoConfiguration.class))
.withUserConfiguration(SessionRepositoryConfiguration.class)
.withPropertyValues("server.port=0",
"server.servlet.session.cookie.name=testname")
.run((context) -> {
SessionRepositoryFilter<?> filter = context
.getBean(SessionRepositoryFilter.class);
CookieHttpSessionIdResolver sessionIdResolver = (CookieHttpSessionIdResolver) ReflectionTestUtils
.getField(filter, "httpSessionIdResolver");
DefaultCookieSerializer cookieSerializer = (DefaultCookieSerializer) ReflectionTestUtils
.getField(sessionIdResolver, "cookieSerializer");
String cookieName = (String) ReflectionTestUtils
.getField(cookieSerializer, "cookieName");
assertThat(cookieName).isEqualTo("testname");
});
}
@Configuration @Configuration
@EnableSpringHttpSession @EnableSpringHttpSession
static class SessionRepositoryConfiguration { static class SessionRepositoryConfiguration {
......
/* /*
* Copyright 2012-2017 the original author or authors. * Copyright 2012-2018 the original author or authors.
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
...@@ -19,17 +19,11 @@ package org.springframework.boot.autoconfigure.web.servlet; ...@@ -19,17 +19,11 @@ package org.springframework.boot.autoconfigure.web.servlet;
import java.io.File; import java.io.File;
import java.io.IOException; import java.io.IOException;
import java.time.Duration; import java.time.Duration;
import java.util.EnumSet;
import java.util.HashMap; import java.util.HashMap;
import java.util.Locale; import java.util.Locale;
import java.util.Map; import java.util.Map;
import java.util.TimeZone; import java.util.TimeZone;
import javax.servlet.ServletContext;
import javax.servlet.ServletException;
import javax.servlet.SessionCookieConfig;
import javax.servlet.SessionTrackingMode;
import org.apache.catalina.Context; import org.apache.catalina.Context;
import org.apache.catalina.Valve; import org.apache.catalina.Valve;
import org.apache.catalina.startup.Tomcat; import org.apache.catalina.startup.Tomcat;
...@@ -57,12 +51,12 @@ import org.springframework.boot.web.embedded.tomcat.TomcatWebServer; ...@@ -57,12 +51,12 @@ import org.springframework.boot.web.embedded.tomcat.TomcatWebServer;
import org.springframework.boot.web.embedded.undertow.UndertowServletWebServerFactory; import org.springframework.boot.web.embedded.undertow.UndertowServletWebServerFactory;
import org.springframework.boot.web.servlet.ServletContextInitializer; import org.springframework.boot.web.servlet.ServletContextInitializer;
import org.springframework.boot.web.servlet.server.ConfigurableServletWebServerFactory; import org.springframework.boot.web.servlet.server.ConfigurableServletWebServerFactory;
import org.springframework.boot.web.servlet.server.Session;
import org.springframework.boot.web.servlet.server.Session.Cookie;
import org.springframework.mock.env.MockEnvironment; import org.springframework.mock.env.MockEnvironment;
import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThat;
import static org.mockito.ArgumentMatchers.anyBoolean; import static org.mockito.ArgumentMatchers.anyBoolean;
import static org.mockito.BDDMockito.given;
import static org.mockito.Mockito.atLeastOnce;
import static org.mockito.Mockito.mock; import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.never; import static org.mockito.Mockito.never;
import static org.mockito.Mockito.spy; import static org.mockito.Mockito.spy;
...@@ -224,21 +218,19 @@ public class DefaultServletWebServerFactoryCustomizerTests { ...@@ -224,21 +218,19 @@ public class DefaultServletWebServerFactoryCustomizerTests {
bindProperties(map); bindProperties(map);
ConfigurableServletWebServerFactory factory = mock( ConfigurableServletWebServerFactory factory = mock(
ConfigurableServletWebServerFactory.class); ConfigurableServletWebServerFactory.class);
ServletContext servletContext = mock(ServletContext.class);
SessionCookieConfig sessionCookieConfig = mock(SessionCookieConfig.class);
given(servletContext.getSessionCookieConfig()).willReturn(sessionCookieConfig);
this.customizer.customize(factory); this.customizer.customize(factory);
triggerInitializers(factory, servletContext); ArgumentCaptor<Session> sessionCaptor = ArgumentCaptor.forClass(Session.class);
verify(factory).setSessionTimeout(Duration.ofSeconds(123)); verify(factory).setSession(sessionCaptor.capture());
verify(servletContext).setSessionTrackingModes( assertThat(sessionCaptor.getValue().getTimeout())
EnumSet.of(SessionTrackingMode.COOKIE, SessionTrackingMode.URL)); .isEqualTo(Duration.ofSeconds(123));
verify(sessionCookieConfig).setName("testname"); Cookie cookie = sessionCaptor.getValue().getCookie();
verify(sessionCookieConfig).setDomain("testdomain"); assertThat(cookie.getName()).isEqualTo("testname");
verify(sessionCookieConfig).setPath("/testpath"); assertThat(cookie.getDomain()).isEqualTo("testdomain");
verify(sessionCookieConfig).setComment("testcomment"); assertThat(cookie.getPath()).isEqualTo("/testpath");
verify(sessionCookieConfig).setHttpOnly(true); assertThat(cookie.getComment()).isEqualTo("testcomment");
verify(sessionCookieConfig).setSecure(true); assertThat(cookie.getHttpOnly()).isTrue();
verify(sessionCookieConfig).setMaxAge(60); assertThat(cookie.getMaxAge()).isEqualTo(Duration.ofSeconds(60));
} }
@Test @Test
...@@ -540,7 +532,10 @@ public class DefaultServletWebServerFactoryCustomizerTests { ...@@ -540,7 +532,10 @@ public class DefaultServletWebServerFactoryCustomizerTests {
bindProperties(map); bindProperties(map);
JettyServletWebServerFactory factory = spy(new JettyServletWebServerFactory()); JettyServletWebServerFactory factory = spy(new JettyServletWebServerFactory());
this.customizer.customize(factory); this.customizer.customize(factory);
verify(factory).setSessionStoreDir(new File("myfolder")); ArgumentCaptor<Session> sessionCaptor = ArgumentCaptor.forClass(Session.class);
verify(factory).setSession(sessionCaptor.capture());
assertThat(sessionCaptor.getValue().getStoreDir())
.isEqualTo(new File("myfolder"));
} }
@Test @Test
...@@ -638,21 +633,6 @@ public class DefaultServletWebServerFactoryCustomizerTests { ...@@ -638,21 +633,6 @@ public class DefaultServletWebServerFactoryCustomizerTests {
} }
} }
private void triggerInitializers(ConfigurableServletWebServerFactory factory,
ServletContext servletContext) throws ServletException {
verify(factory, atLeastOnce()).addInitializers(this.initializersCaptor.capture());
for (Object initializers : this.initializersCaptor.getAllValues()) {
if (initializers instanceof ServletContextInitializer) {
((ServletContextInitializer) initializers).onStartup(servletContext);
}
else {
for (ServletContextInitializer initializer : (ServletContextInitializer[]) initializers) {
initializer.onStartup(servletContext);
}
}
}
}
private void bindProperties(Map<String, String> map) { private void bindProperties(Map<String, String> map) {
ConfigurationPropertySource source = new MapConfigurationPropertySource(map); ConfigurationPropertySource source = new MapConfigurationPropertySource(map);
new Binder(source).bind("server", Bindable.ofInstance(this.properties)); new Binder(source).bind("server", Bindable.ofInstance(this.properties));
......
...@@ -24,6 +24,7 @@ import java.net.MalformedURLException; ...@@ -24,6 +24,7 @@ import java.net.MalformedURLException;
import java.net.URL; import java.net.URL;
import java.nio.channels.ReadableByteChannel; import java.nio.channels.ReadableByteChannel;
import java.nio.charset.Charset; import java.nio.charset.Charset;
import java.time.Duration;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.Arrays; import java.util.Arrays;
import java.util.Collection; import java.util.Collection;
...@@ -231,10 +232,11 @@ public class JettyServletWebServerFactory extends AbstractServletWebServerFactor ...@@ -231,10 +232,11 @@ public class JettyServletWebServerFactory extends AbstractServletWebServerFactor
private void configureSession(WebAppContext context) { private void configureSession(WebAppContext context) {
SessionHandler handler = context.getSessionHandler(); SessionHandler handler = context.getSessionHandler();
Duration sessionTimeout = getSession().getTimeout();
handler.setMaxInactiveInterval( handler.setMaxInactiveInterval(
(getSessionTimeout() == null || getSessionTimeout().isNegative()) ? -1 (sessionTimeout == null || sessionTimeout.isNegative()) ? -1
: (int) getSessionTimeout().getSeconds()); : (int) sessionTimeout.getSeconds());
if (isPersistSession()) { if (getSession().isPersistent()) {
DefaultSessionCache cache = new DefaultSessionCache(handler); DefaultSessionCache cache = new DefaultSessionCache(handler);
FileSessionDataStore store = new FileSessionDataStore(); FileSessionDataStore store = new FileSessionDataStore();
store.setStoreDir(getValidSessionStoreDir()); store.setStoreDir(getValidSessionStoreDir());
......
...@@ -362,7 +362,7 @@ public class TomcatServletWebServerFactory extends AbstractServletWebServerFacto ...@@ -362,7 +362,7 @@ public class TomcatServletWebServerFactory extends AbstractServletWebServerFacto
private void configureSession(Context context) { private void configureSession(Context context) {
long sessionTimeout = getSessionTimeoutInMinutes(); long sessionTimeout = getSessionTimeoutInMinutes();
context.setSessionTimeout((int) sessionTimeout); context.setSessionTimeout((int) sessionTimeout);
if (isPersistSession()) { if (getSession().isPersistent()) {
Manager manager = context.getManager(); Manager manager = context.getManager();
if (manager == null) { if (manager == null) {
manager = new StandardManager(); manager = new StandardManager();
...@@ -385,7 +385,7 @@ public class TomcatServletWebServerFactory extends AbstractServletWebServerFacto ...@@ -385,7 +385,7 @@ public class TomcatServletWebServerFactory extends AbstractServletWebServerFacto
} }
private long getSessionTimeoutInMinutes() { private long getSessionTimeoutInMinutes() {
Duration sessionTimeout = getSessionTimeout(); Duration sessionTimeout = getSession().getTimeout();
if (sessionTimeout == null || sessionTimeout.isNegative() if (sessionTimeout == null || sessionTimeout.isNegative()
|| sessionTimeout.isZero()) { || sessionTimeout.isZero()) {
return 0; return 0;
...@@ -516,8 +516,8 @@ public class TomcatServletWebServerFactory extends AbstractServletWebServerFacto ...@@ -516,8 +516,8 @@ public class TomcatServletWebServerFactory extends AbstractServletWebServerFacto
} }
/** /**
* Set {@link LifecycleListener}s that should be applied to the Tomcat {@link Context}. * Set {@link LifecycleListener}s that should be applied to the Tomcat
* Calling this method will replace any existing listeners. * {@link Context}. Calling this method will replace any existing listeners.
* @param contextLifecycleListeners the listeners to set * @param contextLifecycleListeners the listeners to set
*/ */
public void setContextLifecycleListeners( public void setContextLifecycleListeners(
......
...@@ -21,6 +21,7 @@ import java.io.IOException; ...@@ -21,6 +21,7 @@ import java.io.IOException;
import java.net.MalformedURLException; import java.net.MalformedURLException;
import java.net.URL; import java.net.URL;
import java.nio.charset.Charset; import java.nio.charset.Charset;
import java.time.Duration;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.Arrays; import java.util.Arrays;
import java.util.Collection; import java.util.Collection;
...@@ -239,8 +240,7 @@ public class UndertowServletWebServerFactory extends AbstractServletWebServerFac ...@@ -239,8 +240,7 @@ public class UndertowServletWebServerFactory extends AbstractServletWebServerFac
new SslBuilderCustomizer(getPort(), getAddress(), getSsl(), getSslStoreProvider()) new SslBuilderCustomizer(getPort(), getAddress(), getSsl(), getSslStoreProvider())
.customize(builder); .customize(builder);
if (getHttp2() != null) { if (getHttp2() != null) {
builder.setServerOption(UndertowOptions.ENABLE_HTTP2, builder.setServerOption(UndertowOptions.ENABLE_HTTP2, getHttp2().isEnabled());
getHttp2().isEnabled());
} }
} }
...@@ -274,7 +274,7 @@ public class UndertowServletWebServerFactory extends AbstractServletWebServerFac ...@@ -274,7 +274,7 @@ public class UndertowServletWebServerFactory extends AbstractServletWebServerFac
if (isAccessLogEnabled()) { if (isAccessLogEnabled()) {
configureAccessLog(deployment); configureAccessLog(deployment);
} }
if (isPersistSession()) { if (getSession().isPersistent()) {
File dir = getValidSessionStoreDir(); File dir = getValidSessionStoreDir();
deployment.setSessionPersistenceManager(new FileSessionPersistence(dir)); deployment.setSessionPersistenceManager(new FileSessionPersistence(dir));
} }
...@@ -282,9 +282,10 @@ public class UndertowServletWebServerFactory extends AbstractServletWebServerFac ...@@ -282,9 +282,10 @@ public class UndertowServletWebServerFactory extends AbstractServletWebServerFac
DeploymentManager manager = Servlets.newContainer().addDeployment(deployment); DeploymentManager manager = Servlets.newContainer().addDeployment(deployment);
manager.deploy(); manager.deploy();
SessionManager sessionManager = manager.getDeployment().getSessionManager(); SessionManager sessionManager = manager.getDeployment().getSessionManager();
int sessionTimeout = (getSessionTimeout() == null || getSessionTimeout().isZero() Duration timeoutDuration = getSession().getTimeout();
|| getSessionTimeout().isNegative() ? -1 int sessionTimeout = (timeoutDuration == null || timeoutDuration.isZero()
: (int) getSessionTimeout().getSeconds()); || timeoutDuration.isNegative() ? -1
: (int) timeoutDuration.getSeconds());
sessionManager.setDefaultSessionTimeout(sessionTimeout); sessionManager.setDefaultSessionTimeout(sessionTimeout);
return manager; return manager;
} }
......
/* /*
* Copyright 2012-2017 the original author or authors. * Copyright 2012-2018 the original author or authors.
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
...@@ -19,13 +19,19 @@ package org.springframework.boot.web.servlet.server; ...@@ -19,13 +19,19 @@ package org.springframework.boot.web.servlet.server;
import java.io.File; import java.io.File;
import java.net.URL; import java.net.URL;
import java.nio.charset.Charset; import java.nio.charset.Charset;
import java.time.Duration;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.Arrays; import java.util.Arrays;
import java.util.Collections;
import java.util.HashMap; import java.util.HashMap;
import java.util.LinkedHashSet;
import java.util.List; import java.util.List;
import java.util.Locale; import java.util.Locale;
import java.util.Map; import java.util.Map;
import java.util.Set;
import javax.servlet.ServletContext;
import javax.servlet.ServletException;
import javax.servlet.SessionCookieConfig;
import org.apache.commons.logging.Log; import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory; import org.apache.commons.logging.LogFactory;
...@@ -58,9 +64,7 @@ public abstract class AbstractServletWebServerFactory ...@@ -58,9 +64,7 @@ public abstract class AbstractServletWebServerFactory
private String displayName; private String displayName;
private Duration sessionTimeout = Duration.ofMinutes(30); private Session session = new Session();
private boolean persistSession;
private boolean registerDefaultServlet = true; private boolean registerDefaultServlet = true;
...@@ -72,7 +76,7 @@ public abstract class AbstractServletWebServerFactory ...@@ -72,7 +76,7 @@ public abstract class AbstractServletWebServerFactory
private Map<Locale, Charset> localeCharsetMappings = new HashMap<>(); private Map<Locale, Charset> localeCharsetMappings = new HashMap<>();
private final SessionStoreDirectory sessionStoreDir = new SessionStoreDirectory(); private Map<String, String> initParameters = Collections.emptyMap();
private final DocumentRoot documentRoot = new DocumentRoot(this.logger); private final DocumentRoot documentRoot = new DocumentRoot(this.logger);
...@@ -143,37 +147,6 @@ public abstract class AbstractServletWebServerFactory ...@@ -143,37 +147,6 @@ public abstract class AbstractServletWebServerFactory
this.displayName = displayName; this.displayName = displayName;
} }
/**
* Return the session timeout or {@code null}.
* @return the session timeout
*/
public Duration getSessionTimeout() {
return this.sessionTimeout;
}
@Override
public void setSessionTimeout(Duration sessionTimeout) {
this.sessionTimeout = sessionTimeout;
}
public boolean isPersistSession() {
return this.persistSession;
}
@Override
public void setPersistSession(boolean persistSession) {
this.persistSession = persistSession;
}
public File getSessionStoreDir() {
return this.sessionStoreDir.getDirectory();
}
@Override
public void setSessionStoreDir(File sessionStoreDir) {
this.sessionStoreDir.setDirectory(sessionStoreDir);
}
/** /**
* Flag to indicate that the default servlet should be registered. * Flag to indicate that the default servlet should be registered.
* @return true if the default servlet is to be registered * @return true if the default servlet is to be registered
...@@ -235,6 +208,15 @@ public abstract class AbstractServletWebServerFactory ...@@ -235,6 +208,15 @@ public abstract class AbstractServletWebServerFactory
this.jsp = jsp; this.jsp = jsp;
} }
public Session getSession() {
return this.session;
}
@Override
public void setSession(Session session) {
this.session = session;
}
/** /**
* Return the Locale to Charset mappings. * Return the Locale to Charset mappings.
* @return the charset mappings * @return the charset mappings
...@@ -249,6 +231,15 @@ public abstract class AbstractServletWebServerFactory ...@@ -249,6 +231,15 @@ public abstract class AbstractServletWebServerFactory
this.localeCharsetMappings = localeCharsetMappings; this.localeCharsetMappings = localeCharsetMappings;
} }
@Override
public void setInitParameters(Map<String, String> initParameters) {
this.initParameters = initParameters;
}
public Map<String, String> getInitParameters() {
return this.initParameters;
}
/** /**
* Utility method that can be used by subclasses wishing to combine the specified * Utility method that can be used by subclasses wishing to combine the specified
* {@link ServletContextInitializer} parameters with those defined in this instance. * {@link ServletContextInitializer} parameters with those defined in this instance.
...@@ -259,6 +250,9 @@ public abstract class AbstractServletWebServerFactory ...@@ -259,6 +250,9 @@ public abstract class AbstractServletWebServerFactory
protected final ServletContextInitializer[] mergeInitializers( protected final ServletContextInitializer[] mergeInitializers(
ServletContextInitializer... initializers) { ServletContextInitializer... initializers) {
List<ServletContextInitializer> mergedInitializers = new ArrayList<>(); List<ServletContextInitializer> mergedInitializers = new ArrayList<>();
mergedInitializers.add((servletContext) -> this.initParameters
.forEach(servletContext::setInitParameter));
mergedInitializers.add(new SessionConfiguringInitializer(this.session));
mergedInitializers.addAll(Arrays.asList(initializers)); mergedInitializers.addAll(Arrays.asList(initializers));
mergedInitializers.addAll(this.initializers); mergedInitializers.addAll(this.initializers);
return mergedInitializers return mergedInitializers
...@@ -288,11 +282,72 @@ public abstract class AbstractServletWebServerFactory ...@@ -288,11 +282,72 @@ public abstract class AbstractServletWebServerFactory
} }
protected final File getValidSessionStoreDir() { protected final File getValidSessionStoreDir() {
return this.sessionStoreDir.getValidDirectory(true); return getValidSessionStoreDir(true);
} }
protected final File getValidSessionStoreDir(boolean mkdirs) { protected final File getValidSessionStoreDir(boolean mkdirs) {
return this.sessionStoreDir.getValidDirectory(mkdirs); return this.session.getSessionStoreDirectory().getValidDirectory(mkdirs);
}
/**
* {@link ServletContextInitializer} to apply appropriate parts of the {@link Session}
* configuration.
*/
private static class SessionConfiguringInitializer
implements ServletContextInitializer {
private final Session session;
SessionConfiguringInitializer(Session session) {
this.session = session;
}
@Override
public void onStartup(ServletContext servletContext) throws ServletException {
if (this.session.getTrackingModes() != null) {
servletContext
.setSessionTrackingModes(unwrap(this.session.getTrackingModes()));
}
configureSessionCookie(servletContext.getSessionCookieConfig());
}
private void configureSessionCookie(SessionCookieConfig config) {
Session.Cookie cookie = this.session.getCookie();
if (cookie.getName() != null) {
config.setName(cookie.getName());
}
if (cookie.getDomain() != null) {
config.setDomain(cookie.getDomain());
}
if (cookie.getPath() != null) {
config.setPath(cookie.getPath());
}
if (cookie.getComment() != null) {
config.setComment(cookie.getComment());
}
if (cookie.getHttpOnly() != null) {
config.setHttpOnly(cookie.getHttpOnly());
}
if (cookie.getSecure() != null) {
config.setSecure(cookie.getSecure());
}
if (cookie.getMaxAge() != null) {
config.setMaxAge((int) cookie.getMaxAge().getSeconds());
}
}
private Set<javax.servlet.SessionTrackingMode> unwrap(
Set<Session.SessionTrackingMode> modes) {
if (modes == null) {
return null;
}
Set<javax.servlet.SessionTrackingMode> result = new LinkedHashSet<>();
for (Session.SessionTrackingMode mode : modes) {
result.add(javax.servlet.SessionTrackingMode.valueOf(mode.name()));
}
return result;
}
} }
} }
/* /*
* Copyright 2012-2017 the original author or authors. * Copyright 2012-2018 the original author or authors.
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
...@@ -18,11 +18,12 @@ package org.springframework.boot.web.servlet.server; ...@@ -18,11 +18,12 @@ package org.springframework.boot.web.servlet.server;
import java.io.File; import java.io.File;
import java.nio.charset.Charset; import java.nio.charset.Charset;
import java.time.Duration;
import java.util.List; import java.util.List;
import java.util.Locale; import java.util.Locale;
import java.util.Map; import java.util.Map;
import javax.servlet.ServletContext;
import org.springframework.boot.web.server.ConfigurableWebServerFactory; import org.springframework.boot.web.server.ConfigurableWebServerFactory;
import org.springframework.boot.web.server.MimeMappings; import org.springframework.boot.web.server.MimeMappings;
import org.springframework.boot.web.server.WebServerFactoryCustomizer; import org.springframework.boot.web.server.WebServerFactoryCustomizer;
...@@ -59,23 +60,32 @@ public interface ConfigurableServletWebServerFactory ...@@ -59,23 +60,32 @@ public interface ConfigurableServletWebServerFactory
void setDisplayName(String displayName); void setDisplayName(String displayName);
/** /**
* The session timeout in seconds (default 30 minutes). If {@code null} then sessions * Sets the configuration that will be applied to the container's HTTP session
* never expire. * support.
* @param sessionTimeout the session timeout *
*/ * @param session the session configuration
void setSessionTimeout(Duration sessionTimeout);
/**
* Sets if session data should be persisted between restarts.
* @param persistSession {@code true} if session data should be persisted
*/
void setPersistSession(boolean persistSession);
/**
* Set the directory used to store serialized session data.
* @param sessionStoreDir the directory or {@code null} to use a default location.
*/ */
void setSessionStoreDir(File sessionStoreDir); void setSession(Session session);
// /**
// * The session timeout in seconds (default 30 minutes). If {@code null} then
// sessions
// * never expire.
// * @param sessionTimeout the session timeout
// */
// void setSessionTimeout(Duration sessionTimeout);
//
// /**
// * Sets if session data should be persisted between restarts.
// * @param persistSession {@code true} if session data should be persisted
// */
// void setPersistSession(boolean persistSession);
//
// /**
// * Set the directory used to store serialized session data.
// * @param sessionStoreDir the directory or {@code null} to use a default location.
// */
// void setSessionStoreDir(File sessionStoreDir);
/** /**
* Set if the DefaultServlet should be registered. Defaults to {@code true} so that * Set if the DefaultServlet should be registered. Defaults to {@code true} so that
...@@ -128,4 +138,12 @@ public interface ConfigurableServletWebServerFactory ...@@ -128,4 +138,12 @@ public interface ConfigurableServletWebServerFactory
*/ */
void setLocaleCharsetMappings(Map<Locale, Charset> localeCharsetMappings); void setLocaleCharsetMappings(Map<Locale, Charset> localeCharsetMappings);
/**
* Sets the init parameters that are applied to the container's
* {@link ServletContext}.
*
* @param initParameters the init parameters
*/
void setInitParameters(Map<String, String> initParameters);
} }
/*
* Copyright 2012-2017 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.boot.web.servlet.server;
import java.util.Map;
import java.util.Map.Entry;
import javax.servlet.ServletContext;
import javax.servlet.ServletException;
import org.springframework.boot.web.servlet.ServletContextInitializer;
/**
* A {@code ServletContextInitializer} that configures init parameters on the
* {@code ServletContext}.
*
* @author Andy Wilkinson
* @since 2.0.0
* @see ServletContext#setInitParameter(String, String)
*/
public class InitParameterConfiguringServletContextInitializer
implements ServletContextInitializer {
private final Map<String, String> parameters;
public InitParameterConfiguringServletContextInitializer(
Map<String, String> parameters) {
this.parameters = parameters;
}
@Override
public void onStartup(ServletContext servletContext) throws ServletException {
for (Entry<String, String> entry : this.parameters.entrySet()) {
servletContext.setInitParameter(entry.getKey(), entry.getValue());
}
}
}
/*
* Copyright 2012-2017 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.boot.web.servlet.server;
import java.io.File;
import java.time.Duration;
import java.time.temporal.ChronoUnit;
import java.util.Set;
import org.springframework.boot.context.properties.bind.convert.DefaultDurationUnit;
/**
* Session properties.
*
* @author Andy Wilkinson
* @since 2.0.0
*/
public class Session {
/**
* Session timeout. If a duration suffix is not specified, seconds will be used.
*/
@DefaultDurationUnit(ChronoUnit.SECONDS)
private Duration timeout = Duration.ofMinutes(30);
/**
* Session tracking modes (one or more of the following: "cookie", "url", "ssl").
*/
private Set<Session.SessionTrackingMode> trackingModes;
/**
* Whether to persist session data between restarts.
*/
private boolean persistent;
/**
* Directory used to store session data.
*/
private File storeDir;
private final Cookie cookie = new Cookie();
private final SessionStoreDirectory sessionStoreDirectory = new SessionStoreDirectory();
public Cookie getCookie() {
return this.cookie;
}
public Duration getTimeout() {
return this.timeout;
}
public void setTimeout(Duration timeout) {
this.timeout = timeout;
}
public Set<Session.SessionTrackingMode> getTrackingModes() {
return this.trackingModes;
}
public void setTrackingModes(Set<Session.SessionTrackingMode> trackingModes) {
this.trackingModes = trackingModes;
}
public boolean isPersistent() {
return this.persistent;
}
public void setPersistent(boolean persistent) {
this.persistent = persistent;
}
public File getStoreDir() {
return this.storeDir;
}
public void setStoreDir(File storeDir) {
this.sessionStoreDirectory.setDirectory(storeDir);
this.storeDir = storeDir;
}
SessionStoreDirectory getSessionStoreDirectory() {
return this.sessionStoreDirectory;
}
/**
* Cookie properties.
*/
public static class Cookie {
/**
* Session cookie name.
*/
private String name;
/**
* Domain for the session cookie.
*/
private String domain;
/**
* Path of the session cookie.
*/
private String path;
/**
* Comment for the session cookie.
*/
private String comment;
/**
* "HttpOnly" flag for the session cookie.
*/
private Boolean httpOnly;
/**
* "Secure" flag for the session cookie.
*/
private Boolean secure;
/**
* Maximum age of the session cookie.
*/
@DefaultDurationUnit(ChronoUnit.SECONDS)
private Duration maxAge;
public String getName() {
return this.name;
}
public void setName(String name) {
this.name = name;
}
public String getDomain() {
return this.domain;
}
public void setDomain(String domain) {
this.domain = domain;
}
public String getPath() {
return this.path;
}
public void setPath(String path) {
this.path = path;
}
public String getComment() {
return this.comment;
}
public void setComment(String comment) {
this.comment = comment;
}
public Boolean getHttpOnly() {
return this.httpOnly;
}
public void setHttpOnly(Boolean httpOnly) {
this.httpOnly = httpOnly;
}
public Boolean getSecure() {
return this.secure;
}
public void setSecure(Boolean secure) {
this.secure = secure;
}
public Duration getMaxAge() {
return this.maxAge;
}
public void setMaxAge(Duration maxAge) {
this.maxAge = maxAge;
}
}
/**
* Available session tracking modes (mirrors
* {@link javax.servlet.SessionTrackingMode}.
*/
public enum SessionTrackingMode {
/**
* Send a cookie in response to the client's first request.
*/
COOKIE,
/**
* Rewrite the URL to append a session ID.
*/
URL,
/**
* Use SSL build-in mechanism to track the session.
*/
SSL
}
}
...@@ -99,14 +99,14 @@ public class JettyServletWebServerFactoryTests ...@@ -99,14 +99,14 @@ public class JettyServletWebServerFactoryTests
@Test @Test
public void sessionTimeout() { public void sessionTimeout() {
JettyServletWebServerFactory factory = getFactory(); JettyServletWebServerFactory factory = getFactory();
factory.setSessionTimeout(Duration.ofSeconds(10)); factory.getSession().setTimeout(Duration.ofSeconds(10));
assertTimeout(factory, 10); assertTimeout(factory, 10);
} }
@Test @Test
public void sessionTimeoutInMins() { public void sessionTimeoutInMins() {
JettyServletWebServerFactory factory = getFactory(); JettyServletWebServerFactory factory = getFactory();
factory.setSessionTimeout(Duration.ofMinutes(1)); factory.getSession().setTimeout(Duration.ofMinutes(1));
assertTimeout(factory, 60); assertTimeout(factory, 60);
} }
......
/* /*
* Copyright 2012-2017 the original author or authors. * Copyright 2012-2018 the original author or authors.
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
...@@ -187,21 +187,21 @@ public class TomcatServletWebServerFactoryTests ...@@ -187,21 +187,21 @@ public class TomcatServletWebServerFactoryTests
@Test @Test
public void sessionTimeout() { public void sessionTimeout() {
TomcatServletWebServerFactory factory = getFactory(); TomcatServletWebServerFactory factory = getFactory();
factory.setSessionTimeout(Duration.ofSeconds(10)); factory.getSession().setTimeout(Duration.ofSeconds(10));
assertTimeout(factory, 1); assertTimeout(factory, 1);
} }
@Test @Test
public void sessionTimeoutInMins() { public void sessionTimeoutInMins() {
TomcatServletWebServerFactory factory = getFactory(); TomcatServletWebServerFactory factory = getFactory();
factory.setSessionTimeout(Duration.ofMinutes(1)); factory.getSession().setTimeout(Duration.ofMinutes(1));
assertTimeout(factory, 1); assertTimeout(factory, 1);
} }
@Test @Test
public void noSessionTimeout() { public void noSessionTimeout() {
TomcatServletWebServerFactory factory = getFactory(); TomcatServletWebServerFactory factory = getFactory();
factory.setSessionTimeout(null); factory.getSession().setTimeout(null);
assertTimeout(factory, -1); assertTimeout(factory, -1);
} }
......
/* /*
* Copyright 2012-2017 the original author or authors. * Copyright 2012-2018 the original author or authors.
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
...@@ -40,6 +40,7 @@ import java.util.Arrays; ...@@ -40,6 +40,7 @@ import java.util.Arrays;
import java.util.Collection; import java.util.Collection;
import java.util.Collections; import java.util.Collections;
import java.util.Date; import java.util.Date;
import java.util.EnumSet;
import java.util.HashMap; import java.util.HashMap;
import java.util.Locale; import java.util.Locale;
import java.util.Map; import java.util.Map;
...@@ -59,6 +60,7 @@ import javax.servlet.ServletContext; ...@@ -59,6 +60,7 @@ import javax.servlet.ServletContext;
import javax.servlet.ServletException; import javax.servlet.ServletException;
import javax.servlet.ServletRequest; import javax.servlet.ServletRequest;
import javax.servlet.ServletResponse; import javax.servlet.ServletResponse;
import javax.servlet.SessionCookieConfig;
import javax.servlet.http.HttpServlet; import javax.servlet.http.HttpServlet;
import javax.servlet.http.HttpServletRequest; import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse; import javax.servlet.http.HttpServletResponse;
...@@ -100,6 +102,7 @@ import org.springframework.boot.web.server.WebServerException; ...@@ -100,6 +102,7 @@ import org.springframework.boot.web.server.WebServerException;
import org.springframework.boot.web.servlet.FilterRegistrationBean; import org.springframework.boot.web.servlet.FilterRegistrationBean;
import org.springframework.boot.web.servlet.ServletContextInitializer; import org.springframework.boot.web.servlet.ServletContextInitializer;
import org.springframework.boot.web.servlet.ServletRegistrationBean; import org.springframework.boot.web.servlet.ServletRegistrationBean;
import org.springframework.boot.web.servlet.server.Session.SessionTrackingMode;
import org.springframework.core.io.ClassPathResource; import org.springframework.core.io.ClassPathResource;
import org.springframework.core.io.Resource; import org.springframework.core.io.Resource;
import org.springframework.http.HttpMethod; import org.springframework.http.HttpMethod;
...@@ -713,13 +716,14 @@ public abstract class AbstractServletWebServerFactoryTests { ...@@ -713,13 +716,14 @@ public abstract class AbstractServletWebServerFactoryTests {
@Test @Test
public void defaultSessionTimeout() { public void defaultSessionTimeout() {
assertThat(getFactory().getSessionTimeout()).isEqualTo(Duration.ofMinutes(30)); assertThat(getFactory().getSession().getTimeout())
.isEqualTo(Duration.ofMinutes(30));
} }
@Test @Test
public void persistSession() throws Exception { public void persistSession() throws Exception {
AbstractServletWebServerFactory factory = getFactory(); AbstractServletWebServerFactory factory = getFactory();
factory.setPersistSession(true); factory.getSession().setPersistent(true);
this.webServer = factory.getWebServer(sessionServletRegistration()); this.webServer = factory.getWebServer(sessionServletRegistration());
this.webServer.start(); this.webServer.start();
String s1 = getResponse(getLocalUrl("/session")); String s1 = getResponse(getLocalUrl("/session"));
...@@ -737,8 +741,8 @@ public abstract class AbstractServletWebServerFactoryTests { ...@@ -737,8 +741,8 @@ public abstract class AbstractServletWebServerFactoryTests {
public void persistSessionInSpecificSessionStoreDir() throws Exception { public void persistSessionInSpecificSessionStoreDir() throws Exception {
AbstractServletWebServerFactory factory = getFactory(); AbstractServletWebServerFactory factory = getFactory();
File sessionStoreDir = this.temporaryFolder.newFolder(); File sessionStoreDir = this.temporaryFolder.newFolder();
factory.setPersistSession(true); factory.getSession().setPersistent(true);
factory.setSessionStoreDir(sessionStoreDir); factory.getSession().setStoreDir(sessionStoreDir);
this.webServer = factory.getWebServer(sessionServletRegistration()); this.webServer = factory.getWebServer(sessionServletRegistration());
this.webServer.start(); this.webServer.start();
getResponse(getLocalUrl("/session")); getResponse(getLocalUrl("/session"));
...@@ -759,7 +763,7 @@ public abstract class AbstractServletWebServerFactoryTests { ...@@ -759,7 +763,7 @@ public abstract class AbstractServletWebServerFactoryTests {
@Test @Test
public void getValidSessionStoreWhenSessionStoreIsRelative() { public void getValidSessionStoreWhenSessionStoreIsRelative() {
AbstractServletWebServerFactory factory = getFactory(); AbstractServletWebServerFactory factory = getFactory();
factory.setSessionStoreDir(new File("sessions")); factory.getSession().setStoreDir(new File("sessions"));
File dir = factory.getValidSessionStoreDir(false); File dir = factory.getValidSessionStoreDir(false);
assertThat(dir.getName()).isEqualTo("sessions"); assertThat(dir.getName()).isEqualTo("sessions");
assertThat(dir.getParentFile()).isEqualTo(new ApplicationHome().getDir()); assertThat(dir.getParentFile()).isEqualTo(new ApplicationHome().getDir());
...@@ -768,12 +772,35 @@ public abstract class AbstractServletWebServerFactoryTests { ...@@ -768,12 +772,35 @@ public abstract class AbstractServletWebServerFactoryTests {
@Test @Test
public void getValidSessionStoreWhenSessionStoreReferencesFile() throws Exception { public void getValidSessionStoreWhenSessionStoreReferencesFile() throws Exception {
AbstractServletWebServerFactory factory = getFactory(); AbstractServletWebServerFactory factory = getFactory();
factory.setSessionStoreDir(this.temporaryFolder.newFile()); factory.getSession().setStoreDir(this.temporaryFolder.newFile());
this.thrown.expect(IllegalStateException.class); this.thrown.expect(IllegalStateException.class);
this.thrown.expectMessage("points to a file"); this.thrown.expectMessage("points to a file");
factory.getValidSessionStoreDir(false); factory.getValidSessionStoreDir(false);
} }
@Test
public void sessionCookieConfiguration() {
AbstractServletWebServerFactory factory = getFactory();
factory.getSession().getCookie().setName("testname");
factory.getSession().getCookie().setDomain("testdomain");
factory.getSession().getCookie().setPath("/testpath");
factory.getSession().getCookie().setComment("testcomment");
factory.getSession().getCookie().setHttpOnly(true);
factory.getSession().getCookie().setSecure(true);
factory.getSession().getCookie().setMaxAge(Duration.ofSeconds(60));
final AtomicReference<SessionCookieConfig> configReference = new AtomicReference<>();
this.webServer = factory.getWebServer(
(context) -> configReference.set(context.getSessionCookieConfig()));
SessionCookieConfig sessionCookieConfig = configReference.get();
assertThat(sessionCookieConfig.getName()).isEqualTo("testname");
assertThat(sessionCookieConfig.getDomain()).isEqualTo("testdomain");
assertThat(sessionCookieConfig.getPath()).isEqualTo("/testpath");
assertThat(sessionCookieConfig.getComment()).isEqualTo("testcomment");
assertThat(sessionCookieConfig.isHttpOnly()).isTrue();
assertThat(sessionCookieConfig.isSecure()).isTrue();
assertThat(sessionCookieConfig.getMaxAge()).isEqualTo(60);
}
@Test @Test
public void compressionOfResponseToGetRequest() throws Exception { public void compressionOfResponseToGetRequest() throws Exception {
assertThat(doTestCompression(10000, null, null)).isTrue(); assertThat(doTestCompression(10000, null, null)).isTrue();
...@@ -969,6 +996,48 @@ public abstract class AbstractServletWebServerFactoryTests { ...@@ -969,6 +996,48 @@ public abstract class AbstractServletWebServerFactoryTests {
factory.getWebServer().start(); factory.getWebServer().start();
} }
@Test
public void sessionConfiguration() {
AbstractServletWebServerFactory factory = getFactory();
// map.put("server.servlet.session.timeout", "123");
// map.put("server.servlet.session.tracking-modes", "cookie,url");
// map.put("server.servlet.session.cookie.name", "testname");
// map.put("server.servlet.session.cookie.domain", "testdomain");
// map.put("server.servlet.session.cookie.path", "/testpath");
// map.put("server.servlet.session.cookie.comment", "testcomment");
// map.put("server.servlet.session.cookie.http-only", "true");
// map.put("server.servlet.session.cookie.secure", "true");
// map.put("server.servlet.session.cookie.max-age", "60");
factory.getSession().setTimeout(Duration.ofSeconds(123));
factory.getSession().setTrackingModes(
EnumSet.of(SessionTrackingMode.COOKIE, SessionTrackingMode.URL));
factory.getSession().getCookie().setName("testname");
factory.getSession().getCookie().setDomain("testdomain");
factory.getSession().getCookie().setPath("/testpath");
factory.getSession().getCookie().setComment("testcomment");
factory.getSession().getCookie().setHttpOnly(true);
factory.getSession().getCookie().setSecure(true);
factory.getSession().getCookie().setMaxAge(Duration.ofMinutes(1));
AtomicReference<ServletContext> contextReference = new AtomicReference<ServletContext>();
factory.getWebServer(contextReference::set).start();
ServletContext servletContext = contextReference.get();
assertThat(servletContext.getEffectiveSessionTrackingModes())
.isEqualTo(EnumSet.of(javax.servlet.SessionTrackingMode.COOKIE,
javax.servlet.SessionTrackingMode.URL));
assertThat(servletContext.getSessionCookieConfig().getName())
.isEqualTo("testname");
assertThat(servletContext.getSessionCookieConfig().getDomain())
.isEqualTo("testdomain");
assertThat(servletContext.getSessionCookieConfig().getPath())
.isEqualTo("/testpath");
assertThat(servletContext.getSessionCookieConfig().getComment())
.isEqualTo("testcomment");
assertThat(servletContext.getSessionCookieConfig().isHttpOnly()).isTrue();
assertThat(servletContext.getSessionCookieConfig().isSecure()).isTrue();
assertThat(servletContext.getSessionCookieConfig().getMaxAge()).isEqualTo(60);
}
protected abstract void addConnector(int port, protected abstract void addConnector(int port,
AbstractServletWebServerFactory factory); AbstractServletWebServerFactory factory);
...@@ -1016,8 +1085,7 @@ public abstract class AbstractServletWebServerFactoryTests { ...@@ -1016,8 +1085,7 @@ public abstract class AbstractServletWebServerFactoryTests {
@Override @Override
protected void service(HttpServletRequest req, protected void service(HttpServletRequest req,
HttpServletResponse resp) HttpServletResponse resp) throws IOException {
throws IOException {
resp.setContentType("text/plain"); resp.setContentType("text/plain");
resp.setContentLength(testContent.length()); resp.setContentLength(testContent.length());
resp.getWriter().write(testContent); resp.getWriter().write(testContent);
...@@ -1140,7 +1208,8 @@ public abstract class AbstractServletWebServerFactoryTests { ...@@ -1140,7 +1208,8 @@ public abstract class AbstractServletWebServerFactoryTests {
new ExampleServlet() { new ExampleServlet() {
@Override @Override
public void service(ServletRequest request, ServletResponse response) { public void service(ServletRequest request,
ServletResponse response) {
throw new RuntimeException("Planned"); throw new RuntimeException("Planned");
} }
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment