Commit 25f31b63 authored by Phillip Webb's avatar Phillip Webb

Polish RegistrationBean logic

parent d835ee54
...@@ -217,7 +217,7 @@ public abstract class AbstractEmbeddedServletContainerFactory implements ...@@ -217,7 +217,7 @@ public abstract class AbstractEmbeddedServletContainerFactory implements
@Override @Override
public void addErrorPages(ErrorPage... errorPages) { public void addErrorPages(ErrorPage... errorPages) {
Assert.notNull(this.initializers, "ErrorPages must not be null"); Assert.notNull(errorPages, "ErrorPages must not be null");
this.errorPages.addAll(Arrays.asList(errorPages)); this.errorPages.addAll(Arrays.asList(errorPages));
} }
......
...@@ -20,7 +20,6 @@ import java.util.ArrayList; ...@@ -20,7 +20,6 @@ import java.util.ArrayList;
import java.util.Collection; import java.util.Collection;
import java.util.Collections; import java.util.Collections;
import java.util.Comparator; import java.util.Comparator;
import java.util.HashSet;
import java.util.LinkedHashSet; import java.util.LinkedHashSet;
import java.util.List; import java.util.List;
import java.util.Map.Entry; import java.util.Map.Entry;
...@@ -64,9 +63,10 @@ import org.springframework.web.context.support.WebApplicationContextUtils; ...@@ -64,9 +63,10 @@ import org.springframework.web.context.support.WebApplicationContextUtils;
* <p> * <p>
* In addition, any {@link Servlet} or {@link Filter} beans defined in the context will be * In addition, any {@link Servlet} or {@link Filter} beans defined in the context will be
* automatically registered with the embedded Servlet container. In the case of a single * automatically registered with the embedded Servlet container. In the case of a single
* Servlet bean, the '/*' mapping will be used. If multiple Servlet beans are found then * Servlet bean, the '/' mapping will be used. If multiple Servlet beans are found then
* the lowercase bean name will be used as a mapping prefix. Filter beans will be mapped * the lowercase bean name will be used as a mapping prefix. Any Servlet named
* to all URLs ('/*'). * 'dispatcherServlet' will always be mapped to '/'. Filter beans will be mapped to all
* URLs ('/*').
* *
* <p> * <p>
* For more advanced configuration, the context can instead define beans that implement * For more advanced configuration, the context can instead define beans that implement
...@@ -205,43 +205,47 @@ public class EmbeddedWebApplicationContext extends GenericWebApplicationContext ...@@ -205,43 +205,47 @@ public class EmbeddedWebApplicationContext extends GenericWebApplicationContext
protected Collection<ServletContextInitializer> getServletContextInitializerBeans() { protected Collection<ServletContextInitializer> getServletContextInitializerBeans() {
Set<ServletContextInitializer> initializers = new LinkedHashSet<ServletContextInitializer>(); Set<ServletContextInitializer> initializers = new LinkedHashSet<ServletContextInitializer>();
Set<Object> targets = new HashSet<Object>(); Set<Servlet> servletRegistrations = new LinkedHashSet<Servlet>();
Set<Filter> filterRegistrations = new LinkedHashSet<Filter>();
for (Entry<String, ServletContextInitializer> initializerBean : getOrderedBeansOfType(ServletContextInitializer.class)) { for (Entry<String, ServletContextInitializer> initializerBean : getOrderedBeansOfType(ServletContextInitializer.class)) {
ServletContextInitializer initializer = initializerBean.getValue(); ServletContextInitializer initializer = initializerBean.getValue();
if (initializer instanceof RegistrationBean) {
targets.add(((RegistrationBean) initializer).getRegistrationTarget());
}
initializers.add(initializer); initializers.add(initializer);
if (initializer instanceof ServletRegistrationBean) {
servletRegistrations.add(((ServletRegistrationBean) initializer)
.getServlet());
}
if (initializer instanceof FilterRegistrationBean) {
filterRegistrations.add(((FilterRegistrationBean) initializer)
.getFilter());
}
} }
List<Entry<String, Servlet>> servletBeans = getOrderedBeansOfType(Servlet.class); List<Entry<String, Servlet>> servletBeans = getOrderedBeansOfType(Servlet.class);
for (Entry<String, Servlet> servletBean : servletBeans) { for (Entry<String, Servlet> servletBean : servletBeans) {
final String name = servletBean.getKey(); final String name = servletBean.getKey();
Servlet servlet = servletBean.getValue(); Servlet servlet = servletBean.getValue();
if (targets.contains(servlet)) { if (!servletRegistrations.contains(servlet)) {
continue; String url = (servletBeans.size() == 1 ? "/" : "/" + name + "/");
} if (name.equals(DISPATCHER_SERVLET_NAME)) {
String url = (servletBeans.size() == 1 ? "/" : "/" + name + "/*"); url = "/"; // always map the main dispatcherServlet to "/"
if (name.equals(DISPATCHER_SERVLET_NAME)) { }
url = "/"; // always map the main dispatcherServlet to "/" ServletRegistrationBean registration = new ServletRegistrationBean(
servlet, url);
registration.setName(name);
registration.setMultipartConfig(getMultipartConfig());
initializers.add(registration);
} }
ServletRegistrationBean registration = new ServletRegistrationBean(servlet,
url);
registration.setName(name);
registration.setMultipartConfig(getMultipartConfig());
initializers.add(registration);
} }
for (Entry<String, Filter> filterBean : getOrderedBeansOfType(Filter.class)) { for (Entry<String, Filter> filterBean : getOrderedBeansOfType(Filter.class)) {
String name = filterBean.getKey(); String name = filterBean.getKey();
Filter filter = filterBean.getValue(); Filter filter = filterBean.getValue();
if (targets.contains(filter)) { if (!servletRegistrations.contains(filter)) {
continue; FilterRegistrationBean registration = new FilterRegistrationBean(filter);
registration.setName(name);
initializers.add(registration);
} }
FilterRegistrationBean registration = new FilterRegistrationBean(filter);
registration.setName(name);
initializers.add(registration);
} }
return initializers; return initializers;
......
...@@ -93,6 +93,13 @@ public class FilterRegistrationBean extends RegistrationBean { ...@@ -93,6 +93,13 @@ public class FilterRegistrationBean extends RegistrationBean {
} }
} }
/**
* Returns the filter being registered.
*/
protected Filter getFilter() {
return this.filter;
}
/** /**
* Set the filter to be registered. * Set the filter to be registered.
*/ */
...@@ -220,12 +227,7 @@ public class FilterRegistrationBean extends RegistrationBean { ...@@ -220,12 +227,7 @@ public class FilterRegistrationBean extends RegistrationBean {
@Override @Override
public void onStartup(ServletContext servletContext) throws ServletException { public void onStartup(ServletContext servletContext) throws ServletException {
Assert.notNull(this.filter, "Filter must not be null"); Assert.notNull(this.filter, "Filter must not be null");
configure(servletContext.addFilter(getName(), this.filter)); configure(servletContext.addFilter(getOrDeduceName(this.filter), this.filter));
}
@Override
public Object getRegistrationTarget() {
return this.filter;
} }
/** /**
......
...@@ -48,13 +48,6 @@ public abstract class RegistrationBean implements ServletContextInitializer { ...@@ -48,13 +48,6 @@ public abstract class RegistrationBean implements ServletContextInitializer {
this.name = name; this.name = name;
} }
/**
* @return the name
*/
public String getName() {
return getOrDeduceName(getRegistrationTarget());
}
/** /**
* Sets if asynchronous operations are support for this registration. If not specified * Sets if asynchronous operations are support for this registration. If not specified
* defaults to {@code true}. * defaults to {@code true}.
...@@ -98,20 +91,12 @@ public abstract class RegistrationBean implements ServletContextInitializer { ...@@ -98,20 +91,12 @@ public abstract class RegistrationBean implements ServletContextInitializer {
this.initParameters.put(name, value); this.initParameters.put(name, value);
} }
/**
* The target of the registration (e.g. a Servlet or a Filter) that can be used to
* guess its name if none is supplied explicitly.
*
* @return the target of this registration
*/
public abstract Object getRegistrationTarget();
/** /**
* Deduces the name for this registration. Will return user specified name or fallback * Deduces the name for this registration. Will return user specified name or fallback
* to convention based naming. * to convention based naming.
* @param value the object used for convention based names * @param value the object used for convention based names
*/ */
private String getOrDeduceName(Object value) { protected final String getOrDeduceName(Object value) {
return (this.name != null ? this.name : Conventions.getVariableName(value)); return (this.name != null ? this.name : Conventions.getVariableName(value));
} }
......
...@@ -75,6 +75,13 @@ public class ServletRegistrationBean extends RegistrationBean { ...@@ -75,6 +75,13 @@ public class ServletRegistrationBean extends RegistrationBean {
this.urlMappings.addAll(Arrays.asList(urlMappings)); this.urlMappings.addAll(Arrays.asList(urlMappings));
} }
/**
* Returns the servlet being registered.
*/
protected Servlet getServlet() {
return this.servlet;
}
/** /**
* Sets the servlet to be registered. * Sets the servlet to be registered.
*/ */
...@@ -140,12 +147,7 @@ public class ServletRegistrationBean extends RegistrationBean { ...@@ -140,12 +147,7 @@ public class ServletRegistrationBean extends RegistrationBean {
* Returns the servlet name that will be registered. * Returns the servlet name that will be registered.
*/ */
public String getServletName() { public String getServletName() {
return getName(); return getOrDeduceName(this.servlet);
}
@Override
public Object getRegistrationTarget() {
return this.servlet;
} }
@Override @Override
......
...@@ -190,9 +190,9 @@ public class EmbeddedWebApplicationContextTests { ...@@ -190,9 +190,9 @@ public class EmbeddedWebApplicationContextTests {
ordered.verify(servletContext).addServlet("servletBean1", servlet1); ordered.verify(servletContext).addServlet("servletBean1", servlet1);
ordered.verify(servletContext).addServlet("servletBean2", servlet2); ordered.verify(servletContext).addServlet("servletBean2", servlet2);
verify(escf.getRegisteredServlet(0).getRegistration()).addMapping( verify(escf.getRegisteredServlet(0).getRegistration()).addMapping(
"/servletBean1/*"); "/servletBean1/");
verify(escf.getRegisteredServlet(1).getRegistration()).addMapping( verify(escf.getRegisteredServlet(1).getRegistration()).addMapping(
"/servletBean2/*"); "/servletBean2/");
} }
@Test @Test
...@@ -215,7 +215,7 @@ public class EmbeddedWebApplicationContextTests { ...@@ -215,7 +215,7 @@ public class EmbeddedWebApplicationContextTests {
ordered.verify(servletContext).addServlet("servletBean2", servlet2); ordered.verify(servletContext).addServlet("servletBean2", servlet2);
verify(escf.getRegisteredServlet(0).getRegistration()).addMapping("/"); verify(escf.getRegisteredServlet(0).getRegistration()).addMapping("/");
verify(escf.getRegisteredServlet(1).getRegistration()).addMapping( verify(escf.getRegisteredServlet(1).getRegistration()).addMapping(
"/servletBean2/*"); "/servletBean2/");
} }
@Test @Test
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment