Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 12 additions & 0 deletions src/Compilers/CSharp/Portable/Binder/Binder_Patterns.cs
Original file line number Diff line number Diff line change
Expand Up @@ -650,6 +650,7 @@ internal BoundPattern BindConstantPatternWithFallbackToTypePattern(
BindingDiagnosticBag diagnostics,
out bool hasUnionMatching)
{
NamedTypeSymbol? unionTypeOnEntry = unionType;
NamedTypeSymbol? unionTypeOverride = PrepareForUnionMatchingIfAppropriateAndReturnUnionType(node, ref inputType, diagnostics);
hasUnionMatching = false;

Expand All @@ -676,6 +677,17 @@ internal BoundPattern BindConstantPatternWithFallbackToTypePattern(
diagnostics.Add(ErrorCode.ERR_CannotMatchOnINumberBase, node.Location, inputType);
}

if (constantValueOpt == ConstantValue.Null && unionTypeOverride?.IsValueType == false)
{
Debug.Assert(hasUnionMatching);

// Special case of a null test for a class Union. Its meaning is equivalent to: (<union instance> is null or <union instance>.Value is null)
// Therefore, the type isn't narrowed by this pattern and the following pattern, if any, will do union matching from scratch.
unionType = unionTypeOnEntry;
return new BoundConstantPattern(
node, convertedExpression, constantValueOpt, isUnionMatching: true, inputType: unionTypeOverride, narrowedType: unionTypeOverride, hasErrors);
}

return new BoundConstantPattern(
node, convertedExpression, constantValueOpt ?? ConstantValue.Bad, isUnionMatching: hasUnionMatching, inputType: unionTypeOverride ?? inputType, convertedType, hasErrors || constantValueOpt is null);
}
Expand Down
18 changes: 18 additions & 0 deletions src/Compilers/CSharp/Portable/Binder/UnionMatchingRewriter.cs
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,24 @@ private static BoundPatternWithUnionMatching CreatePatternWithUnionMatching(Name
node = (BoundConstantPattern)base.VisitConstantPattern(node)!;
if (node.IsUnionMatching)
{
Debug.Assert(node.InputType is NamedTypeSymbol { IsUnionType: true });

if (node.ConstantValue == ConstantValue.Null && !node.InputType.IsValueType && node.NarrowedType.Equals(node.InputType, TypeCompareKind.AllIgnoreOptions))
{
// Special case of a null test for a class Union. Its meaning is equivalent to: (<union instance> is null or <union instance>.Value is null)
BoundPatternWithUnionMatching underlyingValueMatching = CreatePatternWithUnionMatching(
(NamedTypeSymbol)node.InputType,
node.Update(node.Value, node.ConstantValue, isUnionMatching: false, inputType: ObjectType, narrowedType: ObjectType));

return new BoundBinaryPattern(
node.Syntax, disjunction: true,
left: node.Update(node.Value, node.ConstantValue, isUnionMatching: false, node.InputType, node.InputType),
right: RewritePatternWithUnionMatchingToPropertyPattern(underlyingValueMatching),
inputType: node.InputType,
narrowedType: node.InputType)
{ WasCompilerGenerated = true };
}

return CreatePatternWithUnionMatching(
(NamedTypeSymbol)node.InputType,
node.Update(node.Value, node.ConstantValue, isUnionMatching: false, inputType: ObjectType, narrowedType: node.NarrowedType));
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,8 @@ private partial void Validate()
if (IsUnionMatching)
{
Debug.Assert(NarrowedType.IsObjectType() ||
NarrowedType.Equals(Value.Type, TypeCompareKind.AllIgnoreOptions));
NarrowedType.Equals(Value.Type, TypeCompareKind.AllIgnoreOptions) ||
(ConstantValue == ConstantValue.Null && !InputType.IsValueType && NarrowedType.Equals(InputType, TypeCompareKind.AllIgnoreOptions)));
}
else
{
Expand Down
90 changes: 70 additions & 20 deletions src/Compilers/CSharp/Portable/FlowAnalysis/NullableWalker.cs
Original file line number Diff line number Diff line change
Expand Up @@ -7928,12 +7928,30 @@ private void ApplyMemberPostConditions(int receiverSlot, MethodSymbol method)
applyMemberPostConditions(receiverSlot, type, notNullMembers, ref State);
}

if (method.ReturnType.SpecialType == SpecialType.System_Boolean
&& !(notNullWhenTrueMembers.IsEmpty && notNullWhenFalseMembers.IsEmpty))
if (method.ReturnType.SpecialType == SpecialType.System_Boolean)
{
Split();
applyMemberPostConditions(receiverSlot, type, notNullWhenTrueMembers, ref StateWhenTrue);
applyMemberPostConditions(receiverSlot, type, notNullWhenFalseMembers, ref StateWhenFalse);
if (!(notNullWhenTrueMembers.IsEmpty && notNullWhenFalseMembers.IsEmpty))
{
Split();
applyMemberPostConditions(receiverSlot, type, notNullWhenTrueMembers, ref StateWhenTrue);
applyMemberPostConditions(receiverSlot, type, notNullWhenFalseMembers, ref StateWhenFalse);
}

if (method is MethodSymbol
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It might be reasonable to share the isTryGetValueSignature helper from GetUnionTypeTryGetValueMethod here.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It might be reasonable to share the isTryGetValueSignature helper from GetUnionTypeTryGetValueMethod here.

If I remember correctly the logic doesn't match exactly.

{
Name: WellKnownMemberNames.TryGetValueMethodName,
ReturnType.SpecialType: SpecialType.System_Boolean,
DeclaredAccessibility: Accessibility.Public,
RefKind: RefKind.None,
Parameters: [{ RefKind: RefKind.Out, Type: var parameterType }]
} tryGetValue &&
tryGetValue.ContainingType is { IsUnionType: true } unionType &&
(object)tryGetValue.OriginalDefinition == Binder.GetUnionTypeTryGetValueMethod(unionType, parameterType)?.OriginalDefinition && // Looking for TryGetValue with exact type match at this call site
Binder.GetUnionTypeValuePropertyNoUseSiteDiagnostics(unionType) is { } unionValue)
{
Split();
markMemberAsNotNull(receiverSlot, ref StateWhenTrue, unionValue);
}
}

method = method.OverriddenMethod;
Expand Down Expand Up @@ -7972,18 +7990,25 @@ void markMembersAsNotNull(int receiverSlot, TypeSymbol type, string memberName,
{
case SymbolKind.Field:
case SymbolKind.Property:
if (GetOrCreateSlot(member, receiverSlot) is int memberSlot &&
memberSlot > 0)
{
SetState(ref state, memberSlot, NullableFlowState.NotNull);
}
state = markMemberAsNotNull(receiverSlot, ref state, member);
break;
case SymbolKind.Event:
case SymbolKind.Method:
break;
}
}
}

