Skip to content

Commit 22d0d46

Browse files
KDr2yebai
andauthored
New construction for TapedTask (#155)
* new construction for TapedTask * Update Project.toml Co-authored-by: Hong Ge <[email protected]>
1 parent 9990abf commit 22d0d46

File tree

4 files changed

+24
-7
lines changed

4 files changed

+24
-7
lines changed

Diff for: Project.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ uuid = "6f1fad26-d15e-5dc8-ae53-837a1d7b8c9f"
33
license = "MIT"
44
desc = "Tape based task copying in Turing"
55
repo = "https://github.com/TuringLang/Libtask.jl.git"
6-
version = "0.8"
6+
version = "0.8.1"
77

88
[deps]
99
FunctionWrappers = "069b7b12-0de2-55c6-9aab-29f3d0a68a2e"

Diff for: src/tapedfunction.jl

+4-4
Original file line numberDiff line numberDiff line change
@@ -58,9 +58,9 @@ mutable struct TapedFunction{F, TapeType}
5858
binding_values::Bindings
5959
arg_binding_slots::Vector{Int} # arg indices in binding_values
6060
retval_binding_slot::Int # 0 indicates the function has not returned
61-
deepcopy_types::Vector{Any}
61+
deepcopy_types::Type # use a Union type for multiple types
6262

63-
function TapedFunction{F, T}(f::F, args...; cache=false, deepcopy_types=[]) where {F, T}
63+
function TapedFunction{F, T}(f::F, args...; cache=false, deepcopy_types=Union{}) where {F, T}
6464
args_type = _accurate_typeof.(args)
6565
cache_key = (f, args_type...)
6666

@@ -78,7 +78,7 @@ mutable struct TapedFunction{F, TapeType}
7878
return tf
7979
end
8080

81-
TapedFunction(f, args...; cache=false, deepcopy_types=[]) =
81+
TapedFunction(f, args...; cache=false, deepcopy_types=Union{}) =
8282
TapedFunction{typeof(f), RawTape}(f, args...; cache=cache, deepcopy_types=deepcopy_types)
8383

8484
function TapedFunction{F, T0}(tf::TapedFunction{F, T1}) where {F, T0, T1}
@@ -472,7 +472,7 @@ tape_shallowcopy(x::Core.Box) = Core.Box(tape_shallowcopy(x.contents))
472472
tape_deepcopy(x::Core.Box) = Core.Box(tape_deepcopy(x.contents))
473473

474474
function _tape_copy(v, deepcopy_types)
475-
if any(t -> isa(v, t), deepcopy_types)
475+
if isa(v, deepcopy_types)
476476
tape_deepcopy(v)
477477
else
478478
tape_shallowcopy(v)

Diff for: src/tapedtask.jl

+3-2
Original file line numberDiff line numberDiff line change
@@ -65,12 +65,13 @@ end
6565

6666
# NOTE: evaluating model without a trace, see
6767
# https://github.com/TuringLang/Turing.jl/pull/1757#diff-8d16dd13c316055e55f300cd24294bb2f73f46cbcb5a481f8936ff56939da7ceR329
68-
function TapedTask(f, args...; deepcopy_types=[Array, Ref]) # deepcoy Array and Ref by default.
68+
function TapedTask(f, args...; deepcopy_types=Union{Array, Ref}) # deepcoy Array and Ref by default.
6969
tf = TapedFunction(f, args...; cache=true, deepcopy_types=deepcopy_types)
7070
TapedTask(tf, args...)
7171
end
7272

73-
TapedTask(t::TapedTask, args...) = TapedTask(func(t), args...)
73+
TapedTask(finfo::Tuple{Any, Type}, args...) = TapedTask(finfo[1], args...; deepcopy_types=finfo[2])
74+
TapedTask(t::TapedTask, args...) = TapedTask(func(t), args...; deepcopy_types=t.tf.deepcopy_types)
7475
func(t::TapedTask) = t.tf.func
7576

7677
#=

Diff for: test/tapedtask.jl

+16
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,20 @@
11
@testset "tapedtask" begin
2+
@testset "construction" begin
3+
function f()
4+
t = 1
5+
while true
6+
produce(t)
7+
t = 1 + t
8+
end
9+
end
10+
11+
ttask = TapedTask(f)
12+
@test consume(ttask) == 1
13+
14+
ttask = TapedTask((f, Union{}))
15+
@test consume(ttask) == 1
16+
end
17+
218
@testset "iteration" begin
319
function f()
420
t = 1

0 commit comments

Comments
 (0)