Skip to content

sparsetensor dialect generating deallocs at end of function than end of current region #51

Open
@vmiheer

Description

@vmiheer
input.mlir
#map = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
#map1 = affine_map<(d0, d1, d2) -> (d0, d2)>
#csrv = #sparse_tensor.encoding<{ map = (d0, d1, d2) ->
    (d0 : dense, d1 : compressed, d2 : dense) }>
#dense1d = #sparse_tensor.encoding<{ map = (d0) ->
    (d0 : dense) }>
#dense = #sparse_tensor.encoding<{ map = (d0, d1) ->
    (d0 : dense, d1 : dense) }>
#densev = #sparse_tensor.encoding<{ map = (d0, d1, d2) ->
    (d0 : dense, d1 : dense, d2 : dense) }>
#csr = #sparse_tensor.encoding<{ map = (d0, d1) ->
    (d0 : dense, d1 : compressed) }>
#partCsr = #part_tensor.encoding<{
  partConst = 1,
  sparseAttributes = #csr
}>
#partDense = #part_tensor.encoding<{
  partConst = 1,
  sparseAttributes = #dense
}>
#partDensev = #part_tensor.encoding<{
  partConst = 1,
  sparseAttributes = #densev
}>
#input_proj_map = {
  indexing_maps = [
    affine_map<(n, f, dh, nh) -> (n, f)>,  // X (in)
    affine_map<(n, f, dh, nh) -> (dh, nh, f)>,  // Q_Proj (in)
    affine_map<(n, f, dh, nh) -> (n, dh, nh)>  // Q (out)
  ],
  iterator_types = ["parallel", "reduction", "parallel", "parallel"]
}
#output_proj_map = {
  indexing_maps = [
    affine_map<(n, f, dh, nh) -> (n, dh, nh)>,  // Attn (in)
    affine_map<(n, f, dh, nh) -> (dh, nh, f)>,  // O_Proj (in)
    affine_map<(n, f, dh, nh) -> (n, f)>  // O (out)
  ],
  iterator_types = ["parallel", "parallel", "reduction", "reduction"]
}
#bsddmm_map = {
  indexing_maps = [
    affine_map<(n1, n2, dh, nh) -> (n1, dh, nh)>,  // q (in)
    affine_map<(n1, n2, dh, nh) -> (n2, dh, nh)>,  // k (in)
    affine_map<(n1, n2, dh, nh) -> (n1, n2)>,  // A (in)
    affine_map<(n1, n2, dh, nh) -> (n1, n2, nh)>   // attn (out)
  ],
  iterator_types = ["parallel", "parallel", "reduction", "parallel"],
  doc = "attn(n1, n2, nh) = q(n1, dh, nh) * k(n2, dh, nh)"
}
#bspmm_map = {
  indexing_maps = [
    affine_map<(n1, n2, dh, nh) -> (n1, n2, nh)>,  // attn (in)
    affine_map<(n1, n2, dh, nh) -> (n2, dh, nh)>,  // v (in)
    affine_map<(n1, n2, dh, nh) -> (n1, dh, nh)>   // out (out)
  ],
  iterator_types = ["parallel", "parallel", "reduction", "parallel"],
  doc = "out(n1, dh, nh) = attn(n1, n2, nh) * v(n2, dh, nh)"
}

