Skip to content

Commit 5720494

Browse files
committed
Added AwaitHelper to properly wait for ValueTasks.
1 parent b525ba3 commit 5720494

File tree

14 files changed

+430
-189
lines changed

14 files changed

+430
-189
lines changed

src/BenchmarkDotNet/Code/DeclarationsProvider.cs

Lines changed: 5 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ private string GetMethodName(MethodInfo method)
6363
(method.ReturnType.GetGenericTypeDefinition() == typeof(Task<>) ||
6464
method.ReturnType.GetGenericTypeDefinition() == typeof(ValueTask<>))))
6565
{
66-
return $"() => {method.Name}().GetAwaiter().GetResult()";
66+
return $"() => awaitHelper.GetResult({method.Name}())";
6767
}
6868

6969
return method.Name;
@@ -149,12 +149,10 @@ internal class TaskDeclarationsProvider : VoidDeclarationsProvider
149149
{
150150
public TaskDeclarationsProvider(Descriptor descriptor) : base(descriptor) { }
151151

152-
// we use GetAwaiter().GetResult() because it's fastest way to obtain the result in blocking way,
153-
// and will eventually throw actual exception, not aggregated one
154152
public override string WorkloadMethodDelegate(string passArguments)
155-
=> $"({passArguments}) => {{ {Descriptor.WorkloadMethod.Name}({passArguments}).GetAwaiter().GetResult(); }}";
153+
=> $"({passArguments}) => {{ awaitHelper.GetResult({Descriptor.WorkloadMethod.Name}({passArguments})); }}";
156154

157-
public override string GetWorkloadMethodCall(string passArguments) => $"{Descriptor.WorkloadMethod.Name}({passArguments}).GetAwaiter().GetResult()";
155+
public override string GetWorkloadMethodCall(string passArguments) => $"awaitHelper.GetResult({Descriptor.WorkloadMethod.Name}({passArguments}))";
158156

159157
protected override Type WorkloadMethodReturnType => typeof(void);
160158
}
@@ -168,11 +166,9 @@ public GenericTaskDeclarationsProvider(Descriptor descriptor) : base(descriptor)
168166

169167
protected override Type WorkloadMethodReturnType => Descriptor.WorkloadMethod.ReturnType.GetTypeInfo().GetGenericArguments().Single();
170168

171-
// we use GetAwaiter().GetResult() because it's fastest way to obtain the result in blocking way,
172-
// and will eventually throw actual exception, not aggregated one
173169
public override string WorkloadMethodDelegate(string passArguments)
174-
=> $"({passArguments}) => {{ return {Descriptor.WorkloadMethod.Name}({passArguments}).GetAwaiter().GetResult(); }}";
170+
=> $"({passArguments}) => {{ return awaitHelper.GetResult({Descriptor.WorkloadMethod.Name}({passArguments})); }}";
175171

