diff --git a/spring-web/src/main/java/org/springframework/web/method/ControllerAdviceBean.java b/spring-web/src/main/java/org/springframework/web/method/ControllerAdviceBean.java index c2e2a691be..37551a38b3 100644 --- a/spring-web/src/main/java/org/springframework/web/method/ControllerAdviceBean.java +++ b/spring-web/src/main/java/org/springframework/web/method/ControllerAdviceBean.java @@ -47,14 +47,13 @@ import org.springframework.web.bind.annotation.ControllerAdvice; public class ControllerAdviceBean implements Ordered { /** - * Declared as {@code Object} since this may be a reference to a - * {@code String} representing the bean name or a reference to the actual - * bean instance. + * Reference to the actual bean instance or a {@code String} representing + * the bean name. */ - private final Object bean; + private final Object beanOrName; /** - * A reference to the resolved bean instance, potentially lazily retrieved + * Reference to the resolved bean instance, potentially lazily retrieved * via the {@code BeanFactory}. */ private Object resolvedBean; @@ -75,7 +74,13 @@ public class ControllerAdviceBean implements Ordered { * @param bean the bean instance */ public ControllerAdviceBean(Object bean) { - this(bean, null); + Assert.notNull(bean, "Bean must not be null"); + this.beanOrName = bean; + this.resolvedBean = bean; + this.beanType = ClassUtils.getUserClass(bean.getClass()); + this.beanTypePredicate = createBeanTypePredicate(this.beanType); + this.beanFactory = null; + this.order = initOrderFromBean(bean); } /** @@ -85,51 +90,16 @@ public class ControllerAdviceBean implements Ordered { * and later to resolve the actual bean */ public ControllerAdviceBean(String beanName, BeanFactory beanFactory) { - this((Object) beanName, beanFactory); - } + Assert.hasText(beanName, "Bean name must contain text"); + Assert.notNull(beanFactory, "BeanFactory must not be null"); + Assert.isTrue(beanFactory.containsBean(beanName), () -> "BeanFactory [" + beanFactory + + "] does not contain specified controller advice bean '" + beanName + "'"); - private ControllerAdviceBean(Object bean, @Nullable BeanFactory beanFactory) { - this.bean = bean; + this.beanOrName = beanName; + this.beanType = getBeanType(beanName, beanFactory); + this.beanTypePredicate = createBeanTypePredicate(this.beanType); this.beanFactory = beanFactory; - Class beanType; - - if (bean instanceof String) { - String beanName = (String) bean; - Assert.hasText(beanName, "Bean name must not be empty"); - Assert.notNull(beanFactory, "BeanFactory must not be null"); - if (!beanFactory.containsBean(beanName)) { - throw new IllegalArgumentException("BeanFactory [" + beanFactory + - "] does not contain specified controller advice bean '" + beanName + "'"); - } - beanType = this.beanFactory.getType(beanName); - if (beanType != null) { - beanType = ClassUtils.getUserClass(beanType); - } - this.order = initOrderFromBeanType(beanType); - } - else { - Assert.notNull(bean, "Bean must not be null"); - beanType = ClassUtils.getUserClass(bean.getClass()); - this.resolvedBean = bean; - this.order = initOrderFromBean(bean); - } - - this.beanType = beanType; - - ControllerAdvice annotation = (beanType != null ? - AnnotatedElementUtils.findMergedAnnotation(beanType, ControllerAdvice.class) : null); - - if (annotation != null) { - this.beanTypePredicate = HandlerTypePredicate.builder() - .basePackage(annotation.basePackages()) - .basePackageClass(annotation.basePackageClasses()) - .assignableType(annotation.assignableTypes()) - .annotation(annotation.annotations()) - .build(); - } - else { - this.beanTypePredicate = HandlerTypePredicate.forAnyHandlerType(); - } + this.order = initOrderFromBeanType(this.beanType); } @@ -160,9 +130,9 @@ public class ControllerAdviceBean implements Ordered { */ public Object resolveBean() { if (this.resolvedBean == null) { - // this.bean must be a String representing the bean name if + // this.beanOrName must be a String representing the bean name if // this.resolvedBean is null. - this.resolvedBean = obtainBeanFactory().getBean((String) this.bean); + this.resolvedBean = obtainBeanFactory().getBean((String) this.beanOrName); } return this.resolvedBean; } @@ -193,17 +163,17 @@ public class ControllerAdviceBean implements Ordered { return false; } ControllerAdviceBean otherAdvice = (ControllerAdviceBean) other; - return (this.bean.equals(otherAdvice.bean) && this.beanFactory == otherAdvice.beanFactory); + return (this.beanOrName.equals(otherAdvice.beanOrName) && this.beanFactory == otherAdvice.beanFactory); } @Override public int hashCode() { - return this.bean.hashCode(); + return this.beanOrName.hashCode(); } @Override public String toString() { - return this.bean.toString(); + return this.beanOrName.toString(); } @@ -222,6 +192,26 @@ public class ControllerAdviceBean implements Ordered { return adviceBeans; } + private static Class getBeanType(String beanName, BeanFactory beanFactory) { + Class beanType = beanFactory.getType(beanName); + return (beanType != null ? ClassUtils.getUserClass(beanType) : null); + } + + private static HandlerTypePredicate createBeanTypePredicate(Class beanType) { + ControllerAdvice annotation = (beanType != null ? + AnnotatedElementUtils.findMergedAnnotation(beanType, ControllerAdvice.class) : null); + + if (annotation != null) { + return HandlerTypePredicate.builder() + .basePackage(annotation.basePackages()) + .basePackageClass(annotation.basePackageClasses()) + .assignableType(annotation.assignableTypes()) + .annotation(annotation.annotations()) + .build(); + } + return HandlerTypePredicate.forAnyHandlerType(); + } + private static int initOrderFromBean(Object bean) { return (bean instanceof Ordered ? ((Ordered) bean).getOrder() : initOrderFromBeanType(bean.getClass())); } diff --git a/spring-web/src/test/java/org/springframework/web/method/ControllerAdviceBeanTests.java b/spring-web/src/test/java/org/springframework/web/method/ControllerAdviceBeanTests.java index 6ba336bf2a..3810ea4154 100644 --- a/spring-web/src/test/java/org/springframework/web/method/ControllerAdviceBeanTests.java +++ b/spring-web/src/test/java/org/springframework/web/method/ControllerAdviceBeanTests.java @@ -51,15 +51,15 @@ public class ControllerAdviceBeanTests { assertThatIllegalArgumentException() .isThrownBy(() -> new ControllerAdviceBean((String) null, null)) - .withMessage("Bean must not be null"); + .withMessage("Bean name must contain text"); assertThatIllegalArgumentException() .isThrownBy(() -> new ControllerAdviceBean("", null)) - .withMessage("Bean name must not be empty"); + .withMessage("Bean name must contain text"); assertThatIllegalArgumentException() .isThrownBy(() -> new ControllerAdviceBean("\t", null)) - .withMessage("Bean name must not be empty"); + .withMessage("Bean name must contain text"); assertThatIllegalArgumentException() .isThrownBy(() -> new ControllerAdviceBean("myBean", null))