Skip to content

Commit d87e0d3

Browse files
authored
Translate string.Join/Concat with ordering on SQLite (#38344)
Fixes #32201
1 parent 16ef9a3 commit d87e0d3

6 files changed

Lines changed: 383 additions & 20 deletions

File tree

Lines changed: 218 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,218 @@
1+
// Licensed to the .NET Foundation under one or more agreements.
2+
// The .NET Foundation licenses this file to you under the MIT license.
3+
4+
using Microsoft.EntityFrameworkCore.Query.SqlExpressions;
5+
6+
// ReSharper disable once CheckNamespace
7+
namespace Microsoft.EntityFrameworkCore.Sqlite.Query.Internal;
8+
9+
/// <summary>
10+
/// This is an internal API that supports the Entity Framework Core infrastructure and not subject to
11+
/// the same compatibility standards as public APIs. It may be changed or removed without notice in
12+
/// any release. You should only use it directly in your code with extreme caution and knowing that
13+
/// doing so can result in application failures when updating to a new Entity Framework Core release.
14+
/// </summary>
15+
public class SqliteAggregateFunctionExpression : SqlExpression
16+
{
17+
private static ConstructorInfo? _quotingConstructor;
18+
19+
/// <summary>
20+
/// This is an internal API that supports the Entity Framework Core infrastructure and not subject to
21+
/// the same compatibility standards as public APIs. It may be changed or removed without notice in
22+
/// any release. You should only use it directly in your code with extreme caution and knowing that
23+
/// doing so can result in application failures when updating to a new Entity Framework Core release.
24+
/// </summary>
25+
public SqliteAggregateFunctionExpression(
26+
string name,
27+
IReadOnlyList<SqlExpression> arguments,
28+
IReadOnlyList<OrderingExpression> orderings,
29+
bool nullable,
30+
IEnumerable<bool> argumentsPropagateNullability,
31+
Type type,
32+
RelationalTypeMapping? typeMapping)
33+
: base(type, typeMapping)
34+
{
35+
Name = name;
36+
Arguments = arguments.ToList();
37+
Orderings = orderings;
38+
IsNullable = nullable;
39+
ArgumentsPropagateNullability = argumentsPropagateNullability.ToList();
40+
}
41+
42+
/// <summary>
43+
/// The name of the aggregate SQL function, e.g. <c>group_concat</c>.
44+
/// </summary>
45+
public virtual string Name { get; }
46+
47+
/// <summary>
48+
/// The arguments passed to the aggregate function.
49+
/// </summary>
50+
public virtual IReadOnlyList<SqlExpression> Arguments { get; }
51+
52+
/// <summary>
53+
/// The orderings applied to the aggregated input, rendered inside the function call as
54+
/// <c>group_concat(value, separator ORDER BY ...)</c>.
55+
/// </summary>
56+
public virtual IReadOnlyList<OrderingExpression> Orderings { get; }
57+
58+
/// <summary>
59+
/// Whether the expression is nullable.
60+
/// </summary>
61+
public virtual bool IsNullable { get; }
62+
63+
/// <summary>
64+
/// For each argument, whether a <see langword="null" /> value propagates to a <see langword="null" /> result.
65+
/// </summary>
66+
public virtual IReadOnlyList<bool> ArgumentsPropagateNullability { get; }
67+
68+
/// <inheritdoc />
69+
protected override Expression VisitChildren(ExpressionVisitor visitor)
70+
{
71+
SqlExpression[]? arguments = null;
72+
for (var i = 0; i < Arguments.Count; i++)
73+
{
74+
var visitedArgument = (SqlExpression)visitor.Visit(Arguments[i]);
75+
if (visitedArgument != Arguments[i] && arguments is null)
76+
{
77+
arguments = new SqlExpression[Arguments.Count];
78+
79+
for (var j = 0; j < i; j++)
80+
{
81+
arguments[j] = Arguments[j];
82+
}
83+
}
84+
85+
if (arguments is not null)
86+
{
87+
arguments[i] = visitedArgument;
88+
}
89+
}
90+
91+
OrderingExpression[]? orderings = null;
92+
for (var i = 0; i < Orderings.Count; i++)
93+
{
94+
var visitedOrdering = (OrderingExpression)visitor.Visit(Orderings[i]);
95+
if (visitedOrdering != Orderings[i] && orderings is null)
96+
{
97+
orderings = new OrderingExpression[Orderings.Count];
98+
99+
for (var j = 0; j < i; j++)
100+
{
101+
orderings[j] = Orderings[j];
102+
}
103+
}
104+
105+
if (orderings is not null)
106+
{
107+
orderings[i] = visitedOrdering;
108+
}
109+
}
110+
111+
return arguments is not null || orderings is not null
112+
? new SqliteAggregateFunctionExpression(
113+
Name,
114+
arguments ?? Arguments,
115+
orderings ?? Orderings,
116+
IsNullable,
117+
ArgumentsPropagateNullability,
118+
Type,
119+
TypeMapping)
120+
: this;
121+
}
122+
123+
/// <summary>
124+
/// Applies the given type mapping, returning a new expression.
125+
/// </summary>
126+
public virtual SqliteAggregateFunctionExpression ApplyTypeMapping(RelationalTypeMapping? typeMapping)
127+
=> new(
128+
Name,
129+
Arguments,
130+
Orderings,
131+
IsNullable,
132+
ArgumentsPropagateNullability,
133+
Type,
134+
typeMapping ?? TypeMapping);
135+
136+
/// <summary>
137+
/// Returns a new expression with the given arguments and orderings, or this instance if nothing changed.
138+
/// </summary>
139+
public virtual SqliteAggregateFunctionExpression Update(
140+
IReadOnlyList<SqlExpression> arguments,
141+
IReadOnlyList<OrderingExpression> orderings)
142+
=> (ReferenceEquals(arguments, Arguments) || arguments.SequenceEqual(Arguments))
143+
&& (ReferenceEquals(orderings, Orderings) || orderings.SequenceEqual(Orderings))
144+
? this
145+
: new SqliteAggregateFunctionExpression(
146+
Name,
147+
arguments,
148+
orderings,
149+
IsNullable,
150+
ArgumentsPropagateNullability,
151+
Type,
152+
TypeMapping);
153+
154+
/// <inheritdoc />
155+
public override Expression Quote()
156+
=> New(
157+
_quotingConstructor ??= typeof(SqliteAggregateFunctionExpression).GetConstructor(
158+
[
159+
typeof(string), typeof(IReadOnlyList<SqlExpression>), typeof(IReadOnlyList<OrderingExpression>), typeof(bool),
160+
typeof(IEnumerable<bool>), typeof(Type), typeof(RelationalTypeMapping)
161+
])!,
162+
Constant(Name),
163+
NewArrayInit(typeof(SqlExpression), initializers: Arguments.Select(a => a.Quote())),
164+
NewArrayInit(typeof(OrderingExpression), Orderings.Select(o => o.Quote())),
165+
Constant(IsNullable),
166+
NewArrayInit(typeof(bool), initializers: ArgumentsPropagateNullability.Select(n => Constant(n))),
167+
Constant(Type),
168+
RelationalExpressionQuotingUtilities.QuoteTypeMapping(TypeMapping));
169+
170+
/// <inheritdoc />
171+
protected override void Print(ExpressionPrinter expressionPrinter)
172+
{
173+
expressionPrinter.Append(Name);
174+
175+
expressionPrinter.Append("(");
176+
expressionPrinter.VisitCollection(Arguments);
177+
178+
if (Orderings.Count > 0)
179+
{
180+
expressionPrinter.Append(" ORDER BY ");
181+
expressionPrinter.VisitCollection(Orderings);
182+
}
183+
184+
expressionPrinter.Append(")");
185+
}
186+
187+
/// <inheritdoc />
188+
public override bool Equals(object? obj)
189+
=> obj is SqliteAggregateFunctionExpression sqliteAggregateFunctionExpression && Equals(sqliteAggregateFunctionExpression);
190+
191+
private bool Equals(SqliteAggregateFunctionExpression? other)
192+
=> ReferenceEquals(this, other)
193+
|| other is not null
194+
&& base.Equals(other)
195+
&& Name == other.Name
196+
&& Arguments.SequenceEqual(other.Arguments)
197+
&& Orderings.SequenceEqual(other.Orderings);
198+
199+
/// <inheritdoc />
200+
public override int GetHashCode()
201+
{
202+
var hash = new HashCode();
203+
hash.Add(base.GetHashCode());
204+
hash.Add(Name);
205+
206+
for (var i = 0; i < Arguments.Count; i++)
207+
{
208+
hash.Add(Arguments[i]);
209+
}
210+
211+
for (var i = 0; i < Orderings.Count; i++)
212+
{
213+
hash.Add(Orderings[i]);
214+
}
215+
216+
return hash.ToHashCode();
217+
}
218+
}

src/EFCore.Sqlite.Core/Query/Internal/SqliteQuerySqlGenerator.cs

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -36,6 +36,10 @@ protected override Expression VisitExtension(Expression extensionExpression)
3636
GenerateJsonEach(jsonEachExpression);
3737
return extensionExpression;
3838

39+
case SqliteAggregateFunctionExpression aggregateFunctionExpression:
40+
GenerateAggregateFunction(aggregateFunctionExpression);
41+
return extensionExpression;
42+
3943
default:
4044
return base.VisitExtension(extensionExpression);
4145
}
@@ -174,6 +178,40 @@ private void GenerateRegexp(RegexpExpression regexpExpression, bool negated = fa
174178
Visit(regexpExpression.Pattern);
175179
}
176180