LocalState markMemberAsNotNull(int receiverSlot, ref LocalState state, Symbol member)
{
if (GetOrCreateSlot(member, receiverSlot) is int memberSlot &&
memberSlot > 0)
{
SetState(ref state, memberSlot, NullableFlowState.NotNull);
}

return state;
}
}

private ImmutableArray<VisitResult> VisitArgumentsEvaluate(
Expand Down Expand Up @@ -8286,7 +8311,7 @@ private void VisitArgumentOutboundAssignmentsAndPostConditions(

var parameterValue = new BoundParameter(argument.Syntax, parameter);
var lValueType = result.LValueType;
trackNullableStateForAssignment(parameterValue, lValueType, MakeSlot(argument), parameterWithState, argument.IsSuppressed, parameterAnnotations);
trackNullableStateForAssignment(parameterValue, lValueType, MakeSlot(argument), parameterWithState, argument.IsSuppressed, parameterAnnotations, refKind, parameter);

// check whether parameter would unsafely let a null out in the worse case
if (!argument.IsSuppressed)
Expand Down Expand Up @@ -8330,7 +8355,7 @@ private void VisitArgumentOutboundAssignmentsAndPostConditions(
CheckDisallowedNullAssignment(parameterWithState, leftAnnotations, argument.Syntax);

AdjustSetValue(argument, ref parameterWithState);
trackNullableStateForAssignment(parameterValue, lValueType, MakeSlot(argument), parameterWithState, argument.IsSuppressed, parameterAnnotations);
trackNullableStateForAssignment(parameterValue, lValueType, MakeSlot(argument), parameterWithState, argument.IsSuppressed, parameterAnnotations, refKind, parameter);

// report warnings if parameter would unsafely let a null out in the worst case
if (!argument.IsSuppressed)
Expand Down Expand Up @@ -8368,9 +8393,9 @@ FlowAnalysisAnnotations notNullBasedOnParameters(FlowAnalysisAnnotations paramet
return parameterAnnotations;
}

void trackNullableStateForAssignment(BoundExpression parameterValue, TypeWithAnnotations lValueType, int targetSlot, TypeWithState parameterWithState, bool isSuppressed, FlowAnalysisAnnotations parameterAnnotations)
void trackNullableStateForAssignment(BoundExpression parameterValue, TypeWithAnnotations lValueType, int targetSlot, TypeWithState parameterWithState, bool isSuppressed, FlowAnalysisAnnotations parameterAnnotations, RefKind refKind, ParameterSymbol parameter)
{
if (!IsConditionalState && !hasConditionalPostCondition(parameterAnnotations))
if (!IsConditionalState && !hasConditionalPostCondition(parameterAnnotations, refKind, parameter))
{
TrackNullableStateForAssignment(parameterValue, lValueType, targetSlot, parameterWithState.WithSuppression(isSuppressed));
}
Expand All @@ -8381,7 +8406,7 @@ void trackNullableStateForAssignment(BoundExpression parameterValue, TypeWithAnn

SetState(StateWhenTrue);
// Note: the suppression applies over the post-condition attributes
TrackNullableStateForAssignment(parameterValue, lValueType, targetSlot, applyPostConditionsWhenTrue(parameterWithState, parameterAnnotations).WithSuppression(isSuppressed));
TrackNullableStateForAssignment(parameterValue, lValueType, targetSlot, applyPostConditionsWhenTrue(parameterWithState, parameterAnnotations, refKind, parameter).WithSuppression(isSuppressed));
Debug.Assert(!IsConditionalState);
var newWhenTrue = State.Clone();

Expand All @@ -8393,10 +8418,35 @@ void trackNullableStateForAssignment(BoundExpression parameterValue, TypeWithAnn
}
}

static bool hasConditionalPostCondition(FlowAnalysisAnnotations annotations)
static bool hasConditionalPostCondition(FlowAnalysisAnnotations annotations, RefKind refKind, ParameterSymbol parameter)
{
if ((((annotations & FlowAnalysisAnnotations.MaybeNullWhenTrue) != 0) ^ ((annotations & FlowAnalysisAnnotations.MaybeNullWhenFalse) != 0)) ||
(((annotations & FlowAnalysisAnnotations.NotNullWhenTrue) != 0) ^ ((annotations & FlowAnalysisAnnotations.NotNullWhenFalse) != 0)))
{
return true;
}

return isUnionTryGetValueValue(refKind, parameter);
}

static bool isUnionTryGetValueValue(RefKind refKind, ParameterSymbol parameter)
{
return (((annotations & FlowAnalysisAnnotations.MaybeNullWhenTrue) != 0) ^ ((annotations & FlowAnalysisAnnotations.MaybeNullWhenFalse) != 0)) ||
(((annotations & FlowAnalysisAnnotations.NotNullWhenTrue) != 0) ^ ((annotations & FlowAnalysisAnnotations.NotNullWhenFalse) != 0));
if (refKind == RefKind.Out &&
parameter.ContainingSymbol is MethodSymbol
{
Name: WellKnownMemberNames.TryGetValueMethodName,
ReturnType.SpecialType: SpecialType.System_Boolean,
DeclaredAccessibility: Accessibility.Public,
RefKind: RefKind.None,
ParameterCount: 1
} tryGetValue &&
tryGetValue.ContainingType is { IsUnionType: true } unionType &&
(object)tryGetValue.OriginalDefinition == Binder.GetUnionTypeTryGetValueMethod(unionType, parameter.Type)?.OriginalDefinition) // Looking for TryGetValue with exact type match at this call site
{
return true;
}

return false;
}

static TypeWithState applyPostConditionsUnconditionally(TypeWithState typeWithState, FlowAnalysisAnnotations annotations)
Expand All @@ -8416,7 +8466,7 @@ static TypeWithState applyPostConditionsUnconditionally(TypeWithState typeWithSt
return typeWithState;
}

static TypeWithState applyPostConditionsWhenTrue(TypeWithState typeWithState, FlowAnalysisAnnotations annotations)
static TypeWithState applyPostConditionsWhenTrue(TypeWithState typeWithState, FlowAnalysisAnnotations annotations, RefKind refKind, ParameterSymbol parameter)
{
bool notNullWhenTrue = (annotations & FlowAnalysisAnnotations.NotNullWhenTrue) != 0;
bool maybeNullWhenTrue = (annotations & FlowAnalysisAnnotations.MaybeNullWhenTrue) != 0;
Expand All @@ -8427,7 +8477,7 @@ static TypeWithState applyPostConditionsWhenTrue(TypeWithState typeWithState, Fl
// [MaybeNull, NotNullWhen(true)] means [MaybeNullWhen(false)]
return TypeWithState.Create(typeWithState.Type, NullableFlowState.MaybeDefault);
}
else if (notNullWhenTrue)
else if (notNullWhenTrue || isUnionTryGetValueValue(refKind, parameter))
{
return TypeWithState.Create(typeWithState.Type, NullableFlowState.NotNull);
}
Expand Down
Loading
Loading