Skip to content

Commit b486970

Browse files
committed
fix: mooncake lux
1 parent ab2fcf2 commit b486970

File tree

3 files changed

+6
-18
lines changed

3 files changed

+6
-18
lines changed

Project.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "Lux"
22
uuid = "b2108857-7c20-44ae-9111-449ecde12c47"
33
authors = ["Avik Pal <avikpal@mit.edu> and contributors"]
4-
version = "1.31.1"
4+
version = "1.31.2"
55

66
[deps]
77
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
@@ -127,7 +127,7 @@ MLUtils = "0.4.4"
127127
MPI = "0.20.19"
128128
MacroTools = "0.5.13"
129129
Markdown = "1.10"
130-
Mooncake = "0.4.148, 0.5"
130+
Mooncake = "0.5"
131131
NCCL = "0.1.2"
132132
NNlib = "0.9.27"
133133
Optimisers = "0.4.6"

ext/MooncakeExt/training.jl

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,7 @@ function Lux.Training.compute_gradients_impl(
33
) where {F}
44
config = get_config(ad)
55
pullback_cache = prepare_pullback_cache(
6-
objective_function,
7-
ts.model,
8-
ts.parameters,
9-
ts.states,
10-
data;
11-
debug_mode=config.debug_mode,
12-
silence_debug_messages=config.silence_debug_messages,
6+
objective_function, ts.model, ts.parameters, ts.states, data; config
137
)
148
# evaluate once to get the correct types
159
loss, stₙ, stats = objective_function(ts.model, ts.parameters, ts.states, data)
Lines changed: 3 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,17 +1,11 @@
1-
using OneHotArrays, LuxLib, Test, BenchmarkTools
1+
using OneHotArrays, LuxLib, Test
22

33
@testset "Specialized OneHotArrays Dispatch" begin
44
x = onehotbatch("aabc", "abcdefghijklmnopqrstuv")
55
weight = reshape(collect(Float32, 1:(1024 * 22)), 1024, 22)
66

7-
dense_res_time = @belapsed fused_dense_bias_activation(
8-
identity, $(weight), $(Array(x)), nothing
9-
)
10-
onehot_res_time = @belapsed fused_dense_bias_activation(
11-
identity, $(weight), $(x), nothing
12-
)
13-
14-
@test onehot_res_time < dense_res_time / 5
157
@test fused_dense_bias_activation(identity, weight, x, nothing)
168
fused_dense_bias_activation(identity, weight, Array(x), nothing)
9+
10+
@test LuxLib.Utils.force_3arg_mul!_dispatch(weight, weight, x)
1711
end

0 commit comments

Comments
 (0)