Skip to content

[DispatchCreation] Hoist scalar tensor.extract and tensor.extract_slice#24552

Open
juanigp wants to merge 3 commits into
iree-org:mainfrom
juanigp:fix-hoist-encoding-mask-slice
Open

[DispatchCreation] Hoist scalar tensor.extract and tensor.extract_slice#24552
juanigp wants to merge 3 commits into
iree-org:mainfrom
juanigp:fix-hoist-encoding-mask-slice

Conversation

@juanigp

@juanigp juanigp commented May 29, 2026

Copy link
Copy Markdown

This PR aims to :

  • Enable hoisting tensor.extract ops that read from a scalar tensor already outside the dispatch.
  • Loosen the requirement to hoist tensor.extract_slice ops.

Context

I encountered this error when working on a modified version of LFM2.5. I attach a reproducer which captures the idea.

A causal mask path stayed fused inside a QK matmul dispatch, which produced large vectors due to tile size propagation in an mmt4d ukernel.

mask_slice_qk_repro.mlir:45:18: error: One or more operations with large vector sizes (32768 bytes) were found:

	%scores_3d = torch.aten.bmm %q, %k : !torch.vtensor<[16,?,64],f32>, !torch.vtensor<[16,64,256],f32> -> !torch.vtensor<[16,?,256],f32>
                 ^
<unknown>:0: note:   %cst = arith.constant dense<0xFF800000> : vector<16x16x16x16xf32>

<unknown>:0: note:   %cst_0 = arith.constant dense<0.000000e+00> : vector<16x16x16x16xf32>

<unknown>:0: note:   %cst_1 = arith.constant dense<0> : vector<16x16x16x16xi8>

This originated due to the mask's slicing offset coming from a tensor.extract and subsequently producing a scalar metadata chain that remained inside the dispatch.

Proposed fix

  • HoistUniformScalarComputePass can accept more "candidate ops" than just arith ops. I just included tensor.extract since its what I ran into. isUniformScalarForDispatch is still in charge to verify that the candidate op is hoistable, so i added the logic to check the tensor.extract ops.
  • IREE::Flow::isOffsetSizeAndStrideMappableToFlow got split into two: isOffsetSizeAndStrideStructurallyMappableToFlow just checks if the slice can be represented as one flat contiguous byte range, and isOffsetSizeAndStrideMappableToFlow checks for that and the additional tensor.extract provenance.
  • isHoistableOp in HoistEncodingOps.cpp was rejecting extract slice ops whose offset, size, and stride where produced by an extract op due to calling isOffsetSizeAndStrideMappableToFlow on them. Now it calls isOffsetSizeAndStrideStructurallyMappableToFlow.

Additional Notes:

  • I tried to not interfere with the codebase's original intentions.
  • Since the extract and extract_slice make it out of the dispatch, the large vectors never occur. I thought this was the right way to address the root cause of the problem.
  • Inspecting the mmt4d ukernel tile size propagation, it seems that the problematic large vectors originated due to propagating a pack op tiling config to the outer dims of an accumulator, which should not happen afaiu. I could work on that separate issue if it is of interest.
