Skip to content

Commit 7daaaae

Browse files
leminh98leminh98
andauthored
Query: Adds Weighted RRF capability to LINQ (#5308)
# Pull Request Template ## Description This PR adds support for weights in the CosmosLINQExtension method. RRF now features two different signatures to allow users to specify weights in their queries ## Type of change Please delete options that are not relevant. - [] Bug fix (non-breaking change which fixes an issue) - [x] New feature (non-breaking change which adds functionality) - [] Breaking change (fix or feature that would cause existing functionality to not work as expected) - [] This change requires a documentation update ## Closing issues To automatically close an issue: closes #IssueNumber --------- Co-authored-by: leminh98 <leminh@microsoft.com>
1 parent b16759f commit 7daaaae

6 files changed

Lines changed: 211 additions & 5 deletions

File tree

Microsoft.Azure.Cosmos/src/Linq/BuiltinFunctions/OtherBuiltinSystemFunctions.cs

Lines changed: 15 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -24,17 +24,21 @@ public RRFVisit()
2424
true,
2525
new List<Type[]>()
2626
{
27-
new Type[]{typeof(double[])}
27+
new Type[]{typeof(double[])},
28+
new Type[]{typeof(double[]), typeof(double[])}
2829
})
2930
{
3031
}
3132

3233
protected override SqlScalarExpression VisitImplicit(MethodCallExpression methodCallExpression, TranslationContext context)
3334
{
34-
if (methodCallExpression.Arguments.Count == 1
35-
&& methodCallExpression.Arguments[0] is NewArrayExpression argumentsExpressions)
35+
if (methodCallExpression.Arguments.Count != 1 && methodCallExpression.Arguments.Count != 2)
36+
{
37+
throw new DocumentQueryException("Invalid Argument Count.");
38+
}
39+
40+
if (methodCallExpression.Arguments[0] is NewArrayExpression argumentsExpressions)
3641
{
37-
// For RRF, We don't need to care about the first argument, it is the object itself and have no relevance to the computation
3842
ReadOnlyCollection<Expression> functionListExpression = argumentsExpressions.Expressions;
3943
List<SqlScalarExpression> arguments = new List<SqlScalarExpression>();
4044
foreach (Expression argument in functionListExpression)
@@ -65,10 +69,16 @@ protected override SqlScalarExpression VisitImplicit(MethodCallExpression method
6569
arguments.Add(ExpressionToSql.VisitNonSubqueryScalarExpression(argument, context));
6670
}
6771

72+
// Append the weight if exists
73+
if (methodCallExpression.Arguments.Count == 2)
74+
{
75+
arguments.Add(ExpressionToSql.VisitNonSubqueryScalarExpression(methodCallExpression.Arguments[1], context));
76+
}
77+
6878
return SqlFunctionCallScalarExpression.CreateBuiltin(SqlFunctionCallScalarExpression.Names.RRF, arguments.ToImmutableArray());
6979
}
7080

71-
return null;
81+
throw new DocumentQueryException(string.Format(CultureInfo.CurrentCulture, "Method {0} is not supported with the given argument list.", methodCallExpression.Method.Name));
7282
}
7383

7484
protected override SqlScalarExpression VisitExplicit(MethodCallExpression methodCallExpression, TranslationContext context)

Microsoft.Azure.Cosmos/src/Linq/CosmosLinqExtensions.cs

Lines changed: 22 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -470,6 +470,28 @@ public static double RRF(params double[] scoringFunctions)
470470
throw new NotImplementedException(ClientResources.ExtensionMethodNotImplemented);
471471
}
472472

