commit fe6c5d650354afb71e1bf56a67112a797b29b6b9 Author: Rob Winch Date: Fri Jun 20 16:04:27 2014 -0500 Initial diff --git a/.gitignore b/.gitignore new file mode 100644 index 00000000..169c2c0c --- /dev/null +++ b/.gitignore @@ -0,0 +1,8 @@ +.gradle +.idea +build +*.iml +bin +.classpath +.settings +target \ No newline at end of file diff --git a/README.adoc b/README.adoc new file mode 100644 index 00000000..fbcc4cc1 --- /dev/null +++ b/README.adoc @@ -0,0 +1,12 @@ +Benefits + +* This can make clustering much easier. This is nice because the clustering setup is done in a vendor neutral way. Furthermore, in some environments (i.e. PaaS solutions) developers cannot modify the cluster settings easily. +* We can use different strategies for determining the session id. This gives us at least a few benefits +** Allowing for a single browser to have multiple simultaneous sessions in a transparent fashion. For example, many developers wish to allow a user to authenticate with multiple accounts and switch between them similar to how you can in gmail. +** When using a REST API, the session can be specified using a header instead of the JSESSIONID cookie (which leaks implementation details to the client). Many would argue that session is bad in REST because it has state, but it is important to note that session is just a form of cache and used responsibly it will increase performance & security. +** When a session id is acquired in a header, we can default CSRF protection to off. This is because if the session id is found in the header we know that it is impossible to be a CSRF attack since, unlike cookies, headers must be manually populated. +* We can easily keep the HttpSession and WebSocket Session in sync. Imagine a web application like gmail where you can authenticate and either write emails (HTTP requests) or chat (WebSocket). In standard servlet environment there is no way to keep the HttpSession alive through the WebSocket so you must ping the server. With our own session strategy we can have the WebSocket messages automatically keep the HttpSession alive. We can also destroy both sessions at once easily. +* We can provide hooks to allow users to invalidate sessions that should not be active. For example, if you look in the lower right of gmail you can see the last account activity and click "Details". This shows a listing of all the active sessions along with the IP address, location, and browser information for your account. +** Users can look through this and determine if anything is suspicious (i.e. if their account has a session that is associated to a country they have never been) and invalidate that session and change their password. +** Another useful example is perhaps they checked their mail at the library and forgot to log out. With this custom mechanism this is very possible. +* Spring Security currently supports restricting the number of concurrent sessions each user can have. The implementation works, but does so passively since we cannot get a handle to the session from the session id. Specifically, each time a user requests a page we check to see if that session id is valid in a separate data store. If it is no longer valid, we invalidate the session. With this new mechanism we can invalidate the session from the session id. \ No newline at end of file diff --git a/build.gradle b/build.gradle new file mode 100644 index 00000000..264c6e0c --- /dev/null +++ b/build.gradle @@ -0,0 +1,90 @@ +apply plugin: 'java' + +buildscript { + repositories { + maven { url "https://repo.spring.io/plugins-release" } + } + dependencies { + classpath("org.springframework.build.gradle:propdeps-plugin:0.0.6") + classpath("org.springframework.build.gradle:spring-io-plugin:0.0.3.RELEASE") + classpath('me.champeau.gradle:gradle-javadoc-hotfix-plugin:0.1') + classpath('org.asciidoctor:asciidoctor-gradle-plugin:0.7.0') + classpath('org.asciidoctor:asciidoctor-java-integration:0.1.4.preview.1') + } +} + +apply plugin: 'java' +apply plugin: 'groovy' +apply plugin: 'javadocHotfix' +apply plugin: 'eclipse-wtp' +apply plugin: 'propdeps' +apply plugin: 'propdeps-maven' +apply plugin: 'propdeps-idea' +apply plugin: 'propdeps-eclipse' + +group = 'org.springframework.session' + +sourceCompatibility = 1.5 +targetCompatibility = 1.5 + +ext.servletApiVersion = '3.0.1' +ext.springSecurityVersion = '3.2.4.RELEASE' +ext.springVersion = '4.0.2.RELEASE' + +repositories { + mavenCentral() + maven { url 'http://clojars.org/repo' } +} + +// Integration test setup +configurations { + integrationTestCompile { + extendsFrom testCompile, optional, provided + } + integrationTestRuntime { + extendsFrom integrationTestCompile, testRuntime + } +} + +sourceSets { + integrationTest { + java.srcDir file('src/integration-test/java') + groovy.srcDirs file('src/integration-test/groovy') + resources.srcDir file('src/integration-test/resources') + compileClasspath = sourceSets.main.output + sourceSets.test.output + configurations.integrationTestCompile + runtimeClasspath = output + compileClasspath + configurations.integrationTestRuntime + } +} + +dependencies { + optional "org.springframework.data:spring-data-redis:1.3.0.RELEASE" + provided "javax.servlet:javax.servlet-api:$servletApiVersion" + integrationTestCompile "redis.clients:jedis:2.4.1", + "org.apache.commons:commons-pool2:2.2", + "redis.embedded:embedded-redis:0.2" + testCompile 'junit:junit:4.11', + 'org.mockito:mockito-core:1.9.5', + "org.springframework:spring-test:$springVersion", + 'org.easytesting:fest-assert:1.4', + "org.springframework.security:spring-security-core:$springSecurityVersion" +} + + +task integrationTest(type: Test, dependsOn: jar) { + testClassesDir = sourceSets.integrationTest.output.classesDir + logging.captureStandardOutput(LogLevel.INFO) + classpath = sourceSets.integrationTest.runtimeClasspath + maxParallelForks = 1 + reports { + html.destination = project.file("$project.buildDir/reports/integration-tests/") + junitXml.destination = project.file("$project.buildDir/integration-test-results/") + } +} + +project.conf2ScopeMappings.addMapping(MavenPlugin.TEST_COMPILE_PRIORITY + 1, project.configurations.getByName("integrationTestCompile"), Conf2ScopeMappingContainer.TEST) +project.conf2ScopeMappings.addMapping(MavenPlugin.TEST_COMPILE_PRIORITY + 2, project.configurations.getByName("integrationTestRuntime"), Conf2ScopeMappingContainer.TEST) +check.dependsOn integrationTest + +project.idea.module { + scopes.TEST.plus += [project.configurations.integrationTestRuntime] +} diff --git a/gradle.properties b/gradle.properties new file mode 100644 index 00000000..d471ccf1 --- /dev/null +++ b/gradle.properties @@ -0,0 +1 @@ +version = '1.0.0.BUILD-SNAPSHOT' \ No newline at end of file diff --git a/gradle/wrapper/gradle-wrapper.jar b/gradle/wrapper/gradle-wrapper.jar new file mode 100644 index 00000000..0087cd3b Binary files /dev/null and b/gradle/wrapper/gradle-wrapper.jar differ diff --git a/gradle/wrapper/gradle-wrapper.properties b/gradle/wrapper/gradle-wrapper.properties new file mode 100644 index 00000000..6de9b343 --- /dev/null +++ b/gradle/wrapper/gradle-wrapper.properties @@ -0,0 +1,6 @@ +#Wed Jun 18 14:02:09 CDT 2014 +distributionBase=GRADLE_USER_HOME +distributionPath=wrapper/dists +zipStoreBase=GRADLE_USER_HOME +zipStorePath=wrapper/dists +distributionUrl=https\://services.gradle.org/distributions/gradle-1.12-all.zip diff --git a/gradlew b/gradlew new file mode 100755 index 00000000..91a7e269 --- /dev/null +++ b/gradlew @@ -0,0 +1,164 @@ +#!/usr/bin/env bash + +############################################################################## +## +## Gradle start up script for UN*X +## +############################################################################## + +# Add default JVM options here. You can also use JAVA_OPTS and GRADLE_OPTS to pass JVM options to this script. +DEFAULT_JVM_OPTS="" + +APP_NAME="Gradle" +APP_BASE_NAME=`basename "$0"` + +# Use the maximum available, or set MAX_FD != -1 to use that value. +MAX_FD="maximum" + +warn ( ) { + echo "$*" +} + +die ( ) { + echo + echo "$*" + echo + exit 1 +} + +# OS specific support (must be 'true' or 'false'). +cygwin=false +msys=false +darwin=false +case "`uname`" in + CYGWIN* ) + cygwin=true + ;; + Darwin* ) + darwin=true + ;; + MINGW* ) + msys=true + ;; +esac + +# For Cygwin, ensure paths are in UNIX format before anything is touched. +if $cygwin ; then + [ -n "$JAVA_HOME" ] && JAVA_HOME=`cygpath --unix "$JAVA_HOME"` +fi + +# Attempt to set APP_HOME +# Resolve links: $0 may be a link +PRG="$0" +# Need this for relative symlinks. +while [ -h "$PRG" ] ; do + ls=`ls -ld "$PRG"` + link=`expr "$ls" : '.*-> \(.*\)$'` + if expr "$link" : '/.*' > /dev/null; then + PRG="$link" + else + PRG=`dirname "$PRG"`"/$link" + fi +done +SAVED="`pwd`" +cd "`dirname \"$PRG\"`/" >&- +APP_HOME="`pwd -P`" +cd "$SAVED" >&- + +CLASSPATH=$APP_HOME/gradle/wrapper/gradle-wrapper.jar + +# Determine the Java command to use to start the JVM. +if [ -n "$JAVA_HOME" ] ; then + if [ -x "$JAVA_HOME/jre/sh/java" ] ; then + # IBM's JDK on AIX uses strange locations for the executables + JAVACMD="$JAVA_HOME/jre/sh/java" + else + JAVACMD="$JAVA_HOME/bin/java" + fi + if [ ! -x "$JAVACMD" ] ; then + die "ERROR: JAVA_HOME is set to an invalid directory: $JAVA_HOME + +Please set the JAVA_HOME variable in your environment to match the +location of your Java installation." + fi +else + JAVACMD="java" + which java >/dev/null 2>&1 || die "ERROR: JAVA_HOME is not set and no 'java' command could be found in your PATH. + +Please set the JAVA_HOME variable in your environment to match the +location of your Java installation." +fi + +# Increase the maximum file descriptors if we can. +if [ "$cygwin" = "false" -a "$darwin" = "false" ] ; then + MAX_FD_LIMIT=`ulimit -H -n` + if [ $? -eq 0 ] ; then + if [ "$MAX_FD" = "maximum" -o "$MAX_FD" = "max" ] ; then + MAX_FD="$MAX_FD_LIMIT" + fi + ulimit -n $MAX_FD + if [ $? -ne 0 ] ; then + warn "Could not set maximum file descriptor limit: $MAX_FD" + fi + else + warn "Could not query maximum file descriptor limit: $MAX_FD_LIMIT" + fi +fi + +# For Darwin, add options to specify how the application appears in the dock +if $darwin; then + GRADLE_OPTS="$GRADLE_OPTS \"-Xdock:name=$APP_NAME\" \"-Xdock:icon=$APP_HOME/media/gradle.icns\"" +fi + +# For Cygwin, switch paths to Windows format before running java +if $cygwin ; then + APP_HOME=`cygpath --path --mixed "$APP_HOME"` + CLASSPATH=`cygpath --path --mixed "$CLASSPATH"` + + # We build the pattern for arguments to be converted via cygpath + ROOTDIRSRAW=`find -L / -maxdepth 1 -mindepth 1 -type d 2>/dev/null` + SEP="" + for dir in $ROOTDIRSRAW ; do + ROOTDIRS="$ROOTDIRS$SEP$dir" + SEP="|" + done + OURCYGPATTERN="(^($ROOTDIRS))" + # Add a user-defined pattern to the cygpath arguments + if [ "$GRADLE_CYGPATTERN" != "" ] ; then + OURCYGPATTERN="$OURCYGPATTERN|($GRADLE_CYGPATTERN)" + fi + # Now convert the arguments - kludge to limit ourselves to /bin/sh + i=0 + for arg in "$@" ; do + CHECK=`echo "$arg"|egrep -c "$OURCYGPATTERN" -` + CHECK2=`echo "$arg"|egrep -c "^-"` ### Determine if an option + + if [ $CHECK -ne 0 ] && [ $CHECK2 -eq 0 ] ; then ### Added a condition + eval `echo args$i`=`cygpath --path --ignore --mixed "$arg"` + else + eval `echo args$i`="\"$arg\"" + fi + i=$((i+1)) + done + case $i in + (0) set -- ;; + (1) set -- "$args0" ;; + (2) set -- "$args0" "$args1" ;; + (3) set -- "$args0" "$args1" "$args2" ;; + (4) set -- "$args0" "$args1" "$args2" "$args3" ;; + (5) set -- "$args0" "$args1" "$args2" "$args3" "$args4" ;; + (6) set -- "$args0" "$args1" "$args2" "$args3" "$args4" "$args5" ;; + (7) set -- "$args0" "$args1" "$args2" "$args3" "$args4" "$args5" "$args6" ;; + (8) set -- "$args0" "$args1" "$args2" "$args3" "$args4" "$args5" "$args6" "$args7" ;; + (9) set -- "$args0" "$args1" "$args2" "$args3" "$args4" "$args5" "$args6" "$args7" "$args8" ;; + esac +fi + +# Split up the JVM_OPTS And GRADLE_OPTS values into an array, following the shell quoting and substitution rules +function splitJvmOpts() { + JVM_OPTS=("$@") +} +eval splitJvmOpts $DEFAULT_JVM_OPTS $JAVA_OPTS $GRADLE_OPTS +JVM_OPTS[${#JVM_OPTS[*]}]="-Dorg.gradle.appname=$APP_BASE_NAME" + +exec "$JAVACMD" "${JVM_OPTS[@]}" -classpath "$CLASSPATH" org.gradle.wrapper.GradleWrapperMain "$@" diff --git a/gradlew.bat b/gradlew.bat new file mode 100644 index 00000000..aec99730 --- /dev/null +++ b/gradlew.bat @@ -0,0 +1,90 @@ +@if "%DEBUG%" == "" @echo off +@rem ########################################################################## +@rem +@rem Gradle startup script for Windows +@rem +@rem ########################################################################## + +@rem Set local scope for the variables with windows NT shell +if "%OS%"=="Windows_NT" setlocal + +@rem Add default JVM options here. You can also use JAVA_OPTS and GRADLE_OPTS to pass JVM options to this script. +set DEFAULT_JVM_OPTS= + +set DIRNAME=%~dp0 +if "%DIRNAME%" == "" set DIRNAME=. +set APP_BASE_NAME=%~n0 +set APP_HOME=%DIRNAME% + +@rem Find java.exe +if defined JAVA_HOME goto findJavaFromJavaHome + +set JAVA_EXE=java.exe +%JAVA_EXE% -version >NUL 2>&1 +if "%ERRORLEVEL%" == "0" goto init + +echo. +echo ERROR: JAVA_HOME is not set and no 'java' command could be found in your PATH. +echo. +echo Please set the JAVA_HOME variable in your environment to match the +echo location of your Java installation. + +goto fail + +:findJavaFromJavaHome +set JAVA_HOME=%JAVA_HOME:"=% +set JAVA_EXE=%JAVA_HOME%/bin/java.exe + +if exist "%JAVA_EXE%" goto init + +echo. +echo ERROR: JAVA_HOME is set to an invalid directory: %JAVA_HOME% +echo. +echo Please set the JAVA_HOME variable in your environment to match the +echo location of your Java installation. + +goto fail + +:init +@rem Get command-line arguments, handling Windowz variants + +if not "%OS%" == "Windows_NT" goto win9xME_args +if "%@eval[2+2]" == "4" goto 4NT_args + +:win9xME_args +@rem Slurp the command line arguments. +set CMD_LINE_ARGS= +set _SKIP=2 + +:win9xME_args_slurp +if "x%~1" == "x" goto execute + +set CMD_LINE_ARGS=%* +goto execute + +:4NT_args +@rem Get arguments from the 4NT Shell from JP Software +set CMD_LINE_ARGS=%$ + +:execute +@rem Setup the command line + +set CLASSPATH=%APP_HOME%\gradle\wrapper\gradle-wrapper.jar + +@rem Execute Gradle +"%JAVA_EXE%" %DEFAULT_JVM_OPTS% %JAVA_OPTS% %GRADLE_OPTS% "-Dorg.gradle.appname=%APP_BASE_NAME%" -classpath "%CLASSPATH%" org.gradle.wrapper.GradleWrapperMain %CMD_LINE_ARGS% + +:end +@rem End local scope for the variables with windows NT shell +if "%ERRORLEVEL%"=="0" goto mainEnd + +:fail +rem Set variable GRADLE_EXIT_CONSOLE if you need the _script_ return code instead of +rem the _cmd.exe /c_ return code! +if not "" == "%GRADLE_EXIT_CONSOLE%" exit 1 +exit /b 1 + +:mainEnd +if "%OS%"=="Windows_NT" endlocal + +:omega diff --git a/settings.gradle b/settings.gradle new file mode 100644 index 00000000..5f6f7f5e --- /dev/null +++ b/settings.gradle @@ -0,0 +1,2 @@ +rootProject.name = 'spring-session' + diff --git a/src/integration-test/java/org/springframework/session/redis/RedisOperationsSessionRepositoryITests.java b/src/integration-test/java/org/springframework/session/redis/RedisOperationsSessionRepositoryITests.java new file mode 100644 index 00000000..02d2b79a --- /dev/null +++ b/src/integration-test/java/org/springframework/session/redis/RedisOperationsSessionRepositoryITests.java @@ -0,0 +1,107 @@ +package org.springframework.session.redis; + +import static org.fest.assertions.Assertions.*; + +import org.junit.After; +import org.junit.Before; +import org.junit.Test; +import org.junit.runner.RunWith; +import org.springframework.beans.factory.annotation.Autowired; +import org.springframework.context.annotation.Bean; +import org.springframework.context.annotation.Configuration; +import org.springframework.data.redis.connection.RedisConnectionFactory; +import org.springframework.data.redis.connection.jedis.JedisConnectionFactory; +import org.springframework.data.redis.core.RedisTemplate; +import org.springframework.data.redis.serializer.StringRedisSerializer; +import org.springframework.security.authentication.UsernamePasswordAuthenticationToken; +import org.springframework.security.core.Authentication; +import org.springframework.security.core.authority.AuthorityUtils; +import org.springframework.security.core.context.SecurityContext; +import org.springframework.security.core.context.SecurityContextHolder; +import org.springframework.session.Session; +import org.springframework.session.SessionRepository; +import org.springframework.test.context.ContextConfiguration; +import org.springframework.test.context.junit4.SpringJUnit4ClassRunner; +import redis.embedded.RedisServer; + +import java.io.IOException; +import java.net.ServerSocket; + +@RunWith(SpringJUnit4ClassRunner.class) +@ContextConfiguration +public class RedisOperationsSessionRepositoryITests { + private RedisServer redisServer; + + @Autowired + private SessionRepository repository; + + @Before + public void setup() throws IOException { + redisServer = new RedisServer(getPort()); + redisServer.start(); + } + + @After + public void shutdown() throws InterruptedException { + redisServer.stop(); + } + + @Test + public void saves() { + Session toSave = repository.createSession(); + toSave.setAttribute("a", "b"); + Authentication toSaveToken = new UsernamePasswordAuthenticationToken("user","password", AuthorityUtils.createAuthorityList("ROLE_USER")); + SecurityContext toSaveContext = SecurityContextHolder.createEmptyContext(); + toSaveContext.setAuthentication(toSaveToken); + toSave.setAttribute("SPRING_SECURITY_CONTEXT", toSaveContext); + + repository.save(toSave); + + Session session = repository.getSession(toSave.getId()); + + assertThat(session.getId()).isEqualTo(toSave.getId()); + assertThat(session.getAttributeNames()).isEqualTo(session.getAttributeNames()); + assertThat(session.getAttribute("a")).isEqualTo(toSave.getAttribute("a")); + + SecurityContext context = (SecurityContext) session.getAttribute("SPRING_SECURITY_CONTEXT"); + + repository.delete(toSave.getId()); + + assertThat(repository.getSession(toSave.getId())).isNull(); + } + + @Configuration + static class Config { + @Bean + public JedisConnectionFactory connectionFactory() throws Exception { + JedisConnectionFactory factory = new JedisConnectionFactory(); + factory.setPort(getPort()); + return factory; + } + + @Bean + public RedisTemplate redisTemplate(RedisConnectionFactory connectionFactory) { + RedisTemplate template = new RedisTemplate(); + template.setKeySerializer(new StringRedisSerializer()); + template.setHashKeySerializer(new StringRedisSerializer()); + template.setConnectionFactory(connectionFactory); + return template; + } + + @Bean + public RedisOperationsSessionRepository sessionRepository(RedisTemplate redisTemplate) { + return new RedisOperationsSessionRepository(redisTemplate); + } + } + + private static Integer availablePort; + + private static int getPort() throws IOException { + if(availablePort == null) { + ServerSocket socket = new ServerSocket(0); + availablePort = socket.getLocalPort(); + socket.close(); + } + return availablePort; + } +} \ No newline at end of file diff --git a/src/main/java/org/springframework/session/MapSession.java b/src/main/java/org/springframework/session/MapSession.java new file mode 100644 index 00000000..34d899c5 --- /dev/null +++ b/src/main/java/org/springframework/session/MapSession.java @@ -0,0 +1,155 @@ +/* + * Copyright 2002-2014 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ +package org.springframework.session; + +import org.springframework.util.Assert; + +import java.util.HashMap; +import java.util.Map; +import java.util.Set; +import java.util.UUID; + +/** + *