176-
public override string GetWorkloadMethodCall(string passArguments) => $"{Descriptor.WorkloadMethod.Name}({passArguments}).GetAwaiter().GetResult()";
172+
public override string GetWorkloadMethodCall(string passArguments) => $"awaitHelper.GetResult({Descriptor.WorkloadMethod.Name}({passArguments}))";
177173
}
178174
}
Lines changed: 112 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,112 @@
1+
using System;
2+
using System.Linq;
3+
using System.Reflection;
4+
using System.Runtime.CompilerServices;
5+
using System.Threading.Tasks;
6+
7+
namespace BenchmarkDotNet.Helpers
8+
{
9+
public class AwaitHelper
10+
{
11+
private readonly object awaiterLock = new object();
12+
private readonly Action awaiterCallback;
13+
private bool awaiterCompleted;
14+
15+
public AwaitHelper()
16+
{
17+
awaiterCallback = AwaiterCallback;
18+
}
19+
20+
private void AwaiterCallback()
21+
{
22+
lock (awaiterLock)
23+
{
24+
awaiterCompleted = true;
25+
System.Threading.Monitor.Pulse(awaiterLock);
26+
}
27+
}
28+
29+
// we use GetAwaiter().GetResult() because it's fastest way to obtain the result in blocking way,
30+
// and will eventually throw actual exception, not aggregated one
31+
public void GetResult(Task task)
32+
{
33+
task.GetAwaiter().GetResult();
34+
}
35+
36+
public T GetResult<T>(Task<T> task)
37+
{
38+
return task.GetAwaiter().GetResult();
39+
}
40+
41+
// It is illegal to call GetResult from an uncomplete ValueTask, so we must hook up a callback.
42+
public void GetResult(ValueTask task)
43+
{
44+
// Don't continue on the captured context, as that may result in a deadlock if the user runs this in-process.
45+
var awaiter = task.ConfigureAwait(false).GetAwaiter();
46+
if (!awaiter.IsCompleted)
47+
{
48+
lock (awaiterLock)
49+
{
50+
awaiterCompleted = false;
51+
awaiter.UnsafeOnCompleted(awaiterCallback);
52+
// Check if the callback executed synchronously before blocking.
53+
if (!awaiterCompleted)
54+
{
55+
System.Threading.Monitor.Wait(awaiterLock);
56+
}
57+
}
58+
}
59+
awaiter.GetResult();
60+
}
61+
62+
public T GetResult<T>(ValueTask<T> task)
63+
{
64+
// Don't continue on the captured context, as that may result in a deadlock if the user runs this in-process.
65+
var awaiter = task.ConfigureAwait(false).GetAwaiter();
66+
if (!awaiter.IsCompleted)
67+
{
68+
lock (awaiterLock)
69+
{
70+
awaiterCompleted = false;
71+
awaiter.UnsafeOnCompleted(awaiterCallback);
72+
// Check if the callback executed synchronously before blocking.
73+
if (!awaiterCompleted)
74+
{
75+
System.Threading.Monitor.Wait(awaiterLock);
76+
}
77+
}
78+
}
79+
return awaiter.GetResult();
80+
}
81+
82+
internal static MethodInfo GetGetResultMethod(Type taskType)
83+
{
84+
if (!taskType.IsGenericType)
85+
{
86+
return typeof(AwaitHelper).GetMethod(nameof(AwaitHelper.GetResult), BindingFlags.Public | BindingFlags.Instance, null, new Type[1] { taskType }, null);
87+
}
88+
89+
Type compareType = taskType.GetGenericTypeDefinition() == typeof(ValueTask<>) ? typeof(ValueTask<>)
90+
: typeof(Task).IsAssignableFrom(taskType.GetGenericTypeDefinition()) ? typeof(Task<>)
91+
: null;
92+
if (compareType == null)
93+
{
94+
return null;
95+
}
96+
var resultType = taskType
97+
.GetMethod(nameof(Task.GetAwaiter), BindingFlags.Public | BindingFlags.Instance)
98+
.ReturnType
99+
.GetMethod(nameof(TaskAwaiter.GetResult), BindingFlags.Public | BindingFlags.Instance)
100+
.ReturnType;
101+
return typeof(AwaitHelper).GetMethods(BindingFlags.Public | BindingFlags.Instance)
102+
.First(m =>
103+
{
104+
if (m.Name != nameof(AwaitHelper.GetResult)) return false;
105+
Type paramType = m.GetParameters().First().ParameterType;
106+
// We have to compare the types indirectly, == check doesn't work.
107+
return paramType.Assembly == compareType.Assembly && paramType.Namespace == compareType.Namespace && paramType.Name == compareType.Name;
108+
})
109+
.MakeGenericMethod(new[] { resultType });
110+
}
111+
}
112+
}

src/BenchmarkDotNet/Helpers/Reflection.Emit/IlGeneratorStatementExtensions.cs

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,37 @@ public static void EmitVoidReturn(this ILGenerator ilBuilder, MethodBuilder meth
4242
ilBuilder.Emit(OpCodes.Ret);
4343
}
4444

