Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -244,8 +244,11 @@ private async ValueTask<bool> MoveNextAsync_RunComponentQueriesAsync(ITrace trac
return await this.MoveNextAsync_DrainSingletonComponentAsync(trace, cancellationToken);
}

IReadOnlyList<ComponentWeight> componentWeights = ExtractComponentWeights(this.hybridSearchQueryInfo);

TryCatch<(IReadOnlyList<HybridSearchQueryResult>, QueryPage)> tryCollateSortedPipelineStageResults = await CollateSortedPipelineStageResultsAsync(
this.queryPipelineStages,
componentWeights,
this.maxConcurrency,
trace,
cancellationToken);
Expand Down Expand Up @@ -393,8 +396,28 @@ private static TryCatch<List<IQueryPipelineStage>> CreateQueryPipelineStages(
return TryCatch<List<IQueryPipelineStage>>.FromResult(queryPipelineStages);
}

private static IReadOnlyList<ComponentWeight> ExtractComponentWeights(HybridSearchQueryInfo hybridSearchQueryInfo)
{
bool useDefaultComponentWeight = (hybridSearchQueryInfo.ComponentWeights == null) || (hybridSearchQueryInfo.ComponentWeights.Count == 0);

List<ComponentWeight> result = new List<ComponentWeight>(hybridSearchQueryInfo.ComponentQueryInfos.Count);
for (int index = 0; index < hybridSearchQueryInfo.ComponentQueryInfos.Count; ++index)
{
QueryInfo queryInfo = hybridSearchQueryInfo.ComponentQueryInfos[index];
Debug.Assert(queryInfo.HasOrderBy, "The component query should have an order by");
Debug.Assert(queryInfo.HasNonStreamingOrderBy, "The component query is a non streaming order by");
Debug.Assert(queryInfo.OrderBy.Count == 1, "The component query should have exactly one order by expression");

double componentWeight = useDefaultComponentWeight ? 1.0 : hybridSearchQueryInfo.ComponentWeights[index];
result.Add(new ComponentWeight(componentWeight, queryInfo.OrderBy[0]));
}

return result;
}

