Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -35,8 +35,9 @@ internal sealed class TokenCredentialCache : IDisposable
// The token refresh retries half the time. Given default of 1hr it will retry at 30m, 15, 7.5, 3.75, 1.875
// If the background refresh fails with less than a minute then just allow the request to hit the exception.
public static readonly TimeSpan MinimumTimeBetweenBackgroundRefreshInterval = TimeSpan.FromMinutes(1);

private const string ScopeFormat = "https://{0}/.default";

private readonly TokenRequestContext tokenRequestContext;
private readonly TokenCredential tokenCredential;
private readonly CancellationTokenSource cancellationTokenSource;
Expand All @@ -62,13 +63,17 @@ internal TokenCredentialCache(
if (accountEndpoint == null)
{
throw new ArgumentNullException(nameof(accountEndpoint));
}

this.tokenRequestContext = new TokenRequestContext(new string[]
{
string.Format(TokenCredentialCache.ScopeFormat, accountEndpoint.Host)
});

}

string? scopeOverride = ConfigurationManager.AADScopeOverrideValue(defaultValue: null);

this.tokenRequestContext = new TokenRequestContext(new string[]
{
!string.IsNullOrEmpty(scopeOverride)
? scopeOverride
: string.Format(TokenCredentialCache.ScopeFormat, accountEndpoint.Host)
});