473+
/// <summary>
474+
/// This system function is used to combine two or more scores provided by other scoring functions.
475+
/// For more information, see https://learn.microsoft.com/en-us/azure/cosmos-db/nosql/query/rrf.
476+
/// This method is to be used in LINQ expressions only and will be evaluated on server.
477+
/// There's no implementation provided in the client library.
478+
/// </summary>
479+
/// <param name="scoringFunctions">the scoring functions to combine. Valid functions are FullTextScore and VectorDistance. </param>
480+
/// <param name="weights">the weights to use for scoring functions</param>
481+
/// <returns>Returns the the combined scores of the scoring functions.</returns>
482+
/// <example>
483+
/// <code>
484+
/// <![CDATA[
485+
/// var matched = documents.OrderByRank(document => document.RRF(document.Name.FullTextScore(<keyword1>), document.Address.FullTextScore(<keyword2>)));
486+
/// ]]>
487+
/// </code>
488+
/// </example>
489+
public static double RRF(double[] scoringFunctions, double[] weights)
490+
{
491+
// The reason for not defining "this" keyword is because this causes undesirable serialization when call Expression.ToString() on this method
492+
throw new NotImplementedException(ClientResources.ExtensionMethodNotImplemented);
493+
}
494+
473495
/// <summary>
474496
/// This method generate query definition from LINQ query.
475497
/// </summary>
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
<Results>
2+
<Result>
3+
<Input>
4+
<Description><![CDATA[Standard weighted RRF calls]]></Description>
5+
<Expression><![CDATA[query.OrderByRank(doc => RRF(new [] {doc.StringField.FullTextScore(new [] {"test1"}), doc.StringField.FullTextScore(new [] {"test1", "text2"})}, new [] {1, 2})).Select(doc => doc.Pk)]]></Expression>
6+
</Input>
7+
<Output>
8+
<SqlQuery><![CDATA[
9+
SELECT VALUE root["Pk"]
10+
FROM root
11+
ORDER BY RANK RRF(FullTextScore(root["StringField"], "test1"), FullTextScore(root["StringField"], "test1", "text2"), [1, 2])]]></SqlQuery>
12+
<Results><![CDATA[[
13+
"Test",
14+
"Test"
15+
]]]></Results>
16+
</Output>
17+
</Result>
18+
<Result>
19+
<Input>
20+
<Description><![CDATA[Standard weighted RRF calls using anonymous types]]></Description>
21+
<Expression><![CDATA[query.OrderByRank(doc => RRF(new [] {doc.StringField.FullTextScore(new [] {"test1"}), doc.StringField.FullTextScore(new [] {"test1", "text2"})}, new [] {1, 2})).Select(doc => doc.Pk)]]></Expression>
22+
</Input>
23+
<Output>
24+
<SqlQuery><![CDATA[
25+
SELECT VALUE root["Pk"]
26+
FROM root
27+
ORDER BY RANK RRF(FullTextScore(root["StringField"], "test1"), FullTextScore(root["StringField"], "test1", "text2"), [1, 2])]]></SqlQuery>
28+
<Results><![CDATA[[
29+
"Test",
30+
"Test"
31+
]]]></Results>
32+
</Output>
33+
</Result>
34+
<Result>
35+
<Input>
36+
<Description><![CDATA[Weighted RRF with weights and functions not in a list]]></Description>
37+
<Expression><![CDATA[query.OrderByRank(doc => RRF(new [] {doc.StringField.FullTextScore(new [] {"test1"}), doc.StringField2.FullTextScore(new [] {"test1", "test2", "test3"}), 1, 2})).Select(doc => doc.Pk)]]></Expression>
38+
</Input>
39+
<Output>
40+
<SqlQuery><![CDATA[]]></SqlQuery>
41+
<ErrorMessage><![CDATA[Expressions of type System.Double is not supported as an argument to CosmosLinqExtensions.RRF. Supported expressions are method calls to FullTextScore, VectorDistance.]]></ErrorMessage>
42+
</Output>
43+
</Result>
44+
<Result>
45+
<Input>
46+
<Description><![CDATA[Weighted RRF with weights array first]]></Description>
47+
<Expression><![CDATA[query.OrderByRank(doc => RRF(new [] {1, 2}, new [] {doc.StringField.FullTextScore(new [] {"test1"}), doc.StringField.FullTextScore(new [] {"test1", "text2"})})).Select(doc => doc.Pk)]]></Expression>
48+
</Input>
49+
<Output>
50+
<SqlQuery><![CDATA[]]></SqlQuery>
51+
<ErrorMessage><![CDATA[Method RRF is not supported with the given argument list.]]></ErrorMessage>
52+
</Output>
53+
</Result>
54+
<Result>
55+
<Input>
56+
<Description><![CDATA[Weighted RRF with mixed and matched values/functions in array]]></Description>
57+
<Expression><![CDATA[query.OrderByRank(doc => RRF(new [] {1, doc.StringField.FullTextScore(new [] {"test1"})}, new [] {2, doc.StringField.FullTextScore(new [] {"test1", "text2"})})).Select(doc => doc.Pk)]]></Expression>
58+
</Input>
59+
<Output>
60+
<SqlQuery><![CDATA[]]></SqlQuery>
61+
<ErrorMessage><![CDATA[Expressions of type System.Double is not supported as an argument to CosmosLinqExtensions.RRF. Supported expressions are method calls to FullTextScore, VectorDistance.]]></ErrorMessage>
62+
</Output>
63+
</Result>
64+
<Result>
65+
<Input>
66+
<Description><![CDATA[Weighted RRF with mixed and matched values/functions in array 2]]></Description>
67+
<Expression><![CDATA[query.OrderByRank(doc => RRF(new [] {doc.StringField.FullTextScore(new [] {"test1"}), doc.StringField.FullTextScore(new [] {"test1"})}, new [] {2, doc.StringField.FullTextScore(new [] {"test1", "text2"})})).Select(doc => doc.Pk)]]></Expression>
68+
</Input>
69+
<Output>
70+
<SqlQuery><![CDATA[
71+
SELECT VALUE root["Pk"]
72+
FROM root
73+
ORDER BY RANK RRF(FullTextScore(root["StringField"], "test1"), FullTextScore(root["StringField"], "test1"), [2, FullTextScore(root["StringField"], "test1", "text2")])]]></SqlQuery>
74+
<ErrorMessage><![CDATA[Status Code: BadRequest,{"errors":[{"severity":"Error","location":{"start":34,"end":200},"code":"SC2229","message":"The last parameter of the RRF function is an optional array of weights. When present, it must be a literal array of numbers, one for each of the component scores used for the RRF function. The length of this array must be the same as the number of the component scores."}]},0x800A0B00]]></ErrorMessage>
75+
</Output>
76+
</Result>
77+
</Results>

