-
Notifications
You must be signed in to change notification settings - Fork 665
Expand file tree
/
Copy pathtest_custom_call_compute.py
More file actions
1972 lines (1715 loc) · 78.7 KB
/
test_custom_call_compute.py
File metadata and controls
1972 lines (1715 loc) · 78.7 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
# Copyright (c) 2022-2026, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# See LICENSE for license information.
import jax
import jax.numpy as jnp
import pytest
from jax import jit, value_and_grad
from functools import reduce
from typing import Union
import operator
from utils import (
assert_allclose,
pytest_parametrize_wrapper,
use_jax_gemm,
)
from transformer_engine.jax.layernorm import layernorm
from transformer_engine.jax.layernorm_mlp import layernorm_mlp
from transformer_engine.jax.cpp_extensions.activation import _jax_act_lu, _jax_quantize_dact_dbias
from transformer_engine.jax.cpp_extensions.normalization import (
_jax_layernorm,
_jax_rmsnorm,
is_norm_zero_centered_gamma_in_weight_dtype,
)
from transformer_engine.jax.cpp_extensions.quantization import (
_jax_quantize,
_jax_quantize_dbias,
)
from transformer_engine.jax.cpp_extensions.misc import get_cudnn_version
from transformer_engine.jax import cpp_extensions as tex
from transformer_engine.jax.quantize import (
NoScaleTensor,
ScaledTensor,
ScaledTensor1x,
ScaledTensor2x,
GroupedScaledTensor1x,
GroupedNoScaleTensor,
ScalingMode,
QuantizerFactory,
QuantizeLayout,
noop_quantizer_set,
QuantizeMetaSet,
QuantizeMeta,
)
from transformer_engine.jax.quantize import helper
from transformer_engine.jax.activation import activation
from transformer_engine.jax.dense import dense, grouped_dense
from transformer_engine.jax.layernorm_dense import layernorm_dense
GEMM_CASES = [
(256, 256, 512),
(32, 32, 32),
(2048, 1024, 2048),
(2048, 2048, 1024),
(2048, 1024, 1024),
]
FP8_COMPUTE_TYPE = [jnp.float8_e4m3fn, jnp.float8_e5m2]
LN_CASES = [(256, 128), (128, 256)]
DTYPES = [jnp.bfloat16, jnp.float32]
# TODO(Phuong): remove unneccessary pytest skips
is_fp8_supported, fp8_unsupported_reason = helper.is_scaling_mode_supported(
ScalingMode.DELAYED_TENSOR_SCALING
)
is_mxfp8_supported, mxfp8_unsupported_reason = helper.is_scaling_mode_supported(
ScalingMode.MXFP8_1D_SCALING
)
is_fp4_supported, fp4_unsupported_reason = helper.is_scaling_mode_supported(
ScalingMode.NVFP4_1D_SCALING
)
""" Find supported scaling modes"""
supported_scaling_modes = helper.get_supported_scaling_modes()
non_fp4_supported_scaling_modes = [s for s in supported_scaling_modes if not s.is_nvfp4_scaling]
supported_recipes = helper.get_supported_quantization_recipes()
supported_recipes = [pytest.param(r, id=r.__class__.__name__) for r in supported_recipes]
def is_shape_supported_by_mxfp8(input_shape):
try:
if isinstance(input_shape, type(pytest.param(0))):
input_shape = input_shape.values[0]
ScalingMode.MXFP8_1D_SCALING.get_scale_shape_2x(input_shape)
return True
except:
# get_scale_shapes will raise an exception if the shape is not supported
return False
def assert_bitwise_scaled_tensors(
a: ScaledTensor, b: ScaledTensor, precise_comparison: bool = True
):
if isinstance(a, ScaledTensor1x) and isinstance(b, ScaledTensor1x):
if not precise_comparison and not a.scaling_mode.is_nvfp4_scaling:
assert_allclose(a.dequantize(), b.dequantize(), dtype=a.data.dtype)
return
assert a.scaling_mode == b.scaling_mode
assert a.scale_inv.dtype == b.scale_inv.dtype
assert a.data_layout == b.data_layout
if a.scaling_mode.is_tensor_scaling():
# Assert in dq_dtype as some unfused codepaths have an intermediate cast
# to an input dtype which reduces precision compared to everything in fp32
assert_allclose(a.scale_inv, b.scale_inv, dtype=a.dq_dtype)
elif a.scaling_mode == ScalingMode.MXFP8_1D_SCALING:
# Compare MXFP8 scales as uint8
assert_allclose(a.scale_inv.astype(jnp.uint8), b.scale_inv.astype(jnp.uint8))
elif a.scaling_mode.is_nvfp4_scaling:
assert_allclose(a.amax, b.amax)
assert_allclose(a.scale_inv, b.scale_inv)
if not precise_comparison:
mismatch = a.data != b.data
mismatch_fraction = jnp.mean(mismatch.astype(jnp.float32))
assert (
mismatch_fraction < 0.05
), f"Mismatch fraction {mismatch_fraction} is too high"
return
else:
raise ValueError(f"Unsupported scaling mode {a.scaling_mode}")
assert_allclose(a.data, b.data)
elif isinstance(a, ScaledTensor2x) and isinstance(b, ScaledTensor2x):
assert_bitwise_scaled_tensors(
a.rowwise_tensor, b.rowwise_tensor, precise_comparison=precise_comparison
)
assert_bitwise_scaled_tensors(
a.colwise_tensor, b.colwise_tensor, precise_comparison=precise_comparison
)
else:
pytest.fail("Unsupported input types")
def assert_dequantized_scaled_tensor(a: ScaledTensor, b: jnp.ndarray):
if isinstance(a, ScaledTensor1x):
if a.data_layout == "T":
flatten_axis = a.data.ndim - a.flatten_axis
b_transpose = jnp.transpose(b, (*range(flatten_axis, b.ndim), *range(flatten_axis)))
assert_allclose(a.dequantize(), b_transpose, dtype=a.data.dtype)
else:
assert_allclose(a.dequantize(), b, dtype=a.data.dtype)
elif isinstance(a, ScaledTensor2x):
assert_dequantized_scaled_tensor(a.rowwise_tensor, b)
assert_dequantized_scaled_tensor(a.colwise_tensor, b)
else:
pytest.fail("a must be a ScaledTensor object")
def assert_dequantized_grouped_scaled_tensor(
a: Union[GroupedScaledTensor1x, ScaledTensor2x], b: jnp.ndarray
):
if isinstance(a, GroupedScaledTensor1x):
assert a.group_sizes.sum() == b.shape[0]
b = jnp.split(b, jnp.cumulative_sum(a.group_sizes)[:-1], axis=0)
dq_a = a.dequantize()
for dq_a_i, b_i in zip(dq_a, b):
if len(dq_a_i) == 0:
continue
if a.data_layout == "T":
data_ndim = len(a.original_shape)
flatten_axis = a.flatten_axis
if b_i.shape[0] == 1:
b_i = jnp.transpose(
b_i, (0, *range(flatten_axis, data_ndim), *range(1, flatten_axis))
)
else:
b_i = jnp.transpose(
b_i, (*range(flatten_axis, data_ndim), *range(flatten_axis))
)
dq_a_i = dq_a_i.reshape(b_i.shape)
assert_allclose(dq_a_i, b_i, dtype=a.data.dtype)
elif isinstance(a, ScaledTensor2x):
assert isinstance(a.rowwise_tensor, GroupedScaledTensor1x)
assert isinstance(a.colwise_tensor, GroupedScaledTensor1x)
assert_dequantized_grouped_scaled_tensor(a.rowwise_tensor, b)
assert_dequantized_grouped_scaled_tensor(a.colwise_tensor, b)
else:
pytest.fail("a must be a GroupedScaledTensor object")
ALL_ACTIVATION_SHAPES = [(32, 64), (16, 128, 256)]
ALL_ACTIVATION_TYPES = [
("gelu",),
("gelu", "linear"),
("silu",),
("silu", "linear"),
("relu",),
("relu", "linear"),
("quick_gelu",),
("quick_gelu", "linear"),
("squared_relu",),
("squared_relu", "linear"),
("clamped_silu", "clamped_linear"),
]
ACTIVATION_TYPES = {
"L0": [
("gelu",),
("gelu", "linear"),
],
"L2": ALL_ACTIVATION_TYPES,
}
class TestActivation:
def ref_act(self, x, activation_type, act_params):
return _jax_act_lu(x, activation_type, act_params=act_params).data
def value_n_grad_ref_func(self, x, activation_type, act_params):
jitted_reference = jit(
value_and_grad(
lambda out: jnp.mean(self.ref_act(out, activation_type, act_params)), (0,)
)
)
return jitted_reference(x)
def primitive_func(self, inputs, activation_type, quantizer, act_params):
out = activation(
inputs, activation_type=activation_type, quantizer=quantizer, act_params=act_params
)
return jnp.mean(out)
@pytest_parametrize_wrapper("shape", ALL_ACTIVATION_SHAPES)
@pytest_parametrize_wrapper(
"activation_type",
(
ALL_ACTIVATION_TYPES # Test all activation types for this test to ensure all are functional, then just test a subset for the other tests to verify other functionality
),
)
def test_act_grad(self, shape, activation_type):
key = jax.random.PRNGKey(0)
x = jax.random.uniform(key, shape, jnp.float32)
x = jnp.expand_dims(x, axis=-2)
x = jnp.repeat(x, len(activation_type), axis=-2)
value_n_grad_primitive_func = jit(
value_and_grad(self.primitive_func, (0,)), static_argnums=(1, 3)
)
act_args = (
{"limit": 0.75, "alpha": 1.702}
if activation_type == ("clamped_silu", "clamped_linear")
else {}
)
act_params = (
tex.activation.ActivationParams.create(activation_type=activation_type, **act_args)
if activation_type == ("clamped_silu", "clamped_linear")
else None
)
prim_out, (prim_grad,) = value_n_grad_primitive_func(x, activation_type, None, act_params)
ref_out, (ref_grad,) = self.value_n_grad_ref_func(x, activation_type, act_params)
assert_allclose(prim_out, ref_out, dtype=x.dtype)
assert_allclose(prim_grad, ref_grad, dtype=x.dtype)
@pytest.mark.skipif(not is_fp8_supported, reason=fp8_unsupported_reason)
@pytest_parametrize_wrapper("shape", ALL_ACTIVATION_SHAPES)
@pytest_parametrize_wrapper("activation_type", ACTIVATION_TYPES)
@pytest_parametrize_wrapper("output_type", [jnp.float8_e4m3fn, jnp.float8_e5m2])
@pytest_parametrize_wrapper(
"scaling_mode", [ScalingMode.DELAYED_TENSOR_SCALING, ScalingMode.CURRENT_TENSOR_SCALING]
)
def test_act_grad_with_tensor_scaling_fp8(
self, random_inputs, activation_type, output_type, scaling_mode
):
x = random_inputs
x = jnp.expand_dims(x, axis=-2)
x = jnp.repeat(x, len(activation_type), axis=-2)
self.activation_type = activation_type
value_n_grad_primitive_func = jit(
value_and_grad(self.primitive_func, (0,)),
static_argnums=(1, 3),
)
quantizer = QuantizerFactory.create(
scaling_mode=scaling_mode,
q_dtype=output_type,
q_layout=QuantizeLayout.ROWWISE,
)
act_args = (
{"limit": 0.75, "alpha": 1.702}
if activation_type == ("clamped_silu", "clamped_linear")
else {}
)
act_params = (
tex.activation.ActivationParams.create(activation_type=activation_type, **act_args)
if activation_type == ("clamped_silu", "clamped_linear")
else None
)
prim_out, (prim_grad,) = value_n_grad_primitive_func(
x, activation_type, quantizer, act_params
)
ref_out, (ref_grad,) = self.value_n_grad_ref_func(x, activation_type, act_params)
assert_allclose(prim_out, ref_out, dtype=output_type)
assert_allclose(prim_grad, ref_grad, dtype=output_type)
@pytest.mark.skipif(not is_mxfp8_supported, reason=mxfp8_unsupported_reason)
@pytest_parametrize_wrapper("shape", ALL_ACTIVATION_SHAPES)
@pytest_parametrize_wrapper("activation_type", ACTIVATION_TYPES)
@pytest_parametrize_wrapper("output_type", [jnp.float8_e4m3fn, jnp.float8_e5m2])
@pytest_parametrize_wrapper(
"q_layout", [QuantizeLayout.ROWWISE, QuantizeLayout.ROWWISE_COLWISE]
)
@pytest_parametrize_wrapper(
"scaling_mode", [ScalingMode.DELAYED_TENSOR_SCALING, ScalingMode.CURRENT_TENSOR_SCALING]
)
def test_act_forward_with_tensor_scaling_fp8(
self, random_inputs, activation_type, output_type, q_layout, scaling_mode
):
x = random_inputs
x = jnp.expand_dims(x, axis=-2)
x = jnp.repeat(x, len(activation_type), axis=-2)
self.activation_type = activation_type
te_quantizer, jax_quantizer = QuantizerFactory.create(
n_quantizers=2,
scaling_mode=scaling_mode,
q_dtype=output_type,
q_layout=q_layout,
)
act_args = (
{"limit": 0.75, "alpha": 1.702}
if activation_type == ("clamped_silu", "clamped_linear")
else {}
)
act_params = (
tex.activation.ActivationParams.create(activation_type=activation_type, **act_args)
if activation_type == ("clamped_silu", "clamped_linear")
else None
)
te_output = tex.act_lu(x, activation_type, te_quantizer, act_params)
jax_output = _jax_act_lu(x, activation_type, jax_quantizer, act_params)
assert_bitwise_scaled_tensors(te_output, jax_output)
@pytest.mark.skipif(not is_mxfp8_supported, reason=mxfp8_unsupported_reason)
@pytest_parametrize_wrapper("shape", [(2, 64, 1, 256)])
@pytest_parametrize_wrapper("activation_type", ACTIVATION_TYPES)
@pytest_parametrize_wrapper("output_type", [jnp.float8_e4m3fn, jnp.float8_e5m2])
@pytest_parametrize_wrapper(
"q_layout", [QuantizeLayout.ROWWISE, QuantizeLayout.ROWWISE_COLWISE]
)
def test_act_forward_with_block_scaling_fp8(
self, random_inputs, activation_type, output_type, q_layout
):
x = random_inputs
x = jnp.repeat(x, len(activation_type), axis=-2)
self.activation_type = activation_type
quantizer = QuantizerFactory.create(
scaling_mode=ScalingMode.MXFP8_1D_SCALING, q_dtype=output_type, q_layout=q_layout
)
act_args = (
{"limit": 0.75, "alpha": 1.702}
if activation_type == ("clamped_silu", "clamped_linear")
else {}
)
act_params = (
tex.activation.ActivationParams.create(activation_type=activation_type, **act_args)
if activation_type == ("clamped_silu", "clamped_linear")
else None
)
output = tex.act_lu(x, activation_type, quantizer, act_params)
ref_out = self.ref_act(x, activation_type, act_params)
assert_dequantized_scaled_tensor(output, ref_out)
NORM_OUTPUT_DTYPES = {
"L0": [jnp.float8_e4m3fn],
"L2": [jnp.float8_e4m3fn, jnp.float8_e5m2],
}
@pytest_parametrize_wrapper("n, hidden", LN_CASES)
@pytest_parametrize_wrapper("inp_dtype", DTYPES)
@pytest_parametrize_wrapper("norm_type", ["layernorm", "rmsnorm"])
@pytest_parametrize_wrapper(
"zero_centered_gamma",
[
pytest.param(True, id="zero_centered"),
pytest.param(False, id="no_zero_centered"),
],
)
@pytest_parametrize_wrapper("epsilon", [1e-2, 1e-6])
class TestNorm:
"""
Test transformer_engine.jax.layernorm APIs
"""
def _test_norm_grad(
self, n, hidden, norm_type, zero_centered_gamma, epsilon, inp_dtype, quantizer
):
def compute_loss(x):
# Higher precision to compute the loss
x_ = x.astype(jnp.float32)
return jnp.mean(jnp.square(x_)).astype(x.dtype)
def reference_func(x, gamma, beta, norm_type, zero_centered_gamma, eps, quantizer):
if norm_type == "rmsnorm":
ln_out, _ = _jax_rmsnorm(x, gamma, zero_centered_gamma, eps, quantizer)
else:
ln_out, _, _ = _jax_layernorm(x, gamma, beta, zero_centered_gamma, eps, quantizer)
# This is a no-op for non-quantized data
ln_out = ln_out.dequantize()
return ln_out
key = jax.random.PRNGKey(0)
subkeys = jax.random.split(key, 3)
x = jax.random.uniform(subkeys[0], (n, hidden), jnp.float32, -1, 1)
x = x.astype(inp_dtype)
gamma_range = (-1, 1) if zero_centered_gamma else (0, 2)
gamma = jax.random.uniform(subkeys[1], (hidden,), jnp.float32, *gamma_range)
gamma = jnp.asarray(gamma, inp_dtype)
if norm_type == "layernorm":
beta = jax.random.uniform(subkeys[2], (hidden,), jnp.float32, -1, 1)
beta = jnp.asarray(beta, inp_dtype)
else:
beta = None
jitted_reference = jit(
value_and_grad(
lambda x, gamma, beta: compute_loss(
reference_func(
x, gamma, beta, norm_type, zero_centered_gamma, epsilon, quantizer=None
)
),
(0, 1, 2),
)
)
jitted_primitive = jit(
value_and_grad(
lambda x, gamma, beta: compute_loss(
layernorm(x, gamma, beta, norm_type, zero_centered_gamma, epsilon, quantizer)
),
(0, 1, 2),
)
)
reference_out, (reference_dx, reference_dgamma, reference_dbeta) = jitted_reference(
x, gamma, beta
)
primitive_out, (primitive_dx, primitive_dgamma, primitive_dbeta) = jitted_primitive(
x, gamma, beta
)
out_dtype = inp_dtype if quantizer is None else quantizer.q_dtype
assert_allclose(primitive_out, reference_out, dtype=out_dtype)
assert_allclose(primitive_dx, reference_dx, dtype=out_dtype)
assert_allclose(primitive_dgamma, reference_dgamma, dtype=out_dtype)
if beta is not None:
assert_allclose(primitive_dbeta, reference_dbeta, dtype=out_dtype)
def test_norm_grad(self, n, hidden, norm_type, zero_centered_gamma, epsilon, inp_dtype):
"""
Test transformer_engine.jax.layernorm.layernorm
"""
if norm_type == "rmsnorm" and zero_centered_gamma is True:
pytest.skip("RMSNorm and zero_centered_gamma is not supported!")
self._test_norm_grad(
n, hidden, norm_type, zero_centered_gamma, epsilon, inp_dtype, quantizer=None
)
@pytest.mark.skipif(not is_fp8_supported, reason=fp8_unsupported_reason)
# No Norm FWD E5M2 in TE backend
@pytest_parametrize_wrapper("out_dtype", [jnp.float8_e4m3fn])
@pytest_parametrize_wrapper(
"q_layout", [QuantizeLayout.ROWWISE, QuantizeLayout.ROWWISE_COLWISE]
)
@pytest_parametrize_wrapper(
"scaling_mode", [ScalingMode.DELAYED_TENSOR_SCALING, ScalingMode.CURRENT_TENSOR_SCALING]
)
def test_norm_grad_with_tensor_scaling_fp8(
self,
n,
hidden,
norm_type,
zero_centered_gamma,
epsilon,
inp_dtype,
out_dtype,
q_layout,
scaling_mode,
):
"""
Test transformer_engine.jax.layernorm.layernorm
"""
if norm_type == "rmsnorm" and zero_centered_gamma is True:
pytest.skip("RMSNorm and zero_centered_gamma is not supported!")
quantizer = QuantizerFactory.create(
scaling_mode=scaling_mode, q_dtype=out_dtype, q_layout=q_layout
)
self._test_norm_grad(
n, hidden, norm_type, zero_centered_gamma, epsilon, inp_dtype, quantizer
)
def _test_norm_forward(
self,
n,
hidden,
norm_type,
zero_centered_gamma,
epsilon,
inp_dtype,
out_dtype,
scaling_mode,
q_layout,
):
key = jax.random.PRNGKey(0)
subkeys = jax.random.split(key, 3)
x = jax.random.uniform(subkeys[0], (n, hidden), inp_dtype, -1, 1)
x = jnp.asarray(x, inp_dtype)
gamma_range = (-1, 1) if zero_centered_gamma else (0, 2)
gamma = jax.random.uniform(subkeys[1], (hidden,), jnp.float32, *gamma_range)
gamma = jnp.asarray(gamma, inp_dtype)
quantizer, ref_quantizer = QuantizerFactory.create(
n_quantizers=2, scaling_mode=scaling_mode, q_dtype=out_dtype, q_layout=q_layout
)
if norm_type == "layernorm":
beta = jax.random.uniform(subkeys[2], (hidden,), jnp.float32, -1, 1)
beta = jnp.asarray(beta, inp_dtype)
output, mu, rsigma = tex.layernorm_fwd(
x, gamma, beta, zero_centered_gamma, epsilon, quantizer=quantizer
)
ref_out, ref_mu, ref_rsigma = _jax_layernorm(
x,
gamma,
beta,
zero_centered_gamma,
epsilon,
quantizer=ref_quantizer,
)
else:
output, rsigma = tex.rmsnorm_fwd(
x, gamma, zero_centered_gamma, epsilon, quantizer=quantizer
)
ref_out, ref_rsigma = _jax_rmsnorm(
x,
gamma,
zero_centered_gamma,
epsilon,
quantizer=ref_quantizer,
)
ref_mu = None
precise_comparison = True
if get_cudnn_version() < (9, 10, 0) and scaling_mode == ScalingMode.MXFP8_1D_SCALING:
# Reduce precision of test as we don't use fused norm below this version CuDNN for MXFP8 and instead
# do an unfused norm and quantize with an intermediate cast into in_dtype which can reduce precision
precise_comparison = False
elif is_norm_zero_centered_gamma_in_weight_dtype(scaling_mode):
# Larger tolerances as our JAX implementation _jax_*norm uses the compute dtype float32
# for zero-centered gamma always
precise_comparison = False
elif scaling_mode == ScalingMode.CURRENT_TENSOR_SCALING and inp_dtype != jnp.float32:
# Current implementation of Current Tensor Scaling performs unfused layernorm and quantization
# and writes intermediate results into the input dtype, which will slightly reduce precision
# if the input dtype is not float32
precise_comparison = False
assert_bitwise_scaled_tensors(output, ref_out, precise_comparison=precise_comparison)
assert_allclose(rsigma, ref_rsigma, dtype=inp_dtype)
if norm_type == "layernorm":
assert_allclose(mu, ref_mu, dtype=inp_dtype)
@pytest.mark.skipif(not is_fp8_supported, reason=fp8_unsupported_reason)
# No Norm FWD E5M2 in TE backend
@pytest_parametrize_wrapper("out_dtype", [jnp.float8_e4m3fn])
@pytest_parametrize_wrapper(
"q_layout", [QuantizeLayout.ROWWISE, QuantizeLayout.ROWWISE_COLWISE]
)
@pytest_parametrize_wrapper(
"scaling_mode", [ScalingMode.DELAYED_TENSOR_SCALING, ScalingMode.CURRENT_TENSOR_SCALING]
)
def test_norm_forward_with_tensor_scaling_fp8(
self,
n,
hidden,
norm_type,
zero_centered_gamma,
epsilon,
inp_dtype,
out_dtype,
q_layout,
scaling_mode,
):
if norm_type == "rmsnorm" and zero_centered_gamma is True:
pytest.skip("RMSNorm and zero_centered_gamma is not supported!")
self._test_norm_forward(
n=n,
hidden=hidden,
norm_type=norm_type,
zero_centered_gamma=zero_centered_gamma,
epsilon=epsilon,
inp_dtype=inp_dtype,
out_dtype=out_dtype,
scaling_mode=scaling_mode,
q_layout=q_layout,
)
@pytest.mark.skipif(not is_mxfp8_supported, reason=mxfp8_unsupported_reason)
@pytest.mark.parametrize(
"out_dtype",
[
jnp.float8_e4m3fn,
],
)
def test_norm_forward_with_block_scaling_fp8(
self, n, hidden, norm_type, zero_centered_gamma, epsilon, inp_dtype, out_dtype
):
self._test_norm_forward(
n=n,
hidden=hidden,
norm_type=norm_type,
zero_centered_gamma=zero_centered_gamma,
epsilon=epsilon,
inp_dtype=inp_dtype,
out_dtype=out_dtype,
scaling_mode=ScalingMode.MXFP8_1D_SCALING,
q_layout=QuantizeLayout.ROWWISE_COLWISE,
)
QUANTIZE_OUTPUT_FP8_DTYPES = {
"L0": [jnp.float8_e4m3fn],
"L2": [jnp.float8_e4m3fn, jnp.float8_e5m2],
}
QUANTIZE_OUTPUT_DTYPES = {
test_level: QUANTIZE_OUTPUT_FP8_DTYPES[test_level] + [jnp.float4_e2m1fn]
for test_level in QUANTIZE_OUTPUT_FP8_DTYPES
}
QUANTIZE_QDTYPE_AND_SCALING_MODES = {
test_level: [
(q_dtype, scaling_mode)
for q_dtype, scaling_mode in zip(
QUANTIZE_OUTPUT_FP8_DTYPES[test_level], supported_scaling_modes
)
if q_dtype in scaling_mode.get_compatible_q_dtypes()
]
for test_level in QUANTIZE_OUTPUT_FP8_DTYPES
}
ALL_QUANTIZE_TEST_SHAPES_AND_FLATTEN_AXES = [
((32, 64), -1),
((2, 64, 32), -1),
((64, 2, 32), -2),
((32, 256, 128), -1),
((32, 256, 128), -2),
((64, 32, 32, 256), -1),
((8192, 2, 4096), -2),
]
QUANTIZE_TEST_SHAPES_AND_FLATTEN_AXES = {
"L0": [
((32, 64), -1),
((2, 64, 32), -1),
((64, 2, 32), -2),
],
"L2": ALL_QUANTIZE_TEST_SHAPES_AND_FLATTEN_AXES,
}
QUANTIZATION_INPUT_DTYPE = {
"L0": [jnp.bfloat16],
"L2": [jnp.float32, jnp.float16, jnp.bfloat16],
}
@pytest.mark.skipif(not is_fp8_supported, reason=fp8_unsupported_reason)
@pytest_parametrize_wrapper("in_dtype", QUANTIZATION_INPUT_DTYPE)
@pytest_parametrize_wrapper("q_dtype", [jnp.float8_e4m3fn, jnp.float8_e5m2, jnp.float4_e2m1fn])
@pytest_parametrize_wrapper("input_shape,flatten_axis", ALL_QUANTIZE_TEST_SHAPES_AND_FLATTEN_AXES)
@pytest_parametrize_wrapper("scaling_mode", supported_scaling_modes)
@pytest_parametrize_wrapper(
"q_layout",
[
QuantizeLayout.ROWWISE,
QuantizeLayout.COLWISE,
QuantizeLayout.ROWWISE_COLWISE,
],
)
class TestQuantize:
"""
Purely quantization related tests that will always test on a wider set of types and shapes
"""
def _skip_unsupported_dtypes(self, q_dtype, scaling_mode):
"""Skip unsupported dtypes for given scaling mode. For example, NVFP4 only supports the float4_e2m1 dtype not float8 dtypes."""
if q_dtype not in scaling_mode.get_compatible_q_dtypes():
pytest.skip(f"Quantize dtype {q_dtype} is not supported by {scaling_mode}")
return
def test_qdq(self, in_dtype, input_shape, q_dtype, scaling_mode, q_layout, flatten_axis):
self._skip_unsupported_dtypes(q_dtype, scaling_mode)
key = jax.random.PRNGKey(0)
# Quantizer is created once as some quantization approaches use state from previous iterations (e.g. delayed scaling)
quantizer = QuantizerFactory.create(
scaling_mode=scaling_mode,
q_dtype=q_dtype,
q_layout=q_layout,
)
if scaling_mode.is_nvfp4_scaling:
if in_dtype != jnp.bfloat16:
pytest.skip("NVFP4 scaling only supported with bfloat16 input dtype currently")
return
q_func = _jax_quantize
# For NVFP4 scaling, the maximum possible error for a single value can be high between the dequantized and original tensors. To ensure quantization and dequantization is operating correctly without requiring a very high tolerance for all values, we instead test that quantizing the dequantized tensor is bitwise identical to the original quantized tensor.
x = jax.random.uniform(key, input_shape, in_dtype) * 10
q1 = q_func(x, quantizer=quantizer, flatten_axis=flatten_axis)
dq_rowwise = None
dq_colwise = None
if isinstance(q1, ScaledTensor1x):
dq = q1.dequantize()
if q1.is_colwise:
dq_colwise = dq
else:
dq_rowwise = dq
elif isinstance(q1, ScaledTensor2x):
dq_rowwise = q1.rowwise_tensor.dequantize()
dq_colwise = q1.colwise_tensor.dequantize()
else:
raise ValueError(f"Unsupported output type {type(q1)}")
# We only compare Q-DQ for the same quantization layout. If we for example QDQ rowwise, then re-quantize colwise, the error will be larger and may not be bitwise identical to the original colwise quantization.
if dq_rowwise is not None:
assert (
dq_rowwise.shape == x.shape
), f"dq_rowwise shape {dq_rowwise.shape} != x shape {x.shape}"
q2_rowwise = q_func(dq_rowwise, quantizer=quantizer, flatten_axis=flatten_axis)
q2_rowwise = (
q2_rowwise
if isinstance(q2_rowwise, ScaledTensor1x)
else q2_rowwise.rowwise_tensor
)
q1_rowwise = q1 if isinstance(q1, ScaledTensor1x) else q1.rowwise_tensor
assert_bitwise_scaled_tensors(q1_rowwise, q2_rowwise)
if dq_colwise is not None:
# Since this is for NVFP4, we are assuming colwise has T layout and we do a transpose here to get back to original shape
flatten_axis = flatten_axis + len(input_shape) if flatten_axis < 0 else flatten_axis
colwise_flatten_axis = len(input_shape) - flatten_axis
dq_colwise = jnp.transpose(
dq_colwise,
(*range(colwise_flatten_axis, dq_colwise.ndim), *range(colwise_flatten_axis)),
)
assert (
dq_colwise.shape == x.shape
), f"dq_colwise shape {dq_colwise.shape} != x shape {x.shape}"
q2_colwise = q_func(dq_colwise, quantizer=quantizer, flatten_axis=flatten_axis)
q2_colwise = (
q2_colwise
if isinstance(q2_colwise, ScaledTensor1x)
else q2_colwise.colwise_tensor
)
q1_colwise = q1 if isinstance(q1, ScaledTensor1x) else q1.colwise_tensor
assert_bitwise_scaled_tensors(q1_colwise, q2_colwise)
assert (
dq_rowwise is not None or dq_colwise is not None
), "At least one of rowwise or colwise dq must be not None"
return
n_iterations = 3 if scaling_mode == ScalingMode.DELAYED_TENSOR_SCALING else 1
for _ in range(n_iterations):
x = jax.random.uniform(key, input_shape, in_dtype)
scaled_tensor = quantizer.quantize(x, flatten_axis=flatten_axis)
assert_dequantized_scaled_tensor(scaled_tensor, x)
def _should_use_precise_comparison(
self, in_dtype, scaling_mode, quantizer, input_shape, flatten_axis
):
if scaling_mode.is_nvfp4_scaling and in_dtype != jnp.bfloat16:
# With NVFP4 scaling, TE kernels internally use bfloat16 so using a different input dtype can lead to small numerical differences compared to the JAX implementation
return False
return True
def test_quantize_bitwise(
self, in_dtype, input_shape, q_dtype, scaling_mode, q_layout, flatten_axis
):
self._skip_unsupported_dtypes(q_dtype, scaling_mode)
key = jax.random.PRNGKey(0)
input = jax.random.uniform(key, input_shape, in_dtype)
te_quantizer, jax_quantizer = QuantizerFactory.create(
n_quantizers=2, q_dtype=q_dtype, scaling_mode=scaling_mode, q_layout=q_layout
)
jax_output = _jax_quantize(input, quantizer=jax_quantizer, flatten_axis=flatten_axis)
te_output = tex.quantize(input, quantizer=te_quantizer, flatten_axis=flatten_axis)
assert_bitwise_scaled_tensors(
te_output,
jax_output,
precise_comparison=self._should_use_precise_comparison(
in_dtype, scaling_mode, te_quantizer, input_shape, flatten_axis
),
)
def test_quantize_bitwise_jitted(
self, in_dtype, input_shape, q_dtype, scaling_mode, q_layout, flatten_axis
):
self._skip_unsupported_dtypes(q_dtype, scaling_mode)
key = jax.random.PRNGKey(0)
input = jax.random.uniform(key, input_shape, in_dtype)
te_quantizer, jax_quantizer = QuantizerFactory.create(
n_quantizers=2, q_dtype=q_dtype, scaling_mode=scaling_mode, q_layout=q_layout
)
jax_impl_func_jit = jax.jit(_jax_quantize, static_argnums=(2, 3))
te_impl_func_jit = jax.jit(tex.quantize, static_argnums=(2,))
jax_output = jax_impl_func_jit(input, quantizer=jax_quantizer, flatten_axis=flatten_axis)
te_output = te_impl_func_jit(input, quantizer=te_quantizer, flatten_axis=flatten_axis)
assert_bitwise_scaled_tensors(
te_output,
jax_output,
precise_comparison=self._should_use_precise_comparison(
in_dtype, scaling_mode, te_quantizer, input_shape, flatten_axis
),
)
@pytest_parametrize_wrapper("in_dtype", [jnp.bfloat16])
@pytest_parametrize_wrapper("q_dtype", [jnp.float4_e2m1fn])
@pytest_parametrize_wrapper("input_shape,flatten_axis", ALL_QUANTIZE_TEST_SHAPES_AND_FLATTEN_AXES)
@pytest_parametrize_wrapper(
"scaling_mode", [s for s in supported_scaling_modes if s.is_nvfp4_scaling]
)
@pytest_parametrize_wrapper(
"q_layout", [QuantizeLayout.ROWWISE, QuantizeLayout.COLWISE, QuantizeLayout.ROWWISE_COLWISE]
)
class TestStochasticRounding:
def _dequantize(self, scaled_tensor) -> list[jnp.ndarray]:
"""Dequantizes a ScaledTensor back to it's original jnp.ndarray form. This always returns an array of jnp.ndarrays, for ScaledTensor2x there will be two tensors, for ScaledTensor1x there will be one tensor."""
if isinstance(scaled_tensor, ScaledTensor1x):
dq = scaled_tensor.dequantize()
if scaled_tensor.data_layout == "T":
dq = jnp.transpose(
dq,
(
*range(scaled_tensor.flatten_axis, dq.ndim),
*range(scaled_tensor.flatten_axis),
),
)
return [dq]
elif isinstance(scaled_tensor, ScaledTensor2x):
[rowwise_dq] = self._dequantize(scaled_tensor.rowwise_tensor)
[colwise_dq] = self._dequantize(scaled_tensor.colwise_tensor)
return [rowwise_dq, colwise_dq]
raise ValueError(
"Unsupported ScaledTensor type, expected ScaledTensor but received"
f" {type(scaled_tensor)}"
)
def _sample_sr_qdq(
self, num_samples, q_func, inputs, q_dtype, scaling_mode, q_layout, flatten_axis
) -> list[jnp.ndarray]:
"""Samples num_samples quantize-dequantize operations with stochastic rounding enabled and returns the dequantized tensors."""
dq_tensors = []
key = jax.random.PRNGKey(0)
for i in range(num_samples):
iter_key = jax.random.fold_in(key, i)
sr_rng_state = jax.random.randint(
iter_key, (1, 4), minval=0, maxval=2**30 - 1, dtype=jnp.uint32
)
quantizer = QuantizerFactory.create(
q_dtype=q_dtype,
scaling_mode=scaling_mode,
q_layout=q_layout,
stochastic_rounding_rng_state=sr_rng_state,
)
q_output = q_func(inputs, quantizer=quantizer, flatten_axis=flatten_axis)
iter_dq = self._dequantize(q_output)
dq_tensors.extend(iter_dq)
avg_sr_tensor = jnp.mean(jnp.stack(dq_tensors), axis=0)
assert avg_sr_tensor.shape == inputs.shape, (
f"Dequantized tensor shape {avg_sr_tensor.shape} does not match input shape"
f" {inputs.shape}"
)
sr_mae = jnp.mean(jnp.abs(avg_sr_tensor - inputs))
dq_var = jnp.var(jnp.stack(dq_tensors))
assert (
dq_var > 0
), "Variance of dequantized tensors is zero, stochastic rounding may not be working"
return dq_tensors
def _round_nearest(
self, q_func, inputs, q_dtype, scaling_mode, q_layout, flatten_axis
) -> jnp.ndarray:
"""Quantizes and dequantizes the input tensor with round nearest quantization."""
quantizer = QuantizerFactory.create(
q_dtype=q_dtype,
scaling_mode=scaling_mode,
q_layout=q_layout,
stochastic_rounding_rng_state=None,
)
q_output = q_func(inputs, quantizer=quantizer, flatten_axis=flatten_axis)
rn_dq = self._dequantize(q_output)[0]
return rn_dq
def _test_sr(
self, num_samples, q_func, inputs, q_dtype, scaling_mode, q_layout, flatten_axis
) -> float:
"""Tests that the mean absolute error (MAE) of stochastic rounding is smaller than round nearest quantization over multiple samples."""
dq_tensors = self._sample_sr_qdq(
num_samples, q_func, inputs, q_dtype, scaling_mode, q_layout, flatten_axis
)
avg_sr_tensor = jnp.mean(jnp.stack(dq_tensors).astype(jnp.float32), axis=0)
assert avg_sr_tensor.shape == inputs.shape, (
f"Dequantized tensor shape {avg_sr_tensor.shape} does not match input shape"
f" {inputs.shape}"
)
round_nearest_tensor = self._round_nearest(
q_func, inputs, q_dtype, scaling_mode, q_layout, flatten_axis
)
sr_mae = jnp.mean(jnp.abs(avg_sr_tensor - inputs))
rn_mae = jnp.mean(jnp.abs(round_nearest_tensor - inputs))
assert sr_mae < rn_mae, (
f"Mean absolute error of stochastic rounding ({sr_mae}) is not smaller than"
f" round nearest ({rn_mae})"
)
return sr_mae
def test_sr_nvfp4(self, in_dtype, input_shape, q_dtype, scaling_mode, q_layout, flatten_axis):
"""Tests that the mean absolute error of stochastic rounding is smaller than round nearest quantization over multiple samples for both TE and JAX implementations. Asserts that the MAE of both implementations is close to each other."""
key = jax.random.PRNGKey(0)
inputs = jax.random.uniform(key, input_shape, in_dtype)
NUM_SAMPLES = 10
te_mean_error = self._test_sr(
NUM_SAMPLES, tex.quantize, inputs, q_dtype, scaling_mode, q_layout, flatten_axis
)
jax_mean_error = self._test_sr(
NUM_SAMPLES, _jax_quantize, inputs, q_dtype, scaling_mode, q_layout, flatten_axis
)
assert_allclose(te_mean_error, jax_mean_error, rtol=0.2, atol=1e-4)
@pytest_parametrize_wrapper("in_dtype", [jnp.bfloat16])
@pytest_parametrize_wrapper("q_dtype", [jnp.float4_e2m1fn])
@pytest_parametrize_wrapper(
"scaling_mode", [s for s in supported_scaling_modes if s == ScalingMode.NVFP4_1D_SCALING]
)
class TestRandomizedHadamardTransform:
@pytest_parametrize_wrapper(
"q_layout", [QuantizeLayout.ROWWISE_COLWISE, QuantizeLayout.COLWISE]
)
@pytest_parametrize_wrapper("input_shape,flatten_axis", [((64, 128), -1)])
def test_rht_quantize_bitwise_jitted(
self, in_dtype, q_dtype, scaling_mode, q_layout, input_shape, flatten_axis
):
key = jax.random.PRNGKey(0)
inputs = jax.random.uniform(key, input_shape, in_dtype)
te_quantizer, jax_quantizer = QuantizerFactory.create(
n_quantizers=2,
q_dtype=q_dtype,
scaling_mode=scaling_mode,
q_layout=q_layout,
use_rht=True,
)
jax_impl_func_jit = jax.jit(_jax_quantize, static_argnums=(2, 3))
te_impl_func_jit = jax.jit(tex.quantize, static_argnums=(2,))