Skip to content

Move runtime async method validation into initial binding #78310

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
Merged
149 changes: 144 additions & 5 deletions src/Compilers/CSharp/Portable/Binder/Binder_Await.cs
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@
using System.Diagnostics.CodeAnalysis;
using Microsoft.CodeAnalysis.CSharp.Symbols;
using Microsoft.CodeAnalysis.CSharp.Syntax;
using Microsoft.CodeAnalysis.PooledObjects;
using Roslyn.Utilities;

namespace Microsoft.CodeAnalysis.CSharp
Expand Down Expand Up @@ -37,7 +38,7 @@ private BoundAwaitExpression BindAwait(BoundExpression expression, SyntaxNode no
// The expression await t is classified the same way as the expression (t).GetAwaiter().GetResult(). Thus,
// if the return type of GetResult is void, the await-expression is classified as nothing. If it has a
// non-void return type T, the await-expression is classified as a value of type T.
TypeSymbol awaitExpressionType = info.GetResult?.ReturnType ?? (hasErrors ? CreateErrorType() : Compilation.DynamicType);
TypeSymbol awaitExpressionType = (info.GetResult ?? info.RuntimeAsyncAwaitMethod)?.ReturnType ?? (hasErrors ? CreateErrorType() : Compilation.DynamicType);

return new BoundAwaitExpression(node, expression, info, debugInfo: default, awaitExpressionType, hasErrors);
}
Expand All @@ -58,11 +59,12 @@ internal BoundAwaitableInfo BindAwaitInfo(BoundAwaitableValuePlaceholder placeho
out PropertySymbol? isCompleted,
out MethodSymbol? getResult,
getAwaiterGetResultCall: out _,
out MethodSymbol? runtimeAsyncAwaitCall,
node,
diagnostics);
hasErrors |= hasGetAwaitableErrors;

return new BoundAwaitableInfo(node, placeholder, isDynamic: isDynamic, getAwaiter, isCompleted, getResult, hasErrors: hasGetAwaitableErrors) { WasCompilerGenerated = true };
return new BoundAwaitableInfo(node, placeholder, isDynamic: isDynamic, getAwaiter, isCompleted, getResult, runtimeAsyncAwaitCall, hasErrors: hasGetAwaitableErrors) { WasCompilerGenerated = true };
}

/// <summary>
Expand Down Expand Up @@ -123,7 +125,7 @@ private bool CouldBeAwaited(BoundExpression expression)
return false;
}

return GetAwaitableExpressionInfo(expression, getAwaiterGetResultCall: out _,
return GetAwaitableExpressionInfo(expression, getAwaiterGetResultCall: out _, runtimeAsyncAwaitCall: out _,
node: syntax, diagnostics: BindingDiagnosticBag.Discarded);
}

