Check if csrf protection is already disabled

If the user has explicitly disabled CSRF protection in a custom
SecurityFilterChain, we should be able to back off and not try
to disable it just for the gRPC endpoints (accidentally switching
it back on for the other endpoints).

Fixes gh-142
This commit is contained in:
Dave Syer
2025-03-24 10:09:43 +00:00
parent cfb14ba7e0
commit e4fa0f83df
5 changed files with 152 additions and 3 deletions

View File

@@ -1,4 +1,5 @@
spring.application.name=grpc-tomcat-secure
server.port=9090
spring.security.user.name=user
spring.security.user.password=user
spring.security.user.password=user
#logging.level.org.springframework.security=TRACE

View File

@@ -0,0 +1,110 @@
package org.springframework.grpc.sample;
import static org.junit.jupiter.api.Assertions.assertEquals;
import static org.junit.jupiter.api.Assertions.assertThrows;
import static org.springframework.security.config.Customizer.withDefaults;
import java.util.List;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.junit.jupiter.api.Test;
import org.springframework.beans.factory.annotation.Autowired;
import org.springframework.beans.factory.annotation.Qualifier;
import org.springframework.boot.builder.SpringApplicationBuilder;
import org.springframework.boot.test.context.SpringBootTest;
import org.springframework.boot.test.context.TestConfiguration;
import org.springframework.boot.test.web.server.LocalServerPort;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Lazy;
import org.springframework.grpc.client.ChannelBuilderOptions;
import org.springframework.grpc.client.GrpcChannelFactory;
import org.springframework.grpc.client.interceptor.security.BasicAuthenticationInterceptor;
import org.springframework.grpc.sample.proto.HelloReply;
import org.springframework.grpc.sample.proto.HelloRequest;
import org.springframework.grpc.sample.proto.SimpleGrpc;
import org.springframework.security.config.annotation.web.builders.HttpSecurity;
import org.springframework.security.config.annotation.web.configurers.CsrfConfigurer;
import org.springframework.security.web.SecurityFilterChain;
import org.springframework.test.annotation.DirtiesContext;
import org.springframework.web.bind.annotation.PostMapping;
import org.springframework.web.bind.annotation.RequestBody;
import org.springframework.web.bind.annotation.RestController;
import io.grpc.Status.Code;
import io.grpc.StatusRuntimeException;
@SpringBootTest(webEnvironment = SpringBootTest.WebEnvironment.RANDOM_PORT,
properties = "spring.grpc.client.default-channel.address=0.0.0.0:${local.server.port}")
public class CsrfDisabledApplicationTests {
private static Log log = LogFactory.getLog(CsrfDisabledApplicationTests.class);
public static void main(String[] args) {
new SpringApplicationBuilder(GrpcServerApplication.class, ExtraConfiguration.class).run(args);
}
@Autowired
@Qualifier("simpleBlockingStub")
private SimpleGrpc.SimpleBlockingStub stub;
@Autowired
@Qualifier("basic")
private SimpleGrpc.SimpleBlockingStub basic;
@Test
@DirtiesContext
void contextLoads() {
}
@Test
@DirtiesContext
void unauthenticated() {
StatusRuntimeException exception = assertThrows(StatusRuntimeException.class,
() -> stub.sayHello(HelloRequest.newBuilder().setName("Alien").build()));
assertEquals(Code.UNAUTHENTICATED, exception.getStatus().getCode());
}
@Test
@DirtiesContext
void authenticated() {
log.info("Testing");
HelloReply response = basic.sayHello(HelloRequest.newBuilder().setName("Alien").build());
assertEquals("Hello ==> Alien", response.getMessage());
}
@TestConfiguration
@RestController
static class ExtraConfiguration {
@Bean
SecurityFilterChain securityFilterChain(HttpSecurity http) throws Exception {
return http.httpBasic(withDefaults())
.csrf(CsrfConfigurer::disable)
.authorizeHttpRequests(requests -> requests.anyRequest().fullyAuthenticated())
.build();
}
@Bean
@Lazy
SimpleGrpc.SimpleBlockingStub basic(GrpcChannelFactory channels, @LocalServerPort int port) {
return SimpleGrpc.newBlockingStub(channels.createChannel("default", ChannelBuilderOptions.defaults()
.withInterceptors(List.of(new BasicAuthenticationInterceptor("user", "user")))));
}
@PostMapping
GreetingResponse postGreeting(@RequestBody GreetingRequest request) {
var helloRequest = HelloRequest.newBuilder().setName(request.name()).build();
var helloReply = "Hello ==> " + helloRequest.getName();
return new GreetingResponse(helloReply);
}
record GreetingRequest(String name) {
}
record GreetingResponse(String message) {
}
}
}

