Skip to content

Commit 54dc712

Browse files
authored
Merge pull request #141 from dfdx/fix/auto-seed-on-cuda
Fix device of seed in the :auto mode
2 parents ea4f470 + ff58b31 commit 54dc712

File tree

4 files changed

+15
-7
lines changed

4 files changed

+15
-7
lines changed

Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "Yota"
22
uuid = "cd998857-8626-517d-b929-70ad188a48f0"
33
authors = ["Andrei Zhabinski <andrei.zhabinski@gmail.com>"]
4-
version = "0.8.4"
4+
version = "0.8.5"
55

66
[deps]
77
ChainRules = "082447d4-558c-5d27-93f4-14fc19e9eca2"

src/grad.jl

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -132,8 +132,6 @@ with the following chain or calls:
132132
where `val = fn(args...)` and `pb` is the pullback function.
133133
"""
134134
function chainrules_transform!(tape::Tape)
135-
# global TAPE = tape
136-
# error("")
137135
i = 1
138136
while i <= length(tape)
139137
# tape[V(i)] isa Call && tape[V(i)].fn == Core.kwcall && break
@@ -183,7 +181,6 @@ function step_back!(tape::Tape, y::Variable)
183181
end
184182
for (i, x) in enumerate(y_fargs)
185183
if x isa V
186-
global STATE = (tape, y, y_fargs, i, x)
187184
dx = push!(tape, mkcall(getfield, dxs, i; line="d$y/d$x"))
188185
# @debug "Updating derivative: $x -> $dx"
189186
set_or_add_deriv!(tape, x, dx)
@@ -208,8 +205,8 @@ function back!(tape::Tape; seed=1)
208205
error("Gradient of a vector-valued function requires a seed")
209206
elseif seed == :auto
210207
zval = tape[z].val
211-
# @assert zval isa Number || zval isa AbstractArray
212-
seed = zval isa AbstractArray ? ones(eltype(zval), size(zval)) : one(zval)
208+
@assert zval isa Number || zval isa AbstractArray
209+
seed = zval isa AbstractArray ? array_like(1, zval, size(zval)) : one(zval)
213210
end
214211
dy = push!(tape, Constant(seed; line="seed for $(tape[V(1)].val)"))
215212
# save seed var to use in compilation later

src/helpers.jl

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -88,7 +88,7 @@ end
8888
unbroadcast_prod_y(x::ArrayOrBroadcasted, y::ArrayOrBroadcasted, Δ) = unbroadcast_prod_x(y, x, Δ)
8989

9090
# device_like(example, a) = (device = guess_device([example]); device(a))
91-
array_like(value, example) = fill!(similar(example, (1,)), value)
91+
array_like(value, example, sz=(1,)) = fill!(similar(example, sz), value)
9292

9393
# unbroadcast_prod_x(x::Number, y::AbstractArray, Δ) = unbroadcast_prod_x(device_like(y, [x]), y, Δ)[1]
9494
unbroadcast_prod_x(x::Number, y::ArrayOrBroadcasted, Δ) = unbroadcast_prod_x(array_like(x, y), y, Δ)[1]

test/test_grad.jl

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -254,6 +254,17 @@ end
254254
val, g = grad(x -> 2x, [1.0, 2.0, 3.0]; seed=ones(3))
255255
@test val == [2.0, 4.0, 6.0]
256256
@test g == (ZeroTangent(), [2.0, 2.0, 2.0])
257+
258+
val, g = grad(x -> 2x, [1.0, 2.0, 3.0]; seed=:auto)
259+
@test val == [2.0, 4.0, 6.0]
260+
@test g == (ZeroTangent(), [2.0, 2.0, 2.0])
261+
262+
if CUDA.functional()
263+
CUDA.allowscalar(false)
264+
val, g = grad(x -> 2x, cu([1.0, 2.0, 3.0]); seed=:auto)
265+
@test val == cu([2.0, 4.0, 6.0])
266+
@test g == (ZeroTangent(), cu([2.0, 2.0, 2.0]))
267+
end
257268
end
258269

259270

0 commit comments

Comments
 (0)