mlir reproducer
module @module {
  func.func @forward(
      %query: !torch.vtensor<[1,16,?,64],f32>,
      %key: !torch.vtensor<[1,16,64,256],f32>,
      %mask: !torch.vtensor<[256,256],ui8>,
      %positions: !torch.vtensor<[1],si64>)
      -> !torch.vtensor<[1,16,?,256],f32>
      attributes {torch.assume_strict_symbolic_shapes} {
    %s = torch.symbolic_int "s" {min_val = 1, max_val = 256} : !torch.int
    torch.bind_symbolic_shape %query, [%s], affine_map<()[s0] -> (1, 16, s0, 64)> : !torch.vtensor<[1,16,?,64],f32>

    %int0 = torch.constant.int 0
    %int1 = torch.constant.int 1
    %int2 = torch.constant.int 2
    %int16 = torch.constant.int 16
    %int64 = torch.constant.int 64
    %int256 = torch.constant.int 256
    %int-1 = torch.constant.int -1

    %seq_len = torch.aten.size.int %query, %int2 : !torch.vtensor<[1,16,?,64],f32>, !torch.int -> !torch.int
    %pos_tensor = torch.aten.select.int %positions, %int0, %int-1 : !torch.vtensor<[1],si64>, !torch.int, !torch.int -> !torch.vtensor<[],si64>
    %pos = torch.aten.item %pos_tensor : !torch.vtensor<[],si64> -> !torch.int

    %end = torch.aten.add.int %pos, %seq_len : !torch.int, !torch.int -> !torch.int
    %bool_dtype = torch.constant.int 11
    %mask_bool = torch.prims.convert_element_type %mask, %bool_dtype : !torch.vtensor<[256,256],ui8>, !torch.int -> !torch.vtensor<[256,256],i1>
    %mask_slice = torch.aten.slice.Tensor %mask_bool, %int0, %pos, %end, %int1 : !torch.vtensor<[256,256],i1>, !torch.int, !torch.int, !torch.int, !torch.int -> !torch.vtensor<[?,256],i1>
    torch.bind_symbolic_shape %mask_slice, [%s], affine_map<()[s0] -> (s0, 256)> : !torch.vtensor<[?,256],i1>

    %float-Inf = torch.constant.float 0xFFF0000000000000
    %float0 = torch.constant.float 0.000000e+00
    %f32_dtype = torch.constant.int 6
    %none = torch.constant.none
    %cpu = torch.constant.device "cpu"
    %neg_inf = torch.aten.scalar_tensor %float-Inf, %f32_dtype, %none, %cpu, %none : !torch.float, !torch.int, !torch.none, !torch.Device, !torch.none -> !torch.vtensor<[],f32>
    %mask_bias = torch.aten.where.ScalarSelf %mask_slice, %float0, %neg_inf : !torch.vtensor<[?,256],i1>, !torch.float, !torch.vtensor<[],f32> -> !torch.vtensor<[?,256],f32>
    torch.bind_symbolic_shape %mask_bias, [%s], affine_map<()[s0] -> (s0, 256)> : !torch.vtensor<[?,256],f32>

    %q_shape = torch.prim.ListConstruct %int16, %seq_len, %int64 : (!torch.int, !torch.int, !torch.int) -> !torch.list<int>
    %q = torch.aten.view %query, %q_shape : !torch.vtensor<[1,16,?,64],f32>, !torch.list<int> -> !torch.vtensor<[16,?,64],f32>
    torch.bind_symbolic_shape %q, [%s], affine_map<()[s0] -> (16, s0, 64)> : !torch.vtensor<[16,?,64],f32>

    %k_shape = torch.prim.ListConstruct %int16, %int64, %int256 : (!torch.int, !torch.int, !torch.int) -> !torch.list<int>
    %k = torch.aten.view %key, %k_shape : !torch.vtensor<[1,16,64,256],f32>, !torch.list<int> -> !torch.vtensor<[16,64,256],f32>
    %scores_3d = torch.aten.bmm %q, %k : !torch.vtensor<[16,?,64],f32>, !torch.vtensor<[16,64,256],f32> -> !torch.vtensor<[16,?,256],f32>
    torch.bind_symbolic_shape %scores_3d, [%s], affine_map<()[s0] -> (16, s0, 256)> : !torch.vtensor<[16,?,256],f32>

    %scores_shape = torch.prim.ListConstruct %int1, %int16, %seq_len, %int256 : (!torch.int, !torch.int, !torch.int, !torch.int) -> !torch.list<int>
    %scores = torch.aten.view %scores_3d, %scores_shape : !torch.vtensor<[16,?,256],f32>, !torch.list<int> -> !torch.vtensor<[1,16,?,256],f32>
    torch.bind_symbolic_shape %scores, [%s], affine_map<()[s0] -> (1, 16, s0, 256)> : !torch.vtensor<[1,16,?,256],f32>

    %result = torch.aten.add.Tensor %scores, %mask_bias, %int1 : !torch.vtensor<[1,16,?,256],f32>, !torch.vtensor<[?,256],f32>, !torch.int -> !torch.vtensor<[1,16,?,256],f32>
    torch.bind_symbolic_shape %result, [%s], affine_map<()[s0] -> (1, 16, s0, 256)> : !torch.vtensor<[1,16,?,256],f32>
    return %result : !torch.vtensor<[1,16,?,256],f32>
  }
}

Compile command:

iree-compile \
  mask_slice_qk_repro.mlir  \
  -o mask_slice_qk_repro.vmfb  \
  --iree-input-type=auto \
  --iree-hal-target-device=local \
  --iree-opt-data-tiling=true \
  --iree-llvmcpu-enable-ukernels=all \
  --iree-hal-local-target-device-backends=llvm-cpu \
  --iree-hal-local-host-device-backends=llvm-cpu \
  --iree-llvmcpu-target-cpu-features=host 

