Skip to content

[BMG][pytorch upstream] pointwise op is slow when using 2D grid #5761

@jianyizh

Description

@jianyizh

Describe the issue

pytorch inductor changes memory coalesce analysis and leads to different tiling intel/torch-xpu-ops#2655. I wonder why triton performance is so different (3.765ms vs 24.192ms). They also uses similar config: XBLOCK: 1024, num_warps: 4, num_ctas: 1, num_stages: 1, maxnreg: None and XBLOCK: 1, YBLOCK: 1024, num_warps: 4, num_ctas: 1, num_stages: 1, maxnreg: None

previous fast kernerl:

import triton
import triton.language as tl

from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
triton_helpers.set_driver_to_gpu()

from torch._dynamo.testing import rand_strided
from torch._C import _xpu_getCurrentRawStream as get_raw_stream
import torch

@triton_heuristics.pointwise(
    size_hints={'x': 1073741824},
    filename=__file__,
    triton_meta={'signature': {'in_ptr0': '*i64', 'in_ptr1': '*fp16', 'out_ptr0': '*fp16', 'xnumel': 'i32', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='xpu', index=0, multi_processor_count=20, cc={'architecture': 21479031808, 'device_id': 57867, 'driver_version': '1.14.36511', 'gpu_eu_count': 160, 'gpu_subslice_count': 20, 'has_atomic64': True, 'has_bfloat16_conversions': True, 'has_fp16': True, 'has_fp64': True, 'has_subgroup_2d_block_io': True, 'has_subgroup_matrix_multiply_accumulate': True, 'has_subgroup_matrix_multiply_accumulate_tensor_float32': False, 'max_compute_units': 160, 'max_num_sub_groups': 64, 'max_work_group_size': 1024, 'name': 'Intel(R) Arc(TM) B580 Graphics', 'platform_name': 'Intel(R) oneAPI Unified Runtime over Level-Zero V2', 'sub_group_sizes': [16, 32], 'total_memory': 12168933376, 'type': 'gpu', 'vendor': 'Intel(R) Corporation', 'version': '20.1.0'}, major=None, regs_per_multiprocessor=None, max_threads_per_multi_processor=None, max_threads_per_block=1024, warp_size=32), 'constants': {}, 'native_matmul': False, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]]}], 'enable_fp_fusion': True},
    inductor_meta={'grid_type': 'Grid1D', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused__unsafe_view_cat_expand_gather_remainder_repeat_slice_transpose_unsqueeze_view_14', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'atomic_add_found': False, 'num_load': 3, 'num_store': 1, 'num_reduction': 0, 'backend_hash': '863CB82B179F018B50BDF9839C4D7F48AB1E094E88F811294C2A303D4612552E', 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'deterministic': False, 'force_filter_reduction_configs': False, 'are_deterministic_algorithms_enabled': False, 'kernel_num_gb': 2.466250752, 'kernel_flop': 0},
    min_elem_per_thread=0
)
@triton.jit
def triton_poi_fused__unsafe_view_cat_expand_gather_remainder_repeat_slice_transpose_unsqueeze_view_14(in_ptr0, in_ptr1, out_ptr0, xnumel, XBLOCK : tl.constexpr):
    xnumel = 805306368
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[:]
    xmask = tl.full([XBLOCK], True, tl.int1)[:]
    x1 = ((xindex // 64) % 128)
    x2 = ((xindex // 8192) % 64)
    x5 = xindex // 524288
    x0 = (xindex % 64)
    x3 = ((xindex // 524288) % 12)
    x4 = xindex // 6291456
    x7 = xindex // 8192
    x8 = xindex
    tmp0 = x1
    tmp1 = tl.full([1], 0, tl.int64)
    tmp2 = tmp0 >= tmp1
    tmp3 = tl.full([1], 64, tl.int64)
    tmp4 = tmp0 < tmp3
    tmp5 = x2
    tmp6 = tl.full([1], 0, tl.int64)
    tmp7 = tmp5 >= tmp6
    tmp8 = tl.full([1], 1, tl.int64)
    tmp9 = tmp5 < tmp8
    tmp10 = tmp9 & tmp4
    tmp11 = tl.load(in_ptr0 + (4032 + 64*(x2) + 4096*x5 + (x1)), tmp10, eviction_policy='evict_last', other=0.0)
    tmp12 = tl.full([1], 4096, tl.int64)
    tmp13 = (tmp11 % tmp12)
    tmp14 = tl.full([1], 0, tl.int32)
    tmp15 = tmp13 != tmp14
    tmp16 = (libdevice.signbit(tmp13) != 0) if (tmp13).dtype is tl.float32 else tmp13 < 0
    tmp17 = (libdevice.signbit(tmp12) != 0) if (tmp12).dtype is tl.float32 else tmp12 < 0
    tmp18 = tmp16 != tmp17
    tmp19 = tmp15 & tmp18
    tmp20 = tmp13 + tmp12
    tmp21 = tl.where(tmp19, tmp20, tmp13)
    tmp22 = tl.full([XBLOCK], 4096, tl.int32)
    tmp23 = tmp21 + tmp22
    tmp24 = tmp21 < 0
    tmp25 = tl.where(tmp24, tmp23, tmp21)
    tl.device_assert(((0 <= tl.broadcast_to(tmp25, [XBLOCK])) & (tl.broadcast_to(tmp25, [XBLOCK]) < 4096)) | ~(tmp10), "index out of bounds: 0 <= tl.broadcast_to(tmp25, [XBLOCK]) < 4096")
    tmp27 = tl.load(in_ptr1 + (x0 + 64*x3 + 768*tmp25 + 3145728*x4), tmp10, other=0.0).to(tl.float32)
    tmp28 = tmp5 >= tmp8
    tmp29 = tl.full([1], 64, tl.int64)
    tmp30 = tmp5 < tmp29
    tmp31 = tmp28 & tmp4
    tmp32 = tl.load(in_ptr0 + (64*((-1) + x2) + 4096*x5 + (x1)), tmp31, eviction_policy='evict_last', other=0.0)
    tmp33 = tl.full([1], 4096, tl.int64)
    tmp34 = (tmp32 % tmp33)
    tmp35 = tl.full([1], 0, tl.int32)
    tmp36 = tmp34 != tmp35
    tmp37 = (libdevice.signbit(tmp34) != 0) if (tmp34).dtype is tl.float32 else tmp34 < 0
    tmp38 = (libdevice.signbit(tmp33) != 0) if (tmp33).dtype is tl.float32 else tmp33 < 0
    tmp39 = tmp37 != tmp38
    tmp40 = tmp36 & tmp39
    tmp41 = tmp34 + tmp33
    tmp42 = tl.where(tmp40, tmp41, tmp34)
    tmp43 = tl.full([XBLOCK], 4096, tl.int32)
    tmp44 = tmp42 + tmp43
    tmp45 = tmp42 < 0
    tmp46 = tl.where(tmp45, tmp44, tmp42)
    tl.device_assert(((0 <= tl.broadcast_to(tmp46, [XBLOCK])) & (tl.broadcast_to(tmp46, [XBLOCK]) < 4096)) | ~(tmp31), "index out of bounds: 0 <= tl.broadcast_to(tmp46, [XBLOCK]) < 4096")
    tmp48 = tl.load(in_ptr1 + (x0 + 64*x3 + 768*tmp46 + 3145728*x4), tmp31, other=0.0).to(tl.float32)
    tmp49 = tl.where(tmp9, tmp27, tmp48)
    tmp50 = tl.full(tmp49.shape, 0.0, tmp49.dtype)
    tmp51 = tl.where(tmp4, tmp49, tmp50)
    tmp52 = tmp0 >= tmp3
    tmp53 = tl.full([1], 128, tl.int64)
    tmp54 = tmp0 < tmp53
    tmp55 = tl.load(in_ptr0 + (64*x7 + ((-64) + x1)), tmp52, eviction_policy='evict_last', other=0.0)
    tmp56 = tl.full([1], 4096, tl.int64)
    tmp57 = (tmp55 % tmp56)
    tmp58 = tl.full([1], 0, tl.int32)
    tmp59 = tmp57 != tmp58
    tmp60 = (libdevice.signbit(tmp57) != 0) if (tmp57).dtype is tl.float32 else tmp57 < 0
    tmp61 = (libdevice.signbit(tmp56) != 0) if (tmp56).dtype is tl.float32 else tmp56 < 0
    tmp62 = tmp60 != tmp61
    tmp63 = tmp59 & tmp62
    tmp64 = tmp57 + tmp56
    tmp65 = tl.where(tmp63, tmp64, tmp57)
    tmp66 = tl.full([XBLOCK], 4096, tl.int32)
    tmp67 = tmp65 + tmp66
    tmp68 = tmp65 < 0
    tmp69 = tl.where(tmp68, tmp67, tmp65)
    tl.device_assert(((0 <= tl.broadcast_to(tmp69, [XBLOCK])) & (tl.broadcast_to(tmp69, [XBLOCK]) < 4096)) | ~(tmp52), "index out of bounds: 0 <= tl.broadcast_to(tmp69, [XBLOCK]) < 4096")
    tmp71 = tl.load(in_ptr1 + (x0 + 64*x3 + 768*tmp69 + 3145728*x4), tmp52, other=0.0).to(tl.float32)
    tmp72 = tl.where(tmp4, tmp51, tmp71)
    tl.store(out_ptr0 + (x8), tmp72, None)


def get_args():
    arg_0 = rand_strided((128, 12, 4096), (49152, 4096, 1), device='xpu:0', dtype=torch.int64)
    arg_1 = rand_strided((524288, 768), (768, 1), device='xpu:0', dtype=torch.float16)
    arg_2 = rand_strided((128, 12, 64, 128, 64), (6291456, 524288, 8192, 64, 1), device='xpu:0', dtype=torch.float16)
    return arg_0, arg_1, arg_2, 805306368,


def call(args):
    with torch.xpu._DeviceGuard(0):
        torch.xpu.set_device(0)
        stream0 = get_raw_stream(0)
        triton_poi_fused__unsafe_view_cat_expand_gather_remainder_repeat_slice_transpose_unsqueeze_view_14.run(*args, stream=stream0)


def benchmark_all_configs(args):
    with torch.xpu._DeviceGuard(0):
        torch.xpu.set_device(0)
        return triton_poi_fused__unsafe_view_cat_expand_gather_remainder_repeat_slice_transpose_unsqueeze_view_14.benchmark_all_configs(*args)


if __name__ == '__main__':
    from torch._inductor.runtime.benchmarking import benchmarker

    args = get_args()
    ms = benchmarker.benchmark(lambda: call(args), device='xpu', rep=40)
    num_gb = 2.466250752
    gb_per_s = num_gb / (ms / 1e3)
    print(f"{ms:.3f}ms    {num_gb:.3f}GB    {gb_per_s:.2f}GB/s")

current slow kernel:

import triton
import triton.language as tl

from torch._inductor.runtime import triton_helpers, triton_heuristics
from torch._inductor.runtime.triton_helpers import libdevice, math as tl_math
from torch._inductor.runtime.hints import AutotuneHint, ReductionHint, TileHint, DeviceProperties
triton_helpers.set_driver_to_gpu()

from torch._dynamo.testing import rand_strided
from torch._C import _xpu_getCurrentRawStream as get_raw_stream
import torch

@triton_heuristics.pointwise(
    size_hints={'y': 16777216, 'x': 64}, tile_hint=TileHint.DEFAULT,
    filename=__file__,
    triton_meta={'signature': {'in_ptr0': '*i64', 'in_ptr1': '*fp16', 'out_ptr0': '*fp16', 'ynumel': 'i32', 'xnumel': 'i32', 'YBLOCK': 'constexpr', 'XBLOCK': 'constexpr'}, 'device': DeviceProperties(type='xpu', index=0, multi_processor_count=20, cc={'architecture': 21479031808, 'device_id': 57867, 'driver_version': '1.14.36511', 'gpu_eu_count': 160, 'gpu_subslice_count': 20, 'has_atomic64': True, 'has_bfloat16_conversions': True, 'has_fp16': True, 'has_fp64': True, 'has_subgroup_2d_block_io': True, 'has_subgroup_matrix_multiply_accumulate': True, 'has_subgroup_matrix_multiply_accumulate_tensor_float32': False, 'max_compute_units': 160, 'max_num_sub_groups': 64, 'max_work_group_size': 1024, 'name': 'Intel(R) Arc(TM) B580 Graphics', 'platform_name': 'Intel(R) oneAPI Unified Runtime over Level-Zero V2', 'sub_group_sizes': [16, 32], 'total_memory': 12168933376, 'type': 'gpu', 'vendor': 'Intel(R) Corporation', 'version': '20.1.0'}, major=None, regs_per_multiprocessor=None, max_threads_per_multi_processor=None, max_threads_per_block=1024, warp_size=32), 'constants': {}, 'native_matmul': False, 'configs': [{(0,): [['tt.divisibility', 16]], (1,): [['tt.divisibility', 16]], (2,): [['tt.divisibility', 16]], (3,): [['tt.divisibility', 16]], (4,): [['tt.divisibility', 16]]}], 'enable_fp_fusion': True},
    inductor_meta={'grid_type': 'Grid2DWithYZOverflow', 'autotune_hints': set(), 'kernel_name': 'triton_poi_fused__unsafe_view_cat_expand_gather_remainder_repeat_slice_transpose_unsqueeze_view_14', 'mutated_arg_names': [], 'optimize_mem': True, 'no_x_dim': False, 'atomic_add_found': False, 'num_load': 3, 'num_store': 1, 'num_reduction': 0, 'backend_hash': '863CB82B179F018B50BDF9839C4D7F48AB1E094E88F811294C2A303D4612552E', 'assert_indirect_indexing': True, 'autotune_local_cache': True, 'autotune_pointwise': True, 'autotune_remote_cache': None, 'force_disable_caches': False, 'dynamic_scale_rblock': True, 'max_autotune': False, 'max_autotune_pointwise': False, 'min_split_scan_rblock': 256, 'spill_threshold': 16, 'store_cubin': False, 'deterministic': False, 'force_filter_reduction_configs': False, 'are_deterministic_algorithms_enabled': False, 'tiling_scores': {'y': 100663296, 'x': 3221225472}, 'kernel_num_gb': 2.466250752, 'kernel_flop': 0},
    min_elem_per_thread=0
)
@triton.jit
def triton_poi_fused__unsafe_view_cat_expand_gather_remainder_repeat_slice_transpose_unsqueeze_view_14(in_ptr0, in_ptr1, out_ptr0, ynumel, xnumel, YBLOCK : tl.constexpr, XBLOCK : tl.constexpr):
    ynumel = 12582912
    xnumel = 64
    yoffset = (tl.program_id(1) + tl.program_id(2) * tl.num_programs(1)) * YBLOCK
    yindex = yoffset + tl.arange(0, YBLOCK)[:, None]
    ymask = yindex < ynumel
    xoffset = tl.program_id(0) * XBLOCK
    xindex = xoffset + tl.arange(0, XBLOCK)[None, :]
    xmask = xindex < xnumel
    y0 = (yindex % 128)
    y1 = ((yindex // 128) % 64)
    y5 = yindex // 8192
    x4 = xindex
    y2 = ((yindex // 8192) % 12)
    y3 = yindex // 98304
    y7 = yindex // 128
    y8 = yindex
    tmp0 = y0
    tmp1 = tl.full([1, 1], 0, tl.int64)
    tmp2 = tmp0 >= tmp1
    tmp3 = tl.full([1, 1], 64, tl.int64)
    tmp4 = tmp0 < tmp3
    tmp5 = tl.broadcast_to(y1, [YBLOCK, XBLOCK])
    tmp6 = tl.full([1, 1], 0, tl.int64)
    tmp7 = tmp5 >= tmp6
    tmp8 = tl.full([1, 1], 1, tl.int64)
    tmp9 = tmp5 < tmp8
    tmp10 = tmp9 & tmp4
    tmp11 = tl.load(in_ptr0 + (tl.broadcast_to(4032 + 64*(y1) + 4096*y5 + (y0), [YBLOCK, XBLOCK])), tmp10 & xmask & ymask, eviction_policy='evict_last', other=0.0)
    tmp12 = tl.full([1, 1], 4096, tl.int64)
    tmp13 = (tmp11 % tmp12)
    tmp14 = tl.full([1, 1], 0, tl.int32)
    tmp15 = tmp13 != tmp14
    tmp16 = (libdevice.signbit(tmp13) != 0) if (tmp13).dtype is tl.float32 else tmp13 < 0
    tmp17 = (libdevice.signbit(tmp12) != 0) if (tmp12).dtype is tl.float32 else tmp12 < 0
    tmp18 = tmp16 != tmp17
    tmp19 = tmp15 & tmp18
    tmp20 = tmp13 + tmp12
    tmp21 = tl.where(tmp19, tmp20, tmp13)
    tmp22 = tl.full([1, 1], 4096, tl.int32)
    tmp23 = tmp21 + tmp22
    tmp24 = tmp21 < 0
    tmp25 = tl.where(tmp24, tmp23, tmp21)
    tl.device_assert(((0 <= tl.broadcast_to(tmp25, [YBLOCK, XBLOCK])) & (tl.broadcast_to(tmp25, [YBLOCK, XBLOCK]) < 4096)) | ~(tmp10 & xmask & ymask), "index out of bounds: 0 <= tl.broadcast_to(tmp25, [YBLOCK, XBLOCK]) < 4096")
    tmp27 = tl.load(in_ptr1 + (x4 + 64*y2 + 768*tmp25 + 3145728*y3), tmp10 & xmask & ymask, other=0.0).to(tl.float32)
    tmp28 = tmp5 >= tmp8
    tmp29 = tl.full([1, 1], 64, tl.int64)
    tmp30 = tmp5 < tmp29
    tmp31 = tmp28 & tmp4
    tmp32 = tl.load(in_ptr0 + (tl.broadcast_to(64*((-1) + y1) + 4096*y5 + (y0), [YBLOCK, XBLOCK])), tmp31 & xmask & ymask, eviction_policy='evict_last', other=0.0)
    tmp33 = tl.full([1, 1], 4096, tl.int64)
    tmp34 = (tmp32 % tmp33)
    tmp35 = tl.full([1, 1], 0, tl.int32)
    tmp36 = tmp34 != tmp35
    tmp37 = (libdevice.signbit(tmp34) != 0) if (tmp34).dtype is tl.float32 else tmp34 < 0
    tmp38 = (libdevice.signbit(tmp33) != 0) if (tmp33).dtype is tl.float32 else tmp33 < 0
    tmp39 = tmp37 != tmp38
    tmp40 = tmp36 & tmp39
    tmp41 = tmp34 + tmp33
    tmp42 = tl.where(tmp40, tmp41, tmp34)
    tmp43 = tl.full([1, 1], 4096, tl.int32)
    tmp44 = tmp42 + tmp43
    tmp45 = tmp42 < 0
    tmp46 = tl.where(tmp45, tmp44, tmp42)
    tl.device_assert(((0 <= tl.broadcast_to(tmp46, [YBLOCK, XBLOCK])) & (tl.broadcast_to(tmp46, [YBLOCK, XBLOCK]) < 4096)) | ~(tmp31 & xmask & ymask), "index out of bounds: 0 <= tl.broadcast_to(tmp46, [YBLOCK, XBLOCK]) < 4096")
    tmp48 = tl.load(in_ptr1 + (x4 + 64*y2 + 768*tmp46 + 3145728*y3), tmp31 & xmask & ymask, other=0.0).to(tl.float32)
    tmp49 = tl.where(tmp9, tmp27, tmp48)
    tmp50 = tl.full(tmp49.shape, 0.0, tmp49.dtype)
    tmp51 = tl.where(tmp4, tmp49, tmp50)
    tmp52 = tmp0 >= tmp3
    tmp53 = tl.full([1, 1], 128, tl.int64)
    tmp54 = tmp0 < tmp53
    tmp55 = tl.load(in_ptr0 + (tl.broadcast_to(64*y7 + ((-64) + y0), [YBLOCK, XBLOCK])), tmp52 & xmask & ymask, eviction_policy='evict_last', other=0.0)
    tmp56 = tl.full([1, 1], 4096, tl.int64)
    tmp57 = (tmp55 % tmp56)
    tmp58 = tl.full([1, 1], 0, tl.int32)
    tmp59 = tmp57 != tmp58
    tmp60 = (libdevice.signbit(tmp57) != 0) if (tmp57).dtype is tl.float32 else tmp57 < 0
    tmp61 = (libdevice.signbit(tmp56) != 0) if (tmp56).dtype is tl.float32 else tmp56 < 0
    tmp62 = tmp60 != tmp61
    tmp63 = tmp59 & tmp62
    tmp64 = tmp57 + tmp56
    tmp65 = tl.where(tmp63, tmp64, tmp57)
    tmp66 = tl.full([1, 1], 4096, tl.int32)
    tmp67 = tmp65 + tmp66
    tmp68 = tmp65 < 0
    tmp69 = tl.where(tmp68, tmp67, tmp65)
    tl.device_assert(((0 <= tl.broadcast_to(tmp69, [YBLOCK, XBLOCK])) & (tl.broadcast_to(tmp69, [YBLOCK, XBLOCK]) < 4096)) | ~(tmp52 & xmask & ymask), "index out of bounds: 0 <= tl.broadcast_to(tmp69, [YBLOCK, XBLOCK]) < 4096")
    tmp71 = tl.load(in_ptr1 + (x4 + 64*y2 + 768*tmp69 + 3145728*y3), tmp52 & xmask & ymask, other=0.0).to(tl.float32)
    tmp72 = tl.where(tmp4, tmp51, tmp71)
    tl.store(out_ptr0 + (x4 + 64*y8), tmp72, xmask & ymask)


def get_args():
    arg_0 = rand_strided((128, 12, 4096), (49152, 4096, 1), device='xpu:0', dtype=torch.int64)
    arg_1 = rand_strided((524288, 768), (768, 1), device='xpu:0', dtype=torch.float16)
    arg_2 = rand_strided((128, 12, 64, 128, 64), (6291456, 524288, 8192, 64, 1), device='xpu:0', dtype=torch.float16)
    return arg_0, arg_1, arg_2, 12582912, 64,


def call(args):
    with torch.xpu._DeviceGuard(0):
        torch.xpu.set_device(0)
        stream0 = get_raw_stream(0)
        triton_poi_fused__unsafe_view_cat_expand_gather_remainder_repeat_slice_transpose_unsqueeze_view_14.run(*args, stream=stream0)


def benchmark_all_configs(args):
    with torch.xpu._DeviceGuard(0):
        torch.xpu.set_device(0)
        return triton_poi_fused__unsafe_view_cat_expand_gather_remainder_repeat_slice_transpose_unsqueeze_view_14.benchmark_all_configs(*args)


if __name__ == '__main__':
    from torch._inductor.runtime.benchmarking import benchmarker

    args = get_args()
    ms = benchmarker.benchmark(lambda: call(args), device='xpu', rep=40)
    num_gb = 2.466250752
    gb_per_s = num_gb / (ms / 1e3)
    print(f"{ms:.3f}ms    {num_gb:.3f}GB    {gb_per_s:.2f}GB/s")

Environment details

device: B580
triton-xpu 3.6.0+git225cdbde
torch 2.11.0a0+git1897172
ii libigc-dev 1:2.25.31484.20243-main amd64 Intel graphics compiler for OpenCL -- core development files
ii intel-level-zero-gpu 1.14.036511.20243-main amd64 Intel(R) Graphics Compute Runtime for oneAPI Level Zero.

Metadata

Metadata

Assignees

Type

Projects

No projects

Relationships

None yet

Development

No branches or pull requests

Issue actions