Commit bc814d20 authored by Andy Wilkinson's avatar Andy Wilkinson

Prevent authenticated principal from clashing with argument of same name

Closes gh-11988
parent 1772a154
...@@ -25,6 +25,7 @@ import org.junit.Test; ...@@ -25,6 +25,7 @@ import org.junit.Test;
import org.springframework.boot.actuate.endpoint.annotation.Endpoint; import org.springframework.boot.actuate.endpoint.annotation.Endpoint;
import org.springframework.boot.actuate.endpoint.annotation.ReadOperation; import org.springframework.boot.actuate.endpoint.annotation.ReadOperation;
import org.springframework.boot.actuate.endpoint.invoke.InvocationContext;
import org.springframework.boot.actuate.endpoint.invoke.convert.ConversionServiceParameterValueMapper; import org.springframework.boot.actuate.endpoint.invoke.convert.ConversionServiceParameterValueMapper;
import org.springframework.boot.actuate.endpoint.invoker.cache.CachingOperationInvokerAdvisor; import org.springframework.boot.actuate.endpoint.invoker.cache.CachingOperationInvokerAdvisor;
import org.springframework.boot.actuate.endpoint.web.EndpointMediaTypes; import org.springframework.boot.actuate.endpoint.web.EndpointMediaTypes;
...@@ -57,7 +58,9 @@ public class CloudFoundryWebEndpointDiscovererTests { ...@@ -57,7 +58,9 @@ public class CloudFoundryWebEndpointDiscovererTests {
for (ExposableWebEndpoint endpoint : endpoints) { for (ExposableWebEndpoint endpoint : endpoints) {
if (endpoint.getId().equals("health")) { if (endpoint.getId().equals("health")) {
WebOperation operation = endpoint.getOperations().iterator().next(); WebOperation operation = endpoint.getOperations().iterator().next();
assertThat(operation.invoke(Collections.emptyMap())).isEqualTo("cf"); assertThat(operation
.invoke(new InvocationContext(null, Collections.emptyMap())))
.isEqualTo("cf");
} }
} }
}); });
......
...@@ -16,7 +16,7 @@ ...@@ -16,7 +16,7 @@
package org.springframework.boot.actuate.endpoint; package org.springframework.boot.actuate.endpoint;
import java.util.Map; import org.springframework.boot.actuate.endpoint.invoke.InvocationContext;
/** /**
* An operation on an {@link ExposableEndpoint endpoint}. * An operation on an {@link ExposableEndpoint endpoint}.
...@@ -34,10 +34,10 @@ public interface Operation { ...@@ -34,10 +34,10 @@ public interface Operation {
OperationType getType(); OperationType getType();
/** /**
* Invoke the underlying operation using the given {@code arguments}. * Invoke the underlying operation using the given {@code context}.
* @param arguments the arguments to pass to the operation * @param context the context in to use when invoking the operation
* @return the result of the operation, may be {@code null} * @return the result of the operation, may be {@code null}
*/ */
Object invoke(Map<String, Object> arguments); Object invoke(InvocationContext context);
} }
...@@ -16,10 +16,9 @@ ...@@ -16,10 +16,9 @@
package org.springframework.boot.actuate.endpoint.annotation; package org.springframework.boot.actuate.endpoint.annotation;
import java.util.Map;
import org.springframework.boot.actuate.endpoint.Operation; import org.springframework.boot.actuate.endpoint.Operation;
import org.springframework.boot.actuate.endpoint.OperationType; import org.springframework.boot.actuate.endpoint.OperationType;
import org.springframework.boot.actuate.endpoint.invoke.InvocationContext;
import org.springframework.boot.actuate.endpoint.invoke.OperationInvoker; import org.springframework.boot.actuate.endpoint.invoke.OperationInvoker;
import org.springframework.boot.actuate.endpoint.invoke.reflect.OperationMethod; import org.springframework.boot.actuate.endpoint.invoke.reflect.OperationMethod;
import org.springframework.core.style.ToStringCreator; import org.springframework.core.style.ToStringCreator;
...@@ -58,8 +57,8 @@ public abstract class AbstractDiscoveredOperation implements Operation { ...@@ -58,8 +57,8 @@ public abstract class AbstractDiscoveredOperation implements Operation {
} }
@Override @Override
public Object invoke(Map<String, Object> arguments) { public Object invoke(InvocationContext context) {
return this.invoker.invoke(arguments); return this.invoker.invoke(context);
} }
@Override @Override
......
/*
* Copyright 2012-2018 the original author or authors.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
* You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/
package org.springframework.boot.actuate.endpoint.invoke;
import java.security.Principal;
import java.util.Map;
import org.springframework.util.Assert;
/**
* The context for the {@link OperationInvoker invocation of an operation}.
*
* @author Andy Wilkinson
* @since 2.0.0
*/
public class InvocationContext {
private final Principal principal;
private final Map<String, Object> arguments;
/**
* Creates a new context for an operation being invoked by the given {@code principal}
* with the given available {@code arguments}.
*
* @param principal the principal invoking the operation. May be {@code null}
* @param arguments the arguments available to the operation. Never {@code null}
*/
public InvocationContext(Principal principal, Map<String, Object> arguments) {
Assert.notNull(arguments, "Arguments must not be null");
this.principal = principal;
this.arguments = arguments;
}
public Principal getPrincipal() {
return this.principal;
}
public Map<String, Object> getArguments() {
return this.arguments;
}
}
...@@ -16,8 +16,6 @@ ...@@ -16,8 +16,6 @@
package org.springframework.boot.actuate.endpoint.invoke; package org.springframework.boot.actuate.endpoint.invoke;
import java.util.Map;
/** /**
* Interface to perform an operation invocation. * Interface to perform an operation invocation.
* *
...@@ -29,11 +27,11 @@ import java.util.Map; ...@@ -29,11 +27,11 @@ import java.util.Map;
public interface OperationInvoker { public interface OperationInvoker {
/** /**
* Invoke the underlying operation using the given {@code arguments}. * Invoke the underlying operation using the given {@code context}.
* @param arguments the arguments to pass to the operation * @param context the context to use to invoke the operation
* @return the result of the operation, may be {@code null} * @return the result of the operation, may be {@code null}
* @throws MissingParametersException if parameters are missing * @throws MissingParametersException if parameters are missing
*/ */
Object invoke(Map<String, Object> arguments) throws MissingParametersException; Object invoke(InvocationContext context) throws MissingParametersException;
} }
...@@ -17,10 +17,11 @@ ...@@ -17,10 +17,11 @@
package org.springframework.boot.actuate.endpoint.invoke.reflect; package org.springframework.boot.actuate.endpoint.invoke.reflect;
import java.lang.reflect.Method; import java.lang.reflect.Method;
import java.util.Map; import java.security.Principal;
import java.util.Set; import java.util.Set;
import java.util.stream.Collectors; import java.util.stream.Collectors;
import org.springframework.boot.actuate.endpoint.invoke.InvocationContext;
import org.springframework.boot.actuate.endpoint.invoke.MissingParametersException; import org.springframework.boot.actuate.endpoint.invoke.MissingParametersException;
import org.springframework.boot.actuate.endpoint.invoke.OperationInvoker; import org.springframework.boot.actuate.endpoint.invoke.OperationInvoker;
import org.springframework.boot.actuate.endpoint.invoke.OperationParameter; import org.springframework.boot.actuate.endpoint.invoke.OperationParameter;
...@@ -66,39 +67,44 @@ public class ReflectiveOperationInvoker implements OperationInvoker { ...@@ -66,39 +67,44 @@ public class ReflectiveOperationInvoker implements OperationInvoker {
} }
@Override @Override
public Object invoke(Map<String, Object> arguments) { public Object invoke(InvocationContext context) {
validateRequiredParameters(arguments); validateRequiredParameters(context);
Method method = this.operationMethod.getMethod(); Method method = this.operationMethod.getMethod();
Object[] resolvedArguments = resolveArguments(arguments); Object[] resolvedArguments = resolveArguments(context);
ReflectionUtils.makeAccessible(method); ReflectionUtils.makeAccessible(method);
return ReflectionUtils.invokeMethod(method, this.target, resolvedArguments); return ReflectionUtils.invokeMethod(method, this.target, resolvedArguments);
} }
private void validateRequiredParameters(Map<String, Object> arguments) { private void validateRequiredParameters(InvocationContext context) {
Set<OperationParameter> missing = this.operationMethod.getParameters().stream() Set<OperationParameter> missing = this.operationMethod.getParameters().stream()
.filter((parameter) -> isMissing(arguments, parameter)) .filter((parameter) -> isMissing(context, parameter))
.collect(Collectors.toSet()); .collect(Collectors.toSet());
if (!missing.isEmpty()) { if (!missing.isEmpty()) {
throw new MissingParametersException(missing); throw new MissingParametersException(missing);
} }
} }
private boolean isMissing(Map<String, Object> arguments, private boolean isMissing(InvocationContext context, OperationParameter parameter) {
OperationParameter parameter) {
if (!parameter.isMandatory()) { if (!parameter.isMandatory()) {
return false; return false;
} }
return arguments.get(parameter.getName()) == null; if (Principal.class.equals(parameter.getType())) {
return context.getPrincipal() == null;
}
return context.getArguments().get(parameter.getName()) == null;
} }
private Object[] resolveArguments(Map<String, Object> arguments) { private Object[] resolveArguments(InvocationContext context) {
return this.operationMethod.getParameters().stream() return this.operationMethod.getParameters().stream()
.map((parameter) -> resolveArgument(parameter, arguments)).toArray(); .map((parameter) -> resolveArgument(parameter, context)).toArray();
} }
private Object resolveArgument(OperationParameter parameter, private Object resolveArgument(OperationParameter parameter,
Map<String, Object> arguments) { InvocationContext context) {
Object value = arguments.get(parameter.getName()); if (Principal.class.equals(parameter.getType())) {
return context.getPrincipal();
}
Object value = context.getArguments().get(parameter.getName());
return this.parameterValueMapper.mapParameterValue(parameter, value); return this.parameterValueMapper.mapParameterValue(parameter, value);
} }
......
...@@ -19,6 +19,7 @@ package org.springframework.boot.actuate.endpoint.invoker.cache; ...@@ -19,6 +19,7 @@ package org.springframework.boot.actuate.endpoint.invoker.cache;
import java.util.Map; import java.util.Map;
import java.util.Objects; import java.util.Objects;
import org.springframework.boot.actuate.endpoint.invoke.InvocationContext;
import org.springframework.boot.actuate.endpoint.invoke.OperationInvoker; import org.springframework.boot.actuate.endpoint.invoke.OperationInvoker;
import org.springframework.util.Assert; import org.springframework.util.Assert;
import org.springframework.util.ObjectUtils; import org.springframework.util.ObjectUtils;
...@@ -59,21 +60,25 @@ public class CachingOperationInvoker implements OperationInvoker { ...@@ -59,21 +60,25 @@ public class CachingOperationInvoker implements OperationInvoker {
} }
@Override @Override
public Object invoke(Map<String, Object> arguments) { public Object invoke(InvocationContext context) {
if (hasArgument(arguments)) { if (hasInput(context)) {
return this.invoker.invoke(arguments); return this.invoker.invoke(context);
} }
long accessTime = System.currentTimeMillis(); long accessTime = System.currentTimeMillis();
CachedResponse cached = this.cachedResponse; CachedResponse cached = this.cachedResponse;
if (cached == null || cached.isStale(accessTime, this.timeToLive)) { if (cached == null || cached.isStale(accessTime, this.timeToLive)) {
Object response = this.invoker.invoke(arguments); Object response = this.invoker.invoke(context);
this.cachedResponse = new CachedResponse(response, accessTime); this.cachedResponse = new CachedResponse(response, accessTime);
return response; return response;
} }
return cached.getResponse(); return cached.getResponse();
} }
private boolean hasArgument(Map<String, Object> arguments) { private boolean hasInput(InvocationContext context) {
if (context.getPrincipal() != null) {
return true;
}
Map<String, Object> arguments = context.getArguments();
if (!ObjectUtils.isEmpty(arguments)) { if (!ObjectUtils.isEmpty(arguments)) {
return arguments.values().stream().anyMatch(Objects::nonNull); return arguments.values().stream().anyMatch(Objects::nonNull);
} }
......
...@@ -32,6 +32,7 @@ import javax.management.ReflectionException; ...@@ -32,6 +32,7 @@ import javax.management.ReflectionException;
import reactor.core.publisher.Mono; import reactor.core.publisher.Mono;
import org.springframework.boot.actuate.endpoint.InvalidEndpointRequestException; import org.springframework.boot.actuate.endpoint.InvalidEndpointRequestException;
import org.springframework.boot.actuate.endpoint.invoke.InvocationContext;
import org.springframework.util.Assert; import org.springframework.util.Assert;
import org.springframework.util.ClassUtils; import org.springframework.util.ClassUtils;
...@@ -96,7 +97,7 @@ public class EndpointMBean implements DynamicMBean { ...@@ -96,7 +97,7 @@ public class EndpointMBean implements DynamicMBean {
String[] parameterNames = operation.getParameters().stream() String[] parameterNames = operation.getParameters().stream()
.map(JmxOperationParameter::getName).toArray(String[]::new); .map(JmxOperationParameter::getName).toArray(String[]::new);
Map<String, Object> arguments = getArguments(parameterNames, params); Map<String, Object> arguments = getArguments(parameterNames, params);
Object result = operation.invoke(arguments); Object result = operation.invoke(new InvocationContext(null, arguments));
if (REACTOR_PRESENT) { if (REACTOR_PRESENT) {
result = ReactiveHandler.handle(result); result = ReactiveHandler.handle(result);
} }
......
...@@ -18,7 +18,6 @@ package org.springframework.boot.actuate.endpoint.web.jersey; ...@@ -18,7 +18,6 @@ package org.springframework.boot.actuate.endpoint.web.jersey;
import java.io.IOException; import java.io.IOException;
import java.io.InputStream; import java.io.InputStream;
import java.security.Principal;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.Collection; import java.util.Collection;
import java.util.Collections; import java.util.Collections;
...@@ -40,6 +39,7 @@ import org.glassfish.jersey.server.model.Resource.Builder; ...@@ -40,6 +39,7 @@ import org.glassfish.jersey.server.model.Resource.Builder;
import reactor.core.publisher.Mono; import reactor.core.publisher.Mono;
import org.springframework.boot.actuate.endpoint.InvalidEndpointRequestException; import org.springframework.boot.actuate.endpoint.InvalidEndpointRequestException;
import org.springframework.boot.actuate.endpoint.invoke.InvocationContext;
import org.springframework.boot.actuate.endpoint.web.EndpointLinksResolver; import org.springframework.boot.actuate.endpoint.web.EndpointLinksResolver;
import org.springframework.boot.actuate.endpoint.web.EndpointMapping; import org.springframework.boot.actuate.endpoint.web.EndpointMapping;
import org.springframework.boot.actuate.endpoint.web.EndpointMediaTypes; import org.springframework.boot.actuate.endpoint.web.EndpointMediaTypes;
...@@ -148,12 +148,9 @@ public class JerseyEndpointResourceFactory { ...@@ -148,12 +148,9 @@ public class JerseyEndpointResourceFactory {
} }
arguments.putAll(extractPathParameters(data)); arguments.putAll(extractPathParameters(data));
arguments.putAll(extractQueryParameters(data)); arguments.putAll(extractQueryParameters(data));
Principal principal = data.getSecurityContext().getUserPrincipal();
if (principal != null) {
arguments.put("principal", principal);
}
try { try {
Object response = this.operation.invoke(arguments); Object response = this.operation.invoke(new InvocationContext(
data.getSecurityContext().getUserPrincipal(), arguments));
return convertToJaxRsResponse(response, data.getRequest().getMethod()); return convertToJaxRsResponse(response, data.getRequest().getMethod());
} }
catch (InvalidEndpointRequestException ex) { catch (InvalidEndpointRequestException ex) {
......
...@@ -29,6 +29,7 @@ import reactor.core.scheduler.Schedulers; ...@@ -29,6 +29,7 @@ import reactor.core.scheduler.Schedulers;
import org.springframework.boot.actuate.endpoint.InvalidEndpointRequestException; import org.springframework.boot.actuate.endpoint.InvalidEndpointRequestException;
import org.springframework.boot.actuate.endpoint.OperationType; import org.springframework.boot.actuate.endpoint.OperationType;
import org.springframework.boot.actuate.endpoint.invoke.InvocationContext;
import org.springframework.boot.actuate.endpoint.invoke.OperationInvoker; import org.springframework.boot.actuate.endpoint.invoke.OperationInvoker;
import org.springframework.boot.actuate.endpoint.web.EndpointMapping; import org.springframework.boot.actuate.endpoint.web.EndpointMapping;
import org.springframework.boot.actuate.endpoint.web.EndpointMediaTypes; import org.springframework.boot.actuate.endpoint.web.EndpointMediaTypes;
...@@ -221,14 +222,14 @@ public abstract class AbstractWebFluxEndpointHandlerMapping ...@@ -221,14 +222,14 @@ public abstract class AbstractWebFluxEndpointHandlerMapping
} }
@Override @Override
public Object invoke(Map<String, Object> arguments) { public Object invoke(InvocationContext context) {
return Mono.create((sink) -> Schedulers.elastic() return Mono.create(
.schedule(() -> invoke(arguments, sink))); (sink) -> Schedulers.elastic().schedule(() -> invoke(context, sink)));
} }
private void invoke(Map<String, Object> arguments, MonoSink<Object> sink) { private void invoke(InvocationContext context, MonoSink<Object> sink) {
try { try {
Object result = this.invoker.invoke(arguments); Object result = this.invoker.invoke(context);
sink.success(result); sink.success(result);
} }
catch (Exception ex) { catch (Exception ex) {
...@@ -275,15 +276,17 @@ public abstract class AbstractWebFluxEndpointHandlerMapping ...@@ -275,15 +276,17 @@ public abstract class AbstractWebFluxEndpointHandlerMapping
Map<String, String> body) { Map<String, String> body) {
return exchange.getPrincipal().defaultIfEmpty(NO_PRINCIPAL) return exchange.getPrincipal().defaultIfEmpty(NO_PRINCIPAL)
.flatMap((principal) -> { .flatMap((principal) -> {
Map<String, Object> arguments = getArguments(exchange, principal, Map<String, Object> arguments = getArguments(exchange, body);
body); return handleResult(
return handleResult((Publisher<?>) this.invoker.invoke(arguments), (Publisher<?>) this.invoker.invoke(new InvocationContext(
principal == NO_PRINCIPAL ? null : principal,
arguments)),
exchange.getRequest().getMethod()); exchange.getRequest().getMethod());
}); });
} }
private Map<String, Object> getArguments(ServerWebExchange exchange, private Map<String, Object> getArguments(ServerWebExchange exchange,
Principal principal, Map<String, String> body) { Map<String, String> body) {
Map<String, Object> arguments = new LinkedHashMap<>(); Map<String, Object> arguments = new LinkedHashMap<>();
arguments.putAll(getTemplateVariables(exchange)); arguments.putAll(getTemplateVariables(exchange));
if (body != null) { if (body != null) {
...@@ -291,9 +294,6 @@ public abstract class AbstractWebFluxEndpointHandlerMapping ...@@ -291,9 +294,6 @@ public abstract class AbstractWebFluxEndpointHandlerMapping
} }
exchange.getRequest().getQueryParams().forEach((name, values) -> arguments exchange.getRequest().getQueryParams().forEach((name, values) -> arguments
.put(name, values.size() == 1 ? values.get(0) : values)); .put(name, values.size() == 1 ? values.get(0) : values));
if (principal != null && principal != NO_PRINCIPAL) {
arguments.put("principal", principal);
}
return arguments; return arguments;
} }
......
...@@ -17,7 +17,6 @@ ...@@ -17,7 +17,6 @@
package org.springframework.boot.actuate.endpoint.web.servlet; package org.springframework.boot.actuate.endpoint.web.servlet;
import java.lang.reflect.Method; import java.lang.reflect.Method;
import java.security.Principal;
import java.util.Arrays; import java.util.Arrays;
import java.util.Collection; import java.util.Collection;
import java.util.LinkedHashMap; import java.util.LinkedHashMap;
...@@ -29,6 +28,7 @@ import javax.servlet.http.HttpServletResponse; ...@@ -29,6 +28,7 @@ import javax.servlet.http.HttpServletResponse;
import org.springframework.beans.factory.InitializingBean; import org.springframework.beans.factory.InitializingBean;
import org.springframework.boot.actuate.endpoint.InvalidEndpointRequestException; import org.springframework.boot.actuate.endpoint.InvalidEndpointRequestException;
import org.springframework.boot.actuate.endpoint.invoke.InvocationContext;
import org.springframework.boot.actuate.endpoint.invoke.OperationInvoker; import org.springframework.boot.actuate.endpoint.invoke.OperationInvoker;
import org.springframework.boot.actuate.endpoint.web.EndpointMapping; import org.springframework.boot.actuate.endpoint.web.EndpointMapping;
import org.springframework.boot.actuate.endpoint.web.EndpointMediaTypes; import org.springframework.boot.actuate.endpoint.web.EndpointMediaTypes;
...@@ -241,7 +241,9 @@ public abstract class AbstractWebMvcEndpointHandlerMapping ...@@ -241,7 +241,9 @@ public abstract class AbstractWebMvcEndpointHandlerMapping
@RequestBody(required = false) Map<String, String> body) { @RequestBody(required = false) Map<String, String> body) {
Map<String, Object> arguments = getArguments(request, body); Map<String, Object> arguments = getArguments(request, body);
try { try {
return handleResult(this.invoker.invoke(arguments), return handleResult(
this.invoker.invoke(new InvocationContext(
request.getUserPrincipal(), arguments)),
HttpMethod.valueOf(request.getMethod())); HttpMethod.valueOf(request.getMethod()));
} }
catch (InvalidEndpointRequestException ex) { catch (InvalidEndpointRequestException ex) {
...@@ -258,10 +260,6 @@ public abstract class AbstractWebMvcEndpointHandlerMapping ...@@ -258,10 +260,6 @@ public abstract class AbstractWebMvcEndpointHandlerMapping
} }
request.getParameterMap().forEach((name, values) -> arguments.put(name, request.getParameterMap().forEach((name, values) -> arguments.put(name,
values.length == 1 ? values[0] : Arrays.asList(values))); values.length == 1 ? values[0] : Arrays.asList(values)));
Principal principal = request.getUserPrincipal();
if (principal != null) {
arguments.put("principal", principal);
}
return arguments; return arguments;
} }
......
...@@ -26,6 +26,7 @@ import org.junit.Before; ...@@ -26,6 +26,7 @@ import org.junit.Before;
import org.junit.Test; import org.junit.Test;
import org.springframework.boot.actuate.endpoint.OperationType; import org.springframework.boot.actuate.endpoint.OperationType;
import org.springframework.boot.actuate.endpoint.invoke.InvocationContext;
import org.springframework.boot.actuate.endpoint.invoke.OperationInvoker; import org.springframework.boot.actuate.endpoint.invoke.OperationInvoker;
import org.springframework.boot.actuate.endpoint.invoke.OperationInvokerAdvisor; import org.springframework.boot.actuate.endpoint.invoke.OperationInvokerAdvisor;
import org.springframework.boot.actuate.endpoint.invoke.OperationParameters; import org.springframework.boot.actuate.endpoint.invoke.OperationParameters;
...@@ -105,7 +106,7 @@ public class DiscoveredOperationsFactoryTests { ...@@ -105,7 +106,7 @@ public class DiscoveredOperationsFactoryTests {
TestOperation operation = getFirst( TestOperation operation = getFirst(
this.factory.createOperations("test", new ExampleWithParams())); this.factory.createOperations("test", new ExampleWithParams()));
Map<String, Object> params = Collections.singletonMap("name", 123); Map<String, Object> params = Collections.singletonMap("name", 123);
Object result = operation.invoke(params); Object result = operation.invoke(new InvocationContext(null, params));
assertThat(result).isEqualTo("123"); assertThat(result).isEqualTo("123");
} }
...@@ -115,7 +116,7 @@ public class DiscoveredOperationsFactoryTests { ...@@ -115,7 +116,7 @@ public class DiscoveredOperationsFactoryTests {
this.invokerAdvisors.add(advisor); this.invokerAdvisors.add(advisor);
TestOperation operation = getFirst( TestOperation operation = getFirst(
this.factory.createOperations("test", new ExampleRead())); this.factory.createOperations("test", new ExampleRead()));
operation.invoke(Collections.emptyMap()); operation.invoke(new InvocationContext(null, Collections.emptyMap()));
assertThat(advisor.getEndpointId()).isEqualTo("test"); assertThat(advisor.getEndpointId()).isEqualTo("test");
assertThat(advisor.getOperationType()).isEqualTo(OperationType.READ); assertThat(advisor.getOperationType()).isEqualTo(OperationType.READ);
assertThat(advisor.getParameters()).isEmpty(); assertThat(advisor.getParameters()).isEmpty();
......
...@@ -24,6 +24,7 @@ import org.junit.Test; ...@@ -24,6 +24,7 @@ import org.junit.Test;
import org.junit.rules.ExpectedException; import org.junit.rules.ExpectedException;
import org.springframework.boot.actuate.endpoint.OperationType; import org.springframework.boot.actuate.endpoint.OperationType;
import org.springframework.boot.actuate.endpoint.invoke.InvocationContext;
import org.springframework.boot.actuate.endpoint.invoke.MissingParametersException; import org.springframework.boot.actuate.endpoint.invoke.MissingParametersException;
import org.springframework.boot.actuate.endpoint.invoke.ParameterValueMapper; import org.springframework.boot.actuate.endpoint.invoke.ParameterValueMapper;
import org.springframework.lang.Nullable; import org.springframework.lang.Nullable;
...@@ -83,7 +84,8 @@ public class ReflectiveOperationInvokerTests { ...@@ -83,7 +84,8 @@ public class ReflectiveOperationInvokerTests {
public void invokeShouldInvokeMethod() { public void invokeShouldInvokeMethod() {
ReflectiveOperationInvoker invoker = new ReflectiveOperationInvoker(this.target, ReflectiveOperationInvoker invoker = new ReflectiveOperationInvoker(this.target,
this.operationMethod, this.parameterValueMapper); this.operationMethod, this.parameterValueMapper);
Object result = invoker.invoke(Collections.singletonMap("name", "boot")); Object result = invoker.invoke(
new InvocationContext(null, Collections.singletonMap("name", "boot")));
assertThat(result).isEqualTo("toob"); assertThat(result).isEqualTo("toob");
} }
...@@ -92,7 +94,8 @@ public class ReflectiveOperationInvokerTests { ...@@ -92,7 +94,8 @@ public class ReflectiveOperationInvokerTests {
ReflectiveOperationInvoker invoker = new ReflectiveOperationInvoker(this.target, ReflectiveOperationInvoker invoker = new ReflectiveOperationInvoker(this.target,
this.operationMethod, this.parameterValueMapper); this.operationMethod, this.parameterValueMapper);
this.thrown.expect(MissingParametersException.class); this.thrown.expect(MissingParametersException.class);
invoker.invoke(Collections.singletonMap("name", null)); invoker.invoke(
new InvocationContext(null, Collections.singletonMap("name", null)));
} }
@Test @Test
...@@ -101,7 +104,8 @@ public class ReflectiveOperationInvokerTests { ...@@ -101,7 +104,8 @@ public class ReflectiveOperationInvokerTests {
Example.class, "reverseNullable", String.class), OperationType.READ); Example.class, "reverseNullable", String.class), OperationType.READ);
ReflectiveOperationInvoker invoker = new ReflectiveOperationInvoker(this.target, ReflectiveOperationInvoker invoker = new ReflectiveOperationInvoker(this.target,
operationMethod, this.parameterValueMapper); operationMethod, this.parameterValueMapper);
Object result = invoker.invoke(Collections.singletonMap("name", null)); Object result = invoker.invoke(
new InvocationContext(null, Collections.singletonMap("name", null)));
assertThat(result).isEqualTo("llun"); assertThat(result).isEqualTo("llun");
} }
...@@ -109,7 +113,8 @@ public class ReflectiveOperationInvokerTests { ...@@ -109,7 +113,8 @@ public class ReflectiveOperationInvokerTests {
public void invokeShouldResolveParameters() { public void invokeShouldResolveParameters() {
ReflectiveOperationInvoker invoker = new ReflectiveOperationInvoker(this.target, ReflectiveOperationInvoker invoker = new ReflectiveOperationInvoker(this.target,
this.operationMethod, this.parameterValueMapper); this.operationMethod, this.parameterValueMapper);
Object result = invoker.invoke(Collections.singletonMap("name", 1234)); Object result = invoker.invoke(
new InvocationContext(null, Collections.singletonMap("name", 1234)));
assertThat(result).isEqualTo("4321"); assertThat(result).isEqualTo("4321");
} }
......
...@@ -24,6 +24,7 @@ import org.junit.Rule; ...@@ -24,6 +24,7 @@ import org.junit.Rule;
import org.junit.Test; import org.junit.Test;
import org.junit.rules.ExpectedException; import org.junit.rules.ExpectedException;
import org.springframework.boot.actuate.endpoint.invoke.InvocationContext;
import org.springframework.boot.actuate.endpoint.invoke.OperationInvoker; import org.springframework.boot.actuate.endpoint.invoke.OperationInvoker;
import static org.assertj.core.api.Assertions.assertThat; import static org.assertj.core.api.Assertions.assertThat;
...@@ -66,12 +67,13 @@ public class CachingOperationInvokerTests { ...@@ -66,12 +67,13 @@ public class CachingOperationInvokerTests {
private void assertCacheIsUsed(Map<String, Object> parameters) { private void assertCacheIsUsed(Map<String, Object> parameters) {
OperationInvoker target = mock(OperationInvoker.class); OperationInvoker target = mock(OperationInvoker.class);
Object expected = new Object(); Object expected = new Object();
given(target.invoke(parameters)).willReturn(expected); InvocationContext context = new InvocationContext(null, parameters);
given(target.invoke(context)).willReturn(expected);
CachingOperationInvoker invoker = new CachingOperationInvoker(target, 500L); CachingOperationInvoker invoker = new CachingOperationInvoker(target, 500L);
Object response = invoker.invoke(parameters); Object response = invoker.invoke(context);
assertThat(response).isSameAs(expected); assertThat(response).isSameAs(expected);
verify(target, times(1)).invoke(parameters); verify(target, times(1)).invoke(context);
Object cachedResponse = invoker.invoke(parameters); Object cachedResponse = invoker.invoke(context);
assertThat(cachedResponse).isSameAs(response); assertThat(cachedResponse).isSameAs(response);
verifyNoMoreInteractions(target); verifyNoMoreInteractions(target);
} }
...@@ -82,24 +84,26 @@ public class CachingOperationInvokerTests { ...@@ -82,24 +84,26 @@ public class CachingOperationInvokerTests {
Map<String, Object> parameters = new HashMap<>(); Map<String, Object> parameters = new HashMap<>();
parameters.put("test", "value"); parameters.put("test", "value");
parameters.put("something", null); parameters.put("something", null);
given(target.invoke(parameters)).willReturn(new Object()); InvocationContext context = new InvocationContext(null, parameters);
given(target.invoke(context)).willReturn(new Object());
CachingOperationInvoker invoker = new CachingOperationInvoker(target, 500L); CachingOperationInvoker invoker = new CachingOperationInvoker(target, 500L);
invoker.invoke(parameters); invoker.invoke(context);
invoker.invoke(parameters); invoker.invoke(context);
invoker.invoke(parameters); invoker.invoke(context);
verify(target, times(3)).invoke(parameters); verify(target, times(3)).invoke(context);
} }
@Test @Test
public void targetInvokedWhenCacheExpires() throws InterruptedException { public void targetInvokedWhenCacheExpires() throws InterruptedException {
OperationInvoker target = mock(OperationInvoker.class); OperationInvoker target = mock(OperationInvoker.class);
Map<String, Object> parameters = new HashMap<>(); Map<String, Object> parameters = new HashMap<>();
given(target.invoke(parameters)).willReturn(new Object()); InvocationContext context = new InvocationContext(null, parameters);
given(target.invoke(context)).willReturn(new Object());
CachingOperationInvoker invoker = new CachingOperationInvoker(target, 50L); CachingOperationInvoker invoker = new CachingOperationInvoker(target, 50L);
invoker.invoke(parameters); invoker.invoke(context);
Thread.sleep(55); Thread.sleep(55);
invoker.invoke(parameters); invoker.invoke(context);
verify(target, times(2)).invoke(parameters); verify(target, times(2)).invoke(context);
} }
} }
...@@ -22,6 +22,7 @@ import java.util.Map; ...@@ -22,6 +22,7 @@ import java.util.Map;
import java.util.function.Function; import java.util.function.Function;
import org.springframework.boot.actuate.endpoint.OperationType; import org.springframework.boot.actuate.endpoint.OperationType;
import org.springframework.boot.actuate.endpoint.invoke.InvocationContext;
/** /**
* Test {@link JmxOperation} implementation. * Test {@link JmxOperation} implementation.
...@@ -66,8 +67,9 @@ public class TestJmxOperation implements JmxOperation { ...@@ -66,8 +67,9 @@ public class TestJmxOperation implements JmxOperation {
} }
@Override @Override
public Object invoke(Map<String, Object> arguments) { public Object invoke(InvocationContext context) {
return (this.invoke == null ? "result" : this.invoke.apply(arguments)); return (this.invoke == null ? "result"
: this.invoke.apply(context.getArguments()));
} }
@Override @Override
......
...@@ -25,6 +25,7 @@ import java.util.List; ...@@ -25,6 +25,7 @@ import java.util.List;
import java.util.Map; import java.util.Map;
import java.util.function.BiConsumer; import java.util.function.BiConsumer;
import java.util.function.Consumer; import java.util.function.Consumer;
import java.util.function.Supplier;
import org.junit.Test; import org.junit.Test;
import reactor.core.publisher.Mono; import reactor.core.publisher.Mono;
...@@ -37,6 +38,7 @@ import org.springframework.boot.actuate.endpoint.annotation.WriteOperation; ...@@ -37,6 +38,7 @@ import org.springframework.boot.actuate.endpoint.annotation.WriteOperation;
import org.springframework.boot.actuate.endpoint.web.WebEndpointResponse; import org.springframework.boot.actuate.endpoint.web.WebEndpointResponse;
import org.springframework.context.ApplicationContext; import org.springframework.context.ApplicationContext;
import org.springframework.context.ConfigurableApplicationContext; import org.springframework.context.ConfigurableApplicationContext;
import org.springframework.context.annotation.AnnotationConfigRegistry;
import org.springframework.context.annotation.Bean; import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration; import org.springframework.context.annotation.Configuration;
import org.springframework.context.annotation.Import; import org.springframework.context.annotation.Import;
...@@ -57,7 +59,7 @@ import static org.mockito.Mockito.verify; ...@@ -57,7 +59,7 @@ import static org.mockito.Mockito.verify;
* @param <T> the type of application context used by the tests * @param <T> the type of application context used by the tests
* @author Andy Wilkinson * @author Andy Wilkinson
*/ */
public abstract class AbstractWebEndpointIntegrationTests<T extends ConfigurableApplicationContext> { public abstract class AbstractWebEndpointIntegrationTests<T extends ConfigurableApplicationContext & AnnotationConfigRegistry> {
private static final Duration TIMEOUT = Duration.ofMinutes(6); private static final Duration TIMEOUT = Duration.ofMinutes(6);
...@@ -65,10 +67,14 @@ public abstract class AbstractWebEndpointIntegrationTests<T extends Configurable ...@@ -65,10 +67,14 @@ public abstract class AbstractWebEndpointIntegrationTests<T extends Configurable
private static final String JSON_MEDIA_TYPE_PATTERN = "application/json(;charset=UTF-8)?"; private static final String JSON_MEDIA_TYPE_PATTERN = "application/json(;charset=UTF-8)?";
private final Class<?> exporterConfiguration; private final Supplier<T> applicationContextSupplier;
protected AbstractWebEndpointIntegrationTests(Class<?> exporterConfiguration) { private final Consumer<T> authenticatedContextCustomizer;
this.exporterConfiguration = exporterConfiguration;
protected AbstractWebEndpointIntegrationTests(Supplier<T> applicationContextSupplier,
Consumer<T> authenticatedContextCustomizer) {
this.applicationContextSupplier = applicationContextSupplier;
this.authenticatedContextCustomizer = authenticatedContextCustomizer;
} }
@Test @Test
...@@ -337,13 +343,23 @@ public abstract class AbstractWebEndpointIntegrationTests<T extends Configurable ...@@ -337,13 +343,23 @@ public abstract class AbstractWebEndpointIntegrationTests<T extends Configurable
@Test @Test
public void principalIsAvailableWhenRequestHasAPrincipal() { public void principalIsAvailableWhenRequestHasAPrincipal() {
load(getSecuredPrincipalEndpointConfiguration(), load((context) -> {
(client) -> client.get().uri("/principal") this.authenticatedContextCustomizer.accept(context);
.accept(MediaType.APPLICATION_JSON).exchange().expectStatus() context.register(PrincipalEndpointConfiguration.class);
.isOk().expectBody(String.class).isEqualTo("Alice")); }, (client) -> client.get().uri("/principal").accept(MediaType.APPLICATION_JSON)
.exchange().expectStatus().isOk().expectBody(String.class)
.isEqualTo("Alice"));
} }
protected abstract T createApplicationContext(Class<?>... config); @Test
public void operationWithAQueryNamedPrincipalCanBeAccessedWhenAuthenticated() {
load((context) -> {
this.authenticatedContextCustomizer.accept(context);
context.register(PrincipalQueryEndpointConfiguration.class);
}, (client) -> client.get().uri("/principalquery?principal=Zoe")
.accept(MediaType.APPLICATION_JSON).exchange().expectStatus().isOk()
.expectBody(String.class).isEqualTo("Zoe"));
}
protected abstract int getPort(T context); protected abstract int getPort(T context);
...@@ -356,40 +372,47 @@ public abstract class AbstractWebEndpointIntegrationTests<T extends Configurable ...@@ -356,40 +372,47 @@ public abstract class AbstractWebEndpointIntegrationTests<T extends Configurable
private void load(Class<?> configuration, private void load(Class<?> configuration,
BiConsumer<ApplicationContext, WebTestClient> consumer) { BiConsumer<ApplicationContext, WebTestClient> consumer) {
load(configuration, "/endpoints", consumer); load((context) -> context.register(configuration), "/endpoints", consumer);
} }
private void load(Class<?> configuration, String endpointPath, protected void load(Class<?> configuration, Consumer<WebTestClient> clientConsumer) {
BiConsumer<ApplicationContext, WebTestClient> consumer) { load((context) -> context.register(configuration), "/endpoints",
T context = createApplicationContext(configuration, this.exporterConfiguration); (context, client) -> clientConsumer.accept(client));
context.getEnvironment().getPropertySources().addLast(new MapPropertySource(
"test", Collections.singletonMap("endpointPath", endpointPath)));
context.refresh();
try {
InetSocketAddress address = new InetSocketAddress(getPort(context));
String url = "http://" + address.getHostString() + ":" + address.getPort()
+ endpointPath;
consumer.accept(context, WebTestClient.bindToServer().baseUrl(url)
.responseTimeout(TIMEOUT).build());
}
finally {
context.close();
}
} }
protected abstract Class<?> getSecuredPrincipalEndpointConfiguration(); protected void load(Consumer<T> contextCustomizer,
Consumer<WebTestClient> clientConsumer) {
protected void load(Class<?> configuration, Consumer<WebTestClient> clientConsumer) { load(contextCustomizer, "/endpoints",
load(configuration, "/endpoints",
(context, client) -> clientConsumer.accept(client)); (context, client) -> clientConsumer.accept(client));
} }
protected void load(Class<?> configuration, String endpointPath, protected void load(Class<?> configuration, String endpointPath,
Consumer<WebTestClient> clientConsumer) { Consumer<WebTestClient> clientConsumer) {
load(configuration, endpointPath, load((context) -> context.register(configuration), endpointPath,
(context, client) -> clientConsumer.accept(client)); (context, client) -> clientConsumer.accept(client));
} }
private void load(Consumer<T> contextCustomizer, String endpointPath,
BiConsumer<ApplicationContext, WebTestClient> consumer) {
T applicationContext = this.applicationContextSupplier.get();
contextCustomizer.accept(applicationContext);
applicationContext.getEnvironment().getPropertySources()
.addLast(new MapPropertySource("test",
Collections.singletonMap("endpointPath", endpointPath)));
applicationContext.refresh();
try {
InetSocketAddress address = new InetSocketAddress(
getPort(applicationContext));
String url = "http://" + address.getHostString() + ":" + address.getPort()
+ endpointPath;
consumer.accept(applicationContext, WebTestClient.bindToServer().baseUrl(url)
.responseTimeout(TIMEOUT).build());
}
finally {
applicationContext.close();
}
}
@Configuration @Configuration
@Import(BaseConfiguration.class) @Import(BaseConfiguration.class)
protected static class TestEndpointConfiguration { protected static class TestEndpointConfiguration {
...@@ -547,6 +570,17 @@ public abstract class AbstractWebEndpointIntegrationTests<T extends Configurable ...@@ -547,6 +570,17 @@ public abstract class AbstractWebEndpointIntegrationTests<T extends Configurable
} }
@Configuration
@Import(BaseConfiguration.class)
protected static class PrincipalQueryEndpointConfiguration {
@Bean
public PrincipalQueryEndpoint principalQueryEndpoint() {
return new PrincipalQueryEndpoint();
}
}
@Endpoint(id = "test") @Endpoint(id = "test")
static class TestEndpoint { static class TestEndpoint {
...@@ -735,6 +769,16 @@ public abstract class AbstractWebEndpointIntegrationTests<T extends Configurable ...@@ -735,6 +769,16 @@ public abstract class AbstractWebEndpointIntegrationTests<T extends Configurable
} }
@Endpoint(id = "principalquery")
static class PrincipalQueryEndpoint {
@ReadOperation
public String read(String principal) {
return principal;
}
}
public interface EndpointDelegate { public interface EndpointDelegate {
void write(); void write();
......
...@@ -45,7 +45,6 @@ import org.springframework.boot.web.servlet.ServletRegistrationBean; ...@@ -45,7 +45,6 @@ import org.springframework.boot.web.servlet.ServletRegistrationBean;
import org.springframework.boot.web.servlet.context.AnnotationConfigServletWebServerApplicationContext; import org.springframework.boot.web.servlet.context.AnnotationConfigServletWebServerApplicationContext;
import org.springframework.context.annotation.Bean; import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration; import org.springframework.context.annotation.Configuration;
import org.springframework.context.annotation.Import;
import org.springframework.core.env.Environment; import org.springframework.core.env.Environment;
import org.springframework.http.HttpStatus; import org.springframework.http.HttpStatus;
import org.springframework.test.web.reactive.server.WebTestClient; import org.springframework.test.web.reactive.server.WebTestClient;
...@@ -61,17 +60,21 @@ public class JerseyWebEndpointIntegrationTests extends ...@@ -61,17 +60,21 @@ public class JerseyWebEndpointIntegrationTests extends
AbstractWebEndpointIntegrationTests<AnnotationConfigServletWebServerApplicationContext> { AbstractWebEndpointIntegrationTests<AnnotationConfigServletWebServerApplicationContext> {
public JerseyWebEndpointIntegrationTests() { public JerseyWebEndpointIntegrationTests() {
super(JerseyConfiguration.class); super(JerseyWebEndpointIntegrationTests::createApplicationContext,
JerseyWebEndpointIntegrationTests::applyAuthenticatedConfiguration);
} }
@Override private static AnnotationConfigServletWebServerApplicationContext createApplicationContext() {
protected AnnotationConfigServletWebServerApplicationContext createApplicationContext(
Class<?>... config) {
AnnotationConfigServletWebServerApplicationContext context = new AnnotationConfigServletWebServerApplicationContext(); AnnotationConfigServletWebServerApplicationContext context = new AnnotationConfigServletWebServerApplicationContext();
context.register(config); context.register(JerseyConfiguration.class);
return context; return context;
} }
private static void applyAuthenticatedConfiguration(
AnnotationConfigServletWebServerApplicationContext context) {
context.register(AuthenticatedConfiguration.class);
}
@Override @Override
protected int getPort(AnnotationConfigServletWebServerApplicationContext context) { protected int getPort(AnnotationConfigServletWebServerApplicationContext context) {
return context.getWebServer().getPort(); return context.getWebServer().getPort();
...@@ -83,11 +86,6 @@ public class JerseyWebEndpointIntegrationTests extends ...@@ -83,11 +86,6 @@ public class JerseyWebEndpointIntegrationTests extends
// Jersey doesn't support the general error page handling // Jersey doesn't support the general error page handling
} }
@Override
protected Class<?> getSecuredPrincipalEndpointConfiguration() {
return SecuredPrincipalEndpointConfiguration.class;
}
@Configuration @Configuration
static class JerseyConfiguration { static class JerseyConfiguration {
...@@ -123,8 +121,7 @@ public class JerseyWebEndpointIntegrationTests extends ...@@ -123,8 +121,7 @@ public class JerseyWebEndpointIntegrationTests extends
} }
@Configuration @Configuration
@Import(PrincipalEndpointConfiguration.class) static class AuthenticatedConfiguration {
static class SecuredPrincipalEndpointConfiguration {
@Bean @Bean
public Filter securityFilter() { public Filter securityFilter() {
......
...@@ -31,13 +31,11 @@ import org.springframework.boot.autoconfigure.ImportAutoConfiguration; ...@@ -31,13 +31,11 @@ import org.springframework.boot.autoconfigure.ImportAutoConfiguration;
import org.springframework.boot.autoconfigure.web.reactive.error.ErrorWebFluxAutoConfiguration; import org.springframework.boot.autoconfigure.web.reactive.error.ErrorWebFluxAutoConfiguration;
import org.springframework.boot.web.embedded.netty.NettyReactiveWebServerFactory; import org.springframework.boot.web.embedded.netty.NettyReactiveWebServerFactory;
import org.springframework.boot.web.reactive.context.AnnotationConfigReactiveWebServerApplicationContext; import org.springframework.boot.web.reactive.context.AnnotationConfigReactiveWebServerApplicationContext;
import org.springframework.boot.web.reactive.context.ReactiveWebServerApplicationContext;
import org.springframework.boot.web.reactive.context.ReactiveWebServerInitializedEvent; import org.springframework.boot.web.reactive.context.ReactiveWebServerInitializedEvent;
import org.springframework.context.ApplicationContext; import org.springframework.context.ApplicationContext;
import org.springframework.context.ApplicationListener; import org.springframework.context.ApplicationListener;
import org.springframework.context.annotation.Bean; import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration; import org.springframework.context.annotation.Configuration;
import org.springframework.context.annotation.Import;
import org.springframework.core.env.Environment; import org.springframework.core.env.Environment;
import org.springframework.http.HttpStatus; import org.springframework.http.HttpStatus;
import org.springframework.http.MediaType; import org.springframework.http.MediaType;
...@@ -58,11 +56,24 @@ import static org.assertj.core.api.Assertions.assertThat; ...@@ -58,11 +56,24 @@ import static org.assertj.core.api.Assertions.assertThat;
* @author Andy Wilkinson * @author Andy Wilkinson
* @see WebFluxEndpointHandlerMapping * @see WebFluxEndpointHandlerMapping
*/ */
public class WebFluxEndpointIntegrationTests public class WebFluxEndpointIntegrationTests extends
extends AbstractWebEndpointIntegrationTests<ReactiveWebServerApplicationContext> { AbstractWebEndpointIntegrationTests<AnnotationConfigReactiveWebServerApplicationContext> {
public WebFluxEndpointIntegrationTests() { public WebFluxEndpointIntegrationTests() {
super(ReactiveConfiguration.class); super(WebFluxEndpointIntegrationTests::createApplicationContext,
WebFluxEndpointIntegrationTests::applyAuthenticatedConfiguration);
}
private static AnnotationConfigReactiveWebServerApplicationContext createApplicationContext() {
AnnotationConfigReactiveWebServerApplicationContext context = new AnnotationConfigReactiveWebServerApplicationContext();
context.register(ReactiveConfiguration.class);
return context;
}
private static void applyAuthenticatedConfiguration(
AnnotationConfigReactiveWebServerApplicationContext context) {
context.register(AuthenticatedConfiguration.class);
} }
@Test @Test
...@@ -89,23 +100,10 @@ public class WebFluxEndpointIntegrationTests ...@@ -89,23 +100,10 @@ public class WebFluxEndpointIntegrationTests
} }
@Override @Override
protected AnnotationConfigReactiveWebServerApplicationContext createApplicationContext( protected int getPort(AnnotationConfigReactiveWebServerApplicationContext context) {
Class<?>... config) {
AnnotationConfigReactiveWebServerApplicationContext context = new AnnotationConfigReactiveWebServerApplicationContext();
context.register(config);
return context;
}
@Override
protected int getPort(ReactiveWebServerApplicationContext context) {
return context.getBean(ReactiveConfiguration.class).port; return context.getBean(ReactiveConfiguration.class).port;
} }
@Override
protected Class<?> getSecuredPrincipalEndpointConfiguration() {
return SecuredPrincipalEndpointConfiguration.class;
}
@Configuration @Configuration
@EnableWebFlux @EnableWebFlux
@ImportAutoConfiguration(ErrorWebFluxAutoConfiguration.class) @ImportAutoConfiguration(ErrorWebFluxAutoConfiguration.class)
...@@ -144,8 +142,8 @@ public class WebFluxEndpointIntegrationTests ...@@ -144,8 +142,8 @@ public class WebFluxEndpointIntegrationTests
} }
@Import(PrincipalEndpointConfiguration.class) @Configuration
static class SecuredPrincipalEndpointConfiguration { static class AuthenticatedConfiguration {
@Bean @Bean
public WebFilter webFilter() { public WebFilter webFilter() {
......
...@@ -45,7 +45,6 @@ import org.springframework.boot.web.embedded.tomcat.TomcatServletWebServerFactor ...@@ -45,7 +45,6 @@ import org.springframework.boot.web.embedded.tomcat.TomcatServletWebServerFactor
import org.springframework.boot.web.servlet.context.AnnotationConfigServletWebServerApplicationContext; import org.springframework.boot.web.servlet.context.AnnotationConfigServletWebServerApplicationContext;
import org.springframework.context.annotation.Bean; import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration; import org.springframework.context.annotation.Configuration;
import org.springframework.context.annotation.Import;
import org.springframework.core.env.Environment; import org.springframework.core.env.Environment;
import org.springframework.http.HttpStatus; import org.springframework.http.HttpStatus;
import org.springframework.http.MediaType; import org.springframework.http.MediaType;
...@@ -64,7 +63,19 @@ public class MvcWebEndpointIntegrationTests extends ...@@ -64,7 +63,19 @@ public class MvcWebEndpointIntegrationTests extends
AbstractWebEndpointIntegrationTests<AnnotationConfigServletWebServerApplicationContext> { AbstractWebEndpointIntegrationTests<AnnotationConfigServletWebServerApplicationContext> {
public MvcWebEndpointIntegrationTests() { public MvcWebEndpointIntegrationTests() {
super(WebMvcConfiguration.class); super(MvcWebEndpointIntegrationTests::createApplicationContext,
MvcWebEndpointIntegrationTests::applyAuthenticatedConfiguration);
}
private static AnnotationConfigServletWebServerApplicationContext createApplicationContext() {
AnnotationConfigServletWebServerApplicationContext context = new AnnotationConfigServletWebServerApplicationContext();
context.register(WebMvcConfiguration.class);
return context;
}
private static void applyAuthenticatedConfiguration(
AnnotationConfigServletWebServerApplicationContext context) {
context.register(AuthenticatedConfiguration.class);
} }
@Test @Test
...@@ -90,24 +101,11 @@ public class MvcWebEndpointIntegrationTests extends ...@@ -90,24 +101,11 @@ public class MvcWebEndpointIntegrationTests extends
}); });
} }
@Override
protected AnnotationConfigServletWebServerApplicationContext createApplicationContext(
Class<?>... config) {
AnnotationConfigServletWebServerApplicationContext context = new AnnotationConfigServletWebServerApplicationContext();
context.register(config);
return context;
}
@Override @Override
protected int getPort(AnnotationConfigServletWebServerApplicationContext context) { protected int getPort(AnnotationConfigServletWebServerApplicationContext context) {
return context.getWebServer().getPort(); return context.getWebServer().getPort();
} }
@Override
protected Class<?> getSecuredPrincipalEndpointConfiguration() {
return SecuredPrincipalEndpointConfiguration.class;
}
@Configuration @Configuration
@ImportAutoConfiguration({ JacksonAutoConfiguration.class, @ImportAutoConfiguration({ JacksonAutoConfiguration.class,
HttpMessageConvertersAutoConfiguration.class, HttpMessageConvertersAutoConfiguration.class,
...@@ -137,8 +135,7 @@ public class MvcWebEndpointIntegrationTests extends ...@@ -137,8 +135,7 @@ public class MvcWebEndpointIntegrationTests extends
} }
@Configuration @Configuration
@Import(PrincipalEndpointConfiguration.class) static class AuthenticatedConfiguration {
static class SecuredPrincipalEndpointConfiguration {
@Bean @Bean
public Filter securityFilter() { public Filter securityFilter() {
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment