Skip to content

Commit fa3c732

Browse files
authored
fix: missing onehotarrays dispatch for cpu matmul (#1655)
* fix: missing onehotarrays dispatch for cpu matmul * fix: mooncake lux
1 parent 6ab9a57 commit fa3c732

File tree

8 files changed

+35
-10
lines changed

8 files changed

+35
-10
lines changed

Project.toml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,7 @@
11
name = "Lux"
22
uuid = "b2108857-7c20-44ae-9111-449ecde12c47"
33
authors = ["Avik Pal <avikpal@mit.edu> and contributors"]
4-
version = "1.31.1"
4+
version = "1.31.2"
55

66
[deps]
77
ADTypes = "47edcb42-4c32-4615-8424-f2b9edc5f35b"
@@ -127,7 +127,7 @@ MLUtils = "0.4.4"
127127
MPI = "0.20.19"
128128
MacroTools = "0.5.13"
129129
Markdown = "1.10"
130-
Mooncake = "0.4.148, 0.5"
130+
Mooncake = "0.5"
131131
NCCL = "0.1.2"
132132
NNlib = "0.9.27"
133133
Optimisers = "0.4.6"

ext/MooncakeExt/training.jl

Lines changed: 1 addition & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -3,13 +3,7 @@ function Lux.Training.compute_gradients_impl(
33
) where {F}
44
config = get_config(ad)
55
pullback_cache = prepare_pullback_cache(
6-
objective_function,
7-
ts.model,
8-
ts.parameters,
9-
ts.states,
10-
data;
11-
debug_mode=config.debug_mode,
12-
silence_debug_messages=config.silence_debug_messages,
6+
objective_function, ts.model, ts.parameters, ts.states, data; config
137
)
148
# evaluate once to get the correct types
159
loss, stₙ, stats = objective_function(ts.model, ts.parameters, ts.states, data)

lib/LuxLib/Project.toml

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "LuxLib"
22
uuid = "82251201-b29d-42c6-8e01-566dec8acb11"
3-
version = "1.15.4"
3+
version = "1.15.3"
44
authors = ["Avik Pal <avikpal@mit.edu> and contributors"]
55

66
[deps]
@@ -36,6 +36,7 @@ ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
3636
LoopVectorization = "bdcacae8-1622-11e9-2a5c-532679323890"
3737
MKL = "33e6dc65-8f57-5167-99aa-e5a354878fb2"
3838
Octavian = "6fd5a793-0b7e-452c-907f-f8bfe9c57db4"
39+
OneHotArrays = "0b1bfda6-eb8a-41d2-88d8-f5af5cad476f"
3940
Polyester = "f517fe37-dbe3-4b94-8317-1923a5111588"
4041
Reactant = "3c362404-f566-11ee-1572-e11a4b42c853"
4142
ReactantCore = "a3311ec8-5e00-46d5-b541-4f83e724a433"
@@ -58,6 +59,7 @@ ForwardDiffExt = "ForwardDiff"
5859
LoopVectorizationExt = ["LoopVectorization", "Polyester"]
5960
MKLExt = "MKL"
6061
OctavianExt = ["Octavian", "LoopVectorization"]
62+
OneHotArraysExt = ["OneHotArrays"]
6163
ReactantExt = ["Reactant", "ReactantCore"]
6264
ReverseDiffExt = "ReverseDiff"
6365
SLEEFPiratesExt = "SLEEFPirates"
@@ -88,6 +90,7 @@ MLDataDevices = "1.17.1"
8890
Markdown = "1.10"
8991
NNlib = "0.9.27"
9092
Octavian = "0.3.28"
93+
OneHotArrays = "0.2.5"
9194
Polyester = "0.7.18"
9295
Preferences = "1.4.3"
9396
Random = "1.10"

lib/LuxLib/ext/OneHotArraysExt.jl

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
module OneHotArraysExt
2+
3+
using LuxLib: Utils
4+
using OneHotArrays: OneHotLike
5+
6+
Utils.force_3arg_mul!_dispatch(::AbstractMatrix, ::AbstractMatrix, ::OneHotLike) = true
7+
8+
end

lib/LuxLib/src/impl/matmul.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -109,6 +109,10 @@ function matmuladd!(
109109
end
110110

111111
function matmul!(C::AbstractMatrix, A::AbstractMatrix, B::AbstractMatrix)
112+
if Utils.force_3arg_mul!_dispatch(C, A, B)
113+
mul!(C, A, B)
114+
return nothing
115+
end
112116
matmul!(C, internal_operation_mode((C, A, B)), A, B)
113117
return nothing
114118
end

lib/LuxLib/src/utils.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -396,4 +396,8 @@ function maybe_reshape(x::AbstractArray{T,N}, dims::Dims{N}) where {T,N}
396396
end
397397
maybe_reshape(x::AbstractArray{T}, dims::Dims) where {T} = reshape(x, dims)
398398

399+
force_3arg_mul!_dispatch(::AbstractMatrix, ::AbstractMatrix, ::AbstractMatrix) = false
400+
401+
CRC.@non_differentiable force_3arg_mul!_dispatch(::Any...)
402+
399403
end

lib/LuxLib/test/Project.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ Metal = "dde4c033-4e86-420c-a63e-0dd931031962"
2222
Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6"
2323
NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
2424
Octavian = "6fd5a793-0b7e-452c-907f-f8bfe9c57db4"
25+
OneHotArrays = "0b1bfda6-eb8a-41d2-88d8-f5af5cad476f"
2526
ParallelTestRunner = "d3525ed8-44d0-4b2c-a655-542cee43accc"
2627
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
2728
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,11 @@
1+
using OneHotArrays, LuxLib, Test
2+
3+
@testset "Specialized OneHotArrays Dispatch" begin
4+
x = onehotbatch("aabc", "abcdefghijklmnopqrstuv")
5+
weight = reshape(collect(Float32, 1:(1024 * 22)), 1024, 22)
6+
7+
@test fused_dense_bias_activation(identity, weight, x, nothing)
8+
fused_dense_bias_activation(identity, weight, Array(x), nothing)
9+
10+
@test LuxLib.Utils.force_3arg_mul!_dispatch(weight, weight, x)
11+
end

0 commit comments

Comments
 (0)