diff --git a/Microsoft.Azure.Cosmos.Samples/Usage/CachingKeyResolverSample/CachingKeyEncryptionKey.cs b/Microsoft.Azure.Cosmos.Samples/Usage/CachingKeyResolverSample/CachingKeyEncryptionKey.cs new file mode 100644 index 0000000000..41cafe7bf3 --- /dev/null +++ b/Microsoft.Azure.Cosmos.Samples/Usage/CachingKeyResolverSample/CachingKeyEncryptionKey.cs @@ -0,0 +1,183 @@ +// ------------------------------------------------------------ +// Copyright (c) Microsoft Corporation. All rights reserved. +// ------------------------------------------------------------ + +namespace CachingKeyResolverSample +{ + using System; + using System.Collections.Concurrent; + using System.Threading; + using System.Threading.Tasks; + using Azure.Core.Cryptography; + + /// + /// A decorator around that caches + /// results in memory. On an AKV failure, the cached unwrapped bytes are returned as a + /// stale-while-revalidate fallback — making the data path resilient to AKV outages. + /// + /// This is the critical missing piece for AKV resilience. The + /// caches the key handle (Resolve), but the + /// handle's UnwrapKey is still a live HTTP POST to AKV. This class caches that + /// HTTP result so the data path survives AKV unavailability. + /// + /// + /// + /// Thread-safe: uses . + /// WrapKey is NOT cached — wrapping is a write-path operation that + /// must always go to AKV for correctness. + /// Stale fallback is best-effort: if the KEK was rotated in AKV and the + /// cache holds bytes unwrapped with the old KEK, decryption will fail + /// naturally because the DEK bytes won't match. This is safe. + /// + /// + public sealed class CachingKeyEncryptionKey : IKeyEncryptionKey + { + private readonly IKeyEncryptionKey innerKey; + private readonly TimeSpan cacheTtl; + + /// + /// Cache of unwrapped DEK bytes, keyed by the hex-encoded encrypted key. + /// + private readonly ConcurrentDictionary unwrapCache; + + /// + /// Initializes a new instance of . + /// + /// The real (e.g., from AKV). + /// How long unwrapped bytes are considered fresh. After this, + /// the next call attempts a live unwrap but falls back to cached bytes on failure. + /// Optional shared cache across multiple key instances. + /// If null, a private cache is created. + public CachingKeyEncryptionKey( + IKeyEncryptionKey innerKey, + TimeSpan cacheTtl, + ConcurrentDictionary sharedCache = null) + { + this.innerKey = innerKey ?? throw new ArgumentNullException(nameof(innerKey)); + this.cacheTtl = cacheTtl; + this.unwrapCache = sharedCache ?? new ConcurrentDictionary(StringComparer.Ordinal); + } + + /// + public string KeyId => this.innerKey.KeyId; + + /// + public byte[] WrapKey(string algorithm, ReadOnlyMemory key, CancellationToken cancellationToken = default) + { + // WrapKey is a write-path operation — always go to AKV. + return this.innerKey.WrapKey(algorithm, key, cancellationToken); + } + + /// + public async Task WrapKeyAsync(string algorithm, ReadOnlyMemory key, CancellationToken cancellationToken = default) + { + return await this.innerKey.WrapKeyAsync(algorithm, key, cancellationToken).ConfigureAwait(false); + } + + /// + public byte[] UnwrapKey(string algorithm, ReadOnlyMemory encryptedKey, CancellationToken cancellationToken = default) + { + string cacheKey = Convert.ToHexString(encryptedKey.Span); + + // Fresh cache hit → return immediately. + if (this.TryGetFresh(cacheKey, out byte[] cached)) + { + return cached; + } + + // Attempt live unwrap. + try + { + byte[] unwrapped = this.innerKey.UnwrapKey(algorithm, encryptedKey, cancellationToken); + this.PutCache(cacheKey, unwrapped); + return unwrapped; + } + catch (Exception) when (this.TryGetStale(cacheKey, out byte[] stale)) + { + // AKV failure + stale cache available → serve stale bytes. + return stale; + } + } + + /// + public async Task UnwrapKeyAsync(string algorithm, ReadOnlyMemory encryptedKey, CancellationToken cancellationToken = default) + { + string cacheKey = Convert.ToHexString(encryptedKey.Span); + + // Fresh cache hit → return immediately. + if (this.TryGetFresh(cacheKey, out byte[] cached)) + { + return cached; + } + + // Attempt live unwrap. + try + { + byte[] unwrapped = await this.innerKey.UnwrapKeyAsync(algorithm, encryptedKey, cancellationToken).ConfigureAwait(false); + this.PutCache(cacheKey, unwrapped); + return unwrapped; + } + catch (Exception) when (this.TryGetStale(cacheKey, out byte[] stale)) + { + // AKV failure + stale cache available → serve stale bytes. + return stale; + } + } + + private bool TryGetFresh(string cacheKey, out byte[] unwrappedBytes) + { + if (this.unwrapCache.TryGetValue(cacheKey, out CachedUnwrapEntry entry) + && DateTime.UtcNow < entry.FreshUntilUtc) + { + unwrappedBytes = entry.UnwrappedBytes; + return true; + } + + unwrappedBytes = null; + return false; + } + + private bool TryGetStale(string cacheKey, out byte[] unwrappedBytes) + { + // Stale entries are usable regardless of TTL — the key bytes don't "go bad." + // Security: if the KEK was rotated, the stale bytes will produce incorrect + // DEK bytes and decryption will fail, so this is safe. + if (this.unwrapCache.TryGetValue(cacheKey, out CachedUnwrapEntry entry)) + { + unwrappedBytes = entry.UnwrappedBytes; + return true; + } + + unwrappedBytes = null; + return false; + } + + private void PutCache(string cacheKey, byte[] unwrappedBytes) + { + this.unwrapCache[cacheKey] = new CachedUnwrapEntry(unwrappedBytes, DateTime.UtcNow.Add(this.cacheTtl)); + } + + /// + /// Immutable record holding cached unwrapped key bytes and their freshness timestamp. + /// + public sealed class CachedUnwrapEntry + { + public CachedUnwrapEntry(byte[] unwrappedBytes, DateTime freshUntilUtc) + { + this.UnwrappedBytes = unwrappedBytes; + this.FreshUntilUtc = freshUntilUtc; + } + + /// + /// The raw unwrapped DEK bytes. + /// + public byte[] UnwrappedBytes { get; } + + /// + /// When this entry is considered stale. After this time, a live unwrap is + /// attempted first, but the entry can still be used as a fallback on failure. + /// + public DateTime FreshUntilUtc { get; } + } + } +} diff --git a/Microsoft.Azure.Cosmos.Samples/Usage/CachingKeyResolverSample/CachingKeyResolver.cs b/Microsoft.Azure.Cosmos.Samples/Usage/CachingKeyResolverSample/CachingKeyResolver.cs new file mode 100644 index 0000000000..1af0a4862d --- /dev/null +++ b/Microsoft.Azure.Cosmos.Samples/Usage/CachingKeyResolverSample/CachingKeyResolver.cs @@ -0,0 +1,235 @@ +// ------------------------------------------------------------ +// Copyright (c) Microsoft Corporation. All rights reserved. +// ------------------------------------------------------------ + +namespace CachingKeyResolverSample +{ + using System; + using System.Collections.Concurrent; + using System.Threading; + using System.Threading.Tasks; + using Azure.Core; + using Azure.Core.Cryptography; + using Azure.Security.KeyVault.Keys.Cryptography; + + /// + /// A caching wrapper around (or any ) + /// that keeps resolved instances in memory so that subsequent + /// calls to and return instantly with zero I/O. + /// + /// A background timer proactively refreshes entries that are approaching expiry so that + /// callers never experience a cache miss on the hot path. + /// + /// + /// This is a customer-facing sample. Review the caching strategy, TTL, and security + /// implications for your own workload before using in production. + /// + public sealed class CachingKeyResolver : IKeyEncryptionKeyResolver, IDisposable, IAsyncDisposable + { + private readonly IKeyEncryptionKeyResolver innerResolver; + private readonly CachingKeyResolverOptions options; + private readonly ConcurrentDictionary cache; + private readonly ConcurrentDictionary refreshesInFlight; + private readonly Timer refreshTimer; + private readonly CancellationTokenSource disposalCts; + private readonly ConcurrentDictionary sharedUnwrapCache; + private int disposed; + + /// + /// Raised when a key is served from cache (true) or resolved from the inner resolver (false). + /// Useful for diagnostics and testing. + /// + public event Action OnCacheAccess; + + /// + /// Initializes a new instance of using an Azure Key Vault + /// created from the provided . + /// + /// Token credential used to authenticate to Azure Key Vault. + /// Caching options. If null, defaults are used. + public CachingKeyResolver(TokenCredential credential, CachingKeyResolverOptions options = null) + : this(new KeyResolver(credential), options) + { + } + + /// + /// Initializes a new instance of using the provided + /// inner resolver. This constructor is useful for testing or wrapping custom resolvers. + /// + /// The inner resolver to delegate to on cache miss. + /// Caching options. If null, defaults are used. + public CachingKeyResolver(IKeyEncryptionKeyResolver innerResolver, CachingKeyResolverOptions options = null) + { + this.innerResolver = innerResolver ?? throw new ArgumentNullException(nameof(innerResolver)); + this.options = options ?? new CachingKeyResolverOptions(); + this.cache = new ConcurrentDictionary(StringComparer.OrdinalIgnoreCase); + this.refreshesInFlight = new ConcurrentDictionary(StringComparer.OrdinalIgnoreCase); + this.disposalCts = new CancellationTokenSource(); + this.sharedUnwrapCache = this.options.UnwrapKeyCacheTimeToLive > TimeSpan.Zero + ? new ConcurrentDictionary(StringComparer.Ordinal) + : null; + + this.refreshTimer = new Timer( + callback: this.BackgroundRefreshCallback, + state: null, + dueTime: this.options.RefreshTimerInterval, + period: this.options.RefreshTimerInterval); + } + + /// + public IKeyEncryptionKey Resolve(string keyId, CancellationToken cancellationToken = default) + { + this.ThrowIfDisposed(); + + if (this.cache.TryGetValue(keyId, out CachedKeyEntry entry) && entry.ExpiresUtc > DateTime.UtcNow) + { + this.OnCacheAccess?.Invoke(keyId, true); + return entry.Key; + } + + IKeyEncryptionKey resolved = this.innerResolver.Resolve(keyId, cancellationToken); + IKeyEncryptionKey stored = this.CacheEntry(keyId, resolved); + this.OnCacheAccess?.Invoke(keyId, false); + return stored; + } + + /// + public async Task ResolveAsync(string keyId, CancellationToken cancellationToken = default) + { + this.ThrowIfDisposed(); + + if (this.cache.TryGetValue(keyId, out CachedKeyEntry entry) && entry.ExpiresUtc > DateTime.UtcNow) + { + this.OnCacheAccess?.Invoke(keyId, true); + return entry.Key; + } + + IKeyEncryptionKey resolved = await this.innerResolver.ResolveAsync(keyId, cancellationToken).ConfigureAwait(false); + IKeyEncryptionKey stored = this.CacheEntry(keyId, resolved); + this.OnCacheAccess?.Invoke(keyId, false); + return stored; + } + + /// + public void Dispose() + { + if (Interlocked.Exchange(ref this.disposed, 1) != 0) + { + return; + } + + this.refreshTimer.Dispose(); + this.disposalCts.Cancel(); + this.disposalCts.Dispose(); + this.cache.Clear(); + } + + /// + public async ValueTask DisposeAsync() + { + if (Interlocked.Exchange(ref this.disposed, 1) != 0) + { + return; + } + + await this.refreshTimer.DisposeAsync().ConfigureAwait(false); + this.disposalCts.Cancel(); + this.disposalCts.Dispose(); + this.cache.Clear(); + } + + private IKeyEncryptionKey CacheEntry(string keyId, IKeyEncryptionKey key) + { + IKeyEncryptionKey wrappedKey = this.WrapKeyIfEnabled(key); + CachedKeyEntry newEntry = new CachedKeyEntry(wrappedKey, DateTime.UtcNow.Add(this.options.KeyCacheTimeToLive)); + this.cache[keyId] = newEntry; + return wrappedKey; + } + + /// + /// Wraps the resolved key with if + /// is configured. + /// Already-wrapped keys are returned as-is to avoid double-wrapping on refresh. + /// + private IKeyEncryptionKey WrapKeyIfEnabled(IKeyEncryptionKey key) + { + if (this.sharedUnwrapCache == null || key is CachingKeyEncryptionKey) + { + return key; + } + + return new CachingKeyEncryptionKey(key, this.options.UnwrapKeyCacheTimeToLive, this.sharedUnwrapCache); + } + + private void BackgroundRefreshCallback(object state) + { + if (this.disposed != 0) + { + return; + } + + foreach (var kvp in this.cache) + { + string keyId = kvp.Key; + CachedKeyEntry entry = kvp.Value; + + TimeSpan timeUntilExpiry = entry.ExpiresUtc - DateTime.UtcNow; + + if (timeUntilExpiry <= this.options.ProactiveRefreshThreshold) + { + if (!this.refreshesInFlight.TryAdd(keyId, 0)) + { + // A refresh is already in flight for this key. + continue; + } + + _ = Task.Run(async () => + { + try + { + CancellationToken ct = this.disposalCts.Token; + IKeyEncryptionKey refreshed = await this.innerResolver + .ResolveAsync(keyId, ct) + .ConfigureAwait(false); + + this.CacheEntry(keyId, refreshed); + } + catch + { + // Swallow: failed refresh should not evict the existing cached entry. + // The old entry remains usable until it fully expires. + } + finally + { + this.refreshesInFlight.TryRemove(keyId, out _); + } + }); + } + } + } + + private void ThrowIfDisposed() + { + if (this.disposed != 0) + { + throw new ObjectDisposedException(nameof(CachingKeyResolver)); + } + } + + /// + /// Internal representation of a cached key entry. + /// + internal sealed class CachedKeyEntry + { + public CachedKeyEntry(IKeyEncryptionKey key, DateTime expiresUtc) + { + this.Key = key; + this.ExpiresUtc = expiresUtc; + } + + public IKeyEncryptionKey Key { get; } + + public DateTime ExpiresUtc { get; } + } + } +} diff --git a/Microsoft.Azure.Cosmos.Samples/Usage/CachingKeyResolverSample/CachingKeyResolverOptions.cs b/Microsoft.Azure.Cosmos.Samples/Usage/CachingKeyResolverSample/CachingKeyResolverOptions.cs new file mode 100644 index 0000000000..c7469a3ed9 --- /dev/null +++ b/Microsoft.Azure.Cosmos.Samples/Usage/CachingKeyResolverSample/CachingKeyResolverOptions.cs @@ -0,0 +1,52 @@ +// ------------------------------------------------------------ +// Copyright (c) Microsoft Corporation. All rights reserved. +// ------------------------------------------------------------ + +namespace CachingKeyResolverSample +{ + using System; + + /// + /// Configuration options for . + /// + public sealed class CachingKeyResolverOptions + { + /// + /// How long cached key encryption keys remain valid before they must be re-resolved + /// from the inner resolver. Default: 2 hours. + /// + /// + /// This value must be shorter than your key rotation interval to ensure that + /// rotated keys are picked up in a timely manner. + /// + public TimeSpan KeyCacheTimeToLive { get; init; } = TimeSpan.FromHours(2); + + /// + /// How long before a cache entry expires to proactively start a background refresh. + /// Default: 5 minutes. + /// + /// + /// Setting this to a value larger than effectively + /// disables proactive refresh. + /// + public TimeSpan ProactiveRefreshThreshold { get; init; } = TimeSpan.FromMinutes(5); + + /// + /// How often the background timer fires to check for cache entries needing refresh. + /// Default: 1 minute. + /// + public TimeSpan RefreshTimerInterval { get; init; } = TimeSpan.FromMinutes(1); + + /// + /// How long cached UnwrapKey results (raw DEK bytes) are considered fresh. + /// After this time, a live AKV call is attempted first, but the cached bytes are + /// still used as a stale fallback if AKV is unavailable. + /// Default: 24 hours — safe when key rotation is infrequent. + /// + /// + /// Set to to disable UnwrapKey caching entirely + /// (the resolver will still cache the key handle from Resolve). + /// + public TimeSpan UnwrapKeyCacheTimeToLive { get; init; } = TimeSpan.FromHours(24); + } +} diff --git a/Microsoft.Azure.Cosmos.Samples/Usage/CachingKeyResolverSample/CachingKeyResolverSample.csproj b/Microsoft.Azure.Cosmos.Samples/Usage/CachingKeyResolverSample/CachingKeyResolverSample.csproj new file mode 100644 index 0000000000..6cda4c64b5 --- /dev/null +++ b/Microsoft.Azure.Cosmos.Samples/Usage/CachingKeyResolverSample/CachingKeyResolverSample.csproj @@ -0,0 +1,25 @@ + + + + Exe + net8.0 + CachingKeyResolverSample + 10.0 + + + + + + + + + + + + + + + + + + diff --git a/Microsoft.Azure.Cosmos.Samples/Usage/CachingKeyResolverSample/Program.cs b/Microsoft.Azure.Cosmos.Samples/Usage/CachingKeyResolverSample/Program.cs new file mode 100644 index 0000000000..611602f93d --- /dev/null +++ b/Microsoft.Azure.Cosmos.Samples/Usage/CachingKeyResolverSample/Program.cs @@ -0,0 +1,191 @@ +// ------------------------------------------------------------ +// Copyright (c) Microsoft Corporation. All rights reserved. +// ------------------------------------------------------------ + +namespace CachingKeyResolverSample +{ + using System; + using System.Collections.ObjectModel; + using System.Threading.Tasks; + using Azure.Identity; + using Microsoft.Azure.Cosmos; + using Microsoft.Azure.Cosmos.Encryption; + + /// + /// Sample demonstrating how to wrap Azure Key Vault's KeyResolver with an + /// in-memory cache to eliminate synchronous AKV I/O during concurrent + /// encryption / decryption operations. + /// + /// + /// Prerequisites: + /// + /// An Azure Cosmos DB account with client-side encryption support. + /// An Azure Key Vault with a key created for encryption. + /// Azure.Identity credentials configured (DefaultAzureCredential). + /// + /// + /// Set the following environment variables before running: + /// + /// COSMOS_ENDPOINT — Cosmos DB account URI + /// COSMOS_KEY — Cosmos DB account key + /// KEY_VAULT_URL — Azure Key Vault URI (e.g. https://my-vault.vault.azure.net/) + /// ENCRYPTION_KEY_NAME — Name of the key in Key Vault + /// + /// + public class Program + { + private const string DatabaseId = "CachingKeyResolverSampleDb"; + private const string ContainerId = "EncryptedItems"; + private const string DekName = "sampleDek"; + + public static async Task Main(string[] args) + { + string cosmosEndpoint = Environment.GetEnvironmentVariable("COSMOS_ENDPOINT"); + string cosmosKey = Environment.GetEnvironmentVariable("COSMOS_KEY"); + string keyVaultUrl = Environment.GetEnvironmentVariable("KEY_VAULT_URL"); + string encryptionKeyName = Environment.GetEnvironmentVariable("ENCRYPTION_KEY_NAME"); + + if (string.IsNullOrEmpty(cosmosEndpoint) + || string.IsNullOrEmpty(cosmosKey) + || string.IsNullOrEmpty(keyVaultUrl) + || string.IsNullOrEmpty(encryptionKeyName)) + { + Console.WriteLine("ERROR: Set COSMOS_ENDPOINT, COSMOS_KEY, KEY_VAULT_URL, and ENCRYPTION_KEY_NAME environment variables."); + return; + } + + string keyId = $"{keyVaultUrl.TrimEnd('/')}/keys/{encryptionKeyName}"; + + int cacheHits = 0; + int cacheMisses = 0; + + // ── Create CachingKeyResolver ─────────────────────────────────── + CachingKeyResolverOptions options = new CachingKeyResolverOptions + { + KeyCacheTimeToLive = TimeSpan.FromHours(2), + ProactiveRefreshThreshold = TimeSpan.FromMinutes(5), + RefreshTimerInterval = TimeSpan.FromMinutes(1), + }; + + using CachingKeyResolver cachingResolver = new CachingKeyResolver( + new DefaultAzureCredential(), + options); + + cachingResolver.OnCacheAccess += (id, isHit) => + { + if (isHit) + { + System.Threading.Interlocked.Increment(ref cacheHits); + Console.WriteLine($" [CACHE HIT] {id}"); + } + else + { + System.Threading.Interlocked.Increment(ref cacheMisses); + Console.WriteLine($" [CACHE MISS] {id}"); + } + }; + + // ── Create Cosmos client with encryption ──────────────────────── + CosmosClient baseClient = new CosmosClient(cosmosEndpoint, cosmosKey); + CosmosClient encryptionClient = baseClient.WithEncryption( + cachingResolver, + KeyEncryptionKeyResolverName.AzureKeyVault); + + try + { + Console.WriteLine("Setting up database and container..."); + + Database database = await encryptionClient.CreateDatabaseIfNotExistsAsync(DatabaseId); + + // Create client encryption key (DEK) wrapped by the AKV key (KEK). + EncryptionKeyWrapMetadata wrapMetadata = new EncryptionKeyWrapMetadata( + KeyEncryptionKeyResolverName.AzureKeyVault, + encryptionKeyName, + keyId, + "RSA-OAEP"); + + try + { + await database.CreateClientEncryptionKeyAsync( + DekName, + DataEncryptionAlgorithm.AeadAes256CbcHmacSha256, + wrapMetadata); + Console.WriteLine($"Created client encryption key '{DekName}'."); + } + catch (CosmosException ex) when (ex.StatusCode == System.Net.HttpStatusCode.Conflict) + { + Console.WriteLine($"Client encryption key '{DekName}' already exists."); + } + + // Create container with encryption policy. + Collection paths = new Collection + { + new ClientEncryptionIncludedPath + { + Path = "/secret", + ClientEncryptionKeyId = DekName, + EncryptionType = "Deterministic", + EncryptionAlgorithm = "AEAD_AES_256_CBC_HMAC_SHA256", + }, + }; + + ClientEncryptionPolicy policy = new ClientEncryptionPolicy(paths, policyFormatVersion: 2); + ContainerProperties containerProps = new ContainerProperties(ContainerId, "/pk") + { + ClientEncryptionPolicy = policy, + }; + + Container container; + try + { + container = await database.CreateContainerAsync(containerProps); + Console.WriteLine($"Created container '{ContainerId}' with encryption policy."); + } + catch (CosmosException ex) when (ex.StatusCode == System.Net.HttpStatusCode.Conflict) + { + container = database.GetContainer(ContainerId); + Console.WriteLine($"Container '{ContainerId}' already exists."); + } + + container = await container.InitializeEncryptionAsync(); + + // ── Insert an encrypted document ──────────────────────────── + var document = new + { + id = Guid.NewGuid().ToString(), + pk = "sample", + secret = "This value is encrypted at rest.", + visible = "This value is NOT encrypted.", + }; + + Console.WriteLine("\nInserting encrypted document..."); + await container.CreateItemAsync(document, new PartitionKey(document.pk)); + Console.WriteLine("Document inserted."); + + // ── Read back the document ────────────────────────────────── + Console.WriteLine("Reading document back..."); + var response = await container.ReadItemAsync(document.id, new PartitionKey(document.pk)); + Console.WriteLine($"Read document: {response.Resource}"); + + // ── Summary ───────────────────────────────────────────────── + Console.WriteLine($"\n══════════════════════════════════════"); + Console.WriteLine($" Cache hits: {cacheHits}"); + Console.WriteLine($" Cache misses: {cacheMisses}"); + Console.WriteLine($"══════════════════════════════════════"); + } + finally + { + // Cleanup + try + { + await encryptionClient.GetDatabase(DatabaseId).DeleteAsync(); + Console.WriteLine($"\nCleaned up database '{DatabaseId}'."); + } + catch + { + // Best-effort cleanup. + } + } + } + } +} diff --git a/Microsoft.Azure.Cosmos.Samples/Usage/CachingKeyResolverSample/Tests/CachingKeyEncryptionKeyTests.cs b/Microsoft.Azure.Cosmos.Samples/Usage/CachingKeyResolverSample/Tests/CachingKeyEncryptionKeyTests.cs new file mode 100644 index 0000000000..d132e382c0 --- /dev/null +++ b/Microsoft.Azure.Cosmos.Samples/Usage/CachingKeyResolverSample/Tests/CachingKeyEncryptionKeyTests.cs @@ -0,0 +1,285 @@ +// ------------------------------------------------------------ +// Copyright (c) Microsoft Corporation. All rights reserved. +// ------------------------------------------------------------ + +namespace CachingKeyResolverSample.Tests +{ + using System; + using System.Collections.Concurrent; + using System.Threading; + using System.Threading.Tasks; + using Azure.Core.Cryptography; + using Microsoft.VisualStudio.TestTools.UnitTesting; + using Moq; + + [TestClass] + public class CachingKeyEncryptionKeyTests + { + private static readonly byte[] TestEncryptedKey = new byte[] { 0x01, 0x02, 0x03, 0x04 }; + private static readonly byte[] TestUnwrappedKey = new byte[] { 0xAA, 0xBB, 0xCC, 0xDD }; + private const string TestAlgorithm = "RSA-OAEP"; + + #region Fresh cache hit tests + + [TestMethod] + public void UnwrapKey_FirstCall_DelegatesToInnerKey() + { + Mock mockInner = CreateMockInnerKey(); + CachingKeyEncryptionKey sut = new CachingKeyEncryptionKey(mockInner.Object, TimeSpan.FromHours(1)); + + byte[] result = sut.UnwrapKey(TestAlgorithm, TestEncryptedKey); + + CollectionAssert.AreEqual(TestUnwrappedKey, result); + mockInner.Verify( + k => k.UnwrapKey(TestAlgorithm, It.IsAny>(), It.IsAny()), + Times.Once); + } + + [TestMethod] + public void UnwrapKey_SecondCall_ReturnsCached_NoInnerCall() + { + Mock mockInner = CreateMockInnerKey(); + CachingKeyEncryptionKey sut = new CachingKeyEncryptionKey(mockInner.Object, TimeSpan.FromHours(1)); + + sut.UnwrapKey(TestAlgorithm, TestEncryptedKey); + byte[] result2 = sut.UnwrapKey(TestAlgorithm, TestEncryptedKey); + + CollectionAssert.AreEqual(TestUnwrappedKey, result2); + mockInner.Verify( + k => k.UnwrapKey(TestAlgorithm, It.IsAny>(), It.IsAny()), + Times.Once); // only the first call hit AKV + } + + [TestMethod] + public async Task UnwrapKeyAsync_SecondCall_ReturnsCached() + { + Mock mockInner = CreateMockInnerKeyAsync(); + CachingKeyEncryptionKey sut = new CachingKeyEncryptionKey(mockInner.Object, TimeSpan.FromHours(1)); + + await sut.UnwrapKeyAsync(TestAlgorithm, TestEncryptedKey); + byte[] result2 = await sut.UnwrapKeyAsync(TestAlgorithm, TestEncryptedKey); + + CollectionAssert.AreEqual(TestUnwrappedKey, result2); + mockInner.Verify( + k => k.UnwrapKeyAsync(TestAlgorithm, It.IsAny>(), It.IsAny()), + Times.Once); + } + + #endregion + + #region Stale fallback tests + + [TestMethod] + public void UnwrapKey_AkvDown_StaleEntryExists_ReturnsStale() + { + // First call succeeds and caches with a very short TTL. + Mock mockInner = CreateMockInnerKey(); + CachingKeyEncryptionKey sut = new CachingKeyEncryptionKey(mockInner.Object, TimeSpan.FromMilliseconds(1)); + + sut.UnwrapKey(TestAlgorithm, TestEncryptedKey); + + // Wait for TTL to expire. + Thread.Sleep(50); + + // Now make inner fail (AKV is down). + mockInner.Setup(k => k.UnwrapKey(It.IsAny(), It.IsAny>(), It.IsAny())) + .Throws(new InvalidOperationException("AKV is down")); + + // Should return stale cached bytes. + byte[] result = sut.UnwrapKey(TestAlgorithm, TestEncryptedKey); + CollectionAssert.AreEqual(TestUnwrappedKey, result); + } + + [TestMethod] + public async Task UnwrapKeyAsync_AkvDown_StaleEntryExists_ReturnsStale() + { + Mock mockInner = CreateMockInnerKeyAsync(); + CachingKeyEncryptionKey sut = new CachingKeyEncryptionKey(mockInner.Object, TimeSpan.FromMilliseconds(1)); + + await sut.UnwrapKeyAsync(TestAlgorithm, TestEncryptedKey); + await Task.Delay(50); + + // AKV goes down. + mockInner.Setup(k => k.UnwrapKeyAsync(It.IsAny(), It.IsAny>(), It.IsAny())) + .ThrowsAsync(new InvalidOperationException("AKV is down")); + + byte[] result = await sut.UnwrapKeyAsync(TestAlgorithm, TestEncryptedKey); + CollectionAssert.AreEqual(TestUnwrappedKey, result); + } + + [TestMethod] + [ExpectedException(typeof(InvalidOperationException))] + public void UnwrapKey_AkvDown_NoCachedEntry_Throws() + { + Mock mockInner = new Mock(); + mockInner.Setup(k => k.UnwrapKey(It.IsAny(), It.IsAny>(), It.IsAny())) + .Throws(new InvalidOperationException("AKV is down")); + + CachingKeyEncryptionKey sut = new CachingKeyEncryptionKey(mockInner.Object, TimeSpan.FromHours(1)); + + // No prior successful call → nothing to fall back to → must throw. + sut.UnwrapKey(TestAlgorithm, TestEncryptedKey); + } + + #endregion + + #region TTL expiry + refresh tests + + [TestMethod] + public void UnwrapKey_StaleEntry_AkvUp_RefreshesCache() + { + byte[] newUnwrappedKey = new byte[] { 0x11, 0x22, 0x33, 0x44 }; + Mock mockInner = CreateMockInnerKey(); + CachingKeyEncryptionKey sut = new CachingKeyEncryptionKey(mockInner.Object, TimeSpan.FromMilliseconds(1)); + + sut.UnwrapKey(TestAlgorithm, TestEncryptedKey); + Thread.Sleep(50); + + // AKV returns new key bytes (simulating key rotation). + mockInner.Setup(k => k.UnwrapKey(It.IsAny(), It.IsAny>(), It.IsAny())) + .Returns(newUnwrappedKey); + + byte[] result = sut.UnwrapKey(TestAlgorithm, TestEncryptedKey); + CollectionAssert.AreEqual(newUnwrappedKey, result); + } + + #endregion + + #region Shared cache tests + + [TestMethod] + public void SharedCache_MultipleKeyInstances_ShareEntries() + { + ConcurrentDictionary sharedCache = new(); + Mock mockInner1 = CreateMockInnerKey(); + Mock mockInner2 = CreateMockInnerKey(); + + CachingKeyEncryptionKey key1 = new CachingKeyEncryptionKey(mockInner1.Object, TimeSpan.FromHours(1), sharedCache); + CachingKeyEncryptionKey key2 = new CachingKeyEncryptionKey(mockInner2.Object, TimeSpan.FromHours(1), sharedCache); + + // Key1 unwraps and populates shared cache. + key1.UnwrapKey(TestAlgorithm, TestEncryptedKey); + + // Key2 should get cache hit — inner not called. + byte[] result = key2.UnwrapKey(TestAlgorithm, TestEncryptedKey); + CollectionAssert.AreEqual(TestUnwrappedKey, result); + + mockInner2.Verify( + k => k.UnwrapKey(It.IsAny(), It.IsAny>(), It.IsAny()), + Times.Never); + } + + #endregion + + #region WrapKey passthrough tests + + [TestMethod] + public void WrapKey_AlwaysDelegatesToInner() + { + byte[] plainKey = new byte[] { 0xAA, 0xBB }; + byte[] wrappedResult = new byte[] { 0xCC, 0xDD }; + + Mock mockInner = new Mock(); + mockInner.Setup(k => k.WrapKey(It.IsAny(), It.IsAny>(), It.IsAny())) + .Returns(wrappedResult); + + CachingKeyEncryptionKey sut = new CachingKeyEncryptionKey(mockInner.Object, TimeSpan.FromHours(1)); + + byte[] result1 = sut.WrapKey(TestAlgorithm, plainKey); + byte[] result2 = sut.WrapKey(TestAlgorithm, plainKey); + + CollectionAssert.AreEqual(wrappedResult, result1); + mockInner.Verify( + k => k.WrapKey(TestAlgorithm, It.IsAny>(), It.IsAny()), + Times.Exactly(2)); // always goes to AKV + } + + #endregion + + #region KeyId passthrough test + + [TestMethod] + public void KeyId_DelegatesToInner() + { + const string expectedKeyId = "https://my-vault.vault.azure.net/keys/my-key/abc123"; + Mock mockInner = new Mock(); + mockInner.Setup(k => k.KeyId).Returns(expectedKeyId); + + CachingKeyEncryptionKey sut = new CachingKeyEncryptionKey(mockInner.Object, TimeSpan.FromHours(1)); + + Assert.AreEqual(expectedKeyId, sut.KeyId); + } + + #endregion + + #region Integration with CachingKeyResolver + + [TestMethod] + public void CachingKeyResolver_WrapsWithCachingKeyEncryptionKey_WhenOptionEnabled() + { + Mock mockResolver = new Mock(MockBehavior.Strict); + Mock rawKey = new Mock(); + rawKey.Setup(k => k.KeyId).Returns("test-key"); + mockResolver.Setup(r => r.Resolve(It.IsAny(), It.IsAny())) + .Returns(rawKey.Object); + + CachingKeyResolverOptions options = new CachingKeyResolverOptions + { + KeyCacheTimeToLive = TimeSpan.FromHours(1), + RefreshTimerInterval = TimeSpan.FromHours(1), + UnwrapKeyCacheTimeToLive = TimeSpan.FromHours(24), + }; + + using CachingKeyResolver sut = new CachingKeyResolver(mockResolver.Object, options); + IKeyEncryptionKey result = sut.Resolve("test-key"); + + Assert.IsInstanceOfType(result, typeof(CachingKeyEncryptionKey)); + } + + [TestMethod] + public void CachingKeyResolver_DoesNotWrap_WhenUnwrapCacheDisabled() + { + Mock mockResolver = new Mock(MockBehavior.Strict); + Mock rawKey = new Mock(); + rawKey.Setup(k => k.KeyId).Returns("test-key"); + mockResolver.Setup(r => r.Resolve(It.IsAny(), It.IsAny())) + .Returns(rawKey.Object); + + CachingKeyResolverOptions options = new CachingKeyResolverOptions + { + KeyCacheTimeToLive = TimeSpan.FromHours(1), + RefreshTimerInterval = TimeSpan.FromHours(1), + UnwrapKeyCacheTimeToLive = TimeSpan.Zero, + }; + + using CachingKeyResolver sut = new CachingKeyResolver(mockResolver.Object, options); + IKeyEncryptionKey result = sut.Resolve("test-key"); + + Assert.IsNotInstanceOfType(result, typeof(CachingKeyEncryptionKey)); + } + + #endregion + + #region Helpers + + private static Mock CreateMockInnerKey() + { + Mock mock = new Mock(); + mock.Setup(k => k.KeyId).Returns("test-key"); + mock.Setup(k => k.UnwrapKey(It.IsAny(), It.IsAny>(), It.IsAny())) + .Returns(TestUnwrappedKey); + return mock; + } + + private static Mock CreateMockInnerKeyAsync() + { + Mock mock = new Mock(); + mock.Setup(k => k.KeyId).Returns("test-key"); + mock.Setup(k => k.UnwrapKeyAsync(It.IsAny(), It.IsAny>(), It.IsAny())) + .ReturnsAsync(TestUnwrappedKey); + return mock; + } + + #endregion + } +} diff --git a/Microsoft.Azure.Cosmos.Samples/Usage/CachingKeyResolverSample/Tests/CachingKeyResolverSample.Tests.csproj b/Microsoft.Azure.Cosmos.Samples/Usage/CachingKeyResolverSample/Tests/CachingKeyResolverSample.Tests.csproj new file mode 100644 index 0000000000..3fdec80177 --- /dev/null +++ b/Microsoft.Azure.Cosmos.Samples/Usage/CachingKeyResolverSample/Tests/CachingKeyResolverSample.Tests.csproj @@ -0,0 +1,21 @@ + + + + net8.0 + CachingKeyResolverSample.Tests + 10.0 + false + + + + + + + + + + + + + + diff --git a/Microsoft.Azure.Cosmos.Samples/Usage/CachingKeyResolverSample/Tests/CachingKeyResolverTests.cs b/Microsoft.Azure.Cosmos.Samples/Usage/CachingKeyResolverSample/Tests/CachingKeyResolverTests.cs new file mode 100644 index 0000000000..b6a7e462d8 --- /dev/null +++ b/Microsoft.Azure.Cosmos.Samples/Usage/CachingKeyResolverSample/Tests/CachingKeyResolverTests.cs @@ -0,0 +1,332 @@ +// ------------------------------------------------------------ +// Copyright (c) Microsoft Corporation. All rights reserved. +// ------------------------------------------------------------ + +namespace CachingKeyResolverSample.Tests +{ + using System; + using System.Collections.Concurrent; + using System.Collections.Generic; + using System.Linq; + using System.Threading; + using System.Threading.Tasks; + using Azure.Core.Cryptography; + using Microsoft.VisualStudio.TestTools.UnitTesting; + using Moq; + + [TestClass] + public class CachingKeyResolverTests + { + private const string TestKeyId = "https://my-vault.vault.azure.net/keys/my-key/abc123"; + + private static Mock CreateMockResolver() + { + Mock mock = new Mock(MockBehavior.Strict); + Mock mockKey = new Mock(); + mockKey.Setup(k => k.KeyId).Returns(TestKeyId); + + mock.Setup(r => r.Resolve(It.IsAny(), It.IsAny())) + .Returns(mockKey.Object); + + mock.Setup(r => r.ResolveAsync(It.IsAny(), It.IsAny())) + .ReturnsAsync(mockKey.Object); + + return mock; + } + + [TestMethod] + public void CacheHitReturnsImmediately() + { + Mock mockResolver = CreateMockResolver(); + CachingKeyResolverOptions options = new CachingKeyResolverOptions + { + KeyCacheTimeToLive = TimeSpan.FromHours(1), + RefreshTimerInterval = TimeSpan.FromHours(1), // prevent timer interference + UnwrapKeyCacheTimeToLive = TimeSpan.Zero, // disable unwrap wrapping for reference equality test + }; + + using CachingKeyResolver sut = new CachingKeyResolver(mockResolver.Object, options); + + // First call: cache miss — inner resolver is invoked. + IKeyEncryptionKey result1 = sut.Resolve(TestKeyId); + // Second call: cache hit — inner resolver should NOT be invoked again. + IKeyEncryptionKey result2 = sut.Resolve(TestKeyId); + + Assert.AreSame(result1, result2); + mockResolver.Verify(r => r.Resolve(TestKeyId, It.IsAny()), Times.Once); + } + + [TestMethod] + public void CacheMissFallsThrough() + { + Mock mockResolver = CreateMockResolver(); + CachingKeyResolverOptions options = new CachingKeyResolverOptions + { + KeyCacheTimeToLive = TimeSpan.FromHours(1), + RefreshTimerInterval = TimeSpan.FromHours(1), + }; + + using CachingKeyResolver sut = new CachingKeyResolver(mockResolver.Object, options); + + IKeyEncryptionKey result = sut.Resolve(TestKeyId); + + Assert.IsNotNull(result); + mockResolver.Verify(r => r.Resolve(TestKeyId, It.IsAny()), Times.Once); + } + + [TestMethod] + public void ExpiredEntryIsReResolved() + { + Mock mockResolver = new Mock(MockBehavior.Strict); + Mock key1 = new Mock(); + Mock key2 = new Mock(); + + int callCount = 0; + mockResolver.Setup(r => r.Resolve(It.IsAny(), It.IsAny())) + .Returns(() => + { + Interlocked.Increment(ref callCount); + return callCount == 1 ? key1.Object : key2.Object; + }); + + // Also setup async for background refresh timer + mockResolver.Setup(r => r.ResolveAsync(It.IsAny(), It.IsAny())) + .ReturnsAsync(() => key2.Object); + + CachingKeyResolverOptions options = new CachingKeyResolverOptions + { + KeyCacheTimeToLive = TimeSpan.FromMilliseconds(200), + ProactiveRefreshThreshold = TimeSpan.Zero, // disable proactive refresh + RefreshTimerInterval = TimeSpan.FromHours(1), // prevent timer interference + UnwrapKeyCacheTimeToLive = TimeSpan.Zero, // disable wrapping for AreSame assertions + }; + + using CachingKeyResolver sut = new CachingKeyResolver(mockResolver.Object, options); + + IKeyEncryptionKey first = sut.Resolve(TestKeyId); + Assert.AreSame(key1.Object, first); + + // Wait for entry to expire. + Thread.Sleep(300); + + IKeyEncryptionKey second = sut.Resolve(TestKeyId); + Assert.AreSame(key2.Object, second); + Assert.AreEqual(2, callCount); + } + + [TestMethod] + public void ConcurrentResolveCallsAreSafe() + { + Mock mockResolver = CreateMockResolver(); + CachingKeyResolverOptions options = new CachingKeyResolverOptions + { + KeyCacheTimeToLive = TimeSpan.FromHours(1), + RefreshTimerInterval = TimeSpan.FromHours(1), + UnwrapKeyCacheTimeToLive = TimeSpan.Zero, // disable unwrap wrapping for reference equality + }; + + using CachingKeyResolver sut = new CachingKeyResolver(mockResolver.Object, options); + + ConcurrentBag results = new ConcurrentBag(); + List tasks = new List(); + + for (int i = 0; i < 100; i++) + { + tasks.Add(Task.Run(() => + { + IKeyEncryptionKey key = sut.Resolve(TestKeyId); + results.Add(key); + })); + } + + Task.WaitAll(tasks.ToArray()); + + Assert.AreEqual(100, results.Count); + // All should return the same cached instance (after the first resolution). + Assert.IsTrue(results.All(k => k == results.First())); + } + + [TestMethod] + public async Task BackgroundRefreshUpdatesCache() + { + Mock mockResolver = new Mock(MockBehavior.Strict); + Mock originalKey = new Mock(); + Mock refreshedKey = new Mock(); + + int resolveAsyncCallCount = 0; + + mockResolver.Setup(r => r.Resolve(It.IsAny(), It.IsAny())) + .Returns(originalKey.Object); + + mockResolver.Setup(r => r.ResolveAsync(It.IsAny(), It.IsAny())) + .ReturnsAsync(() => + { + Interlocked.Increment(ref resolveAsyncCallCount); + return refreshedKey.Object; + }); + + CachingKeyResolverOptions options = new CachingKeyResolverOptions + { + KeyCacheTimeToLive = TimeSpan.FromMilliseconds(500), + ProactiveRefreshThreshold = TimeSpan.FromMilliseconds(400), // refresh when < 400ms left + RefreshTimerInterval = TimeSpan.FromMilliseconds(100), // check every 100ms + UnwrapKeyCacheTimeToLive = TimeSpan.Zero, // disable wrapping for AreSame assertions + }; + + using CachingKeyResolver sut = new CachingKeyResolver(mockResolver.Object, options); + + // Seed the cache. + IKeyEncryptionKey first = sut.Resolve(TestKeyId); + Assert.AreSame(originalKey.Object, first); + + // Wait for proactive refresh: the entry expires at ~500ms, + // proactive threshold is 400ms, so refresh should fire at ~100ms. + // Timer checks every 100ms. + await Task.Delay(600); + + // The cache should now contain the refreshed key. + IKeyEncryptionKey after = sut.Resolve(TestKeyId); + Assert.AreSame(refreshedKey.Object, after); + Assert.IsTrue(resolveAsyncCallCount >= 1, $"Expected at least 1 async resolve call, got {resolveAsyncCallCount}"); + } + + [TestMethod] + public async Task BackgroundRefreshFailureDoesNotEvict() + { + Mock mockResolver = new Mock(MockBehavior.Strict); + Mock originalKey = new Mock(); + + mockResolver.Setup(r => r.Resolve(It.IsAny(), It.IsAny())) + .Returns(originalKey.Object); + + // Async resolve (used by background refresh) throws. + mockResolver.Setup(r => r.ResolveAsync(It.IsAny(), It.IsAny())) + .ThrowsAsync(new InvalidOperationException("Simulated AKV failure")); + + CachingKeyResolverOptions options = new CachingKeyResolverOptions + { + KeyCacheTimeToLive = TimeSpan.FromMilliseconds(500), + ProactiveRefreshThreshold = TimeSpan.FromMilliseconds(400), + RefreshTimerInterval = TimeSpan.FromMilliseconds(100), + UnwrapKeyCacheTimeToLive = TimeSpan.Zero, // disable wrapping for AreSame assertions + }; + + using CachingKeyResolver sut = new CachingKeyResolver(mockResolver.Object, options); + + // Seed the cache. + IKeyEncryptionKey first = sut.Resolve(TestKeyId); + Assert.AreSame(originalKey.Object, first); + + // Wait for background refresh attempt (which will fail). + await Task.Delay(400); + + // The original entry should still be in cache (not evicted). + IKeyEncryptionKey afterFailedRefresh = sut.Resolve(TestKeyId); + Assert.AreSame(originalKey.Object, afterFailedRefresh); + } + + [TestMethod] + public void DisposeStopsTimer() + { + Mock mockResolver = CreateMockResolver(); + CachingKeyResolverOptions options = new CachingKeyResolverOptions + { + KeyCacheTimeToLive = TimeSpan.FromHours(1), + RefreshTimerInterval = TimeSpan.FromMilliseconds(50), + }; + + CachingKeyResolver sut = new CachingKeyResolver(mockResolver.Object, options); + + // Seed cache. + sut.Resolve(TestKeyId); + + // Dispose. + sut.Dispose(); + + // After dispose, inner resolver should not receive any new calls. + int callsBefore = mockResolver.Invocations.Count; + Thread.Sleep(200); // Wait several timer intervals. + int callsAfter = mockResolver.Invocations.Count; + + Assert.AreEqual(callsBefore, callsAfter, "No new resolve calls should happen after Dispose."); + } + + [TestMethod] + public void ResolveAfterDisposeThrows() + { + Mock mockResolver = CreateMockResolver(); + CachingKeyResolverOptions options = new CachingKeyResolverOptions + { + KeyCacheTimeToLive = TimeSpan.FromHours(1), + RefreshTimerInterval = TimeSpan.FromHours(1), + }; + + CachingKeyResolver sut = new CachingKeyResolver(mockResolver.Object, options); + sut.Dispose(); + + Assert.ThrowsException(() => sut.Resolve(TestKeyId)); + } + + [TestMethod] + public async Task ResolveAsyncAfterDisposeThrows() + { + Mock mockResolver = CreateMockResolver(); + CachingKeyResolverOptions options = new CachingKeyResolverOptions + { + KeyCacheTimeToLive = TimeSpan.FromHours(1), + RefreshTimerInterval = TimeSpan.FromHours(1), + }; + + CachingKeyResolver sut = new CachingKeyResolver(mockResolver.Object, options); + sut.Dispose(); + + await Assert.ThrowsExceptionAsync( + () => sut.ResolveAsync(TestKeyId)); + } + + [TestMethod] + public async Task RefreshDeduplication() + { + Mock mockResolver = new Mock(MockBehavior.Strict); + Mock key = new Mock(); + + int asyncResolveCount = 0; + + mockResolver.Setup(r => r.Resolve(It.IsAny(), It.IsAny())) + .Returns(key.Object); + + // Async resolve is slow, simulating AKV latency. + mockResolver.Setup(r => r.ResolveAsync(It.IsAny(), It.IsAny())) + .Returns(async (string id, CancellationToken ct) => + { + Interlocked.Increment(ref asyncResolveCount); + await Task.Delay(300, ct); + return key.Object; + }); + + CachingKeyResolverOptions options = new CachingKeyResolverOptions + { + KeyCacheTimeToLive = TimeSpan.FromMilliseconds(400), + ProactiveRefreshThreshold = TimeSpan.FromMilliseconds(350), // Almost always eligible + RefreshTimerInterval = TimeSpan.FromMilliseconds(50), // Fire rapidly + }; + + using CachingKeyResolver sut = new CachingKeyResolver(mockResolver.Object, options); + + // Seed cache. + sut.Resolve(TestKeyId); + + // Wait for multiple timer ticks to fire while refresh is in flight. + // The resolve is 300ms, timer is 50ms, so multiple ticks will fire + // during one refresh. Deduplication should prevent multiple concurrent resolves. + await Task.Delay(800); + + // We expect one initial sync Resolve + a small number of async refresh calls. + // Without deduplication, we'd see many more async calls. + // With deduplication, the number of async calls should be low (1-3 depending on timing). + Assert.IsTrue( + asyncResolveCount <= 4, + $"Expected at most 4 async resolve calls due to deduplication, got {asyncResolveCount}"); + } + } +}