module {
  func.func private @mpi_getRank() -> index attributes {llvm.emit_c_interface}

  func.func @pte_sparse_mha(%A: tensor<?x?xf32, #partCsr>,
    %X: tensor<?x?xf32, #partDense>,
    %Q_proj: tensor<?x?x?xf32, #densev>, %K_proj: tensor<?x?x?xf32, #densev>,
    %V_proj: tensor<?x?x?xf32, #densev>, %O_proj: tensor<?x?x?xf32, #densev>,
    %N: index, %Dh: index, %Nh: index, %Nf: index, %Tn: index)
      // ->  tensor<?x?xf32>
      ->  tensor<?x?xf32, #dense>
  {
    %c0_index = arith.constant 0 : index
    %c1_index = arith.constant 1 : index
    %c2_index = arith.constant 2 : index
    %c3_index = arith.constant 3 : index
    %c4_index = arith.constant 4 : index
    %c5_index = arith.constant 5 : index
    %c6_index = arith.constant 6 : index

    %a_partition_plan = part_tensor.get_partitions %A:
      tensor<?x?xf32, #partCsr> -> memref<?xindex>
    %mrank = call @mpi_getRank() : () -> index
    %a_pspec = memref.alloc(%c4_index) : memref<?xindex>
    %pspec_start0 = arith.muli %mrank, %c4_index : index
    %el00 = memref.load %a_partition_plan[%pspec_start0] : memref<?xindex>
    %pspec_start1 = arith.addi %pspec_start0, %c1_index : index
    %el10 = memref.load %a_partition_plan[%pspec_start1] : memref<?xindex>
    %pspec_start2 = arith.addi %pspec_start0, %c2_index : index
    %el20 = memref.load %a_partition_plan[%pspec_start2] : memref<?xindex>
    %pspec_start3 = arith.addi %pspec_start0, %c3_index : index
    %el30 = memref.load %a_partition_plan[%pspec_start3] : memref<?xindex>
    memref.store %el00, %a_pspec[%c0_index] : memref<?xindex> // n1_low
    memref.store %el10, %a_pspec[%c1_index] : memref<?xindex> // n2_low
    memref.store %el20, %a_pspec[%c2_index] : memref<?xindex> // n1_high
    memref.store %el30, %a_pspec[%c3_index] : memref<?xindex> // n2_high

    // get aslice just to get dims
    %A0_slice = part_tensor.get_slice %A, %a_pspec:
      tensor<?x?xf32, #partCsr>, memref<?xindex> -> tensor<?x?xf32, #csr>
    %lN10 = tensor.dim %A0_slice, %c0_index : tensor<?x?xf32, #csr>
    %O0 = tensor.empty(%lN10, %Nf) : tensor<?x?xf32>
    %c0_f32 = arith.constant 0.0 : f32
    %O4 = linalg.fill ins(%c0_f32: f32) outs(%O0 : tensor<?x?xf32>)
      -> tensor<?x?xf32>
    %O5 = bufferization.materialize_in_destination  %O4 in %O0 : (tensor<?x?xf32>,
      tensor<?x?xf32>) -> tensor<?x?xf32>

    %O2 = scf.for %iv = %c0_index to %N step %Tn
        iter_args(%O1 = %O5) -> (tensor<?x?xf32>) {
      memref.store %iv,      %a_pspec[%c1_index] : memref<?xindex> // n2_low
      %iv_next = arith.addi %iv, %Tn : index
      memref.store %iv_next, %a_pspec[%c3_index] : memref<?xindex> // n2_high
      %A_slice = part_tensor.get_slice %A, %a_pspec:
        tensor<?x?xf32, #partCsr>, memref<?xindex> -> tensor<?x?xf32, #csr>
      %a_am_0 = part_tensor.get_active_mask %A, %a_pspec, %c0_index:
        tensor<?x?xf32, #partCsr>, memref<?xindex> -> index
      %a_am_1 = part_tensor.get_active_mask %A, %a_pspec, %c1_index:
        tensor<?x?xf32, #partCsr>, memref<?xindex> -> index

      %el0 = memref.load %a_pspec[%c0_index] : memref<?xindex>
      %el1 = memref.load %a_pspec[%c1_index] : memref<?xindex>
      %el2 = memref.load %a_pspec[%c2_index] : memref<?xindex>
      %el3 = memref.load %a_pspec[%c3_index] : memref<?xindex>
      %Q_pspec = memref.alloc(%c4_index) : memref<?xindex>
      memref.store %el0, %Q_pspec[%c0_index] : memref<?xindex> // n1_low
      memref.store %c0_index, %Q_pspec[%c1_index] : memref<?xindex> // nf_low
      memref.store %el2, %Q_pspec[%c2_index] : memref<?xindex> // n1_high
      memref.store %Nf, %Q_pspec[%c3_index] : memref<?xindex> // nf_high
      %K_pspec = memref.alloc(%c4_index) : memref<?xindex>
      memref.store %el1, %K_pspec[%c0_index] : memref<?xindex> // n2_low
      memref.store %c0_index, %K_pspec[%c1_index] : memref<?xindex> // nf_low
      memref.store %el3, %K_pspec[%c2_index] : memref<?xindex> // n2_high
      memref.store %Nf, %K_pspec[%c3_index] : memref<?xindex> // nf_high

      %XforQ_slice = part_tensor.get_slice %X, %Q_pspec:
        tensor<?x?xf32, #partDense>, memref<?xindex> ->
          tensor<?x?xf32, #dense>
      %XforK_slice = part_tensor.get_slice_for_active_mask %X, %K_pspec,
          %a_am_1, %c0_index:
        tensor<?x?xf32, #partDense>, memref<?xindex>, index, index ->
          tensor<?x?xf32, #dense>
      %lN1 = tensor.dim %A_slice, %c0_index : tensor<?x?xf32, #csr>
      %lN2 = tensor.dim %A_slice, %c1_index : tensor<?x?xf32, #csr>
      %Q0 = tensor.empty(%lN1, %Dh, %Nh) : tensor<?x?x?xf32, #densev>
      %Q_slice = linalg.generic #input_proj_map ins(%XforQ_slice, %Q_proj
          : tensor<?x?xf32, #dense>, tensor<?x?x?xf32, #densev>)
        outs(%Q0: tensor<?x?x?xf32, #densev>) {
        ^bb0(%x: f32, %q: f32, %a: f32):  // no predecessors
          %0 = arith.mulf %x, %q : f32
          %1 = arith.addf %0, %a : f32
          linalg.yield %1 : f32
      } -> tensor<?x?x?xf32, #densev>
      %K0 = tensor.empty(%lN2, %Dh, %Nh) : tensor<?x?x?xf32, #densev>
      %K_slice = linalg.generic #input_proj_map ins(%XforK_slice, %K_proj
          : tensor<?x?xf32, #dense>, tensor<?x?x?xf32, #densev>)
        outs(%K0: tensor<?x?x?xf32, #densev>) {
        ^bb0(%x: f32, %q: f32, %a: f32):  // no predecessors
          %0 = arith.mulf %x, %q : f32
          %1 = arith.addf %0, %a : f32
          linalg.yield %1 : f32
      } -> tensor<?x?x?xf32, #densev>
      %attn0 = tensor.empty (%lN1, %lN2, %Nh) : tensor<?x?x?xf32, #csrv>
      %attn2 = linalg.generic #bsddmm_map
        ins(%Q_slice, %K_slice, %A_slice: tensor<?x?x?xf32, #densev>,
          tensor<?x?x?xf32, #densev>, tensor<?x?xf32, #csr>)
        outs(%attn0: tensor<?x?x?xf32, #csrv>) {
        ^bb0(%q: f32, %k: f32, %mask: f32, %attn: f32):  // no predecessors
          %0 = arith.mulf %q, %k : f32
          %1 = arith.mulf %0, %mask : f32
          %2 = arith.addf %1, %attn: f32
          linalg.yield %2 : f32
      } -> tensor<?x?x?xf32, #csrv>

      // attn = attn.softmax()  # (sparse) [N, N, nh]
      %sc0 = arith.constant 0 : index
      %sc1 = arith.constant 1 : index
      %sc1_i8 = arith.constant 1 : i8
      %sc2 = arith.constant 2 : index
      %scst = arith.constant 0.000000e+00 : f32
      %sdim = tensor.dim %attn2, %sc0 : tensor<?x?x?xf32, #csrv>
      %sdim_0 = tensor.dim %attn2, %sc1 : tensor<?x?x?xf32, #csrv>
      %sdim_1 = tensor.dim %attn2, %sc2 : tensor<?x?x?xf32, #csrv>
      %s0 = tensor.empty(%sdim, %sdim_0, %sdim_1) : tensor<?x?x?xf32, #csrv>
      %sc0_2 = arith.constant 0 : index
      %sdim_3 = tensor.dim %attn2, %sc0_2 : tensor<?x?x?xf32, #csrv>
      %sc1_4 = arith.constant 1 : index
      %sdim_5 = tensor.dim %attn2, %sc1_4 : tensor<?x?x?xf32, #csrv>
      %sc2_6 = arith.constant 2 : index
      %sdim_7 = tensor.dim %attn2, %sc2_6 : tensor<?x?x?xf32, #csrv>
      %s11 = tensor.empty(%sdim_3, %sdim_7) : tensor<?x?xf32>
      %sminus_inf = arith.constant -3.40282347E+38 : f32

      %s21 = linalg.fill ins(%sminus_inf : f32) outs(%s11 : tensor<?x?xf32>)
        -> tensor<?x?xf32>
      %s31 = linalg.generic {indexing_maps = [#map, #map1],
        iterator_types = ["parallel", "reduction", "parallel"]}
        ins(%attn2 : tensor<?x?x?xf32, #csrv>) outs(%s21 : tensor<?x?xf32>) {
          ^bb0(%sin: f32, %sout: f32):
            %sres = sparse_tensor.reduce %sin, %sout, %sminus_inf : f32 {
              ^bb0(%sx0: f32, %sx1: f32):
                %s00 = arith.maxnumf %sx0, %sx1 : f32
                sparse_tensor.yield %s00: f32
            }
            linalg.yield %sres : f32
      } -> tensor<?x?xf32>
      %s3 = linalg.generic {indexing_maps = [#map, #map],
        iterator_types = ["parallel", "parallel", "parallel"]}
        ins(%attn2 : tensor<?x?x?xf32, #csrv>)
        outs(%attn2 : tensor<?x?x?xf32, #csrv>) {
          ^bb0(%sin: f32, %sout: f32):
            %sx = linalg.index 0: index
            %sz = linalg.index 2: index
            %sresult = sparse_tensor.unary %sin : f32 to f32
            present={
            ^bb0(%sin1: f32):
              %smaxel = tensor.extract %s31[%sx, %sz]: tensor<?x?xf32>
              %s8 = arith.subf %sin1, %smaxel : f32
              %sret = math.exp %s8 : f32
              sparse_tensor.yield %sret : f32
            }
            absent={}
            linalg.yield %sresult : f32
      } -> tensor<?x?x?xf32, #csrv>
      %s1 = tensor.empty(%sdim_3, %sdim_7) : tensor<?x?xf32>
      %scst_8 = arith.constant 0. : f32
      %s2 = linalg.fill ins(%scst_8 : f32) outs(%s1 : tensor<?x?xf32>)
        -> tensor<?x?xf32>
      %s4 = linalg.generic {indexing_maps = [#map, #map1],
        iterator_types = ["parallel", "reduction", "parallel"]}
        ins(%s3 : tensor<?x?x?xf32, #csrv>) outs(%s2 : tensor<?x?xf32>) {
          ^bb0(%sin: f32, %sout: f32):
            %sres = sparse_tensor.reduce %sin, %sout, %scst_8 : f32 {
              ^bb0(%sx0: f32, %sx1: f32):
                %s00 = arith.addf %sx0, %sx1 : f32
                sparse_tensor.yield %s00: f32
            }
            linalg.yield %sres : f32
      } -> tensor<?x?xf32>
      %attn31  = linalg.generic {indexing_maps = [#map],
        iterator_types = ["parallel", "parallel", "parallel"]}
        outs(%s3: tensor<?x?x?xf32, #csrv>) {
          ^bb0(%sin: f32):
            %sx = linalg.index 0: index
            %sz = linalg.index 2: index
            %sresult = sparse_tensor.unary %sin : f32 to f32
            present={
            ^bb0(%sin1: f32):
              %sdenom = tensor.extract %s4[%sx, %sz]: tensor<?x?xf32>
              %sret = arith.divf %sin1, %sdenom : f32
              sparse_tensor.yield %sret : f32
            }
            absent={}
            linalg.yield %sresult : f32
      } -> tensor<?x?x?xf32, #csrv>
      %V0 = tensor.empty(%lN2, %Dh, %Nh) : tensor<?x?x?xf32, #densev>
      %V_slice = linalg.generic #input_proj_map ins(%XforK_slice, %V_proj
          : tensor<?x?xf32, #dense>, tensor<?x?x?xf32, #densev>)
        outs(%V0: tensor<?x?x?xf32, #densev>) {
        ^bb0(%x: f32, %q: f32, %a: f32):  // no predecessors
          %0 = arith.mulf %x, %q : f32
          %1 = arith.addf %0, %a : f32
          linalg.yield %1 : f32
      } -> tensor<?x?x?xf32, #densev>
      // out = dglsp.bspmm(attn, v)  # [N, dh, nh]
      %spmm_in0 = tensor.empty (%lN1, %Dh, %Nh) : tensor<?x?x?xf32, #densev>
      %spmm_in1 = linalg.fill ins(%c0_f32: f32) outs(%spmm_in0 : tensor<?x?x?xf32, #densev>)
        -> tensor<?x?x?xf32, #densev>
      %attn32 = sparse_tensor.load %attn31 hasInserts : tensor<?x?x?xf32, #csrv>
      %attn4 = linalg.generic #bspmm_map
        ins(%attn32, %V_slice: tensor<?x?x?xf32, #csrv>, tensor<?x?x?xf32, #densev>)
        outs(%spmm_in1: tensor<?x?x?xf32, #densev>) {
        ^bb0(%q: f32, %k: f32, %attn: f32):  // no predecessors
          %0 = arith.mulf %q, %k : f32
          %1 = arith.addf %0, %attn: f32
          linalg.yield %1 : f32
      } -> tensor<?x?x?xf32, #densev>
      // %O0 = tensor.empty(%lN1, %Nf) : tensor<?x?xf32, #dense>
      %O = linalg.generic #output_proj_map ins(%attn4, %O_proj: tensor<?x?x?xf32, #densev>,
      tensor<?x?x?xf32, #densev>)
        outs(%O1: tensor<?x?xf32>) {
        ^bb0(%x: f32, %q: f32, %a: f32):  // no predecessors
          %0 = arith.mulf %x, %q : f32
          %1 = arith.addf %0, %a : f32
          linalg.yield %1 : f32
      } -> tensor<?x?xf32>
      // %O11 = bufferization.materialize_in_destination %O in %O1 : (tensor<?x?xf32>,
      //   tensor<?x?xf32>) -> tensor<?x?xf32>
      // return %attn4 : tensor<?x?x?xf32, #densev>
      // scf.yield %O : tensor<?x?xf32>

      scf.yield %O : tensor<?x?xf32>
    }
    %O3 = tensor.cast %O2 : tensor<?x?xf32> to tensor<?x?xf32, #dense>
    return %O3 : tensor<?x?xf32, #dense>
    // return %O2 : tensor<?x?xf32>
  }
}
                                          

// Local Variables:
// rmsbolt-command: "lapis-opt --sparse-compiler-kokkos='pt-backend=mpi'"
// rmsbolt-automatic-recompile: on-save
// End:                                   

Compiling above mlir causes error:

../sparseMHA.3.dist.mlir:163:16: error: operand #0 does not dominate this use
      %attn2 = linalg.generic #bsddmm_map
               ^
../sparseMHA.3.dist.mlir:163:16: note: see current operation: "memref.dealloc"(%212) : (memref<?xf32>) -> ()
../sparseMHA.3.dist.mlir:163:16: note: operand defined here (op in a child region)

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions