add GrpcExceptionHandledServerCall to catch exceptions

Signed-off-by: Thomas McKernan <tmeaglei@gmail.com>
This commit is contained in:
Thomas McKernan
2025-05-14 19:06:10 -05:00
committed by Dave Syer
parent 3a7bcdf4e8
commit bd70fcd69f
7 changed files with 85 additions and 34 deletions

View File

@@ -31,7 +31,8 @@ option java_outer_classname = "HelloWorldProto";
// The greeting service definition.
service Simple {
// Sends a greeting
rpc SayHello(HelloRequest) returns (HelloReply) {}
rpc SayHello (HelloRequest) returns (HelloReply) {
}
rpc StreamHello(HelloRequest) returns (stream HelloReply) {}
}

View File

@@ -209,17 +209,19 @@
<goal>compile-custom</goal>
</goals>
</execution>
<execution>
<id>grpc-kotlin</id>
<goals>
<goal>compile-custom</goal>
</goals>
<configuration>
<pluginId>grpc-kotlin</pluginId>
</configuration>
</execution>
</executions>
</plugin>
<plugin>
<groupId>org.apache.maven.plugins</groupId>
<artifactId>maven-surefire-plugin</artifactId>
<version>3.1.2</version>
<configuration>
<useModulePath>false</useModulePath>
<includes>
<include>**/*.class</include>
</includes>
</configuration>
</plugin>
</plugins>
</build>

View File

@@ -19,7 +19,7 @@ import org.springframework.grpc.test.AutoConfigureInProcessTransport
@SpringBootTest
@AutoConfigureInProcessTransport
internal class NoAutowiredClients {
class NoAutowiredClients {
@Autowired
private lateinit var context: ApplicationContext
@@ -36,7 +36,7 @@ internal class NoAutowiredClients {
@SpringBootTest(properties = ["spring.grpc.client.default-channel.address=0.0.0.0:9090"])
@AutoConfigureInProcessTransport
internal class DefaultAutowiredClients {
class DefaultAutowiredClients {
@Autowired
private lateinit var context: ApplicationContext
@@ -59,7 +59,7 @@ internal class DefaultAutowiredClients {
properties = ["spring.grpc.client.default-channel.address=0.0.0.0:9090"]
)
@AutoConfigureInProcessTransport
internal class SpecificAutowiredClients {
class SpecificAutowiredClients {
@Autowired
private lateinit var context: ApplicationContext

View File

@@ -50,7 +50,7 @@ import java.time.Duration
properties = ["spring.grpc.server.port=0", "spring.grpc.client.channels.health-test.address=static://0.0.0.0:\${local.grpc.port}", "spring.grpc.client.channels.health-test.health.enabled=true", "spring.grpc.client.channels.health-test.health.service-name=my-service"]
)
@DirtiesContext
internal class WithClientHealthEnabled {
class WithClientHealthEnabled {
@Test
fun loadBalancerRespectsServerHealth(
@Autowired channels: GrpcChannelFactory,
@@ -112,7 +112,7 @@ internal class WithClientHealthEnabled {
)
@AutoConfigureInProcessTransport
@DirtiesContext
internal class WithActuatorHealthAdapter {
class WithActuatorHealthAdapter {
@Test
fun healthIndicatorsAdaptedToGrpcHealthStatus(
@Autowired channels: GrpcChannelFactory,
@@ -158,7 +158,7 @@ internal class WithActuatorHealthAdapter {
}
}
internal class CustomHealthIndicator : HealthIndicator {
class CustomHealthIndicator : HealthIndicator {
override fun health(): Health? {
return if (SERVICE_IS_UP) Health.up().build() else Health.down().build()
}

View File

@@ -52,7 +52,7 @@ import java.util.concurrent.atomic.AtomicInteger
@SpringBootTest
@AutoConfigureInProcessTransport
internal class ServerWithInProcessChannel {
class ServerWithInProcessChannel {
@Test
fun servesResponseToClient(@Autowired channels: GrpcChannelFactory) {
assertThatResponseIsServedToChannel(channels.createChannel("0.0.0.0:0"))
@@ -62,7 +62,7 @@ internal class ServerWithInProcessChannel {
@SpringBootTest
@AutoConfigureInProcessTransport
internal class ServerWithException {
class ServerWithException {
@Test
fun specificErrorResponse(@Autowired channels: GrpcChannelFactory) {
@@ -81,7 +81,7 @@ internal class ServerWithException {
@Test
fun defaultErrorResponseIsUnknown(@Autowired channels: GrpcChannelFactory) {
val client = SimpleGrpc.newBlockingStub(channels.createChannel("0.0.0.0:0"))
Assertions.assertThat<Status.Code?>(
Assertions.assertThat(
Assert.assertThrows(
StatusRuntimeException::class.java
) { client.sayHello(HelloRequest.newBuilder().setName("internal").build()) }
@@ -94,14 +94,14 @@ internal class ServerWithException {
@SpringBootTest
@AutoConfigureInProcessTransport
internal class ServerWithExceptionInInterceptorCall {
class ServerWithExceptionInInterceptorCall {
@Test
fun specificErrorResponse(@Autowired channels: GrpcChannelFactory) {
val client = SimpleGrpc.newBlockingStub(channels.createChannel("0.0.0.0:0"))
Assertions.assertThat(
Assert.assertThrows(
StatusRuntimeException::class.java
) { client.sayHello(HelloRequest.newBuilder().setName("foo").build()) }
) { client.sayHello(HelloRequest.newBuilder().setName("error").build()) }
.status
.code
).isEqualTo(Status.Code.INVALID_ARGUMENT)
@@ -129,7 +129,7 @@ internal class ServerWithExceptionInInterceptorCall {
@SpringBootTest
@AutoConfigureInProcessTransport
internal class ServerWithExceptionInInterceptorListener {
class ServerWithExceptionInInterceptorListener {
@Test
fun specificErrorResponse(
@Autowired channels: GrpcChannelFactory,
@@ -150,7 +150,7 @@ internal class ServerWithExceptionInInterceptorListener {
}
@TestConfiguration
internal open class TestConfig {
open class TestConfig {
companion object {
var callCount: AtomicInteger = AtomicInteger()
var messageCount: AtomicInteger = AtomicInteger()
@@ -206,7 +206,7 @@ internal class ServerWithExceptionInInterceptorListener {
@SpringBootTest("spring.grpc.server.exception-handler.enabled=false")
@AutoConfigureInProcessTransport
internal class ServerWithUnhandledException {
class ServerWithUnhandledException {
@Test
fun specificErrorResponse(@Autowired channels: GrpcChannelFactory) {
val client = SimpleGrpc.newBlockingStub(channels.createChannel("0.0.0.0:0"))
@@ -236,7 +236,7 @@ internal class ServerWithUnhandledException {
@SpringBootTest(properties = ["spring.grpc.server.host=0.0.0.0", "spring.grpc.server.port=0"])
internal class ServerWithAnyIPv4AddressAndRandomPort {
class ServerWithAnyIPv4AddressAndRandomPort {
@Test
fun servesResponseToClientWithAnyIPv4AddressAndRandomPort(
@Autowired channels: GrpcChannelFactory,
@@ -248,7 +248,7 @@ internal class ServerWithAnyIPv4AddressAndRandomPort {
@SpringBootTest(properties = ["spring.grpc.server.host=::", "spring.grpc.server.port=0"])
internal class ServerWithAnyIPv6AddressAndRandomPort {
class ServerWithAnyIPv6AddressAndRandomPort {
@Test
fun servesResponseToClientWithAnyIPv4AddressAndRandomPort(
@Autowired channels: GrpcChannelFactory,
@@ -260,7 +260,7 @@ internal class ServerWithAnyIPv6AddressAndRandomPort {
@SpringBootTest(properties = ["spring.grpc.server.host=127.0.0.1", "spring.grpc.server.port=0"])
internal class ServerWithLocalhostAndRandomPort {
class ServerWithLocalhostAndRandomPort {
@Test
fun servesResponseToClientWithLocalhostAndRandomPort(
@Autowired channels: GrpcChannelFactory,
@@ -275,7 +275,7 @@ internal class ServerWithLocalhostAndRandomPort {
properties = ["spring.grpc.server.port=0", "spring.grpc.client.channels.test-channel.address=static://0.0.0.0:\${local.grpc.port}"]
)
@DirtiesContext
internal class ServerConfiguredWithStaticClientChannel {
class ServerConfiguredWithStaticClientChannel {
@Test
fun servesResponseToClientWithConfiguredChannel(@Autowired channels: GrpcChannelFactory) {
assertThatResponseIsServedToChannel(channels.createChannel("test-channel"))
@@ -285,7 +285,7 @@ internal class ServerConfiguredWithStaticClientChannel {
@SpringBootTest(properties = ["spring.grpc.server.address=unix:unix-test-channel"])
@EnabledOnOs(OS.LINUX)
internal class ServerWithUnixDomain {
class ServerWithUnixDomain {
@Test
fun clientChannelWithUnixDomain(@Autowired channels: GrpcChannelFactory) {
assertThatResponseIsServedToChannel(
@@ -304,7 +304,7 @@ internal class ServerWithUnixDomain {
)
@ActiveProfiles("ssl")
@DirtiesContext
internal class ServerWithSsl {
class ServerWithSsl {
@Test
fun clientChannelWithSsl(@Autowired channels: GrpcChannelFactory) {
assertThatResponseIsServedToChannel(channels.createChannel("test-channel"))
@@ -317,7 +317,7 @@ internal class ServerWithSsl {
)
@ActiveProfiles("ssl")
@DirtiesContext
internal class ServerWithClientAuth {
class ServerWithClientAuth {
@Test
fun clientChannelWithSsl(@Autowired channels: GrpcChannelFactory) {
assertThatResponseIsServedToChannel(channels.createChannel("test-channel"))

View File

@@ -0,0 +1,46 @@
/*
* Copyright 2024-2024 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* https://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.grpc.server.exception;
import io.grpc.ForwardingServerCall;
import io.grpc.Metadata;
import io.grpc.ServerCall;
import io.grpc.Status;
import io.grpc.StatusException;
public class GrpcExceptionHandledServerCall<ReqT, RespT>
extends ForwardingServerCall.SimpleForwardingServerCall<ReqT, RespT> {
private final GrpcExceptionHandler exceptionHandler;
protected GrpcExceptionHandledServerCall(ServerCall<ReqT, RespT> delegate, GrpcExceptionHandler handler) {
super(delegate);
this.exceptionHandler = handler;
}
@Override
public void close(Status status, Metadata trailers) {
if (status.getCode() == Status.Code.UNKNOWN && status.getCause() != null) {
final Throwable cause = status.getCause();
final StatusException statusException = this.exceptionHandler.handleException(cause);
super.close(statusException.getStatus(), trailers);
}
else {
super.close(status, trailers);
}
}
}

View File

@@ -65,16 +65,18 @@ public class GrpcExceptionHandlerInterceptor implements ServerInterceptor {
ServerCallHandler<ReqT, RespT> next) {
Listener<ReqT> listener;
FallbackHandler handler = new FallbackHandler(this.exceptionHandler);
final GrpcExceptionHandledServerCall<ReqT, RespT> exceptionHandledServerCall = new GrpcExceptionHandledServerCall<>(
call, handler);
try {
listener = next.startCall(call, headers);
listener = next.startCall(exceptionHandledServerCall, headers);
}
catch (Throwable t) {
call.close(handler.handleException(t).getStatus(), headers(t));
exceptionHandledServerCall.close(handler.handleException(t).getStatus(), headers(t));
listener = new Listener<ReqT>() {
};
return listener;
}
return new ExceptionHandlerListener<>(listener, call, handler);
return new ExceptionHandlerListener<>(listener, exceptionHandledServerCall, handler);
}
private static Metadata headers(Throwable t) {