@@ -132,8 +132,6 @@ with the following chain or calls:
132132where `val = fn(args...)` and `pb` is the pullback function.
133133"""
134134function chainrules_transform!(tape:: Tape )
135- # global TAPE = tape
136- # error("")
137135 i = 1
138136 while i <= length(tape)
139137 # tape[V(i)] isa Call && tape[V(i)].fn == Core.kwcall && break
@@ -183,7 +181,6 @@ function step_back!(tape::Tape, y::Variable)
183181 end
184182 for (i, x) in enumerate(y_fargs)
185183 if x isa V
186- global STATE = (tape, y, y_fargs, i, x)
187184 dx = push!(tape, mkcall(getfield, dxs, i; line= " d$y /d$x " ))
188185 # @debug "Updating derivative: $x -> $dx"
189186 set_or_add_deriv!(tape, x, dx)
@@ -208,8 +205,8 @@ function back!(tape::Tape; seed=1)
208205 error(" Gradient of a vector-valued function requires a seed" )
209206 elseif seed == :auto
210207 zval = tape[z]. val
211- # @assert zval isa Number || zval isa AbstractArray
212- seed = zval isa AbstractArray ? ones(eltype( zval) , size(zval)) : one(zval)
208+ @assert zval isa Number || zval isa AbstractArray
209+ seed = zval isa AbstractArray ? array_like( 1 , zval, size(zval)) : one(zval)
213210 end
214211 dy = push!(tape, Constant(seed; line= " seed for $(tape[V(1 )]. val) " ))
215212 # save seed var to use in compilation later
0 commit comments