Skip to content
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
// Licensed under the MIT License.

using System;
using System.Collections.Concurrent;
using System.Runtime.InteropServices;
using System.Threading;
using System.Threading.Tasks;
Expand All @@ -12,10 +13,19 @@ namespace Microsoft.Identity.Client.KeyAttestation
{
/// <summary>
/// Static facade for attesting a Credential Guard/CNG key and getting a JWT back.
/// Caches valid MAA tokens to avoid redundant native DLL calls and network round-trips.
/// Key discovery / rotation is the caller's responsibility.
/// </summary>
internal static class PopKeyAttestor
{
// In-process cache: maps "{endpoint}|{clientId}" → AttestationToken (JWT + expiry).
private static readonly ConcurrentDictionary<string, AttestationToken> s_tokenCache =
new ConcurrentDictionary<string, AttestationToken>(StringComparer.OrdinalIgnoreCase);

// Tokens within this window of expiry are considered stale and will be refreshed.
// Matches MSAL's AccessTokenExpirationBuffer (5 minutes).
internal static TimeSpan s_expirationBuffer = TimeSpan.FromMinutes(5);

/// <summary>
/// Test hook to inject a mock attestation provider for unit testing.
/// When set, this delegate is called instead of loading the native DLL.
Expand All @@ -25,8 +35,18 @@ internal static class PopKeyAttestor
/// Tests should not run in parallel when using this hook to avoid race conditions.
/// </remarks>
internal static Func<string, SafeHandle, string, CancellationToken, Task<AttestationResult>> s_testAttestationProvider;

/// <summary>
/// Resets the MAA token cache. Call from [TestCleanup] to prevent cache state leaking between tests.
/// </summary>
internal static void ResetCacheForTest()
{
s_tokenCache.Clear();
}

/// <summary>
/// Asynchronously attests a Credential Guard/CNG key with the remote attestation service and returns a JWT.
/// Returns a cached token if one is available and not within the expiration buffer.
/// Wraps the synchronous <see cref="AttestationClient.Attest"/> in a Task.Run so callers can
/// avoid blocking. Cancellation only applies before the native call starts.
/// </summary>
Expand All @@ -40,39 +60,87 @@ public static Task<AttestationResult> AttestCredentialGuardAsync(
string clientId,
CancellationToken cancellationToken = default)
{
if (keyHandle is null)
throw new ArgumentNullException(nameof(keyHandle));

if (string.IsNullOrWhiteSpace(endpoint))
throw new ArgumentNullException(nameof(endpoint));

cancellationToken.ThrowIfCancellationRequested();

// Check the in-process cache before making any native/network calls.
// Key validation is intentionally deferred: a cache hit returns immediately
// without requiring (or validating) the key handle.
string cacheKey = BuildCacheKey(endpoint, clientId);
if (TryGetCachedToken(cacheKey, out AttestationToken cached))
{
return Task.FromResult(new AttestationResult(
AttestationStatus.Success, cached, cached.Token, 0, null));
}

// Cache miss — validate the key handle before any native/network call.
if (keyHandle is null)
throw new ArgumentNullException(nameof(keyHandle));

if (keyHandle.IsInvalid)
throw new ArgumentException("keyHandle is invalid", nameof(keyHandle));

var safeNCryptKeyHandle = keyHandle as SafeNCryptKeyHandle
?? throw new ArgumentException("keyHandle must be a SafeNCryptKeyHandle. Only Windows CNG keys are supported.", nameof(keyHandle));

cancellationToken.ThrowIfCancellationRequested();

// Check for test provider to avoid loading native DLL in unit tests
// Check for test provider to avoid loading native DLL in unit tests.
if (s_testAttestationProvider != null)
{
return s_testAttestationProvider(endpoint, keyHandle, clientId, cancellationToken);
return AttestAndCacheAsync(
s_testAttestationProvider(endpoint, keyHandle, clientId, cancellationToken),
cacheKey);
}

return Task.Run(() =>
{
try
{
using var client = new AttestationClient();
return client.Attest(endpoint, safeNCryptKeyHandle, clientId ?? string.Empty);
}
catch (Exception ex)
return AttestAndCacheAsync(
Task.Run(() =>
{
// Map any managed exception to AttestationStatus.Exception for consistency.
return new AttestationResult(AttestationStatus.Exception, null, string.Empty, -1, ex.Message);
}
}, cancellationToken);
try
{
using var client = new AttestationClient();
return client.Attest(endpoint, safeNCryptKeyHandle, clientId ?? string.Empty);
}
catch (Exception ex)
{
return new AttestationResult(AttestationStatus.Exception, null, string.Empty, -1, ex.Message);
}
}, cancellationToken),
cacheKey);
}

/// <summary>
/// Awaits the attestation task and writes the result to the cache on success.
/// </summary>
private static async Task<AttestationResult> AttestAndCacheAsync(
Task<AttestationResult> attestTask,
string cacheKey)
{
AttestationResult result = await attestTask.ConfigureAwait(false);

if (result.Status == AttestationStatus.Success && result.Token != null)
{
s_tokenCache[cacheKey] = result.Token;
}

return result;
}

private static bool TryGetCachedToken(string cacheKey, out AttestationToken token)
{
if (s_tokenCache.TryGetValue(cacheKey, out token) &&
token.ExpiresOn - s_expirationBuffer > DateTimeOffset.UtcNow)
{
return true;
}

token = null;
return false;
}

private static string BuildCacheKey(string endpoint, string clientId)
{
return $"{endpoint}|{clientId ?? string.Empty}";
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
// Licensed under the MIT License.

using System;
using System.Collections.Concurrent;
using System.IO;
using System.Security.Cryptography;
using System.Security.Cryptography.X509Certificates;
Expand All @@ -13,6 +14,7 @@
using Microsoft.Identity.Client.Internal;
using Microsoft.Identity.Client.Internal.Logger;
using Microsoft.Identity.Client.KeyAttestation;
using Microsoft.Identity.Client.KeyAttestation.Attestation;
using Microsoft.Identity.Client.ManagedIdentity;
using Microsoft.Identity.Client.ManagedIdentity.KeyProviders;
using Microsoft.Identity.Client.ManagedIdentity.V2;
Expand Down Expand Up @@ -64,6 +66,9 @@ public void ImdsV2Tests_Cleanup()

// Reset test provider to ensure clean state for other tests
PopKeyAttestor.s_testAttestationProvider = null;

// Reset MAA token cache so cached tokens don't leak between tests
PopKeyAttestor.ResetCacheForTest();
}

private void AddMocksToGetEntraToken(
Expand Down Expand Up @@ -948,6 +953,153 @@ await mi.AcquireTokenForManagedIdentity(ManagedIdentityTests.Resource)
}
}

