Commit 42598175 authored by Stephane Nicoll's avatar Stephane Nicoll

Fix extension discovery when endpoint instance is sub-classed

This commit fixes endpoint extension discovery when the related endpoint
is sub-classed. Previously, a strict by type check was applied against
the `endpoint` attribute of `EndpointExtension`.

Rather than using a `Class` check, this commit extracts the id of an
endpoint and uses it to match its extension, if any.

Closes gh-13082
parent 9de2d1f8
...@@ -145,17 +145,17 @@ public abstract class EndpointDiscoverer<E extends ExposableEndpoint<O>, O exten ...@@ -145,17 +145,17 @@ public abstract class EndpointDiscoverer<E extends ExposableEndpoint<O>, O exten
} }
private void addExtensionBeans(Collection<EndpointBean> endpointBeans) { private void addExtensionBeans(Collection<EndpointBean> endpointBeans) {
Map<?, EndpointBean> byType = endpointBeans.stream() Map<String, EndpointBean> byId = endpointBeans.stream()
.collect(Collectors.toMap((bean) -> bean.getType(), (bean) -> bean)); .collect(Collectors.toMap(EndpointBean::getId, (bean) -> bean));
String[] beanNames = BeanFactoryUtils.beanNamesForAnnotationIncludingAncestors( String[] beanNames = BeanFactoryUtils.beanNamesForAnnotationIncludingAncestors(
this.applicationContext, EndpointExtension.class); this.applicationContext, EndpointExtension.class);
for (String beanName : beanNames) { for (String beanName : beanNames) {
ExtensionBean extensionBean = createExtensionBean(beanName); ExtensionBean extensionBean = createExtensionBean(beanName);
EndpointBean endpointBean = byType.get(extensionBean.getEndpointType()); EndpointBean endpointBean = byId.get(extensionBean.getEndpointId());
Assert.state(endpointBean != null, Assert.state(endpointBean != null,
() -> ("Invalid extension '" + extensionBean.getBeanName() () -> ("Invalid extension '" + extensionBean.getBeanName()
+ "': no endpoint found with type '" + "': no endpoint found with id '"
+ extensionBean.getEndpointType().getName() + "'")); + extensionBean.getEndpointId() + "'"));
addExtensionBean(endpointBean, extensionBean); addExtensionBean(endpointBean, extensionBean);
} }
} }
...@@ -488,20 +488,24 @@ public abstract class EndpointDiscoverer<E extends ExposableEndpoint<O>, O exten ...@@ -488,20 +488,24 @@ public abstract class EndpointDiscoverer<E extends ExposableEndpoint<O>, O exten
private final Object bean; private final Object bean;
private final Class<?> endpointType; private final String endpointId;
private final Class<?> filter; private final Class<?> filter;
ExtensionBean(String beanName, Object bean) { ExtensionBean(String beanName, Object bean) {
this.bean = bean;
this.beanName = beanName;
AnnotationAttributes attributes = AnnotatedElementUtils AnnotationAttributes attributes = AnnotatedElementUtils
.getMergedAnnotationAttributes(bean.getClass(), .getMergedAnnotationAttributes(bean.getClass(),
EndpointExtension.class); EndpointExtension.class);
this.beanName = beanName; Class<?> endpointType = attributes.getClass("endpoint");
this.bean = bean; AnnotationAttributes endpointAttributes = AnnotatedElementUtils
this.endpointType = attributes.getClass("endpoint"); .findMergedAnnotationAttributes(endpointType, Endpoint.class, true,
true);
Assert.state(endpointAttributes != null, () -> "Extension "
+ endpointType.getName() + " does not specify an endpoint");
this.endpointId = endpointAttributes.getString("id");
this.filter = attributes.getClass("filter"); this.filter = attributes.getClass("filter");
Assert.state(!this.endpointType.equals(Void.class), () -> "Extension "
+ this.endpointType.getName() + " does not specify an endpoint");
} }
public String getBeanName() { public String getBeanName() {
...@@ -512,8 +516,8 @@ public abstract class EndpointDiscoverer<E extends ExposableEndpoint<O>, O exten ...@@ -512,8 +516,8 @@ public abstract class EndpointDiscoverer<E extends ExposableEndpoint<O>, O exten
return this.bean; return this.bean;
} }
public Class<?> getEndpointType() { public String getEndpointId() {
return this.endpointType; return this.endpointId;
} }
public Class<?> getFilter() { public Class<?> getFilter() {
......
...@@ -233,6 +233,25 @@ public class EndpointDiscovererTests { ...@@ -233,6 +233,25 @@ public class EndpointDiscovererTests {
}); });
} }
@Test
public void getEndpointShouldFindParentExtension() {
load(SubSpecializedEndpointsConfiguration.class, (context) -> {
SpecializedEndpointDiscoverer discoverer = new SpecializedEndpointDiscoverer(
context);
Map<String, SpecializedExposableEndpoint> endpoints = mapEndpoints(
discoverer.getEndpoints());
Map<Method, SpecializedOperation> operations = mapOperations(
endpoints.get("specialized"));
assertThat(operations).containsKeys(
ReflectionUtils.findMethod(SpecializedTestEndpoint.class, "getAll"));
assertThat(operations).containsKeys(ReflectionUtils.findMethod(
SubSpecializedTestEndpoint.class, "getSpecialOne", String.class));
assertThat(operations).containsKeys(
ReflectionUtils.findMethod(SpecializedExtension.class, "getSpecial"));
assertThat(operations).hasSize(3);
});
}
@Test @Test
public void getEndpointsShouldApplyFilters() { public void getEndpointsShouldApplyFilters() {
load(SpecializedEndpointsConfiguration.class, (context) -> { load(SpecializedEndpointsConfiguration.class, (context) -> {
...@@ -371,6 +390,12 @@ public class EndpointDiscovererTests { ...@@ -371,6 +390,12 @@ public class EndpointDiscovererTests {
} }
@Import({ TestEndpoint.class, SubSpecializedTestEndpoint.class,
SpecializedExtension.class })
static class SubSpecializedEndpointsConfiguration {
}
@Endpoint(id = "test") @Endpoint(id = "test")
static class TestEndpoint { static class TestEndpoint {
...@@ -449,6 +474,15 @@ public class EndpointDiscovererTests { ...@@ -449,6 +474,15 @@ public class EndpointDiscovererTests {
} }
static class SubSpecializedTestEndpoint extends SpecializedTestEndpoint {
@ReadOperation
public Object getSpecialOne(@Selector String id) {
return null;
}
}
static class TestEndpointDiscoverer static class TestEndpointDiscoverer
extends EndpointDiscoverer<TestExposableEndpoint, TestOperation> { extends EndpointDiscoverer<TestExposableEndpoint, TestOperation> {
......
...@@ -117,8 +117,7 @@ public class JmxEndpointDiscovererTests { ...@@ -117,8 +117,7 @@ public class JmxEndpointDiscovererTests {
this.thrown.expect(IllegalStateException.class); this.thrown.expect(IllegalStateException.class);
this.thrown.expectMessage( this.thrown.expectMessage(
"Invalid extension 'jmxEndpointDiscovererTests.TestJmxEndpointExtension': " "Invalid extension 'jmxEndpointDiscovererTests.TestJmxEndpointExtension': "
+ "no endpoint found with type '" + "no endpoint found with id 'test'");
+ TestEndpoint.class.getName() + "'");
discoverer.getEndpoints(); discoverer.getEndpoints();
}); });
} }
......
...@@ -82,8 +82,8 @@ public class WebEndpointDiscovererTests { ...@@ -82,8 +82,8 @@ public class WebEndpointDiscovererTests {
load(TestWebEndpointExtensionConfiguration.class, (discoverer) -> { load(TestWebEndpointExtensionConfiguration.class, (discoverer) -> {
this.thrown.expect(IllegalStateException.class); this.thrown.expect(IllegalStateException.class);
this.thrown.expectMessage( this.thrown.expectMessage(
"Invalid extension 'endpointExtension': no endpoint found with type '" "Invalid extension 'endpointExtension': no endpoint found with id '"
+ TestEndpoint.class.getName() + "'"); + "test'");
discoverer.getEndpoints(); discoverer.getEndpoints();
}); });
} }
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment