diff --git a/spring-session-data-redis/src/integration-test/java/org/springframework/session/data/SessionEventRegistry.java b/spring-session-data-redis/src/integration-test/java/org/springframework/session/data/SessionEventRegistry.java index a82d4756..1f080dee 100644 --- a/spring-session-data-redis/src/integration-test/java/org/springframework/session/data/SessionEventRegistry.java +++ b/spring-session-data-redis/src/integration-test/java/org/springframework/session/data/SessionEventRegistry.java @@ -16,24 +16,29 @@ package org.springframework.session.data; +import java.util.ArrayList; import java.util.HashMap; +import java.util.List; import java.util.Map; import java.util.concurrent.ConcurrentHashMap; import java.util.concurrent.ConcurrentMap; +import java.util.concurrent.TimeUnit; + +import org.jetbrains.annotations.Nullable; import org.springframework.context.ApplicationListener; import org.springframework.session.events.AbstractSessionEvent; public class SessionEventRegistry implements ApplicationListener { - private Map events = new HashMap<>(); + private Map> events = new HashMap<>(); private ConcurrentMap locks = new ConcurrentHashMap<>(); @Override public void onApplicationEvent(AbstractSessionEvent event) { String sessionId = event.getSessionId(); - this.events.put(sessionId, event); + this.events.computeIfAbsent(sessionId, (key) -> new ArrayList<>()).add(event); Object lock = getLock(sessionId); synchronized (lock) { lock.notifyAll(); @@ -45,24 +50,41 @@ public class SessionEventRegistry implements ApplicationListener boolean receivedEvent(String sessionId, Class type) + throws InterruptedException { + return waitForEvent(sessionId, type) != null; } @SuppressWarnings("unchecked") - public E getEvent(String sessionId) throws InterruptedException { - return (E) waitForEvent(sessionId); - } - - @SuppressWarnings("unchecked") - private E waitForEvent(String sessionId) throws InterruptedException { + public E waitForEvent(String sessionId, Class type) + throws InterruptedException { Object lock = getLock(sessionId); + long waitInMs = TimeUnit.SECONDS.toMillis(10); + long start = System.currentTimeMillis(); + boolean doneWaiting = false; synchronized (lock) { - if (!this.events.containsKey(sessionId)) { - lock.wait(10000); + while (!doneWaiting) { + E result = getEvent(sessionId, type); + if (result == null) { + // wait until timeout or notified + // might need to continue trying if the notification + // was for a different event + lock.wait(waitInMs); + } + long now = System.currentTimeMillis(); + doneWaiting = (now - start) >= waitInMs; } + return getEvent(sessionId, type); } - return (E) this.events.get(sessionId); + } + + private @Nullable E getEvent(String sessionId, Class type) { + List events = this.events.get(sessionId); + E result = (events != null) ? (E) events.stream() + .filter((event) -> type.isAssignableFrom(event.getClass())) + .findFirst() + .orElse(null) : null; + return result; } private Object getLock(String sessionId) { diff --git a/spring-session-data-redis/src/integration-test/java/org/springframework/session/data/redis/ReactiveRedisIndexedSessionRepositoryConfigurationITests.java b/spring-session-data-redis/src/integration-test/java/org/springframework/session/data/redis/ReactiveRedisIndexedSessionRepositoryConfigurationITests.java index 11aa3c95..458934d5 100644 --- a/spring-session-data-redis/src/integration-test/java/org/springframework/session/data/redis/ReactiveRedisIndexedSessionRepositoryConfigurationITests.java +++ b/spring-session-data-redis/src/integration-test/java/org/springframework/session/data/redis/ReactiveRedisIndexedSessionRepositoryConfigurationITests.java @@ -101,7 +101,7 @@ class ReactiveRedisIndexedSessionRepositoryConfigurationITests { RedisSession session = this.repository.createSession().block(); this.repository.save(session).block(); SessionEventRegistry registry = this.context.getBean(SessionEventRegistry.class); - SessionCreatedEvent event = registry.getEvent(session.getId()); + SessionCreatedEvent event = registry.waitForEvent(session.getId(), SessionCreatedEvent.class); Session eventSession = event.getSession(); assertThat(eventSession).usingRecursiveComparison() .withComparatorForFields(new InstantComparator(), "cached.creationTime", "cached.lastAccessedTime") diff --git a/spring-session-data-redis/src/integration-test/java/org/springframework/session/data/redis/ReactiveRedisIndexedSessionRepositoryITests.java b/spring-session-data-redis/src/integration-test/java/org/springframework/session/data/redis/ReactiveRedisIndexedSessionRepositoryITests.java index b59dc533..0c86527c 100644 --- a/spring-session-data-redis/src/integration-test/java/org/springframework/session/data/redis/ReactiveRedisIndexedSessionRepositoryITests.java +++ b/spring-session-data-redis/src/integration-test/java/org/springframework/session/data/redis/ReactiveRedisIndexedSessionRepositoryITests.java @@ -124,7 +124,7 @@ class ReactiveRedisIndexedSessionRepositoryITests { this.repository.save(session).block(); - SessionCreatedEvent event = this.eventRegistry.getEvent(session.getId()); + SessionCreatedEvent event = this.eventRegistry.waitForEvent(session.getId(), SessionCreatedEvent.class); assertThat(event).isNotNull(); RedisSession eventSession = event.getSession(); compareSessions(session, eventSession); @@ -168,7 +168,7 @@ class ReactiveRedisIndexedSessionRepositoryITests { assertThat(this.redis.expire(key, Duration.ofSeconds(1)).block()).isTrue(); await().atMost(Duration.ofSeconds(3)).untilAsserted(() -> { - SessionExpiredEvent event = this.eventRegistry.getEvent(toSave.getId()); + SessionExpiredEvent event = this.eventRegistry.waitForEvent(toSave.getId(), SessionExpiredEvent.class); RedisSession eventSession = event.getSession(); Map findByPrincipalName = this.repository .findByIndexNameAndIndexValue(INDEX_NAME, principalName) @@ -206,7 +206,7 @@ class ReactiveRedisIndexedSessionRepositoryITests { .block(); assertThat(findByPrincipalName).hasSize(0); assertThat(findByPrincipalName.keySet()).doesNotContain(toSave.getId()); - SessionDeletedEvent event = this.eventRegistry.getEvent(toSave.getId()); + SessionDeletedEvent event = this.eventRegistry.waitForEvent(toSave.getId(), SessionDeletedEvent.class); assertThat(event).isNotNull(); RedisSession eventSession = event.getSession(); compareSessions(toSave, eventSession); diff --git a/spring-session-data-redis/src/integration-test/java/org/springframework/session/data/redis/RedisIndexedSessionRepositoryITests.java b/spring-session-data-redis/src/integration-test/java/org/springframework/session/data/redis/RedisIndexedSessionRepositoryITests.java index b9798a95..a7f44fd1 100644 --- a/spring-session-data-redis/src/integration-test/java/org/springframework/session/data/redis/RedisIndexedSessionRepositoryITests.java +++ b/spring-session-data-redis/src/integration-test/java/org/springframework/session/data/redis/RedisIndexedSessionRepositoryITests.java @@ -112,10 +112,10 @@ class RedisIndexedSessionRepositoryITests extends AbstractRedisITests { this.repository.save(toSave); - assertThat(this.registry.receivedEvent(toSave.getId())).isTrue(); + assertThat(this.registry.receivedEvent(toSave.getId(), SessionCreatedEvent.class)).isTrue(); assertThat(this.redis.boundSetOps(usernameSessionKey).members()).contains(toSave.getId()); - SessionCreatedEvent createdEvent = this.registry.getEvent(toSave.getId()); + SessionCreatedEvent createdEvent = this.registry.waitForEvent(toSave.getId(), SessionCreatedEvent.class); Session session = createdEvent.getSession(); assertThat(session.getId()).isEqualTo(toSave.getId()); @@ -128,11 +128,10 @@ class RedisIndexedSessionRepositoryITests extends AbstractRedisITests { this.repository.deleteById(toSave.getId()); assertThat(this.repository.findById(toSave.getId())).isNull(); - assertThat(this.registry.getEvent(toSave.getId())) - .isInstanceOf(SessionDestroyedEvent.class); assertThat(this.redis.boundSetOps(usernameSessionKey).members()).doesNotContain(toSave.getId()); - assertThat(this.registry.getEvent(toSave.getId()).getSession().getAttribute(expectedAttributeName)) + SessionDestroyedEvent destroyedEvent = this.registry.waitForEvent(toSave.getId(), SessionDestroyedEvent.class); + assertThat(destroyedEvent.getSession().getAttribute(expectedAttributeName)) .isEqualTo(expectedAttributeValue); } @@ -188,7 +187,8 @@ class RedisIndexedSessionRepositoryITests extends AbstractRedisITests { assertThat(findByPrincipalName.keySet()).containsOnly(toSave.getId()); this.repository.deleteById(toSave.getId()); - assertThat(this.registry.receivedEvent(toSave.getId())).isTrue(); + boolean sessionDestroyed = this.registry.receivedEvent(toSave.getId(), SessionDestroyedEvent.class); + assertThat(sessionDestroyed).isTrue(); findByPrincipalName = this.repository.findByIndexNameAndIndexValue(INDEX_NAME, principalName); @@ -351,7 +351,8 @@ class RedisIndexedSessionRepositoryITests extends AbstractRedisITests { assertThat(findByPrincipalName.keySet()).containsOnly(toSave.getId()); this.repository.deleteById(toSave.getId()); - assertThat(this.registry.receivedEvent(toSave.getId())).isTrue(); + boolean sessionDestroyed = this.registry.receivedEvent(toSave.getId(), SessionDestroyedEvent.class); + assertThat(sessionDestroyed).isTrue(); findByPrincipalName = this.repository.findByIndexNameAndIndexValue(INDEX_NAME, getSecurityName());