Assisted by Codex 5.5

@juanigp juanigp force-pushed the fix-hoist-encoding-mask-slice branch 2 times, most recently from 36225a8 to bbe0ed0 Compare May 30, 2026 09:19
@egebeysel egebeysel self-requested a review June 3, 2026 11:38
juanigp added 3 commits June 3, 2026 13:44
Signed-off-by: default <pisula@roofline.ai>
Signed-off-by: Juan Ignacio Pisula <pisula@roofline.ai>
Signed-off-by: default <pisula@roofline.ai>
Signed-off-by: Juan Ignacio Pisula <pisula@roofline.ai>
Signed-off-by: default <pisula@roofline.ai>
Signed-off-by: Juan Ignacio Pisula <pisula@roofline.ai>
Signed-off-by: default <pisula@roofline.ai>
@juanigp juanigp force-pushed the fix-hoist-encoding-mask-slice branch from bbe0ed0 to 8126c87 Compare June 3, 2026 13:45
@AGindinson AGindinson requested a review from ziereis June 3, 2026 14:30
@ziereis

ziereis commented Jun 5, 2026

Copy link
Copy Markdown
Contributor

Generally looks good to me!

One thing i noticed while inspecting the dump (with this patch) from the reproducer is that there is probably some follow up work we could do here.

This is the IR before the tensor-to-flow conversion

    %extracted_3 = tensor.extract %6[] : tensor<i64>
    %c0_i64_4 = arith.constant 0 : i64
    %32 = arith.cmpi sge, %extracted_3, %c0_i64_4 : i64
    %33 = arith.addi %extracted_3, %c256_i64_2 : i64
    %34 = arith.select %32, %extracted_3, %33 : i64
    %35 = arith.cmpi slt, %34, %c0_i64_4 : i64
    %36 = arith.select %35, %c0_i64_4, %34 : i64
    %37 = arith.cmpi sgt, %36, %c256_i64_2 : i64
    %38 = arith.select %37, %c256_i64_2, %36 : i64
    %39 = arith.index_cast %38 : i64 to index
    %c1 = arith.constant 1 : index
    %dim = tensor.dim %30, %c1 : tensor<16x?x64xf32>
    %c1_5 = arith.constant 1 : index
    %dim_6 = tensor.dim %30, %c1_5 : tensor<16x?x64xf32>
    %c1_7 = arith.constant 1 : index
    %dim_8 = tensor.dim %30, %c1_7 : tensor<16x?x64xf32>
    %40 = flow.tensor.encode %30 : tensor<16x?x64xf32>{%dim_8} -> tensor<16x?x64xf32, #encoding>{%dim_8}
    %41 = flow.tensor.encode %31 : tensor<16x64x256xf32> -> tensor<16x64x256xf32, #encoding1>
    %extracted_slice = tensor.extract_slice %4[%39, 0] [%29, 256] [1, 1] : tensor<256x256xi8> to tensor<?x256xi8>
    %42 = flow.tensor.encode %extracted_slice : tensor<?x256xi8>{%29} -> tensor<?x256xi8, #encoding3>{%29}

