Commit ee89e0ef authored by Andy Wilkinson's avatar Andy Wilkinson

Fix client auth with Jetty

Fixes gh-17541
parent e07889b0
...@@ -68,7 +68,8 @@ class SslServerCustomizer implements JettyServerCustomizer { ...@@ -68,7 +68,8 @@ class SslServerCustomizer implements JettyServerCustomizer {
@Override @Override
public void customize(Server server) { public void customize(Server server) {
SslContextFactory sslContextFactory = new SslContextFactory(); SslContextFactory.Server sslContextFactory = new SslContextFactory.Server();
sslContextFactory.setEndpointIdentificationAlgorithm(null);
configureSsl(sslContextFactory, this.ssl, this.sslStoreProvider); configureSsl(sslContextFactory, this.ssl, this.sslStoreProvider);
ServerConnector connector = createConnector(server, sslContextFactory, this.address); ServerConnector connector = createConnector(server, sslContextFactory, this.address);
server.setConnectors(new Connector[] { connector }); server.setConnectors(new Connector[] { connector });
...@@ -131,7 +132,7 @@ class SslServerCustomizer implements JettyServerCustomizer { ...@@ -131,7 +132,7 @@ class SslServerCustomizer implements JettyServerCustomizer {
* @param ssl the ssl details. * @param ssl the ssl details.
* @param sslStoreProvider the ssl store provider * @param sslStoreProvider the ssl store provider
*/ */
protected void configureSsl(SslContextFactory factory, Ssl ssl, SslStoreProvider sslStoreProvider) { protected void configureSsl(SslContextFactory.Server factory, Ssl ssl, SslStoreProvider sslStoreProvider) {
factory.setProtocol(ssl.getProtocol()); factory.setProtocol(ssl.getProtocol());
configureSslClientAuth(factory, ssl); configureSslClientAuth(factory, ssl);
configureSslPasswords(factory, ssl); configureSslPasswords(factory, ssl);
...@@ -158,7 +159,7 @@ class SslServerCustomizer implements JettyServerCustomizer { ...@@ -158,7 +159,7 @@ class SslServerCustomizer implements JettyServerCustomizer {
} }
} }
private void configureSslClientAuth(SslContextFactory factory, Ssl ssl) { private void configureSslClientAuth(SslContextFactory.Server factory, Ssl ssl) {
if (ssl.getClientAuth() == Ssl.ClientAuth.NEED) { if (ssl.getClientAuth() == Ssl.ClientAuth.NEED) {
factory.setNeedClientAuth(true); factory.setNeedClientAuth(true);
factory.setWantClientAuth(true); factory.setWantClientAuth(true);
......
...@@ -81,7 +81,8 @@ public class SslServerCustomizerTests { ...@@ -81,7 +81,8 @@ public class SslServerCustomizerTests {
Ssl ssl = new Ssl(); Ssl ssl = new Ssl();
SslServerCustomizer customizer = new SslServerCustomizer(null, ssl, null, null); SslServerCustomizer customizer = new SslServerCustomizer(null, ssl, null, null);
assertThatExceptionOfType(Exception.class) assertThatExceptionOfType(Exception.class)
.isThrownBy(() -> customizer.configureSsl(new SslContextFactory(), ssl, null)).satisfies((ex) -> { .isThrownBy(() -> customizer.configureSsl(new SslContextFactory.Server(), ssl, null))
.satisfies((ex) -> {
assertThat(ex).isInstanceOf(WebServerException.class); assertThat(ex).isInstanceOf(WebServerException.class);
assertThat(ex).hasMessageContaining("Could not load key store 'null'"); assertThat(ex).hasMessageContaining("Could not load key store 'null'");
}); });
......
...@@ -21,15 +21,11 @@ import java.io.FileInputStream; ...@@ -21,15 +21,11 @@ import java.io.FileInputStream;
import java.net.InetSocketAddress; import java.net.InetSocketAddress;
import java.nio.charset.StandardCharsets; import java.nio.charset.StandardCharsets;
import java.security.KeyStore; import java.security.KeyStore;
import java.security.PrivateKey;
import java.security.cert.X509Certificate;
import java.time.Duration; import java.time.Duration;
import java.util.Arrays; import java.util.Arrays;
import javax.net.ssl.KeyManager;
import javax.net.ssl.KeyManagerFactory; import javax.net.ssl.KeyManagerFactory;
import javax.net.ssl.SSLException; import javax.net.ssl.SSLException;
import javax.net.ssl.X509KeyManager;
import io.netty.channel.ChannelHandlerContext; import io.netty.channel.ChannelHandlerContext;
import io.netty.channel.ChannelInboundHandlerAdapter; import io.netty.channel.ChannelInboundHandlerAdapter;
...@@ -170,23 +166,12 @@ public abstract class AbstractReactiveWebServerFactoryTests { ...@@ -170,23 +166,12 @@ public abstract class AbstractReactiveWebServerFactoryTests {
KeyManagerFactory clientKeyManagerFactory = KeyManagerFactory KeyManagerFactory clientKeyManagerFactory = KeyManagerFactory
.getInstance(KeyManagerFactory.getDefaultAlgorithm()); .getInstance(KeyManagerFactory.getDefaultAlgorithm());
clientKeyManagerFactory.init(clientKeyStore, "password".toCharArray()); clientKeyManagerFactory.init(clientKeyStore, "password".toCharArray());
for (KeyManager keyManager : clientKeyManagerFactory.getKeyManagers()) {
if (keyManager instanceof X509KeyManager) {
X509KeyManager x509KeyManager = (X509KeyManager) keyManager;
PrivateKey privateKey = x509KeyManager.getPrivateKey("spring-boot");
if (privateKey != null) {
X509Certificate[] certificateChain = x509KeyManager.getCertificateChain("spring-boot");
SslContextBuilder builder = SslContextBuilder.forClient().sslProvider(SslProvider.JDK) SslContextBuilder builder = SslContextBuilder.forClient().sslProvider(SslProvider.JDK)
.trustManager(InsecureTrustManagerFactory.INSTANCE) .trustManager(InsecureTrustManagerFactory.INSTANCE).keyManager(clientKeyManagerFactory);
.keyManager(privateKey, certificateChain);
HttpClient client = HttpClient.create().wiretap(true) HttpClient client = HttpClient.create().wiretap(true)
.secure((sslContextSpec) -> sslContextSpec.sslContext(builder)); .secure((sslContextSpec) -> sslContextSpec.sslContext(builder));
return new ReactorClientHttpConnector(client); return new ReactorClientHttpConnector(client);
} }
}
}
throw new IllegalStateException("Key with alias 'spring-boot' not found");
}
protected void testClientAuthSuccess(Ssl sslConfiguration, ReactorClientHttpConnector clientConnector) { protected void testClientAuthSuccess(Ssl sslConfiguration, ReactorClientHttpConnector clientConnector) {
AbstractReactiveWebServerFactory factory = getFactory(); AbstractReactiveWebServerFactory factory = getFactory();
......
...@@ -25,7 +25,6 @@ import java.io.PrintWriter; ...@@ -25,7 +25,6 @@ import java.io.PrintWriter;
import java.net.InetSocketAddress; import java.net.InetSocketAddress;
import java.net.MalformedURLException; import java.net.MalformedURLException;
import java.net.ServerSocket; import java.net.ServerSocket;
import java.net.Socket;
import java.net.URI; import java.net.URI;
import java.net.URISyntaxException; import java.net.URISyntaxException;
import java.net.URL; import java.net.URL;
...@@ -75,8 +74,6 @@ import org.apache.http.conn.ssl.TrustSelfSignedStrategy; ...@@ -75,8 +74,6 @@ import org.apache.http.conn.ssl.TrustSelfSignedStrategy;
import org.apache.http.impl.client.HttpClientBuilder; import org.apache.http.impl.client.HttpClientBuilder;
import org.apache.http.impl.client.HttpClients; import org.apache.http.impl.client.HttpClients;
import org.apache.http.protocol.HttpContext; import org.apache.http.protocol.HttpContext;
import org.apache.http.ssl.PrivateKeyDetails;
import org.apache.http.ssl.PrivateKeyStrategy;
import org.apache.http.ssl.SSLContextBuilder; import org.apache.http.ssl.SSLContextBuilder;
import org.apache.http.ssl.TrustStrategy; import org.apache.http.ssl.TrustStrategy;
import org.apache.jasper.EmbeddedServletOptions; import org.apache.jasper.EmbeddedServletOptions;
...@@ -402,7 +399,7 @@ public abstract class AbstractServletWebServerFactoryTests { ...@@ -402,7 +399,7 @@ public abstract class AbstractServletWebServerFactoryTests {
new ExampleServlet(true, false), "/hello"); new ExampleServlet(true, false), "/hello");
this.webServer = factory.getWebServer(registration); this.webServer = factory.getWebServer(registration);
this.webServer.start(); this.webServer.start();
TrustStrategy trustStrategy = new SerialNumberValidatingTrustSelfSignedStrategy("5c7ae101"); TrustStrategy trustStrategy = new SerialNumberValidatingTrustSelfSignedStrategy("3a3aaec8");
SSLContext sslContext = new SSLContextBuilder().loadTrustMaterial(null, trustStrategy).build(); SSLContext sslContext = new SSLContextBuilder().loadTrustMaterial(null, trustStrategy).build();
HttpClient httpClient = HttpClients.custom().setSSLSocketFactory(new SSLConnectionSocketFactory(sslContext)) HttpClient httpClient = HttpClients.custom().setSSLSocketFactory(new SSLConnectionSocketFactory(sslContext))
.build(); .build();
...@@ -464,14 +461,7 @@ public abstract class AbstractServletWebServerFactoryTests { ...@@ -464,14 +461,7 @@ public abstract class AbstractServletWebServerFactoryTests {
keyStore.load(new FileInputStream(new File("src/test/resources/test.p12")), "secret".toCharArray()); keyStore.load(new FileInputStream(new File("src/test/resources/test.p12")), "secret".toCharArray());
SSLConnectionSocketFactory socketFactory = new SSLConnectionSocketFactory( SSLConnectionSocketFactory socketFactory = new SSLConnectionSocketFactory(
new SSLContextBuilder().loadTrustMaterial(null, new TrustSelfSignedStrategy()) new SSLContextBuilder().loadTrustMaterial(null, new TrustSelfSignedStrategy())
.loadKeyMaterial(keyStore, "secret".toCharArray(), new PrivateKeyStrategy() { .loadKeyMaterial(keyStore, "secret".toCharArray()).build());
@Override
public String chooseAlias(Map<String, PrivateKeyDetails> aliases, Socket socket) {
return "spring-boot";
}
}).build());
HttpClient httpClient = HttpClients.custom().setSSLSocketFactory(socketFactory).build(); HttpClient httpClient = HttpClients.custom().setSSLSocketFactory(socketFactory).build();
HttpComponentsClientHttpRequestFactory requestFactory = new HttpComponentsClientHttpRequestFactory(httpClient); HttpComponentsClientHttpRequestFactory requestFactory = new HttpComponentsClientHttpRequestFactory(httpClient);
assertThat(getResponse(getLocalUrl("https", "/test.txt"), requestFactory)).isEqualTo("test"); assertThat(getResponse(getLocalUrl("https", "/test.txt"), requestFactory)).isEqualTo("test");
...@@ -488,13 +478,7 @@ public abstract class AbstractServletWebServerFactoryTests { ...@@ -488,13 +478,7 @@ public abstract class AbstractServletWebServerFactoryTests {
keyStore.load(new FileInputStream(new File("src/test/resources/test.jks")), "secret".toCharArray()); keyStore.load(new FileInputStream(new File("src/test/resources/test.jks")), "secret".toCharArray());
SSLConnectionSocketFactory socketFactory = new SSLConnectionSocketFactory( SSLConnectionSocketFactory socketFactory = new SSLConnectionSocketFactory(
new SSLContextBuilder().loadTrustMaterial(null, new TrustSelfSignedStrategy()) new SSLContextBuilder().loadTrustMaterial(null, new TrustSelfSignedStrategy())
.loadKeyMaterial(keyStore, "password".toCharArray(), new PrivateKeyStrategy() { .loadKeyMaterial(keyStore, "password".toCharArray()).build());
@Override
public String chooseAlias(Map<String, PrivateKeyDetails> aliases, Socket socket) {
return "spring-boot";
}
}).build());
HttpClient httpClient = HttpClients.custom().setSSLSocketFactory(socketFactory).build(); HttpClient httpClient = HttpClients.custom().setSSLSocketFactory(socketFactory).build();
HttpComponentsClientHttpRequestFactory requestFactory = new HttpComponentsClientHttpRequestFactory(httpClient); HttpComponentsClientHttpRequestFactory requestFactory = new HttpComponentsClientHttpRequestFactory(httpClient);
assertThat(getResponse(getLocalUrl("https", "/test.txt"), requestFactory)).isEqualTo("test"); assertThat(getResponse(getLocalUrl("https", "/test.txt"), requestFactory)).isEqualTo("test");
...@@ -565,13 +549,7 @@ public abstract class AbstractServletWebServerFactoryTests { ...@@ -565,13 +549,7 @@ public abstract class AbstractServletWebServerFactoryTests {
keyStore.load(new FileInputStream(new File("src/test/resources/test.jks")), "secret".toCharArray()); keyStore.load(new FileInputStream(new File("src/test/resources/test.jks")), "secret".toCharArray());
SSLConnectionSocketFactory socketFactory = new SSLConnectionSocketFactory( SSLConnectionSocketFactory socketFactory = new SSLConnectionSocketFactory(
new SSLContextBuilder().loadTrustMaterial(null, new TrustSelfSignedStrategy()) new SSLContextBuilder().loadTrustMaterial(null, new TrustSelfSignedStrategy())
.loadKeyMaterial(keyStore, "password".toCharArray(), new PrivateKeyStrategy() { .loadKeyMaterial(keyStore, "password".toCharArray()).build());
@Override
public String chooseAlias(Map<String, PrivateKeyDetails> aliases, Socket socket) {
return "spring-boot";
}
}).build());
HttpClient httpClient = HttpClients.custom().setSSLSocketFactory(socketFactory).build(); HttpClient httpClient = HttpClients.custom().setSSLSocketFactory(socketFactory).build();
HttpComponentsClientHttpRequestFactory requestFactory = new HttpComponentsClientHttpRequestFactory(httpClient); HttpComponentsClientHttpRequestFactory requestFactory = new HttpComponentsClientHttpRequestFactory(httpClient);
assertThat(getResponse(getLocalUrl("https", "/test.txt"), requestFactory)).isEqualTo("test"); assertThat(getResponse(getLocalUrl("https", "/test.txt"), requestFactory)).isEqualTo("test");
......
Markdown is supported
0% or
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment