diff --git a/spring-test/src/main/java/org/springframework/test/web/servlet/request/MockMultipartHttpServletRequestBuilder.java b/spring-test/src/main/java/org/springframework/test/web/servlet/request/MockMultipartHttpServletRequestBuilder.java index dea33261c2..75cb972331 100644 --- a/spring-test/src/main/java/org/springframework/test/web/servlet/request/MockMultipartHttpServletRequestBuilder.java +++ b/spring-test/src/main/java/org/springframework/test/web/servlet/request/MockMultipartHttpServletRequestBuilder.java @@ -16,6 +16,7 @@ package org.springframework.test.web.servlet.request; +import java.io.IOException; import java.net.URI; import java.util.ArrayList; import java.util.Collection; @@ -24,12 +25,14 @@ import java.util.List; import javax.servlet.ServletContext; import javax.servlet.http.Part; +import org.springframework.http.HttpHeaders; import org.springframework.http.HttpMethod; import org.springframework.http.MediaType; import org.springframework.lang.Nullable; import org.springframework.mock.web.MockHttpServletRequest; import org.springframework.mock.web.MockMultipartFile; import org.springframework.mock.web.MockMultipartHttpServletRequest; +import org.springframework.mock.web.MockPart; import org.springframework.util.Assert; import org.springframework.util.LinkedMultiValueMap; import org.springframework.util.MultiValueMap; @@ -119,7 +122,7 @@ public class MockMultipartHttpServletRequestBuilder extends MockHttpServletReque if (parent instanceof MockMultipartHttpServletRequestBuilder) { MockMultipartHttpServletRequestBuilder parentBuilder = (MockMultipartHttpServletRequestBuilder) parent; this.files.addAll(parentBuilder.files); - parentBuilder.parts.keySet().stream().forEach(name -> + parentBuilder.parts.keySet().forEach(name -> this.parts.putIfAbsent(name, parentBuilder.parts.get(name))); } @@ -138,9 +141,26 @@ public class MockMultipartHttpServletRequestBuilder extends MockHttpServletReque @Override protected final MockHttpServletRequest createServletRequest(ServletContext servletContext) { MockMultipartHttpServletRequest request = new MockMultipartHttpServletRequest(servletContext); - this.files.forEach(request::addFile); + this.files.forEach(file -> request.addPart(toMockPart(file))); this.parts.values().stream().flatMap(Collection::stream).forEach(request::addPart); return request; } + private MockPart toMockPart(MockMultipartFile file) { + byte[] bytes = null; + if (!file.isEmpty()) { + try { + bytes = file.getBytes(); + } + catch (IOException ex) { + throw new IllegalStateException("Unexpected IOException", ex); + } + } + MockPart part = new MockPart(file.getName(), file.getOriginalFilename(), bytes); + if (file.getContentType() != null) { + part.getHeaders().set(HttpHeaders.CONTENT_TYPE, file.getContentType()); + } + return part; + } + } diff --git a/spring-test/src/test/java/org/springframework/test/web/servlet/request/MockMultipartHttpServletRequestBuilderTests.java b/spring-test/src/test/java/org/springframework/test/web/servlet/request/MockMultipartHttpServletRequestBuilderTests.java index 82f039d590..a7087f424f 100644 --- a/spring-test/src/test/java/org/springframework/test/web/servlet/request/MockMultipartHttpServletRequestBuilderTests.java +++ b/spring-test/src/test/java/org/springframework/test/web/servlet/request/MockMultipartHttpServletRequestBuilderTests.java @@ -1,5 +1,5 @@ /* - * Copyright 2002-2019 the original author or authors. + * Copyright 2002-2020 the original author or authors. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -16,21 +16,43 @@ package org.springframework.test.web.servlet.request; +import java.nio.charset.StandardCharsets; + +import javax.servlet.http.Part; + import org.junit.jupiter.api.Test; import org.springframework.http.HttpMethod; import org.springframework.mock.web.MockHttpServletRequest; +import org.springframework.mock.web.MockMultipartFile; +import org.springframework.mock.web.MockPart; import org.springframework.mock.web.MockServletContext; +import org.springframework.web.multipart.support.StandardMultipartHttpServletRequest; import static org.assertj.core.api.Assertions.assertThat; /** + * Unit tests for {@link MockMultipartHttpServletRequestBuilder}. * @author Rossen Stoyanchev */ public class MockMultipartHttpServletRequestBuilderTests { + @Test // gh-26166 + void addFilesAndParts() throws Exception { + MockHttpServletRequest mockRequest = new MockMultipartHttpServletRequestBuilder("/upload") + .file(new MockMultipartFile("file", "test.txt", "text/plain", "Test".getBytes(StandardCharsets.UTF_8))) + .part(new MockPart("data", "{\"node\":\"node\"}".getBytes(StandardCharsets.UTF_8))) + .buildRequest(new MockServletContext()); + + StandardMultipartHttpServletRequest parsedRequest = new StandardMultipartHttpServletRequest(mockRequest); + + assertThat(parsedRequest.getParameterMap()).containsOnlyKeys("data"); + assertThat(parsedRequest.getFileMap()).containsOnlyKeys("file"); + assertThat(parsedRequest.getParts()).extracting(Part::getName).containsExactly("file", "data"); + } + @Test - public void test() { + void mergeAndBuild() { MockHttpServletRequestBuilder parent = new MockHttpServletRequestBuilder(HttpMethod.GET, "/"); parent.characterEncoding("UTF-8"); Object result = new MockMultipartHttpServletRequestBuilder("/fileUpload").merge(parent);