@@ -826,8 +826,8 @@ public virtual SqlExpression Case(
826
826
}
827
827
828
828
// Simplify:
829
- // a == null ? null : a -> a
830
- // a != null ? a : null -> a
829
+ // a == b ? b : a -> a
830
+ // a != b ? a : b -> a
831
831
// And lift:
832
832
// a == b ? null : a -> NULLIF(a, b)
833
833
// a != b ? a : null -> NULLIF(a, b)
@@ -838,28 +838,39 @@ public virtual SqlExpression Case(
838
838
Test : SqlBinaryExpression { OperatorType : ExpressionType . Equal or ExpressionType . NotEqual } binary ,
839
839
Result : var result
840
840
}
841
- ]
842
- && binary . OperatorType switch
843
- {
844
- ExpressionType . Equal when result is SqlConstantExpression { Value : null } && elseResult is not null => elseResult ,
845
- ExpressionType . NotEqual when elseResult is null or SqlConstantExpression { Value : null } => result ,
846
- _ => null
847
- } is SqlExpression conditionalResult )
841
+ ] )
848
842
{
849
843
var ( left , right ) = ( binary . Left , binary . Right ) ;
850
844
851
- if ( left . Equals ( conditionalResult ) )
845
+ // Reverse ifEqual/ifNotEqual for ExpressionType.NotEqual for easier reasonining below
846
+ var ( ifEqual , ifNotEqual ) = binary . OperatorType is ExpressionType . Equal
847
+ ? ( result , elseResult ?? Constant ( null , result . Type , result . TypeMapping ) )
848
+ : ( elseResult ?? Constant ( null , result . Type , result . TypeMapping ) , result ) ;
849
+
850
+ if ( left . Equals ( ifNotEqual ) )
852
851
{
853
- return right is SqlConstantExpression { Value : null }
854
- ? left
855
- : Function ( "NULLIF" , [ left , right ] , nullable : true , [ false , false ] , left . Type , left . TypeMapping ) ;
852
+ switch ( ifEqual )
853
+ {
854
+ // a == b ? b : a -> a
855
+ case SqlConstantExpression { Value : null } :
856
+ return Function ( "NULLIF" , [ left , right ] , nullable : true , [ false , false ] , left . Type , left . TypeMapping ) ;
857
+ // a == b ? null : a -> NULLIF(a, b)
858
+ case var _ when ifEqual . Equals ( right ) :
859
+ return left ;
860
+ }
856
861
}
857
862
858
- if ( right . Equals ( conditionalResult ) )
863
+ if ( right . Equals ( ifNotEqual ) )
859
864
{
860
- return left is SqlConstantExpression { Value : null }
861
- ? right
862
- : Function ( "NULLIF" , [ right , left ] , nullable : true , [ false , false ] , right . Type , right . TypeMapping ) ;
865
+ switch ( ifEqual )
866
+ {
867
+ // b == a ? b : a -> a
868
+ case SqlConstantExpression { Value : null } :
869
+ return Function ( "NULLIF" , [ right , left ] , nullable : true , [ false , false ] , right . Type , right . TypeMapping ) ;
870
+ // b == a ? null : a -> NULLIF(a, b)
871
+ case var _ when ifEqual . Equals ( left ) :
872
+ return right ;
873
+ }
863
874
}
864
875
}
865
876
0 commit comments