diff --git a/spring-beans/src/test/java/org/springframework/beans/factory/annotation/AutowiredAnnotationBeanRegistrationAotContributionTests.java b/spring-beans/src/test/java/org/springframework/beans/factory/annotation/AutowiredAnnotationBeanRegistrationAotContributionTests.java index eb1b1be7c2..fc6cb193f8 100644 --- a/spring-beans/src/test/java/org/springframework/beans/factory/annotation/AutowiredAnnotationBeanRegistrationAotContributionTests.java +++ b/spring-beans/src/test/java/org/springframework/beans/factory/annotation/AutowiredAnnotationBeanRegistrationAotContributionTests.java @@ -91,7 +91,7 @@ class AutowiredAnnotationBeanRegistrationAotContributionTests { } @Test - @CompileWithTargetClassAccess(classes = PackagePrivateFieldInjectionSample.class) + @CompileWithTargetClassAccess void contributeWhenPackagePrivateFieldInjectionInjectsUsingConsumer() { Environment environment = new StandardEnvironment(); this.beanFactory.registerSingleton("environment", environment); @@ -122,7 +122,7 @@ class AutowiredAnnotationBeanRegistrationAotContributionTests { } @Test - @CompileWithTargetClassAccess(classes = PackagePrivateMethodInjectionSample.class) + @CompileWithTargetClassAccess void contributeWhenPackagePrivateMethodInjectionInjectsUsingConsumer() { Environment environment = new StandardEnvironment(); this.beanFactory.registerSingleton("environment", environment); diff --git a/spring-beans/src/test/java/org/springframework/beans/factory/aot/BeanDefinitionMethodGeneratorTests.java b/spring-beans/src/test/java/org/springframework/beans/factory/aot/BeanDefinitionMethodGeneratorTests.java index 56df55ccde..c1641d21ee 100644 --- a/spring-beans/src/test/java/org/springframework/beans/factory/aot/BeanDefinitionMethodGeneratorTests.java +++ b/spring-beans/src/test/java/org/springframework/beans/factory/aot/BeanDefinitionMethodGeneratorTests.java @@ -342,7 +342,7 @@ class BeanDefinitionMethodGeneratorTests { } @Test - @CompileWithTargetClassAccess(classes = PackagePrivateTestBean.class) + @CompileWithTargetClassAccess void generateBeanDefinitionMethodWhenPackagePrivateBean() { RegisteredBean registeredBean = registerBean( new RootBeanDefinition(PackagePrivateTestBean.class)); diff --git a/spring-core-test/src/main/java/org/springframework/aot/test/generator/compile/CompileWithTargetClassAccess.java b/spring-core-test/src/main/java/org/springframework/aot/test/generator/compile/CompileWithTargetClassAccess.java index 181b32ba79..471b3f805a 100644 --- a/spring-core-test/src/main/java/org/springframework/aot/test/generator/compile/CompileWithTargetClassAccess.java +++ b/spring-core-test/src/main/java/org/springframework/aot/test/generator/compile/CompileWithTargetClassAccess.java @@ -21,16 +21,13 @@ import java.lang.annotation.ElementType; import java.lang.annotation.Retention; import java.lang.annotation.RetentionPolicy; import java.lang.annotation.Target; -import java.lang.invoke.MethodHandles; -import java.lang.invoke.MethodHandles.Lookup; import org.junit.jupiter.api.extension.ExtendWith; /** * Annotation that can be used on tests that need a {@link TestCompiler} with - * non-public access to a target class. Allows the compiler to use - * {@link MethodHandles#privateLookupIn} to {@link Lookup#defineClass define the - * class} without polluting the test {@link ClassLoader}. + * non-public access to target classes. Allows the compiler to define classes + * without polluting the test {@link ClassLoader}. * * @author Phillip Webb * @since 6.0 @@ -41,16 +38,4 @@ import org.junit.jupiter.api.extension.ExtendWith; @ExtendWith(CompileWithTargetClassAccessExtension.class) public @interface CompileWithTargetClassAccess { - /** - * The target class names. - * @return the class name - */ - String[] classNames() default {}; - - /** - * The target classes. - * @return the classes - */ - Class[] classes() default {}; - } diff --git a/spring-core-test/src/main/java/org/springframework/aot/test/generator/compile/CompileWithTargetClassAccessClassLoader.java b/spring-core-test/src/main/java/org/springframework/aot/test/generator/compile/CompileWithTargetClassAccessClassLoader.java index 9882c1864b..8114a6cdf1 100644 --- a/spring-core-test/src/main/java/org/springframework/aot/test/generator/compile/CompileWithTargetClassAccessClassLoader.java +++ b/spring-core-test/src/main/java/org/springframework/aot/test/generator/compile/CompileWithTargetClassAccessClassLoader.java @@ -32,21 +32,13 @@ final class CompileWithTargetClassAccessClassLoader extends ClassLoader { private final ClassLoader testClassLoader; - private final String[] targetClasses; - - public CompileWithTargetClassAccessClassLoader(ClassLoader testClassLoader, - String[] targetClasses) { + public CompileWithTargetClassAccessClassLoader(ClassLoader testClassLoader) { super(testClassLoader.getParent()); this.testClassLoader = testClassLoader; - this.targetClasses = targetClasses; } - public String[] getTargetClasses() { - return this.targetClasses; - } - @Override public Class loadClass(String name) throws ClassNotFoundException { if (name.startsWith("org.junit") || name.startsWith("org.hamcrest")) { @@ -70,6 +62,11 @@ final class CompileWithTargetClassAccessClassLoader extends ClassLoader { return super.findClass(name); } + + Class defineClassWithTargetAccess(String name, byte[] b, int off, int len) { + return super.defineClass(name, b, off, len); + } + @Override protected Enumeration findResources(String name) throws IOException { return this.testClassLoader.getResources(name); diff --git a/spring-core-test/src/main/java/org/springframework/aot/test/generator/compile/CompileWithTargetClassAccessExtension.java b/spring-core-test/src/main/java/org/springframework/aot/test/generator/compile/CompileWithTargetClassAccessExtension.java index 0fdf471773..a221b60f8a 100644 --- a/spring-core-test/src/main/java/org/springframework/aot/test/generator/compile/CompileWithTargetClassAccessExtension.java +++ b/spring-core-test/src/main/java/org/springframework/aot/test/generator/compile/CompileWithTargetClassAccessExtension.java @@ -16,11 +16,7 @@ package org.springframework.aot.test.generator.compile; -import java.lang.reflect.AnnotatedElement; import java.lang.reflect.Method; -import java.util.Arrays; -import java.util.LinkedHashSet; -import java.util.Set; import org.junit.jupiter.api.extension.ExtensionContext; import org.junit.jupiter.api.extension.InvocationInterceptor; @@ -34,8 +30,6 @@ import org.junit.platform.launcher.core.LauncherFactory; import org.junit.platform.launcher.listeners.SummaryGeneratingListener; import org.junit.platform.launcher.listeners.TestExecutionSummary; -import org.springframework.core.annotation.MergedAnnotation; -import org.springframework.core.annotation.MergedAnnotations; import org.springframework.util.Assert; import org.springframework.util.CollectionUtils; import org.springframework.util.ReflectionUtils; @@ -121,10 +115,9 @@ class CompileWithTargetClassAccessExtension implements InvocationInterceptor { Class testClass = extensionContext.getRequiredTestClass(); Method testMethod = invocationContext.getExecutable(); - String[] targetClasses = getTargetClasses(testClass, testMethod); ClassLoader originalClassLoader = Thread.currentThread().getContextClassLoader(); ClassLoader forkedClassPathClassLoader = new CompileWithTargetClassAccessClassLoader( - testClass.getClassLoader(), targetClasses); + testClass.getClassLoader()); Thread.currentThread().setContextClassLoader(forkedClassPathClassLoader); try { runTest(forkedClassPathClassLoader, testClass.getName(), testMethod.getName()); @@ -134,22 +127,6 @@ class CompileWithTargetClassAccessExtension implements InvocationInterceptor { } } - private String[] getTargetClasses(AnnotatedElement... elements) { - Set targetClasses = new LinkedHashSet<>(); - for (AnnotatedElement element : elements) { - MergedAnnotation annotation = MergedAnnotations.from(element) - .get(CompileWithTargetClassAccess.class); - if (annotation.isPresent()) { - Arrays.stream(annotation.getStringArray("classNames")).forEach(targetClasses::add); - Arrays.stream(annotation.getClassArray("classes")).map(Class::getName).forEach(targetClasses::add); - if (element instanceof Class clazz) { - targetClasses.add(clazz.getName()); - } - } - } - return targetClasses.toArray(String[]::new); - } - private void runTest(ClassLoader classLoader, String testClassName, String testMethodName) throws Throwable { diff --git a/spring-core-test/src/main/java/org/springframework/aot/test/generator/compile/DynamicClassLoader.java b/spring-core-test/src/main/java/org/springframework/aot/test/generator/compile/DynamicClassLoader.java index e6d66fe011..8c13a234cb 100644 --- a/spring-core-test/src/main/java/org/springframework/aot/test/generator/compile/DynamicClassLoader.java +++ b/spring-core-test/src/main/java/org/springframework/aot/test/generator/compile/DynamicClassLoader.java @@ -19,10 +19,6 @@ package org.springframework.aot.test.generator.compile; import java.io.ByteArrayInputStream; import java.io.IOException; import java.io.InputStream; -import java.lang.System.Logger; -import java.lang.System.Logger.Level; -import java.lang.invoke.MethodHandles; -import java.lang.invoke.MethodHandles.Lookup; import java.lang.reflect.Method; import java.net.MalformedURLException; import java.net.URL; @@ -35,7 +31,6 @@ import java.util.Map; import org.springframework.aot.test.generator.file.ResourceFile; import org.springframework.aot.test.generator.file.ResourceFiles; import org.springframework.lang.Nullable; -import org.springframework.util.ClassUtils; import org.springframework.util.ReflectionUtils; /** @@ -47,13 +42,13 @@ import org.springframework.util.ReflectionUtils; */ public class DynamicClassLoader extends ClassLoader { - private static final Logger logger = System.getLogger(DynamicClassLoader.class.getName()); - - private final ResourceFiles resourceFiles; private final Map classFiles; + @Nullable + private final Method defineClassMethod; + public DynamicClassLoader(ClassLoader parent, ResourceFiles resourceFiles, Map classFiles) { @@ -61,6 +56,19 @@ public class DynamicClassLoader extends ClassLoader { super(parent); this.resourceFiles = resourceFiles; this.classFiles = classFiles; + this.defineClassMethod = findDefineClassMethod(parent); + } + + @Nullable + private Method findDefineClassMethod(ClassLoader parent) { + Class parentClass = parent.getClass(); + if (parentClass.getName().equals(CompileWithTargetClassAccessClassLoader.class.getName())) { + Method defineClassMethod = ReflectionUtils.findMethod(parentClass, + "defineClassWithTargetAccess", String.class, byte[].class, int.class, int.class); + ReflectionUtils.makeAccessible(defineClassMethod); + return defineClassMethod; + } + return null; } @@ -75,37 +83,13 @@ public class DynamicClassLoader extends ClassLoader { private Class defineClass(String name, DynamicClassFileObject classFile) { byte[] bytes = classFile.getBytes(); - Class targetClass = getTargetClass(name); - if (targetClass != null) { - try { - Lookup lookup = MethodHandles.privateLookupIn(targetClass, MethodHandles.lookup()); - return lookup.defineClass(bytes); - } - catch (IllegalAccessException ex) { - logger.log(Level.WARNING, "Unable to define class using MethodHandles Lookup, " - + "only public methods and classes will be accessible"); - } + if (this.defineClassMethod != null) { + return (Class) ReflectionUtils.invokeMethod(this.defineClassMethod, + getParent(), name, bytes, 0, bytes.length); } - return defineClass(name, bytes, 0, bytes.length, null); + return defineClass(name, bytes, 0, bytes.length); } - private Class getTargetClass(String name) { - ClassLoader parentClassLoader = getParent(); - if (parentClassLoader.getClass().getName() - .equals(CompileWithTargetClassAccessClassLoader.class.getName())) { - String packageName = ClassUtils.getPackageName(name); - Method method = ReflectionUtils.findMethod(parentClassLoader.getClass(), "getTargetClasses"); - ReflectionUtils.makeAccessible(method); - String[] targetCasses = (String[]) ReflectionUtils.invokeMethod(method, parentClassLoader); - for (String targetClass : targetCasses) { - String targetPackageName = ClassUtils.getPackageName(targetClass); - if (targetPackageName.equals(packageName)) { - return ClassUtils.resolveClassName(targetClass, this); - } - } - } - return null; - } @Override protected Enumeration findResources(String name) throws IOException { diff --git a/spring-core-test/src/test/java/org/springframework/aot/test/generator/compile/TestCompilerTests.java b/spring-core-test/src/test/java/org/springframework/aot/test/generator/compile/TestCompilerTests.java index 4bc7b12556..0c2667644b 100644 --- a/spring-core-test/src/test/java/org/springframework/aot/test/generator/compile/TestCompilerTests.java +++ b/spring-core-test/src/test/java/org/springframework/aot/test/generator/compile/TestCompilerTests.java @@ -170,7 +170,7 @@ class TestCompilerTests { } @Test - @CompileWithTargetClassAccess(classNames = "com.example.PackagePrivate") + @CompileWithTargetClassAccess void compiledCodeCanAccessExistingPackagePrivateClassIfAnnotated() throws ClassNotFoundException, LinkageError { SourceFiles sourceFiles = SourceFiles.of(SourceFile.of(""" package com.example;