-
-
Couldn't load subscription status.
- Fork 37
Description
Here's a @dynamo I use to trace code execution (full MWE below):
@dynamo function (t::IRTracer)(fargs...)
ir = IR(fargs...)
ir == nothing && return # intrinsic functions
for (v, st) in ir
ex = st.expr
if Meta.isexpr(ex, :call)
ir[v] = Expr(:call, record_or_recurse!, self, ex.args...)
else
ir[v] = Expr(:call, identity, ex)
end
end
return ir
endEssentially, it does 2 things:
- Replaces all calls to
fwithrecord_or_recurse!(..., f, ...), which either recordsfto the list of operations or recursively applies transformation tof. - Replaces all other statements (i.g. constants) with call to
identityfunction.
When I apply this code to function f = x -> sum(x; dims=1) (or any other function with keywords), the following error is printed (although code runs fine and returns correct result):
Internal error: encountered unexpected error in runtime:
AssertionError(msg="argextype only works on argument-position values")
argextype at ./compiler/utilities.jl:166
argextype at ./compiler/utilities.jl:158 [inlined]
call_sig at ./compiler/ssair/inlining.jl:882
process_simple! at ./compiler/ssair/inlining.jl:956
assemble_inline_todo! at ./compiler/ssair/inlining.jl:999
ssa_inlining_pass! at ./compiler/ssair/inlining.jl:74 [inlined]
run_passes at ./compiler/ssair/driver.jl:138
optimize at ./compiler/optimize.jl:174
typeinf at ./compiler/typeinfer.jl:33
typeinf_edge at ./compiler/typeinfer.jl:484
...
Original function IR (@code_ir f(x)):
1: (%1, %2)
%3 = (:dims,)
%4 = Core.apply_type(Core.NamedTuple, %3)
%5 = Core.tuple(1)
%6 = (%4)(%5)
%7 = Core.kwfunc(Main.sum)
%8 = (%7)(%6, Main.sum, %2)
return %8Transformed IR (@code_ir t(f, x), where t::IRTracer):
1: (%1, %2)
%3 = Base.getfield(%2, 1)
%4 = Base.getfield(%2, 2)
%5 = (identity)((:dims,))
%6 = (record_or_recurse!)(%1, Core.apply_type, Core.NamedTuple, %5)
%7 = (record_or_recurse!)(%1, Core.tuple, 1)
%8 = (record_or_recurse!)(%1, %6, %7)
%9 = (record_or_recurse!)(%1, Core.kwfunc, Main.sum)
%10 = (record_or_recurse!)(%1, %9, %8, Main.sum, %4)
return %10If I comment out any of the transformations above (either function calls, or constants), error disappears.
My best guess so far is that the compiler attempts to infer the type of dims argument and makes an assertion that it is still a constant, but since I replaced it with a call to identity(:dims), the compiler pass fails.
Does this theory sound reasonable? If so, is there some metadata about :dims var that I should update to make this work?
MWE:
import IRTools: IR, @code_ir, @dynamo, self, var
const PRIMITIVES = Set([
Core.kwfunc(sum),
Core.apply_type,
])
mutable struct IRTracer
primitives::Set{Any}
ops::Vector{Any}
end
function IRTracer(;primitives=PRIMITIVES)
return IRTracer(primitives, [])
end
Base.show(io::IO, t::IRTracer) = print(io, "IRTracer($(length(t.ops)))")
function record_or_recurse!(t::IRTracer, fargs...)
fn, args = fargs[1], fargs[2:end]
if fn in t.primitives || (fn isa Type && fn <: NamedTuple)
res = fn(args...)
push!(t.ops, fargs)
else
res = t(fn, args...)
end
return res
end
@dynamo function (t::IRTracer)(fargs...)
ir = IR(fargs...)
ir == nothing && return # intrinsic functions
for (v, st) in ir
ex = st.expr
if Meta.isexpr(ex, :call)
ir[v] = Expr(:call, record_or_recurse!, self, ex.args...)
else
ir[v] = Expr(:call, identity, ex)
end
end
return ir
end
x = rand(2, 4)
t = IRTracer()
f = x -> sum(x; dims=1)
t(f, x)