Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -35,35 +35,54 @@ private sealed class SafeSslContextCache : SafeHandleCache<SslContextCacheKey, S

internal readonly struct SslContextCacheKey : IEquatable<SslContextCacheKey>
{
private const int ThumbprintSize = 64; // SHA512 size

public readonly bool IsClient;
public readonly byte[]? CertificateThumbprint;
public readonly ReadOnlyMemory<byte> CertificateThumbprints;
public readonly SslProtocols SslProtocols;

public SslContextCacheKey(bool isClient, SslProtocols sslProtocols, byte[]? certificateThumbprint)
public SslContextCacheKey(bool isClient, SslProtocols sslProtocols, SslStreamCertificateContext? certContext)
{
IsClient = isClient;
SslProtocols = sslProtocols;
CertificateThumbprint = certificateThumbprint;

CertificateThumbprints = ReadOnlyMemory<byte>.Empty;

if (certContext != null)
{
int certCount = 1 + certContext.IntermediateCertificates.Count;
byte[] certificateThumbprints = new byte[certCount * ThumbprintSize];

bool success = certContext.TargetCertificate.TryGetCertHash(HashAlgorithmName.SHA512, certificateThumbprints.AsSpan(0, ThumbprintSize), out _);
Debug.Assert(success);

certCount = 1;
foreach (X509Certificate2 intermediate in certContext.IntermediateCertificates)
{
success = intermediate.TryGetCertHash(HashAlgorithmName.SHA512, certificateThumbprints.AsSpan(certCount * ThumbprintSize, ThumbprintSize), out _);
Debug.Assert(success);
certCount++;
}

CertificateThumbprints = certificateThumbprints;
}
}

public override bool Equals(object? obj) => obj is SslContextCacheKey key && Equals(key);

public bool Equals(SslContextCacheKey other) =>

IsClient == other.IsClient &&
SslProtocols == other.SslProtocols &&
(CertificateThumbprint == null && other.CertificateThumbprint == null ||
CertificateThumbprint != null && other.CertificateThumbprint != null && CertificateThumbprint.AsSpan().SequenceEqual(other.CertificateThumbprint));
CertificateThumbprints.Span.SequenceEqual(other.CertificateThumbprints.Span) &&
SslProtocols == other.SslProtocols;

public override int GetHashCode()
{
HashCode hash = default;

hash.Add(IsClient);
hash.AddBytes(CertificateThumbprints.Span);
hash.Add(SslProtocols);
if (CertificateThumbprint != null)
{
hash.AddBytes(CertificateThumbprint);
}

return hash.ToHashCode();
}
Expand Down Expand Up @@ -172,7 +191,7 @@ internal static SafeSslContextHandle GetOrCreateSslContextHandle(SslAuthenticati
var key = new SslContextCacheKey(
sslAuthenticationOptions.IsClient,
sslAuthenticationOptions.IsClient ? protocols : serverProtocolCacheKey,
sslAuthenticationOptions.CertificateContext?.TargetCertificate.GetCertHash(HashAlgorithmName.SHA512));
sslAuthenticationOptions.CertificateContext);
return s_sslContexts.GetOrCreate(key, static (args) =>
{
var (sslAuthOptions, protocols, allowCached) = args;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -46,24 +46,40 @@ private sealed class MsQuicConfigurationCache : SafeHandleCache<CacheKey, MsQuic

private readonly struct CacheKey : IEquatable<CacheKey>
{
public readonly List<byte[]> CertificateThumbprints;
private const int ThumbprintSize = 64; // SHA512 size

public readonly ReadOnlyMemory<byte> CertificateThumbprints;
public readonly QUIC_CREDENTIAL_FLAGS Flags;
public readonly QUIC_SETTINGS Settings;
public readonly List<SslApplicationProtocol> ApplicationProtocols;
public readonly QUIC_ALLOWED_CIPHER_SUITE_FLAGS AllowedCipherSuites;

public CacheKey(QUIC_SETTINGS settings, QUIC_CREDENTIAL_FLAGS flags, X509Certificate? certificate, ReadOnlyCollection<X509Certificate2>? intermediates, List<SslApplicationProtocol> alpnProtocols, QUIC_ALLOWED_CIPHER_SUITE_FLAGS allowedCipherSuites)
{
CertificateThumbprints = certificate == null ? new List<byte[]>() : new List<byte[]> { certificate.GetCertHash(HashAlgorithmName.SHA512) };
int certCount = certificate == null ? 0 : 1;
certCount += intermediates?.Count ?? 0;
byte[] certificateThumbprints = new byte[certCount * ThumbprintSize];

certCount = 0;
if (certificate != null)
{
bool success = certificate.TryGetCertHash(HashAlgorithmName.SHA512, certificateThumbprints.AsSpan(0, ThumbprintSize), out _);
Debug.Assert(success);
certCount++;
}

if (intermediates != null)
{
foreach (X509Certificate2 intermediate in intermediates)
{
CertificateThumbprints.Add(intermediate.GetCertHash(HashAlgorithmName.SHA512));
bool success = intermediate.TryGetCertHash(HashAlgorithmName.SHA512, certificateThumbprints.AsSpan(certCount * ThumbprintSize, ThumbprintSize), out _);
Debug.Assert(success);
certCount++;
}
}

CertificateThumbprints = certificateThumbprints;

Flags = flags;
Settings = settings;
// make defensive copy to prevent modification (the list comes from user code)
Expand All @@ -75,19 +91,11 @@ public CacheKey(QUIC_SETTINGS settings, QUIC_CREDENTIAL_FLAGS flags, X509Certifi

public bool Equals(CacheKey other)
{
if (CertificateThumbprints.Count != other.CertificateThumbprints.Count)
if (!CertificateThumbprints.Span.SequenceEqual(other.CertificateThumbprints.Span))
{
return false;
}

for (int i = 0; i < CertificateThumbprints.Count; i++)
{
if (!CertificateThumbprints[i].AsSpan().SequenceEqual(other.CertificateThumbprints[i]))
{
return false;
}
}

if (ApplicationProtocols.Count != other.ApplicationProtocols.Count)
{
return false;
Expand All @@ -111,11 +119,7 @@ public override int GetHashCode()
{
HashCode hash = default;

foreach (var thumbprint in CertificateThumbprints)
{
hash.AddBytes(thumbprint);
}

hash.AddBytes(CertificateThumbprints.Span);
hash.Add(Flags);
hash.Add(Settings);

Expand Down
Loading