Check if csrf protection is already disabled
If the user has explicitly disabled CSRF protection in a custom SecurityFilterChain, we should be able to back off and not try to disable it just for the gRPC endpoints (accidentally switching it back on for the other endpoints). Fixes gh-142
This commit is contained in:
@@ -1,4 +1,5 @@
|
||||
spring.application.name=grpc-tomcat-secure
|
||||
server.port=9090
|
||||
spring.security.user.name=user
|
||||
spring.security.user.password=user
|
||||
spring.security.user.password=user
|
||||
#logging.level.org.springframework.security=TRACE
|
||||
@@ -0,0 +1,110 @@
|
||||
package org.springframework.grpc.sample;
|
||||
|
||||
import static org.junit.jupiter.api.Assertions.assertEquals;
|
||||
import static org.junit.jupiter.api.Assertions.assertThrows;
|
||||
import static org.springframework.security.config.Customizer.withDefaults;
|
||||
|
||||
import java.util.List;
|
||||
|
||||
import org.apache.commons.logging.Log;
|
||||
import org.apache.commons.logging.LogFactory;
|
||||
import org.junit.jupiter.api.Test;
|
||||
import org.springframework.beans.factory.annotation.Autowired;
|
||||
import org.springframework.beans.factory.annotation.Qualifier;
|
||||
import org.springframework.boot.builder.SpringApplicationBuilder;
|
||||
import org.springframework.boot.test.context.SpringBootTest;
|
||||
import org.springframework.boot.test.context.TestConfiguration;
|
||||
import org.springframework.boot.test.web.server.LocalServerPort;
|
||||
import org.springframework.context.annotation.Bean;
|
||||
import org.springframework.context.annotation.Lazy;
|
||||
import org.springframework.grpc.client.ChannelBuilderOptions;
|
||||
import org.springframework.grpc.client.GrpcChannelFactory;
|
||||
import org.springframework.grpc.client.interceptor.security.BasicAuthenticationInterceptor;
|
||||
import org.springframework.grpc.sample.proto.HelloReply;
|
||||
import org.springframework.grpc.sample.proto.HelloRequest;
|
||||
import org.springframework.grpc.sample.proto.SimpleGrpc;
|
||||
import org.springframework.security.config.annotation.web.builders.HttpSecurity;
|
||||
import org.springframework.security.config.annotation.web.configurers.CsrfConfigurer;
|
||||
import org.springframework.security.web.SecurityFilterChain;
|
||||
import org.springframework.test.annotation.DirtiesContext;
|
||||
import org.springframework.web.bind.annotation.PostMapping;
|
||||
import org.springframework.web.bind.annotation.RequestBody;
|
||||
import org.springframework.web.bind.annotation.RestController;
|
||||
|
||||
import io.grpc.Status.Code;
|
||||
import io.grpc.StatusRuntimeException;
|
||||
|
||||
@SpringBootTest(webEnvironment = SpringBootTest.WebEnvironment.RANDOM_PORT,
|
||||
properties = "spring.grpc.client.default-channel.address=0.0.0.0:${local.server.port}")
|
||||
public class CsrfDisabledApplicationTests {
|
||||
|
||||
private static Log log = LogFactory.getLog(CsrfDisabledApplicationTests.class);
|
||||
|
||||
public static void main(String[] args) {
|
||||
new SpringApplicationBuilder(GrpcServerApplication.class, ExtraConfiguration.class).run(args);
|
||||
}
|
||||
|
||||
@Autowired
|
||||
@Qualifier("simpleBlockingStub")
|
||||
private SimpleGrpc.SimpleBlockingStub stub;
|
||||
|
||||
@Autowired
|
||||
@Qualifier("basic")
|
||||
private SimpleGrpc.SimpleBlockingStub basic;
|
||||
|
||||
@Test
|
||||
@DirtiesContext
|
||||
void contextLoads() {
|
||||
}
|
||||
|
||||
@Test
|
||||
@DirtiesContext
|
||||
void unauthenticated() {
|
||||
StatusRuntimeException exception = assertThrows(StatusRuntimeException.class,
|
||||
() -> stub.sayHello(HelloRequest.newBuilder().setName("Alien").build()));
|
||||
assertEquals(Code.UNAUTHENTICATED, exception.getStatus().getCode());
|
||||
}
|
||||
|
||||
@Test
|
||||
@DirtiesContext
|
||||
void authenticated() {
|
||||
log.info("Testing");
|
||||
HelloReply response = basic.sayHello(HelloRequest.newBuilder().setName("Alien").build());
|
||||
assertEquals("Hello ==> Alien", response.getMessage());
|
||||
}
|
||||
|
||||
@TestConfiguration
|
||||
@RestController
|
||||
static class ExtraConfiguration {
|
||||
|
||||
@Bean
|
||||
SecurityFilterChain securityFilterChain(HttpSecurity http) throws Exception {
|
||||
return http.httpBasic(withDefaults())
|
||||
.csrf(CsrfConfigurer::disable)
|
||||
.authorizeHttpRequests(requests -> requests.anyRequest().fullyAuthenticated())
|
||||
.build();
|
||||
}
|
||||
|
||||
@Bean
|
||||
@Lazy
|
||||
SimpleGrpc.SimpleBlockingStub basic(GrpcChannelFactory channels, @LocalServerPort int port) {
|
||||
return SimpleGrpc.newBlockingStub(channels.createChannel("default", ChannelBuilderOptions.defaults()
|
||||
.withInterceptors(List.of(new BasicAuthenticationInterceptor("user", "user")))));
|
||||
}
|
||||
|
||||
@PostMapping
|
||||
GreetingResponse postGreeting(@RequestBody GreetingRequest request) {
|
||||
var helloRequest = HelloRequest.newBuilder().setName(request.name()).build();
|
||||
var helloReply = "Hello ==> " + helloRequest.getName();
|
||||
return new GreetingResponse(helloReply);
|
||||
}
|
||||
|
||||
record GreetingRequest(String name) {
|
||||
}
|
||||
|
||||
record GreetingResponse(String message) {
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
}
|
||||
@@ -16,8 +16,13 @@
|
||||
package org.springframework.grpc.autoconfigure.server.security;
|
||||
|
||||
import org.springframework.context.ApplicationContext;
|
||||
import org.springframework.grpc.server.service.GrpcServiceDiscoverer;
|
||||
import org.springframework.security.config.annotation.web.builders.HttpSecurity;
|
||||
import org.springframework.security.config.annotation.web.configurers.AbstractHttpConfigurer;
|
||||
import org.springframework.security.config.annotation.web.configurers.CsrfConfigurer;
|
||||
import org.springframework.security.web.csrf.CsrfFilter;
|
||||
import org.springframework.security.web.util.matcher.AndRequestMatcher;
|
||||
import org.springframework.security.web.util.matcher.NegatedRequestMatcher;
|
||||
|
||||
/**
|
||||
* A custom {@link AbstractHttpConfigurer} that disables CSRF protection for gRPC
|
||||
@@ -38,11 +43,22 @@ public class GrpcDisableCsrfHttpConfigurer extends AbstractHttpConfigurer<GrpcDi
|
||||
@Override
|
||||
public void init(HttpSecurity http) throws Exception {
|
||||
ApplicationContext context = http.getSharedObject(ApplicationContext.class);
|
||||
if (context != null && isServletEnabledAndCsrfDisabled(context)) {
|
||||
http.csrf(csrf -> csrf.ignoringRequestMatchers(GrpcServletRequest.all()));
|
||||
if (context != null && context.getBeanNamesForType(GrpcServiceDiscoverer.class).length == 1
|
||||
&& isServletEnabledAndCsrfDisabled(context) && isCsrfConfigurerPresent(http)) {
|
||||
http.csrf(this::disable);
|
||||
}
|
||||
}
|
||||
|
||||
@SuppressWarnings("unchecked")
|
||||
private boolean isCsrfConfigurerPresent(HttpSecurity http) {
|
||||
return http.getConfigurer(CsrfConfigurer.class) != null;
|
||||
}
|
||||
|
||||
private void disable(CsrfConfigurer<HttpSecurity> csrf) {
|
||||
csrf.requireCsrfProtectionMatcher(new AndRequestMatcher(CsrfFilter.DEFAULT_CSRF_MATCHER,
|
||||
new NegatedRequestMatcher(GrpcServletRequest.all())));
|
||||
}
|
||||
|
||||
private boolean isServletEnabledAndCsrfDisabled(ApplicationContext context) {
|
||||
return context.getEnvironment().getProperty("spring.grpc.server.servlet.enabled", Boolean.class, true)
|
||||
&& !context.getEnvironment()
|
||||
|
||||
@@ -28,6 +28,7 @@ import org.springframework.security.web.util.matcher.AntPathRequestMatcher;
|
||||
import org.springframework.security.web.util.matcher.OrRequestMatcher;
|
||||
import org.springframework.security.web.util.matcher.RequestMatcher;
|
||||
import org.springframework.util.Assert;
|
||||
import org.springframework.web.context.WebApplicationContext;
|
||||
|
||||
import jakarta.servlet.http.HttpServletRequest;
|
||||
|
||||
@@ -102,6 +103,11 @@ public class GrpcServletRequest {
|
||||
this.delegate = matchers.isEmpty() ? request -> false : new OrRequestMatcher(matchers);
|
||||
}
|
||||
|
||||
@Override
|
||||
protected boolean ignoreApplicationContext(WebApplicationContext context) {
|
||||
return context.getBeanNamesForType(GrpcServiceDiscoverer.class).length != 1;
|
||||
}
|
||||
|
||||
private Stream<RequestMatcher> getDelegateMatchers(GrpcServiceDiscoverer context) {
|
||||
return getPatterns(context).map(AntPathRequestMatcher::new);
|
||||
}
|
||||
|
||||
@@ -67,6 +67,22 @@ public class GrpcServletRequestTests {
|
||||
assertThat(matcher.matches(request)).isFalse();
|
||||
};
|
||||
|
||||
@Test
|
||||
void noServices() {
|
||||
GrpcServletRequestMatcher matcher = GrpcServletRequest.all();
|
||||
MockHttpServletRequest request = mockRequestNoServices("/my-service/Method");
|
||||
assertThat(matcher.matches(request)).isFalse();
|
||||
};
|
||||
|
||||
private MockHttpServletRequest mockRequestNoServices(String path) {
|
||||
MockServletContext servletContext = new MockServletContext();
|
||||
servletContext.setAttribute(WebApplicationContext.ROOT_WEB_APPLICATION_CONTEXT_ATTRIBUTE,
|
||||
new StaticWebApplicationContext());
|
||||
MockHttpServletRequest request = new MockHttpServletRequest(servletContext);
|
||||
request.setPathInfo(path);
|
||||
return request;
|
||||
}
|
||||
|
||||
private MockHttpServletRequest mockRequest(String path) {
|
||||
MockServletContext servletContext = new MockServletContext();
|
||||
servletContext.setAttribute(WebApplicationContext.ROOT_WEB_APPLICATION_CONTEXT_ATTRIBUTE, this.context);
|
||||
|
||||
Reference in New Issue
Block a user