Support method parameter annotations on interface

Closes gh-480
This commit is contained in:
rstoyanchev
2022-09-09 12:25:10 +01:00
parent a7a78d9c6c
commit a1165051eb
2 changed files with 119 additions and 5 deletions

View File

@@ -17,6 +17,9 @@ package org.springframework.graphql.data.method;
import java.lang.annotation.Annotation;
import java.lang.reflect.Method;
import java.util.ArrayList;
import java.util.Arrays;
import java.util.List;
import java.util.stream.Collectors;
import java.util.stream.IntStream;
@@ -26,6 +29,7 @@ import org.apache.commons.logging.LogFactory;
import org.springframework.beans.factory.BeanFactory;
import org.springframework.core.BridgeMethodResolver;
import org.springframework.core.MethodParameter;
import org.springframework.core.ResolvableType;
import org.springframework.core.annotation.AnnotatedElementUtils;
import org.springframework.core.annotation.SynthesizingMethodParameter;
import org.springframework.lang.Nullable;
@@ -67,6 +71,9 @@ public class HandlerMethod {
private final MethodParameter[] parameters;
@Nullable
private volatile List<Annotation[][]> interfaceParameterAnnotations;
/**
* Constructor with a handler instance and a method.
@@ -241,6 +248,40 @@ public class HandlerMethod {
return getBeanType().getSimpleName() + "#" + this.method.getName() + "[" + args + " args]";
}
private List<Annotation[][]> getInterfaceParameterAnnotations() {
List<Annotation[][]> parameterAnnotations = this.interfaceParameterAnnotations;
if (parameterAnnotations == null) {
parameterAnnotations = new ArrayList<>();
for (Class<?> ifc : ClassUtils.getAllInterfacesForClassAsSet(this.method.getDeclaringClass())) {
for (Method candidate : ifc.getMethods()) {
if (isOverrideFor(candidate)) {
parameterAnnotations.add(candidate.getParameterAnnotations());
}
}
}
this.interfaceParameterAnnotations = parameterAnnotations;
}
return parameterAnnotations;
}
private boolean isOverrideFor(Method candidate) {
if (!candidate.getName().equals(this.method.getName()) ||
candidate.getParameterCount() != this.method.getParameterCount()) {
return false;
}
Class<?>[] paramTypes = this.method.getParameterTypes();
if (Arrays.equals(candidate.getParameterTypes(), paramTypes)) {
return true;
}
for (int i = 0; i < paramTypes.length; i++) {
if (paramTypes[i] !=
ResolvableType.forMethodParameter(candidate, i, this.method.getDeclaringClass()).resolve()) {
return false;
}
}
return true;
}
@Override
public boolean equals(@Nullable Object other) {
@@ -323,6 +364,9 @@ public class HandlerMethod {
*/
protected class HandlerMethodParameter extends SynthesizingMethodParameter {
@Nullable
private volatile Annotation[] combinedAnnotations;
public HandlerMethodParameter(int index) {
super(HandlerMethod.this.bridgedMethod, index);
}
@@ -347,8 +391,38 @@ public class HandlerMethod {
}
@Override
public HandlerMethodParameter clone() {
return new HandlerMethodParameter(this);
public Annotation[] getParameterAnnotations() {
Annotation[] anns = this.combinedAnnotations;
if (anns == null) {
anns = super.getParameterAnnotations();
int index = getParameterIndex();
if (index >= 0) {
for (Annotation[][] ifcAnns : getInterfaceParameterAnnotations()) {
if (index < ifcAnns.length) {
Annotation[] paramAnns = ifcAnns[index];
if (paramAnns.length > 0) {
List<Annotation> merged = new ArrayList<>(anns.length + paramAnns.length);
merged.addAll(Arrays.asList(anns));
for (Annotation paramAnn : paramAnns) {
boolean existingType = false;
for (Annotation ann : anns) {
if (ann.annotationType() == paramAnn.annotationType()) {
existingType = true;
break;
}
}
if (!existingType) {
merged.add(adaptAnnotation(paramAnn));
}
}
anns = merged.toArray(new Annotation[0]);
}
}
}
}
this.combinedAnnotations = anns;
}
return anns;
}
}

View File

@@ -16,6 +16,8 @@
package org.springframework.graphql.data.method.annotation.support;
import java.lang.reflect.Method;
import java.util.Collections;
import java.util.concurrent.Callable;
import java.util.concurrent.CompletableFuture;
@@ -26,10 +28,14 @@ import org.junit.jupiter.api.Test;
import org.mockito.Mockito;
import org.springframework.core.task.SimpleAsyncTaskExecutor;
import org.springframework.graphql.data.GraphQlArgumentBinder;
import org.springframework.graphql.data.method.HandlerMethod;
import org.springframework.graphql.data.method.HandlerMethodArgumentResolver;
import org.springframework.graphql.data.method.HandlerMethodArgumentResolverComposite;
import org.springframework.graphql.data.method.annotation.Argument;
import org.springframework.graphql.data.method.annotation.QueryMapping;
import org.springframework.lang.Nullable;
import org.springframework.util.ClassUtils;
import static org.assertj.core.api.Assertions.assertThat;
@@ -40,6 +46,22 @@ import static org.assertj.core.api.Assertions.assertThat;
*/
public class DataFetcherHandlerMethodTests {
@Test
void annotatedMethodsOnInterface() {
HandlerMethodArgumentResolverComposite resolvers = new HandlerMethodArgumentResolverComposite();
resolvers.addResolver(new ArgumentMethodArgumentResolver(new GraphQlArgumentBinder()));
DataFetcherHandlerMethod handlerMethod = new DataFetcherHandlerMethod(
handlerMethodFor(new TestController(), "hello"), resolvers, null, null, false);
Object result = handlerMethod.invoke(
DataFetchingEnvironmentImpl.newDataFetchingEnvironment()
.arguments(Collections.singletonMap("name", "Neil"))
.build());
assertThat(result).isEqualTo("Hello, Neil");
}
@Test
void callableReturnValue() throws Exception {
@@ -48,8 +70,8 @@ public class DataFetcherHandlerMethodTests {
resolvers.addResolver(Mockito.mock(HandlerMethodArgumentResolver.class));
DataFetcherHandlerMethod handlerMethod = new DataFetcherHandlerMethod(
new HandlerMethod(new TestController(), TestController.class.getMethod("handleAndReturnCallable")),
resolvers, null, new SimpleAsyncTaskExecutor(), false);
handlerMethodFor(new TestController(), "handleAndReturnCallable"), resolvers, null,
new SimpleAsyncTaskExecutor(), false);
GraphQLContext graphQLContext = new GraphQLContext.Builder().build();
@@ -64,8 +86,26 @@ public class DataFetcherHandlerMethodTests {
assertThat(future.get()).isEqualTo("A");
}
private static HandlerMethod handlerMethodFor(Object controller, String methodName) {
Method method = ClassUtils.getMethod(controller.getClass(), methodName, (Class<?>[]) null);
return new HandlerMethod(controller, method);
}
private static class TestController {
interface TestInterface {
@QueryMapping
String hello(@Argument String name);
}
@SuppressWarnings("unused")
private static class TestController implements TestInterface {
@Override
public String hello(String name) {
return "Hello, " + name;
}
@Nullable
public Callable<String> handleAndReturnCallable() {