Skip to content

Commit 5ec9f2b

Browse files
Union types (#157)
* Union types * Update test/tape_copy.jl * Update Project.toml Co-authored-by: Hong Ge <[email protected]>
1 parent bafaa1f commit 5ec9f2b

File tree

3 files changed

+31
-3
lines changed

3 files changed

+31
-3
lines changed

Diff for: Project.toml

+1-1
Original file line numberDiff line numberDiff line change
@@ -3,7 +3,7 @@ uuid = "6f1fad26-d15e-5dc8-ae53-837a1d7b8c9f"
33
license = "MIT"
44
desc = "Tape based task copying in Turing"
55
repo = "https://github.com/TuringLang/Libtask.jl.git"
6-
version = "0.8.2"
6+
version = "0.8.3"
77

88
[deps]
99
FunctionWrappers = "069b7b12-0de2-55c6-9aab-29f3d0a68a2e"

Diff for: src/tapedtask.jl

+9-2
Original file line numberDiff line numberDiff line change
@@ -63,10 +63,17 @@ function TapedTask(tf::TapedFunction, args...)
6363
return t
6464
end
6565

66+
BASE_COPY_TYPES = Union{Array, Ref}
67+
6668
# NOTE: evaluating model without a trace, see
6769
# https://github.com/TuringLang/Turing.jl/pull/1757#diff-8d16dd13c316055e55f300cd24294bb2f73f46cbcb5a481f8936ff56939da7ceR329
68-
function TapedTask(f, args...; deepcopy_types=Union{Array, Ref}) # deepcoy Array and Ref by default.
69-
tf = TapedFunction(f, args...; cache=true, deepcopy_types=deepcopy_types)
70+
function TapedTask(f, args...; deepcopy_types=nothing) # deepcoy Array and Ref by default.
71+
if isnothing(deepcopy_types)
72+
deepcopy = BASE_COPY_TYPES
73+
else
74+
deepcopy = Union{BASE_COPY_TYPES, deepcopy_types}
75+
end
76+
tf = TapedFunction(f, args...; cache=true, deepcopy_types=deepcopy)
7077
TapedTask(tf, args...)
7178
end
7279

Diff for: test/tape_copy.jl

+21
Original file line numberDiff line numberDiff line change
@@ -171,4 +171,25 @@
171171
y[][2] = 19
172172
@test y[][2] == 19
173173
end
174+
175+
@testset "override deepcopy_types #57" begin
176+
struct DummyType end
177+
178+
function f(start::Int)
179+
t = [start]
180+
while true
181+
produce(t[1])
182+
t[1] = 1 + t[1]
183+
end
184+
end
185+
186+
ttask = TapedTask(f, 0; deepcopy_types=DummyType)
187+
consume(ttask)
188+
189+
ttask2 = copy(ttask)
190+
consume(ttask2)
191+
192+
@test consume(ttask) == 1
193+
@test consume(ttask2) == 2
194+
end
174195
end

0 commit comments

Comments
 (0)