This is the IR after the tensor-to-flow conversion:

  %9 = flow.tensor.load %5 : tensor<i64>
  %10 = arith.addi %9, %8 : i64
  %11 = arith.addi %9, %c256_i64 : i64
  %12 = arith.cmpi sge, %9, %c0_i64 : i64
  %13 = arith.select %12, %9, %11 : i64
  %14 = arith.cmpi slt, %13, %c0_i64 : i64
  %15 = arith.select %14, %c0_i64, %13 : i64
  %16 = arith.cmpi sgt, %15, %c256_i64 : i64
  %17 = arith.select %16, %c256_i64, %15 : i64
  %18 = arith.index_cast %17 : i64 to index
  %19 = arith.index_cast %10 : i64 to index
  %20 = arith.cmpi slt, %19, %c0 : index
  %21 = arith.addi %19, %c256 : index
  %22 = arith.select %20, %21, %19 : index
  %23 = arith.cmpi slt, %22, %c0 : index
  %24 = arith.select %23, %c-1, %22 : index
  %25 = arith.cmpi sgt, %24, %c256 : index
  %26 = arith.select %25, %c256, %24 : index
  %27 = arith.subi %26, %18 : index
  %28 = arith.cmpi slt, %27, %c0 : index
  %29 = arith.select %28, %c0, %27 : index
  %30 = flow.tensor.reshape %7 : tensor<1x16x?x64xf32>{%6} -> tensor<16x?x64xf32>{%6}
  %31 = flow.tensor.reshape %2 : tensor<1x16x64x256xf32> -> tensor<16x64x256xf32>
  %32 = flow.tensor.encode %30 : tensor<16x?x64xf32>{%6} -> tensor<16x?x64xf32, #iree_encoding.encoding<operand_index = 0 : index, op_type = matmul, element_types = [f32, f32, f32], user_indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>], iteration_sizes = [16, ?, 256, 64]>>{%6}
  %33 = flow.tensor.encode %31 : tensor<16x64x256xf32> -> tensor<16x64x256xf32, #iree_encoding.encoding<operand_index = 1 : index, op_type = matmul, element_types = [f32, f32, f32], user_indexing_maps = [affine_map<(d0, d1, d2, d3) -> (d0, d1, d3)>, affine_map<(d0, d1, d2, d3) -> (d0, d3, d2)>, affine_map<(d0, d1, d2, d3) -> (d0, d1, d2)>], iteration_sizes = [16, ?, 256, 64]>>
  %34 = flow.dispatch.workgroups(%5, %3, %29, %29) : (tensor<i64>, tensor<256x256xi8>, index, index) -> tensor<?x256xi8>{%29} =
      (%arg6: !iree_tensor_ext.dispatch.tensor<readonly:tensor<i64>>, %arg7: !iree_tensor_ext.dispatch.tensor<readonly:tensor<256x256xi8>>, %arg8: index, %arg9: index, %arg10: !iree_tensor_ext.dispatch.tensor<writeonly:tensor<?x256xi8>>) {
    %c0_i64_0 = arith.constant 0 : i64
    %c256_i64_1 = arith.constant 256 : i64
    %40 = iree_tensor_ext.dispatch.tensor.load %arg6, offsets = [], sizes = [], strides = [] : !iree_tensor_ext.dispatch.tensor<readonly:tensor<i64>> -> tensor<i64>
    %extracted = tensor.extract %40[] : tensor<i64>
    %41 = arith.addi %extracted, %c256_i64_1 : i64
    %42 = arith.cmpi sge, %extracted, %c0_i64_0 : i64
    %43 = arith.select %42, %extracted, %41 : i64
    %44 = arith.cmpi slt, %43, %c0_i64_0 : i64
    %45 = arith.select %44, %c0_i64_0, %43 : i64
    %46 = arith.cmpi sgt, %45, %c256_i64_1 : i64
    %47 = arith.select %46, %c256_i64_1, %45 : i64
    %48 = arith.index_cast %47 : i64 to index
    %49 = iree_tensor_ext.dispatch.tensor.load %arg7, offsets = [%48, 0], sizes = [%arg9, 256], strides = [1, 1] : !iree_tensor_ext.dispatch.tensor<readonly:tensor<256x256xi8>> -> tensor<?x256xi8>
    iree_tensor_ext.dispatch.tensor.store %49, %arg10, offsets = [0, 0], sizes = [%arg9, 256], strides = [1, 1] : tensor<?x256xi8> -> !iree_tensor_ext.dispatch.tensor<writeonly:tensor<?x256xi8>>{%arg9}
    flow.return
  }

The first "unit" extract_slice %extracted_3 = tensor.extract %6[] : tensor<i64> will get converted into %9 = flow.tensor.load %5 : tensor<i64>. However the second tensor.extract_slice %extracted_slice = tensor.extract_slice %4[%39, 0] [%29, 256] [1, 1] : tensor<256x256xi8> to tensor<?x256xi8> will not get converted into a flow.tensor.slice (even though structurally it could) because its blocked by the producedByValueExtract check and thus ends up as an extra dispatch.

As far as i understand producedByValueExtract is a heuristic which tries to prevent host<->device communication by assuming that if the offsets/sizes are produced by tensor.extract it will be on device and if it is on device its better to materialize the slice as a copy dispatch to avoid the device->host readback. However in this case of the "unit" extract it will be on host anyways so it introduced a unnecessary dispatch that could have just been a flow.tensor.slice.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants