Commit 1975d511 authored by Andy Wilkinson's avatar Andy Wilkinson

Add support for injecting a Principal into web endpoint operations

Closes gh-11941
parent d8de8752
......@@ -18,6 +18,7 @@ package org.springframework.boot.actuate.endpoint.web.jersey;
import java.io.IOException;
import java.io.InputStream;
import java.security.Principal;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Collections;
......@@ -150,6 +151,10 @@ public class JerseyEndpointResourceFactory {
}
arguments.putAll(extractPathParameters(data));
arguments.putAll(extractQueryParameters(data));
Principal principal = data.getSecurityContext().getUserPrincipal();
if (principal != null) {
arguments.put("principal", principal);
}
try {
Object response = this.operation.invoke(arguments);
return convertToJaxRsResponse(response, data.getRequest().getMethod());
......
......@@ -17,6 +17,7 @@
package org.springframework.boot.actuate.endpoint.web.reactive;
import java.lang.reflect.Method;
import java.security.Principal;
import java.util.Collection;
import java.util.LinkedHashMap;
import java.util.Map;
......@@ -251,24 +252,38 @@ public abstract class AbstractWebFluxEndpointHandlerMapping
* Adapter class to convert an {@link OperationInvoker} into a
* {@link ReactiveWebOperation}.
*/
private class ReactiveWebOperationAdapter implements ReactiveWebOperation {
private static final class ReactiveWebOperationAdapter
implements ReactiveWebOperation {
private static final Principal NO_PRINCIPAL = new Principal() {
@Override
public String getName() {
throw new UnsupportedOperationException();
}
};
private final OperationInvoker invoker;
ReactiveWebOperationAdapter(OperationInvoker invoker) {
private ReactiveWebOperationAdapter(OperationInvoker invoker) {
this.invoker = invoker;
}
@Override
public Mono<ResponseEntity<Object>> handle(ServerWebExchange exchange,
Map<String, String> body) {
Map<String, Object> arguments = getArguments(exchange, body);
return handleResult((Publisher<?>) this.invoker.invoke(arguments),
exchange.getRequest().getMethod());
return exchange.getPrincipal().defaultIfEmpty(NO_PRINCIPAL)
.flatMap((principal) -> {
Map<String, Object> arguments = getArguments(exchange, principal,
body);
return handleResult((Publisher<?>) this.invoker.invoke(arguments),
exchange.getRequest().getMethod());
});
}
private Map<String, Object> getArguments(ServerWebExchange exchange,
Map<String, String> body) {
Principal principal, Map<String, String> body) {
Map<String, Object> arguments = new LinkedHashMap<>();
arguments.putAll(getTemplateVariables(exchange));
if (body != null) {
......@@ -276,6 +291,9 @@ public abstract class AbstractWebFluxEndpointHandlerMapping
}
exchange.getRequest().getQueryParams().forEach((name, values) -> arguments
.put(name, values.size() == 1 ? values.get(0) : values));
if (principal != null && principal != NO_PRINCIPAL) {
arguments.put("principal", principal);
}
return arguments;
}
......
......@@ -17,6 +17,7 @@
package org.springframework.boot.actuate.endpoint.web.servlet;
import java.lang.reflect.Method;
import java.security.Principal;
import java.util.Arrays;
import java.util.Collection;
import java.util.LinkedHashMap;
......@@ -257,6 +258,10 @@ public abstract class AbstractWebMvcEndpointHandlerMapping
}
request.getParameterMap().forEach((name, values) -> arguments.put(name,
values.length == 1 ? values[0] : Arrays.asList(values)));
Principal principal = request.getUserPrincipal();
if (principal != null) {
arguments.put("principal", principal);
}
return arguments;
}
......
......@@ -17,6 +17,7 @@
package org.springframework.boot.actuate.endpoint.web.annotation;
import java.net.InetSocketAddress;
import java.security.Principal;
import java.time.Duration;
import java.util.Collections;
import java.util.HashMap;
......@@ -326,6 +327,22 @@ public abstract class AbstractWebEndpointIntegrationTests<T extends Configurable
.valueMatches("Content-Type", JSON_MEDIA_TYPE_PATTERN));
}
@Test
public void principalIsNullWhenRequestHasNoPrincipal() {
load(PrincipalEndpointConfiguration.class,
(client) -> client.get().uri("/principal")
.accept(MediaType.APPLICATION_JSON).exchange().expectStatus()
.isOk().expectBody(String.class).isEqualTo("None"));
}
@Test
public void principalIsAvailableWhenRequestHasAPrincipal() {
load(getSecuredPrincipalEndpointConfiguration(),
(client) -> client.get().uri("/principal")
.accept(MediaType.APPLICATION_JSON).exchange().expectStatus()
.isOk().expectBody(String.class).isEqualTo("Alice"));
}
protected abstract T createApplicationContext(Class<?>... config);
protected abstract int getPort(T context);
......@@ -360,6 +377,8 @@ public abstract class AbstractWebEndpointIntegrationTests<T extends Configurable
}
}
protected abstract Class<?> getSecuredPrincipalEndpointConfiguration();
protected void load(Class<?> configuration, Consumer<WebTestClient> clientConsumer) {
load(configuration, "/endpoints",
(context, client) -> clientConsumer.accept(client));
......@@ -517,6 +536,17 @@ public abstract class AbstractWebEndpointIntegrationTests<T extends Configurable
}
@Configuration
@Import(BaseConfiguration.class)
protected static class PrincipalEndpointConfiguration {
@Bean
public PrincipalEndpoint principalEndpoint() {
return new PrincipalEndpoint();
}
}
@Endpoint(id = "test")
static class TestEndpoint {
......@@ -695,6 +725,16 @@ public abstract class AbstractWebEndpointIntegrationTests<T extends Configurable
}
@Endpoint(id = "principal")
static class PrincipalEndpoint {
@ReadOperation
public String read(@Nullable Principal principal) {
return principal == null ? "None" : principal.getName();
}
}
public interface EndpointDelegate {
void write();
......
......@@ -16,9 +16,17 @@
package org.springframework.boot.actuate.endpoint.web.jersey;
import java.io.IOException;
import java.security.Principal;
import java.util.Collection;
import java.util.HashSet;
import javax.servlet.Filter;
import javax.servlet.FilterChain;
import javax.servlet.ServletException;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletRequestWrapper;
import javax.servlet.http.HttpServletResponse;
import javax.ws.rs.ext.ContextResolver;
import com.fasterxml.jackson.databind.ObjectMapper;
......@@ -36,9 +44,11 @@ import org.springframework.boot.web.servlet.ServletRegistrationBean;
import org.springframework.boot.web.servlet.context.AnnotationConfigServletWebServerApplicationContext;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import org.springframework.context.annotation.Import;
import org.springframework.core.env.Environment;
import org.springframework.http.HttpStatus;
import org.springframework.test.web.reactive.server.WebTestClient;
import org.springframework.web.filter.OncePerRequestFilter;
/**
* Integration tests for web endpoints exposed using Jersey.
......@@ -72,6 +82,11 @@ public class JerseyWebEndpointIntegrationTests extends
// Jersey doesn't support the general error page handling
}
@Override
protected Class<?> getSecuredPrincipalEndpointConfiguration() {
return SecuredPrincipalEndpointConfiguration.class;
}
@Configuration
static class JerseyConfiguration {
......@@ -105,6 +120,43 @@ public class JerseyWebEndpointIntegrationTests extends
}
@Configuration
@Import(PrincipalEndpointConfiguration.class)
static class SecuredPrincipalEndpointConfiguration {
@Bean
public Filter securityFilter() {
return new OncePerRequestFilter() {
@Override
protected void doFilterInternal(HttpServletRequest request,
HttpServletResponse response, FilterChain filterChain)
throws ServletException, IOException {
filterChain.doFilter(new HttpServletRequestWrapper(request) {
@Override
public Principal getUserPrincipal() {
return new Principal() {
@Override
public String getName() {
return "Alice";
}
};
}
}, response);
}
};
}
}
private static final class ObjectMapperContextResolver
implements ContextResolver<ObjectMapper> {
......
......@@ -16,9 +16,11 @@
package org.springframework.boot.actuate.endpoint.web.reactive;
import java.security.Principal;
import java.util.Arrays;
import org.junit.Test;
import reactor.core.publisher.Mono;
import org.springframework.boot.actuate.endpoint.web.EndpointMediaTypes;
import org.springframework.boot.actuate.endpoint.web.annotation.AbstractWebEndpointIntegrationTests;
......@@ -34,12 +36,17 @@ import org.springframework.context.ApplicationContext;
import org.springframework.context.ApplicationListener;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import org.springframework.context.annotation.Import;
import org.springframework.core.env.Environment;
import org.springframework.http.HttpStatus;
import org.springframework.http.MediaType;
import org.springframework.http.server.reactive.HttpHandler;
import org.springframework.web.cors.CorsConfiguration;
import org.springframework.web.reactive.config.EnableWebFlux;
import org.springframework.web.server.ServerWebExchange;
import org.springframework.web.server.ServerWebExchangeDecorator;
import org.springframework.web.server.WebFilter;
import org.springframework.web.server.WebFilterChain;
import org.springframework.web.server.adapter.WebHttpHandlerBuilder;
import static org.assertj.core.api.Assertions.assertThat;
......@@ -95,6 +102,11 @@ public class WebFluxEndpointIntegrationTests
return context.getBean(ReactiveConfiguration.class).port;
}
@Override
protected Class<?> getSecuredPrincipalEndpointConfiguration() {
return SecuredPrincipalEndpointConfiguration.class;
}
@Configuration
@EnableWebFlux
@ImportAutoConfiguration(ErrorWebFluxAutoConfiguration.class)
......@@ -132,4 +144,36 @@ public class WebFluxEndpointIntegrationTests
}
@Import(PrincipalEndpointConfiguration.class)
static class SecuredPrincipalEndpointConfiguration {
@Bean
public WebFilter webFilter() {
return new WebFilter() {
@Override
public Mono<Void> filter(ServerWebExchange exchange,
WebFilterChain chain) {
return chain.filter(new ServerWebExchangeDecorator(exchange) {
@Override
public Mono<Principal> getPrincipal() {
return Mono.just(new Principal() {
@Override
public String getName() {
return "Alice";
}
});
}
});
}
};
}
}
}
......@@ -16,8 +16,17 @@
package org.springframework.boot.actuate.endpoint.web.servlet;
import java.io.IOException;
import java.security.Principal;
import java.util.Arrays;
import javax.servlet.Filter;
import javax.servlet.FilterChain;
import javax.servlet.ServletException;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletRequestWrapper;
import javax.servlet.http.HttpServletResponse;
import org.junit.Test;
import org.springframework.boot.actuate.endpoint.web.EndpointMediaTypes;
......@@ -35,10 +44,12 @@ import org.springframework.boot.web.embedded.tomcat.TomcatServletWebServerFactor
import org.springframework.boot.web.servlet.context.AnnotationConfigServletWebServerApplicationContext;
import org.springframework.context.annotation.Bean;
import org.springframework.context.annotation.Configuration;
import org.springframework.context.annotation.Import;
import org.springframework.core.env.Environment;
import org.springframework.http.HttpStatus;
import org.springframework.http.MediaType;
import org.springframework.web.cors.CorsConfiguration;
import org.springframework.web.filter.OncePerRequestFilter;
import static org.assertj.core.api.Assertions.assertThat;
......@@ -93,6 +104,11 @@ public class MvcWebEndpointIntegrationTests extends
return context.getWebServer().getPort();
}
@Override
protected Class<?> getSecuredPrincipalEndpointConfiguration() {
return SecuredPrincipalEndpointConfiguration.class;
}
@Configuration
@ImportAutoConfiguration({ JacksonAutoConfiguration.class,
HttpMessageConvertersAutoConfiguration.class,
......@@ -120,4 +136,41 @@ public class MvcWebEndpointIntegrationTests extends
}
@Configuration
@Import(PrincipalEndpointConfiguration.class)
static class SecuredPrincipalEndpointConfiguration {
@Bean
public Filter securityFilter() {
return new OncePerRequestFilter() {
@Override
protected void doFilterInternal(HttpServletRequest request,
HttpServletResponse response, FilterChain filterChain)
throws ServletException, IOException {
filterChain.doFilter(new HttpServletRequestWrapper(request) {
@Override
public Principal getUserPrincipal() {
return new Principal() {
@Override
public String getName() {
return "Alice";
}
};
}
}, response);
}
};
}
}
}
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment