-
Notifications
You must be signed in to change notification settings - Fork 12
Closed
Description
Hi, I think I'm seeing a bug in Umlaut when called from Yota. Here's a MWE:
using TransformVariables, Umlaut, Yota
tr = as((a=asℝ, b = asℝ))
function f(x)
nt = transform(tr, x)
nt.a + nt.b
end
# This works
val, tape = trace(f, zeros(2))
# This too
val, tape = trace(f, zeros(2); ctx=Yota.GradCtx)
# Throws `ERROR: Code for this Method is not available.`
grad(f, zeros(2))To get some idea where it's breaking, ...
julia> Umlaut.print_stack_trace()
[1] _transform_tuple(flag::TransformVariables.LogJacFlag, x::AbstractVector, index, ts) in TransformVariables at /home/chad/.julia/packages/TransformVariables/XMykI/src/aggregation.jl:163
[2] _transform_tuple(flag::TransformVariables.LogJacFlag, x::AbstractVector, index, ts) in TransformVariables at /home/chad/.julia/packages/TransformVariables/XMykI/src/aggregation.jl:163
[3] transform_tuple(flag::TransformVariables.LogJacFlag, tt::Tuple{Vararg{TransformVariables.AbstractTransform, N}} where N, x, index) in TransformVariables at /home/chad/.julia/packages/TransformVariables/XMykI/src/aggregation.jl:175
[4] transform_with(flag::TransformVariables.LogJacFlag, tt::TransformVariables.TransformTuple{<:NamedTuple}, x, index) in TransformVariables at /home/chad/.julia/packages/TransformVariables/XMykI/src/aggregation.jl:227
[5] transform(t::TransformVariables.VectorTransform, x::AbstractVector) in TransformVariables at /home/chad/.julia/packages/TransformVariables/XMykI/src/generic.jl:265
[6] f(x) in Main at REPL[5]:1To get some more detail, I looked at the original stack trace, which includes a call to Umlaut.getcode. Adding a line @show f, types to that function, I see that it's trying to call code_lowered(tuple, (Float64,)), which leads to the crash.
Reactions are currently unavailable
Metadata
Metadata
Assignees
Labels
No labels