Skip to content

[QST][CuteDSL] Why cute.make_layout throw an error about dynamic/static shape #2730

@zfan2356

Description

@zfan2356

What is your question?
In this code below, I meet an error about cute.make_layout

import cuda.bindings.driver as cuda
import cutlass
import torch
from cutlass import const_expr, cute

def convert_from_dlpack(
    t: torch.Tensor,
    dynamic_dims: int | None = None,
) -> cute.Tensor:
    cute_tensor = cute.runtime.from_dlpack(t.detach())
    cute_tensor = cute_tensor.mark_compact_shape_dynamic(mode=dynamic_dims)
    return cute_tensor

@cute.jit
def _test_bug_func(
    a: cute.Tensor,
    b: cute.Tensor,
    m: cutlass.Int32,
    n: int,
    k: int,
):
    shape_a = (cutlass.Int32(m), n)

    new_stride = lambda t: (
        cute.assume(cutlass.Int32(t.stride[0]), divby=128 // t.element_type.width),
        *t.stride[1:],
    )
    a = cute.make_tensor(a.iterator, layout=cute.make_layout(shape=shape_a, stride=new_stride(a))) # created successfully

    shape_b = (cutlass.Int32(m), k)
    b = cute.make_tensor(b.iterator, layout=cute.make_layout(shape=shape_b, stride=new_stride(b))) # failed

def test_bug(
    a: torch.Tensor,
    b: torch.Tensor,
):
    m, n = a.shape
    _, k = b.shape

    a_ = convert_from_dlpack(a, dynamic_dims=0)
    b_ = convert_from_dlpack(b, dynamic_dims=0)
    _test_bug_func(a_, b_, m, n, k)


if __name__ == "__main__":
    a = torch.randn((128, 256), device="cuda:0", dtype=torch.bfloat16)
    b = torch.randn((128, 512), device="cuda:0", dtype=torch.bfloat16)
    test_bug(a, b)

but when I exec this file, cute dsl throw an error below

 Traceback (most recent call last):
  File "/data/home/report_bug.py", line 48, in <module>
    test_bug(a, b)
  File "/data/home/report_bug.py", line 42, in test_bug
    _test_bug_func(a_, b_, m, n, k)
  File "/data/home/report_bug.py", line 31, in _test_bug_func
    b = cute.make_tensor(b.iterator, layout=cute.make_layout(shape=shape_b, stride=new_stride(b)))
                                            ^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^
  File "/data/home/envs/dev/lib/python3.12/site-packages/nvidia_cutlass_dsl/python_packages/cutlass/base_dsl/typing.py", line 1098, in __bool__
    raise DSLRuntimeError(
cutlass.base_dsl.common.DSLRuntimeError: DSLRuntimeError: Unable to convert dynamic `Boolean` value to bool at compile time.
💡 Suggestions:
 Decorate the parent function with `jit` decorator and with `preprocess` enabled.
 Ensure not using patterns that DSL does not support.
 Otherwise, please file a bug report.

It seems that I'm using a dynamic dimension in a static layout, but only the first dimension is dynamic.
I used cutlass.Int32 to mark this dimension as dynamic, and the layout for tensor A was created successfully.
However, the layout creation for tensor B failed.

But this code can run success, just because we pass two m into func, what? Is m be moved ownership in construct shape_a?

@cute.jit
def _test_bug_func(
    a: cute.Tensor,
    b: cute.Tensor,
    m: cutlass.Int32,
    m_: cutlass.Int32,
    n: int,
    k: int,
):
    shape_a = (cutlass.Int32(m), n)

    new_stride = lambda t: (
        cute.assume(cutlass.Int32(t.stride[0]), divby=128 // t.element_type.width),
        *t.stride[1:],
    )
    a = cute.make_tensor(a.iterator, layout=cute.make_layout(shape=shape_a, stride=new_stride(a)))

    shape_b = (cutlass.Int32(m_), k)
    b = cute.make_tensor(b.iterator, layout=cute.make_layout(shape=shape_b, stride=new_stride(b)))

def test_bug(
    a: torch.Tensor,
    b: torch.Tensor,
):
    m, n = a.shape
    _, k = b.shape

    a_ = convert_from_dlpack(a, dynamic_dims=0)
    b_ = convert_from_dlpack(b, dynamic_dims=0)
    _test_bug_func(a_, b_, m, m, n, k)


if __name__ == "__main__":
    a = torch.randn((128, 256), device="cuda:0", dtype=torch.bfloat16)
    b = torch.randn((128, 256), device="cuda:0", dtype=torch.bfloat16)
    test_bug(a, b)

Thanks for your help!

Metadata

Metadata

Assignees

No one assigned

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions