Skip to content

Commit b19e020

Browse files
committed
perf: use reduction
1 parent 8d9544e commit b19e020

File tree

1 file changed

+13
-14
lines changed

1 file changed

+13
-14
lines changed

examples/MinimalMamba/main.jl

Lines changed: 13 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -34,16 +34,9 @@ function selective_scan_reference(
3434
## Perform selective scan with a sequential implementation for correctness verification
3535
x = fill!(similar(u, size(A, 1), size(u, 1), size(u, 3)), 0) ## [n, d_in, b]
3636
y = similar(u)
37-
@trace for i in 1:size(u, 2)
37+
@trace for i in Int32(1):Int32(size(u, 2))
3838
@. x = ΔA[:, :, i, :] * x + ΔBu[:, :, i, :]
39-
tmp = batched_matmul(
40-
x,
41-
reshape(C[:, i, :], size(C, 1), 1, size(C, 3));
42-
lhs_contracting_dim=1,
43-
rhs_contracting_dim=1,
44-
lhs_batching_dims=(3,),
45-
rhs_batching_dims=(3,),
46-
) ## [d_in, 1, b]
39+
tmp = sum(x .* reshape(C[:, i, :], size(C, 1), 1, size(C, 3)); dims=1)
4740
y[:, i, :] = reshape(tmp, size(u, 1), size(u, 3))
4841
end
4942
@. y += u * D
@@ -52,15 +45,21 @@ function selective_scan_reference(
5245
end
5346

5447
#=
55-
d_in, l, n, n = 3, 4, 5, 6
56-
u = randn(Float32, d_in, l, n) |> Reactant.to_rarray;
57-
Δ = randn(Float32, d_in, l, n) |> Reactant.to_rarray;
48+
d_in, l, n, b = 512, 128, 64, 32
49+
u = randn(Float32, d_in, l, b) |> Reactant.to_rarray;
50+
Δ = randn(Float32, d_in, l, b) |> Reactant.to_rarray;
5851
A = randn(Float32, n, d_in) |> Reactant.to_rarray;
59-
B = randn(Float32, n, l, n) |> Reactant.to_rarray;
60-
C = randn(Float32, n, l, n) |> Reactant.to_rarray;
52+
B = randn(Float32, n, l, b) |> Reactant.to_rarray;
53+
C = randn(Float32, n, l, b) |> Reactant.to_rarray;
6154
D = randn(Float32, d_in) |> Reactant.to_rarray;
6255
6356
@code_hlo selective_scan_reference(u, Δ, A, B, C, D)
57+
58+
compiled_fn = @compile selective_scan_reference(u, Δ, A, B, C, D)
59+
60+
Reactant.with_profiler("./envs/traces") do
61+
compiled_fn(u, Δ, A, B, C, D)
62+
end
6463
=#
6564

6665
# ## Mamba Architecture

0 commit comments

Comments
 (0)