if (backgroundTokenCredentialRefreshInterval.HasValue)
{
if (backgroundTokenCredentialRefreshInterval.Value <= TimeSpan.Zero)
Expand Down
21 changes: 20 additions & 1 deletion Microsoft.Azure.Cosmos/src/Util/ConfigurationManager.cs
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,12 @@ internal static class ConfigurationManager
/// <summary>
/// Environment variable name to enable thin client mode.
/// </summary>
internal static readonly string ThinClientModeEnabled = "AZURE_COSMOS_THIN_CLIENT_ENABLED";
internal static readonly string ThinClientModeEnabled = "AZURE_COSMOS_THIN_CLIENT_ENABLED";

/// <summary>
/// Environment variable to override AAD scope.
/// </summary>
internal static readonly string AADScopeOverride = "AZURE_COSMOS_AAD_SCOPE_OVERRIDE";

/// <summary>
/// A read-only string containing the environment variable name for capturing the consecutive failure count for reads, before triggering per partition
Expand Down Expand Up @@ -183,6 +188,20 @@ public static bool IsThinClientEnabled(
.GetEnvironmentVariable(
variable: ConfigurationManager.ThinClientModeEnabled,
defaultValue: defaultValue);
}

/// <summary>
/// Gets the AAD scope value to override.
/// </summary>
/// <param name="defaultValue">Emoty string for AAD scope if no scope value is provided.</param>
/// <returns>AAD scope value.</returns>
public static string AADScopeOverrideValue(
string defaultValue)
{
return ConfigurationManager
.GetEnvironmentVariable(
variable: ConfigurationManager.AADScopeOverride,
defaultValue: defaultValue);
}

/// <summary>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ public async Task AadMockTest(ConnectionMode connectionMode)
try
{
(string endpoint, string authKey) = TestCommon.GetAccountInfo();
LocalEmulatorTokenCredential simpleEmulatorTokenCredential = new LocalEmulatorTokenCredential(authKey);
LocalEmulatorTokenCredential simpleEmulatorTokenCredential = new LocalEmulatorTokenCredential(expectedScope: "https://127.0.0.1/.default", masterKey: authKey);
CosmosClientOptions clientOptions = new CosmosClientOptions()
{
ConnectionMode = connectionMode,
Expand Down Expand Up @@ -140,8 +140,9 @@ void GetAadTokenCallBack(

(string endpoint, string authKey) = TestCommon.GetAccountInfo();
LocalEmulatorTokenCredential simpleEmulatorTokenCredential = new LocalEmulatorTokenCredential(
authKey,
GetAadTokenCallBack);
expectedScope: "https://127.0.0.1/.default",
masterKey: authKey,
getTokenCallback: GetAadTokenCallBack);

CosmosClientOptions clientOptions = new CosmosClientOptions()
{
Expand Down Expand Up @@ -191,8 +192,9 @@ void GetAadTokenCallBack(

(string endpoint, string authKey) = TestCommon.GetAccountInfo();
LocalEmulatorTokenCredential simpleEmulatorTokenCredential = new LocalEmulatorTokenCredential(
authKey,
GetAadTokenCallBack);
expectedScope: "https://127.0.0.1/.default",
masterKey: authKey,
getTokenCallback: GetAadTokenCallBack);

CosmosClientOptions clientOptions = new CosmosClientOptions()
{
Expand Down Expand Up @@ -232,8 +234,9 @@ void GetAadTokenCallBack(

(string endpoint, string authKey) = TestCommon.GetAccountInfo();
LocalEmulatorTokenCredential simpleEmulatorTokenCredential = new LocalEmulatorTokenCredential(
authKey,
GetAadTokenCallBack);
expectedScope: "https://127.0.0.1/.default",
masterKey: authKey,
getTokenCallback: GetAadTokenCallBack);

CosmosClientOptions clientOptions = new CosmosClientOptions()
{
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -86,51 +86,51 @@ FROM c
WHERE FullTextContains(c.title, 'John') OR FullTextContains(c.text, 'John') OR FullTextContains(c.text, 'United States')
ORDER BY RANK RRF(FullTextScore(c.title, ['John']), FullTextScore(c.text, ['United States']))",
new List<List<int>>{
new List<int>{ 61, 51, 49, 54, 75, 24, 77, 76, 80, 25, 22, 2, 66, 57, 85 },
new List<int>{ 61, 51, 49, 54, 75, 24, 77, 76, 80, 25, 22, 2, 66, 85, 57 },
new List<int>{ 61, 51, 49, 54, 75, 24, 77, 76, 80, 2, 22, 57, 85 },
new List<int>{ 61, 51, 49, 54, 75, 24, 77, 76, 80, 2, 22, 85, 57 },
}),
MakeSanityTest(@"
SELECT TOP 10 c.index AS Index, c.title AS Title, c.text AS Text
FROM c
WHERE FullTextContains(c.title, 'John') OR FullTextContains(c.text, 'John') OR FullTextContains(c.text, 'United States')
ORDER BY RANK RRF(FullTextScore(c.title, ['John']), FullTextScore(c.text, ['United States']))",
new List<List<int>>{ new List<int>{ 61, 51, 49, 54, 75, 24, 77, 76, 80, 25 } }),
new List<List<int>>{ new List<int>{ 61, 51, 49, 54, 75, 24, 77, 76, 80, 2 } }),
MakeSanityTest(@"
SELECT c.index AS Index, c.title AS Title, c.text AS Text
FROM c
WHERE FullTextContains(c.title, 'John') OR FullTextContains(c.text, 'John') OR FullTextContains(c.text, 'United States')
ORDER BY RANK RRF(FullTextScore(c.title, ['John']), FullTextScore(c.text, ['United States']))
OFFSET 5 LIMIT 10",
new List<List<int>>{
new List<int>{ 24, 77, 76, 80, 25, 22, 2, 66, 57, 85 },
new List<int>{ 24, 77, 76, 80, 25, 22, 2, 66, 85, 57 },
new List<int>{ 24, 77, 76, 80, 2, 22, 57, 85 },
new List<int>{ 24, 77, 76, 80, 2, 22, 85, 57 },
}),
MakeSanityTest(@"
SELECT TOP 10 c.index AS Index, c.title AS Title, c.text AS Text
FROM c
ORDER BY RANK RRF(FullTextScore(c.title, ['John']), FullTextScore(c.text, ['United States']))",
new List<List<int>>{new List<int>{ 61, 51, 49, 54, 75, 24, 77, 76, 80, 25 } }),
new List<List<int>>{new List<int>{ 61, 51, 49, 54, 75, 24, 77, 76, 80, 2 } }),
MakeSanityTest(@"
SELECT c.index AS Index, c.title AS Title, c.text AS Text
FROM c
ORDER BY RANK RRF(FullTextScore(c.title, ['John']), FullTextScore(c.text, ['United States']))
OFFSET 0 LIMIT 13",
new List<List<int>>{ new List<int>{ 61, 51, 49, 54, 75, 24, 77, 76, 80, 25, 22, 2, 66 } }),
OFFSET 0 LIMIT 11",
new List<List<int>>{ new List<int>{ 61, 51, 49, 54, 75, 24, 77, 76, 80, 2, 22 } }),
MakeSanityTest($@"
SELECT TOP 10 c.index AS Index, c.title AS Title, c.text AS Text
FROM c
ORDER BY RANK RRF(FullTextScore(c.title, ['John']), FullTextScore(c.text, ['United States']), VectorDistance(c.vector, {SampleVector}))",
new List<List<int>>{new List<int>{ 21, 75, 37, 24, 26, 35, 49, 87, 55, 9 } }),
new List<List<int>>{new List<int>{ 21, 37, 75, 26, 35, 24, 87, 55, 49, 9 } }),
MakeSanityTest($@"
SELECT TOP 10 c.index AS Index, c.title AS Title, c.text AS Text
FROM c
ORDER BY RANK RRF(VectorDistance(c.vector, {SampleVector}), FullTextScore(c.title, ['John']), FullTextScore(c.text, ['United States']))",
new List<List<int>>{new List<int>{ 21, 75, 37, 24, 26, 35, 49, 87, 55, 9 } }),
new List<List<int>>{new List<int>{ 21, 37, 75, 26, 35, 24, 87, 55, 49, 9 } }),
MakeSanityTest($@"
SELECT TOP 10 c.index AS Index, c.title AS Title, c.text AS Text
FROM c
ORDER BY RANK RRF(VectorDistance(c.vector, {SampleVector}), FullTextScore(c.title, ['John']), VectorDistance(c.image, {SampleVector}), VectorDistance(c.backup_image, {SampleVector}), FullTextScore(c.text, ['United States']))",
new List<List<int>>{new List<int>{ 21, 75, 37, 24, 26, 35, 49, 87, 55, 9 } }),
new List<List<int>>{new List<int>{ 21, 37, 75, 26, 35, 24, 87, 55, 49, 9 } }),
};

await this.RunTests(testCases);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,15 +18,18 @@ public class LocalEmulatorTokenCredential : TokenCredential
private readonly DateTime? DefaultDateTime = null;
private readonly Action<TokenRequestContext, CancellationToken> GetTokenCallback;
private readonly string masterKey;
private readonly string expectedScope;

internal LocalEmulatorTokenCredential(
internal LocalEmulatorTokenCredential(
string expectedScope,
string masterKey = null,
Action<TokenRequestContext, CancellationToken> getTokenCallback = null,
DateTime? defaultDateTime = null)
{
this.masterKey = masterKey;
this.GetTokenCallback = getTokenCallback;
this.DefaultDateTime = defaultDateTime;
this.DefaultDateTime = defaultDateTime;
this.expectedScope = expectedScope;
}

public override AccessToken GetToken(TokenRequestContext requestContext, CancellationToken cancellationToken)
Expand All @@ -40,9 +43,8 @@ public override ValueTask<AccessToken> GetTokenAsync(TokenRequestContext request
}

private AccessToken GetAccessToken(TokenRequestContext requestContext, CancellationToken cancellationToken)
{
// Verify that the request context is a valid URI
Assert.AreEqual("https://127.0.0.1/.default", requestContext.Scopes.First());
{
Assert.AreEqual(this.expectedScope, requestContext.Scopes.First());

this.GetTokenCallback?.Invoke(
requestContext,
Expand Down
Loading
Loading