181+
private void GenerateAggregateFunction(SqliteAggregateFunctionExpression aggregateFunctionExpression)
182+
{
183+
Sql.Append(aggregateFunctionExpression.Name).Append("(");
184+
185+
for (var i = 0; i < aggregateFunctionExpression.Arguments.Count; i++)
186+
{
187+
if (i > 0)
188+
{
189+
Sql.Append(", ");
190+
}
191+
192+
Visit(aggregateFunctionExpression.Arguments[i]);
193+
}
194+
195+
// Unlike SQL Server's "WITHIN GROUP (ORDER BY ...)", SQLite renders the ordering inside the function
196+
// parentheses: group_concat(value, separator ORDER BY ...). Supported since SQLite 3.44.0.
197+
if (aggregateFunctionExpression.Orderings.Count > 0)
198+
{
199+
Sql.Append(" ORDER BY ");
200+
201+
for (var i = 0; i < aggregateFunctionExpression.Orderings.Count; i++)
202+
{
203+
if (i > 0)
204+
{
205+
Sql.Append(", ");
206+
}
207+
208+
Visit(aggregateFunctionExpression.Orderings[i]);
209+
}
210+
}
211+
212+
Sql.Append(")");
213+
}
214+
177215
/// <summary>
178216
/// This is an internal API that supports the Entity Framework Core infrastructure and not subject to
179217
/// the same compatibility standards as public APIs. It may be changed or removed without notice in

