From ec5969c5786b7810755a2fb0fb78d4e2764cfdf1 Mon Sep 17 00:00:00 2001 From: Rossen Stoyanchev Date: Wed, 27 Sep 2017 01:43:11 -0400 Subject: [PATCH] InMemoryWebSession cleans up expired sessions Issue: SPR-15963 --- .../session/InMemoryWebSessionStore.java | 64 +++++++++++++++---- .../session/InMemoryWebSessionStoreTests.java | 46 ++++++++++++- 2 files changed, 96 insertions(+), 14 deletions(-) diff --git a/spring-web/src/main/java/org/springframework/web/server/session/InMemoryWebSessionStore.java b/spring-web/src/main/java/org/springframework/web/server/session/InMemoryWebSessionStore.java index 6192d1304d..d65cfaf202 100644 --- a/spring-web/src/main/java/org/springframework/web/server/session/InMemoryWebSessionStore.java +++ b/spring-web/src/main/java/org/springframework/web/server/session/InMemoryWebSessionStore.java @@ -20,9 +20,12 @@ import java.time.Clock; import java.time.Duration; import java.time.Instant; import java.time.ZoneId; +import java.util.Iterator; import java.util.Map; import java.util.concurrent.ConcurrentHashMap; +import java.util.concurrent.ConcurrentMap; import java.util.concurrent.atomic.AtomicReference; +import java.util.concurrent.locks.ReentrantLock; import reactor.core.publisher.Mono; @@ -40,12 +43,20 @@ import org.springframework.web.server.WebSession; */ public class InMemoryWebSessionStore implements WebSessionStore { + /** Minimum period between expiration checks */ + private static final Duration EXPIRATION_CHECK_PERIOD = Duration.ofSeconds(60); + + private static final IdGenerator idGenerator = new JdkIdGenerator(); private Clock clock = Clock.system(ZoneId.of("GMT")); - private final Map sessions = new ConcurrentHashMap<>(); + private final ConcurrentMap sessions = new ConcurrentHashMap<>(); + + private volatile Instant nextExpirationCheckTime = Instant.now(this.clock).plus(EXPIRATION_CHECK_PERIOD); + + private final ReentrantLock expirationCheckLock = new ReentrantLock(); /** @@ -60,6 +71,8 @@ public class InMemoryWebSessionStore implements WebSessionStore { public void setClock(Clock clock) { Assert.notNull(clock, "Clock is required"); this.clock = clock; + // Force a check when clock changes.. + this.nextExpirationCheckTime = Instant.now(this.clock); } /** @@ -77,20 +90,46 @@ public class InMemoryWebSessionStore implements WebSessionStore { @Override public Mono retrieveSession(String id) { + + Instant currentTime = Instant.now(this.clock); + + if (!this.sessions.isEmpty() && !currentTime.isBefore(this.nextExpirationCheckTime)) { + checkExpiredSessions(currentTime); + } + InMemoryWebSession session = this.sessions.get(id); if (session == null) { return Mono.empty(); } - else if (session.isExpired()) { + else if (session.isExpired(currentTime)) { this.sessions.remove(id); return Mono.empty(); } else { - session.updateLastAccessTime(); + session.updateLastAccessTime(currentTime); return Mono.just(session); } } + private void checkExpiredSessions(Instant currentTime) { + if (this.expirationCheckLock.tryLock()) { + try { + Iterator iterator = this.sessions.values().iterator(); + while (iterator.hasNext()) { + InMemoryWebSession session = iterator.next(); + if (session.isExpired(currentTime)) { + iterator.remove(); + session.invalidate(); + } + } + } + finally { + this.nextExpirationCheckTime = currentTime.plus(EXPIRATION_CHECK_PERIOD); + this.expirationCheckLock.unlock(); + } + } + } + @Override public Mono removeSession(String id) { this.sessions.remove(id); @@ -101,7 +140,7 @@ public class InMemoryWebSessionStore implements WebSessionStore { return Mono.fromSupplier(() -> { Assert.isInstanceOf(InMemoryWebSession.class, webSession); InMemoryWebSession session = (InMemoryWebSession) webSession; - session.updateLastAccessTime(); + session.updateLastAccessTime(Instant.now(getClock())); return session; }); } @@ -122,7 +161,7 @@ public class InMemoryWebSessionStore implements WebSessionStore { private final AtomicReference state = new AtomicReference<>(State.NEW); - InMemoryWebSession() { + public InMemoryWebSession() { this.creationTime = Instant.now(getClock()); this.lastAccessTime = this.creationTime; } @@ -201,25 +240,28 @@ public class InMemoryWebSessionStore implements WebSessionStore { @Override public boolean isExpired() { + return isExpired(Instant.now(getClock())); + } + + private boolean isExpired(Instant currentTime) { if (this.state.get().equals(State.EXPIRED)) { return true; } - if (checkExpired()) { + if (checkExpired(currentTime)) { this.state.set(State.EXPIRED); return true; } return false; } - private boolean checkExpired() { + private boolean checkExpired(Instant currentTime) { return isStarted() && !this.maxIdleTime.isNegative() && - Instant.now(getClock()).minus(this.maxIdleTime).isAfter(this.lastAccessTime); + currentTime.minus(this.maxIdleTime).isAfter(this.lastAccessTime); } - private void updateLastAccessTime() { - this.lastAccessTime = Instant.now(getClock()); + private void updateLastAccessTime(Instant currentTime) { + this.lastAccessTime = currentTime; } - } private enum State { NEW, STARTED, EXPIRED } diff --git a/spring-web/src/test/java/org/springframework/web/server/session/InMemoryWebSessionStoreTests.java b/spring-web/src/test/java/org/springframework/web/server/session/InMemoryWebSessionStoreTests.java index d0bf5cce5a..f2713ad5f8 100644 --- a/spring-web/src/test/java/org/springframework/web/server/session/InMemoryWebSessionStoreTests.java +++ b/spring-web/src/test/java/org/springframework/web/server/session/InMemoryWebSessionStoreTests.java @@ -21,6 +21,7 @@ import java.time.Instant; import org.junit.Test; +import org.springframework.util.Assert; import org.springframework.web.server.WebSession; import static junit.framework.TestCase.assertSame; @@ -59,7 +60,7 @@ public class InMemoryWebSessionStoreTests { WebSession session = this.store.createWebSession().block(); assertNotNull(session); session.getAttributes().put("foo", "bar"); - session.save(); + session.save().block(); String id = session.getId(); WebSession retrieved = this.store.retrieveSession(id).block(); @@ -78,7 +79,7 @@ public class InMemoryWebSessionStoreTests { assertNotNull(session1); String id = session1.getId(); Instant time1 = session1.getLastAccessTime(); - session1.save(); + session1.save().block(); // Fast-forward a few seconds this.store.setClock(Clock.offset(this.store.getClock(), Duration.ofSeconds(5))); @@ -91,7 +92,46 @@ public class InMemoryWebSessionStoreTests { } @Test - public void invalidate() throws Exception { + public void expirationChecks() throws Exception { + // Create 3 sessions + WebSession session1 = this.store.createWebSession().block(); + assertNotNull(session1); + session1.start(); + session1.save().block(); + WebSession session2 = this.store.createWebSession().block(); + assertNotNull(session2); + session2.start(); + session2.save().block(); + + WebSession session3 = this.store.createWebSession().block(); + assertNotNull(session3); + session3.start(); + session3.save().block(); + + // Fast-forward 31 minutes + this.store.setClock(Clock.offset(this.store.getClock(), Duration.ofMinutes(31))); + + // Create 2 more sessions + WebSession session4 = this.store.createWebSession().block(); + assertNotNull(session4); + session4.start(); + session4.save().block(); + + WebSession session5 = this.store.createWebSession().block(); + assertNotNull(session5); + session5.start(); + session5.save().block(); + + // Retrieve, forcing cleanup of all expired.. + assertNull(this.store.retrieveSession(session1.getId()).block()); + assertNull(this.store.retrieveSession(session2.getId()).block()); + assertNull(this.store.retrieveSession(session3.getId()).block()); + + assertNotNull(this.store.retrieveSession(session4.getId()).block()); + assertNotNull(this.store.retrieveSession(session5.getId()).block()); } + + + } \ No newline at end of file