[DispatchCreation] Hoist scalar tensor.extract and tensor.extract_slice#24552
[DispatchCreation] Hoist scalar tensor.extract and tensor.extract_slice#24552juanigp wants to merge 3 commits into
Conversation
36225a8 to
bbe0ed0
Compare
Signed-off-by: default <pisula@roofline.ai> Signed-off-by: Juan Ignacio Pisula <pisula@roofline.ai>
Signed-off-by: default <pisula@roofline.ai> Signed-off-by: Juan Ignacio Pisula <pisula@roofline.ai>
bbe0ed0 to
8126c87
Compare
|
Generally looks good to me! One thing i noticed while inspecting the dump (with this patch) from the reproducer is that there is probably some follow up work we could do here. This is the IR before the tensor-to-flow conversion %extracted_3 = tensor.extract %6[] : tensor<i64>
%c0_i64_4 = arith.constant 0 : i64
%32 = arith.cmpi sge, %extracted_3, %c0_i64_4 : i64
%33 = arith.addi %extracted_3, %c256_i64_2 : i64
%34 = arith.select %32, %extracted_3, %33 : i64
%35 = arith.cmpi slt, %34, %c0_i64_4 : i64
%36 = arith.select %35, %c0_i64_4, %34 : i64
%37 = arith.cmpi sgt, %36, %c256_i64_2 : i64
%38 = arith.select %37, %c256_i64_2, %36 : i64
%39 = arith.index_cast %38 : i64 to index
%c1 = arith.constant 1 : index
%dim = tensor.dim %30, %c1 : tensor<16x?x64xf32>
%c1_5 = arith.constant 1 : index
%dim_6 = tensor.dim %30, %c1_5 : tensor<16x?x64xf32>
%c1_7 = arith.constant 1 : index
%dim_8 = tensor.dim %30, %c1_7 : tensor<16x?x64xf32>
%40 = flow.tensor.encode %30 : tensor<16x?x64xf32>{%dim_8} -> tensor<16x?x64xf32, #encoding>{%dim_8}
%41 = flow.tensor.encode %31 : tensor<16x64x256xf32> -> tensor<16x64x256xf32, #encoding1>
%extracted_slice = tensor.extract_slice %4[%39, 0] [%29, 256] [1, 1] : tensor<256x256xi8> to tensor<?x256xi8>
%42 = flow.tensor.encode %extracted_slice : tensor<?x256xi8>{%29} -> tensor<?x256xi8, #encoding3>{%29}
This is the IR after the tensor-to-flow conversion: %9 = flow.tensor.load %5 : tensor<i64>
%10 = arith.addi %9, %8 : i64
%11 = arith.addi %9, %c256_i64 : i64
%12 = arith.cmpi sge, %9, %c0_i64 : i64
%13 = arith.select %12, %9, %11 : i64
%14 = arith.cmpi slt, %13, %c0_i64 : i64
%15 = arith.select %14, %c0_i64, %13 : i64
%16 = arith.cmpi sgt, %15, %c256_i64 : i64
%17 = arith.select %16, %c256_i64, %15 : i64
%18 = arith.index_cast %17 : i64 to index
%19 = arith.index_cast %10 : i64 to index
%20 = arith.cmpi slt, %19, %c0 : index
%21 = arith.addi %19, %c256 : index
%22 = arith.select %20, %21, %19 : index
%23 = arith.cmpi slt, %22, %c0 : index
%24 = arith.select %23, %c-1, %22 : index
%25 = arith.cmpi sgt, %24, %c256 : index
%26 = arith.select %25, %c256, %24 : index
%27 = arith.subi %26, %18 : index
%28 = arith.cmpi slt, %27, %c0 : index
%29 = arith.select %28, %c0, %27 : index
%30 = flow.tensor.reshape %7 : tensor<1x16x?x64xf32>{%6} -> tensor<16x?x64xf32>{%6}
%31 = flow.tensor.reshape %2 : tensor<1x16x64x256xf32> -> tensor<16x64x256xf32>
%32 = flow.tensor.encode %30 : tensor<16x?x64xf32>{%6} -> tensor<16x?x64xf32, #iree_encoding.encoding<operand_index = 0 : index, op_type = matmul, element_types = [f32, f32, f32], user_indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>], iteration_sizes = [16, ?, 256, 64]>>{%6}
%33 = flow.tensor.encode %31 : tensor<16x64x256xf32> -> tensor<16x64x256xf32, #iree_encoding.encoding<operand_index = 1 : index, op_type = matmul, element_types = [f32, f32, f32], user_indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>], iteration_sizes = [16, ?, 256, 64]>>
%34 = flow.dispatch.workgroups(%5, %3, %29, %29) : (tensor<i64>, tensor<256x256xi8>, index, index) -> tensor<?x256xi8>{%29} =
(%arg6: !iree_tensor_ext.dispatch.tensor<readonly:tensor<i64>>, %arg7: !iree_tensor_ext.dispatch.tensor<readonly:tensor<256x256xi8>>, %arg8: index, %arg9: index, %arg10: !iree_tensor_ext.dispatch.tensor<writeonly:tensor<?x256xi8>>) {
%c0_i64_0 = arith.constant 0 : i64
%c256_i64_1 = arith.constant 256 : i64
%40 = iree_tensor_ext.dispatch.tensor.load %arg6, offsets = [], sizes = [], strides = [] : !iree_tensor_ext.dispatch.tensor<readonly:tensor<i64>> -> tensor<i64>
%extracted = tensor.extract %40[] : tensor<i64>
%41 = arith.addi %extracted, %c256_i64_1 : i64
%42 = arith.cmpi sge, %extracted, %c0_i64_0 : i64
%43 = arith.select %42, %extracted, %41 : i64
%44 = arith.cmpi slt, %43, %c0_i64_0 : i64
%45 = arith.select %44, %c0_i64_0, %43 : i64
%46 = arith.cmpi sgt, %45, %c256_i64_1 : i64
%47 = arith.select %46, %c256_i64_1, %45 : i64
%48 = arith.index_cast %47 : i64 to index
%49 = iree_tensor_ext.dispatch.tensor.load %arg7, offsets = [%48, 0], sizes = [%arg9, 256], strides = [1, 1] : !iree_tensor_ext.dispatch.tensor<readonly:tensor<256x256xi8>> -> tensor<?x256xi8>
iree_tensor_ext.dispatch.tensor.store %49, %arg10, offsets = [0, 0], sizes = [%arg9, 256], strides = [1, 1] : tensor<?x256xi8> -> !iree_tensor_ext.dispatch.tensor<writeonly:tensor<?x256xi8>>{%arg9}
flow.return
}The first "unit" extract_slice As far as i understand |
This PR aims to :
Context
I encountered this error when working on a modified version of LFM2.5. I attach a reproducer which captures the idea.
A causal mask path stayed fused inside a QK matmul dispatch, which produced large vectors due to tile size propagation in an mmt4d ukernel.
This originated due to the mask's slicing offset coming from a tensor.extract and subsequently producing a scalar metadata chain that remained inside the dispatch.
Proposed fix
HoistUniformScalarComputePasscan accept more "candidate ops" than just arith ops. I just included tensor.extract since its what I ran into.isUniformScalarForDispatchis still in charge to verify that the candidate op is hoistable, so i added the logic to check the tensor.extract ops.IREE::Flow::isOffsetSizeAndStrideMappableToFlowgot split into two:isOffsetSizeAndStrideStructurallyMappableToFlowjust checks if the slice can be represented as one flat contiguous byte range, andisOffsetSizeAndStrideMappableToFlowchecks for that and the additional tensor.extract provenance.isHoistableOpin HoistEncodingOps.cpp was rejecting extract slice ops whose offset, size, and stride where produced by an extract op due to callingisOffsetSizeAndStrideMappableToFlowon them. Now it callsisOffsetSizeAndStrideStructurallyMappableToFlow.Additional Notes:
mlir reproducer
Compile command:
Assisted by Codex 5.5