diff --git a/samples/findbyusername/src/main/java/sample/config/SecurityConfig.java b/samples/findbyusername/src/main/java/sample/config/SecurityConfig.java index a2a90ef1..ade70f5f 100644 --- a/samples/findbyusername/src/main/java/sample/config/SecurityConfig.java +++ b/samples/findbyusername/src/main/java/sample/config/SecurityConfig.java @@ -20,26 +20,18 @@ import org.springframework.security.config.annotation.authentication.builders.Au import org.springframework.security.config.annotation.web.builders.HttpSecurity; import org.springframework.security.config.annotation.web.configuration.EnableWebSecurity; import org.springframework.security.config.annotation.web.configuration.WebSecurityConfigurerAdapter; -import org.springframework.security.web.authentication.SavedRequestAwareAuthenticationSuccessHandler; - -import sample.session.CompositeAuthenticationSuccessHandler; -import sample.session.SpringSessionPrincipalNameSuccessHandler; /** * @author Rob Winch */ - @EnableWebSecurity public class SecurityConfig extends WebSecurityConfigurerAdapter { // tag::config[] @Override protected void configure(HttpSecurity http) throws Exception { - CompositeAuthenticationSuccessHandler successHandler = createHandler(); - http .formLogin() - .successHandler(successHandler) .loginPage("/login") .permitAll() .and() @@ -52,19 +44,6 @@ public class SecurityConfig extends WebSecurityConfigurerAdapter { } // end::config[] - // tag::handler[] - private CompositeAuthenticationSuccessHandler createHandler() { - SpringSessionPrincipalNameSuccessHandler setUsernameHandler = - new SpringSessionPrincipalNameSuccessHandler(); - SavedRequestAwareAuthenticationSuccessHandler defaultHandler = - new SavedRequestAwareAuthenticationSuccessHandler(); - - CompositeAuthenticationSuccessHandler successHandler = - new CompositeAuthenticationSuccessHandler(setUsernameHandler, defaultHandler); - return successHandler; - } - // end::handler[] - @Autowired public void configureGlobal(AuthenticationManagerBuilder auth) throws Exception { auth diff --git a/samples/findbyusername/src/main/java/sample/session/CompositeAuthenticationSuccessHandler.java b/samples/findbyusername/src/main/java/sample/session/CompositeAuthenticationSuccessHandler.java deleted file mode 100644 index 4ef361b6..00000000 --- a/samples/findbyusername/src/main/java/sample/session/CompositeAuthenticationSuccessHandler.java +++ /dev/null @@ -1,49 +0,0 @@ -/* - * Copyright 2002-2015 the original author or authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package sample.session; - -import java.io.IOException; -import java.util.Arrays; -import java.util.List; - -import javax.servlet.ServletException; -import javax.servlet.http.HttpServletRequest; -import javax.servlet.http.HttpServletResponse; - -import org.springframework.security.core.Authentication; -import org.springframework.security.web.authentication.AuthenticationSuccessHandler; - -/** - * @author Rob Winch - * - */ -// tag::class[] -public class CompositeAuthenticationSuccessHandler implements AuthenticationSuccessHandler { - private List handlers; - - public CompositeAuthenticationSuccessHandler(AuthenticationSuccessHandler... handlers) { - super(); - this.handlers = Arrays.asList(handlers); - } - - public void onAuthenticationSuccess(HttpServletRequest request, HttpServletResponse response, - Authentication authentication) throws IOException, ServletException { - for(AuthenticationSuccessHandler handler : handlers) { - handler.onAuthenticationSuccess(request, response, authentication); - } - } -} -// end::class[] diff --git a/samples/findbyusername/src/main/java/sample/session/SpringSessionPrincipalNameSuccessHandler.java b/samples/findbyusername/src/main/java/sample/session/SpringSessionPrincipalNameSuccessHandler.java deleted file mode 100644 index df77dcf8..00000000 --- a/samples/findbyusername/src/main/java/sample/session/SpringSessionPrincipalNameSuccessHandler.java +++ /dev/null @@ -1,51 +0,0 @@ -/* - * Copyright 2002-2015 the original author or authors. - * - * Licensed under the Apache License, Version 2.0 (the "License"); - * you may not use this file except in compliance with the License. - * You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ -package sample.session; - -import java.io.IOException; - -import javax.servlet.ServletException; -import javax.servlet.http.HttpServletRequest; -import javax.servlet.http.HttpServletResponse; -import javax.servlet.http.HttpSession; - -import org.springframework.security.core.Authentication; -import org.springframework.security.web.authentication.AuthenticationSuccessHandler; -import org.springframework.session.FindByIndexNameSessionRepository; - -/** - * Inserts the username into Spring session after we successfully authenticate. - * Adding the principal name to the session is a requirement of - * {@link FindByPrincipalNameSessionRepository} - * - * @author Rob Winch - */ -// tag::class[] -public class SpringSessionPrincipalNameSuccessHandler - implements AuthenticationSuccessHandler { - - public void onAuthenticationSuccess(HttpServletRequest request, HttpServletResponse response, - Authentication authentication) throws IOException, ServletException { - - HttpSession session = request.getSession(); - String currentUsername = authentication.getName(); - - // tag::set-username[] - session.setAttribute(FindByIndexNameSessionRepository.PRINCIPAL_NAME_INDEX_NAME, currentUsername); - // end::set-username[] - } -} -// end::class[] diff --git a/spring-session/src/integration-test/java/org/springframework/session/data/gemfire/GemFireOperationsSessionRepositoryIntegrationTests.java b/spring-session/src/integration-test/java/org/springframework/session/data/gemfire/GemFireOperationsSessionRepositoryIntegrationTests.java index ba547ec8..07ece8bd 100644 --- a/spring-session/src/integration-test/java/org/springframework/session/data/gemfire/GemFireOperationsSessionRepositoryIntegrationTests.java +++ b/spring-session/src/integration-test/java/org/springframework/session/data/gemfire/GemFireOperationsSessionRepositoryIntegrationTests.java @@ -24,12 +24,17 @@ import java.util.HashMap; import java.util.List; import java.util.Map; import java.util.Properties; +import java.util.UUID; import org.junit.Before; import org.junit.Test; import org.junit.runner.RunWith; import org.springframework.context.annotation.Bean; import org.springframework.data.gemfire.CacheFactoryBean; +import org.springframework.security.authentication.UsernamePasswordAuthenticationToken; +import org.springframework.security.core.authority.AuthorityUtils; +import org.springframework.security.core.context.SecurityContext; +import org.springframework.security.core.context.SecurityContextHolder; import org.springframework.session.ExpiringSession; import org.springframework.session.FindByIndexNameSessionRepository; import org.springframework.session.data.gemfire.config.annotation.web.http.EnableGemFireHttpSession; @@ -67,14 +72,25 @@ import com.gemstone.gemfire.pdx.PdxWriter; @ContextConfiguration @WebAppConfiguration public class GemFireOperationsSessionRepositoryIntegrationTests extends AbstractGemFireIntegrationTests { + private static final String SPRING_SECURITY_CONTEXT = "SPRING_SECURITY_CONTEXT"; private static final int MAX_INACTIVE_INTERVAL_IN_SECONDS = 300; private static final String GEMFIRE_LOG_LEVEL = "warning"; private static final String SPRING_SESSION_GEMFIRE_REGION_NAME = "TestPartitionedSessions"; + SecurityContext context; + + SecurityContext changedContext; + @Before public void setup() { + context = SecurityContextHolder.createEmptyContext(); + context.setAuthentication(new UsernamePasswordAuthenticationToken("username-"+UUID.randomUUID(), "na", AuthorityUtils.createAuthorityList("ROLE_USER"))); + + changedContext = SecurityContextHolder.createEmptyContext(); + changedContext.setAuthentication(new UsernamePasswordAuthenticationToken("changedContext-"+UUID.randomUUID(), "na", AuthorityUtils.createAuthorityList("ROLE_USER"))); + assertThat(gemfireCache).isNotNull(); assertThat(gemfireSessionRepository).isNotNull(); assertThat(gemfireSessionRepository.getMaxInactiveIntervalInSeconds()).isEqualTo( @@ -159,6 +175,34 @@ public class GemFireOperationsSessionRepositoryIntegrationTests extends Abstract assertThat(robWinchSessions.get(sessionFive.getId())).isEqualTo(sessionFive); } + @Test + public void findSessionsBySecurityPrincipalName() { + ExpiringSession toSave = this.gemfireSessionRepository.createSession(); + toSave.setAttribute(SPRING_SECURITY_CONTEXT, context); + + save(toSave); + + Map findByPrincipalName = doFindByPrincipalName(getSecurityName()); + assertThat(findByPrincipalName).hasSize(1); + assertThat(findByPrincipalName.keySet()).containsOnly(toSave.getId()); + } + + @Test + public void findSessionsByChangedSecurityPrincipalName() { + ExpiringSession toSave = this.gemfireSessionRepository.createSession(); + toSave.setAttribute(SPRING_SECURITY_CONTEXT, context); + save(toSave); + + toSave.setAttribute(SPRING_SECURITY_CONTEXT, changedContext); + save(toSave); + + Map findByPrincipalName = doFindByPrincipalName(getSecurityName()); + assertThat(findByPrincipalName).isEmpty(); + + findByPrincipalName = doFindByPrincipalName(getChangedSecurityName()); + assertThat(findByPrincipalName).hasSize(1); + } + @Test public void findsNoSessionsByNonExistingPrincipal() { Map nonExistingPrincipalSessions = doFindByPrincipalName("nonExistingPrincipalName"); @@ -217,6 +261,14 @@ public class GemFireOperationsSessionRepositoryIntegrationTests extends Abstract assertThat(savedSession.getAttribute(expectedAttributeNames.get(3))).isEqualTo(jonDoe); } + private String getSecurityName() { + return context.getAuthentication().getName(); + } + + private String getChangedSecurityName() { + return changedContext.getAuthentication().getName(); + } + @EnableGemFireHttpSession(regionName = SPRING_SESSION_GEMFIRE_REGION_NAME, maxInactiveIntervalInSeconds = MAX_INACTIVE_INTERVAL_IN_SECONDS) static class SpringSessionGemFireConfiguration { diff --git a/spring-session/src/integration-test/java/org/springframework/session/data/redis/RedisOperationsSessionRepositoryITests.java b/spring-session/src/integration-test/java/org/springframework/session/data/redis/RedisOperationsSessionRepositoryITests.java index 786ef2ce..5d3ecb31 100644 --- a/spring-session/src/integration-test/java/org/springframework/session/data/redis/RedisOperationsSessionRepositoryITests.java +++ b/spring-session/src/integration-test/java/org/springframework/session/data/redis/RedisOperationsSessionRepositoryITests.java @@ -49,6 +49,10 @@ import org.springframework.test.context.web.WebAppConfiguration; @ContextConfiguration @WebAppConfiguration public class RedisOperationsSessionRepositoryITests { + private static final String SPRING_SECURITY_CONTEXT = "SPRING_SECURITY_CONTEXT"; + + private static final String INDEX_NAME = FindByIndexNameSessionRepository.PRINCIPAL_NAME_INDEX_NAME; + @Autowired private RedisOperationsSessionRepository repository; @@ -58,26 +62,37 @@ public class RedisOperationsSessionRepositoryITests { @Autowired RedisOperations redis; + SecurityContext context; + + SecurityContext changedContext; + @Before public void setup() { registry.clear(); + context = SecurityContextHolder.createEmptyContext(); + context.setAuthentication(new UsernamePasswordAuthenticationToken("username-"+UUID.randomUUID(), "na", AuthorityUtils.createAuthorityList("ROLE_USER"))); + + changedContext = SecurityContextHolder.createEmptyContext(); + changedContext.setAuthentication(new UsernamePasswordAuthenticationToken("changedContext-"+UUID.randomUUID(), "na", AuthorityUtils.createAuthorityList("ROLE_USER"))); } @Test public void saves() throws InterruptedException { - String username = "saves-"+System.currentTimeMillis(); + String username = "saves-" + System.currentTimeMillis(); - String usernameSessionKey = "spring:session:RedisOperationsSessionRepositoryITests:index:" + FindByIndexNameSessionRepository.PRINCIPAL_NAME_INDEX_NAME + ":" + username; + String usernameSessionKey = "spring:session:RedisOperationsSessionRepositoryITests:index:" + INDEX_NAME + ":" + + username; RedisSession toSave = repository.createSession(); String expectedAttributeName = "a"; String expectedAttributeValue = "b"; toSave.setAttribute(expectedAttributeName, expectedAttributeValue); - Authentication toSaveToken = new UsernamePasswordAuthenticationToken(username,"password", AuthorityUtils.createAuthorityList("ROLE_USER")); + Authentication toSaveToken = new UsernamePasswordAuthenticationToken(username, "password", + AuthorityUtils.createAuthorityList("ROLE_USER")); SecurityContext toSaveContext = SecurityContextHolder.createEmptyContext(); toSaveContext.setAuthentication(toSaveToken); - toSave.setAttribute("SPRING_SECURITY_CONTEXT", toSaveContext); - toSave.setAttribute(FindByIndexNameSessionRepository.PRINCIPAL_NAME_INDEX_NAME, username); + toSave.setAttribute(SPRING_SECURITY_CONTEXT, toSaveContext); + toSave.setAttribute(INDEX_NAME, username); registry.clear(); repository.save(toSave); @@ -100,8 +115,8 @@ public class RedisOperationsSessionRepositoryITests { assertThat(registry.getEvent()).isInstanceOf(SessionDestroyedEvent.class); assertThat(redis.boundSetOps(usernameSessionKey).members()).doesNotContain(toSave.getId()); - - assertThat(registry.getEvent().getSession().getAttribute(expectedAttributeName)).isEqualTo(expectedAttributeValue); + assertThat(registry.getEvent().getSession().getAttribute(expectedAttributeName)) + .isEqualTo(expectedAttributeValue); } @Test @@ -125,15 +140,18 @@ public class RedisOperationsSessionRepositoryITests { repository.delete(toSave.getId()); } + + @Test public void findByPrincipalName() throws Exception { String principalName = "findByPrincipalName" + UUID.randomUUID(); RedisSession toSave = repository.createSession(); - toSave.setAttribute(FindByIndexNameSessionRepository.PRINCIPAL_NAME_INDEX_NAME, principalName); + toSave.setAttribute(INDEX_NAME, principalName); repository.save(toSave); - Map findByPrincipalName = repository.findByIndexNameAndIndexValue(FindByIndexNameSessionRepository.PRINCIPAL_NAME_INDEX_NAME, principalName); + Map findByPrincipalName = repository.findByIndexNameAndIndexValue(INDEX_NAME, + principalName); assertThat(findByPrincipalName).hasSize(1); assertThat(findByPrincipalName.keySet()).containsOnly(toSave.getId()); @@ -141,7 +159,7 @@ public class RedisOperationsSessionRepositoryITests { repository.delete(toSave.getId()); registry.receivedEvent(); - findByPrincipalName = repository.findByIndexNameAndIndexValue(FindByIndexNameSessionRepository.PRINCIPAL_NAME_INDEX_NAME, principalName); + findByPrincipalName = repository.findByIndexNameAndIndexValue(INDEX_NAME, principalName); assertThat(findByPrincipalName).hasSize(0); assertThat(findByPrincipalName.keySet()).doesNotContain(toSave.getId()); @@ -151,7 +169,7 @@ public class RedisOperationsSessionRepositoryITests { public void findByPrincipalNameExpireRemovesIndex() throws Exception { String principalName = "findByPrincipalNameExpireRemovesIndex" + UUID.randomUUID(); RedisSession toSave = repository.createSession(); - toSave.setAttribute(FindByIndexNameSessionRepository.PRINCIPAL_NAME_INDEX_NAME, principalName); + toSave.setAttribute(INDEX_NAME, principalName); repository.save(toSave); @@ -159,9 +177,10 @@ public class RedisOperationsSessionRepositoryITests { String channel = ":expired"; DefaultMessage message = new DefaultMessage(channel.getBytes("UTF-8"), body.getBytes("UTF-8")); byte[] pattern = new byte[] {}; - repository.onMessage(message , pattern); + repository.onMessage(message, pattern); - Map findByPrincipalName = repository.findByIndexNameAndIndexValue(FindByIndexNameSessionRepository.PRINCIPAL_NAME_INDEX_NAME,principalName); + Map findByPrincipalName = repository.findByIndexNameAndIndexValue(INDEX_NAME, + principalName); assertThat(findByPrincipalName).hasSize(0); assertThat(findByPrincipalName.keySet()).doesNotContain(toSave.getId()); @@ -171,14 +190,15 @@ public class RedisOperationsSessionRepositoryITests { public void findByPrincipalNameNoPrincipalNameChange() throws Exception { String principalName = "findByPrincipalNameNoPrincipalNameChange" + UUID.randomUUID(); RedisSession toSave = repository.createSession(); - toSave.setAttribute(FindByIndexNameSessionRepository.PRINCIPAL_NAME_INDEX_NAME, principalName); + toSave.setAttribute(INDEX_NAME, principalName); repository.save(toSave); toSave.setAttribute("other", "value"); repository.save(toSave); - Map findByPrincipalName = repository.findByIndexNameAndIndexValue(FindByIndexNameSessionRepository.PRINCIPAL_NAME_INDEX_NAME,principalName); + Map findByPrincipalName = repository.findByIndexNameAndIndexValue(INDEX_NAME, + principalName); assertThat(findByPrincipalName).hasSize(1); assertThat(findByPrincipalName.keySet()).containsOnly(toSave.getId()); @@ -188,7 +208,7 @@ public class RedisOperationsSessionRepositoryITests { public void findByPrincipalNameNoPrincipalNameChangeReload() throws Exception { String principalName = "findByPrincipalNameNoPrincipalNameChangeReload" + UUID.randomUUID(); RedisSession toSave = repository.createSession(); - toSave.setAttribute(FindByIndexNameSessionRepository.PRINCIPAL_NAME_INDEX_NAME, principalName); + toSave.setAttribute(INDEX_NAME, principalName); repository.save(toSave); @@ -197,7 +217,8 @@ public class RedisOperationsSessionRepositoryITests { toSave.setAttribute("other", "value"); repository.save(toSave); - Map findByPrincipalName = repository.findByIndexNameAndIndexValue(FindByIndexNameSessionRepository.PRINCIPAL_NAME_INDEX_NAME,principalName); + Map findByPrincipalName = repository.findByIndexNameAndIndexValue(INDEX_NAME, + principalName); assertThat(findByPrincipalName).hasSize(1); assertThat(findByPrincipalName.keySet()).containsOnly(toSave.getId()); @@ -207,14 +228,15 @@ public class RedisOperationsSessionRepositoryITests { public void findByDeletedPrincipalName() throws Exception { String principalName = "findByDeletedPrincipalName" + UUID.randomUUID(); RedisSession toSave = repository.createSession(); - toSave.setAttribute(FindByIndexNameSessionRepository.PRINCIPAL_NAME_INDEX_NAME, principalName); + toSave.setAttribute(INDEX_NAME, principalName); repository.save(toSave); - toSave.setAttribute(FindByIndexNameSessionRepository.PRINCIPAL_NAME_INDEX_NAME, null); + toSave.setAttribute(INDEX_NAME, null); repository.save(toSave); - Map findByPrincipalName = repository.findByIndexNameAndIndexValue(FindByIndexNameSessionRepository.PRINCIPAL_NAME_INDEX_NAME,principalName); + Map findByPrincipalName = repository.findByIndexNameAndIndexValue(INDEX_NAME, + principalName); assertThat(findByPrincipalName).isEmpty(); } @@ -224,17 +246,18 @@ public class RedisOperationsSessionRepositoryITests { String principalName = "findByChangedPrincipalName" + UUID.randomUUID(); String principalNameChanged = "findByChangedPrincipalName" + UUID.randomUUID(); RedisSession toSave = repository.createSession(); - toSave.setAttribute(FindByIndexNameSessionRepository.PRINCIPAL_NAME_INDEX_NAME, principalName); + toSave.setAttribute(INDEX_NAME, principalName); repository.save(toSave); - toSave.setAttribute(FindByIndexNameSessionRepository.PRINCIPAL_NAME_INDEX_NAME, principalNameChanged); + toSave.setAttribute(INDEX_NAME, principalNameChanged); repository.save(toSave); - Map findByPrincipalName = repository.findByIndexNameAndIndexValue(FindByIndexNameSessionRepository.PRINCIPAL_NAME_INDEX_NAME,principalName); + Map findByPrincipalName = repository.findByIndexNameAndIndexValue(INDEX_NAME, + principalName); assertThat(findByPrincipalName).isEmpty(); - findByPrincipalName = repository.findByIndexNameAndIndexValue(FindByIndexNameSessionRepository.PRINCIPAL_NAME_INDEX_NAME,principalNameChanged); + findByPrincipalName = repository.findByIndexNameAndIndexValue(INDEX_NAME, principalNameChanged); assertThat(findByPrincipalName).hasSize(1); assertThat(findByPrincipalName.keySet()).containsOnly(toSave.getId()); @@ -244,15 +267,16 @@ public class RedisOperationsSessionRepositoryITests { public void findByDeletedPrincipalNameReload() throws Exception { String principalName = "findByDeletedPrincipalName" + UUID.randomUUID(); RedisSession toSave = repository.createSession(); - toSave.setAttribute(FindByIndexNameSessionRepository.PRINCIPAL_NAME_INDEX_NAME, principalName); + toSave.setAttribute(INDEX_NAME, principalName); repository.save(toSave); RedisSession getSession = repository.getSession(toSave.getId()); - getSession.setAttribute(FindByIndexNameSessionRepository.PRINCIPAL_NAME_INDEX_NAME, null); + getSession.setAttribute(INDEX_NAME, null); repository.save(getSession); - Map findByPrincipalName = repository.findByIndexNameAndIndexValue(FindByIndexNameSessionRepository.PRINCIPAL_NAME_INDEX_NAME,principalName); + Map findByPrincipalName = repository.findByIndexNameAndIndexValue(INDEX_NAME, + principalName); assertThat(findByPrincipalName).isEmpty(); } @@ -262,24 +286,186 @@ public class RedisOperationsSessionRepositoryITests { String principalName = "findByChangedPrincipalName" + UUID.randomUUID(); String principalNameChanged = "findByChangedPrincipalName" + UUID.randomUUID(); RedisSession toSave = repository.createSession(); - toSave.setAttribute(FindByIndexNameSessionRepository.PRINCIPAL_NAME_INDEX_NAME, principalName); + toSave.setAttribute(INDEX_NAME, principalName); repository.save(toSave); RedisSession getSession = repository.getSession(toSave.getId()); - getSession.setAttribute(FindByIndexNameSessionRepository.PRINCIPAL_NAME_INDEX_NAME, principalNameChanged); + getSession.setAttribute(INDEX_NAME, principalNameChanged); repository.save(getSession); - Map findByPrincipalName = repository.findByIndexNameAndIndexValue(FindByIndexNameSessionRepository.PRINCIPAL_NAME_INDEX_NAME,principalName); + Map findByPrincipalName = repository.findByIndexNameAndIndexValue(INDEX_NAME, + principalName); assertThat(findByPrincipalName).isEmpty(); - findByPrincipalName = repository.findByIndexNameAndIndexValue(FindByIndexNameSessionRepository.PRINCIPAL_NAME_INDEX_NAME,principalNameChanged); + findByPrincipalName = repository.findByIndexNameAndIndexValue(INDEX_NAME, principalNameChanged); assertThat(findByPrincipalName).hasSize(1); assertThat(findByPrincipalName.keySet()).containsOnly(toSave.getId()); } + @Test + public void findBySecurityPrincipalName() throws Exception { + RedisSession toSave = repository.createSession(); + toSave.setAttribute(SPRING_SECURITY_CONTEXT, context); + + repository.save(toSave); + + Map findByPrincipalName = repository.findByIndexNameAndIndexValue(INDEX_NAME, + getSecurityName()); + + assertThat(findByPrincipalName).hasSize(1); + assertThat(findByPrincipalName.keySet()).containsOnly(toSave.getId()); + + repository.delete(toSave.getId()); + registry.receivedEvent(); + + findByPrincipalName = repository.findByIndexNameAndIndexValue(INDEX_NAME, getSecurityName()); + + assertThat(findByPrincipalName).hasSize(0); + assertThat(findByPrincipalName.keySet()).doesNotContain(toSave.getId()); + } + + @Test + public void findBySecurityPrincipalNameExpireRemovesIndex() throws Exception { + RedisSession toSave = repository.createSession(); + toSave.setAttribute(SPRING_SECURITY_CONTEXT, context); + + repository.save(toSave); + + String body = "spring:session:RedisOperationsSessionRepositoryITests:sessions:expires:" + toSave.getId(); + String channel = ":expired"; + DefaultMessage message = new DefaultMessage(channel.getBytes("UTF-8"), body.getBytes("UTF-8")); + byte[] pattern = new byte[] {}; + repository.onMessage(message, pattern); + + Map findByPrincipalName = repository.findByIndexNameAndIndexValue(INDEX_NAME, + getSecurityName()); + + assertThat(findByPrincipalName).hasSize(0); + assertThat(findByPrincipalName.keySet()).doesNotContain(toSave.getId()); + } + + @Test + public void findByPrincipalNameNoSecurityPrincipalNameChange() throws Exception { + RedisSession toSave = repository.createSession(); + toSave.setAttribute(SPRING_SECURITY_CONTEXT, context); + + repository.save(toSave); + + toSave.setAttribute("other", "value"); + repository.save(toSave); + + Map findByPrincipalName = repository.findByIndexNameAndIndexValue(INDEX_NAME, + getSecurityName()); + + assertThat(findByPrincipalName).hasSize(1); + assertThat(findByPrincipalName.keySet()).containsOnly(toSave.getId()); + } + + @Test + public void findByPrincipalNameNoSecurityPrincipalNameChangeReload() throws Exception { + RedisSession toSave = repository.createSession(); + toSave.setAttribute(SPRING_SECURITY_CONTEXT, context); + + repository.save(toSave); + + toSave = repository.getSession(toSave.getId()); + + toSave.setAttribute("other", "value"); + repository.save(toSave); + + Map findByPrincipalName = repository.findByIndexNameAndIndexValue(INDEX_NAME, + getSecurityName()); + + assertThat(findByPrincipalName).hasSize(1); + assertThat(findByPrincipalName.keySet()).containsOnly(toSave.getId()); + } + + @Test + public void findByDeletedSecurityPrincipalName() throws Exception { + RedisSession toSave = repository.createSession(); + toSave.setAttribute(SPRING_SECURITY_CONTEXT, context); + + repository.save(toSave); + + toSave.setAttribute(SPRING_SECURITY_CONTEXT, null); + repository.save(toSave); + + Map findByPrincipalName = repository.findByIndexNameAndIndexValue(INDEX_NAME, + getSecurityName()); + + assertThat(findByPrincipalName).isEmpty(); + } + + @Test + public void findByChangedSecurityPrincipalName() throws Exception { + RedisSession toSave = repository.createSession(); + toSave.setAttribute(SPRING_SECURITY_CONTEXT, context); + + repository.save(toSave); + + toSave.setAttribute(SPRING_SECURITY_CONTEXT, changedContext); + repository.save(toSave); + + Map findByPrincipalName = repository.findByIndexNameAndIndexValue(INDEX_NAME, + getSecurityName()); + assertThat(findByPrincipalName).isEmpty(); + + findByPrincipalName = repository.findByIndexNameAndIndexValue(INDEX_NAME, getChangedSecurityName()); + + assertThat(findByPrincipalName).hasSize(1); + assertThat(findByPrincipalName.keySet()).containsOnly(toSave.getId()); + } + + @Test + public void findByDeletedSecurityPrincipalNameReload() throws Exception { + RedisSession toSave = repository.createSession(); + toSave.setAttribute(SPRING_SECURITY_CONTEXT, context); + + repository.save(toSave); + + RedisSession getSession = repository.getSession(toSave.getId()); + getSession.setAttribute(INDEX_NAME, null); + repository.save(getSession); + + Map findByPrincipalName = repository.findByIndexNameAndIndexValue(INDEX_NAME, + getChangedSecurityName()); + + assertThat(findByPrincipalName).isEmpty(); + } + + @Test + public void findByChangedSecurityPrincipalNameReload() throws Exception { + RedisSession toSave = repository.createSession(); + toSave.setAttribute(SPRING_SECURITY_CONTEXT, context); + + repository.save(toSave); + + RedisSession getSession = repository.getSession(toSave.getId()); + + getSession.setAttribute(SPRING_SECURITY_CONTEXT, changedContext); + repository.save(getSession); + + Map findByPrincipalName = repository.findByIndexNameAndIndexValue(INDEX_NAME, + getSecurityName()); + assertThat(findByPrincipalName).isEmpty(); + + findByPrincipalName = repository.findByIndexNameAndIndexValue(INDEX_NAME, getChangedSecurityName()); + + assertThat(findByPrincipalName).hasSize(1); + assertThat(findByPrincipalName.keySet()).containsOnly(toSave.getId()); + } + + private String getSecurityName() { + return context.getAuthentication().getName(); + } + + private String getChangedSecurityName() { + return changedContext.getAuthentication().getName(); + } + @Configuration @EnableRedisHttpSession(redisNamespace = "RedisOperationsSessionRepositoryITests") static class Config { diff --git a/spring-session/src/main/java/org/springframework/session/data/gemfire/AbstractGemFireOperationsSessionRepository.java b/spring-session/src/main/java/org/springframework/session/data/gemfire/AbstractGemFireOperationsSessionRepository.java index 77d7d6e1..86cd377b 100644 --- a/spring-session/src/main/java/org/springframework/session/data/gemfire/AbstractGemFireOperationsSessionRepository.java +++ b/spring-session/src/main/java/org/springframework/session/data/gemfire/AbstractGemFireOperationsSessionRepository.java @@ -38,6 +38,8 @@ import org.springframework.context.ApplicationEventPublisher; import org.springframework.context.ApplicationEventPublisherAware; import org.springframework.data.gemfire.GemfireAccessor; import org.springframework.data.gemfire.GemfireOperations; +import org.springframework.expression.Expression; +import org.springframework.expression.spel.standard.SpelExpressionParser; import org.springframework.session.ExpiringSession; import org.springframework.session.FindByIndexNameSessionRepository; import org.springframework.session.Session; @@ -331,6 +333,10 @@ public abstract class AbstractGemFireOperationsSessionRepository extends CacheLi }); } + private String SPRING_SECURITY_CONTEXT = "SPRING_SECURITY_CONTEXT"; + + private SpelExpressionParser parser = new SpelExpressionParser(); + private transient boolean delta = false; private int maxInactiveIntervalInSeconds; @@ -463,7 +469,16 @@ public abstract class AbstractGemFireOperationsSessionRepository extends CacheLi /* (non-Javadoc) */ public synchronized String getPrincipalName() { - return getAttribute(FindByIndexNameSessionRepository.PRINCIPAL_NAME_INDEX_NAME); + String principalName = getAttribute(FindByIndexNameSessionRepository.PRINCIPAL_NAME_INDEX_NAME); + if(principalName != null) { + return principalName; + } + Object authentication = getAttribute(SPRING_SECURITY_CONTEXT); + if(authentication != null) { + Expression expression = parser.parseExpression("authentication?.name"); + return expression.getValue(authentication, String.class); + } + return null; } /* (non-Javadoc) */ diff --git a/spring-session/src/main/java/org/springframework/session/data/redis/RedisOperationsSessionRepository.java b/spring-session/src/main/java/org/springframework/session/data/redis/RedisOperationsSessionRepository.java index 5215c003..f81b8d89 100644 --- a/spring-session/src/main/java/org/springframework/session/data/redis/RedisOperationsSessionRepository.java +++ b/spring-session/src/main/java/org/springframework/session/data/redis/RedisOperationsSessionRepository.java @@ -252,6 +252,10 @@ import org.springframework.util.Assert; public class RedisOperationsSessionRepository implements FindByIndexNameSessionRepository, MessageListener { private static final Log logger = LogFactory.getLog(RedisOperationsSessionRepository.class); + private static final String SPRING_SECURITY_CONTEXT = "SPRING_SECURITY_CONTEXT"; + + static PrincipalNameResolver PRINCIPAL_NAME_RESOLVER = new PrincipalNameResolver(); + /** * The default prefix for each key and channel in Redis used by Spring Session */ @@ -475,7 +479,7 @@ public class RedisOperationsSessionRepository implements FindByIndexNameSessionR logger.debug("Publishing SessionDestroyedEvent for session " + sessionId); } - String principal = (String) session.getAttribute(FindByIndexNameSessionRepository.PRINCIPAL_NAME_INDEX_NAME); + String principal = PRINCIPAL_NAME_RESOLVER.resolvePrincipal(session); if(principal != null) { sessionRedisOperations.boundSetOps(getPrincipalKey(principal)).remove(sessionId); } @@ -629,7 +633,7 @@ public class RedisOperationsSessionRepository implements FindByIndexNameSessionR RedisSession(MapSession cached) { Assert.notNull("MapSession cannot be null"); this.cached = cached; - this.originalPrincipalName = cached.getAttribute(FindByIndexNameSessionRepository.PRINCIPAL_NAME_INDEX_NAME); + this.originalPrincipalName = PRINCIPAL_NAME_RESOLVER.resolvePrincipal(this); } public void setNew(boolean isNew) { @@ -695,17 +699,18 @@ public class RedisOperationsSessionRepository implements FindByIndexNameSessionR private void saveDelta() { String sessionId = getId(); getSessionBoundHashOperations(sessionId).putAll(delta); - String key = getSessionAttrNameKey(FindByIndexNameSessionRepository.PRINCIPAL_NAME_INDEX_NAME); - if(delta.containsKey(key)) { + String principalSessionKey = getSessionAttrNameKey(FindByIndexNameSessionRepository.PRINCIPAL_NAME_INDEX_NAME); + String securityPrincipalSessionKey = getSessionAttrNameKey(SPRING_SECURITY_CONTEXT); + if(delta.containsKey(principalSessionKey) || delta.containsKey(securityPrincipalSessionKey)) { if(originalPrincipalName != null) { - String originalPrincipalKey = getPrincipalKey((String) originalPrincipalName); - sessionRedisOperations.boundSetOps(originalPrincipalKey).remove(sessionId); + String originalPrincipalRedisKey = getPrincipalKey((String) originalPrincipalName); + sessionRedisOperations.boundSetOps(originalPrincipalRedisKey).remove(sessionId); } - String principal = (String) delta.get(key); + String principal = PRINCIPAL_NAME_RESOLVER.resolvePrincipal(this); originalPrincipalName = principal; if(principal != null) { - String principalKey = getPrincipalKey( principal); - sessionRedisOperations.boundSetOps(principalKey).add(sessionId); + String principalRedisKey = getPrincipalKey(principal); + sessionRedisOperations.boundSetOps(principalRedisKey).add(sessionId); } } @@ -716,28 +721,20 @@ public class RedisOperationsSessionRepository implements FindByIndexNameSessionR } } - class PrincipalNameResolver { - private static final String SPRING_SECURITY_CONTEXT = "SPRING_SECURITY_CONTEXT"; - private static final String NO_VALUE = "org.springframework.session.data.redis.$PrincipalNameResolver.NO_VALUE"; + static class PrincipalNameResolver { private SpelExpressionParser parser = new SpelExpressionParser(); - public String resolvePrincipal(Map delta) { - String key = getSessionAttrNameKey(FindByIndexNameSessionRepository.PRINCIPAL_NAME_INDEX_NAME); - if(delta.containsKey(key)) { - Object principal = delta.get(key); - return (String) principal; + public String resolvePrincipal(Session session) { + String principalName = session.getAttribute(PRINCIPAL_NAME_INDEX_NAME); + if(principalName != null) { + return principalName; } - key = getSessionAttrNameKey(SPRING_SECURITY_CONTEXT); - if(delta.containsKey(key)) { - Object authentication = delta.get(key); - if(authentication == null) { - return null; - } + Object authentication = session.getAttribute(SPRING_SECURITY_CONTEXT); + if(authentication != null) { Expression expression = parser.parseExpression("authentication?.name"); return expression.getValue(authentication, String.class); } - - return NO_VALUE; + return null; } } diff --git a/spring-session/src/test/java/org/springframework/session/data/redis/RedisOperationsSessionRepositoryTests.java b/spring-session/src/test/java/org/springframework/session/data/redis/RedisOperationsSessionRepositoryTests.java index 722150c5..eecbf575 100644 --- a/spring-session/src/test/java/org/springframework/session/data/redis/RedisOperationsSessionRepositoryTests.java +++ b/spring-session/src/test/java/org/springframework/session/data/redis/RedisOperationsSessionRepositoryTests.java @@ -51,8 +51,9 @@ import org.springframework.data.redis.core.BoundSetOperations; import org.springframework.data.redis.core.BoundValueOperations; import org.springframework.data.redis.core.RedisOperations; import org.springframework.data.redis.serializer.JdkSerializationRedisSerializer; -import org.springframework.security.authentication.TestingAuthenticationToken; +import org.springframework.security.authentication.UsernamePasswordAuthenticationToken; import org.springframework.security.core.Authentication; +import org.springframework.security.core.authority.AuthorityUtils; import org.springframework.security.core.context.SecurityContext; import org.springframework.security.core.context.SecurityContextImpl; import org.springframework.session.ExpiringSession; @@ -66,6 +67,8 @@ import org.springframework.session.events.AbstractSessionEvent; @RunWith(MockitoJUnitRunner.class) @SuppressWarnings({"unchecked","rawtypes"}) public class RedisOperationsSessionRepositoryTests { + static final String SPRING_SECURITY_CONTEXT_KEY = "SPRING_SECURITY_CONTEXT"; + @Mock RedisConnectionFactory factory; @Mock @@ -414,6 +417,31 @@ public class RedisOperationsSessionRepositoryTests { assertThat(event.getValue().getSessionId()).isEqualTo(session.getId()); } + @Test + public void resolvePrincipalIndex() { + PrincipalNameResolver resolver = RedisOperationsSessionRepository.PRINCIPAL_NAME_RESOLVER; + String username = "username"; + RedisSession session = redisRepository.createSession(); + session.setAttribute(FindByIndexNameSessionRepository.PRINCIPAL_NAME_INDEX_NAME, username); + + assertThat(resolver.resolvePrincipal(session)).isEqualTo(username); + } + + @Test + public void resolveIndexOnSecurityContext() { + String principal = "resolveIndexOnSecurityContext"; + Authentication authentication = new UsernamePasswordAuthenticationToken(principal, "notused", AuthorityUtils.createAuthorityList("ROLE_USER")); + SecurityContext context = new SecurityContextImpl(); + context.setAuthentication(authentication); + + PrincipalNameResolver resolver = RedisOperationsSessionRepository.PRINCIPAL_NAME_RESOLVER; + + RedisSession session = redisRepository.createSession(); + session.setAttribute(SPRING_SECURITY_CONTEXT_KEY, context); + + assertThat(resolver.resolvePrincipal(session)).isEqualTo(principal); + } + private String getKey(String id) { return "spring:session:sessions:" + id; }