Skip to content

Commit fa2790e

Browse files
committed
Added AwaitHelper to properly wait for ValueTasks.
1 parent b525ba3 commit fa2790e

File tree

14 files changed

+426
-189
lines changed

14 files changed

+426
-189
lines changed

src/BenchmarkDotNet/Code/DeclarationsProvider.cs

Lines changed: 5 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -63,7 +63,7 @@ private string GetMethodName(MethodInfo method)
6363
(method.ReturnType.GetGenericTypeDefinition() == typeof(Task<>) ||
6464
method.ReturnType.GetGenericTypeDefinition() == typeof(ValueTask<>))))
6565
{
66-
return $"() => {method.Name}().GetAwaiter().GetResult()";
66+
return $"() => awaitHelper.GetResult({method.Name}())";
6767
}
6868

6969
return method.Name;
@@ -149,12 +149,10 @@ internal class TaskDeclarationsProvider : VoidDeclarationsProvider
149149
{
150150
public TaskDeclarationsProvider(Descriptor descriptor) : base(descriptor) { }
151151

152-
// we use GetAwaiter().GetResult() because it's fastest way to obtain the result in blocking way,
153-
// and will eventually throw actual exception, not aggregated one
154152
public override string WorkloadMethodDelegate(string passArguments)
155-
=> $"({passArguments}) => {{ {Descriptor.WorkloadMethod.Name}({passArguments}).GetAwaiter().GetResult(); }}";
153+
=> $"({passArguments}) => {{ awaitHelper.GetResult({Descriptor.WorkloadMethod.Name}({passArguments})); }}";
156154

157-
public override string GetWorkloadMethodCall(string passArguments) => $"{Descriptor.WorkloadMethod.Name}({passArguments}).GetAwaiter().GetResult()";
155+
public override string GetWorkloadMethodCall(string passArguments) => $"awaitHelper.GetResult({Descriptor.WorkloadMethod.Name}({passArguments}))";
158156

159157
protected override Type WorkloadMethodReturnType => typeof(void);
160158
}
@@ -168,11 +166,9 @@ public GenericTaskDeclarationsProvider(Descriptor descriptor) : base(descriptor)
168166

169167
protected override Type WorkloadMethodReturnType => Descriptor.WorkloadMethod.ReturnType.GetTypeInfo().GetGenericArguments().Single();
170168

171-
// we use GetAwaiter().GetResult() because it's fastest way to obtain the result in blocking way,
172-
// and will eventually throw actual exception, not aggregated one
173169
public override string WorkloadMethodDelegate(string passArguments)
174-
=> $"({passArguments}) => {{ return {Descriptor.WorkloadMethod.Name}({passArguments}).GetAwaiter().GetResult(); }}";
170+
=> $"({passArguments}) => {{ return awaitHelper.GetResult({Descriptor.WorkloadMethod.Name}({passArguments})); }}";
175171

