Skip to content

Commit 333b279

Browse files
authored
Merge Tape and TapedFunction (#105)
* merge Tape and TapedFunction * Tape => RawTape * trivias update * remove gettape * merge run(::RawTape) and (tf::TapedFunction)(args...) * minor update
1 parent 64f90e6 commit 333b279

File tree

5 files changed

+114
-139
lines changed

5 files changed

+114
-139
lines changed

Diff for: perf/p0.jl

+3-3
Original file line numberDiff line numberDiff line change
@@ -29,9 +29,9 @@ args = m.evaluator[2:end];
2929
@show "CTask construction..."
3030
t = @btime Libtask.CTask(f, args...)
3131
# schedule(t.task) # work fine!
32-
# @show Libtask.result(t.tf.tape)
32+
# @show Libtask.result(t.tf)
3333
@show "Step in a tape..."
34-
@btime Libtask.step_in(t.tf.tape, args)
34+
@btime Libtask.step_in(t.tf, args)
3535

3636
# Case 2: SMC sampler
3737

@@ -44,4 +44,4 @@ t = @btime Libtask.CTask(m.evaluator[1], m.evaluator[2:end]...);
4444
# schedule(t.task)
4545
# @show Libtask.result(t.tf.tape)
4646
@show "Step in a tape..."
47-
@btime Libtask.step_in(t.tf.tape, m.evaluator[2:end])
47+
@btime Libtask.step_in(t.tf, m.evaluator[2:end])

Diff for: perf/p2.jl

+2-2
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,6 @@ args = m.evaluator[2:end]
5858

5959
t = Libtask.CTask(f, args...)
6060

61-
Libtask.step_in(t.tf.tape, args)
61+
Libtask.step_in(t.tf, args)
6262

63-
@show Libtask.result(t.tf.tape)
63+
@show Libtask.result(t.tf)

Diff for: src/tapedfunction.jl

+85-106
Original file line numberDiff line numberDiff line change
@@ -1,72 +1,105 @@
11
abstract type AbstractInstruction end
2-
3-
mutable struct Tape
4-
tape::Vector{<:AbstractInstruction}
5-
counter::Int
6-
owner
7-
end
2+
abstract type Taped end
3+
const RawTape = Vector{AbstractInstruction}
84

95
"""
106
Instruction
117
128
An `Instruction` stands for a function call
139
"""
14-
mutable struct Instruction{F} <: AbstractInstruction
15-
fun::F
10+
mutable struct Instruction{F, T<:Taped} <: AbstractInstruction
11+
func::F
1612
input::Tuple
1713
output
18-
tape::Tape
14+
tape::T
1915
end
2016

21-
Tape() = Tape(Vector{AbstractInstruction}(), 1, nothing)
22-
Tape(owner) = Tape(Vector{AbstractInstruction}(), 1, owner)
23-
MacroTools.@forward Tape.tape Base.iterate, Base.length
24-
MacroTools.@forward Tape.tape Base.push!, Base.getindex, Base.lastindex
25-
const NULL_TAPE = Tape()
26-
27-
function setowner!(tape::Tape, owner)
28-
tape.owner = owner
29-
return tape
17+
mutable struct TapedFunction{F} <: Taped
18+
func::F # maybe a function or a callable obejct
19+
arity::Int
20+
ir::Union{Nothing, IRTools.IR}
21+
tape::RawTape
22+
counter::Int
23+
owner
24+
function TapedFunction(f::F; arity::Int=-1) where {F}
25+
new{F}(f, arity, nothing, RawTape(), 1, nothing)
26+
end
3027
end
3128

3229
mutable struct Box{T}
3330
val::T
3431
end
3532

33+
## methods for Box
3634
val(x) = x
3735
val(x::Box) = x.val
36+
val(x::TapedFunction) = x.func
3837
box(x) = Box(x)
3938
box(x::Box) = x
39+
Base.show(io::IO, box::Box) = print(io, "Box(", box.val, ")")
4040

41-
gettape(x) = nothing
42-
gettape(x::Instruction) = x.tape
43-
function gettape(x::Tuple)
44-
for i in x
45-
gettape(i) != nothing && return gettape(i)
46-
end
41+
## methods for RawTape and Taped
42+
MacroTools.@forward TapedFunction.tape Base.iterate, Base.length
43+
MacroTools.@forward TapedFunction.tape Base.push!, Base.getindex, Base.lastindex
44+
45+
result(t::RawTape) = isempty(t) ? nothing : val(t[end].output)
46+
result(t::TapedFunction) = result(t.tape)
47+
48+
function increase_counter!(t::TapedFunction)
49+
t.counter > length(t) && return
50+
# instr = t[t.counter]
51+
t.counter += 1
52+
return t
4753
end
48-
result(t::Tape) = isempty(t) ? nothing : val(t[end].output)
4954

50-
function Base.show(io::IO, box::Box)
51-
println(io, "Box($(box.val))")
55+
function reset!(tf::TapedFunction, ir::IRTools.IR, tape::RawTape)
56+
tf.ir = ir
57+
tf.tape = tape
58+
return tf
5259
end
5360

54-
function Base.show(io::IO, instruction::AbstractInstruction)
55-
println(io, "A $(typeof(instruction))")
61+
function (tf::TapedFunction)(args...)
62+
if isempty(tf.tape)
63+
ir = IRTools.@code_ir tf.func(args...)
64+
ir = intercept(ir; recorder=:track!)
65+
tf.ir = ir
66+
tf.tape = RawTape()
67+
tf2 = IRTools.evalir(ir, tf, args...)
68+
@assert tf === tf2
69+
else
70+
# run the raw tape
71+
if length(args) > 0
72+
input = map(box, args)
73+
tf.tape[1].input = input
74+
end
75+
for instruction in tf.tape
76+
instruction()
77+
end
78+
end
79+
return result(tf)
5680
end
5781

58-
function Base.show(io::IO, instruction::Instruction)
59-
fun = instruction.fun
60-
tape = instruction.tape
61-
println(io, "Instruction($(fun)$(map(val, instruction.input)), tape=$(objectid(tape)))")
82+
function Base.show(io::IO, tf::TapedFunction)
83+
buf = IOBuffer()
84+
println(buf, "TapedFunction:")
85+
println(buf, "* .func => $(tf.func)")
86+
println(buf, "* .ir =>")
87+
println(buf, "------------------")
88+
println(buf, tf.ir)
89+
println(buf, "------------------")
90+
println(buf, "* .tape =>")
91+
println(buf, "------------------")
92+
println(buf, tf.tape)
93+
println(buf, "------------------")
94+
print(io, String(take!(buf)))
6295
end
6396

64-
function Base.show(io::IO, tp::Tape)
97+
function Base.show(io::IO, tp::RawTape)
6598
# we use an extra IOBuffer to collect all the data and then
6699
# output it once to avoid output interrupt during task context
67100
# switching
68101
buf = IOBuffer()
69-
print(buf, "$(length(tp))-element Tape")
102+
print(buf, "$(length(tp))-element RawTape")
70103
isempty(tp) || println(buf, ":")
71104
i = 1
72105
for instruction in tp
@@ -77,10 +110,19 @@ function Base.show(io::IO, tp::Tape)
77110
print(io, String(take!(buf)))
78111
end
79112

113+
## methods for Instruction
114+
Base.show(io::IO, instruction::AbstractInstruction) = print(io, "A ", typeof(instruction))
115+
116+
function Base.show(io::IO, instruction::Instruction)
117+
func = instruction.func
118+
tape = instruction.tape
119+
println(io, "Instruction($(func)$(map(val, instruction.input)), tape=$(objectid(tape)))")
120+
end
121+
80122
function (instr::Instruction{F})() where F
81123
# catch run-time exceptions / errors.
82124
try
83-
output = instr.fun(map(val, instr.input)...)
125+
output = instr.func(map(val, instr.input)...)
84126
instr.output.val = output
85127
catch e
86128
println(e, catch_backtrace());
@@ -101,26 +143,9 @@ function (instr::Instruction{typeof(_new)})()
101143
end
102144
end
103145

146+
## internal functions
104147

105-
function increase_counter!(t::Tape)
106-
t.counter > length(t) && return
107-
# instr = t[t.counter]
108-
t.counter += 1
109-
return t
110-
end
111-
112-
function run(tape::Tape, args...)
113-
if length(args) > 0
114-
input = map(box, args)
115-
tape[1].input = input
116-
end
117-
for instruction in tape
118-
instruction()
119-
increase_counter!(tape)
120-
end
121-
end
122-
123-
function run_and_record!(tape::Tape, f, args...)
148+
function track!(tape::Taped, f, args...)
124149
f = val(f) # f maybe a Boxed closure
125150
output = try
126151
box(f(map(val, args)...))
@@ -133,7 +158,7 @@ function run_and_record!(tape::Tape, f, args...)
133158
return output
134159
end
135160

136-
function run_and_record!(tape::Tape, ::typeof(_new), args...)
161+
function track!(tape::Taped, ::typeof(_new), args...)
137162
output = try
138163
expr = Expr(:new, map(val, args)...)
139164
box(eval(expr))
@@ -171,9 +196,11 @@ function _replace_args(args, pairs::Dict)
171196
end
172197
end
173198

174-
function intercept(ir; recorder=:run_and_record!)
199+
function intercept(ir; recorder=:track!)
175200
ir == nothing && return
176-
tape = pushfirst!(ir, IRTools.xcall(@__MODULE__, :Tape))
201+
# we use tf instead of the original function as the first argument
202+
# get the TapedFunction
203+
tape = pushfirst!(ir, IRTools.xcall(Base, :identity, IRTools.arguments(ir)[1]))
177204

178205
# box the args
179206
first_blk = IRTools.blocks(ir)[1]
@@ -229,51 +256,3 @@ function intercept(ir; recorder=:run_and_record!)
229256
unbox_condition(ir)
230257
return ir
231258
end
232-
233-
mutable struct TapedFunction
234-
func # ::Function # maybe a callable obejct
235-
arity::Int
236-
ir::Union{Nothing, IRTools.IR}
237-
tape::Tape
238-
owner
239-
function TapedFunction(f; arity::Int=-1)
240-
new(f, arity, nothing, NULL_TAPE, nothing)
241-
end
242-
end
243-
244-
function reset!(tf::TapedFunction, ir::IRTools.IR, tape::Tape)
245-
tf.ir = ir
246-
tf.tape = tape
247-
setowner!(tape, tf)
248-
return tf
249-
end
250-
251-
function (tf::TapedFunction)(args...)
252-
if isempty(tf.tape)
253-
ir = IRTools.@code_ir tf.func(args...)
254-
ir = intercept(ir; recorder=:run_and_record!)
255-
tape = IRTools.evalir(ir, tf.func, args...)
256-
tf.ir = ir
257-
tf.tape = tape
258-
setowner!(tape, tf)
259-
return result(tape)
260-
end
261-
# TODO: use cache
262-
run(tf.tape, args...)
263-
return result(tf.tape)
264-
end
265-
266-
function Base.show(io::IO, tf::TapedFunction)
267-
buf = IOBuffer()
268-
println(buf, "TapedFunction:")
269-
println(buf, "* .func => $(tf.func)")
270-
println(buf, "* .ir =>")
271-
println(buf, "------------------")
272-
println(buf, tf.ir)
273-
println(buf, "------------------")
274-
println(buf, "* .tape =>")
275-
println(buf, "------------------")
276-
println(buf, tf.tape)
277-
println(buf, "------------------")
278-
print(io, String(take!(buf)))
279-
end

0 commit comments

Comments
 (0)