diff --git a/spring-test/src/main/java/org/springframework/test/context/jdbc/SqlScriptsTestExecutionListener.java b/spring-test/src/main/java/org/springframework/test/context/jdbc/SqlScriptsTestExecutionListener.java index 33376a4d35..0b16e4f216 100644 --- a/spring-test/src/main/java/org/springframework/test/context/jdbc/SqlScriptsTestExecutionListener.java +++ b/spring-test/src/main/java/org/springframework/test/context/jdbc/SqlScriptsTestExecutionListener.java @@ -17,14 +17,17 @@ package org.springframework.test.context.jdbc; import java.lang.reflect.Method; +import java.util.Arrays; import java.util.List; import java.util.Set; +import java.util.stream.Stream; import javax.sql.DataSource; import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; +import org.springframework.aot.hint.RuntimeHints; import org.springframework.context.ApplicationContext; import org.springframework.core.annotation.AnnotatedElementUtils; import org.springframework.core.io.ByteArrayResource; @@ -35,6 +38,7 @@ import org.springframework.lang.NonNull; import org.springframework.lang.Nullable; import org.springframework.test.context.TestContext; import org.springframework.test.context.TestContextAnnotationUtils; +import org.springframework.test.context.aot.AotTestExecutionListener; import org.springframework.test.context.jdbc.Sql.ExecutionPhase; import org.springframework.test.context.jdbc.SqlConfig.ErrorMode; import org.springframework.test.context.jdbc.SqlConfig.TransactionMode; @@ -52,9 +56,11 @@ import org.springframework.util.Assert; import org.springframework.util.ClassUtils; import org.springframework.util.ObjectUtils; import org.springframework.util.ReflectionUtils; -import org.springframework.util.ResourceUtils; +import org.springframework.util.ReflectionUtils.MethodFilter; import org.springframework.util.StringUtils; +import static org.springframework.util.ResourceUtils.CLASSPATH_URL_PREFIX; + /** * {@code TestExecutionListener} that provides support for executing SQL * {@link Sql#scripts scripts} and inlined {@link Sql#statements statements} @@ -90,18 +96,22 @@ import org.springframework.util.StringUtils; * @since 4.1 * @see Sql * @see SqlConfig + * @see SqlMergeMode * @see SqlGroup * @see org.springframework.test.context.transaction.TestContextTransactionUtils * @see org.springframework.test.context.transaction.TransactionalTestExecutionListener * @see org.springframework.jdbc.datasource.init.ResourceDatabasePopulator * @see org.springframework.jdbc.datasource.init.ScriptUtils */ -public class SqlScriptsTestExecutionListener extends AbstractTestExecutionListener { +public class SqlScriptsTestExecutionListener extends AbstractTestExecutionListener implements AotTestExecutionListener { private static final String SLASH = "/"; private static final Log logger = LogFactory.getLog(SqlScriptsTestExecutionListener.class); + private static final MethodFilter sqlMethodFilter = ReflectionUtils.USER_DECLARED_METHODS + .and(method -> AnnotatedElementUtils.hasAnnotation(method, Sql.class)); + /** * Returns {@code 5000}. @@ -129,6 +139,21 @@ public class SqlScriptsTestExecutionListener extends AbstractTestExecutionListen executeSqlScripts(testContext, ExecutionPhase.AFTER_TEST_METHOD); } + /** + * Process the supplied test class and its methods and register run-time + * hints for any SQL scripts configured or detected as classpath resources + * via {@link Sql @Sql}. + * @since 6.0 + */ + @Override + public void processAheadOfTime(Class testClass, RuntimeHints runtimeHints, ClassLoader classLoader) { + getSqlAnnotationsFor(testClass).forEach(sql -> + registerClasspathResources(runtimeHints, getScripts(sql, testClass, null, true))); + getSqlMethods(testClass).forEach(testMethod -> + getSqlAnnotationsFor(testMethod).forEach(sql -> + registerClasspathResources(runtimeHints, getScripts(sql, testClass, testMethod, false)))); + } + /** * Execute SQL scripts configured via {@link Sql @Sql} for the supplied * {@link TestContext} and {@link ExecutionPhase}. @@ -226,8 +251,7 @@ public class SqlScriptsTestExecutionListener extends AbstractTestExecutionListen mergedSqlConfig, executionPhase, testContext)); } - String[] scripts = getScripts(sql, testContext, classLevel); - scripts = TestContextResourceUtils.convertToClasspathResourcePaths(testContext.getTestClass(), scripts); + String[] scripts = getScripts(sql, testContext.getTestClass(), testContext.getTestMethod(), classLevel); List scriptResources = TestContextResourceUtils.convertToResourceList( testContext.getApplicationContext(), scripts); for (String stmt : sql.statements()) { @@ -321,31 +345,29 @@ public class SqlScriptsTestExecutionListener extends AbstractTestExecutionListen return null; } - private String[] getScripts(Sql sql, TestContext testContext, boolean classLevel) { + private String[] getScripts(Sql sql, Class testClass, Method testMethod, boolean classLevel) { String[] scripts = sql.scripts(); if (ObjectUtils.isEmpty(scripts) && ObjectUtils.isEmpty(sql.statements())) { - scripts = new String[] {detectDefaultScript(testContext, classLevel)}; + scripts = new String[] {detectDefaultScript(testClass, testMethod, classLevel)}; } - return scripts; + return TestContextResourceUtils.convertToClasspathResourcePaths(testClass, scripts); } /** * Detect a default SQL script by implementing the algorithm defined in * {@link Sql#scripts}. */ - private String detectDefaultScript(TestContext testContext, boolean classLevel) { - Class clazz = testContext.getTestClass(); - Method method = testContext.getTestMethod(); + private String detectDefaultScript(Class testClass, Method testMethod, boolean classLevel) { String elementType = (classLevel ? "class" : "method"); - String elementName = (classLevel ? clazz.getName() : method.toString()); + String elementName = (classLevel ? testClass.getName() : testMethod.toString()); - String resourcePath = ClassUtils.convertClassNameToResourcePath(clazz.getName()); + String resourcePath = ClassUtils.convertClassNameToResourcePath(testClass.getName()); if (!classLevel) { - resourcePath += "." + method.getName(); + resourcePath += "." + testMethod.getName(); } resourcePath += ".sql"; - String prefixedResourcePath = ResourceUtils.CLASSPATH_URL_PREFIX + SLASH + resourcePath; + String prefixedResourcePath = CLASSPATH_URL_PREFIX + SLASH + resourcePath; ClassPathResource classPathResource = new ClassPathResource(resourcePath); if (classPathResource.exists()) { @@ -364,4 +386,23 @@ public class SqlScriptsTestExecutionListener extends AbstractTestExecutionListen } } + private Stream getSqlMethods(Class testClass) { + return Arrays.stream(ReflectionUtils.getUniqueDeclaredMethods(testClass, sqlMethodFilter)); + } + + private void registerClasspathResources(RuntimeHints runtimeHints, String... locations) { + Arrays.stream(locations) + .filter(location -> location.startsWith(CLASSPATH_URL_PREFIX)) + .map(this::cleanClasspathResource) + .forEach(runtimeHints.resources()::registerPattern); + } + + private String cleanClasspathResource(String location) { + location = location.substring(CLASSPATH_URL_PREFIX.length()); + if (!location.startsWith(SLASH)) { + location = SLASH + location; + } + return location; + } + } diff --git a/spring-test/src/test/java/org/springframework/test/context/aot/TestContextAotGeneratorTests.java b/spring-test/src/test/java/org/springframework/test/context/aot/TestContextAotGeneratorTests.java index dbe7ecc469..8124709a86 100644 --- a/spring-test/src/test/java/org/springframework/test/context/aot/TestContextAotGeneratorTests.java +++ b/spring-test/src/test/java/org/springframework/test/context/aot/TestContextAotGeneratorTests.java @@ -191,6 +191,12 @@ class TestContextAotGeneratorTests extends AbstractAotTests { // @WebAppConfiguration(value = ...) assertThat(resource().forResource("/META-INF/web-resources/resources/Spring.js")).accepts(runtimeHints); assertThat(resource().forResource("/META-INF/web-resources/WEB-INF/views/home.jsp")).accepts(runtimeHints); + + // @Sql(scripts = ...) + assertThat(resource().forResource("/org/springframework/test/context/jdbc/schema.sql")) + .accepts(runtimeHints); + assertThat(resource().forResource("/org/springframework/test/context/aot/samples/jdbc/SqlScriptsSpringJupiterTests.test.sql")) + .accepts(runtimeHints); } private static void assertReflectionRegistered(RuntimeHints runtimeHints, String type, MemberCategory memberCategory) {