176-
public override string GetWorkloadMethodCall(string passArguments) => $"{Descriptor.WorkloadMethod.Name}({passArguments}).GetAwaiter().GetResult()";
172+
public override string GetWorkloadMethodCall(string passArguments) => $"awaitHelper.GetResult({Descriptor.WorkloadMethod.Name}({passArguments}))";
177173
}
178174
}
Lines changed: 108 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,108 @@
1+
using System;
2+
using System.Linq;
3+
using System.Reflection;
4+
using System.Threading.Tasks;
5+
6+
namespace BenchmarkDotNet.Helpers
7+
{
8+
public class AwaitHelper
9+
{
10+
private readonly object awaiterLock = new object();
11+
private readonly Action awaiterCallback;
12+
private bool awaiterCompleted;
13+
14+
public AwaitHelper()
15+
{
16+
awaiterCallback = AwaiterCallback;
17+
}
18+
19+
private void AwaiterCallback()
20+
{
21+
lock (awaiterLock)
22+
{
23+
awaiterCompleted = true;
24+
System.Threading.Monitor.Pulse(awaiterLock);
25+
}
26+
}
27+
28+
// we use GetAwaiter().GetResult() because it's fastest way to obtain the result in blocking way,
29+
// and will eventually throw actual exception, not aggregated one
30+
public void GetResult(Task task)
31+
{
32+
task.GetAwaiter().GetResult();
33+
}
34+
35+
public T GetResult<T>(Task<T> task)
36+
{
37+
return task.GetAwaiter().GetResult();
38+
}
39+
40+
// It is illegal to call GetResult from an uncomplete ValueTask, so we must hook up a callback.
41+
public void GetResult(ValueTask task)
42+
{
43+
// Don't continue on the captured context, as that may result in a deadlock if the user runs this in-process.
44+
var awaiter = task.ConfigureAwait(false).GetAwaiter();
45+
if (!awaiter.IsCompleted)
46+
{
47+
lock (awaiterLock)
48+
{
49+
awaiterCompleted = false;
50+
awaiter.UnsafeOnCompleted(awaiterCallback);
51+
// Check if the callback executed synchronously before blocking.
52+
if (!awaiterCompleted)
53+
{
54+
System.Threading.Monitor.Wait(awaiterLock);
55+
}
56+
}
57+
}
58+
awaiter.GetResult();
59+
}
60+
61+
public T GetResult<T>(ValueTask<T> task)
62+
{
63+
// Don't continue on the captured context, as that may result in a deadlock if the user runs this in-process.
64+
var awaiter = task.ConfigureAwait(false).GetAwaiter();
65+
if (!awaiter.IsCompleted)
66+
{
67+
lock (awaiterLock)
68+
{
69+
awaiterCompleted = false;
70+
awaiter.UnsafeOnCompleted(awaiterCallback);
71+
// Check if the callback executed synchronously before blocking.
72+
if (!awaiterCompleted)
73+
{
74+
System.Threading.Monitor.Wait(awaiterLock);
75+
}
76+
}
77+
}
78+
return awaiter.GetResult();
79+
}
80+
81+
internal static MethodInfo GetGetResultMethod(Type taskType)
82+
{
83+
if (taskType.IsGenericType)
84+
{
85+
Type compareType = taskType.GetGenericTypeDefinition() == typeof(Task<>)
86+
? typeof(Task<>)
87+
: taskType.GetGenericTypeDefinition() == typeof(ValueTask<>)
88+
? typeof(ValueTask<>)
89+
: null;
90+
return compareType == null
91+
? null
92+
: typeof(AwaitHelper).GetMethods(BindingFlags.Public | BindingFlags.Instance)
93+
.First(m =>
94+
{
95+
if (m.Name != nameof(AwaitHelper.GetResult)) return false;
96+
Type paramType = m.GetParameters().First().ParameterType;
97+
// We have to compare the types indirectly, == check doesn't work.
98+
return paramType.Assembly == compareType.Assembly && paramType.Namespace == compareType.Namespace && paramType.Name == compareType.Name;
99+
})
100+
.MakeGenericMethod(taskType.GetGenericArguments());
101+
}
102+
else
103+
{
104+
return typeof(AwaitHelper).GetMethod(nameof(AwaitHelper.GetResult), BindingFlags.Public | BindingFlags.Instance, null, new Type[1] { taskType }, null);
105+
}
106+
}
107+
}
108+
}

src/BenchmarkDotNet/Helpers/Reflection.Emit/IlGeneratorStatementExtensions.cs

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,37 @@ public static void EmitVoidReturn(this ILGenerator ilBuilder, MethodBuilder meth
4242
ilBuilder.Emit(OpCodes.Ret);
4343
}
4444

