-
Notifications
You must be signed in to change notification settings - Fork 924
Expand file tree
/
Copy pathtest_trtllm_cutlass_fused_moe.py
More file actions
1924 lines (1645 loc) · 62.6 KB
/
test_trtllm_cutlass_fused_moe.py
File metadata and controls
1924 lines (1645 loc) · 62.6 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
"""
Copyright (c) 2025 by FlashInfer team.
Licensed under the Apache License, Version 2.0 (the "License");
you may not use this file except in compliance with the License.
You may obtain a copy of the License at
http://www.apache.org/licenses/LICENSE-2.0
Unless required by applicable law or agreed to in writing, software
distributed under the License is distributed on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
See the License for the specific language governing permissions and
limitations under the License.
"""
from contextlib import nullcontext
import pytest
from flashinfer.fused_moe.core import ActivationType
import torch
from torch.nn import functional as F
import flashinfer.fused_moe as fused_moe
from flashinfer.utils import is_sm100a_supported
from flashinfer import (
autotune,
fp4_quantize,
mxfp4_dequantize,
mxfp4_quantize,
mxfp8_dequantize_host,
mxfp8_quantize,
mxfp4_dequantize_host,
)
FLOAT4_E2M1_MAX = 6.0
FLOAT8_E4M3_MAX = torch.finfo(torch.float8_e4m3fn).max
FP8_DTYPE = torch.float8_e4m3fn
def dynamic_per_tensor_fp8_quant(x: torch.tensor) -> tuple[torch.tensor, torch.tensor]:
fp8_traits_max = FLOAT8_E4M3_MAX
fp8_traits_min = -FLOAT8_E4M3_MAX
fp8_max = torch.tensor(fp8_traits_max).float()
one = torch.tensor(1.0).float()
x_max = x.abs().max().float()
scale = x_max / fp8_max
iscale = one / scale
out = (x.float() * iscale).clamp(fp8_traits_min, fp8_traits_max).to(FP8_DTYPE)
return out, scale.view((1,))
def gen_tensor(shape, dtype, stype=None, scale=1.0):
x = torch.randn(*shape, dtype=dtype).cuda() * scale
return x.to(stype) if stype else x
def cast_to_representable(x):
x_q, x_scale = dynamic_per_tensor_fp8_quant(x)
x = x_q.to(x.dtype) * x_scale.to(x.dtype)
return x
def convert_swizzled_to_linear(a_sf_swizzled: torch.Tensor, m, k, block_size):
m_tiles = (m + 128 - 1) // 128
f = block_size * 4
k_tiles = (k + f - 1) // f
tmp = torch.reshape(a_sf_swizzled, (1, m_tiles, k_tiles, 32, 4, 4))
tmp = torch.permute(tmp, (0, 1, 4, 3, 2, 5))
out = tmp.reshape(m_tiles * 128, k_tiles * f // block_size)
return out[0:m, 0:k]
def dequantize_nvfp4_to_dtype(
tensor_fp4, tensor_sf, global_scale, dtype, device, block_size=16
):
"""Dequantize the fp4 tensor back to high precision."""
# Two fp4 values are packed into one uint8.
assert tensor_fp4.dtype == torch.uint8
m, packed_k = tensor_fp4.shape
k = packed_k * 2
tensor_f32 = break_fp4_bytes(tensor_fp4, dtype)
tensor_f32 = tensor_f32.reshape(m, k // block_size, block_size)
tensor_sf = tensor_sf.view(torch.float8_e4m3fn)
tensor_sf = convert_swizzled_to_linear(tensor_sf, m, k, block_size)
tensor_sf_dtype = tensor_sf.to(torch.float32) / global_scale
# scale the tensor
out = (tensor_f32 * tensor_sf_dtype.unsqueeze(-1)).reshape(m, k)
return out.to(dtype=dtype)
def break_fp4_bytes(a, dtype):
assert a.dtype == torch.uint8
m, n = a.shape
# Vectorized nibble processing
a_flat = a.flatten()
high = (a_flat & 0xF0) >> 4 # Upper nibbles
low = a_flat & 0x0F # Lower nibbles
# Combine nibbles for batch processing
combined = torch.stack((low, high), dim=1).flatten()
# Vectorized sign and magnitude extraction
signs = (combined & 0x08).to(torch.bool) # Sign bits
abs_vals = (combined & 0x07).to(torch.long) # Magnitude indices
# Device-aware lookup and sign application
kE2M1ToFloat = torch.tensor(
[0.0, 0.5, 1.0, 1.5, 2.0, 3.0, 4.0, 6.0], dtype=torch.float32
)
kE2M1 = kE2M1ToFloat.to(device=a.device)
values = kE2M1[abs_vals] * torch.where(signs, -1.0, 1.0)
# Reshape to final form
return values.reshape(m, n * 2).to(dtype=dtype)
def break_int4_bytes_to_int8(packed):
low = (packed & 0x0F).to(torch.int8)
high = ((packed >> 4) & 0x0F).to(torch.int8)
low = torch.where(low >= 8, low - 16, low)
high = torch.where(high >= 8, high - 16, high)
return torch.stack([low, high], dim=-1).reshape(packed.shape[0], -1)
def dequantize_int4_to_dtype(
packed_weight: torch.Tensor,
weight_scale: torch.Tensor,
group_size: int,
dtype: torch.dtype,
weight_scale_2: torch.Tensor = None,
) -> torch.Tensor:
# unpack: [N, K//2] -> [N, K]
unpacked = break_int4_bytes_to_int8(packed_weight)
scale_expanded = weight_scale.repeat_interleave(group_size, dim=1)
dequant = unpacked.float() * scale_expanded.float()
if weight_scale_2 is not None:
dequant = dequant / weight_scale_2.float()
return dequant.to(dtype)
def compute_routing(
router_logits: torch.Tensor, top_k: int
) -> tuple[torch.Tensor, torch.Tensor]:
"""
Compute routing weights and selected experts from router logits.
Args:
router_logits (torch.Tensor): Router logits of shape [batch_size, num_experts]
top_k (int): Number of experts to route to per token
Returns:
tuple[torch.Tensor, torch.Tensor]: A tuple containing:
- routing_weights: Expert weights of shape [batch_size, top_k]
- selected_experts: Expert indices of shape [batch_size, top_k]
"""
routing_weights = F.softmax(router_logits, dim=1, dtype=torch.float)
routing_weights, selected_experts = torch.topk(routing_weights, top_k, dim=-1)
routing_weights /= routing_weights.sum(dim=-1, keepdim=True)
routing_weights = routing_weights.float()
return routing_weights, selected_experts
def torch_moe_nvfp4(a, w1, w2, topk, topk_weight, topk_ids, activation_type):
B, D = a.shape
a = a.view(B, -1, D).repeat(1, topk, 1).reshape(-1, D)
out = torch.zeros(B * topk, w2.shape[1], dtype=a.dtype, device=a.device)
# score = torch.softmax(score, dim=-1, dtype=torch.float32)
# topk_weight, topk_ids = torch.topk(score, topk)
topk_weight = topk_weight.view(-1)
topk_ids = topk_ids.view(-1)
# w1 needs to be swapped in terms of gate and up_proj
if activation_type == ActivationType.Swiglu:
def act(weight, mask):
m = weight.shape[0]
assert m % 2 == 0
w1_expert, w3_expert = weight[m // 2 :, :], weight[: m // 2, :]
return F.silu(a[mask] @ w1_expert.t()) * (a[mask] @ w3_expert.t())
elif activation_type == ActivationType.Relu2:
def act(weight, mask):
return F.relu(a[mask] @ weight.t()) ** 2
else:
raise ValueError(f"Unsupported activation type {activation_type}")
for i in range(w1.shape[0]):
mask = topk_ids == i
if mask.sum():
inter = act(w1[i], mask)
inter_gs = torch.tensor(1.0).cuda()
inter_q, inter_blockscale = fp4_quantize(inter, inter_gs)
inter = dequantize_nvfp4_to_dtype(
inter_q,
inter_blockscale,
inter_gs,
dtype=inter.dtype,
device=inter.device,
block_size=16,
).cuda()
out[mask] = inter @ w2[i].transpose(0, 1)
return (
out.view(B, -1, w2.shape[1]) * topk_weight.view(B, -1, 1).to(out.dtype)
).sum(dim=1)
def torch_moe_w4a8(
num_experts,
x,
w31_weight,
w2_weight,
selected_experts,
routing_weights,
fc1_input_scale,
fc2_input_scale,
fc1_pre_quant_scale,
fc2_pre_quant_scale,
fc1_weight_scale_2,
fc2_weight_scale_2,
):
dtype = x.dtype
results = torch.zeros_like(x)
for expert_id in range(num_experts):
mask = selected_experts == expert_id
if not mask.sum():
continue
batch_idx, nth_expert = torch.where(mask)
w31_expert = w31_weight[expert_id] # [2N, K]
w2_expert = w2_weight[expert_id] # [K, N]
w3_expert, w1_expert = torch.chunk(w31_expert, 2, dim=0)
expert_inputs = x[batch_idx]
if fc1_input_scale is not None:
scale1 = fc1_input_scale[expert_id]
if fc1_pre_quant_scale is not None:
expert_inputs_scaled = expert_inputs * fc1_pre_quant_scale[expert_id]
else:
expert_inputs_scaled = expert_inputs
inp_q = (
torch.clamp(expert_inputs_scaled / scale1, -448.0, 448.0)
.to(torch.float8_e4m3fn)
.to(dtype)
)
x1 = (inp_q @ w1_expert.t()) * scale1
x2 = (inp_q @ w3_expert.t()) * scale1
if fc1_weight_scale_2 is not None:
ws2 = fc1_weight_scale_2[expert_id]
x1 = x1 * ws2.to(dtype)
x2 = x2 * ws2.to(dtype)
inter = F.silu(x1) * x2
if fc2_input_scale is not None:
scale2 = fc2_input_scale[expert_id]
if fc2_pre_quant_scale is not None:
inter_scaled = inter * fc2_pre_quant_scale[expert_id]
else:
inter_scaled = inter
inter_q = (
torch.clamp(inter_scaled / scale2, -448.0, 448.0)
.to(torch.float8_e4m3fn)
.to(dtype)
)
output = (inter_q @ w2_expert.t()) * scale2
if fc2_weight_scale_2 is not None:
ws2 = fc2_weight_scale_2[expert_id]
output = output * ws2.to(dtype)
results[batch_idx] += routing_weights[batch_idx, nth_expert, None] * output
return results.view_as(x)
def compute_with_experts(
num_experts,
x,
w31_weight,
w2_weight,
selected_experts,
routing_weights,
alpha=None,
beta=None,
limit=None,
):
results = torch.zeros_like(x)
for expert_id in range(num_experts):
mask = selected_experts == expert_id
if not mask.sum():
continue
batch_idx, nth_expert = torch.where(mask)
w31_expert = w31_weight[expert_id] # [2 * intermediate_size, hidden_size]
w2_expert = w2_weight[expert_id] # [hidden_size, intermediate_size]
# Split w13 into w1 and w3
w3_expert, w1_expert = torch.chunk(w31_expert, 2, dim=0)
expert_inputs = x[batch_idx]
if alpha is not None and limit is not None and beta is not None:
# SwiGLUBias
x1 = expert_inputs @ w1_expert.t()
x1 = x1.clamp_(min=None, max=limit)
x1_scaled = x1 * torch.sigmoid(alpha * x1)
x2 = expert_inputs @ w3_expert.t()
x2 = x2.clamp_(min=-limit, max=limit) + beta
inter = x1_scaled * x2
else:
inter = F.silu(expert_inputs @ w1_expert.t()) * (
expert_inputs @ w3_expert.t()
)
output = inter @ w2_expert.t()
results[batch_idx] += routing_weights[batch_idx, nth_expert, None] * output
return results.view_as(x)
# Test configurations
BATCH_SIZES = [
1,
]
HIDDEN_SIZES = [
128,
]
NUM_EXPERTS = [2]
TOP_K_VALUES = [2]
INTERMEDIATE_SIZES = [
128,
]
EP_NUM_EXPERTS = [8]
EP_TOP_K = [2]
@pytest.mark.parametrize("batch_size", BATCH_SIZES)
@pytest.mark.parametrize("hidden_size", HIDDEN_SIZES)
@pytest.mark.parametrize("num_experts", NUM_EXPERTS)
@pytest.mark.parametrize("top_k", TOP_K_VALUES)
@pytest.mark.parametrize("intermediate_size", INTERMEDIATE_SIZES)
def test_moe(batch_size, hidden_size, num_experts, top_k, intermediate_size):
# Skip invalid configurations
if top_k > num_experts:
pytest.skip(
f"top_k ({top_k}) cannot be greater than num_experts ({num_experts})"
)
torch.manual_seed(42)
x = torch.randn(batch_size, hidden_size, dtype=torch.float16).cuda() / 5
router_logits = torch.randn(batch_size, num_experts, dtype=torch.float32).cuda()
w31_weight = (
torch.randn(
num_experts, 2 * intermediate_size, hidden_size, dtype=torch.float16
).cuda()
/ 5
)
w2_weight = (
torch.randn(
num_experts, hidden_size, intermediate_size, dtype=torch.float16
).cuda()
/ 5
)
routing_weights, selected_experts = compute_routing(router_logits, top_k)
ref_output = compute_with_experts(
num_experts, x, w31_weight, w2_weight, selected_experts, routing_weights
)
flash_output = torch.empty_like(ref_output)
flash_output = fused_moe.cutlass_fused_moe(
x,
selected_experts.to(torch.int),
routing_weights,
w31_weight,
w2_weight,
flash_output.dtype,
output=flash_output,
quant_scales=None,
)
torch.testing.assert_close(ref_output, flash_output[0], rtol=1e-2, atol=1e-2)
@pytest.mark.parametrize("batch_size", BATCH_SIZES)
@pytest.mark.parametrize("hidden_size", HIDDEN_SIZES)
@pytest.mark.parametrize("num_experts", NUM_EXPERTS)
@pytest.mark.parametrize("top_k", TOP_K_VALUES)
@pytest.mark.parametrize("intermediate_size", INTERMEDIATE_SIZES)
@pytest.mark.parametrize("otype, wtype", [(torch.float16, torch.float8_e4m3fn)])
def test_moe_fp8(
batch_size, hidden_size, num_experts, top_k, intermediate_size, otype, wtype
):
# Skip invalid configurations
if top_k > num_experts:
pytest.skip(
f"top_k ({top_k}) cannot be greater than num_experts ({num_experts})"
)
torch.manual_seed(42)
input_shape = (batch_size, hidden_size)
w31_shape = (num_experts, 2 * intermediate_size, hidden_size)
w2_shape = (num_experts, hidden_size, intermediate_size)
x = cast_to_representable(gen_tensor(input_shape, otype))
router_logits = gen_tensor((batch_size, num_experts), otype)
# Create weight tensors
w31_weight = gen_tensor(w31_shape, otype, wtype)
w2_weight = gen_tensor(w2_shape, otype, wtype)
w31_scales = torch.empty(num_experts, 2, dtype=otype).cuda()
w2_scales = torch.empty(num_experts, 1, dtype=otype).cuda()
w31_dequantized = gen_tensor(w31_shape, otype)
w2_dequantized = gen_tensor(w2_shape, otype)
for expert_id in range(num_experts):
w31 = cast_to_representable(gen_tensor(w31_shape[1:], otype, scale=0.1))
w2 = cast_to_representable(gen_tensor(w2_shape[1:], otype, scale=0.09))
w31_quant, s31 = dynamic_per_tensor_fp8_quant(w31)
w2_quant, s2 = dynamic_per_tensor_fp8_quant(w2)
w31_weight.data[expert_id].copy_(w31_quant)
w2_weight.data[expert_id].copy_(w2_quant)
w31_scales.data[expert_id].copy_(s31)
w2_scales.data[expert_id].copy_(s2)
w31_dequantized.data[expert_id].copy_(torch.mul(w31_quant.to(dtype=otype), s31))
w2_dequantized.data[expert_id].copy_(torch.mul(w2_quant.to(dtype=otype), s2))
routing_weights, selected_experts = compute_routing(router_logits, top_k)
ref_output = compute_with_experts(
num_experts,
x,
w31_dequantized,
w2_dequantized,
selected_experts,
routing_weights,
)
flash_output = torch.empty_like(ref_output)
# For fp8, the hidden_state expects quantized.
_, w1_scales = torch.chunk(w31_scales, 2, dim=-1)
x_quant, hidden_states_scale = dynamic_per_tensor_fp8_quant(x)
hidden_states_scale = torch.tensor(hidden_states_scale[0]).cuda()
quant_scales = [
torch.squeeze(w1_scales * hidden_states_scale).float(),
torch.tensor(1.0).cuda(),
torch.squeeze(1.0 * w2_scales).float(),
hidden_states_scale,
]
_ = fused_moe.cutlass_fused_moe(
x_quant,
selected_experts.to(torch.int),
routing_weights,
w31_weight,
w2_weight,
otype,
quant_scales=quant_scales,
output=flash_output,
)
torch.testing.assert_close(ref_output, flash_output, rtol=1e-1, atol=1e-1)
@pytest.mark.parametrize("batch_size", BATCH_SIZES)
@pytest.mark.parametrize("hidden_size", HIDDEN_SIZES)
@pytest.mark.parametrize("num_experts", NUM_EXPERTS)
@pytest.mark.parametrize("top_k", TOP_K_VALUES)
@pytest.mark.parametrize("intermediate_size", INTERMEDIATE_SIZES)
@pytest.mark.parametrize(
"otype, wtype",
[(torch.float16, torch.float8_e4m3fn), (torch.bfloat16, torch.float8_e4m3fn)],
)
@pytest.mark.parametrize("quantized_input", [False, True])
@pytest.mark.parametrize(
"activation_type",
[ActivationType.Swiglu, ActivationType.Relu2],
ids=["swiglu", "relu2"],
)
@pytest.mark.skipif(
torch.cuda.get_device_capability()[0] not in [10, 11, 12],
reason="NVFP4 is only supported on SM100, SM110 and SM120/SM121",
)
def test_moe_nvfp4(
batch_size,
hidden_size,
num_experts,
top_k,
intermediate_size,
otype,
wtype,
quantized_input,
activation_type,
):
# Skip invalid configurations
if top_k > num_experts:
pytest.skip(
f"top_k ({top_k}) cannot be greater than num_experts ({num_experts})"
)
torch.manual_seed(42)
quant_blocksize = 16
round_up = lambda x, y: (x + y - 1) // y * y
e = num_experts
m = batch_size
n = intermediate_size
k = hidden_size
w1_n = 2 * n if activation_type == ActivationType.Swiglu else n
w1 = torch.randn((e, w1_n, k), device="cuda", dtype=otype) / 10
sf_w1_2n = round_up(w1_n, 128)
sf_w1_k = round_up(k // quant_blocksize, 4)
w1_blockscale = torch.empty(
(e, sf_w1_2n, sf_w1_k), device="cuda", dtype=torch.float8_e4m3fn
)
w1_blockscale_cutlass = torch.empty(
(e, sf_w1_2n, sf_w1_k), device="cuda", dtype=torch.float8_e4m3fn
)
w2 = torch.randn((e, k, n), device="cuda", dtype=otype) / 10
sf_w2_k = round_up(k, 128)
sf_w2_n = round_up(n // quant_blocksize, 4)
w2_blockscale = torch.empty(
(e, sf_w2_k, sf_w2_n), device="cuda", dtype=torch.float8_e4m3fn
)
w1_q = torch.empty((e, w1_n, k // 2), device="cuda", dtype=torch.uint8)
w1_q_cutlass = torch.empty((e, w1_n, k // 2), device="cuda", dtype=torch.uint8)
w2_q = torch.empty((e, k, n // 2), device="cuda", dtype=torch.uint8)
w1_gs = torch.empty((e,), device="cuda", dtype=torch.float32)
w2_gs = torch.empty((e,), device="cuda", dtype=torch.float32)
for expert in range(e):
w1_amax = torch.abs(w1).max().to(torch.float32)
w2_amax = torch.abs(w2).max().to(torch.float32)
w1_gs[expert] = FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX / w1_amax
w2_gs[expert] = FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX / w2_amax
w1_q[expert], w1_blockscale[expert] = fp4_quantize(w1[expert], w1_gs[expert])
w1_q_cutlass[expert], w1_blockscale_cutlass[expert] = fp4_quantize(
w1[expert], w1_gs[expert]
)
w2_q[expert], w2_blockscale[expert] = fp4_quantize(w2[expert], w2_gs[expert])
x = torch.randn(m, k, dtype=otype).cuda()
a1_gs = (FLOAT8_E4M3_MAX * FLOAT4_E2M1_MAX) / torch.abs(x).max().to(
torch.float32
).cuda()
a1_gs = torch.tensor(1.0, device="cuda", dtype=torch.float32)
a2_gs = torch.tensor(1.0, device="cuda", dtype=torch.float32)
router_logits = torch.randn(m, e, dtype=otype).cuda()
routing_weights, selected_experts = compute_routing(router_logits, top_k)
# quant_scales format
# auto const fc1_act_global = quant_scales.value()[0];
# auto const fc1_weight_block = quant_scales.value()[1];
# auto const fc1_global = quant_scales.value()[2];
# auto const fc2_act_global = quant_scales.value()[3];
# auto const fc2_weight_block = quant_scales.value()[4];
# auto const fc2_global = quant_scales.value()[5];
flash_output = torch.zeros_like(x)
quant_scales = [
a1_gs,
w1_blockscale.view(torch.int32),
1.0 / (a1_gs * w1_gs),
a2_gs,
w2_blockscale.view(torch.int32),
1.0 / (a2_gs * w2_gs),
]
hidden_states = x
input_sf = None
if quantized_input:
hidden_states, input_sf = fp4_quantize(x, a1_gs)
_ = fused_moe.cutlass_fused_moe(
hidden_states,
selected_experts.to(torch.int),
routing_weights,
w1_q.contiguous().view(torch.long),
w2_q.contiguous().view(torch.long),
otype,
quant_scales=quant_scales,
input_sf=input_sf,
output=flash_output,
activation_type=activation_type,
)
# Ref check
a_fp4, a_scale_interleaved = fp4_quantize(x, a1_gs)
_, m_k = a_fp4.shape
a_in_dtype = dequantize_nvfp4_to_dtype(
a_fp4,
a_scale_interleaved,
a1_gs,
dtype=otype,
device=x.device,
block_size=quant_blocksize,
)
w1_d = torch.empty((e, w1_n, k), device="cuda", dtype=otype)
w2_d = torch.empty((e, k, n), device="cuda", dtype=otype)
for idx in range(0, e):
w1_d[idx] = dequantize_nvfp4_to_dtype(
w1_q[idx],
w1_blockscale[idx],
w1_gs[idx],
dtype=w1.dtype,
device=w1.device,
block_size=quant_blocksize,
)
w2_d[idx] = dequantize_nvfp4_to_dtype(
w2_q[idx],
w2_blockscale[idx],
w2_gs[idx],
dtype=w2.dtype,
device=w2.device,
block_size=quant_blocksize,
)
ref_output = torch_moe_nvfp4(
a_in_dtype,
w1_d,
w2_d,
top_k,
routing_weights,
selected_experts,
activation_type,
)
torch.testing.assert_close(ref_output, flash_output, rtol=2e-1, atol=2e-1)
@pytest.mark.parametrize("batch_size", BATCH_SIZES)
@pytest.mark.parametrize("hidden_size", HIDDEN_SIZES)
@pytest.mark.parametrize("num_experts", EP_NUM_EXPERTS)
@pytest.mark.parametrize("top_k", EP_TOP_K)
@pytest.mark.parametrize("intermediate_size", INTERMEDIATE_SIZES)
def test_moe_expert_parallel(
batch_size, hidden_size, num_experts, top_k, intermediate_size
):
"""
Test expert parallelism with X GPUs and Y experts.
Each GPU handles one expert and results are reduced.
Args:
batch_size: Batch size for the input
hidden_size: Hidden dimension size
num_experts: Number of experts (must be 2 for this test)
top_k: Number of experts to route to per token
intermediate_size: Intermediate dimension size
activation: Activation function type
"""
# This test is specifically for 2 GPUs and 2 experts
# GPU 0 (ep_rank=0) handles expert 0
# GPU 1 (ep_rank=1) handles expert 1
ep_size = num_experts // 2
torch.manual_seed(42)
# Create input tensors
x = torch.randn(batch_size, hidden_size, dtype=torch.float16).cuda()
# Create weight tensors - each GPU will have one expert
w31_weight = (
torch.randn(
num_experts, 2 * intermediate_size, hidden_size, dtype=torch.float16
).cuda()
/ 10
)
w2_weight = (
torch.randn(
num_experts, hidden_size, intermediate_size, dtype=torch.float16
).cuda()
/ 10
)
selected_experts = torch.stack(
[torch.randperm(num_experts)[:top_k] for _ in range(batch_size)]
).cuda()
routing_weights = torch.randn((batch_size, top_k)).cuda()
routing_weights = F.softmax(routing_weights, dim=1)
ref_output = compute_with_experts(
num_experts, x, w31_weight, w2_weight, selected_experts, routing_weights
)
outputs = []
flash_output = torch.zeros_like(ref_output)
for ep_rank in range(ep_size):
# Create output tensor for this GPU
out_hidden_states_local = torch.zeros_like(x)
# Compute expert start and end positions for this rank
experts_per_rank = (
num_experts // ep_size
) # 2 GPUs, so each gets half the experts
expert_start = ep_rank * experts_per_rank
expert_end = expert_start + experts_per_rank # if ep_rank < 1 else num_experts
w31_weight_local = w31_weight[
expert_start:expert_end, :
] # Get only the experts for this rank
w2_weight_local = w2_weight[
expert_start:expert_end, :
] # Get only the experts for this rank
_ = fused_moe.cutlass_fused_moe(
x.contiguous(),
selected_experts.to(torch.int),
routing_weights,
w31_weight_local.contiguous(),
w2_weight_local.contiguous(),
x.dtype,
ep_size=ep_size,
ep_rank=ep_rank,
quant_scales=None,
output=out_hidden_states_local,
)
outputs.append(out_hidden_states_local)
# Reduce results from all GPUs
for ep_rank in range(ep_size):
flash_output += outputs[ep_rank] # [batch_size, num_experts]
torch.testing.assert_close(ref_output, flash_output, rtol=1e-1, atol=1e-1)
TP_SIZES = [2, 4]
@pytest.mark.parametrize("batch_size", BATCH_SIZES)
@pytest.mark.parametrize("hidden_size", HIDDEN_SIZES)
@pytest.mark.parametrize("num_experts", NUM_EXPERTS)
@pytest.mark.parametrize("tp_size", TP_SIZES)
@pytest.mark.parametrize("intermediate_size", INTERMEDIATE_SIZES)
def test_moe_tensor_parallel(
batch_size, hidden_size, num_experts, tp_size, intermediate_size
):
"""
Test tensor parallelism with:
- w31 sharded along second dimension (non-contracting)
- w2 sharded along third dimension (contracting)
- All-reduce to sum partial results
Args:
batch_size: Batch size for the input
hidden_size: Hidden dimension size
num_experts: Number of experts
top_k: Number of experts to route to per token
intermediate_size: Intermediate dimension size
activation: Activation function type
"""
# Set random seed for reproducibility
torch.manual_seed(42)
top_k = 2
# Create input tensors
x = torch.randn(batch_size, hidden_size, dtype=torch.float16).cuda()
# Create weight tensors
w31_weight = (
torch.randn(
num_experts, 2 * intermediate_size, hidden_size, dtype=torch.float16
).cuda()
/ 10
)
w2_weight = (
torch.randn(
num_experts, hidden_size, intermediate_size, dtype=torch.float16
).cuda()
/ 10
)
# Generate unique random expert indices for each token
selected_experts = torch.stack(
[torch.randperm(num_experts)[:top_k] for _ in range(batch_size)]
).cuda()
routing_weights = torch.randn((batch_size, top_k)).cuda()
routing_weights = F.softmax(routing_weights, dim=1)
# Run reference implementation (no parallelism)
ref_output = compute_with_experts(
num_experts, x, w31_weight, w2_weight, selected_experts, routing_weights
)
# Simulate tensor parallelism on # TP GPUs
outputs = []
for tp_rank in range(tp_size):
# Create output tensor for this GPU
out_hidden_states_local = torch.zeros_like(x)
# Shard w31 along second dimension (intermediate_size)
# First split w31 into w3 and w1
w3_weight, w1_weight = torch.chunk(
w31_weight, 2, dim=1
) # [num_experts, intermediate_size, hidden_size] each
# Shard w3 and w1 separately
w3_shard_size = intermediate_size // tp_size
w3_start = tp_rank * w3_shard_size
w3_end = w3_start + w3_shard_size
w3_weight_local = w3_weight[:, w3_start:w3_end, :]
w1_shard_size = intermediate_size // tp_size
w1_start = tp_rank * w1_shard_size
w1_end = w1_start + w1_shard_size
w1_weight_local = w1_weight[:, w1_start:w1_end, :]
# Stack the sharded weights back together
w31_weight_local = torch.cat([w3_weight_local, w1_weight_local], dim=1)
# Shard w2 along third dimension (intermediate_size)
w2_shard_size = intermediate_size // tp_size
w2_start = tp_rank * w2_shard_size
w2_end = w2_start + w2_shard_size
w2_weight_local = w2_weight[:, :, w2_start:w2_end]
_ = fused_moe.cutlass_fused_moe(
x.contiguous(),
selected_experts.to(torch.int),
routing_weights,
w31_weight_local.contiguous(),
w2_weight_local.contiguous(),
x.dtype,
tp_size=tp_size,
tp_rank=tp_rank,
quant_scales=None,
output=out_hidden_states_local,
)
outputs.append(out_hidden_states_local)
# All-reduce to sum partial results from all GPUs
flash_output = sum(outputs)
torch.testing.assert_close(ref_output, flash_output, rtol=1e-2, atol=1e-2)
@pytest.mark.parametrize("batch_size", BATCH_SIZES)
@pytest.mark.parametrize("hidden_size", HIDDEN_SIZES)
@pytest.mark.parametrize("num_experts", EP_NUM_EXPERTS)
@pytest.mark.parametrize("top_k", EP_TOP_K)
@pytest.mark.parametrize("tp_size", TP_SIZES)
@pytest.mark.parametrize("intermediate_size", INTERMEDIATE_SIZES)
def test_moe_tensor_expert_parallel(
batch_size, hidden_size, num_experts, top_k, tp_size, intermediate_size
):
"""
Test combined tensor parallelism and expert parallelism:
- Expert parallelism: Distribute experts across GPUs
- Tensor parallelism: For each expert's weights:
- w31 sharded along second dimension (non-contracting)
- w2 sharded along third dimension (contracting)
- All-reduce to sum partial results
Args:
batch_size: Batch size for the input
hidden_size: Hidden dimension size
num_experts: Number of experts
tp_size: Number of GPUs for tensor parallelism
intermediate_size: Intermediate dimension size
"""
torch.manual_seed(42)
x = torch.randn(batch_size, hidden_size, dtype=torch.float16).cuda()
w31_weight = (
torch.randn(
num_experts, 2 * intermediate_size, hidden_size, dtype=torch.float16
).cuda()
/ 10
)
w2_weight = (
torch.randn(
num_experts, hidden_size, intermediate_size, dtype=torch.float16
).cuda()
/ 10
)
# Generate unique random expert indices for each token
selected_experts = torch.stack(
[torch.randperm(num_experts)[:top_k] for _ in range(batch_size)]
).cuda()
routing_weights = torch.randn((batch_size, top_k)).cuda()
routing_weights = F.softmax(routing_weights, dim=1)
# Run reference implementation (no parallelism)
ref_output = compute_with_experts(
num_experts, x, w31_weight, w2_weight, selected_experts, routing_weights
)
# Simulate combined parallelism
ep_size = num_experts // 2 # Number of GPUs for expert parallelism
outputs = []
# For each expert parallel rank
for ep_rank in range(ep_size):
# Get experts for this rank
experts_per_rank = num_experts // ep_size
expert_start = ep_rank * experts_per_rank
expert_end = expert_start + experts_per_rank
# Get expert weights for this rank
w31_weight_ep = w31_weight[
expert_start:expert_end, :
] # [experts_per_rank, 2*intermediate_size, hidden_size]
w2_weight_ep = w2_weight[
expert_start:expert_end, :
] # [experts_per_rank, hidden_size, intermediate_size]
# For each tensor parallel rank
for tp_rank in range(tp_size):
# Create output tensor for this GPU
out_hidden_states_local = torch.zeros_like(x)
# Split w31 into w3 and w1
w3_weight, w1_weight = torch.chunk(w31_weight_ep, 2, dim=1)
# Shard w3 and w1 separately
w3_shard_size = intermediate_size // tp_size
w3_start = tp_rank * w3_shard_size
w3_end = w3_start + w3_shard_size
w3_weight_local = w3_weight[:, w3_start:w3_end, :]
w1_shard_size = intermediate_size // tp_size
w1_start = tp_rank * w1_shard_size
w1_end = w1_start + w1_shard_size
w1_weight_local = w1_weight[:, w1_start:w1_end, :]
# Stack the sharded weights back together
w31_weight_local = torch.cat([w3_weight_local, w1_weight_local], dim=1)
# Shard w2 along third dimension
w2_shard_size = intermediate_size // tp_size
w2_start = tp_rank * w2_shard_size
w2_end = w2_start + w2_shard_size
w2_weight_local = w2_weight_ep[:, :, w2_start:w2_end]
# Call flashinfer implementation with both parallelisms
out_hidden_states_local = fused_moe.cutlass_fused_moe(
x.contiguous(),
selected_experts.to(torch.int),
routing_weights,
w31_weight_local.contiguous(),
w2_weight_local.contiguous(),
x.dtype,
tp_size=tp_size,
tp_rank=tp_rank,
ep_size=ep_size,
ep_rank=ep_rank,
quant_scales=None,
)
outputs.append(out_hidden_states_local[0])
# All-reduce to sum partial results from all GPUs
flash_output = sum(outputs)
torch.testing.assert_close(ref_output, flash_output, rtol=1e-2, atol=1e-2)
def ceil_div(a: int, b: int) -> int:
return -(a // -b)
def per_block_cast_to_fp8(
x: torch.Tensor, block_size_n: int = 128
) -> tuple[torch.Tensor, torch.Tensor]:
assert x.dim() == 2
m, n = x.shape
x_padded = torch.zeros(
(ceil_div(m, 128) * 128, ceil_div(n, block_size_n) * block_size_n),
dtype=x.dtype,
device=x.device,
)
x_padded[:m, :n] = x
x_view = x_padded.view(-1, 128, x_padded.size(1) // 128, block_size_n)
x_amax = x_view.abs().float().amax(dim=(1, 3), keepdim=True).clamp(1e-4)
x_scaled = (x_view * (448.0 / x_amax)).to(torch.float8_e4m3fn)
x_scaled_sub = x_scaled.view_as(x_padded)[:m, :n].contiguous()
scales = (x_amax / 448.0).view(x_view.size(0), x_view.size(2))
return x_scaled_sub, scales
def per_token_group_quant_fp8(x, group_size, eps=1e-10, dtype=torch.float8_e4m3fn):
"""Function to perform per-token-group quantization on an input tensor
`x` using native torch."""
assert x.shape[-1] % group_size == 0, (
"the last dimension of `x` cannot be divisible by `group_size`"
)
assert x.is_contiguous(), "`x` is not contiguous"
finfo = torch.finfo(dtype)
fp8_min = finfo.min
fp8_max = finfo.max
x_ = x.reshape(x.numel() // group_size, group_size)
amax = x_.abs().max(dim=-1, keepdim=True)[0].clamp(min=eps).to(torch.float32)
x_s = amax / fp8_max
x_q = (x_ / x_s).clamp(min=fp8_min, max=fp8_max).to(dtype)
x_q = x_q.reshape(x.shape)
x_s = x_s.reshape(x.shape[:-1] + (x.shape[-1] // group_size,))