Skip to content

Commit c6a0560

Browse files
committed
fix: missing onehotarrays dispatch for cpu matmul
1 parent 6ab9a57 commit c6a0560

File tree

6 files changed

+36
-1
lines changed

6 files changed

+36
-1
lines changed

lib/LuxLib/Project.toml

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,6 @@
11
name = "LuxLib"
22
uuid = "82251201-b29d-42c6-8e01-566dec8acb11"
3-
version = "1.15.4"
3+
version = "1.15.5"
44
authors = ["Avik Pal <avikpal@mit.edu> and contributors"]
55

66
[deps]
@@ -36,6 +36,7 @@ ForwardDiff = "f6369f11-7733-5829-9624-2563aa707210"
3636
LoopVectorization = "bdcacae8-1622-11e9-2a5c-532679323890"
3737
MKL = "33e6dc65-8f57-5167-99aa-e5a354878fb2"
3838
Octavian = "6fd5a793-0b7e-452c-907f-f8bfe9c57db4"
39+
OneHotArrays = "0b1bfda6-eb8a-41d2-88d8-f5af5cad476f"
3940
Polyester = "f517fe37-dbe3-4b94-8317-1923a5111588"
4041
Reactant = "3c362404-f566-11ee-1572-e11a4b42c853"
4142
ReactantCore = "a3311ec8-5e00-46d5-b541-4f83e724a433"
@@ -58,6 +59,7 @@ ForwardDiffExt = "ForwardDiff"
5859
LoopVectorizationExt = ["LoopVectorization", "Polyester"]
5960
MKLExt = "MKL"
6061
OctavianExt = ["Octavian", "LoopVectorization"]
62+
OneHotArraysExt = ["OneHotArrays"]
6163
ReactantExt = ["Reactant", "ReactantCore"]
6264
ReverseDiffExt = "ReverseDiff"
6365
SLEEFPiratesExt = "SLEEFPirates"
@@ -88,6 +90,7 @@ MLDataDevices = "1.17.1"
8890
Markdown = "1.10"
8991
NNlib = "0.9.27"
9092
Octavian = "0.3.28"
93+
OneHotArrays = "0.2.5"
9194
Polyester = "0.7.18"
9295
Preferences = "1.4.3"
9396
Random = "1.10"

lib/LuxLib/ext/OneHotArraysExt.jl

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,8 @@
1+
module OneHotArraysExt
2+
3+
using LuxLib: Utils
4+
using OneHotArrays: OneHotLike
5+
6+
Utils.force_3arg_mul!_dispatch(::AbstractMatrix, ::AbstractMatrix, ::OneHotLike) = true
7+
8+
end

lib/LuxLib/src/impl/matmul.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -109,6 +109,10 @@ function matmuladd!(
109109
end
110110

111111
function matmul!(C::AbstractMatrix, A::AbstractMatrix, B::AbstractMatrix)
112+
if Utils.force_3arg_mul!_dispatch(C, A, B)
113+
mul!(C, A, B)
114+
return nothing
115+
end
112116
matmul!(C, internal_operation_mode((C, A, B)), A, B)
113117
return nothing
114118
end

lib/LuxLib/src/utils.jl

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -396,4 +396,8 @@ function maybe_reshape(x::AbstractArray{T,N}, dims::Dims{N}) where {T,N}
396396
end
397397
maybe_reshape(x::AbstractArray{T}, dims::Dims) where {T} = reshape(x, dims)
398398

399+
force_3arg_mul!_dispatch(::AbstractMatrix, ::AbstractMatrix, ::AbstractMatrix) = false
400+
401+
CRC.@non_differentiable force_3arg_mul!_dispatch(::Any...)
402+
399403
end

lib/LuxLib/test/Project.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ Metal = "dde4c033-4e86-420c-a63e-0dd931031962"
2222
Mooncake = "da2b9cff-9c12-43a0-ae48-6db2b0edb7d6"
2323
NNlib = "872c559c-99b0-510c-b3b7-b6c96a88d5cd"
2424
Octavian = "6fd5a793-0b7e-452c-907f-f8bfe9c57db4"
25+
OneHotArrays = "0b1bfda6-eb8a-41d2-88d8-f5af5cad476f"
2526
ParallelTestRunner = "d3525ed8-44d0-4b2c-a655-542cee43accc"
2627
Pkg = "44cfe95a-1eb2-52ea-b672-e2afdf69b78f"
2728
Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c"
Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,15 @@
1+
using OneHotArrays, LuxLib, Test, BenchmarkTools
2+
3+
@testset "Specialized OneHotArrays Dispatch" begin
4+
x = onehotbatch("aabc", "abcdefghijklmnopqrstuv")
5+
weight = reshape(collect(Float32, 1:(1024 * 22)), 1024, 22)
6+
7+
dense_res_time = @belapsed fused_dense_bias_activation(
8+
identity, weight, Array(x), nothing
9+
)
10+
onehot_res_time = @belapsed fused_dense_bias_activation(identity, weight, x, nothing)
11+
12+
@test onehot_res_time < dense_res_time / 5
13+
@test fused_dense_bias_activation(identity, weight, x, nothing)
14+
fused_dense_bias_activation(identity, weight, Array(x), nothing)
15+
end

0 commit comments

Comments
 (0)