[TestMethod]
public async Task MaaTokenCache_Hit_DoesNotCallAttestationProviderAgain()
{
// Arrange: track how many times the attestation provider is called
int providerCallCount = 0;
PopKeyAttestor.s_testAttestationProvider = (endpoint, keyHandle, clientId, ct) =>
{
Interlocked.Increment(ref providerCallCount);
var fakeJwt = "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.fake.sig";
var token = new AttestationToken(fakeJwt, DateTimeOffset.UtcNow.AddHours(1));
return Task.FromResult(new AttestationResult(AttestationStatus.Success, token, fakeJwt, 0, string.Empty));
};

using (new EnvVariableContext())
using (var httpManager = new MockHttpManager())
{
ManagedIdentityClient.ResetSourceForTest();
SetEnvironmentVariables(ManagedIdentitySource.ImdsV2, TestConstants.ImdsEndpoint);

var rawCert = CreateRawCertForCsrKeyWithCnDc(
Constants.ManagedIdentityDefaultClientId, TestConstants.TenantId,
DateTimeOffset.UtcNow.AddHours(25));

var mi = await CreateManagedIdentityAsync(httpManager, managedIdentityKeyType: ManagedIdentityKeyType.KeyGuard).ConfigureAwait(false);

// First acquire: mints cert + calls MAA
AddMocksToGetEntraToken(httpManager, certificateRequestCertificate: rawCert);
await mi.AcquireTokenForManagedIdentity(ManagedIdentityTests.Resource)
.WithMtlsProofOfPossession()
.WithAttestationSupport()
.ExecuteAsync().ConfigureAwait(false);

Assert.AreEqual(1, providerCallCount, "MAA should be called once on first acquire.");

// Second acquire: cert is cached, so MAA token cache should also be hit
var second = await mi.AcquireTokenForManagedIdentity(ManagedIdentityTests.Resource)
.WithMtlsProofOfPossession()
.WithAttestationSupport()
.ExecuteAsync().ConfigureAwait(false);

Assert.AreEqual(1, providerCallCount, "MAA should NOT be called again when token is cached.");
Assert.AreEqual(TokenSource.Cache, second.AuthenticationResultMetadata.TokenSource);
}
}

