diff --git a/spring-messaging/src/main/java/org/springframework/messaging/handler/annotation/support/PayloadMethodArgumentResolver.java b/spring-messaging/src/main/java/org/springframework/messaging/handler/annotation/support/PayloadMethodArgumentResolver.java index 9e14cb324e..ce1d3d1663 100644 --- a/spring-messaging/src/main/java/org/springframework/messaging/handler/annotation/support/PayloadMethodArgumentResolver.java +++ b/spring-messaging/src/main/java/org/springframework/messaging/handler/annotation/support/PayloadMethodArgumentResolver.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2021 the original author or authors. + * Copyright 2002-2023 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -17,6 +17,7 @@ package org.springframework.messaging.handler.annotation.support; import java.lang.annotation.Annotation; +import java.util.Optional; import org.springframework.core.MethodParameter; import org.springframework.core.annotation.AnnotationUtils; @@ -27,6 +28,7 @@ import org.springframework.messaging.converter.MessageConverter; import org.springframework.messaging.converter.SmartMessageConverter; import org.springframework.messaging.handler.annotation.Payload; import org.springframework.messaging.handler.invocation.HandlerMethodArgumentResolver; +import org.springframework.messaging.support.MessageBuilder; import org.springframework.util.Assert; import org.springframework.util.ClassUtils; import org.springframework.util.ObjectUtils; @@ -113,24 +115,30 @@ public class PayloadMethodArgumentResolver implements HandlerMethodArgumentResol throw new IllegalStateException("@Payload SpEL expressions not supported by this resolver"); } + boolean isOptionalTargetClass = (parameter.getParameterType() == Optional.class); Object payload = message.getPayload(); if (isEmptyPayload(payload)) { - if (ann == null || ann.required()) { + if ((ann == null || ann.required()) && !isOptionalTargetClass) { String paramName = getParameterName(parameter); BindingResult bindingResult = new BeanPropertyBindingResult(payload, paramName); bindingResult.addError(new ObjectError(paramName, "Payload value must not be empty")); throw new MethodArgumentNotValidException(message, parameter, bindingResult); } else { - return null; + return (isOptionalTargetClass ? Optional.empty() : null); } } + if (payload instanceof Optional optional) { + payload = optional.get(); + message = MessageBuilder.createMessage(payload, message.getHeaders()); + } + Class targetClass = resolveTargetClass(parameter, message); Class payloadClass = payload.getClass(); if (ClassUtils.isAssignable(targetClass, payloadClass)) { validate(message, parameter, payload); - return payload; + return (isOptionalTargetClass ? Optional.of(payload) : payload); } else { if (this.converter instanceof SmartMessageConverter smartConverter) { @@ -144,7 +152,7 @@ public class PayloadMethodArgumentResolver implements HandlerMethodArgumentResol payloadClass.getName() + "] to [" + targetClass.getName() + "] for " + message); } validate(message, parameter, payload); - return payload; + return (isOptionalTargetClass ? Optional.of(payload) : payload); } } @@ -161,11 +169,14 @@ public class PayloadMethodArgumentResolver implements HandlerMethodArgumentResol if (payload == null) { return true; } - else if (payload instanceof byte[]) { - return ((byte[]) payload).length == 0; + else if (payload instanceof byte[] bytes) { + return bytes.length == 0; } - else if (payload instanceof String) { - return !StringUtils.hasText((String) payload); + else if (payload instanceof String s) { + return !StringUtils.hasText(s); + } + else if (payload instanceof Optional optional) { + return optional.isEmpty(); } else { return false; @@ -184,7 +195,7 @@ public class PayloadMethodArgumentResolver implements HandlerMethodArgumentResol * @since 5.2 */ protected Class resolveTargetClass(MethodParameter parameter, Message message) { - return parameter.getParameterType(); + return parameter.nestedIfOptional().getNestedParameterType(); } /** diff --git a/spring-messaging/src/test/java/org/springframework/messaging/handler/annotation/support/PayloadMethodArgumentResolverTests.java b/spring-messaging/src/test/java/org/springframework/messaging/handler/annotation/support/PayloadMethodArgumentResolverTests.java index 8782f59bd4..4c3fab7d7f 100644 --- a/spring-messaging/src/test/java/org/springframework/messaging/handler/annotation/support/PayloadMethodArgumentResolverTests.java +++ b/spring-messaging/src/test/java/org/springframework/messaging/handler/annotation/support/PayloadMethodArgumentResolverTests.java @@ -22,6 +22,7 @@ import java.lang.annotation.RetentionPolicy; import java.lang.annotation.Target; import java.lang.reflect.Method; import java.util.Locale; +import java.util.Optional; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; @@ -61,6 +62,8 @@ public class PayloadMethodArgumentResolverTests { private MethodParameter paramWithSpelExpression; + private MethodParameter paramOptional; + private MethodParameter paramNotAnnotated; private MethodParameter paramValidatedNotAnnotated; @@ -74,16 +77,17 @@ public class PayloadMethodArgumentResolverTests { Method payloadMethod = PayloadMethodArgumentResolverTests.class.getDeclaredMethod( "handleMessage", String.class, String.class, Locale.class, - String.class, String.class, String.class, String.class); + String.class, Optional.class, String.class, String.class, String.class); this.paramAnnotated = new SynthesizingMethodParameter(payloadMethod, 0); this.paramAnnotatedNotRequired = new SynthesizingMethodParameter(payloadMethod, 1); this.paramAnnotatedRequired = new SynthesizingMethodParameter(payloadMethod, 2); this.paramWithSpelExpression = new SynthesizingMethodParameter(payloadMethod, 3); - this.paramValidated = new SynthesizingMethodParameter(payloadMethod, 4); + this.paramOptional = new SynthesizingMethodParameter(payloadMethod, 4); + this.paramValidated = new SynthesizingMethodParameter(payloadMethod, 5); this.paramValidated.initParameterNameDiscovery(new DefaultParameterNameDiscoverer()); - this.paramValidatedNotAnnotated = new SynthesizingMethodParameter(payloadMethod, 5); - this.paramNotAnnotated = new SynthesizingMethodParameter(payloadMethod, 6); + this.paramValidatedNotAnnotated = new SynthesizingMethodParameter(payloadMethod, 6); + this.paramNotAnnotated = new SynthesizingMethodParameter(payloadMethod, 7); } @Test @@ -127,13 +131,33 @@ public class PayloadMethodArgumentResolverTests { Message emptyByteArrayMessage = MessageBuilder.withPayload(new byte[0]).build(); assertThat(this.resolver.resolveArgument(this.paramAnnotatedNotRequired, emptyByteArrayMessage)).isNull(); - Message emptyStringMessage = MessageBuilder.withPayload("").build(); + Message emptyStringMessage = MessageBuilder.withPayload(" ").build(); assertThat(this.resolver.resolveArgument(this.paramAnnotatedNotRequired, emptyStringMessage)).isNull(); + assertThat(((Optional) this.resolver.resolveArgument(this.paramOptional, emptyStringMessage)).isEmpty()).isTrue(); + + Message emptyOptionalMessage = MessageBuilder.withPayload(Optional.empty()).build(); + assertThat(this.resolver.resolveArgument(this.paramAnnotatedNotRequired, emptyOptionalMessage)).isNull(); Message notEmptyMessage = MessageBuilder.withPayload("ABC".getBytes()).build(); assertThat(this.resolver.resolveArgument(this.paramAnnotatedNotRequired, notEmptyMessage)).isEqualTo("ABC"); } + @Test + public void resolveOptionalTarget() throws Exception { + Message message = MessageBuilder.withPayload("ABC".getBytes()).build(); + Object actual = this.resolver.resolveArgument(paramOptional, message); + + assertThat(((Optional) actual).get()).isEqualTo("ABC"); + } + + @Test + public void resolveOptionalSource() throws Exception { + Message message = MessageBuilder.withPayload(Optional.of("ABC".getBytes())).build(); + Object actual = this.resolver.resolveArgument(paramAnnotated, message); + + assertThat(actual).isEqualTo("ABC"); + } + @Test public void resolveNonConvertibleParam() { Message notEmptyMessage = MessageBuilder.withPayload(123).build(); @@ -218,6 +242,7 @@ public class PayloadMethodArgumentResolverTests { @Payload(required=false) String paramNotRequired, @Payload(required=true) Locale nonConvertibleRequiredParam, @Payload("foo.bar") String paramWithSpelExpression, + @Payload Optional optionalParam, @MyValid @Payload String validParam, @Validated String validParamNotAnnotated, String paramNotAnnotated) {