+ * A {@link Session} implementation that is backed by a {@link java.util.Map}. The defaults for the properties are: + *

+ *
    + *
  • id - a secure random generated id
  • + *
  • creationTime - the moment the {@link MapSession} was instantiated
  • + *
  • lastAccessedTime - the moment the {@link MapSession} was instantiated
  • + *
  • maxInactiveInterval - 30 minutes
  • + *
+ * + *

+ * This implementation has no synchronization, so it is best to use the copy constructor when working on multiple threads. + *

+ * + * @author Rob Winch + */ +public final class MapSession implements Session { + private String id = UUID.randomUUID().toString(); + private Map sessionAttrs = new HashMap(); + private long creationTime = System.currentTimeMillis(); + private long lastAccessedTime = creationTime; + + /** + * Defaults to 30 minutes + */ + private int maxInactiveInterval = 1800; + + /** + * Creates a new instance + */ + public MapSession() { + } + + /** + * Creates a new instance from the provided {@link Session} + * + * @param session the {@link Session} to initialize this {@link Session} with. Cannot be null. + */ + public MapSession(Session session) { + Assert.notNull(session, "session cannot be null"); + this.id = session.getId(); + this.sessionAttrs = new HashMap(session.getAttributeNames().size()); + for (String attrName : session.getAttributeNames()) { + Object attrValue = session.getAttribute(attrName); + this.sessionAttrs.put(attrName, attrValue); + } + this.lastAccessedTime = session.getLastAccessedTime(); + this.creationTime = session.getCreationTime(); + this.maxInactiveInterval = session.getMaxInactiveInterval(); + } + + @Override + public void setLastAccessedTime(long lastAccessedTime) { + this.lastAccessedTime = lastAccessedTime; + } + + @Override + public long getCreationTime() { + return creationTime; + } + + @Override + public String getId() { + return id; + } + + @Override + public long getLastAccessedTime() { + return lastAccessedTime; + } + + @Override + public void setMaxInactiveInterval(int interval) { + this.maxInactiveInterval = interval; + } + + @Override + public int getMaxInactiveInterval() { + return maxInactiveInterval; + } + + @Override + public Object getAttribute(String attributeName) { + return sessionAttrs.get(attributeName); + } + + @Override + public Set getAttributeNames() { + return sessionAttrs.keySet(); + } + + @Override + public void setAttribute(String attributeName, Object attributeValue) { + if (attributeValue == null) { + removeAttribute(attributeName); + } else { + sessionAttrs.put(attributeName, attributeValue); + } + } + + @Override + public void removeAttribute(String attributeName) { + sessionAttrs.remove(attributeName); + } + + /** + * Sets the time that this {@link Session} was created in milliseconds since midnight of 1/1/1970 GMT. The default is when the {@link Session} was instantiated. + * @param creationTime the time that this {@link Session} was created in milliseconds since midnight of 1/1/1970 GMT. + */ + public void setCreationTime(long creationTime) { + this.creationTime = creationTime; + } + + /** + * Sets the identifier for this {@link Session}. The id should be a secure random generated value to prevent malicious users from guessing this value. The default is a secure random generated identifier. + * + * @param id the identifier for this session. + */ + public void setId(String id) { + this.id = id; + } + + public boolean equals(Object obj) { + return obj instanceof Session && id.equals(((Session) obj).getId()); + } + + public int hashCode() { + return id.hashCode(); + } +} \ No newline at end of file diff --git a/src/main/java/org/springframework/session/MapSessionRepository.java b/src/main/java/org/springframework/session/MapSessionRepository.java new file mode 100644 index 00000000..f8705aa2 --- /dev/null +++ b/src/main/java/org/springframework/session/MapSessionRepository.java @@ -0,0 +1,67 @@ +/* + * Copyright 2002-2013 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ +package org.springframework.session; + +import org.springframework.util.Assert; + +import java.util.Map; +import java.util.concurrent.ConcurrentHashMap; + +/** + * A {@link SessionRepository} backed by a {@link java.util.Map} and that uses a {@link MapSession}. By default a + * {@link java.util.concurrent.ConcurrentHashMap} is used, but a custom {@link java.util.Map} can be injected to use + * distributed maps provided by NoSQL stores like Redis and Hazelcast. + * + * @author Rob Winch + * @since 4.0 + */ +public class MapSessionRepository implements SessionRepository { + private final Map sessions; + + /** + * Creates an instance backed by a {@link java.util.concurrent.ConcurrentHashMap} + */ + public MapSessionRepository() { + this(new ConcurrentHashMap()); + } + + /** + * Creates a new instance backed by the provided {@link java.util.Map}. This allows injecting a distributed {@link java.util.Map}. + * + * @param sessions the {@link java.util.Map} to use. Cannot be null. + */ + public MapSessionRepository(Map sessions) { + Assert.notNull(sessions, "sessions cannot be null"); + this.sessions = sessions; + } + + public void save(Session session) { + sessions.put(session.getId(), new MapSession(session)); + } + + public Session getSession(String id) { + Session result = sessions.get(id); + return result == null ? null : new MapSession(result); + } + + public void delete(String id) { + sessions.remove(id); + } + + public Session createSession() { + return new MapSession(); + } +} diff --git a/src/main/java/org/springframework/session/Session.java b/src/main/java/org/springframework/session/Session.java new file mode 100644 index 00000000..9c661e34 --- /dev/null +++ b/src/main/java/org/springframework/session/Session.java @@ -0,0 +1,100 @@ +/* + * Copyright 2002-2014 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ +package org.springframework.session; + +import java.io.Serializable; +import java.util.Set; + +/** + * Provides a way to identify a user in an agnostic way. This allows the session to be used by an HttpSession, WebSocket + * Session, or even non web related sessions. + * + * @author Rob Winch + * @since 4.0 + */ +public interface Session extends Serializable { + /** + * Allows setting the last time this {@link Session} was accessed. + * + * @param lastAccessedTime the last time the client sent a request associated with the session expressed in milliseconds since midnight of 1/1/1970 GMT + */ + void setLastAccessedTime(long lastAccessedTime); + + /** + * Gets the time when this session was created in milliseconds since midnight of 1/1/1970 GMT. + * + * @return the time when this session was created in milliseconds since midnight of 1/1/1970 GMT. + */ + long getCreationTime(); + + /** + * Gets a unique string that identifies the {@link Session} + * + * @return a unique string that identifies the {@link Session} + */ + String getId(); + + /** + * Gets the last time this {@link Session} was accessed expressed in milliseconds since midnight of 1/1/1970 GMT + * + * @return the last time the client sent a request associated with the session expressed in milliseconds since midnight of 1/1/1970 GMT + */ + long getLastAccessedTime(); + + /** + * Sets the maximum inactive interval in seconds between requests before this session will be invalidated. A negative time indicates that the session will never timeout. + * + * @param interval the number of seconds that the {@link Session} should be kept alive between client requests. + */ + void setMaxInactiveInterval(int interval); + + /** + * Gets the maximum inactive interval in seconds between requests before this session will be invalidated. A negative time indicates that the session will never timeout. + * + * @return the maximum inactive interval in seconds between requests before this session will be invalidated. A negative time indicates that the session will never timeout. + */ + int getMaxInactiveInterval(); + + /** + * Gets the Object associated with the specified name or null if no Object is associated to that name. + * + * @param attributeName the name of the attribute to get + * @return the Object associated with the specified name or null if no Object is associated to that name + */ + Object getAttribute(String attributeName); + + /** + * Gets the attribute names that have a value associated with it. Each value can be passed into {@link org.springframework.session.Session#getAttribute(String)} to obtain the attribute value. + * + * @return the attribute names that have a value associated with it. + * @see #getAttribute(String) + */ + Set getAttributeNames(); + + /** + * Sets the attribute value for the provided attribute name. If the attributeValue is null, it has the same result as removing the attribute with {@link org.springframework.session.Session#removeAttribute(String)} . + * + * @param attributeName the attribute name to set + * @param attributeValue the value of the attribute to set. If null, the attribute will be removed. + */ + void setAttribute(String attributeName, Object attributeValue); + + /** + * Removes the attribute with the provided attribute name + * @param attributeName the name of the attribute to remove + */ + void removeAttribute(String attributeName); +} \ No newline at end of file diff --git a/src/main/java/org/springframework/session/SessionRepository.java b/src/main/java/org/springframework/session/SessionRepository.java new file mode 100644 index 00000000..5e97f232 --- /dev/null +++ b/src/main/java/org/springframework/session/SessionRepository.java @@ -0,0 +1,60 @@ +/* + * Copyright 2002-2014 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ +package org.springframework.session; + +/** + * A repository interface for managing {@link Session} instances. + * + * @author Rob Winch + * @since 4.0 + */ +public interface SessionRepository { + /** + * Ensures the {@link Session} created by {@link org.springframework.security.session.SessionRepository#createSession()} is saved. + * + *

+ * Some implementations may choose to save as the {@link Session} is updated by returning a {@link Session} that + * immediately persists any changes. In this case, this method may not actually do anything. + *

+ * + * @param session the {@link Session} to save + */ + void save(S session); + + /** + * Gets the {@link Session} by the {@link Session#getId()} or null if no {@link Session} is found. + * @param id the {@link org.springframework.security.session.Session#getId()} to lookup + * @return the {@link Session} by the {@link Session#getId()} or null if no {@link Session} is found. + */ + Session getSession(String id); + + /** + * Deletes the {@link Session} with the given {@link Session#getId()} or does nothing if the {@link Session} is not found. + * @param id the {@link org.springframework.security.session.Session#getId()} to delete + */ + void delete(String id); + + /** + * Creates a new {@link Session} that is capable of being persisted by this {@link SessionRepository}. + * + *

This allows optimizations and customizations in how the {@link Session} is persisted. For example, the + * implementation returned might keep track of the changes ensuring that only the delta needs to be persisted on + * a save.

+ * + * @return a new {@link Session} that is capable of being persisted by this {@link SessionRepository} + */ + S createSession(); +} \ No newline at end of file diff --git a/src/main/java/org/springframework/session/redis/RedisOperationsSessionRepository.java b/src/main/java/org/springframework/session/redis/RedisOperationsSessionRepository.java new file mode 100644 index 00000000..5957e035 --- /dev/null +++ b/src/main/java/org/springframework/session/redis/RedisOperationsSessionRepository.java @@ -0,0 +1,169 @@ +/* + * Copyright 2002-2013 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ +package org.springframework.session.redis; + +import org.springframework.data.redis.core.BoundHashOperations; +import org.springframework.data.redis.core.RedisOperations; +import org.springframework.session.MapSession; +import org.springframework.session.Session; +import org.springframework.session.SessionRepository; + +import java.util.HashMap; +import java.util.Map; +import java.util.Set; +import java.util.concurrent.TimeUnit; + +/** + * @author Rob Winch + */ +public class RedisOperationsSessionRepository implements SessionRepository { + private final String BOUNDED_HASH_KEY_PREFIX = "spring-security-sessions:"; + private final String CREATION_TIME_ATTR = "creationTime"; + private final String MAX_INACTIVE_ATTR = "maxInactiveInterval"; + private final String LAST_ACCESSED_ATTR = "lastAccessedTime"; + private final String SESSION_ATTR_PREFIX = "sessionAttr:"; + + + private final RedisOperations redisTemplate; + + public RedisOperationsSessionRepository(RedisOperations redisTemplate) { + this.redisTemplate = redisTemplate; + } + + @Override + public void save(RedisSession session) { + session.saveDelta(); + } + + @Override + public Session getSession(String id) { + Map entries = getOperations(id).entries(); + if(entries.isEmpty()) { + return null; + } + MapSession loaded = new MapSession(); + loaded.setId(id); + for(Map.Entry entry : entries.entrySet()) { + String key = (String) entry.getKey(); + if(CREATION_TIME_ATTR.equals(key)) { + loaded.setCreationTime((Long) entry.getValue()); + } else if(MAX_INACTIVE_ATTR.equals(key)) { + loaded.setMaxInactiveInterval((Integer) entry.getValue()); + } else if(LAST_ACCESSED_ATTR.equals(key)) { + loaded.setLastAccessedTime((Long) entry.getValue()); + } else if(key.startsWith(SESSION_ATTR_PREFIX)) { + loaded.setAttribute(key.substring(SESSION_ATTR_PREFIX.length()), entry.getValue()); + } + } + return new RedisSession(loaded); + } + + @Override + public void delete(String sessionId) { + String key = getKey(sessionId); + this.redisTemplate.delete(key); + } + + @Override + public RedisSession createSession() { + return new RedisSession(); + } + + private String getKey(String sessionId) { + return BOUNDED_HASH_KEY_PREFIX + sessionId; + } + + private BoundHashOperations getOperations(String sessionId) { + String key = getKey(sessionId); + return this.redisTemplate.boundHashOps(key); + } + + class RedisSession implements Session { + private final MapSession cached; + private Map delta = new HashMap(); + + private RedisSession() { + this(new MapSession()); + delta.put(CREATION_TIME_ATTR, getCreationTime()); + delta.put(MAX_INACTIVE_ATTR, getMaxInactiveInterval()); + delta.put(LAST_ACCESSED_ATTR, getLastAccessedTime()); + } + + private RedisSession(MapSession cached) { + this.cached = cached; + } + + @Override + public void setLastAccessedTime(long lastAccessedTime) { + cached.setLastAccessedTime(lastAccessedTime); + delta.put(LAST_ACCESSED_ATTR, getLastAccessedTime()); + } + + @Override + public long getCreationTime() { + return cached.getCreationTime(); + } + + @Override + public String getId() { + return cached.getId(); + } + + @Override + public long getLastAccessedTime() { + return cached.getLastAccessedTime(); + } + + @Override + public void setMaxInactiveInterval(int interval) { + cached.setMaxInactiveInterval(interval); + delta.put(MAX_INACTIVE_ATTR, getMaxInactiveInterval()); + } + + @Override + public int getMaxInactiveInterval() { + return cached.getMaxInactiveInterval(); + } + + @Override + public Object getAttribute(String attributeName) { + return cached.getAttribute(attributeName); + } + + @Override + public Set getAttributeNames() { + return cached.getAttributeNames(); + } + + @Override + public void setAttribute(String attributeName, Object attributeValue) { + cached.setAttribute(attributeName, attributeValue); + delta.put(SESSION_ATTR_PREFIX + attributeName, attributeValue); + } + + @Override + public void removeAttribute(String attributeName) { + cached.removeAttribute(attributeName); + delta.put(SESSION_ATTR_PREFIX + attributeName, null); + } + + private void saveDelta() { + getOperations(getId()).putAll(delta); + getOperations(getId()).expire(getMaxInactiveInterval(), TimeUnit.SECONDS); + delta.clear(); + } + } +} diff --git a/src/main/java/org/springframework/session/web/CookieHttpSessionStrategy.java b/src/main/java/org/springframework/session/web/CookieHttpSessionStrategy.java new file mode 100644 index 00000000..63b70152 --- /dev/null +++ b/src/main/java/org/springframework/session/web/CookieHttpSessionStrategy.java @@ -0,0 +1,112 @@ +/* + * Copyright 2002-2013 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ +package org.springframework.session.web; + +import org.springframework.session.Session; +import org.springframework.util.Assert; + +import javax.servlet.http.Cookie; +import javax.servlet.http.HttpServletRequest; +import javax.servlet.http.HttpServletResponse; + +/** + * A {@link HttpSessionStrategy} that uses a cookie to obtain the session from. Specifically, this implementation will + * allow specifying a cookie name using {@link CookieHttpSessionStrategy#setCookieName(String)}. The default is "SESSION". + * + * When a session is created, the HTTP response will have a cookie with the specified cookie name and the value of the + * session id. The cookie will be marked as a session cookie, marked as HTTPOnly, and if + * {@link javax.servlet.http.HttpServletRequest#isSecure()} returns true, the cookie will be marked as secure. For example: + * + *
+ * HTTP/1.1 200 OK
+ * Set-Cookie: SESSION=f81d4fae-7dec-11d0-a765-00a0c91e6bf6; Secure; HttpOnly
+ * 
+ * + * The client should now include the session in each request by specifying the same cookie in their request. For example: + * + *
+ * GET /messages/ HTTP/1.1
+ * Host: example.com
+ * Cookie: SESSION=f81d4fae-7dec-11d0-a765-00a0c91e6bf6
+ * 
+ * + * When the session is invalidated, the server will send an HTTP response that expires the cookie. For example: + * + *
+ * HTTP/1.1 200 OK
+ * Set-Cookie: SESSION=f81d4fae-7dec-11d0-a765-00a0c91e6bf6; Expires=Thur, 1 Jan 1970 00:00:00 GMT; Secure; HttpOnly
+ * 
+ * + * @author Rob Winch + */ +public final class CookieHttpSessionStrategy implements HttpSessionStrategy { + private String cookieName = "SESSION"; + + @Override + public String getRequestedSessionId(HttpServletRequest request) { + Cookie session = getCookie(request, cookieName); + return session == null ? null : session.getValue(); + } + + @Override + public void onNewSession(Session session, HttpServletRequest request, HttpServletResponse response) { + Cookie cookie = new Cookie(cookieName, session.getId()); + cookie.setHttpOnly(true); + cookie.setSecure(request.isSecure()); + response.addCookie(cookie); + // TODO set the path? + // TODO set the domain? + } + + @Override + public void onInvalidateSession(HttpServletRequest request, HttpServletResponse response) { + Cookie sessionCookie = new Cookie(cookieName,""); + sessionCookie.setMaxAge(0); + sessionCookie.setHttpOnly(true); + sessionCookie.setSecure(request.isSecure()); + response.addCookie(sessionCookie); + } + + /** + * Sets the name of the cookie to be used + * @param cookieName + */ + public void setCookieName(String cookieName) { + Assert.notNull(cookieName, "cookieName cannot be null"); + this.cookieName = cookieName; + } + + /** + * Retrieve the first cookie with the given name. Note that multiple + * cookies can have the same name but different paths or domains. + * @param request current servlet request + * @param name cookie name + * @return the first cookie with the given name, or {@code null} if none is found + */ + private static Cookie getCookie(HttpServletRequest request, String name) { + // TODO what if there are multiple by the same name w/ different path + Assert.notNull(request, "Request must not be null"); + Cookie cookies[] = request.getCookies(); + if (cookies != null) { + for (Cookie cookie : cookies) { + if (name.equals(cookie.getName())) { + return cookie; + } + } + } + return null; + } +} \ No newline at end of file diff --git a/src/main/java/org/springframework/session/web/HeaderHttpSessionStrategy.java b/src/main/java/org/springframework/session/web/HeaderHttpSessionStrategy.java new file mode 100644 index 00000000..6ed89ae4 --- /dev/null +++ b/src/main/java/org/springframework/session/web/HeaderHttpSessionStrategy.java @@ -0,0 +1,79 @@ +/* + * Copyright 2002-2013 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ +package org.springframework.session.web; + +import org.springframework.session.Session; +import org.springframework.util.Assert; + +import javax.servlet.http.HttpServletRequest; +import javax.servlet.http.HttpServletResponse; + +/** + * A {@link HttpSessionStrategy} that uses a header to obtain the session from. Specifically, this implementation will + * allow specifying a header name using {@link HeaderHttpSessionStrategy#setHeaderName(String)}. The default is "x-auth-token". + * + * When a session is created, the HTTP response will have a response header of the specified name and the value of the session id. For example: + * + *
+ * HTTP/1.1 200 OK
+ * x-auth-token: f81d4fae-7dec-11d0-a765-00a0c91e6bf6
+ * 
+ * + * The client should now include the session in each request by specifying the same header in their request. For example: + * + *
+ * GET /messages/ HTTP/1.1
+ * Host: example.com
+ * x-auth-token: f81d4fae-7dec-11d0-a765-00a0c91e6bf6
+ * 
+ * + * When the session is invalidated, the server will send an HTTP response that has the header name and a blank value. For example: + * + *
+ * HTTP/1.1 200 OK
+ * x-auth-token:
+ * 
+ * + * @author Rob Winch + */ +public class HeaderHttpSessionStrategy implements HttpSessionStrategy { + private String headerName = "x-auth-token"; + + @Override + public String getRequestedSessionId(HttpServletRequest request) { + return request.getHeader(headerName); + } + + @Override + public void onNewSession(Session session, HttpServletRequest request, HttpServletResponse response) { + response.addHeader(headerName, session.getId()); + } + + @Override + public void onInvalidateSession(HttpServletRequest request, HttpServletResponse response) { + response.addHeader(headerName, ""); + } + + /** + * The name of the header to obtain the session id from. Default is "x-auth-token". + * + * @param headerName the name of the header to obtain the session id from. + */ + public void setHeaderName(String headerName) { + Assert.notNull(headerName, "headerName cannot be null"); + this.headerName = headerName; + } +} \ No newline at end of file diff --git a/src/main/java/org/springframework/session/web/HttpSessionStrategy.java b/src/main/java/org/springframework/session/web/HttpSessionStrategy.java new file mode 100644 index 00000000..db240a6f --- /dev/null +++ b/src/main/java/org/springframework/session/web/HttpSessionStrategy.java @@ -0,0 +1,62 @@ +/* + * Copyright 2002-2014 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ +package org.springframework.session.web; + +import org.springframework.session.Session; + +import javax.servlet.http.HttpServletRequest; +import javax.servlet.http.HttpServletResponse; + +/** + * A strategy for mapping HTTP request and responses to a {@link Session}. + * + * @author Rob Winch + */ +public interface HttpSessionStrategy { + + /** + * Obtains the requested session id from the provided {@link javax.servlet.http.HttpServletRequest}. For example, + * the session id might come from a cookie or a request header. + * + * @param request the {@link javax.servlet.http.HttpServletRequest} to obtain the session id from. Cannot be null. + * @return the {@link javax.servlet.http.HttpServletRequest} to obtain the session id from. + */ + String getRequestedSessionId(HttpServletRequest request); + + /** + * This method is invoked when a new session is created and should inform a client what the new session id is. For + * example, it might create a new cookie with the session id in it or set an HTTP response header with the value of + * the new session id. + * + * Some implementations may wish to associate additional information to the {@link Session} at this time. For example, they + * may wish to add the IP Address, browser headers, the username, etc to the {@link org.springframework.session.Session}. + * + * @param session the {@link org.springframework.session.Session} that is being sent to the client. Cannot be null. + * @param request the {@link javax.servlet.http.HttpServletRequest} that create the new {@link org.springframework.session.Session} Cannot be null. + * @param response the {@link javax.servlet.http.HttpServletResponse} that is associated with the {@link javax.servlet.http.HttpServletRequest} that created the new {@link org.springframework.session.Session} Cannot be null. + */ + void onNewSession(Session session, HttpServletRequest request, HttpServletResponse response); + + /** + * This method is invoked when a session is invalidated and should inform a client that the session id is no longer valid. For + * example, it might remove a cookie with the session id in it or set an HTTP response header with an empty value indicating + * to the client to no longer submit that session id. + * + * @param request the {@link javax.servlet.http.HttpServletRequest} that invalidated the {@link org.springframework.session.Session} Cannot be null. + * @param response the {@link javax.servlet.http.HttpServletResponse} that is associated with the {@link javax.servlet.http.HttpServletRequest} that invalidated the {@link org.springframework.session.Session} Cannot be null. + */ + void onInvalidateSession(HttpServletRequest request, HttpServletResponse response); +} diff --git a/src/main/java/org/springframework/session/web/OnCommittedResponseWrapper.java b/src/main/java/org/springframework/session/web/OnCommittedResponseWrapper.java new file mode 100644 index 00000000..1ed0f3d5 --- /dev/null +++ b/src/main/java/org/springframework/session/web/OnCommittedResponseWrapper.java @@ -0,0 +1,400 @@ +/* + * Copyright 2002-2012 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software distributed under the License is distributed on + * an "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the License for the + * specific language governing permissions and limitations under the License. + */ +package org.springframework.session.web; + +import org.apache.commons.logging.Log; +import org.apache.commons.logging.LogFactory; + +import javax.servlet.ServletOutputStream; +import javax.servlet.http.HttpServletResponse; +import javax.servlet.http.HttpServletResponseWrapper; +import java.io.IOException; +import java.io.PrintWriter; +import java.util.Locale; + +/** + * Base class for response wrappers which encapsulate the logic for handling an event when the + * {@link javax.servlet.http.HttpServletResponse} is committed. + * + * @author Rob Winch + */ +abstract class OnCommittedResponseWrapper extends HttpServletResponseWrapper { + private final Log logger = LogFactory.getLog(getClass()); + + private boolean disableOnCommitted; + + /** + * @param response the response to be wrapped + */ + public OnCommittedResponseWrapper(HttpServletResponse response) { + super(response); + } + + /** + * Invoke this method to disable invoking {@link OnCommittedResponseWrapper#onResponseCommitted()} when the {@link javax.servlet.http.HttpServletResponse} is + * committed. This can be useful in the event that Async Web Requests are + * made. + */ + public void disableOnResponseCommitted() { + this.disableOnCommitted = true; + } + + /** + * Implement the logic for handling the {@link javax.servlet.http.HttpServletResponse} being committed + */ + protected abstract void onResponseCommitted(); + + /** + * Makes sure {@link OnCommittedResponseWrapper#onResponseCommitted()} is invoked before calling the + * superclass sendError() + */ + @Override + public final void sendError(int sc) throws IOException { + doOnResponseCommitted(); + super.sendError(sc); + } + + /** + * Makes sure {@link OnCommittedResponseWrapper#onResponseCommitted()} is invoked before calling the + * superclass sendError() + */ + @Override + public final void sendError(int sc, String msg) throws IOException { + doOnResponseCommitted(); + super.sendError(sc, msg); + } + + /** + * Makes sure {@link OnCommittedResponseWrapper#onResponseCommitted()} is invoked before calling the + * superclass sendRedirect() + */ + @Override + public final void sendRedirect(String location) throws IOException { + doOnResponseCommitted(); + super.sendRedirect(location); + } + + /** + * Makes sure {@link OnCommittedResponseWrapper#onResponseCommitted()} is invoked before calling the calling + * getOutputStream().close() or getOutputStream().flush() + */ + @Override + public ServletOutputStream getOutputStream() throws IOException { + return new SaveContextServletOutputStream(super.getOutputStream()); + } + + /** + * Makes sure {@link OnCommittedResponseWrapper#onResponseCommitted()} is invoked before calling the + * getWriter().close() or getWriter().flush() + */ + @Override + public PrintWriter getWriter() throws IOException { + return new SaveContextPrintWriter(super.getWriter()); + } + + /** + * Makes sure {@link OnCommittedResponseWrapper#onResponseCommitted()} is invoked before calling the + * superclass flushBuffer() + */ + @Override + public void flushBuffer() throws IOException { + doOnResponseCommitted(); + super.flushBuffer(); + } + + /** + * Calls onResponseCommmitted() with the current contents as long as + * {@link #disableOnResponseCommitted()()} was not invoked. + */ + private void doOnResponseCommitted() { + if(!disableOnCommitted) { + onResponseCommitted(); + } else if(logger.isDebugEnabled()){ + logger.debug("Skip invoking on"); + } + } + + /** + * Ensures {@link OnCommittedResponseWrapper#onResponseCommitted()} is invoked before calling the prior to methods that commit the response. We delegate all methods + * to the original {@link java.io.PrintWriter} to ensure that the behavior is as close to the original {@link java.io.PrintWriter} + * as possible. See SEC-2039 + * @author Rob Winch + */ + private class SaveContextPrintWriter extends PrintWriter { + private final PrintWriter delegate; + + public SaveContextPrintWriter(PrintWriter delegate) { + super(delegate); + this.delegate = delegate; + } + + public void flush() { + doOnResponseCommitted(); + delegate.flush(); + } + + public void close() { + doOnResponseCommitted(); + delegate.close(); + } + + public int hashCode() { + return delegate.hashCode(); + } + + public boolean equals(Object obj) { + return delegate.equals(obj); + } + + public String toString() { + return getClass().getName() + "[delegate=" + delegate.toString() + "]"; + } + + public boolean checkError() { + return delegate.checkError(); + } + + public void write(int c) { + delegate.write(c); + } + + public void write(char[] buf, int off, int len) { + delegate.write(buf, off, len); + } + + public void write(char[] buf) { + delegate.write(buf); + } + + public void write(String s, int off, int len) { + delegate.write(s, off, len); + } + + public void write(String s) { + delegate.write(s); + } + + public void print(boolean b) { + delegate.print(b); + } + + public void print(char c) { + delegate.print(c); + } + + public void print(int i) { + delegate.print(i); + } + + public void print(long l) { + delegate.print(l); + } + + public void print(float f) { + delegate.print(f); + } + + public void print(double d) { + delegate.print(d); + } + + public void print(char[] s) { + delegate.print(s); + } + + public void print(String s) { + delegate.print(s); + } + + public void print(Object obj) { + delegate.print(obj); + } + + public void println() { + delegate.println(); + } + + public void println(boolean x) { + delegate.println(x); + } + + public void println(char x) { + delegate.println(x); + } + + public void println(int x) { + delegate.println(x); + } + + public void println(long x) { + delegate.println(x); + } + + public void println(float x) { + delegate.println(x); + } + + public void println(double x) { + delegate.println(x); + } + + public void println(char[] x) { + delegate.println(x); + } + + public void println(String x) { + delegate.println(x); + } + + public void println(Object x) { + delegate.println(x); + } + + public PrintWriter printf(String format, Object... args) { + return delegate.printf(format, args); + } + + public PrintWriter printf(Locale l, String format, Object... args) { + return delegate.printf(l, format, args); + } + + public PrintWriter format(String format, Object... args) { + return delegate.format(format, args); + } + + public PrintWriter format(Locale l, String format, Object... args) { + return delegate.format(l, format, args); + } + + public PrintWriter append(CharSequence csq) { + return delegate.append(csq); + } + + public PrintWriter append(CharSequence csq, int start, int end) { + return delegate.append(csq, start, end); + } + + public PrintWriter append(char c) { + return delegate.append(c); + } + } + + /** + * Ensures{@link OnCommittedResponseWrapper#onResponseCommitted()} is invoked before calling methods that commit the response. We delegate all methods + * to the original {@link javax.servlet.ServletOutputStream} to ensure that the behavior is as close to the original {@link javax.servlet.ServletOutputStream} + * as possible. See SEC-2039 + * + * @author Rob Winch + */ + private class SaveContextServletOutputStream extends ServletOutputStream { + private final ServletOutputStream delegate; + + public SaveContextServletOutputStream(ServletOutputStream delegate) { + this.delegate = delegate; + } + + public void write(int b) throws IOException { + this.delegate.write(b); + } + + public void flush() throws IOException { + doOnResponseCommitted(); + delegate.flush(); + } + + public void close() throws IOException { + doOnResponseCommitted(); + delegate.close(); + } + + public int hashCode() { + return delegate.hashCode(); + } + + public boolean equals(Object obj) { + return delegate.equals(obj); + } + + public void print(boolean b) throws IOException { + delegate.print(b); + } + + public void print(char c) throws IOException { + delegate.print(c); + } + + public void print(double d) throws IOException { + delegate.print(d); + } + + public void print(float f) throws IOException { + delegate.print(f); + } + + public void print(int i) throws IOException { + delegate.print(i); + } + + public void print(long l) throws IOException { + delegate.print(l); + } + + public void print(String arg0) throws IOException { + delegate.print(arg0); + } + + public void println() throws IOException { + delegate.println(); + } + + public void println(boolean b) throws IOException { + delegate.println(b); + } + + public void println(char c) throws IOException { + delegate.println(c); + } + + public void println(double d) throws IOException { + delegate.println(d); + } + + public void println(float f) throws IOException { + delegate.println(f); + } + + public void println(int i) throws IOException { + delegate.println(i); + } + + public void println(long l) throws IOException { + delegate.println(l); + } + + public void println(String s) throws IOException { + delegate.println(s); + } + + public void write(byte[] b) throws IOException { + delegate.write(b); + } + + public void write(byte[] b, int off, int len) throws IOException { + delegate.write(b, off, len); + } + + public String toString() { + return getClass().getName() + "[delegate=" + delegate.toString() + "]"; + } + } +} \ No newline at end of file diff --git a/src/main/java/org/springframework/session/web/OncePerRequestFilter.java b/src/main/java/org/springframework/session/web/OncePerRequestFilter.java new file mode 100644 index 00000000..0d068c18 --- /dev/null +++ b/src/main/java/org/springframework/session/web/OncePerRequestFilter.java @@ -0,0 +1,88 @@ +/* + * Copyright 2002-2013 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ +package org.springframework.session.web; + +import javax.servlet.*; +import javax.servlet.http.HttpServletRequest; +import javax.servlet.http.HttpServletResponse; +import java.io.IOException; + +/** + * Allows for easily ensuring that a request is only invoked once per request. This is a simplified version of spring-web's + * OncePerRequestFilter and copied to reduce the foot print required to use the session support. + * + * @author Rob Winch + */ +abstract class OncePerRequestFilter implements Filter { + /** + * Suffix that gets appended to the filter name for the + * "already filtered" request attribute. + */ + public static final String ALREADY_FILTERED_SUFFIX = ".FILTERED"; + + private String alreadyFilteredAttributeName = getClass().getName().concat(ALREADY_FILTERED_SUFFIX); + + + /** + * This {@code doFilter} implementation stores a request attribute for + * "already filtered", proceeding without filtering again if the + * attribute is already there. + */ + @Override + public final void doFilter(ServletRequest request, ServletResponse response, FilterChain filterChain) + throws ServletException, IOException { + + if (!(request instanceof HttpServletRequest) || !(response instanceof HttpServletResponse)) { + throw new ServletException("OncePerRequestFilter just supports HTTP requests"); + } + HttpServletRequest httpRequest = (HttpServletRequest) request; + HttpServletResponse httpResponse = (HttpServletResponse) response; + boolean hasAlreadyFilteredAttribute = request.getAttribute(alreadyFilteredAttributeName) != null; + + + if (hasAlreadyFilteredAttribute) { + + // Proceed without invoking this filter... + filterChain.doFilter(request, response); + } + else { + // Do invoke this filter... + request.setAttribute(alreadyFilteredAttributeName, Boolean.TRUE); + try { + doFilterInternal(httpRequest, httpResponse, filterChain); + } + finally { + // Remove the "already filtered" request attribute for this request. + request.removeAttribute(alreadyFilteredAttributeName); + } + } + } + + + /** + * Same contract as for {@code doFilter}, but guaranteed to be + * just invoked once per request within a single request thread. + *

Provides HttpServletRequest and HttpServletResponse arguments instead of the + * default ServletRequest and ServletResponse ones. + */ + protected abstract void doFilterInternal( + HttpServletRequest request, HttpServletResponse response, FilterChain filterChain) + throws ServletException, IOException; + + public void init(FilterConfig config) {} + + public void destroy() {} +} diff --git a/src/main/java/org/springframework/session/web/SessionRepositoryFilter.java b/src/main/java/org/springframework/session/web/SessionRepositoryFilter.java new file mode 100644 index 00000000..2f64df7f --- /dev/null +++ b/src/main/java/org/springframework/session/web/SessionRepositoryFilter.java @@ -0,0 +1,319 @@ +/* + * Copyright 2002-2013 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); you may not + * use this file except in compliance with the License. You may obtain a copy of + * the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, WITHOUT + * WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. See the + * License for the specific language governing permissions and limitations under + * the License. + */ +package org.springframework.session.web; + +import org.springframework.session.Session; +import org.springframework.session.SessionRepository; +import org.springframework.util.Assert; + +import javax.servlet.FilterChain; +import javax.servlet.ServletContext; +import javax.servlet.ServletException; +import javax.servlet.http.*; +import java.io.IOException; +import java.util.Collections; +import java.util.Enumeration; +import java.util.NoSuchElementException; +import java.util.Set; + +/** + * Switches the {@link javax.servlet.http.HttpSession} implementation to be backed by a {@link org.springframework.session.Session}. + * + * The {@link SessionRepositoryFilter} wraps the {@link javax.servlet.http.HttpServletRequest} and overrides the methods + * to get an {@link javax.servlet.http.HttpSession} to be backed by a {@link org.springframework.session.Session} returned + * by the {@link org.springframework.session.SessionRepository}. + * + * The {@link SessionRepositoryFilter} uses a {@link HttpSessionStrategy} (default {@link CookieHttpSessionStrategy} to + * bridge logic between an {@link javax.servlet.http.HttpSession} and the {@link org.springframework.session.Session} + * abstraction. Specifically: + * + *

    + *
  • The session id is looked up using {@link HttpSessionStrategy#getRequestedSessionId(javax.servlet.http.HttpServletRequest)}. + * The default is to look in a cookie named SESSION.
  • + *
  • The session id of newly created {@link org.springframework.session.Session} is sent to the client using + * {@link HttpSessionStrategy#onNewSession(org.springframework.session.Session, javax.servlet.http.HttpServletRequest, javax.servlet.http.HttpServletResponse)}
  • + *
  • The client is notified that the session id is no longer valid with {@link HttpSessionStrategy#onInvalidateSession(javax.servlet.http.HttpServletRequest, javax.servlet.http.HttpServletResponse)}
  • + *
+ * + * session id is looked up using the provided {@link HttpSessionStrategy}. The same strategy is used to convey the + * session id of newly created {@link org.springframework.session.Session}s to the client. + * + * @author Rob Winch + */ +public class SessionRepositoryFilter extends OncePerRequestFilter { + private final SessionRepository sessionRepository; + + private HttpSessionStrategy httpSessionStrategy = new CookieHttpSessionStrategy(); + + public SessionRepositoryFilter(SessionRepository sessionRepository) { + this.sessionRepository = sessionRepository; + } + + /** + * Sets the {@link HttpSessionStrategy} to be used. The default is a {@link CookieHttpSessionStrategy}. + * + * @param httpSessionStrategy the {@link HttpSessionStrategy} to use. Cannot be null. + */ + public void setHttpSessionStrategy(HttpSessionStrategy httpSessionStrategy) { + Assert.notNull(httpSessionStrategy,"httpSessionIdStrategy cannot be null"); + this.httpSessionStrategy = httpSessionStrategy; + } + + protected void doFilterInternal(HttpServletRequest request, HttpServletResponse response, FilterChain filterChain) throws ServletException, IOException { + SessionRepositoryRequestWrapper wrappedRequest = new SessionRepositoryRequestWrapper(request, response); + SessionRepositoryResponseWrapper wrappedResponse = new SessionRepositoryResponseWrapper(wrappedRequest,response); + try { + filterChain.doFilter(wrappedRequest, wrappedResponse); + } finally { + wrappedRequest.commitSession(); + } + } + + private static final class SessionRepositoryResponseWrapper extends OnCommittedResponseWrapper { + + private final SessionRepositoryRequestWrapper request; + + /** + * @param response the response to be wrapped + */ + public SessionRepositoryResponseWrapper(SessionRepositoryRequestWrapper request, HttpServletResponse response) { + super(response); + this.request = request; + } + + @Override + protected void onResponseCommitted() { + request.commitSession(); + } + } + + /** + * A {@link javax.servlet.http.HttpServletRequest} that retrieves the {@link javax.servlet.http.HttpSession} using a + * {@link org.springframework.session.SessionRepository}. + * + * @author Rob Winch + * @since 4.0 + */ + private final class SessionRepositoryRequestWrapper extends HttpServletRequestWrapper { + private HttpSessionWrapper currentSession; + private boolean requestedValidSession; + private final HttpServletResponse response; + + private SessionRepositoryRequestWrapper(HttpServletRequest request, HttpServletResponse response) { + super(request); + this.response = response; + } + + private void commitSession() { + HttpSessionWrapper wrappedSession = currentSession; + if(wrappedSession == null) { + if(isInvalidateClientSession()) { + httpSessionStrategy.onInvalidateSession(this, response); + } + } else { + Session session = wrappedSession.session; + sessionRepository.save(session); + httpSessionStrategy.onNewSession(session, this, response); + } + } + + private boolean isInvalidateClientSession() { + return currentSession == null && requestedValidSession; + } + + @Override + public HttpSession getSession(boolean create) { + if(currentSession != null) { + return currentSession; + } + String requestedSessionId = getRequestedSessionId(); + if(requestedSessionId != null) { + Session session = sessionRepository.getSession(requestedSessionId); + if(session != null) { + this.requestedValidSession = true; + session.setLastAccessedTime(System.currentTimeMillis()); + currentSession = new HttpSessionWrapper(session, getServletContext()); + currentSession.setNew(false); + return currentSession; + } + } + if(!create) { + return null; + } + Session session = sessionRepository.createSession(); + currentSession = new HttpSessionWrapper(session, getServletContext()); + return currentSession; + } + + @Override + public HttpSession getSession() { + return getSession(true); + } + + @Override + public String getRequestedSessionId() { + return httpSessionStrategy.getRequestedSessionId(this); + } + + private final class HttpSessionWrapper implements HttpSession { + final Session session; + private final ServletContext servletContext; + private boolean invalidated; + private boolean old; + + public HttpSessionWrapper(Session session, ServletContext servletContext) { + this.session = session; + this.servletContext = servletContext; + } + + void updateLastAccessedTime() { + checkState(); + session.setLastAccessedTime(System.currentTimeMillis()); + } + + @Override + public long getCreationTime() { + checkState(); + return session.getCreationTime(); + } + + @Override + public String getId() { + return session.getId(); + } + + @Override + public long getLastAccessedTime() { + checkState(); + return session.getLastAccessedTime(); + } + + @Override + public ServletContext getServletContext() { + return servletContext; + } + + @Override + public void setMaxInactiveInterval(int interval) { + session.setMaxInactiveInterval(interval); + } + + @Override + public int getMaxInactiveInterval() { + return session.getMaxInactiveInterval(); + } + + @Override + public HttpSessionContext getSessionContext() { + return NOOP_SESSION_CONTEXT; + } + + @Override + public Object getAttribute(String name) { + checkState(); + return session.getAttribute(name); + } + + @Override + public Object getValue(String name) { + return getAttribute(name); + } + + @Override + public Enumeration getAttributeNames() { + checkState(); + return Collections.enumeration(session.getAttributeNames()); + } + + @Override + public String[] getValueNames() { + checkState(); + Set attrs = session.getAttributeNames(); + return attrs.toArray(new String[0]); + } + + @Override + public void setAttribute(String name, Object value) { + checkState(); + session.setAttribute(name, value); + } + + @Override + public void putValue(String name, Object value) { + setAttribute(name, value); + } + + @Override + public void removeAttribute(String name) { + checkState(); + session.removeAttribute(name); + } + + @Override + public void removeValue(String name) { + removeAttribute(name); + } + + @Override + public final void invalidate() { + checkState(); + this.invalidated = true; + currentSession = null; + sessionRepository.delete(getId()); + } + + public void setNew(boolean isNew) { + this.old = !isNew; + } + + @Override + public boolean isNew() { + checkState(); + return !old; + } + + private void checkState() { + if(invalidated) { + throw new IllegalStateException("The HttpSession has already be invalidated."); + } + } + } + } + + private static final HttpSessionContext NOOP_SESSION_CONTEXT = new HttpSessionContext() { + @Override + public HttpSession getSession(String sessionId) { + return null; + } + + @Override + public Enumeration getIds() { + return EMPTY_ENUMERATION; + } + }; + + private final static Enumeration EMPTY_ENUMERATION = new Enumeration() { + @Override + public boolean hasMoreElements() { + return false; + } + + @Override + public String nextElement() { + throw new NoSuchElementException("a"); + } + }; +} diff --git a/src/test/java/org/springframework/session/MapSessionTests.java b/src/test/java/org/springframework/session/MapSessionTests.java new file mode 100644 index 00000000..175946f7 --- /dev/null +++ b/src/test/java/org/springframework/session/MapSessionTests.java @@ -0,0 +1,106 @@ +package org.springframework.session; + +import org.junit.Before; +import org.junit.Test; + +import java.util.Set; + +import static org.fest.assertions.Assertions.assertThat; + +public class MapSessionTests { + + private MapSession session; + + @Before + public void setup() { + session = new MapSession(); + } + + @Test(expected = IllegalArgumentException.class) + public void constructorNullSession() { + new MapSession(null); + } + + /** + * Ensure conforms to the javadoc of {@link Session} + */ + @Test + public void setAttributeNullObjectRemoves() { + String attr = "attr"; + session.setAttribute(attr, new Object()); + session.setAttribute(attr, null); + assertThat(session.getAttributeNames()).isEmpty(); + } + + @Test + public void equalsNonSessionFalse() { + assertThat(session.equals(new Object())).isFalse(); + } + + @Test + public void equalsCustomSession() { + CustomSession other = new CustomSession(); + session.setId(other.getId()); + assertThat(session.equals(other)).isTrue(); + } + + @Test + public void hashCodeEqualsIdHashCode() { + session.setId("constantId"); + assertThat(session.hashCode()).isEqualTo(session.getId().hashCode()); + } + + static class CustomSession implements Session { + + @Override + public void setLastAccessedTime(long lastAccessedTime) { + + } + + @Override + public long getCreationTime() { + return 0; + } + + @Override + public String getId() { + return "id"; + } + + @Override + public long getLastAccessedTime() { + return 0; + } + + @Override + public void setMaxInactiveInterval(int interval) { + + } + + @Override + public int getMaxInactiveInterval() { + return 0; + } + + @Override + public Object getAttribute(String attributeName) { + return null; + } + + @Override + public Set getAttributeNames() { + return null; + } + + @Override + public void setAttribute(String attributeName, Object attributeValue) { + + } + + @Override + public void removeAttribute(String attributeName) { + + } + } + +} \ No newline at end of file diff --git a/src/test/java/org/springframework/session/web/CookieHttpSessionStrategyTests.java b/src/test/java/org/springframework/session/web/CookieHttpSessionStrategyTests.java new file mode 100644 index 00000000..d3fe3da2 --- /dev/null +++ b/src/test/java/org/springframework/session/web/CookieHttpSessionStrategyTests.java @@ -0,0 +1,92 @@ +package org.springframework.session.web; + +import static org.fest.assertions.Assertions.*; + +import org.junit.Before; +import org.junit.Test; +import org.springframework.mock.web.MockHttpServletRequest; +import org.springframework.mock.web.MockHttpServletResponse; +import org.springframework.session.MapSession; +import org.springframework.session.Session; + +import javax.servlet.http.Cookie; + +public class CookieHttpSessionStrategyTests { + private MockHttpServletRequest request; + private MockHttpServletResponse response; + + private CookieHttpSessionStrategy strategy; + private String cookieName; + private Session session; + + @Before + public void setup() throws Exception { + cookieName = "SESSION"; + session = new MapSession(); + request = new MockHttpServletRequest(); + response = new MockHttpServletResponse(); + strategy = new CookieHttpSessionStrategy(); + } + + @Test + public void getRequestedSessionIdNull() throws Exception { + assertThat(strategy.getRequestedSessionId(request)).isNull(); + } + + @Test + public void getRequestedSessionIdNotNull() throws Exception { + setSessionId(session.getId()); + assertThat(strategy.getRequestedSessionId(request)).isEqualTo(session.getId()); + } + + @Test + public void getRequestedSessionIdNotNullCustomCookieName() throws Exception { + setCookieName("CUSTOM"); + setSessionId(session.getId()); + assertThat(strategy.getRequestedSessionId(request)).isEqualTo(session.getId()); + } + + @Test + public void onNewSession() throws Exception { + strategy.onNewSession(session, request, response); + assertThat(getSessionId()).isEqualTo(session.getId()); + } + + @Test + public void onNewSessionCustomCookieName() throws Exception { + setCookieName("CUSTOM"); + strategy.onNewSession(session, request, response); + assertThat(getSessionId()).isEqualTo(session.getId()); + } + + @Test + public void onDeleteSession() throws Exception { + strategy.onInvalidateSession(request, response); + assertThat(getSessionId()).isEmpty(); + } + + @Test + public void onDeleteSessionCustomCookieName() throws Exception { + setCookieName("CUSTOM"); + strategy.onInvalidateSession(request, response); + assertThat(getSessionId()).isEmpty(); + } + + @Test(expected = IllegalArgumentException.class) + public void setCookieNameNull() throws Exception { + strategy.setCookieName(null); + } + + public void setCookieName(String cookieName) { + strategy.setCookieName(cookieName); + this.cookieName = cookieName; + } + + public void setSessionId(String id) { + request.setCookies(new Cookie(cookieName, id)); + } + + public String getSessionId() { + return response.getCookie(cookieName).getValue(); + } +} \ No newline at end of file diff --git a/src/test/java/org/springframework/session/web/HeaderSessionStrategyTests.java b/src/test/java/org/springframework/session/web/HeaderSessionStrategyTests.java new file mode 100644 index 00000000..edf0eae8 --- /dev/null +++ b/src/test/java/org/springframework/session/web/HeaderSessionStrategyTests.java @@ -0,0 +1,90 @@ +package org.springframework.session.web; + +import org.junit.Before; +import org.junit.Test; +import org.springframework.mock.web.MockHttpServletRequest; +import org.springframework.mock.web.MockHttpServletResponse; +import org.springframework.session.MapSession; +import org.springframework.session.Session; + +import static org.fest.assertions.Assertions.assertThat; + +public class HeaderSessionStrategyTests { + private MockHttpServletRequest request; + private MockHttpServletResponse response; + + private HeaderHttpSessionStrategy strategy; + private String headerName; + private Session session; + + @Before + public void setup() throws Exception { + headerName = "x-auth-token"; + session = new MapSession(); + request = new MockHttpServletRequest(); + response = new MockHttpServletResponse(); + strategy = new HeaderHttpSessionStrategy(); + } + + @Test + public void getRequestedSessionIdNull() throws Exception { + assertThat(strategy.getRequestedSessionId(request)).isNull(); + } + + @Test + public void getRequestedSessionIdNotNull() throws Exception { + setSessionId(session.getId()); + assertThat(strategy.getRequestedSessionId(request)).isEqualTo(session.getId()); + } + + @Test + public void getRequestedSessionIdNotNullCustomHeaderName() throws Exception { + setHeaderName("CUSTOM"); + setSessionId(session.getId()); + assertThat(strategy.getRequestedSessionId(request)).isEqualTo(session.getId()); + } + + @Test + public void onNewSession() throws Exception { + strategy.onNewSession(session, request, response); + assertThat(getSessionId()).isEqualTo(session.getId()); + } + + @Test + public void onNewSessionCustomHeaderName() throws Exception { + setHeaderName("CUSTOM"); + strategy.onNewSession(session, request, response); + assertThat(getSessionId()).isEqualTo(session.getId()); + } + + @Test + public void onDeleteSession() throws Exception { + strategy.onInvalidateSession(request, response); + assertThat(getSessionId()).isEmpty(); + } + + @Test + public void onDeleteSessionCustomHeaderName() throws Exception { + setHeaderName("CUSTOM"); + strategy.onInvalidateSession(request, response); + assertThat(getSessionId()).isEmpty(); + } + + @Test(expected = IllegalArgumentException.class) + public void setHeaderNameNull() throws Exception { + strategy.setHeaderName(null); + } + + public void setHeaderName(String headerName) { + strategy.setHeaderName(headerName); + this.headerName = headerName; + } + + public void setSessionId(String id) { + request.addHeader(headerName, id); + } + + public String getSessionId() { + return response.getHeader(headerName); + } +} \ No newline at end of file diff --git a/src/test/java/org/springframework/session/web/OncePerRequestFilterTests.java b/src/test/java/org/springframework/session/web/OncePerRequestFilterTests.java new file mode 100644 index 00000000..a155ed98 --- /dev/null +++ b/src/test/java/org/springframework/session/web/OncePerRequestFilterTests.java @@ -0,0 +1,73 @@ +package org.springframework.session.web; + +import org.junit.Before; +import org.junit.Test; +import org.springframework.mock.web.MockFilterChain; +import org.springframework.mock.web.MockHttpServletRequest; +import org.springframework.mock.web.MockHttpServletResponse; + +import javax.servlet.FilterChain; +import javax.servlet.ServletException; +import javax.servlet.http.HttpServlet; +import javax.servlet.http.HttpServletRequest; +import javax.servlet.http.HttpServletResponse; +import java.io.IOException; +import java.util.ArrayList; +import java.util.List; + +import static org.fest.assertions.Assertions.*; + +public class OncePerRequestFilterTests { + private MockHttpServletRequest request; + private MockHttpServletResponse response; + private MockFilterChain chain; + private OncePerRequestFilter filter; + private HttpServlet servlet; + + + private List invocations; + + @Before + public void setup() { + servlet = new HttpServlet() {}; + request = new MockHttpServletRequest(); + response = new MockHttpServletResponse(); + chain = new MockFilterChain(); + invocations = new ArrayList(); + filter = new OncePerRequestFilter() { + @Override + protected void doFilterInternal(HttpServletRequest request, HttpServletResponse response, FilterChain filterChain) throws ServletException, IOException { + invocations.add(this); + filterChain.doFilter(request, response); + } + }; + } + + @Test + public void doFilterOnce() throws ServletException, IOException { + filter.doFilter(request, response, chain); + + assertThat(invocations).containsOnly(filter); + } + + @Test + public void doFilterMultiOnlyIvokesOnce() throws ServletException, IOException { + filter.doFilter(request, response, new MockFilterChain(servlet, filter)); + + assertThat(invocations).containsOnly(filter); + } + + @Test + public void doFilterOtherSubclassInvoked() throws ServletException, IOException { + OncePerRequestFilter filter2 = new OncePerRequestFilter() { + @Override + protected void doFilterInternal(HttpServletRequest request, HttpServletResponse response, FilterChain filterChain) throws ServletException, IOException { + invocations.add(this); + filterChain.doFilter(request, response); + } + }; + filter.doFilter(request, response, new MockFilterChain(servlet, filter2)); + + assertThat(invocations).containsOnly(filter, filter2); + } +} \ No newline at end of file diff --git a/src/test/java/org/springframework/session/web/SessionRepositoryFilterTests.java b/src/test/java/org/springframework/session/web/SessionRepositoryFilterTests.java new file mode 100644 index 00000000..b931a6cd --- /dev/null +++ b/src/test/java/org/springframework/session/web/SessionRepositoryFilterTests.java @@ -0,0 +1,882 @@ +package org.springframework.session.web; + +import org.junit.Before; +import org.junit.Test; +import org.springframework.mock.web.MockFilterChain; +import org.springframework.mock.web.MockHttpServletRequest; +import org.springframework.mock.web.MockHttpServletResponse; +import org.springframework.session.MapSessionRepository; +import org.springframework.session.SessionRepository; + +import javax.servlet.FilterChain; +import javax.servlet.ServletContext; +import javax.servlet.ServletException; +import javax.servlet.http.*; +import java.io.IOException; +import java.util.Arrays; +import java.util.Collections; +import java.util.NoSuchElementException; + +import static org.fest.assertions.Assertions.assertThat; +import static org.junit.Assert.fail; + +public class SessionRepositoryFilterTests { + private final static String SESSION_ATTR_NAME = HttpSession.class.getName(); + + private SessionRepository sessionRepository; + + private SessionRepositoryFilter filter; + + private MockHttpServletRequest request; + + private MockHttpServletResponse response; + + private MockFilterChain chain; + + @Before + public void setup() throws Exception { + sessionRepository = new MapSessionRepository(); + filter = new SessionRepositoryFilter(sessionRepository); + request = new MockHttpServletRequest(); + response = new MockHttpServletResponse(); + chain = new MockFilterChain(); + } + + @Test + public void doFilterCreateDate() throws Exception { + final String CREATE_ATTR = "create"; + doFilter(new DoInFilter() { + @Override + public void doFilter(HttpServletRequest wrappedRequest) { + long creationTime = wrappedRequest.getSession().getCreationTime(); + long now = System.currentTimeMillis(); + assertThat(now - creationTime).isGreaterThanOrEqualTo(0).isLessThan(5000); + request.setAttribute(CREATE_ATTR, creationTime); + } + }); + + final long expectedCreationTime = (Long) request.getAttribute(CREATE_ATTR); + Thread.sleep(50L); + setupSession(); + + doFilter(new DoInFilter() { + @Override + public void doFilter(HttpServletRequest wrappedRequest) { + long creationTime = wrappedRequest.getSession().getCreationTime(); + + assertThat(creationTime).isEqualTo(expectedCreationTime); + } + }); + } + + @Test + public void doFilterLastAccessedTime() throws Exception { + final String ACCESS_ATTR = "create"; + doFilter(new DoInFilter() { + @Override + public void doFilter(HttpServletRequest wrappedRequest) { + long lastAccessed = wrappedRequest.getSession().getLastAccessedTime(); + assertThat(lastAccessed).isEqualTo(wrappedRequest.getSession().getCreationTime()); + request.setAttribute(ACCESS_ATTR, lastAccessed); + } + }); + + final long creationTime = (Long) request.getAttribute(ACCESS_ATTR); + Thread.sleep(10L); + setupSession(); + + doFilter(new DoInFilter() { + @Override + public void doFilter(HttpServletRequest wrappedRequest) { + long lastAccessed = wrappedRequest.getSession().getLastAccessedTime(); + + assertThat(lastAccessed).isGreaterThan(wrappedRequest.getSession().getCreationTime()); + } + }); + } + + @Test + public void doFilterId() throws Exception { + final String ID_ATTR = "create"; + doFilter(new DoInFilter() { + @Override + public void doFilter(HttpServletRequest wrappedRequest) { + String id = wrappedRequest.getSession().getId(); + assertThat(id).isNotNull(); + assertThat(wrappedRequest.getSession().getId()).isEqualTo(id); + request.setAttribute(ID_ATTR, id); + } + }); + + final String id = (String) request.getAttribute(ID_ATTR); + assertThat(getSessionCookie().getValue()).isEqualTo(id); + setSessionCookie(id); + + doFilter(new DoInFilter() { + @Override + public void doFilter(HttpServletRequest wrappedRequest) { + assertThat(wrappedRequest.getSession().getId()).isEqualTo(id); + } + }); + } + + @Test + public void doFilterIdChanges() throws Exception { + final String ID_ATTR = "create"; + doFilter(new DoInFilter() { + @Override + public void doFilter(HttpServletRequest wrappedRequest) { + String id = wrappedRequest.getSession().getId(); + request.setAttribute(ID_ATTR, id); + } + }); + + final String id = (String) request.getAttribute(ID_ATTR); + + doFilter(new DoInFilter() { + @Override + public void doFilter(HttpServletRequest wrappedRequest) { + assertThat(wrappedRequest.getSession().getId()).isNotEqualTo(id); + } + }); + } + + @Test + public void doFilterServletContext() throws Exception { + doFilter(new DoInFilter() { + @Override + public void doFilter(HttpServletRequest wrappedRequest) { + ServletContext context = wrappedRequest.getSession().getServletContext(); + assertThat(context).isSameAs(wrappedRequest.getServletContext()); + } + }); + } + + @Test + public void doFilterMaxInactiveIntervalDefault() throws Exception { + doFilter(new DoInFilter() { + @Override + public void doFilter(HttpServletRequest wrappedRequest) { + int interval = wrappedRequest.getSession().getMaxInactiveInterval(); + assertThat(interval).isEqualTo(1800); // 30 minute default (same as Tomcat) + } + }); + } + + @Test + public void doFilterMaxInactiveIntervalOverride() throws Exception { + final int interval = 600; + doFilter(new DoInFilter() { + @Override + public void doFilter(HttpServletRequest wrappedRequest) { + wrappedRequest.getSession().setMaxInactiveInterval(interval); + assertThat(wrappedRequest.getSession().getMaxInactiveInterval()).isEqualTo(interval); + } + }); + + setupSession(); + + doFilter(new DoInFilter() { + @Override + public void doFilter(HttpServletRequest wrappedRequest) { + assertThat(wrappedRequest.getSession().getMaxInactiveInterval()).isEqualTo(interval); + } + }); + } + + @Test + public void doFilterAttribute() throws Exception { + final String ATTR = "ATTR"; + final String VALUE = "VALUE"; + doFilter(new DoInFilter() { + @Override + public void doFilter(HttpServletRequest wrappedRequest) { + wrappedRequest.getSession().setAttribute(ATTR, VALUE); + assertThat(wrappedRequest.getSession().getAttribute(ATTR)).isEqualTo(VALUE); + assertThat(Collections.list(wrappedRequest.getSession().getAttributeNames())).containsOnly(ATTR); + } + }); + + setupSession(); + + doFilter(new DoInFilter() { + @Override + public void doFilter(HttpServletRequest wrappedRequest) { + assertThat(wrappedRequest.getSession().getAttribute(ATTR)).isEqualTo(VALUE); + assertThat(Collections.list(wrappedRequest.getSession().getAttributeNames())).containsOnly(ATTR); + } + }); + + setupSession(); + + doFilter(new DoInFilter() { + @Override + public void doFilter(HttpServletRequest wrappedRequest) { + assertThat(wrappedRequest.getSession().getAttribute(ATTR)).isEqualTo(VALUE); + + wrappedRequest.getSession().removeAttribute(ATTR); + + assertThat(wrappedRequest.getSession().getAttribute(ATTR)).isNull(); + } + }); + + setupSession(); + + doFilter(new DoInFilter() { + @Override + public void doFilter(HttpServletRequest wrappedRequest) { + assertThat(wrappedRequest.getSession().getAttribute(ATTR)).isNull(); + } + }); + } + + @Test + public void doFilterValue() throws Exception { + final String ATTR = "ATTR"; + final String VALUE = "VALUE"; + doFilter(new DoInFilter() { + @Override + public void doFilter(HttpServletRequest wrappedRequest) { + wrappedRequest.getSession().putValue(ATTR, VALUE); + assertThat(wrappedRequest.getSession().getValue(ATTR)).isEqualTo(VALUE); + assertThat(Arrays.asList(wrappedRequest.getSession().getValueNames())).containsOnly(ATTR); + } + }); + + setupSession(); + + doFilter(new DoInFilter() { + @Override + public void doFilter(HttpServletRequest wrappedRequest) { + assertThat(wrappedRequest.getSession().getValue(ATTR)).isEqualTo(VALUE); + assertThat(Arrays.asList(wrappedRequest.getSession().getValueNames())).containsOnly(ATTR); + } + }); + + setupSession(); + + doFilter(new DoInFilter() { + @Override + public void doFilter(HttpServletRequest wrappedRequest) { + assertThat(wrappedRequest.getSession().getValue(ATTR)).isEqualTo(VALUE); + + wrappedRequest.getSession().removeValue(ATTR); + + assertThat(wrappedRequest.getSession().getValue(ATTR)).isNull(); + } + }); + + setupSession(); + + doFilter(new DoInFilter() { + @Override + public void doFilter(HttpServletRequest wrappedRequest) { + assertThat(wrappedRequest.getSession().getValue(ATTR)).isNull(); + } + }); + } + + @Test + public void doFilterIsNewTrue() throws Exception { + doFilter(new DoInFilter() { + @Override + public void doFilter(HttpServletRequest wrappedRequest) { + assertThat(wrappedRequest.getSession().isNew()).isTrue(); + assertThat(wrappedRequest.getSession().isNew()).isTrue(); + } + }); + } + + @Test + public void doFilterIsNewFalse() throws Exception { + doFilter(new DoInFilter() { + @Override + public void doFilter(HttpServletRequest wrappedRequest) { + wrappedRequest.getSession(); + } + }); + + setupSession(); + + doFilter(new DoInFilter() { + @Override + public void doFilter(HttpServletRequest wrappedRequest) { + assertThat(wrappedRequest.getSession().isNew()).isFalse(); + } + }); + } + + @Test + public void doFilterGetSessionNew() throws Exception { + doFilter(new DoInFilter() { + @Override + public void doFilter(HttpServletRequest wrappedRequest) { + wrappedRequest.getSession(); + } + }); + + assertNewSession(); + } + + @Test + public void doFilterGetSessionTrueNew() throws Exception { + doFilter(new DoInFilter() { + @Override + public void doFilter(HttpServletRequest wrappedRequest) { + wrappedRequest.getSession(true); + } + }); + + assertNewSession(); + } + + @Test + public void doFilterGetSessionFalseNew() throws Exception { + doFilter(new DoInFilter() { + @Override + public void doFilter(HttpServletRequest wrappedRequest) { + wrappedRequest.getSession(false); + } + }); + + assertNoSession(); + } + + @Test + public void doFilterGetSessionGetSessionFalse() throws Exception { + doFilter(new DoInFilter() { + @Override + public void doFilter(HttpServletRequest wrappedRequest) { + wrappedRequest.getSession(); + } + }); + + setupSession(); + + doFilter(new DoInFilter() { + @Override + public void doFilter(HttpServletRequest wrappedRequest) { + assertThat(wrappedRequest.getSession(false)).isNotNull(); + } + }); + } + + @Test + public void doFilterCookieSecuritySettings() throws Exception { + request.setSecure(true); + doFilter(new DoInFilter() { + @Override + public void doFilter(HttpServletRequest wrappedRequest) { + wrappedRequest.getSession(); + } + }); + + Cookie session = getSessionCookie(); + assertThat(session.isHttpOnly()).describedAs("Session Cookie should be HttpOnly").isTrue(); + assertThat(session.getSecure()).describedAs("Session Cookie should be marked as Secure").isTrue(); + } + + @Test + public void doFilterSessionContext() throws Exception { + doFilter(new DoInFilter() { + @Override + public void doFilter(HttpServletRequest wrappedRequest) { + HttpSessionContext sessionContext = wrappedRequest.getSession().getSessionContext(); + assertThat(sessionContext).isNotNull(); + assertThat(sessionContext.getSession("a")).isNull(); + assertThat(sessionContext.getIds()).isNotNull(); + assertThat(sessionContext.getIds().hasMoreElements()).isFalse(); + + try { + sessionContext.getIds().nextElement(); + fail("Expected Exception"); + } catch(NoSuchElementException success) {} + } + }); + } + + + + // --- saving + + @Test + public void doFilterGetAttr() throws Exception { + final String ATTR_NAME = "attr"; + final String ATTR_VALUE = "value"; + final String ATTR_NAME2 = "attr2"; + final String ATTR_VALUE2 = "value2"; + + doFilter(new DoInFilter() { + @Override + public void doFilter(HttpServletRequest wrappedRequest) { + wrappedRequest.getSession().setAttribute(ATTR_NAME, ATTR_VALUE); + wrappedRequest.getSession().setAttribute(ATTR_NAME2, ATTR_VALUE2); + } + }); + + assertNewSession(); + + setupSession(); + + doFilter(new DoInFilter() { + @Override + public void doFilter(HttpServletRequest wrappedRequest) { + assertThat(wrappedRequest.getSession().getAttribute(ATTR_NAME)).isEqualTo(ATTR_VALUE); + assertThat(wrappedRequest.getSession().getAttribute(ATTR_NAME2)).isEqualTo(ATTR_VALUE2); + } + }); + } + + // --- invalidate + + @Test + public void doFilterInvalidateInvalidateIllegalState() throws Exception { + doFilter(new DoInFilter() { + @Override + public void doFilter(HttpServletRequest wrappedRequest) { + HttpSession session = wrappedRequest.getSession(); + session.invalidate(); + try { + session.invalidate(); + fail("Expected Exception"); + } catch(IllegalStateException success) {} + } + }); + } + + @Test + public void doFilterInvalidateCreationTimeIllegalState() throws Exception { + doFilter(new DoInFilter() { + @Override + public void doFilter(HttpServletRequest wrappedRequest) { + HttpSession session = wrappedRequest.getSession(); + session.invalidate(); + try { + session.getCreationTime(); + fail("Expected Exception"); + } catch(IllegalStateException success) {} + } + }); + } + + @Test + public void doFilterInvalidateAttributeIllegalState() throws Exception { + doFilter(new DoInFilter() { + @Override + public void doFilter(HttpServletRequest wrappedRequest) { + HttpSession session = wrappedRequest.getSession(); + session.invalidate(); + try { + session.getAttribute("attr"); + fail("Expected Exception"); + } catch(IllegalStateException success) {} + } + }); + } + + @Test + public void doFilterInvalidateValueIllegalState() throws Exception { + doFilter(new DoInFilter() { + @Override + public void doFilter(HttpServletRequest wrappedRequest) { + HttpSession session = wrappedRequest.getSession(); + session.invalidate(); + try { + session.getValue("attr"); + fail("Expected Exception"); + } catch(IllegalStateException success) {} + } + }); + } + + @Test + public void doFilterInvalidateAttributeNamesIllegalState() throws Exception { + doFilter(new DoInFilter() { + @Override + public void doFilter(HttpServletRequest wrappedRequest) { + HttpSession session = wrappedRequest.getSession(); + session.invalidate(); + try { + session.getAttributeNames(); + fail("Expected Exception"); + } catch(IllegalStateException success) {} + } + }); + } + + @Test + public void doFilterInvalidateValueNamesIllegalState() throws Exception { + doFilter(new DoInFilter() { + @Override + public void doFilter(HttpServletRequest wrappedRequest) { + HttpSession session = wrappedRequest.getSession(); + session.invalidate(); + try { + session.getValueNames(); + fail("Expected Exception"); + } catch(IllegalStateException success) {} + } + }); + } + + @Test + public void doFilterInvalidateSetAttributeIllegalState() throws Exception { + doFilter(new DoInFilter() { + @Override + public void doFilter(HttpServletRequest wrappedRequest) { + HttpSession session = wrappedRequest.getSession(); + session.invalidate(); + try { + session.setAttribute("a", "b"); + fail("Expected Exception"); + } catch(IllegalStateException success) {} + } + }); + } + + @Test + public void doFilterInvalidatePutValueIllegalState() throws Exception { + doFilter(new DoInFilter() { + @Override + public void doFilter(HttpServletRequest wrappedRequest) { + HttpSession session = wrappedRequest.getSession(); + session.invalidate(); + try { + session.putValue("a", "b"); + fail("Expected Exception"); + } catch(IllegalStateException success) {} + } + }); + } + + @Test + public void doFilterInvalidateRemoveAttributeIllegalState() throws Exception { + doFilter(new DoInFilter() { + @Override + public void doFilter(HttpServletRequest wrappedRequest) { + HttpSession session = wrappedRequest.getSession(); + session.invalidate(); + try { + session.removeAttribute("name"); + fail("Expected Exception"); + } catch(IllegalStateException success) {} + } + }); + } + + @Test + public void doFilterInvalidateRemoveValueIllegalState() throws Exception { + doFilter(new DoInFilter() { + @Override + public void doFilter(HttpServletRequest wrappedRequest) { + HttpSession session = wrappedRequest.getSession(); + session.invalidate(); + try { + session.removeValue("name"); + fail("Expected Exception"); + } catch(IllegalStateException success) {} + } + }); + } + + @Test + public void doFilterInvalidateNewIllegalState() throws Exception { + doFilter(new DoInFilter() { + @Override + public void doFilter(HttpServletRequest wrappedRequest) { + HttpSession session = wrappedRequest.getSession(); + session.invalidate(); + try { + session.isNew(); + fail("Expected Exception"); + } catch(IllegalStateException success) {} + } + }); + } + + @Test + public void doFilterInvalidateLastAccessedTimeIllegalState() throws Exception { + doFilter(new DoInFilter() { + @Override + public void doFilter(HttpServletRequest wrappedRequest) { + HttpSession session = wrappedRequest.getSession(); + session.invalidate(); + try { + session.getLastAccessedTime(); + fail("Expected Exception"); + } catch(IllegalStateException success) {} + } + }); + } + + @Test + public void doFilterInvalidateId() throws Exception { + doFilter(new DoInFilter() { + @Override + public void doFilter(HttpServletRequest wrappedRequest) { + HttpSession session = wrappedRequest.getSession(); + session.invalidate(); + // no exception + session.getId(); + } + }); + } + + @Test + public void doFilterInvalidateServletContext() throws Exception { + doFilter(new DoInFilter() { + @Override + public void doFilter(HttpServletRequest wrappedRequest) { + HttpSession session = wrappedRequest.getSession(); + session.invalidate(); + + // no exception + session.getServletContext(); + } + }); + } + + @Test + public void doFilterInvalidateSessionContext() throws Exception { + doFilter(new DoInFilter() { + @Override + public void doFilter(HttpServletRequest wrappedRequest) { + HttpSession session = wrappedRequest.getSession(); + session.invalidate(); + + // no exception + session.getSessionContext(); + } + }); + } + + @Test + public void doFilterInvalidateMaxInteractiveInterval() throws Exception { + doFilter(new DoInFilter() { + @Override + public void doFilter(HttpServletRequest wrappedRequest) { + HttpSession session = wrappedRequest.getSession(); + session.invalidate(); + + // no exception + session.getMaxInactiveInterval(); + session.setMaxInactiveInterval(3600); + } + }); + } + + @Test + public void doFilterInvalidateAndGetSession() throws Exception { + final String ATTR_NAME = "attr"; + final String ATTR_VALUE = "value"; + final String ATTR_NAME2 = "attr2"; + final String ATTR_VALUE2 = "value2"; + + doFilter(new DoInFilter() { + @Override + public void doFilter(HttpServletRequest wrappedRequest) { + wrappedRequest.getSession().setAttribute(ATTR_NAME, ATTR_VALUE); + wrappedRequest.getSession().invalidate(); + wrappedRequest.getSession().setAttribute(ATTR_NAME2, ATTR_VALUE2); + } + }); + + assertNewSession(); + + setupSession(); + + doFilter(new DoInFilter() { + @Override + public void doFilter(HttpServletRequest wrappedRequest) { + assertThat(wrappedRequest.getSession().getAttribute(ATTR_NAME)).isNull(); + assertThat(wrappedRequest.getSession().getAttribute(ATTR_NAME2)).isEqualTo(ATTR_VALUE2); + } + }); + } + + // --- invalid session ids + + @Test + public void doFilterGetSessionInvalidSessionId() throws Exception { + setSessionCookie("INVALID"); + doFilter(new DoInFilter() { + @Override + public void doFilter(HttpServletRequest wrappedRequest) { + wrappedRequest.getSession(); + } + }); + + assertNewSession(); + } + + @Test + public void doFilterGetSessionTrueInvalidSessionId() throws Exception { + setSessionCookie("INVALID"); + doFilter(new DoInFilter() { + @Override + public void doFilter(HttpServletRequest wrappedRequest) { + wrappedRequest.getSession(true); + } + }); + + assertNewSession(); + } + + @Test + public void doFilterGetSessionFalseInvalidSessionId() throws Exception { + setSessionCookie("INVALID"); + doFilter(new DoInFilter() { + @Override + public void doFilter(HttpServletRequest wrappedRequest) { + wrappedRequest.getSession(false); + } + }); + + assertNoSession(); + } + + // --- commit response saves immediately + + @Test + public void doFilterSendError() throws Exception { + doFilter(new DoInFilter() { + @Override + public void doFilter(HttpServletRequest wrappedRequest, HttpServletResponse wrappedResponse) throws IOException { + String id = wrappedRequest.getSession().getId(); + wrappedResponse.sendError(HttpServletResponse.SC_INTERNAL_SERVER_ERROR); + assertThat(sessionRepository.getSession(id)).isNotNull(); + } + }); + } + + @Test + public void doFilterSendErrorAndMessage() throws Exception { + doFilter(new DoInFilter() { + @Override + public void doFilter(HttpServletRequest wrappedRequest, HttpServletResponse wrappedResponse) throws IOException { + String id = wrappedRequest.getSession().getId(); + wrappedResponse.sendError(HttpServletResponse.SC_INTERNAL_SERVER_ERROR, "Error"); + assertThat(sessionRepository.getSession(id)).isNotNull(); + } + }); + } + + @Test + public void doFilterSendRedirect() throws Exception { + doFilter(new DoInFilter() { + @Override + public void doFilter(HttpServletRequest wrappedRequest, HttpServletResponse wrappedResponse) throws IOException { + String id = wrappedRequest.getSession().getId(); + wrappedResponse.sendRedirect("/"); + assertThat(sessionRepository.getSession(id)).isNotNull(); + } + }); + } + + @Test + public void doFilterFlushBuffer() throws Exception { + doFilter(new DoInFilter() { + @Override + public void doFilter(HttpServletRequest wrappedRequest, HttpServletResponse wrappedResponse) throws IOException { + String id = wrappedRequest.getSession().getId(); + wrappedResponse.flushBuffer(); + assertThat(sessionRepository.getSession(id)).isNotNull(); + } + }); + } + + @Test + public void doFilterOutputFlush() throws Exception { + doFilter(new DoInFilter() { + @Override + public void doFilter(HttpServletRequest wrappedRequest, HttpServletResponse wrappedResponse) throws IOException { + String id = wrappedRequest.getSession().getId(); + wrappedResponse.getOutputStream().flush(); + assertThat(sessionRepository.getSession(id)).isNotNull(); + } + }); + } + + @Test + public void doFilterOutputClose() throws Exception { + doFilter(new DoInFilter() { + @Override + public void doFilter(HttpServletRequest wrappedRequest, HttpServletResponse wrappedResponse) throws IOException { + String id = wrappedRequest.getSession().getId(); + wrappedResponse.getOutputStream().close(); + assertThat(sessionRepository.getSession(id)).isNotNull(); + } + }); + } + + @Test + public void doFilterWriterFlush() throws Exception { + doFilter(new DoInFilter() { + @Override + public void doFilter(HttpServletRequest wrappedRequest, HttpServletResponse wrappedResponse) throws IOException { + String id = wrappedRequest.getSession().getId(); + wrappedResponse.getWriter().flush(); + assertThat(sessionRepository.getSession(id)).isNotNull(); + } + }); + } + + @Test + public void doFilterWriterClose() throws Exception { + doFilter(new DoInFilter() { + @Override + public void doFilter(HttpServletRequest wrappedRequest, HttpServletResponse wrappedResponse) throws IOException { + String id = wrappedRequest.getSession().getId(); + wrappedResponse.getWriter().close(); + assertThat(sessionRepository.getSession(id)).isNotNull(); + } + }); + } + + // --- helper methods + + private void assertNewSession() { + Cookie cookie = getSessionCookie(); + assertThat(cookie).isNotNull(); + assertThat(cookie.getMaxAge()).isEqualTo(-1); + assertThat(cookie.getValue()).isNotEqualTo("INVALID"); + assertThat(cookie.isHttpOnly()).describedAs("Cookie is expected to be HTTP Only").isTrue(); + assertThat(cookie.getSecure()).describedAs("Cookie secured is expected to be " + request.isSecure()).isEqualTo(request.isSecure()); + assertThat(request.getSession(false)).describedAs("The original HttpServletRequest HttpSession should be null").isNull(); + } + + private void assertNoSession() { + Cookie cookie = getSessionCookie(); + assertThat(cookie).isNull(); + assertThat(request.getSession(false)).describedAs("The original HttpServletRequest HttpSession should be null").isNull(); + } + + private Cookie getSessionCookie() { + return response.getCookie("SESSION"); + } + + private void setSessionCookie(String sessionId) { + request.setCookies(new Cookie[]{new Cookie("SESSION", sessionId)}); + } + + private void setupSession() { + setSessionCookie(getSessionCookie().getValue()); + } + + private void doFilter(final DoInFilter doInFilter) throws ServletException, IOException { + chain = new MockFilterChain(new HttpServlet() {}, new OncePerRequestFilter() { + @Override + protected void doFilterInternal(HttpServletRequest request, HttpServletResponse response, FilterChain filterChain) throws ServletException, IOException { + doInFilter.doFilter(request, response); + } + }); + filter.doFilter(request, response, chain); + } + + abstract class DoInFilter { + void doFilter(HttpServletRequest wrappedRequest, HttpServletResponse wrappedResponse) throws ServletException, IOException { + doFilter(wrappedRequest); + } + void doFilter(HttpServletRequest wrappedRequest) {} + } +} \ No newline at end of file