[TestMethod]
public async Task MaaTokenCache_ExpiredToken_CallsAttestationProviderAgain()
{
// Arrange: use a very short expiration buffer so we can simulate expiry
var originalBuffer = PopKeyAttestor.s_expirationBuffer;
PopKeyAttestor.s_expirationBuffer = TimeSpan.Zero;

int providerCallCount = 0;
PopKeyAttestor.s_testAttestationProvider = (endpoint, keyHandle, clientId, ct) =>
{
Interlocked.Increment(ref providerCallCount);
// Return a token that is already expired
var expiredJwt = "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.expired.sig";
var token = new AttestationToken(expiredJwt, DateTimeOffset.UtcNow.AddSeconds(-1));
return Task.FromResult(new AttestationResult(AttestationStatus.Success, token, expiredJwt, 0, string.Empty));
};

try
{
using (new EnvVariableContext())
using (var httpManager = new MockHttpManager())
{
ManagedIdentityClient.ResetSourceForTest();
SetEnvironmentVariables(ManagedIdentitySource.ImdsV2, TestConstants.ImdsEndpoint);

var rawCert = CreateRawCertForCsrKeyWithCnDc(
Constants.ManagedIdentityDefaultClientId, TestConstants.TenantId,
DateTimeOffset.UtcNow.AddHours(25));

var mi = await CreateManagedIdentityAsync(httpManager, managedIdentityKeyType: ManagedIdentityKeyType.KeyGuard).ConfigureAwait(false);

// First acquire
AddMocksToGetEntraToken(httpManager, certificateRequestCertificate: rawCert);
await mi.AcquireTokenForManagedIdentity(ManagedIdentityTests.Resource)
.WithMtlsProofOfPossession()
.WithAttestationSupport()
.ExecuteAsync().ConfigureAwait(false);

Assert.AreEqual(1, providerCallCount);

// Force-refresh to bypass the MSAL token cache and re-enter cert/MAA logic
ManagedIdentityClient.ResetSourceForTest();
ImdsV2ManagedIdentitySource.ResetCertCacheForTest();

var mi2 = await CreateManagedIdentityAsync(httpManager, managedIdentityKeyType: ManagedIdentityKeyType.KeyGuard).ConfigureAwait(false);
AddMocksToGetEntraToken(httpManager, certificateRequestCertificate: rawCert);

await mi2.AcquireTokenForManagedIdentity(ManagedIdentityTests.Resource)
.WithMtlsProofOfPossession()
.WithAttestationSupport()
.ExecuteAsync().ConfigureAwait(false);

Assert.AreEqual(2, providerCallCount, "MAA should be called again when cached token is expired.");
}
}
finally
{
PopKeyAttestor.s_expirationBuffer = originalBuffer;
}
}

[TestMethod]
public async Task MaaTokenCache_DifferentEndpoints_CachedSeparately()
{
// Arrange: two providers for two different endpoints
var callCounts = new ConcurrentDictionary<string, int>();
PopKeyAttestor.s_testAttestationProvider = (endpoint, keyHandle, clientId, ct) =>
{
callCounts.AddOrUpdate(endpoint, 1, (_, c) => c + 1);
var fakeJwt = $"eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.{Uri.EscapeDataString(endpoint)}.sig";
var token = new AttestationToken(fakeJwt, DateTimeOffset.UtcNow.AddHours(1));
return Task.FromResult(new AttestationResult(AttestationStatus.Success, token, fakeJwt, 0, string.Empty));
};

const string endpoint1 = "https://eastus.attestation.azure.net";
const string endpoint2 = "https://westus.attestation.azure.net";

// Use a real RSACng key handle so validation passes on first (cache-miss) calls
using var rsa1 = new RSACng(2048);
using var rsa2 = new RSACng(2048);
var handle1 = rsa1.Key.Handle;
var handle2 = rsa2.Key.Handle;

// First calls: cache miss → MAA is invoked
var result1a = await PopKeyAttestor.AttestCredentialGuardAsync(
endpoint1, handle1, "client1", CancellationToken.None).ConfigureAwait(false);
var result2a = await PopKeyAttestor.AttestCredentialGuardAsync(
endpoint2, handle2, "client1", CancellationToken.None).ConfigureAwait(false);

// Second calls: cache hit → key handle not needed, pass null
var result1b = await PopKeyAttestor.AttestCredentialGuardAsync(
endpoint1, null, "client1", CancellationToken.None).ConfigureAwait(false);
var result2b = await PopKeyAttestor.AttestCredentialGuardAsync(
endpoint2, null, "client1", CancellationToken.None).ConfigureAwait(false);

Assert.AreEqual(1, callCounts[endpoint1], "endpoint1 MAA should only be called once.");
Assert.AreEqual(1, callCounts[endpoint2], "endpoint2 MAA should only be called once.");
Assert.AreEqual(result1a.Jwt, result1b.Jwt, "endpoint1 cached JWT should match.");
Assert.AreEqual(result2a.Jwt, result2b.Jwt, "endpoint2 cached JWT should match.");
Assert.AreNotEqual(result1a.Jwt, result2a.Jwt, "Different endpoints should produce different tokens.");
}

#endregion

#region Cached certificate tests
Expand Down Expand Up @@ -1722,3 +1874,4 @@ private static void AssertCertSubjectCnDc(X509Certificate2 cert, string expected
#endregion
}
}

Loading