Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -57,12 +57,14 @@
import com.yahoo.rdl.Struct;

import com.linecorp.armeria.client.ClientFactory;
import com.linecorp.armeria.client.ClientTlsSpec;
import com.linecorp.armeria.client.ClientTlsSpecBuilder;
import com.linecorp.armeria.client.athenz.ZtsBaseClient;
import com.linecorp.armeria.common.TlsKeyPair;
import com.linecorp.armeria.common.TlsProvider;
import com.linecorp.armeria.common.annotation.Nullable;
import com.linecorp.armeria.common.util.TlsEngineType;
import com.linecorp.armeria.internal.common.SslContextFactory;
import com.linecorp.armeria.internal.common.SslContextFactory.SslContextMode;

import io.netty.handler.ssl.JdkSslContext;

Expand Down Expand Up @@ -238,13 +240,26 @@ private static JwtsSigningKeyResolver newDefaultJwtsSigningKeyResolver(ZtsBaseCl
}
final ClientFactory clientFactory = ztsBaseClient.clientFactory();
final TlsProvider tlsProvider = clientFactory.options().tlsProvider();
final SslContextFactory sslContextFactory = new SslContextFactory(tlsProvider, TlsEngineType.JDK,
null, clientFactory.meterRegistry());
final JdkSslContext sslContext = (JdkSslContext) sslContextFactory.getOrCreate(SslContextMode.CLIENT,
"*");
final SslContextFactory sslContextFactory = new SslContextFactory(null, clientFactory.meterRegistry());
final ClientTlsSpec clientTlsSpec = toTlsSpec(tlsProvider);
final JdkSslContext sslContext = (JdkSslContext) sslContextFactory.getOrCreate(clientTlsSpec);
return new JwtsSigningKeyResolver(ztsUri + oauth2KeysPath, sslContext.context(), proxyUriStr);
}

private static ClientTlsSpec toTlsSpec(TlsProvider tlsProvider) {
final ClientTlsSpecBuilder builder = ClientTlsSpec.builder();
final TlsKeyPair tlsKeyPair = tlsProvider.keyPair("*");
if (tlsKeyPair != null) {
builder.tlsKeyPair(tlsKeyPair);
}
final List<X509Certificate> trustedCertificates = tlsProvider.trustedCertificates("*");
if (trustedCertificates != null) {
builder.trustedCertificates(trustedCertificates);
}
builder.engineType(TlsEngineType.JDK);
return builder.build();
}

/**
* Set the role token allowed offset. this might be necessary
* if the client and server are not ntp synchronized, and we
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -338,6 +338,13 @@ public BlockingWebClientRequestPreparation responseTimeoutMode(ResponseTimeoutMo
return this;
}

@Override
@UnstableApi
public BlockingWebClientRequestPreparation clientTlsSpec(ClientTlsSpec clientTlsSpec) {
delegate.clientTlsSpec(clientTlsSpec);
return this;
}

@Override
public BlockingWebClientRequestPreparation requestOptions(RequestOptions requestOptions) {
delegate.requestOptions(requestOptions);
Expand Down
83 changes: 25 additions & 58 deletions core/src/main/java/com/linecorp/armeria/client/Bootstraps.java
Original file line number Diff line number Diff line change
Expand Up @@ -19,15 +19,14 @@
import static com.linecorp.armeria.common.SessionProtocol.httpAndHttpsValues;

import java.lang.reflect.Array;
import java.net.InetSocketAddress;
import java.net.SocketAddress;
import java.util.Objects;
import java.util.Set;

import com.linecorp.armeria.common.SerializationFormat;
import com.linecorp.armeria.common.SessionProtocol;
import com.linecorp.armeria.common.annotation.Nullable;
import com.linecorp.armeria.internal.common.SslContextFactory;
import com.linecorp.armeria.internal.common.SslContextFactory.SslContextMode;

import io.netty.bootstrap.Bootstrap;
import io.netty.channel.Channel;
Expand All @@ -39,28 +38,24 @@
final class Bootstraps {

private final EventLoop eventLoop;
private final SslContext sslCtxHttp1Only;
private final SslContext sslCtxHttp1Or2;
@Nullable
private final SslContextFactory sslContextFactory;

private final HttpClientFactory clientFactory;
private final Bootstrap inetBaseBootstrap;
private final DefaultSslContexts defaultSslContexts;
@Nullable
private final Bootstrap unixBaseBootstrap;
private final Bootstrap[][] inetBootstraps;
private final Bootstrap @Nullable [][] unixBootstraps;

Bootstraps(HttpClientFactory clientFactory, EventLoop eventLoop,
SslContext sslCtxHttp1Or2, SslContext sslCtxHttp1Only,
@Nullable SslContextFactory sslContextFactory) {
SslContextFactory sslContextFactory, DefaultSslContexts defaultSslContexts) {
this.eventLoop = eventLoop;
this.sslCtxHttp1Or2 = sslCtxHttp1Or2;
this.sslCtxHttp1Only = sslCtxHttp1Only;
this.sslContextFactory = sslContextFactory;
this.clientFactory = clientFactory;

inetBaseBootstrap = clientFactory.newInetBootstrap();
this.defaultSslContexts = defaultSslContexts;
inetBaseBootstrap.group(eventLoop);
inetBootstraps = staticBootstrapMap(inetBaseBootstrap);

Expand All @@ -80,20 +75,13 @@ private Bootstrap[][] staticBootstrapMap(Bootstrap baseBootstrap) {
// Attempting to access the array with an unallowed protocol will trigger NPE,
// which will help us find a bug.
for (SessionProtocol p : sessionProtocols) {
final SslContext sslCtx = determineSslContext(p);
final SslContext sslCtx = p.isTls() ? defaultSslContexts.getSslContext(p) : null;
createAndSetBootstrap(baseBootstrap, maps, p, sslCtx, true);
createAndSetBootstrap(baseBootstrap, maps, p, sslCtx, false);
}
return maps;
}

/**
* Determine {@link SslContext} by the specified {@link SessionProtocol}.
*/
SslContext determineSslContext(SessionProtocol desiredProtocol) {
return desiredProtocol.isExplicitHttp1() ? sslCtxHttp1Only : sslCtxHttp1Or2;
}

private Bootstrap select(boolean isDomainSocket, SessionProtocol desiredProtocol,
SerializationFormat serializationFormat) {
final Bootstrap[][] bootstraps = isDomainSocket ? unixBootstraps : inetBootstraps;
Expand All @@ -102,7 +90,7 @@ private Bootstrap select(boolean isDomainSocket, SessionProtocol desiredProtocol
}

private void createAndSetBootstrap(Bootstrap baseBootstrap, Bootstrap[][] maps,
SessionProtocol desiredProtocol, SslContext sslContext,
SessionProtocol desiredProtocol, @Nullable SslContext sslContext,
boolean webSocket) {
maps[desiredProtocol.ordinal()][toIndex(webSocket)] = newBootstrap(baseBootstrap, desiredProtocol,
sslContext, webSocket, false);
Expand All @@ -121,7 +109,7 @@ private static int toIndex(SerializationFormat serializationFormat) {
* {@link SessionProtocol} and {@link SerializationFormat}.
*/
Bootstrap getOrCreate(SocketAddress remoteAddress, SessionProtocol desiredProtocol,
SerializationFormat serializationFormat) {
SerializationFormat serializationFormat, ClientTlsSpec tlsSpec) {
if (!httpAndHttpsValues().contains(desiredProtocol)) {
throw new IllegalArgumentException("Unsupported session protocol: " + desiredProtocol);
}
Expand All @@ -132,71 +120,50 @@ Bootstrap getOrCreate(SocketAddress remoteAddress, SessionProtocol desiredProtoc
eventLoop.getClass().getName());
}

if (sslContextFactory == null || !desiredProtocol.isTls()) {
if (!desiredProtocol.isTls()) {
return select(isDomainSocket, desiredProtocol, serializationFormat);
}
final ClientTlsSpec defaultTlsSpec = defaultSslContexts.getClientTlsSpec(desiredProtocol);
if (Objects.equals(defaultTlsSpec, tlsSpec)) {
return select(isDomainSocket, desiredProtocol, serializationFormat);
}

final Bootstrap baseBootstrap = isDomainSocket ? unixBaseBootstrap : inetBaseBootstrap;
assert baseBootstrap != null;
return newBootstrap(baseBootstrap, remoteAddress, desiredProtocol, serializationFormat);
return newBootstrap(baseBootstrap, desiredProtocol, serializationFormat, tlsSpec);
}

private Bootstrap newBootstrap(Bootstrap baseBootstrap, SocketAddress remoteAddress,
private Bootstrap newBootstrap(Bootstrap baseBootstrap,
SessionProtocol desiredProtocol,
SerializationFormat serializationFormat) {
SerializationFormat serializationFormat, ClientTlsSpec tlsSpec) {
final boolean webSocket = serializationFormat == SerializationFormat.WS;
final SslContext sslContext = newSslContext(remoteAddress, desiredProtocol);
final SslContext sslContext = sslContextFactory.getOrCreate(tlsSpec);
return newBootstrap(baseBootstrap, desiredProtocol, sslContext, webSocket, true);
}

private Bootstrap newBootstrap(Bootstrap baseBootstrap, SessionProtocol desiredProtocol,
SslContext sslContext, boolean webSocket, boolean closeSslContext) {
@Nullable SslContext sslContext, boolean webSocket,
boolean closeSslContext) {
final Bootstrap bootstrap = baseBootstrap.clone();
bootstrap.handler(clientChannelInitializer(desiredProtocol, sslContext, webSocket, closeSslContext));
return bootstrap;
}

SslContext getOrCreateSslContext(SocketAddress remoteAddress, SessionProtocol desiredProtocol) {
if (sslContextFactory == null) {
return determineSslContext(desiredProtocol);
} else {
return newSslContext(remoteAddress, desiredProtocol);
}
SslContext getOrCreateSslContext(ClientTlsSpec tlsSpec) {
return sslContextFactory.getOrCreate(tlsSpec);
}

private SslContext newSslContext(SocketAddress remoteAddress, SessionProtocol desiredProtocol) {
final String hostname;
if (remoteAddress instanceof InetSocketAddress) {
hostname = ((InetSocketAddress) remoteAddress).getHostString();
} else {
assert remoteAddress instanceof DomainSocketAddress;
hostname = "unix:" + ((DomainSocketAddress) remoteAddress).path();
}

final SslContextMode sslContextMode =
desiredProtocol.isExplicitHttp1() ? SslContextFactory.SslContextMode.CLIENT_HTTP1_ONLY
: SslContextFactory.SslContextMode.CLIENT;
assert sslContextFactory != null;
return sslContextFactory.getOrCreate(sslContextMode, hostname);
}

boolean shouldReleaseSslContext(SslContext sslContext) {
return sslContext != sslCtxHttp1Only && sslContext != sslCtxHttp1Or2;
}

void releaseSslContext(SslContext sslContext) {
if (sslContextFactory != null) {
sslContextFactory.release(sslContext);
}
void release(SslContext sslContext) {
sslContextFactory.release(sslContext);
}

private ChannelInitializer<Channel> clientChannelInitializer(SessionProtocol p, SslContext sslCtx,
private ChannelInitializer<Channel> clientChannelInitializer(SessionProtocol p, @Nullable SslContext sslCtx,
boolean webSocket, boolean closeSslContext) {
return new ChannelInitializer<Channel>() {
@Override
protected void initChannel(Channel ch) throws Exception {
if (closeSslContext) {
ch.closeFuture().addListener(unused -> releaseSslContext(sslCtx));
if (closeSslContext && sslCtx != null) {
ch.closeFuture().addListener(unused -> release(sslCtx));
}
ch.pipeline().addLast(new HttpClientPipelineConfigurator(
clientFactory, webSocket, p, sslCtx));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,7 @@
import com.google.common.base.MoreObjects.ToStringHelper;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.ImmutableMap;
import com.google.common.collect.ImmutableSet;
import com.google.common.primitives.Ints;

import com.linecorp.armeria.client.proxy.ProxyConfig;
Expand All @@ -68,7 +69,6 @@
import com.linecorp.armeria.common.outlier.OutlierDetection;
import com.linecorp.armeria.common.util.EventLoopGroups;
import com.linecorp.armeria.common.util.TlsEngineType;
import com.linecorp.armeria.internal.common.IgnoreHostsTrustManager;
import com.linecorp.armeria.internal.common.RequestContextUtil;
import com.linecorp.armeria.internal.common.util.ChannelUtil;

Expand Down Expand Up @@ -132,6 +132,7 @@ public final class ClientFactoryBuilder implements TlsSetters {
private TlsProvider tlsProvider;
@Nullable
private ClientTlsConfig tlsConfig;
private ClientTlsSpec clientTlsSpec = ClientTlsSpec.of();
private boolean staticTlsSettingsSet;
private boolean autoCloseConnectionPoolListener = true;

Expand Down Expand Up @@ -421,8 +422,10 @@ public ClientFactoryBuilder tls(PrivateKey key, @Nullable String keyPassword,
@Override
public ClientFactoryBuilder tls(TlsKeyPair tlsKeyPair) {
requireNonNull(tlsKeyPair, "tlsKeyPair");
return tlsCustomizer(customizer -> customizer.keyManager(tlsKeyPair.privateKey(),
tlsKeyPair.certificateChain()));
ensureNoTlsProvider();
staticTlsSettingsSet = true;
clientTlsSpec = clientTlsSpec.toBuilder().tlsKeyPair(tlsKeyPair).build();
return this;
}

/**
Expand All @@ -431,7 +434,10 @@ public ClientFactoryBuilder tls(TlsKeyPair tlsKeyPair) {
@Override
public ClientFactoryBuilder tls(KeyManagerFactory keyManagerFactory) {
requireNonNull(keyManagerFactory, "keyManagerFactory");
return tlsCustomizer(customizer -> customizer.keyManager(keyManagerFactory));
ensureNoTlsProvider();
staticTlsSettingsSet = true;
clientTlsSpec = clientTlsSpec.toBuilder().keyManagerFactory(keyManagerFactory).build();
return this;
}

/**
Expand Down Expand Up @@ -1073,13 +1079,8 @@ private ClientFactoryOptions buildOptions() {
if (tlsConfig != null) {
option(ClientFactoryOptions.TLS_CONFIG, tlsConfig);
}
} else {
if (tlsNoVerifySet) {
tlsCustomizer(b -> b.trustManager(InsecureTrustManagerFactory.INSTANCE));
} else if (!insecureHosts.isEmpty()) {
tlsCustomizer(b -> b.trustManager(IgnoreHostsTrustManager.of(insecureHosts)));
}
}
option(ClientFactoryOptions.CLIENT_TLS_SPEC, clientTlsSpec);

final ClientFactoryOptions newOptions = ClientFactoryOptions.of(options.values());
final long maxConnectionAgeMillis = newOptions.maxConnectionAgeMillis();
Expand Down Expand Up @@ -1124,7 +1125,9 @@ private ClientFactoryOptions buildOptions() {
* Returns a newly-created {@link ClientFactory} based on the properties of this builder.
*/
public ClientFactory build() {
return new DefaultClientFactory(new HttpClientFactory(buildOptions(), autoCloseConnectionPoolListener));
final ClientFactoryOptions options = buildOptions();
return new DefaultClientFactory(new HttpClientFactory(
options, autoCloseConnectionPoolListener, tlsNoVerifySet, ImmutableSet.copyOf(insecureHosts)));
}

@Override
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -374,6 +374,10 @@ private static long clampedDefaultMaxClientConnectionAge() {
public static final ClientFactoryOption<Consumer<? super ChannelPipeline>> CHANNEL_PIPELINE_CUSTOMIZER =
ClientFactoryOption.define("CHANNEL_PIPELINE_CUSTOMIZER", v -> { /* no-op */ });

@UnstableApi
public static final ClientFactoryOption<ClientTlsSpec> CLIENT_TLS_SPEC =
ClientFactoryOption.define("CLIENT_TLS_SPEC", ClientTlsSpec.of());

private static final ClientFactoryOptions EMPTY = new ClientFactoryOptions(ImmutableList.of());

/**
Expand Down Expand Up @@ -744,4 +748,12 @@ public ClientTlsConfig tlsConfig() {
public Consumer<? super ChannelPipeline> channelPipelineCustomizer() {
return get(CHANNEL_PIPELINE_CUSTOMIZER);
}

/**
* Returns the default {@link ClientTlsSpec}.
*/
@UnstableApi
public ClientTlsSpec clientTlsSpec() {
return get(CLIENT_TLS_SPEC);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -619,6 +619,13 @@ default boolean isTimedOut() {
@UnstableApi
ResponseTimeoutMode responseTimeoutMode();

/**
* Returns the request-specific TLS configuration.
*/
@UnstableApi
@Nullable
ClientTlsSpec clientTlsSpec();

@Override
default ClientRequestContext unwrap() {
return (ClientRequestContext) RequestContext.super.unwrap();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
import com.linecorp.armeria.common.RequestId;
import com.linecorp.armeria.common.RpcRequest;
import com.linecorp.armeria.common.annotation.Nullable;
import com.linecorp.armeria.common.annotation.UnstableApi;
import com.linecorp.armeria.common.util.TimeoutMode;

/**
Expand Down Expand Up @@ -171,6 +172,12 @@ public ResponseTimeoutMode responseTimeoutMode() {
return unwrap().responseTimeoutMode();
}

@Override
@UnstableApi
public @Nullable ClientTlsSpec clientTlsSpec() {
return unwrap().clientTlsSpec();
}

@Override
public void hook(Supplier<? extends AutoCloseable> contextHook) {
unwrap().hook(contextHook);
Expand Down
Loading
Loading