Commit a3f97879 authored by Phillip Webb's avatar Phillip Webb

Add DelegatingFilterProxyRegistrationBean

Add a RegistrationBean that can be used to create DelegatingFilterProxy
filters that don't cause early initialization.

Fixes gh-4165
parent a7bdef61
/*
* Copyright 2012-2015 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.boot.context.embedded;
import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.EnumSet;
import java.util.LinkedHashSet;
import java.util.Set;
import javax.servlet.DispatcherType;
import javax.servlet.Filter;
import javax.servlet.FilterRegistration;
import javax.servlet.ServletContext;
import javax.servlet.ServletException;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.springframework.util.Assert;
/**
* Abstract base {@link ServletContextInitializer} to register {@link Filter}s in a
* Servlet 3.0+ container.
*
* @author Phillip Webb
*/
abstract class AbstractFilterRegistrationBean extends RegistrationBean {
/**
* Filters that wrap the servlet request should be ordered less than or equal to this.
*/
protected static final int REQUEST_WRAPPER_FILTER_MAX_ORDER = 0;
private final Log logger = LogFactory.getLog(getClass());
static final EnumSet<DispatcherType> ASYNC_DISPATCHER_TYPES = EnumSet.of(
DispatcherType.FORWARD, DispatcherType.INCLUDE, DispatcherType.REQUEST,
DispatcherType.ASYNC);
static final EnumSet<DispatcherType> NON_ASYNC_DISPATCHER_TYPES = EnumSet
.of(DispatcherType.FORWARD, DispatcherType.INCLUDE, DispatcherType.REQUEST);
private static final String[] DEFAULT_URL_MAPPINGS = { "/*" };
private Set<ServletRegistrationBean> servletRegistrationBeans = new LinkedHashSet<ServletRegistrationBean>();
private Set<String> servletNames = new LinkedHashSet<String>();
private Set<String> urlPatterns = new LinkedHashSet<String>();
private EnumSet<DispatcherType> dispatcherTypes;
private boolean matchAfter = false;
/**
* Create a new instance to be registered with the specified
* {@link ServletRegistrationBean}s.
* @param servletRegistrationBeans associate {@link ServletRegistrationBean}s
*/
AbstractFilterRegistrationBean(ServletRegistrationBean... servletRegistrationBeans) {
Assert.notNull(servletRegistrationBeans,
"ServletRegistrationBeans must not be null");
Collections.addAll(this.servletRegistrationBeans, servletRegistrationBeans);
}
/**
* Set {@link ServletRegistrationBean}s that the filter will be registered against.
* @param servletRegistrationBeans the Servlet registration beans
*/
public void setServletRegistrationBeans(
Collection<? extends ServletRegistrationBean> servletRegistrationBeans) {
Assert.notNull(servletRegistrationBeans,
"ServletRegistrationBeans must not be null");
this.servletRegistrationBeans = new LinkedHashSet<ServletRegistrationBean>(
servletRegistrationBeans);
}
/**
* Return a mutable collection of the {@link ServletRegistrationBean} that the filter
* will be registered against. {@link ServletRegistrationBean}s.
* @return the Servlet registration beans
* @see #setServletNames
* @see #setUrlPatterns
*/
public Collection<ServletRegistrationBean> getServletRegistrationBeans() {
return this.servletRegistrationBeans;
}
/**
* Add {@link ServletRegistrationBean}s for the filter.
* @param servletRegistrationBeans the servlet registration beans to add
* @see #setServletRegistrationBeans
*/
public void addServletRegistrationBeans(
ServletRegistrationBean... servletRegistrationBeans) {
Assert.notNull(servletRegistrationBeans,
"ServletRegistrationBeans must not be null");
Collections.addAll(this.servletRegistrationBeans, servletRegistrationBeans);
}
/**
* Set servlet names that the filter will be registered against. This will replace any
* previously specified servlet names.
* @param servletNames the servlet names
* @see #setServletRegistrationBeans
* @see #setUrlPatterns
*/
public void setServletNames(Collection<String> servletNames) {
Assert.notNull(servletNames, "ServletNames must not be null");
this.servletNames = new LinkedHashSet<String>(servletNames);
}
/**
* Return a mutable collection of servlet names that the filter will be registered
* against.
* @return the servlet names
*/
public Collection<String> getServletNames() {
return this.servletNames;
}
/**
* Add servlet names for the filter.
* @param servletNames the servlet names to add
*/
public void addServletNames(String... servletNames) {
Assert.notNull(servletNames, "ServletNames must not be null");
this.servletNames.addAll(Arrays.asList(servletNames));
}
/**
* Set the URL patterns that the filter will be registered against. This will replace
* any previously specified URL patterns.
* @param urlPatterns the URL patterns
* @see #setServletRegistrationBeans
* @see #setServletNames
*/
public void setUrlPatterns(Collection<String> urlPatterns) {
Assert.notNull(urlPatterns, "UrlPatterns must not be null");
this.urlPatterns = new LinkedHashSet<String>(urlPatterns);
}
/**
* Return a mutable collection of URL patterns that the filter will be registered
* against.
* @return the URL patterns
*/
public Collection<String> getUrlPatterns() {
return this.urlPatterns;
}
/**
* Add URL patterns that the filter will be registered against.
* @param urlPatterns the URL patterns
*/
public void addUrlPatterns(String... urlPatterns) {
Assert.notNull(urlPatterns, "UrlPatterns must not be null");
Collections.addAll(this.urlPatterns, urlPatterns);
}
/**
* Convenience method to {@link #setDispatcherTypes(EnumSet) set dispatcher types}
* using the specified elements.
* @param first the first dispatcher type
* @param rest additional dispatcher types
*/
public void setDispatcherTypes(DispatcherType first, DispatcherType... rest) {
this.dispatcherTypes = EnumSet.of(first, rest);
}
/**
* Sets the dispatcher types that should be used with the registration. If not
* specified the types will be deduced based on the value of
* {@link #isAsyncSupported()}.
* @param dispatcherTypes the dispatcher types
*/
public void setDispatcherTypes(EnumSet<DispatcherType> dispatcherTypes) {
this.dispatcherTypes = dispatcherTypes;
}
/**
* Set if the filter mappings should be matched after any declared filter mappings of
* the ServletContext. Defaults to {@code false} indicating the filters are supposed
* to be matched before any declared filter mappings of the ServletContext.
* @param matchAfter if filter mappings are matched after
*/
public void setMatchAfter(boolean matchAfter) {
this.matchAfter = matchAfter;
}
/**
* Return if filter mappings should be matched after any declared Filter mappings of
* the ServletContext.
* @return if filter mappings are matched after
*/
public boolean isMatchAfter() {
return this.matchAfter;
}
@Override
public void onStartup(ServletContext servletContext) throws ServletException {
Filter filter = getFilter();
Assert.notNull(filter, "Filter must not be null");
String name = getOrDeduceName(filter);
if (!isEnabled()) {
this.logger.info("Filter " + name + " was not registered (disabled)");
return;
}
FilterRegistration.Dynamic added = servletContext.addFilter(name, filter);
if (added == null) {
this.logger.info("Filter " + name + " was not registered "
+ "(possibly already registered?)");
return;
}
configure(added);
}
/**
* Return the {@link Filter} to be registered.
* @return the filter
*/
protected abstract Filter getFilter();
/**
* Configure registration settings. Subclasses can override this method to perform
* additional configuration if required.
* @param registration the registration
*/
protected void configure(FilterRegistration.Dynamic registration) {
super.configure(registration);
EnumSet<DispatcherType> dispatcherTypes = this.dispatcherTypes;
if (dispatcherTypes == null) {
dispatcherTypes = (isAsyncSupported() ? ASYNC_DISPATCHER_TYPES
: NON_ASYNC_DISPATCHER_TYPES);
}
Set<String> servletNames = new LinkedHashSet<String>();
for (ServletRegistrationBean servletRegistrationBean : this.servletRegistrationBeans) {
servletNames.add(servletRegistrationBean.getServletName());
}
servletNames.addAll(this.servletNames);
if (servletNames.isEmpty() && this.urlPatterns.isEmpty()) {
this.logger.info("Mapping filter: '" + registration.getName() + "' to: "
+ Arrays.asList(DEFAULT_URL_MAPPINGS));
registration.addMappingForUrlPatterns(dispatcherTypes, this.matchAfter,
DEFAULT_URL_MAPPINGS);
}
else {
if (servletNames.size() > 0) {
this.logger.info("Mapping filter: '" + registration.getName()
+ "' to servlets: " + servletNames);
registration.addMappingForServletNames(dispatcherTypes, this.matchAfter,
servletNames.toArray(new String[servletNames.size()]));
}
if (this.urlPatterns.size() > 0) {
this.logger.info("Mapping filter: '" + registration.getName()
+ "' to urls: " + this.urlPatterns);
registration.addMappingForUrlPatterns(dispatcherTypes, this.matchAfter,
this.urlPatterns.toArray(new String[this.urlPatterns.size()]));
}
}
}
}
/*
* Copyright 2012-2015 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.boot.context.embedded;
import javax.servlet.Filter;
import javax.servlet.ServletContext;
import org.springframework.beans.BeansException;
import org.springframework.context.ApplicationContext;
import org.springframework.context.ApplicationContextAware;
import org.springframework.util.Assert;
import org.springframework.web.context.WebApplicationContext;
import org.springframework.web.filter.DelegatingFilterProxy;
/**
* A {@link ServletContextInitializer} to register {@link DelegatingFilterProxy}s in a
* Servlet 3.0+ container. Similar to the {@link ServletContext#addFilter(String, Filter)
* registration} features provided by {@link ServletContext} but with a Spring Bean
* friendly design.
* <p>
* The bean name of the actual delegate {@link Filter} should be specified using the
* {@code targetBeanName} constructor argument. Unlike the {@link FilterRegistrationBean},
* referenced filters are not instantiated early. In fact, if the delegate filter bean is
* marked {@code @Lazy} it won't be instantiated at all until the filter is called.
* <p>
* Registrations can be associated with {@link #setUrlPatterns URL patterns} and/or
* servlets (either by {@link #setServletNames name} or via a
* {@link #setServletRegistrationBeans ServletRegistrationBean}s. When no URL pattern or
* servlets are specified the filter will be associated to '/*'. The targetBeanName will
* be used as the filter name if not otherwise specified.
*
* @author Phillip Webb
* @since 1.3.0
* @see ServletContextInitializer
* @see ServletContext#addFilter(String, Filter)
* @see FilterRegistrationBean
* @see DelegatingFilterProxy
*/
public class DelegatingFilterProxyRegistrationBean extends AbstractFilterRegistrationBean
implements ApplicationContextAware {
private ApplicationContext applicationContext;
private final String targetBeanName;
/**
* Create a new {@link DelegatingFilterProxyRegistrationBean} instance to be
* registered with the specified {@link ServletRegistrationBean}s.
* @param targetBeanName name of the target filter bean to look up in the Spring
* application context (must not be {@code null}).
* @param servletRegistrationBeans associate {@link ServletRegistrationBean}s
*/
public DelegatingFilterProxyRegistrationBean(String targetBeanName,
ServletRegistrationBean... servletRegistrationBeans) {
super(servletRegistrationBeans);
Assert.hasLength(targetBeanName, "TargetBeanName must not be null or empty");
this.targetBeanName = targetBeanName;
setName(targetBeanName);
}
@Override
public void setApplicationContext(ApplicationContext applicationContext)
throws BeansException {
this.applicationContext = applicationContext;
}
protected String getTargetBeanName() {
return this.targetBeanName;
}
@Override
protected Filter getFilter() {
return new DelegatingFilterProxy(this.targetBeanName, getWebApplicationContext());
}
private WebApplicationContext getWebApplicationContext() {
Assert.notNull(this.applicationContext, "ApplicationContext be injected");
Assert.isInstanceOf(WebApplicationContext.class, this.applicationContext);
return (WebApplicationContext) this.applicationContext;
}
}
...@@ -16,21 +16,9 @@ ...@@ -16,21 +16,9 @@
package org.springframework.boot.context.embedded; package org.springframework.boot.context.embedded;
import java.util.Arrays;
import java.util.Collection;
import java.util.Collections;
import java.util.EnumSet;
import java.util.LinkedHashSet;
import java.util.Set;
import javax.servlet.DispatcherType;
import javax.servlet.Filter; import javax.servlet.Filter;
import javax.servlet.FilterRegistration;
import javax.servlet.ServletContext; import javax.servlet.ServletContext;
import javax.servlet.ServletException;
import org.apache.commons.logging.Log;
import org.apache.commons.logging.LogFactory;
import org.springframework.util.Assert; import org.springframework.util.Assert;
/** /**
...@@ -48,37 +36,17 @@ import org.springframework.util.Assert; ...@@ -48,37 +36,17 @@ import org.springframework.util.Assert;
* @author Phillip Webb * @author Phillip Webb
* @see ServletContextInitializer * @see ServletContextInitializer
* @see ServletContext#addFilter(String, Filter) * @see ServletContext#addFilter(String, Filter)
* @see DelegatingFilterProxyRegistrationBean
*/ */
public class FilterRegistrationBean extends RegistrationBean { public class FilterRegistrationBean extends AbstractFilterRegistrationBean {
/** /**
* Filters that wrap the servlet request should be ordered less than or equal to this. * Filters that wrap the servlet request should be ordered less than or equal to this.
*/ */
public static final int REQUEST_WRAPPER_FILTER_MAX_ORDER = 0; public static final int REQUEST_WRAPPER_FILTER_MAX_ORDER = AbstractFilterRegistrationBean.REQUEST_WRAPPER_FILTER_MAX_ORDER;
private static Log logger = LogFactory.getLog(FilterRegistrationBean.class);
static final EnumSet<DispatcherType> ASYNC_DISPATCHER_TYPES = EnumSet.of(
DispatcherType.FORWARD, DispatcherType.INCLUDE, DispatcherType.REQUEST,
DispatcherType.ASYNC);
static final EnumSet<DispatcherType> NON_ASYNC_DISPATCHER_TYPES = EnumSet
.of(DispatcherType.FORWARD, DispatcherType.INCLUDE, DispatcherType.REQUEST);
private static final String[] DEFAULT_URL_MAPPINGS = { "/*" };
private Filter filter; private Filter filter;
private Set<ServletRegistrationBean> servletRegistrationBeans = new LinkedHashSet<ServletRegistrationBean>();
private Set<String> servletNames = new LinkedHashSet<String>();
private Set<String> urlPatterns = new LinkedHashSet<String>();
private EnumSet<DispatcherType> dispatcherTypes;
private boolean matchAfter = false;
/** /**
* Create a new {@link FilterRegistrationBean} instance. * Create a new {@link FilterRegistrationBean} instance.
*/ */
...@@ -93,17 +61,12 @@ public class FilterRegistrationBean extends RegistrationBean { ...@@ -93,17 +61,12 @@ public class FilterRegistrationBean extends RegistrationBean {
*/ */
public FilterRegistrationBean(Filter filter, public FilterRegistrationBean(Filter filter,
ServletRegistrationBean... servletRegistrationBeans) { ServletRegistrationBean... servletRegistrationBeans) {
super(servletRegistrationBeans);
Assert.notNull(filter, "Filter must not be null"); Assert.notNull(filter, "Filter must not be null");
Assert.notNull(servletRegistrationBeans,
"ServletRegistrationBeans must not be null");
this.filter = filter; this.filter = filter;
Collections.addAll(this.servletRegistrationBeans, servletRegistrationBeans);
} }
/** @Override
* Returns the filter being registered.
* @return the filter
*/
protected Filter getFilter() { protected Filter getFilter() {
return this.filter; return this.filter;
} }
...@@ -117,196 +80,4 @@ public class FilterRegistrationBean extends RegistrationBean { ...@@ -117,196 +80,4 @@ public class FilterRegistrationBean extends RegistrationBean {
this.filter = filter; this.filter = filter;
} }
/**
* Set {@link ServletRegistrationBean}s that the filter will be registered against.
* @param servletRegistrationBeans the Servlet registration beans
*/
public void setServletRegistrationBeans(
Collection<? extends ServletRegistrationBean> servletRegistrationBeans) {
Assert.notNull(servletRegistrationBeans,
"ServletRegistrationBeans must not be null");
this.servletRegistrationBeans = new LinkedHashSet<ServletRegistrationBean>(
servletRegistrationBeans);
}
/**
* Return a mutable collection of the {@link ServletRegistrationBean} that the filter
* will be registered against. {@link ServletRegistrationBean}s.
* @return the Servlet registration beans
* @see #setServletNames
* @see #setUrlPatterns
*/
public Collection<ServletRegistrationBean> getServletRegistrationBeans() {
return this.servletRegistrationBeans;
}
/**
* Add {@link ServletRegistrationBean}s for the filter.
* @param servletRegistrationBeans the servlet registration beans to add
* @see #setServletRegistrationBeans
*/
public void addServletRegistrationBeans(
ServletRegistrationBean... servletRegistrationBeans) {
Assert.notNull(servletRegistrationBeans,
"ServletRegistrationBeans must not be null");
Collections.addAll(this.servletRegistrationBeans, servletRegistrationBeans);
}
/**
* Set servlet names that the filter will be registered against. This will replace any
* previously specified servlet names.
* @param servletNames the servlet names
* @see #setServletRegistrationBeans
* @see #setUrlPatterns
*/
public void setServletNames(Collection<String> servletNames) {
Assert.notNull(servletNames, "ServletNames must not be null");
this.servletNames = new LinkedHashSet<String>(servletNames);
}
/**
* Return a mutable collection of servlet names that the filter will be registered
* against.
* @return the servlet names
*/
public Collection<String> getServletNames() {
return this.servletNames;
}
/**
* Add servlet names for the filter.
* @param servletNames the servlet names to add
*/
public void addServletNames(String... servletNames) {
Assert.notNull(servletNames, "ServletNames must not be null");
this.servletNames.addAll(Arrays.asList(servletNames));
}
/**
* Set the URL patterns that the filter will be registered against. This will replace
* any previously specified URL patterns.
* @param urlPatterns the URL patterns
* @see #setServletRegistrationBeans
* @see #setServletNames
*/
public void setUrlPatterns(Collection<String> urlPatterns) {
Assert.notNull(urlPatterns, "UrlPatterns must not be null");
this.urlPatterns = new LinkedHashSet<String>(urlPatterns);
}
/**
* Return a mutable collection of URL patterns that the filter will be registered
* against.
* @return the URL patterns
*/
public Collection<String> getUrlPatterns() {
return this.urlPatterns;
}
/**
* Add URL patterns that the filter will be registered against.
* @param urlPatterns the URL patterns
*/
public void addUrlPatterns(String... urlPatterns) {
Assert.notNull(urlPatterns, "UrlPatterns must not be null");
Collections.addAll(this.urlPatterns, urlPatterns);
}
/**
* Convenience method to {@link #setDispatcherTypes(EnumSet) set dispatcher types}
* using the specified elements.
* @param first the first dispatcher type
* @param rest additional dispatcher types
*/
public void setDispatcherTypes(DispatcherType first, DispatcherType... rest) {
this.dispatcherTypes = EnumSet.of(first, rest);
}
/**
* Sets the dispatcher types that should be used with the registration. If not
* specified the types will be deduced based on the value of
* {@link #isAsyncSupported()}.
* @param dispatcherTypes the dispatcher types
*/
public void setDispatcherTypes(EnumSet<DispatcherType> dispatcherTypes) {
this.dispatcherTypes = dispatcherTypes;
}
/**
* Set if the filter mappings should be matched after any declared filter mappings of
* the ServletContext. Defaults to {@code false} indicating the filters are supposed
* to be matched before any declared filter mappings of the ServletContext.
* @param matchAfter if filter mappings are matched after
*/
public void setMatchAfter(boolean matchAfter) {
this.matchAfter = matchAfter;
}
/**
* Return if filter mappings should be matched after any declared Filter mappings of
* the ServletContext.
* @return if filter mappings are matched after
*/
public boolean isMatchAfter() {
return this.matchAfter;
}
@Override
public void onStartup(ServletContext servletContext) throws ServletException {
Assert.notNull(this.filter, "Filter must not be null");
String name = getOrDeduceName(this.filter);
if (!isEnabled()) {
logger.info("Filter " + name + " was not registered (disabled)");
return;
}
FilterRegistration.Dynamic added = servletContext.addFilter(name, this.filter);
if (added == null) {
logger.info("Filter " + name + " was not registered "
+ "(possibly already registered?)");
return;
}
configure(added);
}
/**
* Configure registration settings. Subclasses can override this method to perform
* additional configuration if required.
* @param registration the registration
*/
protected void configure(FilterRegistration.Dynamic registration) {
super.configure(registration);
EnumSet<DispatcherType> dispatcherTypes = this.dispatcherTypes;
if (dispatcherTypes == null) {
dispatcherTypes = (isAsyncSupported() ? ASYNC_DISPATCHER_TYPES
: NON_ASYNC_DISPATCHER_TYPES);
}
Set<String> servletNames = new LinkedHashSet<String>();
for (ServletRegistrationBean servletRegistrationBean : this.servletRegistrationBeans) {
servletNames.add(servletRegistrationBean.getServletName());
}
servletNames.addAll(this.servletNames);
if (servletNames.isEmpty() && this.urlPatterns.isEmpty()) {
logger.info("Mapping filter: '" + registration.getName() + "' to: "
+ Arrays.asList(DEFAULT_URL_MAPPINGS));
registration.addMappingForUrlPatterns(dispatcherTypes, this.matchAfter,
DEFAULT_URL_MAPPINGS);
}
else {
if (servletNames.size() > 0) {
logger.info("Mapping filter: '" + registration.getName()
+ "' to servlets: " + servletNames);
registration.addMappingForServletNames(dispatcherTypes, this.matchAfter,
servletNames.toArray(new String[servletNames.size()]));
}
if (this.urlPatterns.size() > 0) {
logger.info("Mapping filter: '" + registration.getName() + "' to urls: "
+ this.urlPatterns);
registration.addMappingForUrlPatterns(dispatcherTypes, this.matchAfter,
this.urlPatterns.toArray(new String[this.urlPatterns.size()]));
}
}
}
} }
...@@ -31,6 +31,7 @@ import org.springframework.util.Assert; ...@@ -31,6 +31,7 @@ import org.springframework.util.Assert;
* @author Phillip Webb * @author Phillip Webb
* @see ServletRegistrationBean * @see ServletRegistrationBean
* @see FilterRegistrationBean * @see FilterRegistrationBean
* @see DelegatingFilterProxyRegistrationBean
* @see ServletListenerRegistrationBean * @see ServletListenerRegistrationBean
*/ */
public abstract class RegistrationBean implements ServletContextInitializer, Ordered { public abstract class RegistrationBean implements ServletContextInitializer, Ordered {
......
...@@ -62,6 +62,9 @@ class ServletContextInitializerBeans ...@@ -62,6 +62,9 @@ class ServletContextInitializerBeans
private final Log log = LogFactory.getLog(getClass()); private final Log log = LogFactory.getLog(getClass());
/**
* Seen bean instances or bean names.
*/
private final Set<Object> seen = new HashSet<Object>(); private final Set<Object> seen = new HashSet<Object>();
private final MultiValueMap<Class<?>, ServletContextInitializer> initializers; private final MultiValueMap<Class<?>, ServletContextInitializer> initializers;
...@@ -92,17 +95,26 @@ class ServletContextInitializerBeans ...@@ -92,17 +95,26 @@ class ServletContextInitializerBeans
private void addServletContextInitializerBean(String beanName, private void addServletContextInitializerBean(String beanName,
ServletContextInitializer initializer, ListableBeanFactory beanFactory) { ServletContextInitializer initializer, ListableBeanFactory beanFactory) {
if (initializer instanceof ServletRegistrationBean) { if (initializer instanceof ServletRegistrationBean) {
Servlet source = ((ServletRegistrationBean) initializer).getServlet();
addServletContextInitializerBean(Servlet.class, beanName, initializer, addServletContextInitializerBean(Servlet.class, beanName, initializer,
beanFactory, ((ServletRegistrationBean) initializer).getServlet()); beanFactory, source);
} }
else if (initializer instanceof FilterRegistrationBean) { else if (initializer instanceof FilterRegistrationBean) {
Filter source = ((FilterRegistrationBean) initializer).getFilter();
addServletContextInitializerBean(Filter.class, beanName, initializer,
beanFactory, source);
}
else if (initializer instanceof DelegatingFilterProxyRegistrationBean) {
String source = ((DelegatingFilterProxyRegistrationBean) initializer)
.getTargetBeanName();
addServletContextInitializerBean(Filter.class, beanName, initializer, addServletContextInitializerBean(Filter.class, beanName, initializer,
beanFactory, ((FilterRegistrationBean) initializer).getFilter()); beanFactory, source);
} }
else if (initializer instanceof ServletListenerRegistrationBean) { else if (initializer instanceof ServletListenerRegistrationBean) {
EventListener source = ((ServletListenerRegistrationBean<?>) initializer)
.getListener();
addServletContextInitializerBean(EventListener.class, beanName, initializer, addServletContextInitializerBean(EventListener.class, beanName, initializer,
beanFactory, beanFactory, source);
((ServletListenerRegistrationBean<?>) initializer).getListener());
} }
else { else {
addServletContextInitializerBean(ServletContextInitializer.class, beanName, addServletContextInitializerBean(ServletContextInitializer.class, beanName,
...@@ -164,7 +176,8 @@ class ServletContextInitializerBeans ...@@ -164,7 +176,8 @@ class ServletContextInitializerBeans
private <T, B extends T> void addAsRegistrationBean(ListableBeanFactory beanFactory, private <T, B extends T> void addAsRegistrationBean(ListableBeanFactory beanFactory,
Class<T> type, Class<B> beanType, RegistrationBeanAdapter<T> adapter) { Class<T> type, Class<B> beanType, RegistrationBeanAdapter<T> adapter) {
List<Map.Entry<String, B>> beans = getOrderedBeansOfType(beanFactory, beanType); List<Map.Entry<String, B>> beans = getOrderedBeansOfType(beanFactory, beanType,
this.seen);
for (Entry<String, B> bean : beans) { for (Entry<String, B> bean : beans) {
if (this.seen.add(bean.getValue())) { if (this.seen.add(bean.getValue())) {
int order = getOrder(bean.getValue()); int order = getOrder(bean.getValue());
...@@ -175,7 +188,6 @@ class ServletContextInitializerBeans ...@@ -175,7 +188,6 @@ class ServletContextInitializerBeans
registration.setName(beanName); registration.setName(beanName);
registration.setOrder(order); registration.setOrder(order);
this.initializers.add(type, registration); this.initializers.add(type, registration);
if (this.log.isDebugEnabled()) { if (this.log.isDebugEnabled()) {
this.log.debug( this.log.debug(
"Created " + type.getSimpleName() + " initializer for bean '" "Created " + type.getSimpleName() + " initializer for bean '"
...@@ -197,18 +209,30 @@ class ServletContextInitializerBeans ...@@ -197,18 +209,30 @@ class ServletContextInitializerBeans
private <T> List<Entry<String, T>> getOrderedBeansOfType( private <T> List<Entry<String, T>> getOrderedBeansOfType(
ListableBeanFactory beanFactory, Class<T> type) { ListableBeanFactory beanFactory, Class<T> type) {
return getOrderedBeansOfType(beanFactory, type, Collections.emptySet());
}
private <T> List<Entry<String, T>> getOrderedBeansOfType(
ListableBeanFactory beanFactory, Class<T> type, Set<?> excludes) {
List<Entry<String, T>> beans = new ArrayList<Entry<String, T>>(); List<Entry<String, T>> beans = new ArrayList<Entry<String, T>>();
Comparator<Entry<String, T>> comparator = new Comparator<Entry<String, T>>() { Comparator<Entry<String, T>> comparator = new Comparator<Entry<String, T>>() {
@Override @Override
public int compare(Entry<String, T> o1, Entry<String, T> o2) { public int compare(Entry<String, T> o1, Entry<String, T> o2) {
return AnnotationAwareOrderComparator.INSTANCE.compare(o1.getValue(), return AnnotationAwareOrderComparator.INSTANCE.compare(o1.getValue(),
o2.getValue()); o2.getValue());
} }
}; };
String[] names = beanFactory.getBeanNamesForType(type, true, false); String[] names = beanFactory.getBeanNamesForType(type, true, false);
Map<String, T> map = new LinkedHashMap<String, T>(); Map<String, T> map = new LinkedHashMap<String, T>();
for (String name : names) { for (String name : names) {
map.put(name, beanFactory.getBean(name, type)); if (!excludes.contains(name)) {
T bean = beanFactory.getBean(name, type);
if (!excludes.contains(bean)) {
map.put(name, bean);
}
}
} }
beans.addAll(map.entrySet()); beans.addAll(map.entrySet());
Collections.sort(beans, comparator); Collections.sort(beans, comparator);
......
/*
* Copyright 2012-2015 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.boot.context.embedded;
import java.util.Arrays;
import java.util.Collections;
import java.util.EnumSet;
import java.util.HashMap;
import java.util.LinkedHashSet;
import java.util.Map;
import javax.servlet.DispatcherType;
import javax.servlet.Filter;
import javax.servlet.FilterRegistration;
import javax.servlet.ServletContext;
import org.junit.Before;
import org.junit.Rule;
import org.junit.Test;
import org.junit.rules.ExpectedException;
import org.mockito.Mock;
import org.mockito.MockitoAnnotations;
import static org.mockito.BDDMockito.given;
import static org.mockito.Matchers.anyObject;
import static org.mockito.Matchers.anyString;
import static org.mockito.Matchers.eq;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify;
/**
* Abstract base for {@link AbstractFilterRegistrationBean} tests.
*
* @author Phillip Webb
*/
public abstract class AbstractFilterRegistrationBeanTests {
@Rule
public ExpectedException thrown = ExpectedException.none();
@Mock
ServletContext servletContext;
@Mock
FilterRegistration.Dynamic registration;
@Before
public void setupMocks() {
MockitoAnnotations.initMocks(this);
given(this.servletContext.addFilter(anyString(), (Filter) anyObject()))
.willReturn(this.registration);
}
@Test
public void startupWithDefaults() throws Exception {
AbstractFilterRegistrationBean bean = createFilterRegistrationBean();
bean.onStartup(this.servletContext);
verify(this.servletContext).addFilter(eq("mockFilter"), getExpectedFilter());
verify(this.registration).setAsyncSupported(true);
verify(this.registration).addMappingForUrlPatterns(
AbstractFilterRegistrationBean.ASYNC_DISPATCHER_TYPES, false, "/*");
}
@Test
public void startupWithSpecifiedValues() throws Exception {
AbstractFilterRegistrationBean bean = createFilterRegistrationBean();
bean.setName("test");
bean.setAsyncSupported(false);
bean.setInitParameters(Collections.singletonMap("a", "b"));
bean.addInitParameter("c", "d");
bean.setUrlPatterns(new LinkedHashSet<String>(Arrays.asList("/a", "/b")));
bean.addUrlPatterns("/c");
bean.setServletNames(new LinkedHashSet<String>(Arrays.asList("s1", "s2")));
bean.addServletNames("s3");
bean.setServletRegistrationBeans(
Collections.singleton(mockServletRegistation("s4")));
bean.addServletRegistrationBeans(mockServletRegistation("s5"));
bean.setMatchAfter(true);
bean.onStartup(this.servletContext);
verify(this.servletContext).addFilter(eq("test"), getExpectedFilter());
verify(this.registration).setAsyncSupported(false);
Map<String, String> expectedInitParameters = new HashMap<String, String>();
expectedInitParameters.put("a", "b");
expectedInitParameters.put("c", "d");
verify(this.registration).setInitParameters(expectedInitParameters);
verify(this.registration).addMappingForUrlPatterns(
AbstractFilterRegistrationBean.NON_ASYNC_DISPATCHER_TYPES, true, "/a",
"/b", "/c");
verify(this.registration).addMappingForServletNames(
AbstractFilterRegistrationBean.NON_ASYNC_DISPATCHER_TYPES, true, "s4",
"s5", "s1", "s2", "s3");
}
@Test
public void specificName() throws Exception {
AbstractFilterRegistrationBean bean = createFilterRegistrationBean();
bean.setName("specificName");
bean.onStartup(this.servletContext);
verify(this.servletContext).addFilter(eq("specificName"), getExpectedFilter());
}
@Test
public void deducedName() throws Exception {
AbstractFilterRegistrationBean bean = createFilterRegistrationBean();
bean.onStartup(this.servletContext);
verify(this.servletContext).addFilter(eq("mockFilter"), getExpectedFilter());
}
@Test
public void disable() throws Exception {
AbstractFilterRegistrationBean bean = createFilterRegistrationBean();
bean.setEnabled(false);
bean.onStartup(this.servletContext);
verify(this.servletContext, times(0)).addFilter(eq("mockFilter"),
getExpectedFilter());
}
@Test
public void setServletRegistrationBeanMustNotBeNull() throws Exception {
AbstractFilterRegistrationBean bean = createFilterRegistrationBean();
this.thrown.expect(IllegalArgumentException.class);
this.thrown.expectMessage("ServletRegistrationBeans must not be null");
bean.setServletRegistrationBeans(null);
}
@Test
public void addServletRegistrationBeanMustNotBeNull() throws Exception {
AbstractFilterRegistrationBean bean = createFilterRegistrationBean();
this.thrown.expect(IllegalArgumentException.class);
this.thrown.expectMessage("ServletRegistrationBeans must not be null");
bean.addServletRegistrationBeans((ServletRegistrationBean[]) null);
}
@Test
public void setServletRegistrationBeanReplacesValue() throws Exception {
AbstractFilterRegistrationBean bean = createFilterRegistrationBean(
mockServletRegistation("a"));
bean.setServletRegistrationBeans(new LinkedHashSet<ServletRegistrationBean>(
Arrays.asList(mockServletRegistation("b"))));
bean.onStartup(this.servletContext);
verify(this.registration).addMappingForServletNames(
AbstractFilterRegistrationBean.ASYNC_DISPATCHER_TYPES, false, "b");
}
@Test
public void modifyInitParameters() throws Exception {
AbstractFilterRegistrationBean bean = createFilterRegistrationBean();
bean.addInitParameter("a", "b");
bean.getInitParameters().put("a", "c");
bean.onStartup(this.servletContext);
verify(this.registration).setInitParameters(Collections.singletonMap("a", "c"));
}
@Test
public void setUrlPatternMustNotBeNull() throws Exception {
AbstractFilterRegistrationBean bean = createFilterRegistrationBean();
this.thrown.expect(IllegalArgumentException.class);
this.thrown.expectMessage("UrlPatterns must not be null");
bean.setUrlPatterns(null);
}
@Test
public void addUrlPatternMustNotBeNull() throws Exception {
AbstractFilterRegistrationBean bean = createFilterRegistrationBean();
this.thrown.expect(IllegalArgumentException.class);
this.thrown.expectMessage("UrlPatterns must not be null");
bean.addUrlPatterns((String[]) null);
}
@Test
public void setServletNameMustNotBeNull() throws Exception {
AbstractFilterRegistrationBean bean = createFilterRegistrationBean();
this.thrown.expect(IllegalArgumentException.class);
this.thrown.expectMessage("ServletNames must not be null");
bean.setServletNames(null);
}
@Test
public void addServletNameMustNotBeNull() throws Exception {
AbstractFilterRegistrationBean bean = createFilterRegistrationBean();
this.thrown.expect(IllegalArgumentException.class);
this.thrown.expectMessage("ServletNames must not be null");
bean.addServletNames((String[]) null);
}
@Test
public void withSpecificDispatcherTypes() throws Exception {
AbstractFilterRegistrationBean bean = createFilterRegistrationBean();
bean.setDispatcherTypes(DispatcherType.INCLUDE, DispatcherType.FORWARD);
bean.onStartup(this.servletContext);
verify(this.registration).addMappingForUrlPatterns(
EnumSet.of(DispatcherType.INCLUDE, DispatcherType.FORWARD), false, "/*");
}
@Test
public void withSpecificDispatcherTypesEnumSet() throws Exception {
AbstractFilterRegistrationBean bean = createFilterRegistrationBean();
EnumSet<DispatcherType> types = EnumSet.of(DispatcherType.INCLUDE,
DispatcherType.FORWARD);
bean.setDispatcherTypes(types);
bean.onStartup(this.servletContext);
verify(this.registration).addMappingForUrlPatterns(types, false, "/*");
}
protected abstract Filter getExpectedFilter();
protected abstract AbstractFilterRegistrationBean createFilterRegistrationBean(
ServletRegistrationBean... servletRegistrationBeans);
protected final ServletRegistrationBean mockServletRegistation(String name) {
ServletRegistrationBean bean = new ServletRegistrationBean();
bean.setName(name);
return bean;
}
}
/*
* Copyright 2012-2015 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.boot.context.embedded;
import javax.servlet.Filter;
import org.junit.Test;
import org.springframework.mock.web.MockServletContext;
import org.springframework.test.util.ReflectionTestUtils;
import org.springframework.web.context.WebApplicationContext;
import org.springframework.web.context.support.GenericWebApplicationContext;
import org.springframework.web.filter.DelegatingFilterProxy;
import static org.hamcrest.Matchers.equalTo;
import static org.hamcrest.Matchers.instanceOf;
import static org.junit.Assert.assertThat;
import static org.mockito.Matchers.isA;
/**
* Tests for {@link DelegatingFilterProxyRegistrationBean}.
*
* @author Phillip Webb
*/
public class DelegatingFilterProxyRegistrationBeanTests
extends AbstractFilterRegistrationBeanTests {
private WebApplicationContext applicationContext = new GenericWebApplicationContext(
new MockServletContext());;
@Test
public void targetBeanNameMustNotBeNull() throws Exception {
this.thrown.expect(IllegalArgumentException.class);
this.thrown.expectMessage("TargetBeanName must not be null or empty");
new DelegatingFilterProxyRegistrationBean(null);
}
@Test
public void targetBeanNameMustNotBeEmpty() throws Exception {
this.thrown.expect(IllegalArgumentException.class);
this.thrown.expectMessage("TargetBeanName must not be null or empty");
new DelegatingFilterProxyRegistrationBean("");
}
@Test
public void nameDefaultsToTargetBeanName() throws Exception {
assertThat(new DelegatingFilterProxyRegistrationBean("myFilter")
.getOrDeduceName(null), equalTo("myFilter"));
}
@Test
public void getFilterUsesDelegatingFilterProxy() throws Exception {
AbstractFilterRegistrationBean registrationBean = createFilterRegistrationBean();
Filter filter = registrationBean.getFilter();
assertThat(filter, instanceOf(DelegatingFilterProxy.class));
assertThat(ReflectionTestUtils.getField(filter, "webApplicationContext"),
equalTo((Object) this.applicationContext));
assertThat(ReflectionTestUtils.getField(filter, "targetBeanName"),
equalTo((Object) "mockFilter"));
}
@Test
public void createServletRegistrationBeanMustNotBeNull() throws Exception {
this.thrown.expect(IllegalArgumentException.class);
this.thrown.expectMessage("ServletRegistrationBeans must not be null");
new DelegatingFilterProxyRegistrationBean("mockFilter",
(ServletRegistrationBean[]) null);
}
@Override
protected AbstractFilterRegistrationBean createFilterRegistrationBean(
ServletRegistrationBean... servletRegistrationBeans) {
DelegatingFilterProxyRegistrationBean bean = new DelegatingFilterProxyRegistrationBean(
"mockFilter", servletRegistrationBeans);
bean.setApplicationContext(this.applicationContext);
return bean;
}
@Override
protected Filter getExpectedFilter() {
return isA(DelegatingFilterProxy.class);
}
}
...@@ -29,13 +29,19 @@ import javax.servlet.ServletException; ...@@ -29,13 +29,19 @@ import javax.servlet.ServletException;
import javax.servlet.ServletRequest; import javax.servlet.ServletRequest;
import javax.servlet.ServletResponse; import javax.servlet.ServletResponse;
import org.apache.struts.mock.MockHttpServletRequest;
import org.apache.struts.mock.MockHttpServletResponse;
import org.junit.After; import org.junit.After;
import org.junit.Before; import org.junit.Before;
import org.junit.Rule; import org.junit.Rule;
import org.junit.Test; import org.junit.Test;
import org.junit.rules.ExpectedException; import org.junit.rules.ExpectedException;
import org.mockito.ArgumentCaptor;
import org.mockito.Captor;
import org.mockito.InOrder; import org.mockito.InOrder;
import org.mockito.MockitoAnnotations;
import org.springframework.beans.MutablePropertyValues; import org.springframework.beans.MutablePropertyValues;
import org.springframework.beans.factory.BeanCreationException;
import org.springframework.beans.factory.DisposableBean; import org.springframework.beans.factory.DisposableBean;
import org.springframework.beans.factory.config.BeanDefinition; import org.springframework.beans.factory.config.BeanDefinition;
import org.springframework.beans.factory.config.ConfigurableListableBeanFactory; import org.springframework.beans.factory.config.ConfigurableListableBeanFactory;
...@@ -52,6 +58,8 @@ import org.springframework.context.support.PropertySourcesPlaceholderConfigurer; ...@@ -52,6 +58,8 @@ import org.springframework.context.support.PropertySourcesPlaceholderConfigurer;
import org.springframework.core.Ordered; import org.springframework.core.Ordered;
import org.springframework.core.annotation.Order; import org.springframework.core.annotation.Order;
import org.springframework.core.env.ConfigurableEnvironment; import org.springframework.core.env.ConfigurableEnvironment;
import org.springframework.mock.web.MockFilterChain;
import org.springframework.mock.web.MockFilterConfig;
import org.springframework.web.context.ServletContextAware; import org.springframework.web.context.ServletContextAware;
import org.springframework.web.context.WebApplicationContext; import org.springframework.web.context.WebApplicationContext;
import org.springframework.web.context.request.SessionScope; import org.springframework.web.context.request.SessionScope;
...@@ -88,8 +96,12 @@ public class EmbeddedWebApplicationContextTests { ...@@ -88,8 +96,12 @@ public class EmbeddedWebApplicationContextTests {
private EmbeddedWebApplicationContext context; private EmbeddedWebApplicationContext context;
@Captor
private ArgumentCaptor<Filter> filterCaptor;
@Before @Before
public void setup() { public void setup() {
MockitoAnnotations.initMocks(this);
this.context = new EmbeddedWebApplicationContext(); this.context = new EmbeddedWebApplicationContext();
} }
...@@ -302,9 +314,9 @@ public class EmbeddedWebApplicationContextTests { ...@@ -302,9 +314,9 @@ public class EmbeddedWebApplicationContextTests {
ordered.verify(escf.getServletContext()).addFilter("filterBean1", filter1); ordered.verify(escf.getServletContext()).addFilter("filterBean1", filter1);
ordered.verify(escf.getServletContext()).addFilter("filterBean2", filter2); ordered.verify(escf.getServletContext()).addFilter("filterBean2", filter2);
verify(escf.getRegisteredFilter(0).getRegistration()).addMappingForUrlPatterns( verify(escf.getRegisteredFilter(0).getRegistration()).addMappingForUrlPatterns(
FilterRegistrationBean.ASYNC_DISPATCHER_TYPES, false, "/*"); AbstractFilterRegistrationBean.ASYNC_DISPATCHER_TYPES, false, "/*");
verify(escf.getRegisteredFilter(1).getRegistration()).addMappingForUrlPatterns( verify(escf.getRegisteredFilter(1).getRegistration()).addMappingForUrlPatterns(
FilterRegistrationBean.ASYNC_DISPATCHER_TYPES, false, "/*"); AbstractFilterRegistrationBean.ASYNC_DISPATCHER_TYPES, false, "/*");
} }
@Test @Test
...@@ -395,7 +407,7 @@ public class EmbeddedWebApplicationContextTests { ...@@ -395,7 +407,7 @@ public class EmbeddedWebApplicationContextTests {
} }
@Test @Test
public void filterReegistrationBeansSkipsRegisteredFilters() throws Exception { public void filterRegistrationBeansSkipsRegisteredFilters() throws Exception {
addEmbeddedServletContainerFactoryBean(); addEmbeddedServletContainerFactoryBean();
Filter filter = mock(Filter.class); Filter filter = mock(Filter.class);
FilterRegistrationBean initializer = new FilterRegistrationBean(filter); FilterRegistrationBean initializer = new FilterRegistrationBean(filter);
...@@ -408,6 +420,32 @@ public class EmbeddedWebApplicationContextTests { ...@@ -408,6 +420,32 @@ public class EmbeddedWebApplicationContextTests {
verify(servletContext, atMost(1)).addFilter(anyString(), (Filter) anyObject()); verify(servletContext, atMost(1)).addFilter(anyString(), (Filter) anyObject());
} }
@Test
public void delegatingFilterProxyRegistrationBeansSkipsTargetBeanNames()
throws Exception {
addEmbeddedServletContainerFactoryBean();
DelegatingFilterProxyRegistrationBean initializer = new DelegatingFilterProxyRegistrationBean(
"filterBean");
this.context.registerBeanDefinition("initializerBean",
beanDefinition(initializer));
BeanDefinition filterBeanDefinition = beanDefinition(
new IllegalStateException("Create FilterBean Failure"));
filterBeanDefinition.setLazyInit(true);
this.context.registerBeanDefinition("filterBean", filterBeanDefinition);
this.context.refresh();
ServletContext servletContext = getEmbeddedServletContainerFactory()
.getServletContext();
verify(servletContext, atMost(1)).addFilter(anyString(),
this.filterCaptor.capture());
// Up to this point the filterBean should not have been created, calling
// the delegate proxy will trigger creation and an exception
this.thrown.expect(BeanCreationException.class);
this.thrown.expectMessage("Create FilterBean Failure");
this.filterCaptor.getValue().init(new MockFilterConfig());
this.filterCaptor.getValue().doFilter(new MockHttpServletRequest(),
new MockHttpServletResponse(), new MockFilterChain());
}
@Test @Test
public void postProcessEmbeddedServletContainerFactory() throws Exception { public void postProcessEmbeddedServletContainerFactory() throws Exception {
RootBeanDefinition bd = new RootBeanDefinition( RootBeanDefinition bd = new RootBeanDefinition(
...@@ -459,7 +497,6 @@ public class EmbeddedWebApplicationContextTests { ...@@ -459,7 +497,6 @@ public class EmbeddedWebApplicationContextTests {
this.context.close(); this.context.close();
assertThat(validator.destroyed, is(true)); assertThat(validator.destroyed, is(true));
assertThat(validator.containerStoppedFirst, is(true)); assertThat(validator.containerStoppedFirst, is(true));
} }
private void addEmbeddedServletContainerFactoryBean() { private void addEmbeddedServletContainerFactoryBean() {
...@@ -482,6 +519,9 @@ public class EmbeddedWebApplicationContextTests { ...@@ -482,6 +519,9 @@ public class EmbeddedWebApplicationContextTests {
} }
public static <T> T getBean(T object) { public static <T> T getBean(T object) {
if (object instanceof RuntimeException) {
throw (RuntimeException) object;
}
return object; return object;
} }
......
/* /*
* Copyright 2012-2013 the original author or authors. * Copyright 2012-2015 the original author or authors.
* *
* Licensed under the Apache License, Version 2.0 (the "License"); * Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License. * you may not use this file except in compliance with the License.
...@@ -16,29 +16,11 @@ ...@@ -16,29 +16,11 @@
package org.springframework.boot.context.embedded; package org.springframework.boot.context.embedded;
import java.util.Arrays;
import java.util.Collections;
import java.util.EnumSet;
import java.util.HashMap;
import java.util.LinkedHashSet;
import java.util.Map;
import javax.servlet.DispatcherType;
import javax.servlet.Filter; import javax.servlet.Filter;
import javax.servlet.FilterRegistration;
import javax.servlet.ServletContext;
import org.junit.Before;
import org.junit.Rule;
import org.junit.Test; import org.junit.Test;
import org.junit.rules.ExpectedException;
import org.mockito.Mock;
import org.mockito.MockitoAnnotations;
import static org.mockito.BDDMockito.given; import static org.mockito.Matchers.eq;
import static org.mockito.Matchers.anyObject;
import static org.mockito.Matchers.anyString;
import static org.mockito.Mockito.times;
import static org.mockito.Mockito.verify; import static org.mockito.Mockito.verify;
/** /**
...@@ -46,93 +28,18 @@ import static org.mockito.Mockito.verify; ...@@ -46,93 +28,18 @@ import static org.mockito.Mockito.verify;
* *
* @author Phillip Webb * @author Phillip Webb
*/ */
public class FilterRegistrationBeanTests { public class FilterRegistrationBeanTests extends AbstractFilterRegistrationBeanTests {
@Rule
public ExpectedException thrown = ExpectedException.none();
private final MockFilter filter = new MockFilter(); private final MockFilter filter = new MockFilter();
@Mock
private ServletContext servletContext;
@Mock
private FilterRegistration.Dynamic registration;
@Before
public void setupMocks() {
MockitoAnnotations.initMocks(this);
given(this.servletContext.addFilter(anyString(), (Filter) anyObject()))
.willReturn(this.registration);
}
@Test
public void startupWithDefaults() throws Exception {
FilterRegistrationBean bean = new FilterRegistrationBean(this.filter);
bean.onStartup(this.servletContext);
verify(this.servletContext).addFilter("mockFilter", this.filter);
verify(this.registration).setAsyncSupported(true);
verify(this.registration).addMappingForUrlPatterns(
FilterRegistrationBean.ASYNC_DISPATCHER_TYPES, false, "/*");
}
@Test
public void startupWithSpecifiedValues() throws Exception {
FilterRegistrationBean bean = new FilterRegistrationBean();
bean.setName("test");
bean.setFilter(this.filter);
bean.setAsyncSupported(false);
bean.setInitParameters(Collections.singletonMap("a", "b"));
bean.addInitParameter("c", "d");
bean.setUrlPatterns(new LinkedHashSet<String>(Arrays.asList("/a", "/b")));
bean.addUrlPatterns("/c");
bean.setServletNames(new LinkedHashSet<String>(Arrays.asList("s1", "s2")));
bean.addServletNames("s3");
bean.setServletRegistrationBeans(
Collections.singleton(mockServletRegistation("s4")));
bean.addServletRegistrationBeans(mockServletRegistation("s5"));
bean.setMatchAfter(true);
bean.onStartup(this.servletContext);
verify(this.servletContext).addFilter("test", this.filter);
verify(this.registration).setAsyncSupported(false);
Map<String, String> expectedInitParameters = new HashMap<String, String>();
expectedInitParameters.put("a", "b");
expectedInitParameters.put("c", "d");
verify(this.registration).setInitParameters(expectedInitParameters);
verify(this.registration).addMappingForUrlPatterns(
FilterRegistrationBean.NON_ASYNC_DISPATCHER_TYPES, true, "/a", "/b",
"/c");
verify(this.registration).addMappingForServletNames(
FilterRegistrationBean.NON_ASYNC_DISPATCHER_TYPES, true, "s4", "s5", "s1",
"s2", "s3");
}
@Test @Test
public void specificName() throws Exception { public void setFilter() throws Exception {
FilterRegistrationBean bean = new FilterRegistrationBean();
bean.setName("specificName");
bean.setFilter(this.filter);
bean.onStartup(this.servletContext);
verify(this.servletContext).addFilter("specificName", this.filter);
}
@Test
public void deducedName() throws Exception {
FilterRegistrationBean bean = new FilterRegistrationBean(); FilterRegistrationBean bean = new FilterRegistrationBean();
bean.setFilter(this.filter); bean.setFilter(this.filter);
bean.onStartup(this.servletContext); bean.onStartup(this.servletContext);
verify(this.servletContext).addFilter("mockFilter", this.filter); verify(this.servletContext).addFilter("mockFilter", this.filter);
} }
@Test
public void disable() throws Exception {
FilterRegistrationBean bean = new FilterRegistrationBean();
bean.setFilter(this.filter);
bean.setEnabled(false);
bean.onStartup(this.servletContext);
verify(this.servletContext, times(0)).addFilter("mockFilter", this.filter);
}
@Test @Test
public void setFilterMustNotBeNull() throws Exception { public void setFilterMustNotBeNull() throws Exception {
FilterRegistrationBean bean = new FilterRegistrationBean(); FilterRegistrationBean bean = new FilterRegistrationBean();
...@@ -142,20 +49,12 @@ public class FilterRegistrationBeanTests { ...@@ -142,20 +49,12 @@ public class FilterRegistrationBeanTests {
} }
@Test @Test
public void createServletMustNotBeNull() throws Exception { public void constructFilterMustNotBeNull() throws Exception {
this.thrown.expect(IllegalArgumentException.class); this.thrown.expect(IllegalArgumentException.class);
this.thrown.expectMessage("Filter must not be null"); this.thrown.expectMessage("Filter must not be null");
new FilterRegistrationBean(null); new FilterRegistrationBean(null);
} }
@Test
public void setServletRegistrationBeanMustNotBeNull() throws Exception {
FilterRegistrationBean bean = new FilterRegistrationBean(this.filter);
this.thrown.expect(IllegalArgumentException.class);
this.thrown.expectMessage("ServletRegistrationBeans must not be null");
bean.setServletRegistrationBeans(null);
}
@Test @Test
public void createServletRegistrationBeanMustNotBeNull() throws Exception { public void createServletRegistrationBeanMustNotBeNull() throws Exception {
this.thrown.expect(IllegalArgumentException.class); this.thrown.expect(IllegalArgumentException.class);
...@@ -163,88 +62,15 @@ public class FilterRegistrationBeanTests { ...@@ -163,88 +62,15 @@ public class FilterRegistrationBeanTests {
new FilterRegistrationBean(this.filter, (ServletRegistrationBean[]) null); new FilterRegistrationBean(this.filter, (ServletRegistrationBean[]) null);
} }
@Test @Override
public void addServletRegistrationBeanMustNotBeNull() throws Exception { protected AbstractFilterRegistrationBean createFilterRegistrationBean(
FilterRegistrationBean bean = new FilterRegistrationBean(this.filter); ServletRegistrationBean... servletRegistrationBeans) {
this.thrown.expect(IllegalArgumentException.class); return new FilterRegistrationBean(this.filter, servletRegistrationBeans);
this.thrown.expectMessage("ServletRegistrationBeans must not be null");
bean.addServletRegistrationBeans((ServletRegistrationBean[]) null);
}
@Test
public void setServletRegistrationBeanReplacesValue() throws Exception {
FilterRegistrationBean bean = new FilterRegistrationBean(this.filter,
mockServletRegistation("a"));
bean.setServletRegistrationBeans(new LinkedHashSet<ServletRegistrationBean>(
Arrays.asList(mockServletRegistation("b"))));
bean.onStartup(this.servletContext);
verify(this.registration).addMappingForServletNames(
FilterRegistrationBean.ASYNC_DISPATCHER_TYPES, false, "b");
}
@Test
public void modifyInitParameters() throws Exception {
FilterRegistrationBean bean = new FilterRegistrationBean(this.filter);
bean.addInitParameter("a", "b");
bean.getInitParameters().put("a", "c");
bean.onStartup(this.servletContext);
verify(this.registration).setInitParameters(Collections.singletonMap("a", "c"));
}
@Test
public void setUrlPatternMustNotBeNull() throws Exception {
FilterRegistrationBean bean = new FilterRegistrationBean(this.filter);
this.thrown.expect(IllegalArgumentException.class);
this.thrown.expectMessage("UrlPatterns must not be null");
bean.setUrlPatterns(null);
}
@Test
public void addUrlPatternMustNotBeNull() throws Exception {
FilterRegistrationBean bean = new FilterRegistrationBean(this.filter);
this.thrown.expect(IllegalArgumentException.class);
this.thrown.expectMessage("UrlPatterns must not be null");
bean.addUrlPatterns((String[]) null);
} }
@Test @Override
public void setServletNameMustNotBeNull() throws Exception { protected Filter getExpectedFilter() {
FilterRegistrationBean bean = new FilterRegistrationBean(this.filter); return eq(this.filter);
this.thrown.expect(IllegalArgumentException.class);
this.thrown.expectMessage("ServletNames must not be null");
bean.setServletNames(null);
} }
@Test
public void addServletNameMustNotBeNull() throws Exception {
FilterRegistrationBean bean = new FilterRegistrationBean(this.filter);
this.thrown.expect(IllegalArgumentException.class);
this.thrown.expectMessage("ServletNames must not be null");
bean.addServletNames((String[]) null);
}
@Test
public void withSpecificDispatcherTypes() throws Exception {
FilterRegistrationBean bean = new FilterRegistrationBean(this.filter);
bean.setDispatcherTypes(DispatcherType.INCLUDE, DispatcherType.FORWARD);
bean.onStartup(this.servletContext);
verify(this.registration).addMappingForUrlPatterns(
EnumSet.of(DispatcherType.INCLUDE, DispatcherType.FORWARD), false, "/*");
}
@Test
public void withSpecificDispatcherTypesEnumSet() throws Exception {
FilterRegistrationBean bean = new FilterRegistrationBean(this.filter);
EnumSet<DispatcherType> types = EnumSet.of(DispatcherType.INCLUDE,
DispatcherType.FORWARD);
bean.setDispatcherTypes(types);
bean.onStartup(this.servletContext);
verify(this.registration).addMappingForUrlPatterns(types, false, "/*");
}
private ServletRegistrationBean mockServletRegistation(String name) {
ServletRegistrationBean bean = new ServletRegistrationBean();
bean.setName(name);
return bean;
}
} }
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment