Open
Description
Softmax.mlir lowring to kokkos fails with new lapis-opt. (@brian-kelley, Maybe there are different options I should've been using).
lapis-opt --sparse-compiler-kokkos='pt-backend=mpi' softmax.mlir -o softmax.scf.mlir
softmax.mlir
#map = affine_map<(d0, d1, d2) -> (d0, d1, d2)>
#map1 = affine_map<(d0, d1, d2) -> (d0, d2)>
#csrv = #sparse_tensor.encoding<{ map = (d0, d1, d2) ->
(d0 : dense, d1 : compressed, d2 : dense) }>
#dense1d = #sparse_tensor.encoding<{ map = (d0) ->
(d0 : dense) }>
#dense = #sparse_tensor.encoding<{ map = (d0, d1) ->
(d0 : dense, d1 : dense) }>
#densev = #sparse_tensor.encoding<{ map = (d0, d1, d2) ->
(d0 : dense, d1 : dense, d2 : dense) }>
#csr = #sparse_tensor.encoding<{ map = (d0, d1) ->
(d0 : dense, d1 : compressed) }>
#partCsr = #part_tensor.encoding<{
partConst = 1,
sparseAttributes = #csr
}>
#partDense = #part_tensor.encoding<{
partConst = 1,
sparseAttributes = #dense
}>
#partDensev = #part_tensor.encoding<{
partConst = 1,
sparseAttributes = #densev
}>
#input_proj_map = {
indexing_maps = [
affine_map<(n, f, dh, nh) -> (n, f)>, // X (in)
affine_map<(n, f, dh, nh) -> (dh, nh, f)>, // Q_Proj (in)
affine_map<(n, f, dh, nh) -> (n, dh, nh)> // Q (out)
],
iterator_types = ["parallel", "reduction", "parallel", "parallel"]
}
#output_proj_map = {
indexing_maps = [
affine_map<(n, f, dh, nh) -> (n, dh, nh)>, // Attn (in)
affine_map<(n, f, dh, nh) -> (dh, nh, f)>, // O_Proj (in)
affine_map<(n, f, dh, nh) -> (n, f)> // O (out)
],
iterator_types = ["parallel", "parallel", "reduction", "reduction"]
}
#bsddmm_map = {
indexing_maps = [
affine_map<(n1, n2, dh, nh) -> (n1, dh, nh)>, // q (in)
affine_map<(n1, n2, dh, nh) -> (n2, dh, nh)>, // k (in)
affine_map<(n1, n2, dh, nh) -> (n1, n2)>, // A (in)
affine_map<(n1, n2, dh, nh) -> (n1, n2, nh)> // attn (out)
],
iterator_types = ["parallel", "parallel", "reduction", "parallel"],
doc = "attn(n1, n2, nh) = q(n1, dh, nh) * k(n2, dh, nh)"
}
#bspmm_map = {
indexing_maps = [
affine_map<(n1, n2, dh, nh) -> (n1, n2, nh)>, // attn (in)
affine_map<(n1, n2, dh, nh) -> (n2, dh, nh)>, // v (in)
affine_map<(n1, n2, dh, nh) -> (n1, dh, nh)> // out (out)
],
iterator_types = ["parallel", "parallel", "reduction", "parallel"],
doc = "out(n1, dh, nh) = attn(n1, n2, nh) * v(n2, dh, nh)"
}
module {
func.func @pte_softmax(%attn2: tensor<?x?x?xf32, #csrv>)
-> tensor<?x?x?xf32, #csrv> {
%sc0 = arith.constant 0 : index
%sc1 = arith.constant 1 : index
%sc1_i8 = arith.constant 1 : i8
%sc2 = arith.constant 2 : index
%scst = arith.constant 0.000000e+00 : f32
%sdim = tensor.dim %attn2, %sc0 : tensor<?x?x?xf32, #csrv>
%sdim_0 = tensor.dim %attn2, %sc1 : tensor<?x?x?xf32, #csrv>
%sdim_1 = tensor.dim %attn2, %sc2 : tensor<?x?x?xf32, #csrv>
%sc0_2 = arith.constant 0 : index
%sdim_3 = tensor.dim %attn2, %sc0_2 : tensor<?x?x?xf32, #csrv>
%sc1_4 = arith.constant 1 : index
%sdim_5 = tensor.dim %attn2, %sc1_4 : tensor<?x?x?xf32, #csrv>
%sc2_6 = arith.constant 2 : index
%sdim_7 = tensor.dim %attn2, %sc2_6 : tensor<?x?x?xf32, #csrv>
%s11 = tensor.empty(%sdim_3, %sdim_7) : tensor<?x?xf32>
%sminus_inf = arith.constant -3.40282347E+38 : f32
%s21 = linalg.fill ins(%sminus_inf : f32) outs(%s11 : tensor<?x?xf32>)
-> tensor<?x?xf32>
%s31 = linalg.generic {indexing_maps = [#map, #map1],
iterator_types = ["parallel", "reduction", "parallel"]}
ins(%attn2 : tensor<?x?x?xf32, #csrv>) outs(%s21 : tensor<?x?xf32>) {
^bb0(%sin: f32, %sout: f32):
%sres = sparse_tensor.reduce %sin, %sout, %sminus_inf : f32 {
^bb0(%sx0: f32, %sx1: f32):
%s00 = arith.maxnumf %sx0, %sx1 : f32
sparse_tensor.yield %s00: f32
}
linalg.yield %sres : f32
} -> tensor<?x?xf32>
%s3 = linalg.generic {indexing_maps = [#map, #map],
iterator_types = ["parallel", "parallel", "parallel"]}
ins(%attn2 : tensor<?x?x?xf32, #csrv>)
outs(%attn2 : tensor<?x?x?xf32, #csrv>) {
^bb0(%sin: f32, %sout: f32):
%sx = linalg.index 0: index
%sz = linalg.index 2: index
%sresult = sparse_tensor.unary %sin : f32 to f32
present={
^bb0(%sin1: f32):
%smaxel = tensor.extract %s31[%sx, %sz]: tensor<?x?xf32>
%s8 = arith.subf %sin1, %smaxel : f32
%sret = math.exp %s8 : f32
sparse_tensor.yield %sret : f32
}
absent={}
linalg.yield %sresult : f32
} -> tensor<?x?x?xf32, #csrv>
%s1 = tensor.empty(%sdim_3, %sdim_7) : tensor<?x?xf32>
%scst_8 = arith.constant 0. : f32
%s2 = linalg.fill ins(%scst_8 : f32) outs(%s1 : tensor<?x?xf32>)
-> tensor<?x?xf32>
%s4 = linalg.generic {indexing_maps = [#map, #map1],
iterator_types = ["parallel", "reduction", "parallel"]}
ins(%s3 : tensor<?x?x?xf32, #csrv>) outs(%s2 : tensor<?x?xf32>) {
^bb0(%sin: f32, %sout: f32):
%sres = sparse_tensor.reduce %sin, %sout, %scst_8 : f32 {
^bb0(%sx0: f32, %sx1: f32):
%s00 = arith.addf %sx0, %sx1 : f32
sparse_tensor.yield %s00: f32
}
linalg.yield %sres : f32
} -> tensor<?x?xf32>
%attn31 = linalg.generic {indexing_maps = [#map],
iterator_types = ["parallel", "parallel", "parallel"]}
outs(%s3: tensor<?x?x?xf32, #csrv>) {
^bb0(%sin: f32):
%sx = linalg.index 0: index
%sz = linalg.index 2: index
%sresult = sparse_tensor.unary %sin : f32 to f32
present={
^bb0(%sin1: f32):
%sdenom = tensor.extract %s4[%sx, %sz]: tensor<?x?xf32>
%sret = arith.divf %sin1, %sdenom : f32
sparse_tensor.yield %sret : f32
}
absent={}
linalg.yield %sresult : f32
} -> tensor<?x?x?xf32, #csrv>
bufferization.dealloc_tensor %s1: tensor<?x?xf32>
bufferization.dealloc_tensor %s11: tensor<?x?xf32>
return %attn31 : tensor<?x?x?xf32, #csrv>
}
}
// Local Variables:
// rmsbolt-command: "lapis-opt --sparse-compiler-kokkos='pt-backend=mpi'"
// rmsbolt-automatic-recompile: on-save
// End:
Metadata
Metadata
Assignees
Labels
No labels