GH-1111 Fix filter registration for serverless-web

Resolves #1111
This commit is contained in:
Oleg Zhurakousky
2024-02-20 15:08:02 +01:00
parent 0b08e8b242
commit 062db0b13d
5 changed files with 98 additions and 25 deletions

View File

@@ -16,17 +16,18 @@
package org.springframework.cloud.function.serverless.web;
import jakarta.servlet.Filter;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.springframework.beans.BeansException;
import org.springframework.beans.factory.InitializingBean;
import org.springframework.boot.autoconfigure.condition.ConditionalOnMissingBean;
import org.springframework.boot.autoconfigure.web.servlet.DispatcherServletRegistrationBean;
import org.springframework.boot.web.context.ConfigurableWebServerApplicationContext;
import org.springframework.boot.web.server.WebServer;
import org.springframework.boot.web.server.WebServerException;
import org.springframework.boot.web.servlet.ServletContextInitializer;
import org.springframework.boot.web.servlet.ServletContextInitializerBeans;
import org.springframework.boot.web.servlet.context.ServletWebServerApplicationContext;
import org.springframework.boot.web.servlet.server.ServletWebServerFactory;
import org.springframework.cloud.function.serverless.web.ServerlessMVC.ProxyServletConfig;
@@ -96,9 +97,11 @@ public class ServerlessAutoConfiguration {
logger.info("Configuring Serverless Web Container");
ServerlessServletContext servletContext = new ServerlessServletContext();
servletApplicationContet.setServletContext(servletContext);
this.applicationContext.getBeansOfType(Filter.class).entrySet().forEach(entry -> {
servletContext.addFilter(entry.getKey(), entry.getValue());
});
for (ServletContextInitializer beans : new ServletContextInitializerBeans(this.applicationContext)) {
if (!(beans instanceof DispatcherServletRegistrationBean)) {
beans.onStartup(servletContext);
}
}
}
}
}

View File

@@ -169,6 +169,7 @@ public class ServerlessHttpServletRequest implements HttpServletRequest {
this.servletContext = servletContext;
this.method = method;
this.requestURI = requestURI;
this.pathInfo = requestURI;
this.locales.add(Locale.ENGLISH);
}

View File

