Open
Description
input.mlir
#map = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
#map1 = affine_map<(d0, d1, d2) -> (d0, d2)>
#csrv = #sparse_tensor.encoding<{ map = (d0, d1, d2) ->
(d0 : dense, d1 : compressed, d2 : dense) }>
#dense1d = #sparse_tensor.encoding<{ map = (d0) ->
(d0 : dense) }>
#dense = #sparse_tensor.encoding<{ map = (d0, d1) ->
(d0 : dense, d1 : dense) }>
#densev = #sparse_tensor.encoding<{ map = (d0, d1, d2) ->
(d0 : dense, d1 : dense, d2 : dense) }>
#csr = #sparse_tensor.encoding<{ map = (d0, d1) ->
(d0 : dense, d1 : compressed) }>
#partCsr = #part_tensor.encoding<{
partConst = 1,
sparseAttributes = #csr
}>
#partDense = #part_tensor.encoding<{
partConst = 1,
sparseAttributes = #dense
}>
#partDensev = #part_tensor.encoding<{
partConst = 1,
sparseAttributes = #densev
}>
#input_proj_map = {
indexing_maps = [
affine_map<(n, f, dh, nh) -> (n, f)>, // X (in)
affine_map<(n, f, dh, nh) -> (dh, nh, f)>, // Q_Proj (in)
affine_map<(n, f, dh, nh) -> (n, dh, nh)> // Q (out)
],
iterator_types = ["parallel", "reduction", "parallel", "parallel"]
}
#output_proj_map = {
indexing_maps = [
affine_map<(n, f, dh, nh) -> (n, dh, nh)>, // Attn (in)
affine_map<(n, f, dh, nh) -> (dh, nh, f)>, // O_Proj (in)
affine_map<(n, f, dh, nh) -> (n, f)> // O (out)
],
iterator_types = ["parallel", "parallel", "reduction", "reduction"]
}
#bsddmm_map = {
indexing_maps = [
affine_map<(n1, n2, dh, nh) -> (n1, dh, nh)>, // q (in)
affine_map<(n1, n2, dh, nh) -> (n2, dh, nh)>, // k (in)
affine_map<(n1, n2, dh, nh) -> (n1, n2)>, // A (in)
affine_map<(n1, n2, dh, nh) -> (n1, n2, nh)> // attn (out)
],
iterator_types = ["parallel", "parallel", "reduction", "parallel"],
doc = "attn(n1, n2, nh) = q(n1, dh, nh) * k(n2, dh, nh)"
}
#bspmm_map = {
indexing_maps = [
affine_map<(n1, n2, dh, nh) -> (n1, n2, nh)>, // attn (in)
affine_map<(n1, n2, dh, nh) -> (n2, dh, nh)>, // v (in)
affine_map<(n1, n2, dh, nh) -> (n1, dh, nh)> // out (out)
],
iterator_types = ["parallel", "parallel", "reduction", "parallel"],
doc = "out(n1, dh, nh) = attn(n1, n2, nh) * v(n2, dh, nh)"
}
module {
func.func private @mpi_getRank() -> index attributes {llvm.emit_c_interface}
func.func @pte_sparse_mha(%A: tensor<?x?xf32, #partCsr>,
%X: tensor<?x?xf32, #partDense>,
%Q_proj: tensor<?x?x?xf32, #densev>, %K_proj: tensor<?x?x?xf32, #densev>,
%V_proj: tensor<?x?x?xf32, #densev>, %O_proj: tensor<?x?x?xf32, #densev>,
%N: index, %Dh: index, %Nh: index, %Nf: index, %Tn: index)
// -> tensor<?x?xf32>
-> tensor<?x?xf32, #dense>
{
%c0_index = arith.constant 0 : index
%c1_index = arith.constant 1 : index
%c2_index = arith.constant 2 : index
%c3_index = arith.constant 3 : index
%c4_index = arith.constant 4 : index
%c5_index = arith.constant 5 : index
%c6_index = arith.constant 6 : index
%a_partition_plan = part_tensor.get_partitions %A:
tensor<?x?xf32, #partCsr> -> memref<?xindex>
%mrank = call @mpi_getRank() : () -> index
%a_pspec = memref.alloc(%c4_index) : memref<?xindex>
%pspec_start0 = arith.muli %mrank, %c4_index : index
%el00 = memref.load %a_partition_plan[%pspec_start0] : memref<?xindex>
%pspec_start1 = arith.addi %pspec_start0, %c1_index : index
%el10 = memref.load %a_partition_plan[%pspec_start1] : memref<?xindex>
%pspec_start2 = arith.addi %pspec_start0, %c2_index : index
%el20 = memref.load %a_partition_plan[%pspec_start2] : memref<?xindex>
%pspec_start3 = arith.addi %pspec_start0, %c3_index : index
%el30 = memref.load %a_partition_plan[%pspec_start3] : memref<?xindex>
memref.store %el00, %a_pspec[%c0_index] : memref<?xindex> // n1_low
memref.store %el10, %a_pspec[%c1_index] : memref<?xindex> // n2_low
memref.store %el20, %a_pspec[%c2_index] : memref<?xindex> // n1_high
memref.store %el30, %a_pspec[%c3_index] : memref<?xindex> // n2_high
// get aslice just to get dims
%A0_slice = part_tensor.get_slice %A, %a_pspec:
tensor<?x?xf32, #partCsr>, memref<?xindex> -> tensor<?x?xf32, #csr>
%lN10 = tensor.dim %A0_slice, %c0_index : tensor<?x?xf32, #csr>
%O0 = tensor.empty(%lN10, %Nf) : tensor<?x?xf32>
%c0_f32 = arith.constant 0.0 : f32
%O4 = linalg.fill ins(%c0_f32: f32) outs(%O0 : tensor<?x?xf32>)
-> tensor<?x?xf32>
%O5 = bufferization.materialize_in_destination %O4 in %O0 : (tensor<?x?xf32>,
tensor<?x?xf32>) -> tensor<?x?xf32>
%O2 = scf.for %iv = %c0_index to %N step %Tn
iter_args(%O1 = %O5) -> (tensor<?x?xf32>) {
memref.store %iv, %a_pspec[%c1_index] : memref<?xindex> // n2_low
%iv_next = arith.addi %iv, %Tn : index
memref.store %iv_next, %a_pspec[%c3_index] : memref<?xindex> // n2_high
%A_slice = part_tensor.get_slice %A, %a_pspec:
tensor<?x?xf32, #partCsr>, memref<?xindex> -> tensor<?x?xf32, #csr>
%a_am_0 = part_tensor.get_active_mask %A, %a_pspec, %c0_index:
tensor<?x?xf32, #partCsr>, memref<?xindex> -> index
%a_am_1 = part_tensor.get_active_mask %A, %a_pspec, %c1_index:
tensor<?x?xf32, #partCsr>, memref<?xindex> -> index
%el0 = memref.load %a_pspec[%c0_index] : memref<?xindex>
%el1 = memref.load %a_pspec[%c1_index] : memref<?xindex>
%el2 = memref.load %a_pspec[%c2_index] : memref<?xindex>
%el3 = memref.load %a_pspec[%c3_index] : memref<?xindex>
%Q_pspec = memref.alloc(%c4_index) : memref<?xindex>
memref.store %el0, %Q_pspec[%c0_index] : memref<?xindex> // n1_low
memref.store %c0_index, %Q_pspec[%c1_index] : memref<?xindex> // nf_low
memref.store %el2, %Q_pspec[%c2_index] : memref<?xindex> // n1_high
memref.store %Nf, %Q_pspec[%c3_index] : memref<?xindex> // nf_high
%K_pspec = memref.alloc(%c4_index) : memref<?xindex>
memref.store %el1, %K_pspec[%c0_index] : memref<?xindex> // n2_low
memref.store %c0_index, %K_pspec[%c1_index] : memref<?xindex> // nf_low
memref.store %el3, %K_pspec[%c2_index] : memref<?xindex> // n2_high
memref.store %Nf, %K_pspec[%c3_index] : memref<?xindex> // nf_high
%XforQ_slice = part_tensor.get_slice %X, %Q_pspec:
tensor<?x?xf32, #partDense>, memref<?xindex> ->
tensor<?x?xf32, #dense>
%XforK_slice = part_tensor.get_slice_for_active_mask %X, %K_pspec,
%a_am_1, %c0_index:
tensor<?x?xf32, #partDense>, memref<?xindex>, index, index ->
tensor<?x?xf32, #dense>
%lN1 = tensor.dim %A_slice, %c0_index : tensor<?x?xf32, #csr>
%lN2 = tensor.dim %A_slice, %c1_index : tensor<?x?xf32, #csr>
%Q0 = tensor.empty(%lN1, %Dh, %Nh) : tensor<?x?x?xf32, #densev>
%Q_slice = linalg.generic #input_proj_map ins(%XforQ_slice, %Q_proj
: tensor<?x?xf32, #dense>, tensor<?x?x?xf32, #densev>)
outs(%Q0: tensor<?x?x?xf32, #densev>) {
^bb0(%x: f32, %q: f32, %a: f32): // no predecessors
%0 = arith.mulf %x, %q : f32
%1 = arith.addf %0, %a : f32
linalg.yield %1 : f32
} -> tensor<?x?x?xf32, #densev>
%K0 = tensor.empty(%lN2, %Dh, %Nh) : tensor<?x?x?xf32, #densev>
%K_slice = linalg.generic #input_proj_map ins(%XforK_slice, %K_proj
: tensor<?x?xf32, #dense>, tensor<?x?x?xf32, #densev>)
outs(%K0: tensor<?x?x?xf32, #densev>) {
^bb0(%x: f32, %q: f32, %a: f32): // no predecessors
%0 = arith.mulf %x, %q : f32
%1 = arith.addf %0, %a : f32
linalg.yield %1 : f32
} -> tensor<?x?x?xf32, #densev>
%attn0 = tensor.empty (%lN1, %lN2, %Nh) : tensor<?x?x?xf32, #csrv>
%attn2 = linalg.generic #bsddmm_map
ins(%Q_slice, %K_slice, %A_slice: tensor<?x?x?xf32, #densev>,
tensor<?x?x?xf32, #densev>, tensor<?x?xf32, #csr>)
outs(%attn0: tensor<?x?x?xf32, #csrv>) {
^bb0(%q: f32, %k: f32, %mask: f32, %attn: f32): // no predecessors
%0 = arith.mulf %q, %k : f32
%1 = arith.mulf %0, %mask : f32
%2 = arith.addf %1, %attn: f32
linalg.yield %2 : f32
} -> tensor<?x?x?xf32, #csrv>
// attn = attn.softmax() # (sparse) [N, N, nh]
%sc0 = arith.constant 0 : index
%sc1 = arith.constant 1 : index
%sc1_i8 = arith.constant 1 : i8
%sc2 = arith.constant 2 : index
%scst = arith.constant 0.000000e+00 : f32
%sdim = tensor.dim %attn2, %sc0 : tensor<?x?x?xf32, #csrv>
%sdim_0 = tensor.dim %attn2, %sc1 : tensor<?x?x?xf32, #csrv>
%sdim_1 = tensor.dim %attn2, %sc2 : tensor<?x?x?xf32, #csrv>
%s0 = tensor.empty(%sdim, %sdim_0, %sdim_1) : tensor<?x?x?xf32, #csrv>
%sc0_2 = arith.constant 0 : index
%sdim_3 = tensor.dim %attn2, %sc0_2 : tensor<?x?x?xf32, #csrv>
%sc1_4 = arith.constant 1 : index
%sdim_5 = tensor.dim %attn2, %sc1_4 : tensor<?x?x?xf32, #csrv>
%sc2_6 = arith.constant 2 : index
%sdim_7 = tensor.dim %attn2, %sc2_6 : tensor<?x?x?xf32, #csrv>
%s11 = tensor.empty(%sdim_3, %sdim_7) : tensor<?x?xf32>
%sminus_inf = arith.constant -3.40282347E+38 : f32
%s21 = linalg.fill ins(%sminus_inf : f32) outs(%s11 : tensor<?x?xf32>)
-> tensor<?x?xf32>
%s31 = linalg.generic {indexing_maps = [#map, #map1],
iterator_types = ["parallel", "reduction", "parallel"]}
ins(%attn2 : tensor<?x?x?xf32, #csrv>) outs(%s21 : tensor<?x?xf32>) {
^bb0(%sin: f32, %sout: f32):
%sres = sparse_tensor.reduce %sin, %sout, %sminus_inf : f32 {
^bb0(%sx0: f32, %sx1: f32):
%s00 = arith.maxnumf %sx0, %sx1 : f32
sparse_tensor.yield %s00: f32
}
linalg.yield %sres : f32
} -> tensor<?x?xf32>
%s3 = linalg.generic {indexing_maps = [#map, #map],
iterator_types = ["parallel", "parallel", "parallel"]}
ins(%attn2 : tensor<?x?x?xf32, #csrv>)
outs(%attn2 : tensor<?x?x?xf32, #csrv>) {
^bb0(%sin: f32, %sout: f32):
%sx = linalg.index 0: index
%sz = linalg.index 2: index
%sresult = sparse_tensor.unary %sin : f32 to f32
present={
^bb0(%sin1: f32):
%smaxel = tensor.extract %s31[%sx, %sz]: tensor<?x?xf32>
%s8 = arith.subf %sin1, %smaxel : f32
%sret = math.exp %s8 : f32
sparse_tensor.yield %sret : f32
}
absent={}
linalg.yield %sresult : f32
} -> tensor<?x?x?xf32, #csrv>
%s1 = tensor.empty(%sdim_3, %sdim_7) : tensor<?x?xf32>
%scst_8 = arith.constant 0. : f32
%s2 = linalg.fill ins(%scst_8 : f32) outs(%s1 : tensor<?x?xf32>)
-> tensor<?x?xf32>
%s4 = linalg.generic {indexing_maps = [#map, #map1],
iterator_types = ["parallel", "reduction", "parallel"]}
ins(%s3 : tensor<?x?x?xf32, #csrv>) outs(%s2 : tensor<?x?xf32>) {
^bb0(%sin: f32, %sout: f32):
%sres = sparse_tensor.reduce %sin, %sout, %scst_8 : f32 {
^bb0(%sx0: f32, %sx1: f32):
%s00 = arith.addf %sx0, %sx1 : f32
sparse_tensor.yield %s00: f32
}
linalg.yield %sres : f32
} -> tensor<?x?xf32>
%attn31 = linalg.generic {indexing_maps = [#map],
iterator_types = ["parallel", "parallel", "parallel"]}
outs(%s3: tensor<?x?x?xf32, #csrv>) {
^bb0(%sin: f32):
%sx = linalg.index 0: index
%sz = linalg.index 2: index
%sresult = sparse_tensor.unary %sin : f32 to f32
present={
^bb0(%sin1: f32):
%sdenom = tensor.extract %s4[%sx, %sz]: tensor<?x?xf32>
%sret = arith.divf %sin1, %sdenom : f32
sparse_tensor.yield %sret : f32
}
absent={}
linalg.yield %sresult : f32
} -> tensor<?x?x?xf32, #csrv>
%V0 = tensor.empty(%lN2, %Dh, %Nh) : tensor<?x?x?xf32, #densev>
%V_slice = linalg.generic #input_proj_map ins(%XforK_slice, %V_proj
: tensor<?x?xf32, #dense>, tensor<?x?x?xf32, #densev>)
outs(%V0: tensor<?x?x?xf32, #densev>) {
^bb0(%x: f32, %q: f32, %a: f32): // no predecessors
%0 = arith.mulf %x, %q : f32
%1 = arith.addf %0, %a : f32
linalg.yield %1 : f32
} -> tensor<?x?x?xf32, #densev>
// out = dglsp.bspmm(attn, v) # [N, dh, nh]
%spmm_in0 = tensor.empty (%lN1, %Dh, %Nh) : tensor<?x?x?xf32, #densev>
%spmm_in1 = linalg.fill ins(%c0_f32: f32) outs(%spmm_in0 : tensor<?x?x?xf32, #densev>)
-> tensor<?x?x?xf32, #densev>
%attn32 = sparse_tensor.load %attn31 hasInserts : tensor<?x?x?xf32, #csrv>
%attn4 = linalg.generic #bspmm_map
ins(%attn32, %V_slice: tensor<?x?x?xf32, #csrv>, tensor<?x?x?xf32, #densev>)
outs(%spmm_in1: tensor<?x?x?xf32, #densev>) {
^bb0(%q: f32, %k: f32, %attn: f32): // no predecessors
%0 = arith.mulf %q, %k : f32
%1 = arith.addf %0, %attn: f32
linalg.yield %1 : f32
} -> tensor<?x?x?xf32, #densev>
// %O0 = tensor.empty(%lN1, %Nf) : tensor<?x?xf32, #dense>
%O = linalg.generic #output_proj_map ins(%attn4, %O_proj: tensor<?x?x?xf32, #densev>,
tensor<?x?x?xf32, #densev>)
outs(%O1: tensor<?x?xf32>) {
^bb0(%x: f32, %q: f32, %a: f32): // no predecessors
%0 = arith.mulf %x, %q : f32
%1 = arith.addf %0, %a : f32
linalg.yield %1 : f32
} -> tensor<?x?xf32>
// %O11 = bufferization.materialize_in_destination %O in %O1 : (tensor<?x?xf32>,
// tensor<?x?xf32>) -> tensor<?x?xf32>
// return %attn4 : tensor<?x?x?xf32, #densev>
// scf.yield %O : tensor<?x?xf32>
scf.yield %O : tensor<?x?xf32>
}
%O3 = tensor.cast %O2 : tensor<?x?xf32> to tensor<?x?xf32, #dense>
return %O3 : tensor<?x?xf32, #dense>
// return %O2 : tensor<?x?xf32>
}
}
// Local Variables:
// rmsbolt-command: "lapis-opt --sparse-compiler-kokkos='pt-backend=mpi'"
// rmsbolt-automatic-recompile: on-save
// End:
Compiling above mlir causes error:
../sparseMHA.3.dist.mlir:163:16: error: operand #0 does not dominate this use
%attn2 = linalg.generic #bsddmm_map
^
../sparseMHA.3.dist.mlir:163:16: note: see current operation: "memref.dealloc"(%212) : (memref<?xf32>) -> ()
../sparseMHA.3.dist.mlir:163:16: note: operand defined here (op in a child region)
Metadata
Metadata
Assignees
Labels
No labels