Skip to content

Commit f14c794

Browse files
committed
fix: simpleRNN works with reactant
1 parent 6557b64 commit f14c794

File tree

4 files changed

+25
-21
lines changed

4 files changed

+25
-21
lines changed

examples/SimpleRNN/main.jl

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -150,22 +150,32 @@ function main(model_type)
150150

151151
for epoch in 1:25
152152
## Train the model
153+
total_loss = 0.0f0
154+
total_samples = 0
153155
for (x, y) in train_loader
154156
(_, loss, _, train_state) = Training.single_train_step!(
155157
ad, lossfn, (x, y), train_state
156158
)
157-
@printf "Epoch [%3d]: Loss %4.5f\n" epoch loss
159+
total_loss += loss * length(y)
160+
total_samples += length(y)
158161
end
162+
@printf "Epoch [%3d]: Loss %4.5f\n" epoch (total_loss/total_samples)
159163

160164
## Validate the model
165+
total_acc = 0.0f0
166+
total_loss = 0.0f0
167+
total_samples = 0
168+
161169
st_ = Lux.testmode(train_state.states)
162170
for (x, y) in val_loader
163171
ŷ, st_ = model_compiled(x, train_state.parameters, st_)
164172
ŷ, y = cdev(ŷ), cdev(y)
165-
loss = lossfn(ŷ, y)
166-
acc = accuracy(ŷ, y)
167-
@printf "Validation: Loss %4.5f Accuracy %4.5f\n" loss acc
173+
total_acc += accuracy(ŷ, y) * length(y)
174+
total_loss += lossfn(ŷ, y) * length(y)
175+
total_samples += length(y)
168176
end
177+
178+
@printf "Validation:\tLoss %4.5f\tAccuracy %4.5f\n" (total_loss/total_samples) (total_acc/total_samples)
169179
end
170180

171181
return (train_state.parameters, train_state.states) |> cpu_device()

ext/LuxReactantExt/training.jl

Lines changed: 6 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -13,22 +13,14 @@ function wrapped_objective_function(
1313
end
1414

1515
function compute_gradients_internal(objective_function::F, model, data, ps, st) where {F}
16-
# XXX: Hacky workaround for https://github.com/LuxDL/Lux.jl/issues/1186
17-
# stats_wrapper = StatsAndNewStateWrapper(nothing, nothing)
18-
# res = Enzyme.gradient(
19-
# Enzyme.set_abi(Enzyme.ReverseWithPrimal, Reactant.ReactantABI),
20-
# Const(wrapped_objective_function), Const(objective_function),
21-
# Const(model), ps, Const(st), Const(data), Const(stats_wrapper)
22-
# )
23-
# loss, dps = res.val, res.derivs[3]
24-
# return dps, loss, stats_wrapper.stats, stats_wrapper.st
16+
stats_wrapper = StatsAndNewStateWrapper(nothing, nothing)
2517
res = Enzyme.gradient(
2618
Enzyme.set_abi(Enzyme.ReverseWithPrimal, Reactant.ReactantABI),
27-
Const(objective_function), Const(model), ps, Const(st), Const(data)
19+
Const(wrapped_objective_function), Const(objective_function),
20+
Const(model), ps, Const(st), Const(data), Const(stats_wrapper)
2821
)
29-
(loss, new_st, stats) = res.val
30-
(_, dps, _, _) = res.derivs
31-
return dps, loss, stats, new_st
22+
loss, dps = res.val, res.derivs[3]
23+
return dps, loss, stats_wrapper.stats, stats_wrapper.st
3224
end
3325

3426
function maybe_dump_to_mlir_file!(f::F, args...) where {F}
@@ -98,8 +90,7 @@ for inplace in ("!", "")
9890
return ts
9991
end
10092

101-
# XXX: Should we add a check to ensure the inputs to this function is same as the one
102-
# used in the compiled function? We can re-trigger the compilation with a warning
93+
# XXX: recompile with a warning if new input types are used
10394
@eval function Lux.Training.$(fname)(backend::ReactantBackend, objective_function::F,
10495
data, ts::Training.TrainState) where {F}
10596
maybe_dump_to_mlir_file!($(internal_fn), objective_function, ts.model, data,

lib/LuxCore/Project.toml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "LuxCore"
22
uuid = "bb33d45b-7691-41d6-9220-0943567d0623"
33
authors = ["Avik Pal <avikpal@mit.edu> and contributors"]
4-
version = "1.2.1"
4+
version = "1.2.2"
55

66
[deps]
77
Compat = "34da2185-b29b-5c13-b0c7-acf172513d20"

lib/LuxCore/ext/LuxCoreReactantExt.jl

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
module LuxCoreReactantExt
22

3-
using LuxCore: AbstractLuxLayer
3+
using LuxCore: AbstractLuxLayer, LuxCore
44
using Reactant: Reactant
55

66
# Avoid tracing though models since it won't contain anything useful
@@ -10,4 +10,7 @@ function Reactant.make_tracer(
1010
return model
1111
end
1212

13+
LuxCore.replicate(rng::Reactant.TracedRNG) = copy(rng)
14+
LuxCore.replicate(rng::Reactant.ConcreteRNG) = copy(rng)
15+
1316
end

0 commit comments

Comments
 (0)