diff --git a/spring-web/src/main/java/org/springframework/web/server/session/DefaultWebSessionManager.java b/spring-web/src/main/java/org/springframework/web/server/session/DefaultWebSessionManager.java index fa50120554..50555d1dc4 100644 --- a/spring-web/src/main/java/org/springframework/web/server/session/DefaultWebSessionManager.java +++ b/spring-web/src/main/java/org/springframework/web/server/session/DefaultWebSessionManager.java @@ -80,7 +80,6 @@ public class DefaultWebSessionManager implements WebSessionManager { public Mono getSession(ServerWebExchange exchange) { return Mono.defer(() -> retrieveSession(exchange) - .flatMap(session -> removeSessionIfExpired(exchange, session)) .flatMap(this.getSessionStore()::updateLastAccessTime) .switchIfEmpty(this.sessionStore.createWebSession()) .doOnNext(session -> exchange.getResponse().beforeCommit(() -> save(exchange, session)))); @@ -92,14 +91,6 @@ public class DefaultWebSessionManager implements WebSessionManager { .next(); } - private Mono removeSessionIfExpired(ServerWebExchange exchange, WebSession session) { - if (session.isExpired()) { - this.sessionIdResolver.expireSession(exchange); - return this.sessionStore.removeSession(session.getId()).then(Mono.empty()); - } - return Mono.just(session); - } - private Mono save(ServerWebExchange exchange, WebSession session) { if (session.isExpired()) { return Mono.error(new IllegalStateException( @@ -110,11 +101,14 @@ public class DefaultWebSessionManager implements WebSessionManager { } if (!session.isStarted()) { + if (hasNewSessionId(exchange, session)) { + this.sessionIdResolver.expireSession(exchange); + } return Mono.empty(); } if (hasNewSessionId(exchange, session)) { - DefaultWebSessionManager.this.sessionIdResolver.setSessionId(exchange, session.getId()); + this.sessionIdResolver.setSessionId(exchange, session.getId()); } return session.save(); diff --git a/spring-web/src/main/java/org/springframework/web/server/session/InMemoryWebSessionStore.java b/spring-web/src/main/java/org/springframework/web/server/session/InMemoryWebSessionStore.java index f3a5758596..b8db7090b5 100644 --- a/spring-web/src/main/java/org/springframework/web/server/session/InMemoryWebSessionStore.java +++ b/spring-web/src/main/java/org/springframework/web/server/session/InMemoryWebSessionStore.java @@ -77,7 +77,17 @@ public class InMemoryWebSessionStore implements WebSessionStore { @Override public Mono retrieveSession(String id) { - return (this.sessions.containsKey(id) ? Mono.just(this.sessions.get(id)) : Mono.empty()); + WebSession session = this.sessions.get(id); + if (session == null) { + return Mono.empty(); + } + else if (session.isExpired()) { + this.sessions.remove(id); + return Mono.empty(); + } + else { + return Mono.just(session); + } } @Override diff --git a/spring-web/src/main/java/org/springframework/web/server/session/WebSessionStore.java b/spring-web/src/main/java/org/springframework/web/server/session/WebSessionStore.java index 4ff61176b3..14505425ad 100644 --- a/spring-web/src/main/java/org/springframework/web/server/session/WebSessionStore.java +++ b/spring-web/src/main/java/org/springframework/web/server/session/WebSessionStore.java @@ -40,8 +40,10 @@ public interface WebSessionStore { /** * Return the WebSession for the given id. + *

Note: This method should perform an expiration check, + * remove the session if it has expired and return empty. * @param sessionId the session to load - * @return the session, or an empty {@code Mono}. + * @return the session, or an empty {@code Mono} . */ Mono retrieveSession(String sessionId); diff --git a/spring-web/src/test/java/org/springframework/web/server/session/DefaultWebSessionManagerTests.java b/spring-web/src/test/java/org/springframework/web/server/session/DefaultWebSessionManagerTests.java index 3bfe7ccf5a..0e0883d08b 100644 --- a/spring-web/src/test/java/org/springframework/web/server/session/DefaultWebSessionManagerTests.java +++ b/spring-web/src/test/java/org/springframework/web/server/session/DefaultWebSessionManagerTests.java @@ -40,7 +40,6 @@ import static org.mockito.ArgumentMatchers.any; import static org.mockito.ArgumentMatchers.eq; import static org.mockito.Mockito.never; import static org.mockito.Mockito.verify; -import static org.mockito.Mockito.verifyZeroInteractions; import static org.mockito.Mockito.when; /** @@ -64,20 +63,16 @@ public class DefaultWebSessionManagerTests { @Mock private WebSession createSession; - @Mock - private WebSession retrieveSession; - @Mock private WebSession updateSession; + @Before public void setUp() throws Exception { when(this.store.createWebSession()).thenReturn(Mono.just(this.createSession)); when(this.store.updateLastAccessTime(any())).thenReturn(Mono.just(this.updateSession)); - when(this.store.retrieveSession(any())).thenReturn(Mono.just(this.retrieveSession)); when(this.createSession.save()).thenReturn(Mono.empty()); when(this.updateSession.getId()).thenReturn("update-session-id"); - when(this.retrieveSession.getId()).thenReturn("retrieve-session-id"); this.manager = new DefaultWebSessionManager(); this.manager.setSessionIdResolver(this.idResolver); @@ -97,7 +92,6 @@ public class DefaultWebSessionManagerTests { assertFalse(session.isStarted()); assertFalse(session.isExpired()); - verifyZeroInteractions(this.retrieveSession, this.updateSession); verify(this.createSession, never()).save(); verify(this.idResolver, never()).setSessionId(any(), any()); } @@ -138,19 +132,6 @@ public class DefaultWebSessionManagerTests { assertEquals(id, actual.getId()); } - @Test - public void existingSessionIsExpired() throws Exception { - String id = this.retrieveSession.getId(); - when(this.retrieveSession.isExpired()).thenReturn(true); - when(this.idResolver.resolveSessionIds(this.exchange)).thenReturn(Collections.singletonList(id)); - when(this.store.removeSession(any())).thenReturn(Mono.empty()); - - WebSession actual = this.manager.getSession(this.exchange).block(); - assertEquals(this.createSession.getId(), actual.getId()); - verify(this.store).removeSession(id); - verify(this.idResolver).expireSession(any()); - } - @Test public void multipleSessionIds() throws Exception { WebSession existing = this.updateSession; diff --git a/spring-web/src/test/java/org/springframework/web/server/session/InMemoryWebSessionStoreTests.java b/spring-web/src/test/java/org/springframework/web/server/session/InMemoryWebSessionStoreTests.java index efc92a315a..7c3cf21ce4 100644 --- a/spring-web/src/test/java/org/springframework/web/server/session/InMemoryWebSessionStoreTests.java +++ b/spring-web/src/test/java/org/springframework/web/server/session/InMemoryWebSessionStoreTests.java @@ -15,11 +15,16 @@ */ package org.springframework.web.server.session; +import java.time.Clock; +import java.time.Duration; + import org.junit.Test; import org.springframework.web.server.WebSession; +import static junit.framework.TestCase.assertSame; import static org.junit.Assert.assertNotNull; +import static org.junit.Assert.assertNull; import static org.junit.Assert.assertTrue; /** @@ -28,34 +33,34 @@ import static org.junit.Assert.assertTrue; */ public class InMemoryWebSessionStoreTests { - private InMemoryWebSessionStore sessionStore = new InMemoryWebSessionStore(); + private InMemoryWebSessionStore store = new InMemoryWebSessionStore(); @Test public void constructorWhenImplicitStartCopiedThenCopyIsStarted() { - WebSession original = this.sessionStore.createWebSession().block(); + WebSession original = this.store.createWebSession().block(); assertNotNull(original); original.getAttributes().put("foo", "bar"); - WebSession copy = this.sessionStore.updateLastAccessTime(original).block(); + WebSession copy = this.store.updateLastAccessTime(original).block(); assertNotNull(copy); assertTrue(copy.isStarted()); } @Test public void constructorWhenExplicitStartCopiedThenCopyIsStarted() { - WebSession original = this.sessionStore.createWebSession().block(); + WebSession original = this.store.createWebSession().block(); assertNotNull(original); original.start(); - WebSession copy = this.sessionStore.updateLastAccessTime(original).block(); + WebSession copy = this.store.updateLastAccessTime(original).block(); assertNotNull(copy); assertTrue(copy.isStarted()); } @Test public void startsSessionExplicitly() { - WebSession session = this.sessionStore.createWebSession().block(); + WebSession session = this.store.createWebSession().block(); assertNotNull(session); session.start(); assertTrue(session.isStarted()); @@ -63,11 +68,27 @@ public class InMemoryWebSessionStoreTests { @Test public void startsSessionImplicitly() { - WebSession session = this.sessionStore.createWebSession().block(); + WebSession session = this.store.createWebSession().block(); assertNotNull(session); session.start(); session.getAttributes().put("foo", "bar"); assertTrue(session.isStarted()); } + @Test + public void retrieveExpiredSession() throws Exception { + WebSession session = this.store.createWebSession().block(); + assertNotNull(session); + session.getAttributes().put("foo", "bar"); + session.save(); + + String id = session.getId(); + WebSession retrieved = this.store.retrieveSession(id).block(); + assertNotNull(retrieved); + assertSame(session, retrieved); + + this.store.setClock(Clock.offset(this.store.getClock(), Duration.ofMinutes(31))); + WebSession retrievedAgain = this.store.retrieveSession(id).block(); + assertNull(retrievedAgain); + } } \ No newline at end of file diff --git a/spring-web/src/test/java/org/springframework/web/server/session/WebSessionIntegrationTests.java b/spring-web/src/test/java/org/springframework/web/server/session/WebSessionIntegrationTests.java index 48958d75ad..f862f0d904 100644 --- a/spring-web/src/test/java/org/springframework/web/server/session/WebSessionIntegrationTests.java +++ b/spring-web/src/test/java/org/springframework/web/server/session/WebSessionIntegrationTests.java @@ -43,6 +43,7 @@ import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertNotEquals; import static org.junit.Assert.assertNotNull; import static org.junit.Assert.assertNull; +import static org.junit.Assert.assertTrue; /** * Integration tests for with a server-side session. @@ -109,7 +110,7 @@ public class WebSessionIntegrationTests extends AbstractHttpHandlerIntegrationTe assertNull(response.getHeaders().get("Set-Cookie")); assertEquals(2, this.handler.getSessionRequestCount()); - // Now set the clock of the session back by 31 minutes + // Now fast-forward by 31 minutes InMemoryWebSessionStore store = (InMemoryWebSessionStore) this.sessionManager.getSessionStore(); WebSession session = store.retrieveSession(id).block(); assertNotNull(session); @@ -125,6 +126,33 @@ public class WebSessionIntegrationTests extends AbstractHttpHandlerIntegrationTe assertEquals(1, this.handler.getSessionRequestCount()); } + @Test + public void expiredSessionEnds() throws Exception { + + // First request: no session yet, new session created + RequestEntity request = RequestEntity.get(createUri()).build(); + ResponseEntity response = this.restTemplate.exchange(request, Void.class); + + assertEquals(HttpStatus.OK, response.getStatusCode()); + String id = extractSessionId(response.getHeaders()); + assertNotNull(id); + assertEquals(1, this.handler.getSessionRequestCount()); + + // Now fast-forward by 31 minutes + InMemoryWebSessionStore store = (InMemoryWebSessionStore) this.sessionManager.getSessionStore(); + store.setClock(Clock.offset(store.getClock(), Duration.ofMinutes(31))); + + // Second request: session expires + URI uri = new URI("http://localhost:" + this.port + "/?expiredSession"); + request = RequestEntity.get(uri).header("Cookie", "SESSION=" + id).build(); + response = this.restTemplate.exchange(request, Void.class); + + assertEquals(HttpStatus.OK, response.getStatusCode()); + String value = response.getHeaders().getFirst("Set-Cookie"); + assertNotNull(value); + assertTrue("Actual value: " + value, value.contains("Max-Age=0")); + } + @Test public void changeSessionId() throws Exception { @@ -178,11 +206,18 @@ public class WebSessionIntegrationTests extends AbstractHttpHandlerIntegrationTe @Override public Mono handle(ServerWebExchange exchange) { - if (exchange.getRequest().getQueryParams().containsKey("changeId")) { + if (exchange.getRequest().getQueryParams().containsKey("expiredSession")) { + return exchange.getSession().doOnNext(session -> { + // Don't do anything, leave it expired... + }).then(); + } + else if (exchange.getRequest().getQueryParams().containsKey("changeId")) { return exchange.getSession().flatMap(session -> session.changeSessionId().doOnSuccess(aVoid -> updateSessionAttribute(session))); } - return exchange.getSession().doOnSuccess(this::updateSessionAttribute).then(); + else { + return exchange.getSession().doOnSuccess(this::updateSessionAttribute).then(); + } } private void updateSessionAttribute(WebSession session) {