src/EFCore.Sqlite.Core/Query/Internal/SqliteSqlNullabilityProcessor.cs

Lines changed: 63 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -41,6 +41,8 @@ protected override SqlExpression VisitCustomSqlExpression(
4141
{
4242
GlobExpression globExpression => VisitGlob(globExpression, allowOptimizedExpansion, out nullable),
4343
RegexpExpression regexpExpression => VisitRegexp(regexpExpression, allowOptimizedExpansion, out nullable),
44+
SqliteAggregateFunctionExpression aggregateFunctionExpression
45+
=> VisitAggregateFunction(aggregateFunctionExpression, allowOptimizedExpansion, out nullable),
4446
_ => base.VisitCustomSqlExpression(sqlExpression, allowOptimizedExpansion, out nullable)
4547
};
4648

@@ -84,6 +86,67 @@ protected virtual SqlExpression VisitRegexp(
8486
return regexpExpression.Update(match, pattern);
8587
}
8688

89+
/// <summary>
90+
/// This is an internal API that supports the Entity Framework Core infrastructure and not subject to
91+
/// the same compatibility standards as public APIs. It may be changed or removed without notice in
92+
/// any release. You should only use it directly in your code with extreme caution and knowing that
93+
/// doing so can result in application failures when updating to a new Entity Framework Core release.
94+
/// </summary>
95+
protected virtual SqlExpression VisitAggregateFunction(
96+
SqliteAggregateFunctionExpression aggregateFunctionExpression,
97+
bool allowOptimizedExpansion,
98+
out bool nullable)
99+
{
100+
nullable = aggregateFunctionExpression.IsNullable;
101+
102+
SqlExpression[]? arguments = null;
103+
for (var i = 0; i < aggregateFunctionExpression.Arguments.Count; i++)
104+
{
105+
var visitedArgument = Visit(aggregateFunctionExpression.Arguments[i], out _);
106+
if (visitedArgument != aggregateFunctionExpression.Arguments[i] && arguments is null)
107+
{
108+
arguments = new SqlExpression[aggregateFunctionExpression.Arguments.Count];
109+
110+
for (var j = 0; j < i; j++)
111+
{
112+
arguments[j] = aggregateFunctionExpression.Arguments[j];
113+
}
114+
}
115+
116+
if (arguments is not null)
117+
{
118+
arguments[i] = visitedArgument;
119+
}
120+
}
121+
122+
OrderingExpression[]? orderings = null;
123+
for (var i = 0; i < aggregateFunctionExpression.Orderings.Count; i++)
124+
{
125+
var ordering = aggregateFunctionExpression.Orderings[i];
126+
var visitedOrdering = ordering.Update(Visit(ordering.Expression, out _));
127+
if (visitedOrdering != aggregateFunctionExpression.Orderings[i] && orderings is null)
128+
{
129+
orderings = new OrderingExpression[aggregateFunctionExpression.Orderings.Count];
130+
131+
for (var j = 0; j < i; j++)
132+
{
133+
orderings[j] = aggregateFunctionExpression.Orderings[j];
134+
}
135+
}
136+
137+
if (orderings is not null)
138+
{
139+
orderings[i] = visitedOrdering;
140+
}
141+
}
142+
143+
return arguments is not null || orderings is not null
144+
? aggregateFunctionExpression.Update(
145+
arguments ?? aggregateFunctionExpression.Arguments,
146+
orderings ?? aggregateFunctionExpression.Orderings)
147+
: aggregateFunctionExpression;
148+
}
149+
87150
/// <inheritdoc />
88151
protected override SqlExpression VisitSqlFunction(
89152
SqlFunctionExpression sqlFunctionExpression,

0 commit comments

Comments
 (0)