Improved generics support in ResolvableMethod

This commit is contained in:
Rossen Stoyanchev
2017-03-05 18:10:45 -05:00
parent c5351fdbef
commit 0296d003af
8 changed files with 234 additions and 217 deletions

View File

@@ -38,18 +38,21 @@ import org.springframework.cglib.proxy.Callback;
import org.springframework.cglib.proxy.Enhancer;
import org.springframework.cglib.proxy.Factory;
import org.springframework.cglib.proxy.MethodProxy;
import org.springframework.core.LocalVariableTableParameterNameDiscoverer;
import org.springframework.core.MethodIntrospector;
import org.springframework.core.MethodParameter;
import org.springframework.core.ParameterNameDiscoverer;
import org.springframework.core.ResolvableType;
import org.springframework.core.annotation.AnnotatedElementUtils;
import org.springframework.core.annotation.SynthesizingMethodParameter;
import org.springframework.objenesis.ObjenesisException;
import org.springframework.objenesis.SpringObjenesis;
import org.springframework.util.Assert;
import org.springframework.util.ObjectUtils;
import org.springframework.util.ReflectionUtils;
/**
* Convenience class to resolve a method and its parameters based on hints.
* Convenience class to resolve method parameters from hints.
*
* <h1>Background</h1>
*
@@ -57,19 +60,15 @@ import org.springframework.util.ReflectionUtils;
* "TestController" with a diverse range of method signatures representing
* supported annotations and argument types. It becomes challenging to use
* naming strategies to keep track of methods and arguments especially in
* combination variables for reflection metadata.
* combination with variables for reflection metadata.
*
* <p>The idea with {@link ResolvableMethod} is NOT to rely on naming techniques
* but to use hints to zero in on method parameters. Especially in combination
* with {@link ResolvableType} such hints can be strongly typed and make tests
* more readable by being explicit about what is being tested and more robust
* since the provided hints have to match.
*
* <p>Common use cases:
* but to use hints to zero in on method parameters. Such hints can be strongly
* typed and explicit about what is being tested.
*
* <h2>1. Declared Return Type</h2>
*
* When testing return types it's common to have many methods with a unique
* When testing return types it's likely to have many methods with a unique
* return type, possibly with or without an annotation.
*
* <pre>
@@ -78,6 +77,8 @@ import org.springframework.util.ReflectionUtils;
*
* // Return type
* on(TestController.class).resolveReturnType(Foo.class);
* on(TestController.class).resolveReturnType(List.class, Foo.class);
* on(TestController.class).resolveReturnType(Mono.class, responseEntity(Foo.class));
*
* // Annotation + return type
* on(TestController.class).annotated(ResponseBody.class).resolveReturnType(Bar.class);
@@ -85,7 +86,7 @@ import org.springframework.util.ReflectionUtils;
* // Annotation not present
* on(TestController.class).notAnnotated(ResponseBody.class).resolveReturnType();
*
* // Annotation properties
* // Annotation with attributes
* on(TestController.class)
* .annotated(RequestMapping.class, patterns("/foo"), params("p"))
* .annotated(ResponseBody.class)
@@ -124,6 +125,9 @@ public class ResolvableMethod {
private static final SpringObjenesis objenesis = new SpringObjenesis();
private static final ParameterNameDiscoverer nameDiscoverer =
new LocalVariableTableParameterNameDiscoverer();
private final Method method;
@@ -151,9 +155,20 @@ public class ResolvableMethod {
/**
* Find a unique argument matching the given type.
* @param type the expected type
* @param generics optional array of generic types
*/
public MethodParameter arg(Class<?> type) {
return new ArgResolver().arg(type);
public MethodParameter arg(Class<?> type, Class<?>... generics) {
return new ArgResolver().arg(type, generics);
}
/**
* Find a unique argument matching the given type.
* @param type the expected type
* @param generic at least one generic type
* @param generics optional array of generic types
*/
public MethodParameter arg(Class<?> type, ResolvableType generic, ResolvableType... generics) {
return new ArgResolver().arg(type, generic, generics);
}
/**
@@ -207,6 +222,19 @@ public class ResolvableMethod {
.collect(Collectors.joining(",\n\t", "(\n\t", "\n)"));
}
private static ResolvableType toResolvableType(Class<?> type, Class<?>... generics) {
return ObjectUtils.isEmpty(generics) ?
ResolvableType.forClass(type) :
ResolvableType.forClassWithGenerics(type, generics);
}
private static ResolvableType toResolvableType(Class<?> type, ResolvableType generic, ResolvableType... generics) {
ResolvableType[] genericTypes = new ResolvableType[generics.length + 1];
genericTypes[0] = generic;
System.arraycopy(generics, 0, genericTypes, 1, generics.length);
return ResolvableType.forClassWithGenerics(type, genericTypes);
}
/**
* Main entry point providing access to a {@code ResolvableMethod} builder.
@@ -278,16 +306,29 @@ public class ResolvableMethod {
/**
* Filter on methods returning the given type.
* @param returnType the return type
* @param generics optional array of generic types
*/
public Builder returning(Class<?> returnType) {
return returning(ResolvableType.forClass(returnType));
public Builder returning(Class<?> returnType, Class<?>... generics) {
return returning(toResolvableType(returnType, generics));
}
/**
* Filter on methods returning the given type with generics.
* @param returnType the return type
* @param generic at least one generic type
* @param generics optional extra generic types
*/
public Builder returning(Class<?> returnType, ResolvableType generic, ResolvableType... generics) {
return returning(toResolvableType(returnType, generic, generics));
}
/**
* Filter on methods returning the given type.
* @param returnType the return type
*/
public Builder returning(ResolvableType resolvableType) {
String expected = resolvableType.toString();
public Builder returning(ResolvableType returnType) {
String expected = returnType.toString();
String message = "returnType=" + expected;
addFilter(message, m -> expected.equals(ResolvableType.forMethodReturnType(m).toString()));
return this;
@@ -336,7 +377,7 @@ public class ResolvableMethod {
}
// Build & Resolve shortcuts...
// Build & resolve shortcuts...
/**
* Resolve and return the {@code Method} equivalent to:
@@ -365,15 +406,26 @@ public class ResolvableMethod {
/**
* Shortcut to the unique return type equivalent to:
* <p>{@code returning(returnType).build().returnType()}
* @param returnType the return type
* @param generics optional array of generic types
*/
public MethodParameter resolveReturnType(Class<?> returnType) {
return returning(returnType).build().returnType();
public MethodParameter resolveReturnType(Class<?> returnType, Class<?>... generics) {
return returning(returnType, generics).build().returnType();
}
/**
* Shortcut to the unique return type equivalent to:
* <p>{@code returning(returnType).build().returnType()}
* @param returnType the return type
* @param generic at least one generic type
* @param generics optional extra generic types
*/
public MethodParameter resolveReturnType(Class<?> returnType, ResolvableType generic,
ResolvableType... generics) {
return returning(returnType, generic, generics).build().returnType();
}
public MethodParameter resolveReturnType(ResolvableType returnType) {
return returning(returnType).build().returnType();
}
@@ -392,51 +444,6 @@ public class ResolvableMethod {
}
}
@SuppressWarnings("unchecked")
private static <T> T initProxy(Class<?> type, MethodInvocationInterceptor interceptor) {
Assert.notNull(type, "'type' must not be null");
if (type.isInterface()) {
ProxyFactory factory = new ProxyFactory(EmptyTargetSource.INSTANCE);
factory.addInterface(type);
factory.addInterface(Supplier.class);
factory.addAdvice(interceptor);
return (T) factory.getProxy();
}
else {
Enhancer enhancer = new Enhancer();
enhancer.setSuperclass(type);
enhancer.setInterfaces(new Class<?>[] {Supplier.class});
enhancer.setNamingPolicy(SpringNamingPolicy.INSTANCE);
enhancer.setCallbackType(org.springframework.cglib.proxy.MethodInterceptor.class);
Class<?> proxyClass = enhancer.createClass();
Object proxy = null;
if (objenesis.isWorthTrying()) {
try {
proxy = objenesis.newInstance(proxyClass, enhancer.getUseCache());
}
catch (ObjenesisException ex) {
logger.debug("Objenesis failed, falling back to default constructor", ex);
}
}
if (proxy == null) {
try {
proxy = ReflectionUtils.accessibleConstructor(proxyClass).newInstance();
}
catch (Throwable ex) {
throw new IllegalStateException("Unable to instantiate proxy " +
"via both Objenesis and default constructor fails as well", ex);
}
}
((Factory) proxy).setCallbacks(new Callback[] {interceptor});
return (T) proxy;
}
}
/**
* Predicate with a descriptive label.
*/
@@ -533,9 +540,16 @@ public class ResolvableMethod {
* Resolve the argument also matching to the given type.
* @param type the expected type
*/
public MethodParameter arg(Class<?> type) {
this.filters.add(p -> type.equals(p.getParameterType()));
return arg(ResolvableType.forClass(type));
public MethodParameter arg(Class<?> type, Class<?>... generics) {
return arg(toResolvableType(type, generics));
}
/**
* Resolve the argument also matching to the given type.
* @param type the expected type
*/
public MethodParameter arg(Class<?> type, ResolvableType generic, ResolvableType... generics) {
return arg(toResolvableType(type, generic, generics));
}
/**
@@ -562,6 +576,7 @@ public class ResolvableMethod {
List<MethodParameter> matches = new ArrayList<>();
for (int i = 0; i < method.getParameterCount(); i++) {
MethodParameter param = new SynthesizingMethodParameter(method, i);
param.initParameterNameDiscovery(nameDiscoverer);
if (this.filters.stream().allMatch(p -> p.test(param))) {
matches.add(param);
}
@@ -597,4 +612,49 @@ public class ResolvableMethod {
}
}
@SuppressWarnings("unchecked")
private static <T> T initProxy(Class<?> type, MethodInvocationInterceptor interceptor) {
Assert.notNull(type, "'type' must not be null");
if (type.isInterface()) {
ProxyFactory factory = new ProxyFactory(EmptyTargetSource.INSTANCE);
factory.addInterface(type);
factory.addInterface(Supplier.class);
factory.addAdvice(interceptor);
return (T) factory.getProxy();
}
else {
Enhancer enhancer = new Enhancer();
enhancer.setSuperclass(type);
enhancer.setInterfaces(new Class<?>[] {Supplier.class});
enhancer.setNamingPolicy(SpringNamingPolicy.INSTANCE);
enhancer.setCallbackType(org.springframework.cglib.proxy.MethodInterceptor.class);
Class<?> proxyClass = enhancer.createClass();
Object proxy = null;
if (objenesis.isWorthTrying()) {
try {
proxy = objenesis.newInstance(proxyClass, enhancer.getUseCache());
}
catch (ObjenesisException ex) {
logger.debug("Objenesis failed, falling back to default constructor", ex);
}
}
if (proxy == null) {
try {
proxy = ReflectionUtils.accessibleConstructor(proxyClass).newInstance();
}
catch (Throwable ex) {
throw new IllegalStateException("Unable to instantiate proxy " +
"via both Objenesis and default constructor fails as well", ex);
}
}
((Factory) proxy).setCallbacks(new Callback[] {interceptor});
return (T) proxy;
}
}
}

View File

@@ -27,7 +27,6 @@ import org.junit.Before;
import org.junit.Test;
import org.springframework.beans.propertyeditors.StringTrimmerEditor;
import org.springframework.core.LocalVariableTableParameterNameDiscoverer;
import org.springframework.core.MethodParameter;
import org.springframework.core.convert.support.DefaultConversionService;
import org.springframework.mock.web.test.MockHttpServletRequest;
@@ -59,7 +58,6 @@ import static org.junit.Assert.assertTrue;
import static org.junit.Assert.fail;
import static org.mockito.BDDMockito.given;
import static org.mockito.BDDMockito.mock;
import static org.springframework.core.ResolvableType.forClassWithGenerics;
/**
* Test fixture with {@link org.springframework.web.method.annotation.RequestParamMethodArgumentResolver}.
@@ -101,7 +99,7 @@ public class RequestParamMethodArgumentResolverTests {
param = this.testMethod.annotated(RequestParam.class).arg(MultipartFile.class);
assertTrue(resolver.supportsParameter(param));
param = this.testMethod.annotated(RequestParam.class).arg(forClassWithGenerics(List.class, MultipartFile.class));
param = this.testMethod.annotated(RequestParam.class).arg(List.class, MultipartFile.class);
assertTrue(resolver.supportsParameter(param));
param = this.testMethod.annotated(RequestParam.class).arg(MultipartFile[].class);
@@ -110,7 +108,7 @@ public class RequestParamMethodArgumentResolverTests {
param = this.testMethod.annotated(RequestParam.class).arg(Part.class);
assertTrue(resolver.supportsParameter(param));
param = this.testMethod.annotated(RequestParam.class).arg(forClassWithGenerics(List.class, Part.class));
param = this.testMethod.annotated(RequestParam.class).arg(List.class, Part.class);
assertTrue(resolver.supportsParameter(param));
param = this.testMethod.annotated(RequestParam.class).arg(Part[].class);
@@ -125,7 +123,7 @@ public class RequestParamMethodArgumentResolverTests {
param = this.testMethod.notAnnotated().arg(MultipartFile.class);
assertTrue(resolver.supportsParameter(param));
param = this.testMethod.notAnnotated(RequestParam.class).arg(forClassWithGenerics(List.class, MultipartFile.class));
param = this.testMethod.notAnnotated(RequestParam.class).arg(List.class, MultipartFile.class);
assertTrue(resolver.supportsParameter(param));
param = this.testMethod.notAnnotated(RequestParam.class).arg(Part.class);
@@ -140,10 +138,10 @@ public class RequestParamMethodArgumentResolverTests {
param = this.testMethod.annotated(RequestParam.class, required().negate()).arg(String.class);
assertTrue(resolver.supportsParameter(param));
param = this.testMethod.annotated(RequestParam.class).arg(forClassWithGenerics(Optional.class, Integer.class));
param = this.testMethod.annotated(RequestParam.class).arg(Optional.class, Integer.class);
assertTrue(resolver.supportsParameter(param));
param = this.testMethod.annotated(RequestParam.class).arg(forClassWithGenerics(Optional.class, MultipartFile.class));
param = this.testMethod.annotated(RequestParam.class).arg(Optional.class, MultipartFile.class);
assertTrue(resolver.supportsParameter(param));
resolver = new RequestParamMethodArgumentResolver(null, false);
@@ -201,8 +199,7 @@ public class RequestParamMethodArgumentResolverTests {
webRequest = new ServletWebRequest(request);
MethodParameter param = this.testMethod
.annotated(RequestParam.class)
.arg(forClassWithGenerics(List.class, MultipartFile.class));
.annotated(RequestParam.class).arg(List.class, MultipartFile.class);
Object result = resolver.resolveArgument(param, null, webRequest, null);
assertTrue(result instanceof List);
@@ -255,10 +252,7 @@ public class RequestParamMethodArgumentResolverTests {
request.addPart(new MockPart("other", "Hello World 3".getBytes()));
webRequest = new ServletWebRequest(request);
MethodParameter param = this.testMethod
.annotated(RequestParam.class)
.arg(forClassWithGenerics(List.class, Part.class));
MethodParameter param = this.testMethod.annotated(RequestParam.class).arg(List.class, Part.class);
Object result = resolver.resolveArgument(param, null, webRequest, null);
assertTrue(result instanceof List);
assertEquals(Arrays.asList(expected1, expected2), result);
@@ -293,8 +287,6 @@ public class RequestParamMethodArgumentResolverTests {
webRequest = new ServletWebRequest(request);
MethodParameter param = this.testMethod.notAnnotated().arg(MultipartFile.class);
param.initParameterNameDiscovery(new LocalVariableTableParameterNameDiscoverer());
Object result = resolver.resolveArgument(param, null, webRequest, null);
assertTrue(result instanceof MultipartFile);
assertEquals("Invalid result", expected, result);
@@ -310,9 +302,7 @@ public class RequestParamMethodArgumentResolverTests {
webRequest = new ServletWebRequest(request);
MethodParameter param = this.testMethod
.notAnnotated(RequestParam.class)
.arg(forClassWithGenerics(List.class, MultipartFile.class));
param.initParameterNameDiscovery(new LocalVariableTableParameterNameDiscoverer());
.notAnnotated(RequestParam.class).arg(List.class, MultipartFile.class);
Object result = resolver.resolveArgument(param, null, webRequest, null);
assertTrue(result instanceof List);
@@ -335,9 +325,7 @@ public class RequestParamMethodArgumentResolverTests {
webRequest = new ServletWebRequest(request);
MethodParameter param = this.testMethod
.notAnnotated(RequestParam.class)
.arg(forClassWithGenerics(List.class, MultipartFile.class));
param.initParameterNameDiscovery(new LocalVariableTableParameterNameDiscoverer());
.notAnnotated(RequestParam.class).arg(List.class, MultipartFile.class);
Object actual = resolver.resolveArgument(param, null, webRequest, null);
assertTrue(actual instanceof List);
@@ -371,7 +359,6 @@ public class RequestParamMethodArgumentResolverTests {
webRequest = new ServletWebRequest(request);
MethodParameter param = this.testMethod.notAnnotated(RequestParam.class).arg(Part.class);
param.initParameterNameDiscovery(new LocalVariableTableParameterNameDiscoverer());
Object result = resolver.resolveArgument(param, null, webRequest, null);
assertTrue(result instanceof Part);
assertEquals("Invalid result", expected, result);
@@ -403,7 +390,6 @@ public class RequestParamMethodArgumentResolverTests {
this.request.addParameter("stringNotAnnot", "");
MethodParameter param = this.testMethod.notAnnotated(RequestParam.class).arg(String.class);
param.initParameterNameDiscovery(new LocalVariableTableParameterNameDiscoverer());
Object arg = resolver.resolveArgument(param, null, webRequest, binderFactory);
assertNull(arg);
}
@@ -427,9 +413,7 @@ public class RequestParamMethodArgumentResolverTests {
public void resolveSimpleTypeParam() throws Exception {
request.setParameter("stringNotAnnot", "plainValue");
MethodParameter param = this.testMethod.notAnnotated(RequestParam.class).arg(String.class);
param.initParameterNameDiscovery(new LocalVariableTableParameterNameDiscoverer());
Object result = resolver.resolveArgument(param, null, webRequest, null);
param.initParameterNameDiscovery(new LocalVariableTableParameterNameDiscoverer());
assertTrue(result instanceof String);
assertEquals("plainValue", result);
@@ -438,7 +422,6 @@ public class RequestParamMethodArgumentResolverTests {
@Test // SPR-8561
public void resolveSimpleTypeParamToNull() throws Exception {
MethodParameter param = this.testMethod.notAnnotated(RequestParam.class).arg(String.class);
param.initParameterNameDiscovery(new LocalVariableTableParameterNameDiscoverer());
Object result = resolver.resolveArgument(param, null, webRequest, null);
assertNull(result);
}
@@ -455,7 +438,6 @@ public class RequestParamMethodArgumentResolverTests {
public void resolveEmptyValueWithoutDefault() throws Exception {
this.request.addParameter("stringNotAnnot", "");
MethodParameter param = this.testMethod.notAnnotated(RequestParam.class).arg(String.class);
param.initParameterNameDiscovery(new LocalVariableTableParameterNameDiscoverer());
Object result = resolver.resolveArgument(param, null, webRequest, null);
assertEquals("", result);
}
@@ -476,8 +458,7 @@ public class RequestParamMethodArgumentResolverTests {
WebDataBinderFactory binderFactory = new DefaultDataBinderFactory(initializer);
MethodParameter param = this.testMethod
.annotated(RequestParam.class)
.arg(forClassWithGenerics(Optional.class, Integer.class));
.annotated(RequestParam.class).arg(Optional.class, Integer.class);
Object result = resolver.resolveArgument(param, null, webRequest, binderFactory);
assertEquals(Optional.empty(), result);
@@ -500,8 +481,7 @@ public class RequestParamMethodArgumentResolverTests {
webRequest = new ServletWebRequest(request);
MethodParameter param = this.testMethod
.annotated(RequestParam.class)
.arg(forClassWithGenerics(Optional.class, MultipartFile.class));
.annotated(RequestParam.class).arg(Optional.class, MultipartFile.class);
Object result = resolver.resolveArgument(param, null, webRequest, binderFactory);
assertTrue(result instanceof Optional);
@@ -518,10 +498,10 @@ public class RequestParamMethodArgumentResolverTests {
request.setContentType("multipart/form-data");
MethodParameter param = this.testMethod
.annotated(RequestParam.class)
.arg(forClassWithGenerics(Optional.class, MultipartFile.class));
.annotated(RequestParam.class).arg(Optional.class, MultipartFile.class);
assertEquals(Optional.empty(), resolver.resolveArgument(param, null, webRequest, binderFactory));
Object actual = resolver.resolveArgument(param, null, webRequest, binderFactory);
assertEquals(Optional.empty(), actual);
}
@Test
@@ -531,10 +511,10 @@ public class RequestParamMethodArgumentResolverTests {
WebDataBinderFactory binderFactory = new DefaultDataBinderFactory(initializer);
MethodParameter param = this.testMethod
.annotated(RequestParam.class)
.arg(forClassWithGenerics(Optional.class, MultipartFile.class));
.annotated(RequestParam.class).arg(Optional.class, MultipartFile.class);
assertEquals(Optional.empty(), resolver.resolveArgument(param, null, webRequest, binderFactory));
Object actual = resolver.resolveArgument(param, null, webRequest, binderFactory);
assertEquals(Optional.empty(), actual);
}
private Predicate<RequestParam> name(String name) {