diff --git a/spring-cloud-function-web/src/main/java/org/springframework/cloud/function/web/mvc/FunctionController.java b/spring-cloud-function-web/src/main/java/org/springframework/cloud/function/web/mvc/FunctionController.java index e3bbb2fd9..7be88b573 100644 --- a/spring-cloud-function-web/src/main/java/org/springframework/cloud/function/web/mvc/FunctionController.java +++ b/spring-cloud-function-web/src/main/java/org/springframework/cloud/function/web/mvc/FunctionController.java @@ -18,8 +18,11 @@ package org.springframework.cloud.function.web.mvc; import java.util.Arrays; import java.util.Iterator; +import java.util.List; +import java.util.stream.Collectors; import org.reactivestreams.Publisher; +import reactor.core.publisher.Flux; import reactor.core.publisher.Mono; import org.springframework.cloud.function.context.catalog.SimpleFunctionRegistry.FunctionInvocationWrapper; @@ -28,12 +31,20 @@ import org.springframework.cloud.function.web.RequestProcessor.FunctionWrapper; import org.springframework.cloud.function.web.constants.WebRequestConstants; import org.springframework.http.MediaType; import org.springframework.http.ResponseEntity; +import org.springframework.http.ResponseEntity.BodyBuilder; +import org.springframework.messaging.Message; +import org.springframework.messaging.support.MessageBuilder; import org.springframework.stereotype.Component; +import org.springframework.util.CollectionUtils; +import org.springframework.util.MultiValueMap; import org.springframework.web.bind.annotation.GetMapping; import org.springframework.web.bind.annotation.PostMapping; import org.springframework.web.bind.annotation.RequestBody; import org.springframework.web.bind.annotation.ResponseBody; +import org.springframework.web.context.request.ServletWebRequest; import org.springframework.web.context.request.WebRequest; +import org.springframework.web.multipart.MultipartFile; +import org.springframework.web.multipart.support.StandardMultipartHttpServletRequest; /** * @author Dave Syer @@ -48,14 +59,44 @@ public class FunctionController { this.processor = processor; } + + @PostMapping(path = "/**", consumes = { MediaType.APPLICATION_FORM_URLENCODED_VALUE, MediaType.MULTIPART_FORM_DATA_VALUE }) @ResponseBody public Mono> form(WebRequest request) { FunctionWrapper wrapper = wrapper(request); + + if (((ServletWebRequest) request).getRequest() instanceof StandardMultipartHttpServletRequest) { + MultiValueMap multiFileMap = ((StandardMultipartHttpServletRequest) ((ServletWebRequest) request) + .getRequest()).getMultiFileMap(); + if (!CollectionUtils.isEmpty(multiFileMap)) { + List> files = multiFileMap.values().stream().flatMap(v -> v.stream()) + .map(file -> MessageBuilder.withPayload(file).copyHeaders(wrapper.headers()).build()) + .collect(Collectors.toList()); + FunctionInvocationWrapper function = wrapper.function(); + + Publisher result = (Publisher) function.apply(Flux.fromIterable(files)); + BodyBuilder builder = ResponseEntity.ok(); + if (result instanceof Flux) { + result = Flux.from(result).map(message -> ((Message) message).getPayload()).collectList(); + } + return Mono.from(result).flatMap(body -> Mono.just(builder.body(body))); + } + } return this.processor.post(wrapper, null, false); } +// @PostMapping(path = "/**", consumes = { MediaType.APPLICATION_FORM_URLENCODED_VALUE, +// MediaType.MULTIPART_FORM_DATA_VALUE }) +// public Mono> handleFileUpload(@RequestParam("file") MultipartFile file, WebRequest request) { +// FunctionWrapper wrapper = wrapper(request); +// +// Object result = wrapper.function().apply(file); +// +// return Mono.just(ResponseEntity.status(HttpStatus.OK).body(result)); +// } + @PostMapping(path = "/**") @ResponseBody public Mono> post(WebRequest request, diff --git a/spring-cloud-function-web/src/test/java/org/springframework/cloud/function/web/mvc/MultipartFileTests.java b/spring-cloud-function-web/src/test/java/org/springframework/cloud/function/web/mvc/MultipartFileTests.java new file mode 100644 index 000000000..9956b71e5 --- /dev/null +++ b/spring-cloud-function-web/src/test/java/org/springframework/cloud/function/web/mvc/MultipartFileTests.java @@ -0,0 +1,114 @@ +/* + * Copyright 2012-2019 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.cloud.function.web.mvc; + +import java.net.URI; +import java.util.List; +import java.util.function.Function; + +import org.junit.jupiter.api.AfterEach; +import org.junit.jupiter.api.BeforeEach; +import org.junit.jupiter.api.Test; + +import org.springframework.boot.SpringApplication; +import org.springframework.boot.autoconfigure.EnableAutoConfiguration; +import org.springframework.boot.test.web.client.TestRestTemplate; +import org.springframework.cloud.function.json.JsonMapper; +import org.springframework.context.ApplicationContext; +import org.springframework.context.annotation.Bean; +import org.springframework.core.io.ClassPathResource; +import org.springframework.http.HttpEntity; +import org.springframework.http.HttpHeaders; +import org.springframework.http.HttpMethod; +import org.springframework.http.MediaType; +import org.springframework.http.ResponseEntity; +import org.springframework.util.LinkedMultiValueMap; +import org.springframework.util.SocketUtils; +import org.springframework.web.multipart.MultipartFile; + +import static org.assertj.core.api.Assertions.assertThat; +/** + * + * @author Oleg Zhurakousky + * + */ +public class MultipartFileTests { + + @BeforeEach + public void init() throws Exception { + String port = "" + SocketUtils.findAvailableTcpPort(); + System.setProperty("server.port", port); + } + + @AfterEach + public void close() throws Exception { + System.clearProperty("server.port"); + } + + @Test + public void testMultipartFileUpload() throws Exception { + ApplicationContext context = SpringApplication.run(TestConfiguration.class); + JsonMapper mapper = context.getBean(JsonMapper.class); + TestRestTemplate template = new TestRestTemplate(); + String port = System.getProperty("server.port"); + + LinkedMultiValueMap map = new LinkedMultiValueMap<>(); + map.add("file", new ClassPathResource("META-INF/spring.factories")); + HttpHeaders headers = new HttpHeaders(); + headers.setContentType(MediaType.MULTIPART_FORM_DATA); + + HttpEntity> requestEntity = new HttpEntity>( + map, headers); + ResponseEntity result = template.exchange(new URI("http://localhost:" + port + "/uppercase"), + HttpMethod.POST, requestEntity, String.class); + List resultCollection = mapper.fromJson(result.getBody(), List.class); + assertThat(resultCollection.get(0)).isEqualTo("SPRING.FACTORIES"); + } + + @Test + public void testMultipartFilesUpload() throws Exception { + ApplicationContext context = SpringApplication.run(TestConfiguration.class); + JsonMapper mapper = context.getBean(JsonMapper.class); + TestRestTemplate template = new TestRestTemplate(); + String port = System.getProperty("server.port"); + + LinkedMultiValueMap map = new LinkedMultiValueMap<>(); + map.add("fileA", new ClassPathResource("META-INF/spring.factories")); + map.add("fileB", new ClassPathResource("static/test.html")); + HttpHeaders headers = new HttpHeaders(); + headers.setContentType(MediaType.MULTIPART_FORM_DATA); + + HttpEntity> requestEntity = new HttpEntity>( + map, headers); + ResponseEntity result = template.exchange(new URI("http://localhost:" + port + "/uppercase"), + HttpMethod.POST, requestEntity, String.class); + List resultCollection = mapper.fromJson(result.getBody(), List.class); + assertThat(resultCollection.get(0)).isEqualTo("SPRING.FACTORIES"); + assertThat(resultCollection.get(1)).isEqualTo("TEST.HTML"); + } + + @EnableAutoConfiguration + protected static class TestConfiguration { + + @Bean + public Function uppercase() { + return value -> { + return value.getOriginalFilename().toUpperCase(); + }; + } + } +}