Skip to content

Commit e720a0e

Browse files
authored
Fix AsyncEx.AwaitTask cancellation (#46)
#24 (comment)
1 parent 8705e22 commit e720a0e

File tree

2 files changed

+81
-92
lines changed

2 files changed

+81
-92
lines changed

src/IcedTasks/AsyncEx.fs

Lines changed: 19 additions & 91 deletions
Original file line numberDiff line numberDiff line change
@@ -25,35 +25,27 @@ type AsyncEx =
2525
/// This is based on <see href="https://stackoverflow.com/a/66815960">How to use awaitable inside async?</see> and <see href="https://github.com/fsharp/fslang-suggestions/issues/840">Async.Await overload (esp. AwaitTask without throwing AggregateException)</see>
2626
/// </remarks>
2727
static member inline AwaitAwaiter(awaiter: 'Awaiter) =
28+
let inline handleFinished (onNext: 'a -> unit, onError: exn -> unit, awaiter) =
29+
try
30+
onNext (Awaiter.GetResult awaiter)
31+
with
32+
| :? AggregateException as ae when ae.InnerExceptions.Count = 1 ->
33+
onError ae.InnerExceptions.[0]
34+
| e ->
35+
// Why not handle TaskCanceledException/OperationCanceledException?
36+
// From https://github.com/dotnet/fsharp/blob/89e641108e8773e8d5731437a2b944510de52567/src/FSharp.Core/async.fs#L1228-L1231:
37+
// A cancelled task calls the exception continuation with TaskCanceledException, since it may not represent cancellation of
38+
// the overall async (they may be governed by different cancellation tokens, or
39+
// the task may not have a cancellation token at all).
40+
onError e
41+
2842
Async.FromContinuations(fun (onNext, onError, onCancel) ->
2943
if Awaiter.IsCompleted awaiter then
30-
try
31-
onNext (Awaiter.GetResult awaiter)
32-
with
33-
| :? TaskCanceledException as ce -> onCancel ce
34-
| :? OperationCanceledException as ce -> onCancel ce
35-
| :? AggregateException as ae ->
36-
if ae.InnerExceptions.Count = 1 then
37-
onError ae.InnerExceptions.[0]
38-
else
39-
onError ae
40-
| e -> onError e
44+
handleFinished (onNext, onError, awaiter)
4145
else
42-
Awaiter.OnCompleted(
46+
Awaiter.UnsafeOnCompleted(
4347
awaiter,
44-
(fun () ->
45-
try
46-
onNext (Awaiter.GetResult awaiter)
47-
with
48-
| :? TaskCanceledException as ce -> onCancel ce
49-
| :? OperationCanceledException as ce -> onCancel ce
50-
| :? AggregateException as ae ->
51-
if ae.InnerExceptions.Count = 1 then
52-
onError ae.InnerExceptions.[0]
53-
else
54-
onError ae
55-
| e -> onError e
56-
)
48+
(fun () -> handleFinished (onNext, onError, awaiter))
5749
)
5850
)
5951

@@ -78,39 +70,7 @@ type AsyncEx =
7870
/// <remarks>
7971
/// This is based on <see href="https://github.com/fsharp/fslang-suggestions/issues/840">Async.Await overload (esp. AwaitTask without throwing AggregateException)</see>
8072
/// </remarks>
81-
static member AwaitTask(task: Task) : Async<unit> =
82-
Async.FromContinuations(fun (onNext, onError, onCancel) ->
83-
if task.IsCompleted then
84-
if task.IsFaulted then
85-
let e = task.Exception
86-
87-
if e.InnerExceptions.Count = 1 then
88-
onError e.InnerExceptions.[0]
89-
else
90-
onError e
91-
elif task.IsCanceled then
92-
onCancel (TaskCanceledException(task))
93-
else
94-
onNext ()
95-
else
96-
task.ContinueWith(
97-
(fun (task: Task) ->
98-
if task.IsFaulted then
99-
let e = task.Exception
100-
101-
if e.InnerExceptions.Count = 1 then
102-
onError e.InnerExceptions.[0]
103-
else
104-
onError e
105-
elif task.IsCanceled then
106-
onCancel (TaskCanceledException(task))
107-
else
108-
onNext ()
109-
),
110-
TaskContinuationOptions.ExecuteSynchronously
111-
)
112-
|> ignore
113-
)
73+
static member AwaitTask(task: Task) : Async<unit> = AsyncEx.AwaitAwaitable(task)
11474

11575
/// <summary>
11676
/// Return an asynchronous computation that will wait for the given Task to complete and return
@@ -122,39 +82,7 @@ type AsyncEx =
12282
/// This is based on <see href="https://github.com/fsharp/fslang-suggestions/issues/840">Async.Await overload (esp. AwaitTask without throwing AggregateException)</see>
12383
/// </remarks>
12484
static member AwaitTask(task: Task<'T>) : Async<'T> =
125-
Async.FromContinuations(fun (onNext, onError, onCancel) ->
126-
127-
if task.IsCompleted then
128-
if task.IsFaulted then
129-
let e = task.Exception
130-
131-
if e.InnerExceptions.Count = 1 then
132-
onError e.InnerExceptions.[0]
133-
else
134-
onError e
135-
elif task.IsCanceled then
136-
onCancel (TaskCanceledException(task))
137-
else
138-
onNext task.Result
139-
else
140-
task.ContinueWith(
141-
(fun (task: Task<'T>) ->
142-
if task.IsFaulted then
143-
let e = task.Exception
144-
145-
if e.InnerExceptions.Count = 1 then
146-
onError e.InnerExceptions.[0]
147-
else
148-
onError e
149-
elif task.IsCanceled then
150-
onCancel (TaskCanceledException(task))
151-
else
152-
onNext task.Result
153-
),
154-
TaskContinuationOptions.ExecuteSynchronously
155-
)
156-
|> ignore
157-
)
85+
AsyncEx.AwaitAwaiter(Awaitable.GetTaskAwaiter task)
15886

15987

16088
/// <summary>

tests/IcedTasks.Tests/AsyncExTests.fs

Lines changed: 62 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -811,7 +811,68 @@ module PolyfillTest =
811811
let! result = outer
812812
Expect.equal result () "Should return the data"
813813
}
814+
815+
let withCancellation (ct: CancellationToken) (a: Async<'a>) : Async<'a> =
816+
async {
817+
let! ct2 = Async.CancellationToken
818+
use cts = CancellationTokenSource.CreateLinkedTokenSource(ct, ct2)
819+
let tcs = new TaskCompletionSource<'a>()
820+
821+
use _reg =
822+
cts.Token.Register(fun () ->
823+
tcs.TrySetCanceled(cts.Token)
824+
|> ignore
825+
)
826+
827+
let a =
828+
async {
829+
try
830+
let! a = a
831+
832+
tcs.TrySetResult a
833+
|> ignore
834+
with ex ->
835+
tcs.TrySetException ex
836+
|> ignore
837+
}
838+
839+
Async.Start(a, cts.Token)
840+
841+
return!
842+
tcs.Task
843+
|> AsyncEx.AwaitTask
844+
}
845+
846+
testCase "Don't cancel everything if one task cancels"
847+
<| fun () ->
848+
use cts = new CancellationTokenSource()
849+
cts.CancelAfter(100)
850+
851+
let doWork i =
852+
asyncEx {
853+
try
854+
let! _ =
855+
Async.Sleep(100)
856+
|> withCancellation cts.Token
857+
858+
()
859+
with :? OperationCanceledException as e ->
860+
()
861+
}
862+
863+
Seq.init
864+
(Environment.ProcessorCount
865+
* 2)
866+
doWork
867+
|> Async.Parallel
868+
|> Async.RunSynchronously
869+
|> ignore
814870
]
815871

872+
816873
[<Tests>]
817-
let asyncExTests = testList "IcedTasks.Polyfill.Async" [ builderTests ]
874+
let asyncExTests =
875+
testList "IcedTasks.Polyfill.Async" [
876+
builderTests
877+
878+
]

0 commit comments

Comments
 (0)