Microsoft.Azure.Cosmos/tests/Microsoft.Azure.Cosmos.EmulatorTests/Linq/LinqTranslationBaselineTests.cs

Lines changed: 89 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -667,6 +667,95 @@ static DataObject createDataObj(Random random)
667667
this.ExecuteTestSuite(inputs);
668668
}
669669

670+
[TestMethod]
671+
public void TestWeightedRRF()
672+
{
673+
const int Records = 2;
674+
const int MaxStringLength = 100;
675+
static DataObject createDataObj(Random random)
676+
{
677+
DataObject obj = new DataObject
678+
{
679+
StringField = LinqTestsCommon.RandomString(random, random.Next(MaxStringLength)),
680+
IntField = 1,
681+
Id = Guid.NewGuid().ToString(),
682+
Pk = "Test"
683+
};
684+
return obj;
685+
}
686+
Func<bool, IQueryable<DataObject>> getQuery = LinqTestsCommon.GenerateTestCosmosData(createDataObj, Records, testContainer);
687+
688+
List<LinqTestInput> inputs = new List<LinqTestInput>
689+
{
690+
// public static double RRF(double[][] scoringFunctions, double[] weights)
691+
new LinqTestInput("Standard weighted RRF calls", b => getQuery(b)
692+
.OrderByRank(doc => RRF(new double[]
693+
{
694+
doc.StringField.FullTextScore(new string[] { "test1" }),
695+
doc.StringField.FullTextScore(new string[] { "test1", "text2" })
696+
},
697+
new double[] { 1.0, 2.0 } ))
698+
.Select(doc => doc.Pk)),
699+
700+
new LinqTestInput("Standard weighted RRF calls using anonymous types", b => getQuery(b)
701+
.OrderByRank(doc => RRF(new []
702+
{
703+
doc.StringField.FullTextScore(new string[] { "test1" }),
704+
doc.StringField.FullTextScore(new string[] { "test1", "text2" })
705+
},
706+
new [] { 1.0, 2.0 } ))
707+
.Select(doc => doc.Pk)),
708+
709+
// Negative case: weights are not in an array
710+
new LinqTestInput("Weighted RRF with weights and functions not in a list", b => getQuery(b)
711+
.OrderByRank(doc => RRF(doc.StringField.FullTextScore(new string[] { "test1" }),
712+
doc.StringField2.FullTextScore(new string[] { "test1", "test2", "test3" }),
713+
1.0,
714+
2.0))
715+
.Select(doc => doc.Pk)),
716+
new LinqTestInput("Weighted RRF with weights array first", b => getQuery(b)
717+
.OrderByRank(doc => RRF(new double[] { 1.0, 2.0 },
718+
new double[]
719+
{
720+
doc.StringField.FullTextScore(new string[] { "test1" }),
721+
doc.StringField.FullTextScore(new string[] { "test1", "text2" })
722+
}))
723+
.Select(doc => doc.Pk)),
724+
new LinqTestInput("Weighted RRF with mixed and matched values/functions in array", b => getQuery(b)
725+
.OrderByRank(doc => RRF(new double[] {
726+
1.0,
727+
doc.StringField.FullTextScore(new string[] { "test1" }) },
728+
new double[]
729+
{
730+
2.0,
731+
doc.StringField.FullTextScore(new string[] { "test1", "text2" })
732+
}))
733+
.Select(doc => doc.Pk)),
734+
new LinqTestInput("Weighted RRF with mixed and matched values/functions in array 2", b => getQuery(b)
735+
.OrderByRank(doc => RRF(new double[] {
736+
doc.StringField.FullTextScore(new string[] { "test1" }),
737+
doc.StringField.FullTextScore(new string[] { "test1" }) },
738+
new double[]
739+
{
740+
2.0,
741+
doc.StringField.FullTextScore(new string[] { "test1", "text2" })
742+
}))
743+
.Select(doc => doc.Pk)),
744+
745+
746+
};
747+
748+
foreach (LinqTestInput input in inputs)
749+
{
750+
// OrderBy are not supported client side.
751+
// Therefore this method is verified with baseline only.
752+
input.skipVerification = true;
753+
input.serializeOutput = true;
754+
}
755+
756+
this.ExecuteTestSuite(inputs);
757+
}
758+
670759
[TestMethod]
671760
public void TestOrderByRankFunctionComposeWithOtherFunctions()
672761
{

Microsoft.Azure.Cosmos/tests/Microsoft.Azure.Cosmos.EmulatorTests/Microsoft.Azure.Cosmos.EmulatorTests.csproj

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -250,6 +250,9 @@
250250
<Content Include="BaselineTest\TestBaseline\LinqTranslationBaselineTests.TestRRFOrderByRankFunction.xml">
251251
<CopyToOutputDirectory>PreserveNewest</CopyToOutputDirectory>
252252
</Content>
253+
<Content Include="BaselineTest\TestBaseline\LinqTranslationBaselineTests.TestWeightedRRF.xml">
254+
<CopyToOutputDirectory>PreserveNewest</CopyToOutputDirectory>
255+
</Content>
253256
<Content Include="BaselineTest\TestBaseline\LinqTranslationBaselineTests.TestOrderByRankFunctionComposeWithOtherFunctions.xml">
254257
<CopyToOutputDirectory>PreserveNewest</CopyToOutputDirectory>
255258
</Content>

Microsoft.Azure.Cosmos/tests/Microsoft.Azure.Cosmos.Tests/Contracts/DotNetSDKAPI.json

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6366,6 +6366,11 @@
63666366
],
63676367
"MethodInfo": "Double FullTextScore[TSource](TSource, System.String[]);IsAbstract:False;IsStatic:True;IsVirtual:False;IsGenericMethod:True;IsConstructor:False;IsFinal:False;"
63686368
},
6369+
"Double RRF(Double[], Double[])": {
6370+
"Type": "Method",
6371+
"Attributes": [],
6372+
"MethodInfo": "Double RRF(Double[], Double[]);IsAbstract:False;IsStatic:True;IsVirtual:False;IsGenericMethod:False;IsConstructor:False;IsFinal:False;"
6373+
},
63696374
"Double RRF(Double[])": {
63706375
"Type": "Method",
63716376
"Attributes": [],

0 commit comments

Comments
 (0)