Expand Down Expand Up @@ -242,10 +244,11 @@ private bool ReportBadAwaitContext(SyntaxNodeOrToken nodeOrToken, BindingDiagnos
internal bool GetAwaitableExpressionInfo(
BoundExpression expression,
out BoundExpression? getAwaiterGetResultCall,
out MethodSymbol? runtimeAsyncAwaitCall,
SyntaxNode node,
BindingDiagnosticBag diagnostics)
{
return GetAwaitableExpressionInfo(expression, expression, out _, out _, out _, out _, out getAwaiterGetResultCall, node, diagnostics);
return GetAwaitableExpressionInfo(expression, expression, out _, out _, out _, out _, out getAwaiterGetResultCall, out runtimeAsyncAwaitCall, node, diagnostics);
}

private bool GetAwaitableExpressionInfo(
Expand All @@ -256,6 +259,7 @@ private bool GetAwaitableExpressionInfo(
out PropertySymbol? isCompleted,
out MethodSymbol? getResult,
out BoundExpression? getAwaiterGetResultCall,
out MethodSymbol? runtimeAsyncAwaitCall,
SyntaxNode node,
BindingDiagnosticBag diagnostics)
{
Expand All @@ -266,6 +270,7 @@ private bool GetAwaitableExpressionInfo(
isCompleted = null;
getResult = null;
getAwaiterGetResultCall = null;
runtimeAsyncAwaitCall = null;

if (!ValidateAwaitedExpression(expression, node, diagnostics))
{
Expand All @@ -274,10 +279,21 @@ private bool GetAwaitableExpressionInfo(

if (expression.HasDynamicType())
{
// PROTOTYPE: Handle runtime async here
isDynamic = true;
return true;
}

var isRuntimeAsyncEnabled = Compilation.IsRuntimeAsyncEnabledIn(this.ContainingMemberOrLambda);

// When RuntimeAsync is enabled, we first check for whether there is an AsyncHelpers.Await method that can handle the expression.
// PROTOTYPE: Do the full algorithm specified in https://github.com/dotnet/roslyn/pull/77957

if (tryGetRuntimeAwaitHelper(out runtimeAsyncAwaitCall))
{
return true;
}

if (!GetGetAwaiterMethod(getAwaiterArgument, node, diagnostics, out getAwaiter))
{
return false;
Expand All @@ -286,7 +302,130 @@ private bool GetAwaitableExpressionInfo(
TypeSymbol awaiterType = getAwaiter.Type!;
return GetIsCompletedProperty(awaiterType, node, expression.Type!, diagnostics, out isCompleted)
&& AwaiterImplementsINotifyCompletion(awaiterType, node, diagnostics)
&& GetGetResultMethod(getAwaiter, node, expression.Type!, diagnostics, out getResult, out getAwaiterGetResultCall);
&& GetGetResultMethod(getAwaiter, node, expression.Type!, diagnostics, out getResult, out getAwaiterGetResultCall)
&& (!isRuntimeAsyncEnabled || getRuntimeAwaitAwaiter(awaiterType, out runtimeAsyncAwaitCall));

bool tryGetRuntimeAwaitHelper(out MethodSymbol? runtimeAwaitHelper)
{
if (!isRuntimeAsyncEnabled)
{
runtimeAwaitHelper = null;
return false;
}

var exprOriginalType = expression.Type!.OriginalDefinition;
SpecialMember awaitCall;
TypeWithAnnotations? maybeNestedType = null;
if (ReferenceEquals(exprOriginalType, GetSpecialType(InternalSpecialType.System_Threading_Tasks_Task, diagnostics, expression.Syntax)))
{
awaitCall = SpecialMember.System_Runtime_CompilerServices_AsyncHelpers__AwaitTask;
}
else if (ReferenceEquals(exprOriginalType, GetSpecialType(InternalSpecialType.System_Threading_Tasks_Task_T, diagnostics, expression.Syntax)))
{
awaitCall = SpecialMember.System_Runtime_CompilerServices_AsyncHelpers__AwaitTaskT_T;
maybeNestedType = ((NamedTypeSymbol)expression.Type).TypeArgumentsWithAnnotationsNoUseSiteDiagnostics[0];
}
else if (ReferenceEquals(exprOriginalType, GetSpecialType(InternalSpecialType.System_Threading_Tasks_ValueTask, diagnostics, expression.Syntax)))
{
awaitCall = SpecialMember.System_Runtime_CompilerServices_AsyncHelpers__AwaitValueTask;
}
else if (ReferenceEquals(exprOriginalType, GetSpecialType(InternalSpecialType.System_Threading_Tasks_ValueTask_T, diagnostics, expression.Syntax)))
{
awaitCall = SpecialMember.System_Runtime_CompilerServices_AsyncHelpers__AwaitValueTaskT_T;
maybeNestedType = ((NamedTypeSymbol)expression.Type).TypeArgumentsWithAnnotationsNoUseSiteDiagnostics[0];
}
else
{
runtimeAwaitHelper = null;
return false;
}

runtimeAwaitHelper = (MethodSymbol)GetSpecialTypeMember(awaitCall, diagnostics, expression.Syntax);

if (runtimeAwaitHelper is null)
{
return false;
}

if (maybeNestedType is { } nestedType)
{
Debug.Assert(runtimeAwaitHelper.TypeParameters.Length == 1);
runtimeAwaitHelper = runtimeAwaitHelper.Construct([nestedType]);
checkMethodGenericConstraints(runtimeAwaitHelper, diagnostics, expression.Syntax.Location);
}
#if DEBUG
else
{
Debug.Assert(runtimeAwaitHelper.TypeParameters.Length == 0);
}
#endif

return true;
}

bool getRuntimeAwaitAwaiter(TypeSymbol awaiterType, out MethodSymbol? runtimeAwaitAwaiterMethod)
{
// Use site info is discarded because we don't actually do this conversion, we just need to know which generic
// method to call.
var discardedUseSiteInfo = CompoundUseSiteInfo<AssemblySymbol>.Discarded;
var useUnsafeAwait = Compilation.Conversions.ClassifyImplicitConversionFromType(
awaiterType,
Compilation.GetSpecialType(InternalSpecialType.System_Runtime_CompilerServices_ICriticalNotifyCompletion),
ref discardedUseSiteInfo).IsImplicit;

var awaitMethod = (MethodSymbol?)GetSpecialTypeMember(
useUnsafeAwait
? SpecialMember.System_Runtime_CompilerServices_AsyncHelpers__UnsafeAwaitAwaiter_TAwaiter
: SpecialMember.System_Runtime_CompilerServices_AsyncHelpers__AwaitAwaiter_TAwaiter,
diagnostics,
expression.Syntax);

if (awaitMethod is null)
{
runtimeAwaitAwaiterMethod = null;
return false;
}

Debug.Assert(awaitMethod is { Arity: 1 });

runtimeAwaitAwaiterMethod = awaitMethod.Construct(awaiterType);
checkMethodGenericConstraints(runtimeAwaitAwaiterMethod, diagnostics, expression.Syntax.Location);

return true;
}

void checkMethodGenericConstraints(MethodSymbol method, BindingDiagnosticBag diagnostics, Location location)
{
var diagnosticsBuilder = ArrayBuilder<TypeParameterDiagnosticInfo>.GetInstance();
ArrayBuilder<TypeParameterDiagnosticInfo>? useSiteDiagnosticsBuilder = null;
ConstraintsHelper.CheckMethodConstraints(
method,
new ConstraintsHelper.CheckConstraintsArgs(this.Compilation, this.Conversions, includeNullability: false, location, diagnostics: null),
diagnosticsBuilder,
nullabilityDiagnosticsBuilderOpt: null,
ref useSiteDiagnosticsBuilder);

foreach (var pair in diagnosticsBuilder)
{
if (pair.UseSiteInfo.DiagnosticInfo is { } diagnosticInfo)
{
diagnostics.Add(diagnosticInfo, location);
}
diagnosticsBuilder.Free();
}

if (useSiteDiagnosticsBuilder is { })
{
foreach (var pair in useSiteDiagnosticsBuilder)
{
if (pair.UseSiteInfo.DiagnosticInfo is { } diagnosticInfo)
{
diagnostics.Add(diagnosticInfo, location);
}
}
useSiteDiagnosticsBuilder.Free();
}
}
}

/// <summary>
Expand Down
6 changes: 3 additions & 3 deletions src/Compilers/CSharp/Portable/Binder/Binder_Symbols.cs
Original file line number Diff line number Diff line change
Expand Up @@ -1654,20 +1654,20 @@ NamespaceOrTypeOrAliasSymbolWithAnnotations convertToUnboundGenericType()
}
}

internal NamedTypeSymbol GetSpecialType(SpecialType typeId, BindingDiagnosticBag diagnostics, SyntaxNode node)
internal NamedTypeSymbol GetSpecialType(ExtendedSpecialType typeId, BindingDiagnosticBag diagnostics, SyntaxNode node)
{
return GetSpecialType(this.Compilation, typeId, node, diagnostics);
}

internal static NamedTypeSymbol GetSpecialType(CSharpCompilation compilation, SpecialType typeId, SyntaxNode node, BindingDiagnosticBag diagnostics)
internal static NamedTypeSymbol GetSpecialType(CSharpCompilation compilation, ExtendedSpecialType typeId, SyntaxNode node, BindingDiagnosticBag diagnostics)
{
NamedTypeSymbol typeSymbol = compilation.GetSpecialType(typeId);
Debug.Assert((object)typeSymbol != null, "Expect an error type if special type isn't found");
ReportUseSite(typeSymbol, diagnostics, node);
return typeSymbol;
}

internal static NamedTypeSymbol GetSpecialType(CSharpCompilation compilation, SpecialType typeId, Location location, BindingDiagnosticBag diagnostics)
internal static NamedTypeSymbol GetSpecialType(CSharpCompilation compilation, ExtendedSpecialType typeId, Location location, BindingDiagnosticBag diagnostics)
{
NamedTypeSymbol typeSymbol = compilation.GetSpecialType(typeId);
Debug.Assert((object)typeSymbol != null, "Expect an error type if special type isn't found");
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -150,7 +150,7 @@ internal static BoundStatement BindUsingStatementOrDeclarationFromParts(SyntaxNo

if (awaitableTypeOpt is null)
{
awaitOpt = new BoundAwaitableInfo(syntax, awaitableInstancePlaceholder: null, isDynamic: true, getAwaiter: null, isCompleted: null, getResult: null) { WasCompilerGenerated = true };
awaitOpt = new BoundAwaitableInfo(syntax, awaitableInstancePlaceholder: null, isDynamic: true, getAwaiter: null, isCompleted: null, getResult: null, runtimeAsyncAwaitMethod: null) { WasCompilerGenerated = true };
}
else
{
Expand Down
4 changes: 4 additions & 0 deletions src/Compilers/CSharp/Portable/BoundTree/BoundNodes.xml
Original file line number Diff line number Diff line change
Expand Up @@ -693,6 +693,10 @@
<Field Name="GetAwaiter" Type="BoundExpression?" Null="allow"/>
<Field Name="IsCompleted" Type="PropertySymbol?" Null="allow"/>
<Field Name="GetResult" Type="MethodSymbol?" Null="allow"/>
<!-- Refers to the runtime async helper we call for awaiting. Either this is an instance of an AsyncHelpers.Await call, and
GetAwaiter, IsCompleted, and GetResult are null, or this is AsyncHelpers.AwaitAwaiter/UnsafeAwaitAwaiter, and the other
fields are not null. -->
<Field Name="RuntimeAsyncAwaitMethod" Type="MethodSymbol?" Null="allow"/>
</Node>

<Node Name="BoundAwaitExpression" Base="BoundExpression">
Expand Down
25 changes: 20 additions & 5 deletions src/Compilers/CSharp/Portable/Compilation/CSharpCompilation.cs
Original file line number Diff line number Diff line change
Expand Up @@ -316,7 +316,7 @@ internal bool IsNullableAnalysisEnabledAlways
/// Returns true if this method should be processed with runtime async handling instead
/// of compiler async state machine generation.
/// </summary>
internal bool IsRuntimeAsyncEnabledIn(MethodSymbol method)
internal bool IsRuntimeAsyncEnabledIn(Symbol? symbol)
{
// PROTOTYPE: EE tests fail this assert, handle and test
//Debug.Assert(ReferenceEquals(method.ContainingAssembly, Assembly));
Expand All @@ -325,7 +325,21 @@ internal bool IsRuntimeAsyncEnabledIn(MethodSymbol method)
return false;
}

return method switch
if (symbol is not MethodSymbol method)
{
return false;
}

var methodReturn = method.ReturnType.OriginalDefinition;
if (!ReferenceEquals(methodReturn, GetSpecialType(InternalSpecialType.System_Threading_Tasks_Task))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It would be nice if there a way to say something like methodReturn.InternalSpecialType is InternalSpecialType.Task or Task_T or ValueTask or ValueTask_T. No need to try and dig up a way to do that though.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually, I think we can, since this is special types, not well-known types.

&& !ReferenceEquals(methodReturn, GetSpecialType(InternalSpecialType.System_Threading_Tasks_Task_T))
&& !ReferenceEquals(methodReturn, GetSpecialType(InternalSpecialType.System_Threading_Tasks_ValueTask))
&& !ReferenceEquals(methodReturn, GetSpecialType(InternalSpecialType.System_Threading_Tasks_ValueTask_T)))
{
return false;
Copy link
Member

@jcouv jcouv Apr 25, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is there a corresponding update to the design doc? Or should we have a follow-up comment to make the void-returning method scenario work at some point? #Pending

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Need to update the design doc, the runtime-side is very clear that only Task/ValueTask methods can be runtime async.

}

return symbol switch
{
SourceMethodSymbol { IsRuntimeAsyncEnabledInMethod: ThreeState.True } => true,
SourceMethodSymbol { IsRuntimeAsyncEnabledInMethod: ThreeState.False } => false,
Expand Down Expand Up @@ -2211,11 +2225,12 @@ internal bool ReturnsAwaitableToVoidOrInt(MethodSymbol method, BindingDiagnostic
var dumbInstance = new BoundLiteral(syntax, ConstantValue.Null, namedType);
var binder = GetBinder(syntax);
BoundExpression? result;
var success = binder.GetAwaitableExpressionInfo(dumbInstance, out result, syntax, diagnostics);
var success = binder.GetAwaitableExpressionInfo(dumbInstance, out result, out MethodSymbol? runtimeAwaitMethod, syntax, diagnostics);

RoslynDebug.Assert(!namedType.IsDynamic());
return success &&
(result!.Type!.IsVoidType() || result.Type!.SpecialType == SpecialType.System_Int32);
Debug.Assert(result is { Type: not null } || runtimeAwaitMethod is { ReturnType: not null });
var returnType = result?.Type ?? runtimeAwaitMethod!.ReturnType;
return success && (returnType.IsVoidType() || returnType.SpecialType == SpecialType.System_Int32);
}

/// <summary>
Expand Down
Loading
Loading