private static async ValueTask<TryCatch<(IReadOnlyList<HybridSearchQueryResult>, QueryPage)>> CollateSortedPipelineStageResultsAsync(
IReadOnlyList<IQueryPipelineStage> queryPipelineStages,
IReadOnlyList<ComponentWeight> componentWeights,
int maxConcurrency,
ITrace trace,
CancellationToken cancellationToken)
Expand Down Expand Up @@ -439,14 +462,14 @@ private static TryCatch<List<IQueryPipelineStage>> CreateQueryPipelineStages(

IReadOnlyList<List<ScoreTuple>> componentScores = tryGetComponentScores.Result;

foreach (List<ScoreTuple> scoreTuples in componentScores)
for (int index = 0; index < componentScores.Count; ++index)
{
scoreTuples.Sort((x, y) => (-1) * x.Score.CompareTo(y.Score)); // sort descending, since higher scores are better
componentScores[index].Sort((x, y) => componentWeights[index].Comparison(x.Score, y.Score));
}

int[,] ranks = ComputeRanks(componentScores);

ComputeRrfScores(ranks, queryResults);
ComputeRrfScores(ranks, componentWeights, queryResults);

HybridSearchDebugTraceHelpers.TraceQueryResultsWithRanks(queryResults, ranks);

Expand Down Expand Up @@ -578,7 +601,7 @@ private static TryCatch<IReadOnlyList<List<ScoreTuple>>> RetrieveComponentScores
for (int index = 0; index < componentScores[componentIndex].Count; ++index)
{
// Identical scores should have the same rank
if ((index > 0) && (componentScores[componentIndex][index].Score < componentScores[componentIndex][index - 1].Score))
if ((index > 0) && (componentScores[componentIndex][index].Score != componentScores[componentIndex][index - 1].Score))
{
++rank;
}
Expand All @@ -592,16 +615,18 @@ private static TryCatch<IReadOnlyList<List<ScoreTuple>>> RetrieveComponentScores

private static void ComputeRrfScores(
int[,] ranks,
IReadOnlyList<ComponentWeight> componentWeights,
List<HybridSearchQueryResult> queryResults)
{
int componentCount = ranks.GetLength(0);
Debug.Assert(componentWeights.Count == componentCount, "The number of component weights should match the number of components");

for (int index = 0; index < queryResults.Count; ++index)
{
double rrfScore = 0;
for (int componentIndex = 0; componentIndex < componentCount; ++componentIndex)
{
rrfScore += 1.0 / (RrfConstant + ranks[componentIndex, index]);
rrfScore += componentWeights[componentIndex].Weight / (RrfConstant + ranks[componentIndex, index]);
}

queryResults[index] = queryResults[index].WithScore(rrfScore);
Expand Down Expand Up @@ -750,6 +775,24 @@ private static string FormatComponentQueryTextWorkaround(string format, GlobalFu
return TryCatch<(GlobalFullTextSearchStatistics, QueryPage)>.FromResult((globalStatisticsAggregator.GetResult(), queryPage));
}

private class ComponentWeight
{
public SortOrder SortOrder { get; }

public double Weight { get; }

public Comparison<double> Comparison { get; }

public ComponentWeight(double weight, SortOrder sortOrder)
{
this.Weight = weight;
this.SortOrder = sortOrder;

int comparisonFactor = (this.SortOrder == SortOrder.Ascending) ? 1 : -1;
this.Comparison = (x, y) => comparisonFactor * x.CompareTo(y);
}
}

private readonly struct ScoreTuple
{
public double Score { get; }
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,7 @@ enum QueryFeatures : ulong
ListAndSetAggregate = 1 << 12,
CountIf = 1 << 13,
HybridSearch = 1 << 14,
WeightedRankFusion = 1 << 15,
WeightedRankFusion = 1 << 15,
HybridSearchSkipOrderByRewrite = 1 << 16,
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,8 @@ internal static class QueryPlanRetriever
| QueryFeatures.DCount
| QueryFeatures.NonStreamingOrderBy
| QueryFeatures.CountIf
| QueryFeatures.HybridSearch;
| QueryFeatures.HybridSearch
| QueryFeatures.WeightedRankFusion;

private static readonly QueryFeatures SupportedQueryFeaturesWithoutNonStreamingOrderBy =
SupportedQueryFeatures & (~QueryFeatures.NonStreamingOrderBy);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,19 +58,6 @@ public sealed class HybridSearchQueryTests : QueryTestsBase

[TestMethod]
public async Task SanityTests()
{
CosmosArray documentsArray = await LoadDocuments();
IEnumerable<string> documents = documentsArray.Select(document => document.ToString());

await this.CreateIngestQueryDeleteAsync(
connectionModes: ConnectionModes.Direct, // | ConnectionModes.Gateway,
collectionTypes: CollectionTypes.MultiPartition, // | CollectionTypes.SinglePartition,
documents: documents,
query: RunSanityTests,
indexingPolicy: CompositeIndexPolicy);
}

private static async Task RunSanityTests(Container container, IReadOnlyList<CosmosObject> _)
{
List<SanityTestCase> testCases = new List<SanityTestCase>
{
Expand Down Expand Up @@ -146,6 +133,68 @@ ORDER BY RANK RRF(VectorDistance(c.vector, {SampleVector}), FullTextScore(c.titl
new List<List<int>>{new List<int>{ 21, 75, 37, 24, 26, 35, 49, 87, 55, 9 } }),
};

await this.RunTests(testCases);
}

[TestMethod]
[Ignore("This test is disabled because it needs an emulator refresh.")]
public async Task WeightedRankFusionTests()
{
List<SanityTestCase> testCases = new List<SanityTestCase>
{
MakeSanityTest(@"
SELECT c.index AS Index, c.title AS Title, c.text AS Text
FROM c
WHERE FullTextContains(c.title, 'John') OR FullTextContains(c.text, 'John') OR FullTextContains(c.text, 'United States')
ORDER BY RANK RRF(FullTextScore(c.title, ['John']), FullTextScore(c.text, ['United States']), [1, 1])",
new List<List<int>>{
new List<int>{ 61, 51, 49, 54, 75, 24, 77, 76, 80, 25, 22, 2, 66, 57, 85 },
new List<int>{ 61, 51, 49, 54, 75, 24, 77, 76, 80, 25, 22, 2, 66, 85, 57 },
}),
MakeSanityTest(@"
SELECT c.index AS Index, c.title AS Title, c.text AS Text
FROM c
WHERE FullTextContains(c.title, 'John') OR FullTextContains(c.text, 'John') OR FullTextContains(c.text, 'United States')
ORDER BY RANK RRF(FullTextScore(c.title, ['John']), FullTextScore(c.text, ['United States']), [10, 10])",
new List<List<int>>{
new List<int>{ 61, 51, 49, 54, 75, 24, 77, 76, 80, 25, 22, 2, 66, 57, 85 },
new List<int>{ 61, 51, 49, 54, 75, 24, 77, 76, 80, 25, 22, 2, 66, 85, 57 },
}),
MakeSanityTest(@"
SELECT TOP 10 c.index AS Index, c.title AS Title, c.text AS Text
FROM c
WHERE FullTextContains(c.title, 'John') OR FullTextContains(c.text, 'John') OR FullTextContains(c.text, 'United States')
ORDER BY RANK RRF(FullTextScore(c.title, ['John']), FullTextScore(c.text, ['United States']), [0.1, 0.1])",
new List<List<int>>{ new List<int>{ 61, 51, 49, 54, 75, 24, 77, 76, 80, 25 } }),
MakeSanityTest(@"
SELECT c.index AS Index, c.title AS Title, c.text AS Text
FROM c
WHERE FullTextContains(c.title, 'John') OR FullTextContains(c.text, 'John') OR FullTextContains(c.text, 'United States')
ORDER BY RANK RRF(FullTextScore(c.title, ['John']), FullTextScore(c.text, ['United States']), [-1, -1])",
new List<List<int>>{
new List<int>{ 85, 57, 66, 2, 22, 25, 77, 76, 80, 75, 24, 49, 54, 51, 81 },
new List<int>{ 57, 85, 2, 66, 22, 25, 80, 76, 77, 24, 75, 54, 49, 51, 61 },
}),
};

await this.RunTests(testCases);
}

private async Task RunTests(IEnumerable<SanityTestCase> testCases)
{
CosmosArray documentsArray = await LoadDocuments();
IEnumerable<string> documents = documentsArray.Select(document => document.ToString());

await this.CreateIngestQueryDeleteAsync(
connectionModes: ConnectionModes.Direct, // | ConnectionModes.Gateway,
collectionTypes: CollectionTypes.MultiPartition, // | CollectionTypes.SinglePartition,
documents: documents,
query: (container, _) => RunTests(container, testCases),
indexingPolicy: CompositeIndexPolicy);
}

private static async Task RunTests(Container container, IEnumerable<SanityTestCase> testCases)
{
foreach (SanityTestCase testCase in testCases)
{
List<TextDocument> result = await RunQueryCombinationsAsync<TextDocument>(
Expand Down
Loading