-
Notifications
You must be signed in to change notification settings - Fork 1.6k
Open
Labels
Description
CuTe DSL
Bug Report
Hi, things are strange here. Orders are the same here, despite whether A tensor is k_major.
Plz check it, thank you.
| order=(0, 1, 2) if is_k_major else (0, 1, 2), |
@dsl_user_op
def make_smem_layout_a(
a_layout: LayoutEnum,
mma_tiler_mnk: cute.Tile,
a_dtype: Type[Numeric],
num_stages: int,
*,
loc=None,
ip=None,
) -> Union[cute.Layout, cute.ComposedLayout]:
"""This function helps with:
1. Get the partitioned shape of the A tensor based on the MMA tiler.
2. Select the heuristic SMEM layout atom based on the A tensor's majorness, the data type, and the major mode size.
3. cute.Tile the SMEM layout atom to the MMA tile shape.
4. Stage the SMEM layout based on the number of stages.
:param a_layout: The layout enum for tensor A
:type a_layout: LayoutEnum
:param mma_tiler_mnk: The MMA tile shape
:type mma_tiler_mnk: cute.cute.Tile
:param a_dtype: The element type for tensor A
:type a_dtype: Type[Numeric]
:param num_stages: The number of pipeline stages for tensor A
:type num_stages: int
:return: SMEM layout for tensor A
:rtype: Union[cute.Layout, cute.ComposedLayout]
"""
# Extract A tensor shape from the MMA tiler (M dimension)
a_tile_shape_mnk = mma_tiler_mnk
a_smem_shape = cute.slice_(a_tile_shape_mnk, (None, 0, None), loc=loc, ip=ip)
# Determine if K is the major mode and get the major mode size
is_k_major = a_layout.is_k_major_a()
a_major_mode_size = a_tile_shape_mnk[2] if is_k_major else a_tile_shape_mnk[0]
# Create SMEM layout atom for A tensor based on major mode and data type
a_smem_layout_atom = make_smem_layout_atom(
get_smem_layout_atom(a_layout, a_dtype, a_major_mode_size, loc=loc, ip=ip),
a_dtype,
loc=loc,
ip=ip,
)
# Tile the SMEM layout atom to the A tensor shape and add staging dimension
a_smem_layout_staged = cute.tile_to_shape(
a_smem_layout_atom,
cute.append(a_smem_shape, num_stages),
order=(0, 1, 2) if is_k_major else (0, 1, 2),
loc=loc,
ip=ip,
)
return a_smem_layout_staged