@@ -25,11 +25,11 @@ import org.junit.jupiter.api.AfterEach;
import org.junit.jupiter.api.BeforeEach;
import org.junit.jupiter.api.Test;
import org.springframework.cloud.function.test.app.Pet;
import org.springframework.cloud.function.test.app.PetStoreSpringAppConfig;
import org.springframework.cloud.function.test.app.PetStoreSpringAppConfig.AnotherFilter;
import org.springframework.cloud.function.test.app.PetStoreSpringAppConfig.SimpleFilter;
import org.springframework.http.HttpStatus;
import org.springframework.security.test.context.support.WithMockUser;
import static org.assertj.core.api.Assertions.assertThat;
@@ -47,6 +47,8 @@ public class RequestResponseTests {
@BeforeEach
public void before() {
System.setProperty("spring.main.banner-mode", "off");
System.setProperty("trace", "true");
System.setProperty("contextInitTimeout", "20000");
this.mvc = ServerlessMVC.INSTANCE(PetStoreSpringAppConfig.class);
}
@@ -57,11 +59,15 @@ public class RequestResponseTests {
@Test
public void validateAccessDeniedWithCustomHandler() throws Exception {
HttpServletRequest request = new ServerlessHttpServletRequest(null, "GET", "/foo");
HttpServletRequest request = new ServerlessHttpServletRequest(null, "GET", "/foo/deny");
ServerlessHttpServletResponse response = new ServerlessHttpServletResponse();
mvc.service(request, response);
assertThat(response.getErrorMessage()).isEqualTo("Can't touch this");
assertThat(response.getStatus()).isEqualTo(403);
SimpleFilter simpleFilter = this.mvc.getApplicationContext().getBean(SimpleFilter.class);
assertThat(simpleFilter.invoked).isTrue();
AnotherFilter anotherFilter = this.mvc.getApplicationContext().getBean(AnotherFilter.class);
assertThat(anotherFilter.invoked).isTrue();
}
@Test
@@ -89,7 +95,7 @@ public class RequestResponseTests {
assertThat(pets.get(0)).isInstanceOf(Pet.class);
}
@WithMockUser("spring")
//@WithMockUser("spring")
@Test
public void validateGetPojo() throws Exception {
HttpServletRequest request = new ServerlessHttpServletRequest(null, "GET", "/pets/6e3cc370-892f-4efe-a9eb-82926ff8cc5b");

View File

@@ -28,6 +28,7 @@ import jakarta.servlet.http.HttpServletResponse;
import jakarta.servlet.http.HttpSession;
import org.springframework.boot.autoconfigure.EnableAutoConfiguration;
import org.springframework.boot.web.servlet.FilterRegistrationBean;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import org.springframework.context.annotation.Import;
@@ -40,8 +41,11 @@ import org.springframework.security.core.context.SecurityContext;
import org.springframework.security.core.context.SecurityContextHolder;
import org.springframework.security.web.SecurityFilterChain;
import org.springframework.security.web.access.AccessDeniedHandler;
import org.springframework.security.web.authentication.logout.LogoutFilter;
import org.springframework.security.web.context.SecurityContextHolderFilter;
import org.springframework.security.web.savedrequest.RequestCacheAwareFilter;
import org.springframework.web.filter.GenericFilterBean;
import org.springframework.web.filter.OncePerRequestFilter;
import org.springframework.web.servlet.HandlerAdapter;
import org.springframework.web.servlet.HandlerMapping;
import org.springframework.web.servlet.mvc.method.annotation.RequestMappingHandlerAdapter;
@@ -73,38 +77,96 @@ public class PetStoreSpringAppConfig {
}
@Bean
public SecurityFilterChain securityFilterChain(HttpSecurity http) throws Exception {
public SecurityFilterChain securityFilterChain(HttpSecurity http, SimpleFilter simpleFilter,
AnotherFilter anotherFilter) throws Exception {
http
.csrf(csrf -> csrf.disable())
.cors(cors -> cors.disable())
.addFilterBefore(new GenericFilterBean() {
@Override
public void doFilter(ServletRequest request, ServletResponse response, FilterChain chain)
throws IOException, ServletException {
SecurityContext securityContext = SecurityContextHolder.getContext();
securityContext.setAuthentication(UsernamePasswordAuthenticationToken.authenticated("user", "password",
Collections.singleton(new SimpleGrantedAuthority("USER"))));
Collections.singleton(new SimpleGrantedAuthority("ROLE_USER"))));
HttpSession session = ((HttpServletRequest) request).getSession();
session.setAttribute("SPRING_SECURITY_CONTEXT", securityContext);
chain.doFilter(request, response);
}
}, SecurityContextHolderFilter.class)
.securityMatcher("/foo")
.authorizeHttpRequests(authorize -> authorize
.anyRequest().hasRole("FOO")
)
.exceptionHandling(f -> f.accessDeniedHandler(accessDeniedHandler()));
.securityMatcher("/foo/deny")
.authorizeHttpRequests(auth -> {
auth.anyRequest().hasRole("FOO");
})
.addFilterAfter(simpleFilter, LogoutFilter.class)
.addFilterAfter(anotherFilter, RequestCacheAwareFilter.class)
.exceptionHandling(f -> f.accessDeniedHandler(new MyAccessDeinedHandler()));
return http.build();
}
@Bean
public AccessDeniedHandler accessDeniedHandler() {
return new AccessDeniedHandler() {
@Override
public void handle(HttpServletRequest request, HttpServletResponse response,
AccessDeniedException accessDeniedException) throws IOException, ServletException {
response.sendError(403, "Can't touch this");
}
};
public FilterRegistrationBean<SimpleFilter> simpleFilterRegistration(SimpleFilter simpleFilter) {
FilterRegistrationBean<SimpleFilter> registration = new FilterRegistrationBean<>(simpleFilter);
registration.setEnabled(false);
return registration;
}
@Bean
public FilterRegistrationBean<AnotherFilter> anotherFilterRegistration(AnotherFilter simpleFilter) {
FilterRegistrationBean<AnotherFilter> registration = new FilterRegistrationBean<>(simpleFilter);
registration.setEnabled(false);
return registration;
}
@Bean
public SimpleFilter simpleFilter() {
return new SimpleFilter();
}
@Bean
public AnotherFilter anotherFilter() {
return new AnotherFilter();
}
public static class SimpleFilter extends OncePerRequestFilter {
public boolean invoked;
@Override
protected void doFilterInternal(HttpServletRequest request, HttpServletResponse response,
FilterChain filterChain) throws ServletException, IOException {
if (invoked) {
throw new IllegalStateException("Filter has already been invoked");
}
else {
invoked = true;
}
filterChain.doFilter(request, response);
}
}
public static class AnotherFilter extends OncePerRequestFilter {
public boolean invoked;
@Override
protected void doFilterInternal(HttpServletRequest request, HttpServletResponse response,
FilterChain filterChain) throws ServletException, IOException {
if (invoked) {
throw new IllegalStateException("Filter has already been invoked");
}
else {
invoked = true;
}
filterChain.doFilter(request, response);
}
}
public static class MyAccessDeinedHandler implements AccessDeniedHandler {
@Override
public void handle(HttpServletRequest request, HttpServletResponse response,
AccessDeniedException accessDeniedException) throws IOException, ServletException {
response.sendError(403, "Can't touch this");
}
}
}

View File

@@ -21,6 +21,7 @@ import java.util.Optional;
import java.util.UUID;
import org.springframework.http.HttpStatus;
import org.springframework.web.bind.annotation.GetMapping;
import org.springframework.web.bind.annotation.PathVariable;
import org.springframework.web.bind.annotation.RequestBody;
import org.springframework.web.bind.annotation.RequestMapping;
@@ -94,7 +95,7 @@ public class PetsController {
return newPet;
}
@RequestMapping(path = "/foo", method = RequestMethod.GET)
@GetMapping("/foo/deny")
public Pet foo() {
Pet newPet = new Pet();
newPet.setId(UUID.randomUUID().toString());