@@ -34,16 +34,9 @@ function selective_scan_reference(
3434 # # Perform selective scan with a sequential implementation for correctness verification
3535 x = fill!(similar(u, size(A, 1 ), size(u, 1 ), size(u, 3 )), 0 ) # # [n, d_in, b]
3636 y = similar(u)
37- @trace for i in 1 : size(u, 2 )
37+ @trace for i in Int32( 1 ) : Int32( size(u, 2 ) )
3838 @. x = ΔA[:, :, i, :] * x + ΔBu[:, :, i, :]
39- tmp = batched_matmul(
40- x,
41- reshape(C[:, i, :], size(C, 1 ), 1 , size(C, 3 ));
42- lhs_contracting_dim= 1 ,
43- rhs_contracting_dim= 1 ,
44- lhs_batching_dims= (3 ,),
45- rhs_batching_dims= (3 ,),
46- ) # # [d_in, 1, b]
39+ tmp = sum(x .* reshape(C[:, i, :], size(C, 1 ), 1 , size(C, 3 )); dims= 1 )
4740 y[:, i, :] = reshape(tmp, size(u, 1 ), size(u, 3 ))
4841 end
4942 @. y += u * D
@@ -52,15 +45,21 @@ function selective_scan_reference(
5245end
5346
5447#=
55- d_in, l, n, n = 3, 4, 5, 6
56- u = randn(Float32, d_in, l, n ) |> Reactant.to_rarray;
57- Δ = randn(Float32, d_in, l, n ) |> Reactant.to_rarray;
48+ d_in, l, n, b = 512, 128, 64, 32
49+ u = randn(Float32, d_in, l, b ) |> Reactant.to_rarray;
50+ Δ = randn(Float32, d_in, l, b ) |> Reactant.to_rarray;
5851A = randn(Float32, n, d_in) |> Reactant.to_rarray;
59- B = randn(Float32, n, l, n ) |> Reactant.to_rarray;
60- C = randn(Float32, n, l, n ) |> Reactant.to_rarray;
52+ B = randn(Float32, n, l, b ) |> Reactant.to_rarray;
53+ C = randn(Float32, n, l, b ) |> Reactant.to_rarray;
6154D = randn(Float32, d_in) |> Reactant.to_rarray;
6255
6356@code_hlo selective_scan_reference(u, Δ, A, B, C, D)
57+
58+ compiled_fn = @compile selective_scan_reference(u, Δ, A, B, C, D)
59+
60+ Reactant.with_profiler("./envs/traces") do
61+ compiled_fn(u, Δ, A, B, C, D)
62+ end
6463=#
6564
6665# ## Mamba Architecture
0 commit comments