Skip to content

Commit e19678b

Browse files
committed
Enable tcgen05 blockscaled ops on Thor SM110
Edge-LLM NvFP4 MoE CuTeDSL kernels on Thor use tcgen05 blockscaled MMA and SMEM-to-TMEM scale-factor copies. The existing checks only admitted the SM100/SM103 paths, so source-built CuTeDSL rejected SM110. Admit Thor's blockscaled MMA arch aliases sm_101a and sm_110a, and allow the SM110f family for S2T tcgen05 copy ops. Validation: - git diff --check - python3 -m py_compile python/CuTeDSL/cutlass/cute/nvgpu/tcgen05/mma.py python/CuTeDSL/cutlass/cute/nvgpu/tcgen05/copy.py - DKG grouped_blockscaled_gemm.py documented 4-group example on Thor SM110: PASS - Edge-LLM nvfp4_moe AOT for sm_110/aarch64: 12/12 variants PASS
1 parent 9c1d096 commit e19678b

2 files changed

Lines changed: 9 additions & 3 deletions

File tree

python/CuTeDSL/cutlass/cute/nvgpu/tcgen05/copy.py

Lines changed: 7 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -786,16 +786,20 @@ class _S2TCopyBase(CopyOp):
786786
787787
:param cta_group: Cooperative Thread Array (CTA) group configuration
788788
:type cta_group: CtaGroup
789-
:raises OpError: If the current architecture is not SM100f family or if invalid parameters are provided
789+
:raises OpError: If the current architecture is not SM100f or SM110f
790+
family or if invalid parameters are provided
790791
"""
791792

792793
cta_group: CtaGroup
793794

794795
def __post_init__(self) -> None:
795796
# Arch verification
796797
arch = BaseDSL._get_dsl().get_arch_enum()
797-
if not arch.is_family_of(Arch.sm_100f):
798-
supported = Arch.filter(lambda a: a.is_family_of(Arch.sm_100f))
798+
# S2T tcgen05 copy encodings are valid on both SM100 and Thor SM110.
799+
if not (arch.is_family_of(Arch.sm_100f) or arch.is_family_of(Arch.sm_110f)):
800+
supported = Arch.filter(
801+
lambda a: a.is_family_of(Arch.sm_100f) or a.is_family_of(Arch.sm_110f)
802+
)
799803
raise OpError(
800804
self,
801805
f"expects arch to be one of {supported}, but got {arch}",

python/CuTeDSL/cutlass/cute/nvgpu/tcgen05/mma.py

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -386,7 +386,9 @@ class BlockScaledMmaOp(Tcgen05MmaOp):
386386

387387
admissible_archs = [
388388
Arch.sm_100a,
389+
Arch.sm_101a,
389390
Arch.sm_103a,
391+
Arch.sm_110a,
390392
]
391393

392394
def __post_init__(self) -> None:

0 commit comments

Comments
 (0)