Skip to content

Commit 79c2d24

Browse files
committed
Generalize simplification as suggested
1 parent f4472db commit 79c2d24

File tree

2 files changed

+30
-19
lines changed

2 files changed

+30
-19
lines changed

Diff for: src/EFCore.Relational/Query/SqlExpressionFactory.cs

+28-17
Original file line numberDiff line numberDiff line change
@@ -826,8 +826,8 @@ public virtual SqlExpression Case(
826826
}
827827

828828
// Simplify:
829-
// a == null ? null : a -> a
830-
// a != null ? a : null -> a
829+
// a == b ? b : a -> a
830+
// a != b ? a : b -> a
831831
// And lift:
832832
// a == b ? null : a -> NULLIF(a, b)
833833
// a != b ? a : null -> NULLIF(a, b)
@@ -838,28 +838,39 @@ public virtual SqlExpression Case(
838838
Test: SqlBinaryExpression { OperatorType: ExpressionType.Equal or ExpressionType.NotEqual } binary,
839839
Result: var result
840840
}
841-
]
842-
&& binary.OperatorType switch
843-
{
844-
ExpressionType.Equal when result is SqlConstantExpression { Value: null } && elseResult is not null => elseResult,
845-
ExpressionType.NotEqual when elseResult is null or SqlConstantExpression { Value: null } => result,
846-
_ => null
847-
} is SqlExpression conditionalResult)
841+
])
848842
{
849843
var (left, right) = (binary.Left, binary.Right);
850844

851-
if (left.Equals(conditionalResult))
845+
// Reverse ifEqual/ifNotEqual for ExpressionType.NotEqual for easier reasonining below
846+
var (ifEqual, ifNotEqual) = binary.OperatorType is ExpressionType.Equal
847+
? (result, elseResult ?? Constant(null, result.Type, result.TypeMapping))
848+
: (elseResult ?? Constant(null, result.Type, result.TypeMapping), result);
849+
850+
if (left.Equals(ifNotEqual))
852851
{
853-
return right is SqlConstantExpression { Value: null }
854-
? left
855-
: Function("NULLIF", [left, right], nullable: true, [false, false], left.Type, left.TypeMapping);
852+
switch (ifEqual)
853+
{
854+
// a == b ? b : a -> a
855+
case SqlConstantExpression { Value: null }:
856+
return Function("NULLIF", [left, right], nullable: true, [false, false], left.Type, left.TypeMapping);
857+
// a == b ? null : a -> NULLIF(a, b)
858+
case var _ when ifEqual.Equals(right):
859+
return left;
860+
}
856861
}
857862

858-
if (right.Equals(conditionalResult))
863+
if (right.Equals(ifNotEqual))
859864
{
860-
return left is SqlConstantExpression { Value: null }
861-
? right
862-
: Function("NULLIF", [right, left], nullable: true, [false, false], right.Type, right.TypeMapping);
865+
switch (ifEqual)
866+
{
867+
// b == a ? b : a -> a
868+
case SqlConstantExpression { Value: null }:
869+
return Function("NULLIF", [right, left], nullable: true, [false, false], right.Type, right.TypeMapping);
870+
// b == a ? null : a -> NULLIF(a, b)
871+
case var _ when ifEqual.Equals(left):
872+
return right;
873+
}
863874
}
864875
}
865876

Diff for: test/EFCore.Specification.Tests/Query/Translations/OperatorTranslationsTestBase.cs

+2-2
Original file line numberDiff line numberDiff line change
@@ -16,15 +16,15 @@ public virtual Task Conditional_simplifiable_equality(bool async)
1616
=> AssertQuery(
1717
async,
1818
// ReSharper disable once MergeConditionalExpression
19-
cs => cs.Set<NullableBasicTypesEntity>().Where(x => (x.Int == null ? null : x.Int) > 1));
19+
cs => cs.Set<NullableBasicTypesEntity>().Where(x => (x.Int == 9 ? 9 : x.Int) > 1));
2020

2121
[ConditionalTheory]
2222
[MemberData(nameof(IsAsyncData))]
2323
public virtual Task Conditional_simplifiable_inequality(bool async)
2424
=> AssertQuery(
2525
async,
2626
// ReSharper disable once MergeConditionalExpression
27-
cs => cs.Set<NullableBasicTypesEntity>().Where(x => (x.Int != null ? x.Int : null) > 1));
27+
cs => cs.Set<NullableBasicTypesEntity>().Where(x => (x.Int != 8 ? x.Int : 8) > 1));
2828

2929
// In relational providers, x == a ? null : x ("un-coalescing conditional") is translated to SQL NULLIF
3030

0 commit comments

Comments
 (0)