Skip to content

Commit d392203

Browse files
authored
[Blackwell] Support mixed precision in mxfp tutorial (#6204)
Now that TMA load for padded fp4 is supported, update the tutorial for mixed precision (`--format mixed`). ~Either cpasync or device TMA can be used for fp4. To use TMA, set `--mixed-fp4-tma`.~ (UPDATE: Now supports only TMA) Using TMA makes it significantly faster (up to 50%) If #6194 goes in first, I'll drop `_experimental_` here. --------- Co-authored-by: Masahiro Masuda <mmasuda@nvidia.com>
1 parent 8f5984c commit d392203

1 file changed

Lines changed: 107 additions & 44 deletions

File tree

python/tutorials/10-block-scaled-matmul.py

Lines changed: 107 additions & 44 deletions
Original file line numberDiff line numberDiff line change
@@ -88,10 +88,12 @@ def _matmul_launch_metadata(grid, kernel, args):
8888
ret = {}
8989
M, N, K = args["M"], args["N"], args["K"]
9090
kernel_name = kernel.name
91-
if "ELEM_PER_BYTE" and "VEC_SIZE" in args:
92-
if args["ELEM_PER_BYTE"] == 1:
91+
if "ELEM_PER_BYTE_A" and "ELEM_PER_BYTE_B" and "VEC_SIZE" in args:
92+
if args["ELEM_PER_BYTE_A"] == 1 and args["ELEM_PER_BYTE_B"] == 1:
9393
kernel_name += "_mxfp8"
94-
elif args["ELEM_PER_BYTE"] == 2:
94+
elif args["ELEM_PER_BYTE_A"] == 1 and args["ELEM_PER_BYTE_B"] == 2:
95+
kernel_name += "_mixed"
96+
elif args["ELEM_PER_BYTE_A"] == 2 and args["ELEM_PER_BYTE_B"] == 2:
9597
if args["VEC_SIZE"] == 16:
9698
kernel_name += "_nvfp4"
9799
elif args["VEC_SIZE"] == 32:
@@ -104,23 +106,29 @@ def _matmul_launch_metadata(grid, kernel, args):
104106
@triton.jit(launch_metadata=_matmul_launch_metadata)
105107
def block_scaled_matmul_kernel( #
106108
a_desc, a_scale, #
107-
b_desc, b_scale, #
109+
b_desc_or_tensor, b_scale, #
108110
c_desc, #
109111
M: tl.constexpr, N: tl.constexpr, K: tl.constexpr, #
110112
stride_sk: tl.constexpr, stride_sb: tl.constexpr, stride_sc: tl.constexpr, stride_sd: tl.constexpr,
111113
output_type: tl.constexpr, #
112-
ELEM_PER_BYTE: tl.constexpr, #
114+
ELEM_PER_BYTE_A: tl.constexpr, #
115+
ELEM_PER_BYTE_B: tl.constexpr, #
113116
VEC_SIZE: tl.constexpr, #
114117
BLOCK_M: tl.constexpr, #
115118
BLOCK_N: tl.constexpr, #
116119
BLOCK_K: tl.constexpr, #
117120
NUM_STAGES: tl.constexpr, #
118121
USE_2D_SCALE_LOAD: tl.constexpr): #
119122

120-
if ELEM_PER_BYTE == 1:
121-
dtype = tl.float8e4nv
122-
elif ELEM_PER_BYTE == 2:
123-
dtype = tl.dtype("uint8")
123+
if ELEM_PER_BYTE_A == 1:
124+
dtype_a = tl.float8e4nv
125+
elif ELEM_PER_BYTE_A == 2:
126+
dtype_a = tl.dtype("uint8")
127+
128+
if ELEM_PER_BYTE_B == 1:
129+
dtype_b = tl.float8e4nv
130+
elif ELEM_PER_BYTE_B == 2:
131+
dtype_b = tl.dtype("uint8")
124132

125133
if output_type == 0:
126134
output_dtype = tl.float32
@@ -129,25 +137,38 @@ def block_scaled_matmul_kernel( #
129137
elif output_type == 2:
130138
output_dtype = tl.float8e4nv
131139

132-
tl.inline_asm_elementwise("prefetch.tensormap [$1]; // dummy $0", "=r,l", [a_desc], dtype=tl.int32, is_pure=False,
133-
pack=1)
134-
tl.inline_asm_elementwise("prefetch.tensormap [$1]; // dummy $0", "=r,l", [b_desc], dtype=tl.int32, is_pure=False,
135-
pack=1)
136-
tl.inline_asm_elementwise("prefetch.tensormap [$1]; // dummy $0", "=r,l", [c_desc], dtype=tl.int32, is_pure=False,
137-
pack=1)
138-
139140
pid = tl.program_id(axis=0)
140141
num_pid_m = tl.cdiv(M, BLOCK_M)
141142
pid_m = pid % num_pid_m
142143
pid_n = pid // num_pid_m
143144
offs_am = pid_m * BLOCK_M
144145
offs_bn = pid_n * BLOCK_N
145-
offs_k = 0
146+
offs_k_a = 0
147+
offs_k_b = 0
146148

147149
## block scale offsets
148150
offs_sm = (pid_m * (BLOCK_M // 128) + tl.arange(0, BLOCK_M // 128)) % M
149151
offs_sn = (pid_n * (BLOCK_N // 128) + tl.arange(0, BLOCK_N // 128)) % N
150152

153+
MIXED_PREC: tl.constexpr = ELEM_PER_BYTE_A == 1 and ELEM_PER_BYTE_B == 2
154+
155+
if MIXED_PREC:
156+
b_desc = tl.make_tensor_descriptor(
157+
b_desc_or_tensor,
158+
shape=[N, K // ELEM_PER_BYTE_B],
159+
strides=[K // ELEM_PER_BYTE_B, 1],
160+
block_shape=[BLOCK_N, BLOCK_K // ELEM_PER_BYTE_B],
161+
)
162+
else:
163+
b_desc = b_desc_or_tensor
164+
tl.inline_asm_elementwise("prefetch.tensormap [$1]; // dummy $0", "=r,l", [b_desc], dtype=tl.int32,
165+
is_pure=False, pack=1)
166+
167+
tl.inline_asm_elementwise("prefetch.tensormap [$1]; // dummy $0", "=r,l", [a_desc], dtype=tl.int32, is_pure=False,
168+
pack=1)
169+
tl.inline_asm_elementwise("prefetch.tensormap [$1]; // dummy $0", "=r,l", [c_desc], dtype=tl.int32, is_pure=False,
170+
pack=1)
171+
151172
# For now it is recommended to use 2D scale loads for better performance.
152173
# In the future we will bring additional optimizations to either allow 5D loads,
153174
# the use of TMAs for scale factors, or both.
@@ -171,26 +192,39 @@ def block_scaled_matmul_kernel( #
171192

172193
accumulator = tl.zeros((BLOCK_M, BLOCK_N), dtype=tl.float32)
173194
for k in tl.range(0, tl.cdiv(K, BLOCK_K), num_stages=NUM_STAGES):
174-
a = tl._experimental_descriptor_load(a_desc, [offs_am, offs_k], [BLOCK_M, BLOCK_K // ELEM_PER_BYTE], dtype)
175-
b = tl._experimental_descriptor_load(b_desc, [offs_bn, offs_k], [BLOCK_N, BLOCK_K // ELEM_PER_BYTE], dtype)
195+
a = tl._experimental_descriptor_load(a_desc, [offs_am, offs_k_a], [BLOCK_M, BLOCK_K // ELEM_PER_BYTE_A],
196+
dtype_a)
197+
198+
if MIXED_PREC:
199+
b = b_desc.load([offs_bn, offs_k_b])
200+
else:
201+
b = tl._experimental_descriptor_load(b_desc, [offs_bn, offs_k_b], [BLOCK_N, BLOCK_K // ELEM_PER_BYTE_B],
202+
dtype_b)
203+
176204
scale_a = tl.load(a_scale_ptr)
177205
scale_b = tl.load(b_scale_ptr)
178206
if USE_2D_SCALE_LOAD:
179207
scale_a = scale_a.reshape(BLOCK_M // 128, BLOCK_K // VEC_SIZE // 4, 32, 4, 4)
180208
scale_b = scale_b.reshape(BLOCK_N // 128, BLOCK_K // VEC_SIZE // 4, 32, 4, 4)
181209
scale_a = scale_a.trans(0, 3, 2, 1, 4).reshape(BLOCK_M, BLOCK_K // VEC_SIZE)
182210
scale_b = scale_b.trans(0, 3, 2, 1, 4).reshape(BLOCK_N, BLOCK_K // VEC_SIZE)
183-
if ELEM_PER_BYTE == 2:
211+
212+
if MIXED_PREC:
213+
accumulator = tl.dot_scaled(a, scale_a, "e4m3", b.T, scale_b, "e2m1", accumulator)
214+
elif ELEM_PER_BYTE_A == 2 and ELEM_PER_BYTE_B == 2:
184215
accumulator = tl.dot_scaled(a, scale_a, "e2m1", b.T, scale_b, "e2m1", accumulator)
185216
else:
186217
accumulator = tl.dot_scaled(a, scale_a, "e4m3", b.T, scale_b, "e4m3", accumulator)
187-
offs_k += BLOCK_K // ELEM_PER_BYTE
218+
219+
offs_k_a += BLOCK_K // ELEM_PER_BYTE_A
220+
offs_k_b += BLOCK_K // ELEM_PER_BYTE_B
188221
a_scale_ptr += (BLOCK_K // VEC_SIZE // 4) * stride_sb
189222
b_scale_ptr += (BLOCK_K // VEC_SIZE // 4) * stride_sb
223+
190224
tl._experimental_descriptor_store(c_desc, accumulator.to(output_dtype), [offs_am, offs_bn])
191225

192226

193-
def block_scaled_matmul(a_desc, a_scale, b_desc, b_scale, dtype_dst, M, N, K, configs):
227+
def block_scaled_matmul(a_desc, a_scale, b_desc_or_tensor, b_scale, dtype_dst, M, N, K, configs):
194228
output = torch.empty((M, N), dtype=dtype_dst, device="cuda")
195229
if dtype_dst == torch.float32:
196230
dtype_dst = 0
@@ -205,11 +239,11 @@ def block_scaled_matmul(a_desc, a_scale, b_desc, b_scale, dtype_dst, M, N, K, co
205239
output.element_size())
206240

207241
grid = (triton.cdiv(M, configs["BLOCK_SIZE_M"]) * triton.cdiv(N, configs["BLOCK_SIZE_N"]), 1)
208-
block_scaled_matmul_kernel[grid](a_desc, a_scale, b_desc, b_scale, c_desc, M, N, K, a_scale.stride(0),
242+
block_scaled_matmul_kernel[grid](a_desc, a_scale, b_desc_or_tensor, b_scale, c_desc, M, N, K, a_scale.stride(0),
209243
a_scale.stride(1), a_scale.stride(2), a_scale.stride(3), dtype_dst,
210-
configs["ELEM_PER_BYTE"], configs["VEC_SIZE"], configs["BLOCK_SIZE_M"],
211-
configs["BLOCK_SIZE_N"], configs["BLOCK_SIZE_K"], configs["num_stages"],
212-
USE_2D_SCALE_LOAD=True)
244+
configs["ELEM_PER_BYTE_A"], configs["ELEM_PER_BYTE_B"], configs["VEC_SIZE"],
245+
configs["BLOCK_SIZE_M"], configs["BLOCK_SIZE_N"], configs["BLOCK_SIZE_K"],
246+
configs["num_stages"], USE_2D_SCALE_LOAD=True)
213247
return output
214248

215249

@@ -218,8 +252,9 @@ def initialize_block_scaled(M, N, K, block_scale_type="nvfp4", compute_reference
218252
BLOCK_N = 256
219253
BLOCK_K = 256 if "fp4" in block_scale_type else 128
220254
VEC_SIZE = 16 if block_scale_type == "nvfp4" else 32
221-
assert block_scale_type in ["nvfp4", "mxfp4", "mxfp8"], f"Invalid block scale type: {block_scale_type}"
222-
ELEM_PER_BYTE = 2 if "fp4" in block_scale_type else 1
255+
assert block_scale_type in ["nvfp4", "mxfp4", "mxfp8", "mixed"], f"Invalid block scale type: {block_scale_type}"
256+
ELEM_PER_BYTE_A = 2 if "fp4" in block_scale_type else 1
257+
ELEM_PER_BYTE_B = 1 if block_scale_type == "mxfp8" else 2
223258

224259
device = "cuda"
225260
a_ref = MXFP4Tensor(size=(M, K), device=device).random()
@@ -229,20 +264,32 @@ def initialize_block_scaled(M, N, K, block_scale_type="nvfp4", compute_reference
229264
# the data is generated in col-major layout, packed along K for fp4, and then
230265
# logically transposed. Note that if one operand is of fp8 precision, unlike Hopper,
231266
# Blackwell supports both row-major and col-major layouts for the RHS matrix.
267+
# For the mixed-precision case, the fp4 RHS can be either in row or col-major layout.
268+
# But for performance reason, it is recommended to use col-major layout. If TMA is used
269+
# for the fp4 RHS operand load in mixed-precision dot, as in this tutorial, it must be
270+
# in col-major layout.
232271
b_ref = MXFP4Tensor(size=(N, K), device=device).random()
233-
if block_scale_type == "mxfp8":
272+
if block_scale_type in ["mxfp8", "mixed"]:
234273
a_ref = a_ref.to(torch.float32)
235-
b_ref = b_ref.to(torch.float32)
236274
a = a_ref.to(torch.float8_e4m3fn)
237-
b = b_ref.to(torch.float8_e4m3fn)
238275
else:
239276
# Pack two fp4 elements per byte along K
240277
a = a_ref.to_packed_tensor(dim=1)
278+
279+
if block_scale_type == "mxfp8":
280+
b_ref = b_ref.to(torch.float32)
281+
b = b_ref.to(torch.float8_e4m3fn)
282+
else:
241283
b = b_ref.to_packed_tensor(dim=1)
284+
242285
b_ref = b_ref.to(torch.float32).T
243286

244-
a_desc = TmaDescKernelParam(a.data_ptr(), a.shape, [BLOCK_M, BLOCK_K // ELEM_PER_BYTE], 1)
245-
b_desc = TmaDescKernelParam(b.data_ptr(), b.shape, [BLOCK_N, BLOCK_K // ELEM_PER_BYTE], 1)
287+
a_desc = TmaDescKernelParam(a.data_ptr(), a.shape, [BLOCK_M, BLOCK_K // ELEM_PER_BYTE_A], 1)
288+
289+
if block_scale_type == "mixed":
290+
b_desc_or_tensor = b
291+
else:
292+
b_desc_or_tensor = TmaDescKernelParam(b.data_ptr(), b.shape, [BLOCK_N, BLOCK_K // ELEM_PER_BYTE_B], 1)
246293

247294
epsilon = 1e-8
248295
a_scale = torch.rand((M // 128, K // VEC_SIZE // 4, 32, 4, 4), device=device) + epsilon
@@ -252,7 +299,7 @@ def initialize_block_scaled(M, N, K, block_scale_type="nvfp4", compute_reference
252299
b_scale = b_scale.to(torch.float8_e4m3fn)
253300
a_scale_ref = a_scale
254301
b_scale_ref = b_scale
255-
elif block_scale_type in ["mxfp4", "mxfp8"]:
302+
elif block_scale_type in ["mxfp4", "mxfp8", "mixed"]:
256303
a_scale_ref = MXScaleTensor(a_scale)
257304
b_scale_ref = MXScaleTensor(b_scale)
258305
a_scale = a_scale_ref.data
@@ -276,16 +323,26 @@ def unpack_scale(packed):
276323
"BLOCK_SIZE_N": BLOCK_N,
277324
"BLOCK_SIZE_K": BLOCK_K,
278325
"num_stages": 4,
279-
"ELEM_PER_BYTE": ELEM_PER_BYTE,
326+
"ELEM_PER_BYTE_A": ELEM_PER_BYTE_A,
327+
"ELEM_PER_BYTE_B": ELEM_PER_BYTE_B,
280328
"VEC_SIZE": VEC_SIZE,
281329
}
282-
return a_desc, a_scale, b_desc, b_scale, configs, reference
330+
return a_desc, a_scale, b_desc_or_tensor, b_scale, configs, reference
283331

284332

285333
def validate_block_scaled(M, N, K, block_scale_type="nvfp4"):
286-
a_desc, a_scale, b_desc, b_scale, configs, reference = initialize_block_scaled(M, N, K, block_scale_type,
287-
compute_reference=True)
288-
output = block_scaled_matmul(a_desc, a_scale, b_desc, b_scale, torch.float16, M, N, K, configs)
334+
335+
def alloc_fn(size: int, align: int, _):
336+
return torch.empty(size, dtype=torch.int8, device="cuda")
337+
338+
if block_scale_type == "mixed":
339+
# This is needed for TMA with the descriptor created on the device.
340+
# TMA load for mixed-precision fp4 is supported only by device TMA.
341+
triton.set_allocator(alloc_fn)
342+
343+
a_desc, a_scale, b_desc_or_tensor, b_scale, configs, reference = initialize_block_scaled(
344+
M, N, K, block_scale_type, compute_reference=True)
345+
output = block_scaled_matmul(a_desc, a_scale, b_desc_or_tensor, b_scale, torch.float16, M, N, K, configs)
289346
torch.testing.assert_close(reference, output.to(torch.float32), atol=1e-3, rtol=1e-3)
290347
print(f"✅ (pass {block_scale_type})")
291348

@@ -296,13 +353,19 @@ def bench_block_scaled(K, block_scale_type="nvfp4", reps=10):
296353
N = 8192
297354
print(f"Problem Shape = {M}x{N}x{K}")
298355

299-
a_desc, a_scale, b_desc, b_scale, configs, _ = initialize_block_scaled(M, N, K, block_scale_type,
300-
compute_reference=False)
301-
_ = block_scaled_matmul(a_desc, a_scale, b_desc, b_scale, torch.float16, M, N, K, configs)
356+
def alloc_fn(size: int, align: int, _):
357+
return torch.empty(size, dtype=torch.int8, device="cuda")
358+
359+
if block_scale_type == "mixed":
360+
triton.set_allocator(alloc_fn)
361+
362+
a_desc, a_scale, b_desc_or_tensor, b_scale, configs, _ = initialize_block_scaled(
363+
M, N, K, block_scale_type, compute_reference=False)
364+
_ = block_scaled_matmul(a_desc, a_scale, b_desc_or_tensor, b_scale, torch.float16, M, N, K, configs)
302365

303366
proton.activate(0)
304367
for _ in range(reps):
305-
_ = block_scaled_matmul(a_desc, a_scale, b_desc, b_scale, torch.float16, M, N, K, configs)
368+
_ = block_scaled_matmul(a_desc, a_scale, b_desc_or_tensor, b_scale, torch.float16, M, N, K, configs)
306369
proton.deactivate(0)
307370
print("Done benchmarking")
308371

@@ -321,7 +384,7 @@ def show_profile(profile_name):
321384
parser.add_argument("--K_range", type=int, nargs=2)
322385
parser.add_argument("--K_step", type=int, default=512)
323386
parser.add_argument("--bench", action="store_true")
324-
parser.add_argument("--format", type=str, choices=["mxfp4", "nvfp4", "mxfp8"], default="nvfp4")
387+
parser.add_argument("--format", type=str, choices=["mxfp4", "nvfp4", "mxfp8", "mixed"], default="nvfp4")
325388
args = parser.parse_args()
326389

327390
if not supports_block_scaling():

0 commit comments

Comments
 (0)