Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
30 changes: 21 additions & 9 deletions flashmask/flash_mask/cute/blackwell_helpers.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@ def gemm_w_idx(
B_idx: Optional[Int32] = None,
zero_init: bool | Boolean = False,
swap_AB: bool = False,
num_unroll_groups: int = 1,
) -> None:
if const_expr(swap_AB):
return gemm_w_idx(
Expand All @@ -30,7 +31,9 @@ def gemm_w_idx(
rA = tCrA if const_expr(A_idx is None) else tCrA[None, None, None, A_idx]
rB = tCrB if const_expr(B_idx is None) else tCrB[None, None, None, B_idx]
mma_atom = cute.make_mma_atom(tiled_mma.op)
for k in cutlass.range_constexpr(cute.size(tCrA.shape[2])):
for k in cutlass.range(
cute.size(tCrA.shape[2]), unroll=cute.size(tCrA.shape[2]) // num_unroll_groups
):
mma_atom.set(tcgen05.Field.ACCUMULATE, not zero_init or k != 0)
cute.gemm(mma_atom, acc, rA[None, None, k], rB[None, None, k], acc)

Expand All @@ -46,6 +49,7 @@ def gemm_ptx_w_idx(
A_idx: Optional[Int32] = None,
B_idx: Optional[Int32] = None,
zero_init: bool | Boolean = False,
cta_group: int = 1,
**kwargs,
) -> None:
rA = tCrA if const_expr(A_idx is None) else tCrA[None, None, None, A_idx]
Expand All @@ -57,10 +61,17 @@ def gemm_ptx_w_idx(
mma_atom = cute.make_mma_atom(tiled_mma.op)
acc_tmem_addr = acc.iterator.toint()
gemm_ptx_partial(
mma_atom.op, acc_tmem_addr, rA, rB, sA_cur, sB_cur, zero_init=zero_init, **kwargs
mma_atom.op,
acc_tmem_addr,
rA,
rB,
sA_cur,
sB_cur,
zero_init=zero_init,
cta_group=cta_group,
**kwargs,
)


@cute.jit
def gemm(
tiled_mma: cute.TiledMma,
Expand Down Expand Up @@ -372,6 +383,7 @@ def gemm_ptx_partial(
# sA_offset: Int32 = 0,
# acc_offset: Int32 = 0,
tA_addr: Optional[Int32] = None,
cta_group: int = 1,
) -> None:
# acc_tmem_addr += acc_offset
is_ts = op.a_src == cute.nvgpu.tcgen05.OperandSource.TMEM
Expand Down Expand Up @@ -463,7 +475,7 @@ def gemm_ptx_partial(
f"mov.b64 smem_desc_a, {{smem_desc_a_lo_start, smem_desc_a_hi}};\n\t"
f"mov.b64 smem_desc_b, {{smem_desc_b_lo_start, smem_desc_b_hi}};\n\t"
"setp.ne.b32 p, $2, 0;\n\t"
f"@leader_thread tcgen05.mma.cta_group::1.kind::f16 [tmem_acc], smem_desc_a, smem_desc_b, idesc, {pred_str};\n\t"
f"@leader_thread tcgen05.mma.cta_group::{cta_group}.kind::f16 [tmem_acc], smem_desc_a, smem_desc_b, idesc, {pred_str};\n\t"
+ "".join(
(
# f"add.u32 smem_desc_a_lo, smem_desc_a_lo, {hex(offset_a_diff[k - 1])};\n\t"
Expand All @@ -472,7 +484,7 @@ def gemm_ptx_partial(
f"add.u32 smem_desc_b_lo, smem_desc_b_lo_start, {hex(offset_b[k])};\n\t"
f"mov.b64 smem_desc_a, {{smem_desc_a_lo, smem_desc_a_hi}};\n\t"
f"mov.b64 smem_desc_b, {{smem_desc_b_lo, smem_desc_b_hi}};\n\t"
f"@leader_thread tcgen05.mma.cta_group::1.kind::f16 [tmem_acc], smem_desc_a, smem_desc_b, idesc, 1;\n\t"
f"@leader_thread tcgen05.mma.cta_group::{cta_group}.kind::f16 [tmem_acc], smem_desc_a, smem_desc_b, idesc, 1;\n\t"
)
for k in range(1, cute.size(tCrA.shape[2]))
)
Expand Down Expand Up @@ -536,15 +548,15 @@ def gemm_ptx_partial(
f"mov.b32 smem_desc_b_hi, {hex(smem_desc_b_hi)};\n\t"
f"mov.b64 smem_desc_b, {{smem_desc_b_lo_start, smem_desc_b_hi}};\n\t"
"setp.ne.b32 p, $2, 0;\n\t"
f"@leader_thread tcgen05.mma.cta_group::1.kind::f16 [tmem_acc], [tmem_a], smem_desc_b, idesc, {pred_str};\n\t"
f"@leader_thread tcgen05.mma.cta_group::{cta_group}.kind::f16 [tmem_acc], [tmem_a], smem_desc_b, idesc, {pred_str};\n\t"
+ "".join(
(
# f"add.u32 tmem_a, tmem_a, {hex(offset_a_diff[k - 1])};\n\t"
# f"add.u32 smem_desc_b_lo, smem_desc_b_lo, {hex(offset_b_diff[k - 1])};\n\t"
f"add.u32 smem_desc_b_lo, smem_desc_b_lo_start, {hex(offset_b[k])};\n\t"
f"mov.b64 smem_desc_b, {{smem_desc_b_lo, smem_desc_b_hi}};\n\t"
# f"@leader_thread tcgen05.mma.cta_group::1.kind::f16 [tmem_acc], [tmem_a], smem_desc_b, idesc, 1;\n\t"
f"@leader_thread tcgen05.mma.cta_group::1.kind::f16 [tmem_acc], [tmem_a + {hex(offset_a[k])}], smem_desc_b, idesc, 1;\n\t"
f"@leader_thread tcgen05.mma.cta_group::{cta_group}.kind::f16 [tmem_acc], [tmem_a + {hex(offset_a[k])}], smem_desc_b, idesc, 1;\n\t"
)
for k in range(
1,
Expand All @@ -559,7 +571,7 @@ def gemm_ptx_partial(
(
f"add.u32 smem_desc_b_lo, smem_desc_b_lo, {hex(offset_b_diff[k - 1])};\n\t"
f"mov.b64 smem_desc_b, {{smem_desc_b_lo, smem_desc_b_hi}};\n\t"
f"@leader_thread tcgen05.mma.cta_group::1.kind::f16 [tmem_acc], [tmem_a + {hex(offset_a[k])}], smem_desc_b, idesc, 1;\n\t"
f"@leader_thread tcgen05.mma.cta_group::{cta_group}.kind::f16 [tmem_acc], [tmem_a + {hex(offset_a[k])}], smem_desc_b, idesc, 1;\n\t"
)
for k in range(cute.size(tCrA.shape[2]) // 4 * 3, cute.size(tCrA.shape[2]))
)
Expand Down Expand Up @@ -750,4 +762,4 @@ def gemm_ptx_partial1(
has_side_effects=True,
is_align_stack=False,
asm_dialect=llvm.AsmDialect.AD_ATT,
)
)
32 changes: 32 additions & 0 deletions flashmask/flash_mask/cute/copy_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,6 +207,38 @@ def store_shared_remote_fp32x4(
)


@dsl_user_op
def cpasync_bulk_s2cluster(
smem_src_ptr: cute.Pointer,
smem_dst_ptr: cute.Pointer,
mbar_ptr: cute.Pointer,
size: int | Int32,
peer_cta_rank_in_cluster: Int32,
*,
loc=None,
ip=None,
):
smem_src_ptr_i32 = smem_src_ptr.toint(loc=loc, ip=ip).ir_value()
smem_dst_ptr_i32 = set_block_rank(
smem_dst_ptr, peer_cta_rank_in_cluster, loc=loc, ip=ip
).ir_value()
mbar_ptr_i32 = set_block_rank(mbar_ptr, peer_cta_rank_in_cluster, loc=loc, ip=ip).ir_value()
llvm.inline_asm(
None,
[
smem_dst_ptr_i32,
smem_src_ptr_i32,
mbar_ptr_i32,
Int32(size).ir_value(loc=loc, ip=ip),
],
"cp.async.bulk.shared::cluster.shared::cta.mbarrier::complete_tx::bytes [$0], [$1], $3, [$2];",
"r,r,r,r",
has_side_effects=True,
is_align_stack=False,
asm_dialect=llvm.AsmDialect.AD_ATT,
)


@dsl_user_op
def cpasync_bulk_g2s(
gmem_ptr: cute.Pointer,
Expand Down
Loading