Skip to content

Conversation

@chengjunlu
Copy link
Contributor

@chengjunlu chengjunlu commented Nov 4, 2025

In the DPAS layout, three data types are involved in the block load and dot product (DPAS) computation flow:

  1. load2DGenXType – represents the raw data loaded from memory.
  2. packedDPASOperandType – the packed form of data used as DPAS operands.
  3. unpackedType – the unpacked form used for intermediate transformations.

For non-DPAS layouts, only the first two types are used. The data flow proceeds as follows:

  1. A 2D block load operation fetches data into load2DGenXType values.
  2. Vector shuffles reorganize the loaded data into DPAS operand fragments.
  3. Bitcasts convert between packed and unpacked representations to prepare data for computation.
  4. The tt.dot (DPAS) operation consumes packed operands to perform the dot product.

During optimization, redundant pack/unpack and bitcast operations are removed, resulting in a simplified sequence:

  • A single block load (load_2d)
  • Shuffle operations defining operand layout
  • A DPAS instruction consuming packed operands

Conceptually, the combination of packedDPASOperandType and shufflevector determines how input data maps to the DPAS computation flow.

Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull Request Overview

This PR improves the 2D block IO lowering for DPAS (Dot Product Accumulate Systolic) and DotOp layouts by extending support from just OperandB to all DPAS operand types (OperandA, OperandB, and OperandC). The implementation refactors the existing if-else structure into a switch statement and adds detailed documentation explaining the data flow and optimization patterns for DPAS operations.

Key changes:

  • Extends DPAS layout handling to support OperandA and OperandC in addition to OperandB
  • Refactors conditional logic from if-else to switch statement with proper default case handling
  • Adds comprehensive inline documentation explaining the three-type system and data flow optimization

Reviewed Changes

Copilot reviewed 2 out of 2 changed files in this pull request and generated 2 comments.

File Description
LoadStoreOpToLLVM.cpp Extends DPAS operand handling with switch statement and adds detailed documentation of data flow patterns
tensor-pointer-load-block-2d.mlir Updates test expectations to include new shuffle vector and bitcast operations for DPAS operand handling

💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

@chengjunlu chengjunlu force-pushed the chengjun/improve_2d_load_for_dpas branch from 05e4534 to 0aa1b3c Compare November 4, 2025 04:49
@chengjunlu chengjunlu requested a review from Copilot November 4, 2025 04:52
Copy link
Contributor

Copilot AI left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Pull Request Overview

Copilot reviewed 2 out of 2 changed files in this pull request and generated 3 comments.


💡 Add Copilot custom instructions for smarter, more guided reviews. Learn how to get started.

@chengjunlu chengjunlu force-pushed the chengjun/improve_2d_load_for_dpas branch 2 times, most recently from 458d741 to 6705f55 Compare November 4, 2025 06:59
Comment on lines +2779 to +2780
default:
llvm_unreachable("unexpected OpIdx type.");
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This should be unnecessary because DpasEncodingAttr::OpIdx is an enum class with 3 enumerators, which have a corresponding case in the switch statement.

@etiotto
Copy link
Contributor

etiotto commented Nov 4, 2025

@chengjunlu does the PR improves any benchmark ?

@chengjunlu
Copy link
Contributor Author

@chengjunlu does the PR improves any benchmark ?

I will collect the data.

@chengjunlu
Copy link
Contributor Author

@chengjunlu does the PR improves any benchmark ?

In the 4kx4kx4k of the gemm_tensor_of_ptr_benchmark, the changes can improve the register spilling of the configuration:

Autotuning kernel matmul_kernel with config BLOCK_SIZE_M: 256, BLOCK_SIZE_N: 256, BLOCK_SIZE_K: 32, GROUP_SIZE_M: 4, grf_mode: large, num_warps: 32, num_ctas: 1, num_stages: 3, maxnreg: None,

There are about 1.5k register spilling without the change for the same configuration:

(I): Detected 1472 spills for  "matmul_kernel"
Autotuning kernel matmul_kernel with config BLOCK_SIZE_M: 256, BLOCK_SIZE_N: 256, BLOCK_SIZE_K: 32, GROUP_SIZE_M: 4, grf_mode: large, num_warps: 32, num_ctas: 1, num_stages: 3, maxnreg: None,

The key difference is that the original code generate extra bitcast from/to i32 bewteen load and dpas which cannot be optimized:

  %194 = call <64 x i16> @llvm.genx.GenISA.LSC2DBlockRead.v64i16(i64 %190, i32 %193, i32 31, i32 8191, i32 %192, i32 0, i32 16, i32 16, i32 32, i32 2, i1 false, i1 false, i32 0)
  %bc207 = bitcast <64 x i16> %194 to <64 x bfloat>, !dbg !439
  %195 = shufflevector <64 x bfloat> %bc, <64 x bfloat> undef, <16 x i32> <i32 0, i32 1, i32 2, i32 3, i32 4, i32 5, i32 6, i32 7, i32 8, 
...
  %204 = bitcast <16 x bfloat> %195 to <8 x i32>, !dbg !443
...
  %205 = call <8 x float> @llvm.genx.GenISA.sub.group.dpas.v8f32.v8f32.v8i16.v8i32(<8 x float> %vectorized_phi, <8 x i16> %203, <8 x i32> %204, i32 11, i32 11, i32 8, i32 8, i1 false) #5, !dbg !443

With this change, the bitcast is eliminated.

  %170 = call <64 x i16> @llvm.genx.GenISA.LSC2DBlockRead.v64i16(i64 %166, i32 %169, i32 31, i32 8191, i32 %168, i32 0, i32 16, i32 16, i32 32, i32 2, i1 false, i1 false, i32 0)
  %171 = shufflevector <64 x i16> %170, <64 x i16> undef, <8 x i32> <i32 0, i32 1, i32 2, i32 3, i32 4, i32 5, i32 6, i32 7>, !dbg !438
 ...
  %211 = call <8 x float> @llvm.genx.GenISA.sub.group.dpas.v8f32.v8f32.v8i16.v8i32(<8 x float> %vectorized_phi, <8 x i16> %171, <8 x i32> %191, i32 11, i32 11, i32 8, i32 8, i1 false) #5, !dbg !443

The IGC can generate more efficient code without useless bitcasts in which there are sufflevector.

The performance improved of the 4kx4kx4k case is:
143.4375 tflops -> 176.915987

The benchmark runner is still in progress:
https://github.com/intel/intel-xpu-backend-for-triton/actions/runs/19094440683

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

3 participants