Skip to content

Commit c01c152

Browse files
authored
Pass CancellationToken to GetAsyncEnumerator (#48)
1 parent 722aa5d commit c01c152

File tree

6 files changed

+176
-19
lines changed

6 files changed

+176
-19
lines changed

src/IcedTasks/CancellableTaskBuilderBase.fs

Lines changed: 11 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -741,13 +741,17 @@ module CancellableTaskBase =
741741
source: #IAsyncEnumerable<'T>,
742742
body: 'T -> CancellableTaskBaseCode<_, unit, 'Builder>
743743
) : CancellableTaskBaseCode<_, _, 'Builder> =
744-
745-
this.Using(
746-
source.GetAsyncEnumerator CancellationToken.None,
747-
(fun (e: IAsyncEnumerator<'T>) ->
748-
this.WhileAsync(
749-
(fun () -> Awaitable.GetAwaiter(e.MoveNextAsync())),
750-
(fun sm -> (body e.Current).Invoke(&sm))
744+
this.Bind(
745+
this.Source((fun (ct: CancellationToken) -> ValueTask<_> ct)),
746+
(fun ct ->
747+
this.Using(
748+
source.GetAsyncEnumerator ct,
749+
(fun (e: IAsyncEnumerator<'T>) ->
750+
this.WhileAsync(
751+
(fun () -> Awaitable.GetAwaiter(e.MoveNextAsync())),
752+
(fun sm -> (body e.Current).Invoke(&sm))
753+
)
754+
)
751755
)
752756
)
753757
)

tests/IcedTasks.Tests/AsyncExTests.fs

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -741,6 +741,41 @@ module AsyncExTests =
741741
)
742742
}
743743

744+
745+
testCaseAsync "IAsyncEnumerator receives CancellationToken"
746+
<| async {
747+
do!
748+
asyncEx {
749+
750+
let mutable index = 0
751+
let loops = 10
752+
753+
let asyncSeq =
754+
AsyncEnumerable.forXtoY
755+
0
756+
loops
757+
(fun _ -> valueTaskUnit { do! Task.Yield() })
758+
759+
use cts = new CancellationTokenSource()
760+
761+
let actual =
762+
asyncEx {
763+
for (i: int) in asyncSeq do
764+
do! Task.Yield()
765+
index <- index + 1
766+
}
767+
768+
do!
769+
Async.StartAsTask(actual, cancellationToken = cts.Token)
770+
|> Async.AwaitTask
771+
772+
Expect.equal
773+
asyncSeq.LastEnumerator.Value.CancellationToken
774+
cts.Token
775+
""
776+
}
777+
}
778+
744779
]
745780
]
746781

tests/IcedTasks.Tests/CancellablePoolingValueTaskTests.fs

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -851,6 +851,39 @@ module CancellablePoolingValueTaskTests =
851851
}
852852
)
853853
}
854+
855+
856+
testCaseAsync "IAsyncEnumerator receives CancellationToken"
857+
<| async {
858+
do!
859+
cancellablePoolingValueTask {
860+
861+
let mutable index = 0
862+
let loops = 10
863+
864+
let asyncSeq =
865+
AsyncEnumerable.forXtoY
866+
0
867+
loops
868+
(fun _ -> valueTaskUnit { do! Task.Yield() })
869+
870+
use cts = new CancellationTokenSource()
871+
872+
let actual =
873+
cancellablePoolingValueTask {
874+
for (i: int) in asyncSeq do
875+
do! Task.Yield()
876+
index <- index + 1
877+
}
878+
879+
do! actual cts.Token
880+
881+
Expect.equal
882+
asyncSeq.LastEnumerator.Value.CancellationToken
883+
cts.Token
884+
""
885+
}
886+
}
854887
]
855888

856889

tests/IcedTasks.Tests/CancellableTaskTests.fs

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -812,6 +812,41 @@ module CancellableTaskTests =
812812
}
813813
)
814814
}
815+
816+
817+
testCaseAsync "IAsyncEnumerator receives CancellationToken"
818+
<| async {
819+
820+
do!
821+
cancellableTask {
822+
823+
let mutable index = 0
824+
let loops = 10
825+
826+
let asyncSeq =
827+
AsyncEnumerable.forXtoY
828+
0
829+
loops
830+
(fun _ -> valueTaskUnit { do! Task.Yield() })
831+
832+
use cts = new CancellationTokenSource()
833+
834+
let actual =
835+
cancellableTask {
836+
for (i: int) in asyncSeq do
837+
do! Task.Yield()
838+
index <- index + 1
839+
}
840+
841+
do! actual cts.Token
842+
843+
Expect.equal
844+
asyncSeq.LastEnumerator.Value.CancellationToken
845+
cts.Token
846+
""
847+
}
848+
849+
}
815850
]
816851
testList "MergeSources" [
817852

tests/IcedTasks.Tests/CancellableValueTaskTests.fs

Lines changed: 33 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -851,6 +851,39 @@ module CancellableValueTaskTests =
851851
)
852852
}
853853

854+
855+
testCaseAsync "IAsyncEnumerator receives CancellationToken"
856+
<| async {
857+
do!
858+
cancellableValueTask {
859+
860+
let mutable index = 0
861+
let loops = 10
862+
863+
let asyncSeq =
864+
AsyncEnumerable.forXtoY
865+
0
866+
loops
867+
(fun _ -> valueTaskUnit { do! Task.Yield() })
868+
869+
use cts = new CancellationTokenSource()
870+
871+
let actual =
872+
cancellableValueTask {
873+
for (i: int) in asyncSeq do
874+
do! Task.Yield()
875+
index <- index + 1
876+
}
877+
878+
do! actual cts.Token
879+
880+
Expect.equal
881+
asyncSeq.LastEnumerator.Value.CancellationToken
882+
cts.Token
883+
""
884+
}
885+
}
886+
854887
]
855888

856889
testList "MergeSources" [

tests/IcedTasks.Tests/Expect.fs

Lines changed: 29 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -171,23 +171,40 @@ module AsyncEnumerable =
171171
open System.Collections.Generic
172172
open System.Threading
173173

174-
type AsyncEnumerable<'T>(e: IEnumerable<'T>, beforeMoveNext: Func<_, ValueTask>) =
174+
type AsyncEnumerator<'T>(current, moveNext, dispose, cancellationToken: CancellationToken) =
175+
member this.CancellationToken = cancellationToken
175176

176-
member this.GetAsyncEnumerator(ct) =
177-
let enumerator = e.GetEnumerator()
177+
interface IAsyncEnumerator<'T> with
178+
member this.Current = current ()
179+
member this.MoveNextAsync() = moveNext ()
180+
member this.DisposeAsync() = dispose ()
178181

179-
{ new IAsyncEnumerator<'T> with
180-
member this.Current = enumerator.Current
182+
type AsyncEnumerable<'T>(e: IEnumerable<'T>, beforeMoveNext: Func<_, ValueTask>) =
181183

182-
member this.MoveNextAsync() =
183-
valueTask {
184-
do! beforeMoveNext.Invoke(ct)
185-
return enumerator.MoveNext()
186-
}
184+
let mutable lastEnumerator = None
185+
member this.LastEnumerator = lastEnumerator
187186

188-
member this.DisposeAsync() = valueTaskUnit { enumerator.Dispose() }
187+
member this.GetAsyncEnumerator(ct) =
188+
let enumerator = e.GetEnumerator()
189189

190-
}
190+
lastEnumerator <-
191+
Some
192+
<| AsyncEnumerator(
193+
(fun () -> enumerator.Current),
194+
(fun () ->
195+
valueTask {
196+
do! beforeMoveNext.Invoke(ct)
197+
return enumerator.MoveNext()
198+
}
199+
),
200+
(fun () ->
201+
enumerator.Dispose()
202+
|> ValueTask
203+
),
204+
ct
205+
)
206+
207+
lastEnumerator.Value
191208

192209
interface IAsyncEnumerable<'T> with
193210
member this.GetAsyncEnumerator(ct: CancellationToken) = this.GetAsyncEnumerator(ct)

0 commit comments

Comments
 (0)