Skip to content

Commit 70f5c8d

Browse files
committed
Proper Task.zip implementation
1 parent 50b1b7b commit 70f5c8d

File tree

2 files changed

+129
-24
lines changed

2 files changed

+129
-24
lines changed

src/FSharpPlus/Extensions/Task.fs

+94-23
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@ namespace FSharpPlus
77
module Task =
88

99
open System
10+
open System.Threading
1011
open System.Threading.Tasks
1112

1213
let private (|Canceled|Faulted|Completed|) (t: Task<'a>) =
@@ -135,26 +136,91 @@ module Task =
135136
tcs.Task
136137

137138
/// <summary>Creates a task workflow from two workflows 'x' and 'y', mapping its results with 'f'.</summary>
138-
/// <remarks>Similar to lift2 but although workflows are started in sequence they might end independently in different order.</remarks>
139-
/// <param name="f">The mapping function.</param>
140-
/// <param name="x">First task workflow.</param>
141-
/// <param name="y">Second task workflow.</param>
142-
let map2 f x y = task {
143-
let! x' = x
144-
let! y' = y
145-
return f x' y' }
139+
/// <remarks>Similar to lift2 but although workflows are started in sequence they might end independently in different order
140+
/// and all errors are collected.
141+
/// </remarks>
142+
/// <param name="mapper">The mapping function.</param>
143+
/// <param name="task1">First task workflow.</param>
144+
/// <param name="task2">Second task workflow.</param>
145+
let map2 mapper (task1: Task<'T1>) (task2: Task<'T2>) : Task<'U> =
146+
if task1.Status = TaskStatus.RanToCompletion && task2.Status = TaskStatus.RanToCompletion then
147+
try Task.FromResult (mapper task1.Result task2.Result)
148+
with e ->
149+
let tcs = TaskCompletionSource<_> ()
150+
tcs.SetException e
151+
tcs.Task
152+
else
153+
let tcs = TaskCompletionSource<_> ()
154+
let r1 = ref Unchecked.defaultof<_>
155+
let r2 = ref Unchecked.defaultof<_>
156+
let mutable cancelled = false
157+
let failures = [|IReadOnlyCollection.empty; IReadOnlyCollection.empty|]
158+
let pending = ref 2
159+
160+
let trySet () =
161+
if Interlocked.Decrement pending = 0 then
162+
let noFailures = Array.forall IReadOnlyCollection.isEmpty failures
163+
if noFailures && not cancelled then
164+
try tcs.SetResult (mapper r1.Value r2.Value)
165+
with e -> tcs.SetException e
166+
elif noFailures then tcs.SetCanceled ()
167+
else tcs.SetException (failures |> Seq.map AggregateException |> Seq.reduce Exception.add).InnerExceptions
168+
169+
let k (v: ref<'k>) i t =
170+
match t with
171+
| Canceled -> cancelled <- true
172+
| Faulted e -> failures[i] <- e.InnerExceptions
173+
| Completed r -> v.Value <- r
174+
trySet ()
175+
176+
task1.ContinueWith (k r1 0) |> ignore
177+
task2.ContinueWith (k r2 1) |> ignore
178+
tcs.Task
146179

147180
/// <summary>Creates a task workflow from three workflows 'x', 'y' and z, mapping its results with 'f'.</summary>
148-
/// <remarks>Similar to lift3 but although workflows are started in sequence they might end independently in different order.</remarks>
149-
/// <param name="f">The mapping function.</param>
150-
/// <param name="x">First task workflow.</param>
151-
/// <param name="y">Second task workflow.</param>
152-
/// <param name="z">Third task workflow.</param>
153-
let map3 f x y z = task {
154-
let! x' = x
155-
let! y' = y
156-
let! z' = z
157-
return f x' y' z' }
181+
/// <remarks>Similar to lift3 but although workflows are started in sequence they might end independently in different order
182+
/// and all errors are collected.
183+
/// </remarks>
184+
/// <param name="mapper">The mapping function.</param>
185+
/// <param name="task1">First task workflow.</param>
186+
/// <param name="task2">Second task workflow.</param>
187+
/// <param name="task3">Third task workflow.</param>
188+
let map3 mapper (task1: Task<'T1>) (task2: Task<'T2>) (task3: Task<'T3>) : Task<'U> =
189+
if task1.Status = TaskStatus.RanToCompletion && task2.Status = TaskStatus.RanToCompletion && task3.Status = TaskStatus.RanToCompletion then
190+
try Task.FromResult (mapper task1.Result task2.Result task3.Result)
191+
with e ->
192+
let tcs = TaskCompletionSource<_> ()
193+
tcs.SetException e
194+
tcs.Task
195+
else
196+
let tcs = TaskCompletionSource<_> ()
197+
let r1 = ref Unchecked.defaultof<_>
198+
let r2 = ref Unchecked.defaultof<_>
199+
let r3 = ref Unchecked.defaultof<_>
200+
let mutable cancelled = false
201+
let failures = [|IReadOnlyCollection.empty<exn>; IReadOnlyCollection.empty; IReadOnlyCollection.empty|]
202+
let pending = ref 3
203+
204+
let trySet () =
205+
if Interlocked.Decrement pending = 0 then
206+
let noFailures = Array.forall isNull failures
207+
if noFailures && not cancelled then
208+
try tcs.SetResult (mapper r1.Value r2.Value r3.Value)
209+
with e -> tcs.SetException e
210+
elif noFailures then tcs.SetCanceled ()
211+
else tcs.SetException (failures |> Seq.concat |> Seq.fold Exception.add (AggregateException ())).InnerExceptions
212+
213+
let k (v: ref<'k>) i t =
214+
match t with
215+
| Canceled -> cancelled <- true
216+
| Faulted e -> failures[i] <- e.InnerExceptions
217+
| Completed r -> v.Value <- r
218+
trySet ()
219+
220+
task1.ContinueWith (k r1 0) |> ignore
221+
task2.ContinueWith (k r2 1) |> ignore
222+
task3.ContinueWith (k r3 2) |> ignore
223+
tcs.Task
158224

159225
/// <summary>Creates a task workflow that is the result of applying the resulting function of a task workflow
160226
/// to the resulting value of another task workflow</summary>
@@ -242,11 +308,16 @@ module Task =
242308
tcs.Task
243309

244310
/// <summary>Creates a task workflow from two workflows 'x' and 'y', tupling its results.</summary>
245-
/// <remarks>Similar to zipSequentially but although workflows are started in sequence they might end independently in different order.</remarks>
246-
let zip x y = task {
247-
let! x' = x
248-
let! y' = y
249-
return x', y' }
311+
/// <remarks>Similar to zipSequentially but although workflows are started in sequence they might end independently in different order
312+
/// and all errors are collected.
313+
/// </remarks>
314+
let zip (task1: Task<'T1>) (task2: Task<'T2>) = map2 (fun x y -> x, y) task1 task2
315+
316+
/// <summary>Creates a task workflow from two workflows 'x', 'y' and 'z', tupling its results.</summary>
317+
/// <remarks>Similar to zipSequentially but although workflows are started in sequence they might end independently in different order
318+
/// and all errors are collected.
319+
/// </remarks>
320+
let zip3 (task1: Task<'T1>) (task2: Task<'T2>) (task3: Task<'T3>) = map3 (fun x y z -> x, y, z) task1 task2 task3
250321

251322
/// Flattens two nested tasks into one.
252323
let join (source: Task<Task<'T>>) : Task<'T> = source.Unwrap()

tests/FSharpPlus.Tests/Task.fs

+35-1
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@ module Task =
1212
exception TestException of string
1313

1414
module TaskTests =
15+
open System.Threading
1516

1617
let createTask isFailed delay value =
1718
if not isFailed && delay = 0 then Task.FromResult value
@@ -39,11 +40,17 @@ module Task =
3940
let b = Task.zipSequentially x1 x2
4041
require b.IsCompleted "Task.zipSequentially didn't short-circuit"
4142

43+
let b1 = Task.zip3 x1 x2 x3
44+
require b1.IsCompleted "Task.zip3 didn't short-circuit"
45+
4246
let c = Task.lift2 (+) x1 x2
4347
require c.IsCompleted "Task.lift2 didn't short-circuit"
4448

4549
let d = Task.lift3 (fun x y z -> x + y + z) x1 x2 x3
46-
require d.IsCompleted "Task.lift3 didn't short-circiut"
50+
require d.IsCompleted "Task.lift3 didn't short-circuit"
51+
52+
let d2 = Task.map3 (fun x y z -> x + y + z) x1 x2 x3
53+
require d2.IsCompleted "Task.map3 didn't short-circuit"
4754

4855
[<Test>]
4956
let erroredTasks () =
@@ -151,6 +158,33 @@ module Task =
151158
r19.Exception.InnerExceptions |> areEquivalent [TestException "Ouch, can't create: 2"]
152159
let r20 = Task.lift3 (mapping3 false) (x1 ()) (x2 ()) (e3 ())
153160
r20.Exception.InnerExceptions |> areEquivalent [TestException "Ouch, can't create: 3"]
161+
162+
[<Test>]
163+
let testTaskZip () =
164+
let t1 = createTask true 0 1
165+
let t2 = createTask true 0 2
166+
let t3 = createTask true 0 3
167+
168+
let c = new CancellationToken true
169+
let t4 = Task.FromCanceled<int> c
170+
171+
let t5 = createTask false 0 5
172+
let t6 = createTask false 0 6
173+
174+
let t12 = Task.WhenAll [t1; t2]
175+
let t12t12 = Task.WhenAll [t12; t12]
176+
177+
let t12123 = Task.zip3 t12t12 (Task.WhenAll [t3; t3]) t4
178+
let ac1 = t12123.Exception.InnerExceptions |> Seq.map (fun x -> int (Char.GetNumericValue x.Message.[35]))
179+
180+
CollectionAssert.AreEquivalent ([1; 2; 1; 2; 3], ac1, "Task.zip(3) should add only non already existing exceptions.")
181+
182+
let t13 = Task.zip3 (Task.zip t1 t3) t4 (Task.zip t5 t6)
183+
Assert.AreEqual (true, t13.IsFaulted, "Task.zip(3) between a value, an exception and a cancellation -> exception wins.")
184+
let ac2 = t13.Exception.InnerExceptions |> Seq.map (fun x -> int (Char.GetNumericValue x.Message.[35]))
185+
CollectionAssert.AreEquivalent ([1; 3], ac2, "Task.zip between 2 exceptions => both exceptions returned, even after combining with cancellation and values.")
186+
187+
154188

155189
module TaskBuilderTests =
156190

0 commit comments

Comments
 (0)