45+
public static void EmitSetFieldToNewInstance(
46+
this ILGenerator ilBuilder,
47+
FieldBuilder field,
48+
Type instanceType)
49+
{
50+
if (field.IsStatic)
51+
throw new ArgumentException("The field should be instance field", nameof(field));
52+
53+
if (instanceType != null)
54+
{
55+
/*
56+
IL_0006: ldarg.0
57+
IL_0007: newobj instance void BenchmarkDotNet.Helpers.AwaitHelper::.ctor()
58+
IL_000c: stfld class BenchmarkDotNet.Helpers.AwaitHelper BenchmarkDotNet.Autogenerated.Runnable_0::awaitHelper
59+
*/
60+
var ctor = instanceType.GetConstructor(Array.Empty<Type>());
61+
if (ctor == null)
62+
throw new InvalidOperationException($"Bug: instanceType {instanceType.Name} does not have a 0-parameter accessible constructor.");
63+
64+
ilBuilder.Emit(OpCodes.Ldarg_0);
65+
ilBuilder.Emit(OpCodes.Newobj, ctor);
66+
ilBuilder.Emit(OpCodes.Stfld, field);
67+
}
68+
else
69+
{
70+
ilBuilder.Emit(OpCodes.Ldarg_0);
71+
ilBuilder.Emit(OpCodes.Ldnull);
72+
ilBuilder.Emit(OpCodes.Stfld, field);
73+
}
74+
}
75+
4576
public static void EmitSetDelegateToThisField(
4677
this ILGenerator ilBuilder,
4778
FieldBuilder delegateField,

src/BenchmarkDotNet/Templates/BenchmarkType.txt

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,8 @@
5757

5858
public Runnable_$ID$()
5959
{
60+
awaitHelper = new BenchmarkDotNet.Helpers.AwaitHelper();
61+
6062
globalSetupAction = $GlobalSetupMethodName$;
6163
globalCleanupAction = $GlobalCleanupMethodName$;
6264
iterationSetupAction = $IterationSetupMethodName$;
@@ -66,6 +68,8 @@
6668
$InitializeArgumentFields$
6769
}
6870

71+
private readonly BenchmarkDotNet.Helpers.AwaitHelper awaitHelper;
72+
6973
private System.Action globalSetupAction;
7074
private System.Action globalCleanupAction;
7175
private System.Action iterationSetupAction;

src/BenchmarkDotNet/Toolchains/InProcess.Emit.Implementation/ConsumableTypeInfo.cs

Lines changed: 15 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
using BenchmarkDotNet.Engines;
22
using JetBrains.Annotations;
33
using System;
4+
using System.Collections.Generic;
5+
using System.Linq;
46
using System.Reflection;
57
using System.Runtime.CompilerServices;
68
using System.Threading.Tasks;
@@ -17,28 +19,24 @@ public ConsumableTypeInfo(Type methodReturnType)
1719

1820
OriginMethodReturnType = methodReturnType;
1921

20-
// Please note this code does not support await over extension methods.
21-
var getAwaiterMethod = methodReturnType.GetMethod(nameof(Task<int>.GetAwaiter), BindingFlagsPublicInstance);
22-
if (getAwaiterMethod == null)
22+
// Only support (Value)Task for parity with other toolchains (and so we can use AwaitHelper).
23+
IsAwaitable = methodReturnType == typeof(Task) || methodReturnType == typeof(ValueTask)
24+
|| (methodReturnType.GetTypeInfo().IsGenericType
25+
&& (methodReturnType.GetTypeInfo().GetGenericTypeDefinition() == typeof(Task<>)
26+
|| methodReturnType.GetTypeInfo().GetGenericTypeDefinition() == typeof(ValueTask<>)));
27+
28+
if (!IsAwaitable)
2329
{
2430
WorkloadMethodReturnType = methodReturnType;
2531
}
2632
else
2733
{
28-
var getResultMethod = getAwaiterMethod
34+
WorkloadMethodReturnType = methodReturnType
35+
.GetMethod(nameof(Task.GetAwaiter), BindingFlagsPublicInstance)
2936
.ReturnType
30-
.GetMethod(nameof(TaskAwaiter.GetResult), BindingFlagsPublicInstance);
31-
32-
if (getResultMethod == null)
33-
{
34-
WorkloadMethodReturnType = methodReturnType;
35-
}
36-
else
37-
{
38-
WorkloadMethodReturnType = getResultMethod.ReturnType;
39-
GetAwaiterMethod = getAwaiterMethod;
40-
GetResultMethod = getResultMethod;
41-
}
37+
.GetMethod(nameof(TaskAwaiter.GetResult), BindingFlagsPublicInstance)
38+
.ReturnType;
39+
GetResultMethod = Helpers.AwaitHelper.GetGetResultMethod(methodReturnType);
4240
}
4341

4442
if (WorkloadMethodReturnType == null)
@@ -78,8 +76,6 @@ public ConsumableTypeInfo(Type methodReturnType)
7876
[NotNull]
7977
public Type OverheadMethodReturnType { get; }
8078

81-
[CanBeNull]
82-
public MethodInfo GetAwaiterMethod { get; }
8379
[CanBeNull]
8480
public MethodInfo GetResultMethod { get; }
8581

@@ -89,6 +85,6 @@ public ConsumableTypeInfo(Type methodReturnType)
8985
[CanBeNull]
9086
public FieldInfo WorkloadConsumableField { get; }
9187

92-
public bool IsAwaitable => GetAwaiterMethod != null && GetResultMethod != null;
88+
public bool IsAwaitable { get; }
9389
}
9490
}

0 commit comments

Comments
 (0)