-
Notifications
You must be signed in to change notification settings - Fork 1.6k
Open
Labels
Description
Hi, I have a question about the class LdMatrix16x16x8bOp. Transpose should not be asserted to be true, since PTX supports {.trans} modifier for ldmatrix.
@dataclass(frozen=True)
class LdMatrix16x16x8bOp(BaseOp):
"""
16x16 8-bit ``ldmatrix`` Operation.
See the `PTX documentation <https://docs.nvidia.com/cuda/parallel-thread-execution/#warp-level-matrix-load-instruction-ldmatrix>`__.
This operation corresponds to the ``.m16n16`` and the ``.b16`` qualifiers.
"""
def __init__(self, num_matrices: int) -> None:
super().__init__(transpose=True, num_matrices=num_matrices)
self._verify()
def _verify(self):
assert self.transpose, "transpose must be True"
if self.num_matrices not in [1, 2]:
raise OpError(
self,
"expects the 'num_matrices' Op parameter to be one of [1,2]",
)
def _make_trait(
self, copy_internal_type: Type[Numeric], *, loc=None, ip=None, **kwargs
) -> "LdMatrix16x16x8bTrait":
mode = _pack_shape((16, 16), loc=loc, ip=ip)
ty = _cute_nvgpu_ir.CopyAtomLdsmType.get(
copy_internal_type.mlir_type,
mode.type.attribute,
_cute_nvgpu_ir.LdsmSzPattern.u8,
self.num_matrices,
ir.UnitAttr.get(),
)
return LdMatrix16x16x8bTrait(cute.make_atom(ty, loc=loc, ip=ip))