45+
public static void EmitSetFieldToNewInstance(
46+
this ILGenerator ilBuilder,
47+
FieldBuilder field,
48+
Type instanceType)
49+
{
50+
if (field.IsStatic)
51+
throw new ArgumentException("The field should be instance field", nameof(field));
52+
53+
if (instanceType != null)
54+
{
55+
/*
56+
IL_0006: ldarg.0
57+
IL_0007: newobj instance void BenchmarkDotNet.Helpers.AwaitHelper::.ctor()
58+
IL_000c: stfld class BenchmarkDotNet.Helpers.AwaitHelper BenchmarkDotNet.Autogenerated.Runnable_0::awaitHelper
59+
*/
60+
var ctor = instanceType.GetConstructor(Array.Empty<Type>());
61+
if (ctor == null)
62+
throw new InvalidOperationException($"Bug: instanceType {instanceType.Name} does not have a 0-parameter accessible constructor.");
63+
64+
ilBuilder.Emit(OpCodes.Ldarg_0);
65+
ilBuilder.Emit(OpCodes.Newobj, ctor);
66+
ilBuilder.Emit(OpCodes.Stfld, field);
67+
}
68+
else
69+
{
70+
ilBuilder.Emit(OpCodes.Ldarg_0);
71+
ilBuilder.Emit(OpCodes.Ldnull);
72+
ilBuilder.Emit(OpCodes.Stfld, field);
73+
}
74+
}
75+
4576
public static void EmitSetDelegateToThisField(
4677
this ILGenerator ilBuilder,
4778
FieldBuilder delegateField,

src/BenchmarkDotNet/Templates/BenchmarkType.txt

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -57,6 +57,8 @@
5757

5858
public Runnable_$ID$()
5959
{
60+
awaitHelper = new BenchmarkDotNet.Helpers.AwaitHelper();
61+
6062
globalSetupAction = $GlobalSetupMethodName$;
6163
globalCleanupAction = $GlobalCleanupMethodName$;
6264
iterationSetupAction = $IterationSetupMethodName$;
@@ -66,6 +68,8 @@
6668
$InitializeArgumentFields$
6769
}
6870

71+
private readonly BenchmarkDotNet.Helpers.AwaitHelper awaitHelper;
72+
6973
private System.Action globalSetupAction;
7074
private System.Action globalCleanupAction;
7175
private System.Action iterationSetupAction;

src/BenchmarkDotNet/Toolchains/InProcess.Emit.Implementation/ConsumableTypeInfo.cs

Lines changed: 15 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,8 @@
11
using BenchmarkDotNet.Engines;
22
using JetBrains.Annotations;
33
using System;
4+
using System.Collections.Generic;
5+
using System.Linq;
46
using System.Reflection;
57
using System.Runtime.CompilerServices;
68
using System.Threading.Tasks;
@@ -17,28 +19,24 @@ public ConsumableTypeInfo(Type methodReturnType)
1719

1820
OriginMethodReturnType = methodReturnType;
1921

20-
// Please note this code does not support await over extension methods.
21-
var getAwaiterMethod = methodReturnType.GetMethod(nameof(Task<int>.GetAwaiter), BindingFlagsPublicInstance);
22-
if (getAwaiterMethod == null)
22+
// Only support (Value)Task for parity with other toolchains (and so we can use AwaitHelper).
23+
IsAwaitable = methodReturnType == typeof(Task) || methodReturnType == typeof(ValueTask)
24+
|| (methodReturnType.GetTypeInfo().IsGenericType
25+
&& (methodReturnType.GetTypeInfo().GetGenericTypeDefinition() == typeof(Task<>)
26+
|| methodReturnType.GetTypeInfo().GetGenericTypeDefinition() == typeof(ValueTask<>)));
27+
28+
if (!IsAwaitable)
2329
{
2430
WorkloadMethodReturnType = methodReturnType;
2531
}
2632
else
2733
{
28-
var getResultMethod = getAwaiterMethod
34+
WorkloadMethodReturnType = methodReturnType
35+
.GetMethod(nameof(Task.GetAwaiter), BindingFlagsPublicInstance)
2936
.ReturnType
30-
.GetMethod(nameof(TaskAwaiter.GetResult), BindingFlagsPublicInstance);
31-
32-
if (getResultMethod == null)
33-
{
34-
WorkloadMethodReturnType = methodReturnType;
35-
}
36-
else
37-
{
38-
WorkloadMethodReturnType = getResultMethod.ReturnType;
39-
GetAwaiterMethod = getAwaiterMethod;
40-
GetResultMethod = getResultMethod;
41-
}
37+
.GetMethod(nameof(TaskAwaiter.GetResult), BindingFlagsPublicInstance)
38+
.ReturnType;
39+
GetResultMethod = Helpers.AwaitHelper.GetGetResultMethod(methodReturnType);
4240
}
4341

4442
if (WorkloadMethodReturnType == null)
@@ -78,8 +76,6 @@ public ConsumableTypeInfo(Type methodReturnType)
7876
[NotNull]
7977
public Type OverheadMethodReturnType { get; }
8078

81-
[CanBeNull]
82-
public MethodInfo GetAwaiterMethod { get; }
8379
[CanBeNull]
8480
public MethodInfo GetResultMethod { get; }
8581

@@ -89,6 +85,6 @@ public ConsumableTypeInfo(Type methodReturnType)
8985
[CanBeNull]
9086
public FieldInfo WorkloadConsumableField { get; }
9187

92-
public bool IsAwaitable => GetAwaiterMethod != null && GetResultMethod != null;
88+
public bool IsAwaitable { get; }
9389
}
9490
}

0 commit comments

Comments
 (0)