View File

@@ -16,8 +16,13 @@
package org.springframework.grpc.autoconfigure.server.security;
import org.springframework.context.ApplicationContext;
import org.springframework.grpc.server.service.GrpcServiceDiscoverer;
import org.springframework.security.config.annotation.web.builders.HttpSecurity;
import org.springframework.security.config.annotation.web.configurers.AbstractHttpConfigurer;
import org.springframework.security.config.annotation.web.configurers.CsrfConfigurer;
import org.springframework.security.web.csrf.CsrfFilter;
import org.springframework.security.web.util.matcher.AndRequestMatcher;
import org.springframework.security.web.util.matcher.NegatedRequestMatcher;
/**
* A custom {@link AbstractHttpConfigurer} that disables CSRF protection for gRPC
@@ -38,11 +43,22 @@ public class GrpcDisableCsrfHttpConfigurer extends AbstractHttpConfigurer<GrpcDi
@Override
public void init(HttpSecurity http) throws Exception {
ApplicationContext context = http.getSharedObject(ApplicationContext.class);
if (context != null && isServletEnabledAndCsrfDisabled(context)) {
http.csrf(csrf -> csrf.ignoringRequestMatchers(GrpcServletRequest.all()));
if (context != null && context.getBeanNamesForType(GrpcServiceDiscoverer.class).length == 1
&& isServletEnabledAndCsrfDisabled(context) && isCsrfConfigurerPresent(http)) {
http.csrf(this::disable);
}
}
@SuppressWarnings("unchecked")
private boolean isCsrfConfigurerPresent(HttpSecurity http) {
return http.getConfigurer(CsrfConfigurer.class) != null;
}
private void disable(CsrfConfigurer<HttpSecurity> csrf) {
csrf.requireCsrfProtectionMatcher(new AndRequestMatcher(CsrfFilter.DEFAULT_CSRF_MATCHER,
new NegatedRequestMatcher(GrpcServletRequest.all())));
}
private boolean isServletEnabledAndCsrfDisabled(ApplicationContext context) {
return context.getEnvironment().getProperty("spring.grpc.server.servlet.enabled", Boolean.class, true)
&& !context.getEnvironment()

View File

@@ -28,6 +28,7 @@ import org.springframework.security.web.util.matcher.AntPathRequestMatcher;
import org.springframework.security.web.util.matcher.OrRequestMatcher;
import org.springframework.security.web.util.matcher.RequestMatcher;
import org.springframework.util.Assert;
import org.springframework.web.context.WebApplicationContext;
import jakarta.servlet.http.HttpServletRequest;
@@ -102,6 +103,11 @@ public class GrpcServletRequest {
this.delegate = matchers.isEmpty() ? request -> false : new OrRequestMatcher(matchers);
}
@Override
protected boolean ignoreApplicationContext(WebApplicationContext context) {
return context.getBeanNamesForType(GrpcServiceDiscoverer.class).length != 1;
}
private Stream<RequestMatcher> getDelegateMatchers(GrpcServiceDiscoverer context) {
return getPatterns(context).map(AntPathRequestMatcher::new);
}

View File

@@ -67,6 +67,22 @@ public class GrpcServletRequestTests {
assertThat(matcher.matches(request)).isFalse();
};
@Test
void noServices() {
GrpcServletRequestMatcher matcher = GrpcServletRequest.all();
MockHttpServletRequest request = mockRequestNoServices("/my-service/Method");
assertThat(matcher.matches(request)).isFalse();
};
private MockHttpServletRequest mockRequestNoServices(String path) {
MockServletContext servletContext = new MockServletContext();
servletContext.setAttribute(WebApplicationContext.ROOT_WEB_APPLICATION_CONTEXT_ATTRIBUTE,
new StaticWebApplicationContext());
MockHttpServletRequest request = new MockHttpServletRequest(servletContext);
request.setPathInfo(path);
return request;
}
private MockHttpServletRequest mockRequest(String path) {
MockServletContext servletContext = new MockServletContext();
servletContext.setAttribute(WebApplicationContext.ROOT_WEB_APPLICATION_CONTEXT_ATTRIBUTE, this.context);