From a5ce8ae87f7f12ffb41602fd381dcf498a2ff766 Mon Sep 17 00:00:00 2001 From: Marcus Hert Da Coregio Date: Tue, 27 Feb 2024 11:12:41 -0300 Subject: [PATCH] Polish Max Sessions on WebFlux This commit changes the PreventLoginServerMaximumSessionsExceededHandler to invalidate the WebSession in addition to throwing the error, this is needed otherwise the session would still be saved with the security context. It also changes the SessionRegistryWebSession to first perform the operation on the delegate and then invoke the needed method on the ReactiveSessionRegistry Issue gh-6192 --- .../config/web/server/ServerHttpSecurity.java | 20 +++++++++------- .../server/SessionManagementSpecTests.java | 10 ++++++-- ...rolServerAuthenticationSuccessHandler.java | 4 ++-- .../MaximumSessionsContext.java | 12 ++++++++-- ...nServerMaximumSessionsExceededHandler.java | 8 +++---- ...erMaximumSessionsExceededHandlerTests.java | 6 ++--- ...erMaximumSessionsExceededHandlerTests.java | 24 +++++++++++++------ 7 files changed, 56 insertions(+), 28 deletions(-) diff --git a/config/src/main/java/org/springframework/security/config/web/server/ServerHttpSecurity.java b/config/src/main/java/org/springframework/security/config/web/server/ServerHttpSecurity.java index 1a8d2e3a73..1939a4f157 100644 --- a/config/src/main/java/org/springframework/security/config/web/server/ServerHttpSecurity.java +++ b/config/src/main/java/org/springframework/security/config/web/server/ServerHttpSecurity.java @@ -2143,19 +2143,23 @@ public class ServerHttpSecurity { @Override public Mono changeSessionId() { String currentId = this.session.getId(); - return SessionRegistryWebFilter.this.sessionRegistry.removeSessionInformation(currentId) - .flatMap((information) -> this.session.changeSessionId().thenReturn(information)) - .flatMap((information) -> { - information = information.withSessionId(this.session.getId()); - return SessionRegistryWebFilter.this.sessionRegistry.saveSessionInformation(information); - }); + return this.session.changeSessionId() + .then(Mono.defer( + () -> SessionRegistryWebFilter.this.sessionRegistry.removeSessionInformation(currentId) + .flatMap((information) -> { + information = information.withSessionId(this.session.getId()); + return SessionRegistryWebFilter.this.sessionRegistry + .saveSessionInformation(information); + }))); } @Override public Mono invalidate() { String currentId = this.session.getId(); - return SessionRegistryWebFilter.this.sessionRegistry.removeSessionInformation(currentId) - .flatMap((information) -> this.session.invalidate()); + return this.session.invalidate() + .then(Mono.defer(() -> SessionRegistryWebFilter.this.sessionRegistry + .removeSessionInformation(currentId))) + .then(); } @Override diff --git a/config/src/test/java/org/springframework/security/config/web/server/SessionManagementSpecTests.java b/config/src/test/java/org/springframework/security/config/web/server/SessionManagementSpecTests.java index 089a3916d4..22a5eee49c 100644 --- a/config/src/test/java/org/springframework/security/config/web/server/SessionManagementSpecTests.java +++ b/config/src/test/java/org/springframework/security/config/web/server/SessionManagementSpecTests.java @@ -67,6 +67,7 @@ import org.springframework.web.reactive.function.BodyInserters; import org.springframework.web.server.adapter.WebHttpHandlerBuilder; import org.springframework.web.server.session.DefaultWebSessionManager; +import static org.assertj.core.api.Assertions.assertThat; import static org.mockito.ArgumentMatchers.any; import static org.mockito.BDDMockito.given; import static org.mockito.Mockito.mock; @@ -95,14 +96,19 @@ public class SessionManagementSpecTests { ResponseCookie firstLoginSessionCookie = loginReturningCookie(data); // second login should fail - this.client.mutateWith(csrf()) + ResponseCookie secondLoginSessionCookie = this.client.mutateWith(csrf()) .post() .uri("/login") .contentType(MediaType.MULTIPART_FORM_DATA) .body(BodyInserters.fromFormData(data)) .exchange() .expectHeader() - .location("/login?error"); + .location("/login?error") + .returnResult(Void.class) + .getResponseCookies() + .getFirst("SESSION"); + + assertThat(secondLoginSessionCookie).isNull(); // first login should still be valid this.client.mutateWith(csrf()) diff --git a/web/src/main/java/org/springframework/security/web/server/authentication/ConcurrentSessionControlServerAuthenticationSuccessHandler.java b/web/src/main/java/org/springframework/security/web/server/authentication/ConcurrentSessionControlServerAuthenticationSuccessHandler.java index ca777da888..556bf042df 100644 --- a/web/src/main/java/org/springframework/security/web/server/authentication/ConcurrentSessionControlServerAuthenticationSuccessHandler.java +++ b/web/src/main/java/org/springframework/security/web/server/authentication/ConcurrentSessionControlServerAuthenticationSuccessHandler.java @@ -81,8 +81,8 @@ public final class ConcurrentSessionControlServerAuthenticationSuccessHandler } } } - return this.maximumSessionsExceededHandler - .handle(new MaximumSessionsContext(authentication, registeredSessions, maximumSessions)); + return this.maximumSessionsExceededHandler.handle(new MaximumSessionsContext(authentication, + registeredSessions, maximumSessions, currentSession)); }); } diff --git a/web/src/main/java/org/springframework/security/web/server/authentication/MaximumSessionsContext.java b/web/src/main/java/org/springframework/security/web/server/authentication/MaximumSessionsContext.java index 0875051b78..9ba11bc17d 100644 --- a/web/src/main/java/org/springframework/security/web/server/authentication/MaximumSessionsContext.java +++ b/web/src/main/java/org/springframework/security/web/server/authentication/MaximumSessionsContext.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2023 the original author or authors. + * Copyright 2002-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -20,6 +20,7 @@ import java.util.List; import org.springframework.security.core.Authentication; import org.springframework.security.core.session.ReactiveSessionInformation; +import org.springframework.web.server.WebSession; public final class MaximumSessionsContext { @@ -29,11 +30,14 @@ public final class MaximumSessionsContext { private final int maximumSessionsAllowed; + private final WebSession currentSession; + public MaximumSessionsContext(Authentication authentication, List sessions, - int maximumSessionsAllowed) { + int maximumSessionsAllowed, WebSession currentSession) { this.authentication = authentication; this.sessions = sessions; this.maximumSessionsAllowed = maximumSessionsAllowed; + this.currentSession = currentSession; } public Authentication getAuthentication() { @@ -48,4 +52,8 @@ public final class MaximumSessionsContext { return this.maximumSessionsAllowed; } + public WebSession getCurrentSession() { + return this.currentSession; + } + } diff --git a/web/src/main/java/org/springframework/security/web/server/authentication/PreventLoginServerMaximumSessionsExceededHandler.java b/web/src/main/java/org/springframework/security/web/server/authentication/PreventLoginServerMaximumSessionsExceededHandler.java index a98f8795e6..1afb5771e3 100644 --- a/web/src/main/java/org/springframework/security/web/server/authentication/PreventLoginServerMaximumSessionsExceededHandler.java +++ b/web/src/main/java/org/springframework/security/web/server/authentication/PreventLoginServerMaximumSessionsExceededHandler.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2023 the original author or authors. + * Copyright 2002-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -31,9 +31,9 @@ public final class PreventLoginServerMaximumSessionsExceededHandler implements S @Override public Mono handle(MaximumSessionsContext context) { - return Mono - .error(new SessionAuthenticationException("Maximum sessions of " + context.getMaximumSessionsAllowed() - + " for authentication '" + context.getAuthentication().getName() + "' exceeded")); + return context.getCurrentSession() + .invalidate() + .then(Mono.defer(() -> Mono.error(new SessionAuthenticationException("Maximum sessions exceeded")))); } } diff --git a/web/src/test/java/org/springframework/security/web/server/authentication/session/InvalidateLeastUsedServerMaximumSessionsExceededHandlerTests.java b/web/src/test/java/org/springframework/security/web/server/authentication/session/InvalidateLeastUsedServerMaximumSessionsExceededHandlerTests.java index 60b6107418..3c16e6fdd9 100644 --- a/web/src/test/java/org/springframework/security/web/server/authentication/session/InvalidateLeastUsedServerMaximumSessionsExceededHandlerTests.java +++ b/web/src/test/java/org/springframework/security/web/server/authentication/session/InvalidateLeastUsedServerMaximumSessionsExceededHandlerTests.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2023 the original author or authors. + * Copyright 2002-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -50,7 +50,7 @@ class InvalidateLeastUsedServerMaximumSessionsExceededHandlerTests { given(session2.getLastAccessTime()).willReturn(Instant.ofEpochMilli(1700827760000L)); given(session2.invalidate()).willReturn(Mono.empty()); MaximumSessionsContext context = new MaximumSessionsContext(mock(Authentication.class), - List.of(session1, session2), 2); + List.of(session1, session2), 2, null); this.handler.handle(context).block(); @@ -72,7 +72,7 @@ class InvalidateLeastUsedServerMaximumSessionsExceededHandlerTests { given(session1.invalidate()).willReturn(Mono.empty()); given(session2.invalidate()).willReturn(Mono.empty()); MaximumSessionsContext context = new MaximumSessionsContext(mock(Authentication.class), - List.of(session1, session2, session3), 2); + List.of(session1, session2, session3), 2, null); this.handler.handle(context).block(); diff --git a/web/src/test/java/org/springframework/security/web/server/authentication/session/PreventLoginServerMaximumSessionsExceededHandlerTests.java b/web/src/test/java/org/springframework/security/web/server/authentication/session/PreventLoginServerMaximumSessionsExceededHandlerTests.java index 819489ee43..68f1f09650 100644 --- a/web/src/test/java/org/springframework/security/web/server/authentication/session/PreventLoginServerMaximumSessionsExceededHandlerTests.java +++ b/web/src/test/java/org/springframework/security/web/server/authentication/session/PreventLoginServerMaximumSessionsExceededHandlerTests.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2023 the original author or authors. + * Copyright 2002-2024 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -19,13 +19,19 @@ package org.springframework.security.web.server.authentication.session; import java.util.Collections; import org.junit.jupiter.api.Test; +import reactor.core.publisher.Mono; +import reactor.test.StepVerifier; import org.springframework.security.authentication.TestAuthentication; import org.springframework.security.web.authentication.session.SessionAuthenticationException; import org.springframework.security.web.server.authentication.MaximumSessionsContext; import org.springframework.security.web.server.authentication.PreventLoginServerMaximumSessionsExceededHandler; +import org.springframework.web.server.WebSession; -import static org.assertj.core.api.Assertions.assertThatExceptionOfType; +import static org.assertj.core.api.Assertions.assertThat; +import static org.mockito.BDDMockito.given; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.verify; /** * Tests for {@link PreventLoginServerMaximumSessionsExceededHandler}. @@ -35,13 +41,17 @@ import static org.assertj.core.api.Assertions.assertThatExceptionOfType; class PreventLoginServerMaximumSessionsExceededHandlerTests { @Test - void handleWhenInvokedThenThrowsSessionAuthenticationException() { + void handleWhenInvokedThenInvalidateWebSessionAndThrowsSessionAuthenticationException() { PreventLoginServerMaximumSessionsExceededHandler handler = new PreventLoginServerMaximumSessionsExceededHandler(); + WebSession webSession = mock(); + given(webSession.invalidate()).willReturn(Mono.empty()); MaximumSessionsContext context = new MaximumSessionsContext(TestAuthentication.authenticatedUser(), - Collections.emptyList(), 1); - assertThatExceptionOfType(SessionAuthenticationException.class) - .isThrownBy(() -> handler.handle(context).block()) - .withMessage("Maximum sessions of 1 for authentication 'user' exceeded"); + Collections.emptyList(), 1, webSession); + StepVerifier.create(handler.handle(context)).expectErrorSatisfies((ex) -> { + assertThat(ex).isInstanceOf(SessionAuthenticationException.class); + assertThat(ex.getMessage()).isEqualTo("Maximum sessions exceeded"); + }).verify(); + verify(webSession).invalidate(); } }