diff --git a/spring-cloud-function-adapters/spring-cloud-function-serverless-web/src/main/java/org/springframework/cloud/function/serverless/web/ProxyAsyncContext.java b/spring-cloud-function-adapters/spring-cloud-function-serverless-web/src/main/java/org/springframework/cloud/function/serverless/web/ProxyAsyncContext.java new file mode 100644 index 000000000..5078a7491 --- /dev/null +++ b/spring-cloud-function-adapters/spring-cloud-function-serverless-web/src/main/java/org/springframework/cloud/function/serverless/web/ProxyAsyncContext.java @@ -0,0 +1,176 @@ +/* + * Copyright 2023-2023 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.cloud.function.serverless.web; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.List; + +import jakarta.servlet.AsyncContext; +import jakarta.servlet.AsyncEvent; +import jakarta.servlet.AsyncListener; +import jakarta.servlet.ServletContext; +import jakarta.servlet.ServletException; +import jakarta.servlet.ServletRequest; +import jakarta.servlet.ServletResponse; +import jakarta.servlet.http.HttpServletRequest; +import jakarta.servlet.http.HttpServletResponse; + +import org.springframework.beans.BeanUtils; +import org.springframework.lang.Nullable; +import org.springframework.util.Assert; +import org.springframework.web.util.WebUtils; + +/** + * Implementation of Async context for {@link ProxyMvc}. + * + * @author Oleg Zhurakousky + */ +public class ProxyAsyncContext implements AsyncContext { + private final HttpServletRequest request; + + @Nullable + private final HttpServletResponse response; + + private final List listeners = new ArrayList<>(); + + @Nullable + private String dispatchedPath; + + private long timeout = 10 * 1000L; + + private final List dispatchHandlers = new ArrayList<>(); + + + public ProxyAsyncContext(ServletRequest request, @Nullable ServletResponse response) { + this.request = (HttpServletRequest) request; + this.response = (HttpServletResponse) response; + } + + + public void addDispatchHandler(Runnable handler) { + Assert.notNull(handler, "Dispatch handler must not be null"); + synchronized (this) { + if (this.dispatchedPath == null) { + this.dispatchHandlers.add(handler); + } + else { + handler.run(); + } + } + } + + @Override + public ServletRequest getRequest() { + return this.request; + } + + @Override + @Nullable + public ServletResponse getResponse() { + return this.response; + } + + @Override + public boolean hasOriginalRequestAndResponse() { + return (this.request instanceof ProxyHttpServletRequest && this.response instanceof ProxyHttpServletResponse); + } + + @Override + public void dispatch() { + dispatch(this.request.getRequestURI()); + } + + @Override + public void dispatch(String path) { + dispatch(null, path); + } + + @Override + public void dispatch(@Nullable ServletContext context, String path) { + synchronized (this) { + this.dispatchedPath = path; + this.dispatchHandlers.forEach(Runnable::run); + } + } + + @Nullable + public String getDispatchedPath() { + return this.dispatchedPath; + } + + @Override + public void complete() { + ProxyHttpServletRequest mockRequest = WebUtils.getNativeRequest(this.request, ProxyHttpServletRequest.class); + if (mockRequest != null) { + mockRequest.setAsyncStarted(false); + } + for (AsyncListener listener : this.listeners) { + try { + listener.onComplete(new AsyncEvent(this, this.request, this.response)); + } + catch (IOException ex) { + throw new IllegalStateException("AsyncListener failure", ex); + } + } + } + + @Override + public void start(Runnable runnable) { + runnable.run(); + } + + @Override + public void addListener(AsyncListener listener) { + this.listeners.add(listener); + } + + @Override + public void addListener(AsyncListener listener, ServletRequest request, ServletResponse response) { + this.listeners.add(listener); + } + + public List getListeners() { + return this.listeners; + } + + @Override + public T createListener(Class clazz) throws ServletException { + return BeanUtils.instantiateClass(clazz); + } + + /** + * By default this is set to 10000 (10 seconds) even though the Servlet API + * specifies a default async request timeout of 30 seconds. Keep in mind the + * timeout could further be impacted by global configuration through the MVC + * Java config or the XML namespace, as well as be overridden per request on + * {@link org.springframework.web.context.request.async.DeferredResult DeferredResult} + * or on + * {@link org.springframework.web.servlet.mvc.method.annotation.SseEmitter SseEmitter}. + * @param timeout the timeout value to use. + * @see AsyncContext#setTimeout(long) + */ + @Override + public void setTimeout(long timeout) { + this.timeout = timeout; + } + + @Override + public long getTimeout() { + return this.timeout; + } +} diff --git a/spring-cloud-function-adapters/spring-cloud-function-serverless-web/src/main/java/org/springframework/cloud/function/serverless/web/ProxyHttpServletRequest.java b/spring-cloud-function-adapters/spring-cloud-function-serverless-web/src/main/java/org/springframework/cloud/function/serverless/web/ProxyHttpServletRequest.java index 574b78c1f..e669d4226 100644 --- a/spring-cloud-function-adapters/spring-cloud-function-serverless-web/src/main/java/org/springframework/cloud/function/serverless/web/ProxyHttpServletRequest.java +++ b/spring-cloud-function-adapters/spring-cloud-function-serverless-web/src/main/java/org/springframework/cloud/function/serverless/web/ProxyHttpServletRequest.java @@ -113,7 +113,7 @@ public class ProxyHttpServletRequest implements HttpServletRequest { private boolean asyncStarted = false; - private boolean asyncSupported = false; + private boolean asyncSupported = true; private DispatcherType dispatcherType = DispatcherType.REQUEST; @@ -163,6 +163,8 @@ public class ProxyHttpServletRequest implements HttpServletRequest { private final MultiValueMap parts = new LinkedMultiValueMap<>(); + private AsyncContext asyncContext; + public ProxyHttpServletRequest(ServletContext servletContext, String method, String requestURI) { this.servletContext = servletContext; this.method = method; @@ -246,8 +248,6 @@ public class ProxyHttpServletRequest implements HttpServletRequest { */ @Nullable public String getContentAsString() throws IllegalStateException, UnsupportedEncodingException { -// Assert.state(this.characterEncoding != null, "Cannot get content as a String for a null character encoding. " -// + "Consider setting the characterEncoding in the request."); if (this.content == null) { return null; @@ -633,7 +633,10 @@ public class ProxyHttpServletRequest implements HttpServletRequest { @Override public AsyncContext startAsync(ServletRequest request, @Nullable ServletResponse response) { - throw new UnsupportedOperationException(); + Assert.state(this.asyncSupported, "Async not supported"); + this.asyncStarted = true; + this.asyncContext = this.asyncContext == null ? new ProxyAsyncContext(request, response) : this.asyncContext; + return this.asyncContext; } public void setAsyncStarted(boolean asyncStarted) { @@ -647,6 +650,7 @@ public class ProxyHttpServletRequest implements HttpServletRequest { public void setAsyncSupported(boolean asyncSupported) { this.asyncSupported = asyncSupported; + this.dispatcherType = DispatcherType.ASYNC; } @Override @@ -655,15 +659,16 @@ public class ProxyHttpServletRequest implements HttpServletRequest { } public void setAsyncContext(@Nullable AsyncContext asyncContext) { - throw new UnsupportedOperationException(); + this.asyncContext = asyncContext; } @Override @Nullable public AsyncContext getAsyncContext() { - return null; + return this.asyncContext; } + public void setDispatcherType(DispatcherType dispatcherType) { this.dispatcherType = dispatcherType; } @@ -692,7 +697,7 @@ public class ProxyHttpServletRequest implements HttpServletRequest { @Override @Nullable public String getHeader(String name) { - return this.headers.containsKey(name) ? this.headers.get(name).toString() : null; + return this.headers.containsKey(name) ? this.headers.get(name).get(0) : null; } @Override diff --git a/spring-cloud-function-adapters/spring-cloud-function-serverless-web/src/main/java/org/springframework/cloud/function/serverless/web/ProxyHttpServletResponse.java b/spring-cloud-function-adapters/spring-cloud-function-serverless-web/src/main/java/org/springframework/cloud/function/serverless/web/ProxyHttpServletResponse.java index 841293800..c33f3bd1f 100644 --- a/spring-cloud-function-adapters/spring-cloud-function-serverless-web/src/main/java/org/springframework/cloud/function/serverless/web/ProxyHttpServletResponse.java +++ b/spring-cloud-function-adapters/spring-cloud-function-serverless-web/src/main/java/org/springframework/cloud/function/serverless/web/ProxyHttpServletResponse.java @@ -41,6 +41,7 @@ import org.springframework.lang.Nullable; import org.springframework.util.Assert; import org.springframework.web.util.WebUtils; + /** * * @author Oleg Zhurakousky @@ -160,7 +161,6 @@ public class ProxyHttpServletResponse implements HttpServletResponse { @Override public void flushBuffer() { - } @Override @@ -248,7 +248,7 @@ public class ProxyHttpServletResponse implements HttpServletResponse { @Override @Nullable public String getHeader(String name) { - return this.headers.containsKey(name) ? this.headers.get(name).toString() : null; + return this.headers.containsKey(name) ? this.headers.get(name).get(0) : null; } /** diff --git a/spring-cloud-function-adapters/spring-cloud-function-serverless-web/src/main/java/org/springframework/cloud/function/serverless/web/ProxyMvc.java b/spring-cloud-function-adapters/spring-cloud-function-serverless-web/src/main/java/org/springframework/cloud/function/serverless/web/ProxyMvc.java index 60cf1b19e..abb7850ec 100644 --- a/spring-cloud-function-adapters/spring-cloud-function-serverless-web/src/main/java/org/springframework/cloud/function/serverless/web/ProxyMvc.java +++ b/spring-cloud-function-adapters/spring-cloud-function-serverless-web/src/main/java/org/springframework/cloud/function/serverless/web/ProxyMvc.java @@ -41,13 +41,16 @@ import jakarta.servlet.http.HttpServletResponse; import org.apache.commons.logging.Log; import org.apache.commons.logging.LogFactory; -import org.springframework.boot.web.servlet.context.AnnotationConfigServletWebApplicationContext; +import org.springframework.boot.autoconfigure.web.servlet.DispatcherServletAutoConfiguration; +import org.springframework.context.support.GenericApplicationContext; import org.springframework.http.HttpStatus; import org.springframework.lang.Nullable; import org.springframework.util.Assert; import org.springframework.util.ObjectUtils; import org.springframework.util.StringUtils; import org.springframework.web.context.ConfigurableWebApplicationContext; +import org.springframework.web.context.request.async.WebAsyncManager; +import org.springframework.web.context.request.async.WebAsyncUtils; import org.springframework.web.servlet.DispatcherServlet; /** @@ -59,7 +62,7 @@ import org.springframework.web.servlet.DispatcherServlet; * @author Oleg Zhurakousky * */ -public class ProxyMvc { +public final class ProxyMvc { private static Log LOG = LogFactory.getLog(ProxyMvc.class); @@ -84,8 +87,7 @@ public class ProxyMvc { } public static ProxyMvc INSTANCE(Class... componentClasses) { - AnnotationConfigServletWebApplicationContext applpicationContext = new AnnotationConfigServletWebApplicationContext(); - applpicationContext.scan(componentClasses[0].getPackageName()); + ConfigurableWebApplicationContext applpicationContext = ServerlessWebApplication.run(componentClasses, new String[] {}); return INSTANCE(applpicationContext); } @@ -97,21 +99,23 @@ public class ProxyMvc { this.applicationContext = applicationContext; ProxyServletContext servletContext = new ProxyServletContext(); this.applicationContext.setServletContext(servletContext); - this.dispatcher = new DispatcherServlet(this.applicationContext); - this.dispatcher.setDetectAllHandlerMappings(false); + this.applicationContext.refresh(); - ServletRegistration.Dynamic reg = servletContext.addServlet("dispatcherServlet", dispatcher); + if (this.applicationContext.containsBean(DispatcherServletAutoConfiguration.DEFAULT_DISPATCHER_SERVLET_BEAN_NAME)) { + this.dispatcher = this.applicationContext.getBean(DispatcherServlet.class); + } + else { + this.dispatcher = new DispatcherServlet(this.applicationContext); + this.dispatcher.setDetectAllHandlerMappings(false); + ((GenericApplicationContext) this.applicationContext).registerBean(DispatcherServletAutoConfiguration.DEFAULT_DISPATCHER_SERVLET_BEAN_NAME, + DispatcherServlet.class, () -> this.dispatcher); + } + + ServletRegistration.Dynamic reg = servletContext.addServlet(DispatcherServletAutoConfiguration.DEFAULT_DISPATCHER_SERVLET_BEAN_NAME, dispatcher); reg.setLoadOnStartup(1); this.servletContext = applicationContext.getServletContext(); try { - this.dispatcher.init(new ProxyServletConfig(this.servletContext)); - try { - this.service(new ProxyHttpServletRequest(servletContext, "INFO", "/"), new ProxyHttpServletResponse()); - } - catch (Exception e) { - //ignore as this is just a pre-warming attempt - } } catch (Exception e) { throw new IllegalStateException("Faild to create Spring MVC DispatcherServlet proxy", e); @@ -137,10 +141,17 @@ public class ProxyMvc { this.service(request, response, (CountDownLatch) null); } + public void service(HttpServletRequest request, HttpServletResponse response, CountDownLatch latch) throws Exception { ProxyFilterChain filterChain = new ProxyFilterChain(this.dispatcher); filterChain.doFilter(request, response); + WebAsyncManager asyncManager = WebAsyncUtils.getAsyncManager(request); + if (asyncManager.isConcurrentHandlingStarted()) { + this.dispatcher.service(request, response); + } + + if (latch != null) { latch.countDown(); } @@ -170,7 +181,7 @@ public class ProxyMvc { ProxyFilterChain(DispatcherServlet servlet) { List filters = new ArrayList<>(); servlet.getServletContext().getFilterRegistrations().values().forEach(fr -> filters.add(((ProxyFilterRegistration) fr).getFilter())); - servlet.getWebApplicationContext().getBeansOfType(Filter.class).values().forEach(f -> filters.add(f)); + //servlet.getWebApplicationContext().getBeansOfType(Filter.class).values().forEach(f -> filters.add(f)); Assert.notNull(filters, "filters cannot be null"); Assert.noNullElements(filters, "filters cannot contain null values"); this.filters = initFilterList(servlet, filters.toArray(new Filter[] {})); @@ -310,7 +321,7 @@ public class ProxyMvc { @Override public String getServletName() { - return "spring-serverless-proxy"; + return DispatcherServletAutoConfiguration.DEFAULT_DISPATCHER_SERVLET_BEAN_NAME; } @Override diff --git a/spring-cloud-function-adapters/spring-cloud-function-serverless-web/src/main/java/org/springframework/cloud/function/serverless/web/ProxyServletContext.java b/spring-cloud-function-adapters/spring-cloud-function-serverless-web/src/main/java/org/springframework/cloud/function/serverless/web/ProxyServletContext.java index 9b0138fb1..6567df9e8 100644 --- a/spring-cloud-function-adapters/spring-cloud-function-serverless-web/src/main/java/org/springframework/cloud/function/serverless/web/ProxyServletContext.java +++ b/spring-cloud-function-adapters/spring-cloud-function-serverless-web/src/main/java/org/springframework/cloud/function/serverless/web/ProxyServletContext.java @@ -227,7 +227,9 @@ public class ProxyServletContext implements ServletContext { @Override public FilterRegistration.Dynamic addFilter(String filterName, Filter filter) { - throw new UnsupportedOperationException("This ServletContext does not represent a running web container"); + ProxyFilterRegistration registration = new ProxyFilterRegistration(filterName, filter); + filterRegistrations.put(filterName, registration); + return registration; } Map filterRegistrations = new HashMap<>(); diff --git a/spring-cloud-function-adapters/spring-cloud-function-serverless-web/src/main/java/org/springframework/cloud/function/serverless/web/ServerlessWebApplication.java b/spring-cloud-function-adapters/spring-cloud-function-serverless-web/src/main/java/org/springframework/cloud/function/serverless/web/ServerlessWebApplication.java new file mode 100644 index 000000000..1c4bcd3ed --- /dev/null +++ b/spring-cloud-function-adapters/spring-cloud-function-serverless-web/src/main/java/org/springframework/cloud/function/serverless/web/ServerlessWebApplication.java @@ -0,0 +1,402 @@ +/* + * Copyright 2023-2023 the original author or authors. + * + * Licensed under the Apache License, Version 2.0 (the "License"); + * you may not use this file except in compliance with the License. + * You may obtain a copy of the License at + * + * https://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.springframework.cloud.function.serverless.web; + +import java.io.IOException; +import java.io.PrintStream; +import java.util.ArrayList; +import java.util.List; +import java.util.Set; +import java.util.function.Consumer; + +import org.apache.commons.logging.Log; +import org.apache.commons.logging.LogFactory; + +import org.springframework.aot.AotDetector; +import org.springframework.beans.BeansException; +import org.springframework.beans.factory.config.BeanFactoryPostProcessor; +import org.springframework.beans.factory.config.ConfigurableListableBeanFactory; +import org.springframework.beans.factory.support.AbstractAutowireCapableBeanFactory; +import org.springframework.beans.factory.support.DefaultListableBeanFactory; +import org.springframework.boot.ApplicationArguments; +import org.springframework.boot.ApplicationContextFactory; +import org.springframework.boot.Banner; +import org.springframework.boot.BootstrapRegistryInitializer; +import org.springframework.boot.ConfigurableBootstrapContext; +import org.springframework.boot.DefaultApplicationArguments; +import org.springframework.boot.DefaultBootstrapContext; +import org.springframework.boot.DefaultPropertiesPropertySource; +import org.springframework.boot.LazyInitializationBeanFactoryPostProcessor; +import org.springframework.boot.ResourceBanner; +import org.springframework.boot.SpringApplication; +import org.springframework.boot.SpringApplicationRunListener; +import org.springframework.boot.SpringBootVersion; +import org.springframework.boot.WebApplicationType; +import org.springframework.boot.ansi.AnsiColor; +import org.springframework.boot.ansi.AnsiOutput; +import org.springframework.boot.ansi.AnsiStyle; +import org.springframework.boot.context.properties.source.ConfigurationPropertySources; +import org.springframework.context.ApplicationContextInitializer; +import org.springframework.context.ConfigurableApplicationContext; +import org.springframework.context.aot.AotApplicationContextInitializer; +import org.springframework.core.Ordered; +import org.springframework.core.env.ConfigurableEnvironment; +import org.springframework.core.env.Environment; +import org.springframework.core.io.DefaultResourceLoader; +import org.springframework.core.io.Resource; +import org.springframework.core.io.ResourceLoader; +import org.springframework.core.io.support.SpringFactoriesLoader; +import org.springframework.core.io.support.SpringFactoriesLoader.ArgumentResolver; +import org.springframework.core.metrics.ApplicationStartup; +import org.springframework.core.metrics.StartupStep; +import org.springframework.util.Assert; +import org.springframework.web.context.ConfigurableWebApplicationContext; + +/** + * + * @author Oleg Zhurakousky + * + */ +class ServerlessWebApplication extends SpringApplication { + + private static final Log logger = LogFactory.getLog(ServerlessWebApplication.class); + + private ApplicationStartup applicationStartup = ApplicationStartup.DEFAULT; + + private ApplicationContextFactory applicationContextFactory = ApplicationContextFactory.DEFAULT; + + private boolean allowCircularReferences; + + private boolean allowBeanDefinitionOverriding; + + private boolean logStartupInfo = true; + + private boolean lazyInitialization = false; + + private WebApplicationType webApplicationType; + + private List> initializers; + + public static ConfigurableWebApplicationContext run(Class[] primarySources, String[] args) { + return new ServerlessWebApplication(primarySources).run(args); + } + + ServerlessWebApplication(Class... classes) { + super(classes); + } + + @Override + public ConfigurableWebApplicationContext run(String... args) { + this.webApplicationType = WebApplicationType.SERVLET; + DefaultBootstrapContext bootstrapContext = createBootstrapContext(); + ConfigurableWebApplicationContext context = null; + SpringApplicationRunListeners listeners = getRunListeners(args); + listeners.starting(bootstrapContext, this.getMainApplicationClass()); + try { + ApplicationArguments applicationArguments = new DefaultApplicationArguments(args); + ConfigurableEnvironment environment = prepareEnvironment(listeners, bootstrapContext, applicationArguments); + Banner printedBanner = printBanner(environment); + context = (ConfigurableWebApplicationContext) createApplicationContext(); + context.setApplicationStartup(this.applicationStartup); + prepareContext(bootstrapContext, context, environment, listeners, applicationArguments, printedBanner); + } + catch (Throwable ex) { + throw new IllegalStateException(ex); + } + + return context; + } + + + private ConfigurableEnvironment prepareEnvironment(SpringApplicationRunListeners listeners, + DefaultBootstrapContext bootstrapContext, ApplicationArguments applicationArguments) { + // Create and configure the environment + ConfigurableEnvironment environment = getOrCreateEnvironment(); + configureEnvironment(environment, applicationArguments.getSourceArgs()); + ConfigurationPropertySources.attach(environment); + listeners.environmentPrepared(bootstrapContext, environment); + DefaultPropertiesPropertySource.moveToEnd(environment); + Assert.state(!environment.containsProperty("spring.main.environment-prefix"), + "Environment prefix cannot be set via properties."); + bindToSpringApplication(environment); + ConfigurationPropertySources.attach(environment); + return environment; + } + + private ConfigurableEnvironment getOrCreateEnvironment() { + ConfigurableEnvironment environment = this.applicationContextFactory.createEnvironment(this.webApplicationType); + if (environment == null && this.applicationContextFactory != ApplicationContextFactory.DEFAULT) { + environment = ApplicationContextFactory.DEFAULT.createEnvironment(this.webApplicationType); + } + return environment; + } + + private SpringApplicationRunListeners getRunListeners(String[] args) { + ArgumentResolver argumentResolver = ArgumentResolver.of(SpringApplication.class, this); + argumentResolver = argumentResolver.and(String[].class, args); + List listeners = getSpringFactoriesInstances(SpringApplicationRunListener.class, argumentResolver); + return new SpringApplicationRunListeners(logger, listeners, this.applicationStartup); + } + + private Banner printBanner(ConfigurableEnvironment environment) { + ResourceLoader resourceLoader = (this.getResourceLoader() != null) ? this.getResourceLoader() + : new DefaultResourceLoader(null); + SpringApplicationBannerPrinter bannerPrinter = new SpringApplicationBannerPrinter(resourceLoader, new SpringAwsBanner()); + return bannerPrinter.print(environment, this.getMainApplicationClass(), System.out); + } + + + private DefaultBootstrapContext createBootstrapContext() { + DefaultBootstrapContext bootstrapContext = new DefaultBootstrapContext(); + ArrayList bootstrapRegistryInitializers = new ArrayList<>(getSpringFactoriesInstances(BootstrapRegistryInitializer.class)); + bootstrapRegistryInitializers.forEach((initializer) -> initializer.initialize(bootstrapContext)); + return bootstrapContext; + } + + private List getSpringFactoriesInstances(Class type) { + return getSpringFactoriesInstances(type, null); + } + + private List getSpringFactoriesInstances(Class type, ArgumentResolver argumentResolver) { + return SpringFactoriesLoader.forDefaultResourceLocation(getClassLoader()).load(type, argumentResolver); + } + + private void prepareContext(DefaultBootstrapContext bootstrapContext, ConfigurableApplicationContext context, + ConfigurableEnvironment environment, SpringApplicationRunListeners listeners, + ApplicationArguments applicationArguments, Banner printedBanner) { + context.setEnvironment(environment); + postProcessApplicationContext(context); + addAotGeneratedInitializerIfNecessary(this.initializers); + applyInitializers(context); + listeners.contextPrepared(context); + bootstrapContext.close(context); + if (this.logStartupInfo) { + logStartupInfo(context.getParent() == null); + logStartupProfileInfo(context); + } + // Add boot specific singleton beans + ConfigurableListableBeanFactory beanFactory = context.getBeanFactory(); + beanFactory.registerSingleton("springApplicationArguments", applicationArguments); + if (printedBanner != null) { + beanFactory.registerSingleton("springBootBanner", printedBanner); + } + if (beanFactory instanceof AbstractAutowireCapableBeanFactory autowireCapableBeanFactory) { + autowireCapableBeanFactory.setAllowCircularReferences(this.allowCircularReferences); + if (beanFactory instanceof DefaultListableBeanFactory listableBeanFactory) { + listableBeanFactory.setAllowBeanDefinitionOverriding(this.allowBeanDefinitionOverriding); + } + } + if (this.lazyInitialization) { + context.addBeanFactoryPostProcessor(new LazyInitializationBeanFactoryPostProcessor()); + } + context.addBeanFactoryPostProcessor(new PropertySourceOrderingBeanFactoryPostProcessor(context)); + if (!AotDetector.useGeneratedArtifacts()) { + // Load the sources + Set sources = getAllSources(); + Assert.notEmpty(sources, "Sources must not be empty"); + load(context, sources.toArray(new Object[0])); + } + listeners.contextLoaded(context); + } + + private void addAotGeneratedInitializerIfNecessary(List> initializers) { + if (AotDetector.useGeneratedArtifacts()) { + List> aotInitializers = new ArrayList<>( + initializers.stream().filter(AotApplicationContextInitializer.class::isInstance).toList()); + if (aotInitializers.isEmpty()) { + String initializerClassName = this.getMainApplicationClass().getName() + "__ApplicationContextInitializer"; + aotInitializers.add(AotApplicationContextInitializer.forInitializerClasses(initializerClassName)); + } + initializers.removeAll(aotInitializers); + initializers.addAll(0, aotInitializers); + } + } + + private static class PropertySourceOrderingBeanFactoryPostProcessor implements BeanFactoryPostProcessor, Ordered { + + private final ConfigurableApplicationContext context; + + PropertySourceOrderingBeanFactoryPostProcessor(ConfigurableApplicationContext context) { + this.context = context; + } + + @Override + public int getOrder() { + return Ordered.HIGHEST_PRECEDENCE; + } + + @Override + public void postProcessBeanFactory(ConfigurableListableBeanFactory beanFactory) throws BeansException { + DefaultPropertiesPropertySource.moveToEnd(this.context.getEnvironment()); + } + + } + + private static class SpringApplicationBannerPrinter { + + static final String BANNER_LOCATION_PROPERTY = "spring.banner.location"; + + static final String DEFAULT_BANNER_LOCATION = "banner.txt"; + + private static final Banner DEFAULT_BANNER = new SpringAwsBanner(); + + private final ResourceLoader resourceLoader; + + private final Banner fallbackBanner; + + SpringApplicationBannerPrinter(ResourceLoader resourceLoader, Banner fallbackBanner) { + this.resourceLoader = resourceLoader; + this.fallbackBanner = fallbackBanner; + } + + Banner print(Environment environment, Class sourceClass, PrintStream out) { + Banner banner = getBanner(environment); + banner.printBanner(environment, sourceClass, out); + return new PrintedBanner(banner, sourceClass); + } + + private Banner getBanner(Environment environment) { + Banner textBanner = getTextBanner(environment); + if (textBanner != null) { + return textBanner; + } + if (this.fallbackBanner != null) { + return this.fallbackBanner; + } + return DEFAULT_BANNER; + } + + private Banner getTextBanner(Environment environment) { + String location = environment.getProperty(BANNER_LOCATION_PROPERTY, DEFAULT_BANNER_LOCATION); + Resource resource = this.resourceLoader.getResource(location); + try { + if (resource.exists() && !resource.getURL().toExternalForm().contains("liquibase-core")) { + return new ResourceBanner(resource); + } + } + catch (IOException ex) { + // Ignore + } + return null; + } + + /** + * Decorator that allows a {@link Banner} to be printed again without needing to + * specify the source class. + */ + private static class PrintedBanner implements Banner { + + private final Banner banner; + + private final Class sourceClass; + + PrintedBanner(Banner banner, Class sourceClass) { + this.banner = banner; + this.sourceClass = sourceClass; + } + + @Override + public void printBanner(Environment environment, Class sourceClass, PrintStream out) { + sourceClass = (sourceClass != null) ? sourceClass : this.sourceClass; + this.banner.printBanner(environment, sourceClass, out); + } + + } + } + + private static class SpringAwsBanner implements Banner { + + private static final String[] BANNER = { "", "\n" + + " ____ _ _____ ______ _ _ _ \n" + + " / ___| _ __ _ __(_)_ __ __ _ / / \\ \\ / / ___| | | __ _ _ __ ___ | |__ __| | __ _ \n" + + " \\___ \\| '_ \\| '__| | '_ \\ / _` | / / _ \\ \\ /\\ / /\\___ \\ | | / _` | '_ ` _ \\| '_ \\ / _` |/ _` |\n" + + " ___) | |_) | | | | | | | (_| |/ / ___ \\ V V / ___) | | |__| (_| | | | | | | |_) | (_| | (_| |\n" + + " |____/| .__/|_| |_|_| |_|\\__, /_/_/ \\_\\_/\\_/ |____/ |_____\\__,_|_| |_| |_|_.__/ \\__,_|\\__,_|\n" + + " |_| |___/ \n" + + "" }; + + private static final String SPRING_BOOT = " :: Spring Boot :: "; + + private static final int STRAP_LINE_SIZE = 42; + + @Override + public void printBanner(Environment environment, Class sourceClass, PrintStream printStream) { + for (String line : BANNER) { + printStream.println(line); + } + String version = SpringBootVersion.getVersion(); + version = (version != null) ? " (v" + version + ")" : ""; + StringBuilder padding = new StringBuilder(); + while (padding.length() < STRAP_LINE_SIZE - (version.length() + SPRING_BOOT.length())) { + padding.append(" "); + } + + printStream.println(AnsiOutput.toString(AnsiColor.GREEN, SPRING_BOOT, AnsiColor.DEFAULT, padding.toString(), + AnsiStyle.FAINT, version)); + printStream.println(); + } + + } + + private static class SpringApplicationRunListeners { + + private final List listeners; + + private final ApplicationStartup applicationStartup; + + SpringApplicationRunListeners(Log log, List listeners, + ApplicationStartup applicationStartup) { + this.listeners = List.copyOf(listeners); + this.applicationStartup = applicationStartup; + } + + void starting(ConfigurableBootstrapContext bootstrapContext, Class mainApplicationClass) { + doWithListeners("spring.boot.application.starting", (listener) -> listener.starting(bootstrapContext), + (step) -> { + if (mainApplicationClass != null) { + step.tag("mainApplicationClass", mainApplicationClass.getName()); + } + }); + } + + void environmentPrepared(ConfigurableBootstrapContext bootstrapContext, ConfigurableEnvironment environment) { + doWithListeners("spring.boot.application.environment-prepared", + (listener) -> listener.environmentPrepared(bootstrapContext, environment)); + } + + void contextPrepared(ConfigurableApplicationContext context) { + doWithListeners("spring.boot.application.context-prepared", (listener) -> listener.contextPrepared(context)); + } + + void contextLoaded(ConfigurableApplicationContext context) { + doWithListeners("spring.boot.application.context-loaded", (listener) -> listener.contextLoaded(context)); + } + private void doWithListeners(String stepName, Consumer listenerAction) { + doWithListeners(stepName, listenerAction, null); + } + + private void doWithListeners(String stepName, Consumer listenerAction, + Consumer stepAction) { + StartupStep step = this.applicationStartup.start(stepName); + this.listeners.forEach(listenerAction); + if (stepAction != null) { + stepAction.accept(step); + } + step.end(); + } + } +} diff --git a/spring-cloud-function-adapters/spring-cloud-function-serverless-web/src/test/java/org/springframework/cloud/function/serverless/web/RequestResponseTests.java b/spring-cloud-function-adapters/spring-cloud-function-serverless-web/src/test/java/org/springframework/cloud/function/serverless/web/RequestResponseTests.java index 07bd8b5eb..ceeb273b9 100644 --- a/spring-cloud-function-adapters/spring-cloud-function-serverless-web/src/test/java/org/springframework/cloud/function/serverless/web/RequestResponseTests.java +++ b/spring-cloud-function-adapters/spring-cloud-function-serverless-web/src/test/java/org/springframework/cloud/function/serverless/web/RequestResponseTests.java @@ -25,6 +25,7 @@ import org.junit.jupiter.api.AfterEach; import org.junit.jupiter.api.BeforeEach; import org.junit.jupiter.api.Test; + import org.springframework.cloud.function.test.app.Pet; import org.springframework.cloud.function.test.app.PetStoreSpringAppConfig; import org.springframework.http.HttpStatus; @@ -45,7 +46,7 @@ public class RequestResponseTests { @BeforeEach public void before() { - this.mvc = ProxyMvc.INSTANCE(PetStoreSpringAppConfig.class, ProxyErrorController.class); + this.mvc = ProxyMvc.INSTANCE(ProxyErrorController.class, PetStoreSpringAppConfig.class); } @AfterEach