Skip to content

[BUG] orders are the same, despite whether A tensor is k_major #2874

@Willie-Qu

Description

@Willie-Qu

CuTe DSL

Bug Report

Hi, things are strange here. Orders are the same here, despite whether A tensor is k_major.
Plz check it, thank you.

order=(0, 1, 2) if is_k_major else (0, 1, 2),

@dsl_user_op
def make_smem_layout_a(
    a_layout: LayoutEnum,
    mma_tiler_mnk: cute.Tile,
    a_dtype: Type[Numeric],
    num_stages: int,
    *,
    loc=None,
    ip=None,
) -> Union[cute.Layout, cute.ComposedLayout]:
    """This function helps with:

    1. Get the partitioned shape of the A tensor based on the MMA tiler.
    2. Select the heuristic SMEM layout atom based on the A tensor's majorness, the data type, and the major mode size.
    3. cute.Tile the SMEM layout atom to the MMA tile shape.
    4. Stage the SMEM layout based on the number of stages.

    :param a_layout: The layout enum for tensor A
    :type a_layout: LayoutEnum
    :param mma_tiler_mnk: The MMA tile shape
    :type mma_tiler_mnk: cute.cute.Tile
    :param a_dtype: The element type for tensor A
    :type a_dtype: Type[Numeric]
    :param num_stages: The number of pipeline stages for tensor A
    :type num_stages: int

    :return: SMEM layout for tensor A
    :rtype: Union[cute.Layout, cute.ComposedLayout]
    """
    # Extract A tensor shape from the MMA tiler (M dimension)
    a_tile_shape_mnk = mma_tiler_mnk
    a_smem_shape = cute.slice_(a_tile_shape_mnk, (None, 0, None), loc=loc, ip=ip)

    # Determine if K is the major mode and get the major mode size
    is_k_major = a_layout.is_k_major_a()
    a_major_mode_size = a_tile_shape_mnk[2] if is_k_major else a_tile_shape_mnk[0]

    # Create SMEM layout atom for A tensor based on major mode and data type
    a_smem_layout_atom = make_smem_layout_atom(
        get_smem_layout_atom(a_layout, a_dtype, a_major_mode_size, loc=loc, ip=ip),
        a_dtype,
        loc=loc,
        ip=ip,
    )

    # Tile the SMEM layout atom to the A tensor shape and add staging dimension
    a_smem_layout_staged = cute.tile_to_shape(
        a_smem_layout_atom,
        cute.append(a_smem_shape, num_stages),
        order=(0, 1, 2) if is_k_major else (0, 1, 2),
        loc=loc,
        ip=ip,
    )

    return a_smem_layout_staged

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions