Skip to content

Commit b79515c

Browse files
committed
feat: add float8 dtype support
1 parent 9627c49 commit b79515c

File tree

2 files changed

+126
-33
lines changed

2 files changed

+126
-33
lines changed

tester/api_config/config_analyzer.py

Lines changed: 69 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -125,6 +125,14 @@ def convert_dtype_to_torch_type(self, dtype):
125125
return torch.complex64
126126
elif dtype in ["complex128", numpy.complex128]:
127127
return torch.complex128
128+
elif dtype in ["float8_e4m3fn"]:
129+
if hasattr(torch, "float8_e4m3fn"):
130+
return torch.float8_e4m3fn
131+
return torch.float32 # fallback
132+
elif dtype in ["float8_e5m2"]:
133+
if hasattr(torch, "float8_e5m2"):
134+
return torch.float8_e5m2
135+
return torch.float32 # fallback
128136
else:
129137
raise ValueError(f"Unsupport dtype: {dtype}")
130138

@@ -206,12 +214,15 @@ def get_numpy_tensor(self, api_config, index=None, key=None, **kwargs):
206214
if key is not None:
207215
self.key = key
208216

209-
if self.dtype in ["float8_e5m2", "float8_e4m3fn"]:
210-
print("Warning ", self.dtype, "not supported")
211-
return
217+
# if self.dtype in ["float8_e5m2", "float8_e4m3fn"]:
218+
# print("Warning ", self.dtype, "not supported")
219+
# return
212220

213221
original_dtype = self.dtype
214-
self.dtype = "float32" if self.dtype == "bfloat16" else self.dtype
222+
if self.dtype == "bfloat16":
223+
self.dtype = "float32"
224+
elif self.dtype in ["float8_e5m2", "float8_e4m3fn"]:
225+
self.dtype = "float32"
215226

216227
if self.numpy_tensor is None:
217228
if api_config.api_name in not_zero_apis:
@@ -2621,39 +2632,67 @@ def get_exponent_max(value, dtype_max, default_max=5):
26212632
self.dtype
26222633
)
26232634

2635+
if original_dtype == "float8_e4m3fn":
2636+
self.numpy_tensor = numpy.clip(self.numpy_tensor, -448, 448)
2637+
elif original_dtype == "float8_e5m2":
2638+
self.numpy_tensor = numpy.clip(self.numpy_tensor, -57344, 57344)
2639+
26242640
self.dtype = original_dtype
26252641
return self.numpy_tensor
26262642

26272643
def get_paddle_tensor(self, api_config):
2628-
if self.dtype in ["float8_e5m2", "float8_e4m3fn"]:
2629-
print("Warning ", self.dtype, "not supported")
2630-
return
2644+
# if self.dtype in ["float8_e5m2", "float8_e4m3fn"]:
2645+
# print("Warning ", self.dtype, "not supported")
2646+
# return
26312647

26322648
if self.paddle_tensor is None:
2649+
np_tensor = self.get_numpy_tensor(api_config)
2650+
#print(f"[DEBUG] Numpy Tensor for {self.dtype}: {np_tensor} dtype={np_tensor.dtype}")
2651+
2652+
# Use float32 as intermediate for float8
2653+
intermediate_dtype = self.dtype
2654+
if self.dtype == "bfloat16":
2655+
intermediate_dtype = "float32"
2656+
elif self.dtype in ["float8_e5m2", "float8_e4m3fn"]:
2657+
intermediate_dtype = "float32"
2658+
26332659
self.paddle_tensor = paddle.to_tensor(
2634-
self.get_numpy_tensor(api_config),
2635-
dtype="float32" if self.dtype == "bfloat16" else self.dtype,
2660+
np_tensor,
2661+
dtype=intermediate_dtype,
26362662
place=self.place,
26372663
)
26382664

26392665
self.paddle_tensor.stop_gradient = False
26402666
if self.dtype == "bfloat16":
26412667
self.paddle_tensor = paddle.cast(self.paddle_tensor, dtype="bfloat16")
2668+
elif self.dtype == "float8_e4m3fn":
2669+
#print(f"[DEBUG] Before Paddle Cast (float8_e4m3fn): {self.paddle_tensor}\n[DEBUG] dtype check: {self.paddle_tensor.dtype}", flush=True)
2670+
self.paddle_tensor = paddle.cast(self.paddle_tensor, dtype="float8_e4m3fn")
2671+
#print(f"[DEBUG] Forward Paddle Input Tensor (float8_e4m3fn): {self.paddle_tensor}\n[DEBUG] dtype check: {self.paddle_tensor.dtype}", flush=True)
2672+
elif self.dtype == "float8_e5m2":
2673+
#print(f"[DEBUG] Before Paddle Cast (float8_e5m2): {self.paddle_tensor}\n[DEBUG] dtype check: {self.paddle_tensor.dtype}", flush=True)
2674+
self.paddle_tensor = paddle.cast(self.paddle_tensor, dtype="float8_e5m2")
2675+
#print(f"[DEBUG] Forward Paddle Input Tensor (float8_e5m2): {self.paddle_tensor}\n[DEBUG] dtype check: {self.paddle_tensor.dtype}", flush=True)
26422676
return self.paddle_tensor
26432677

26442678
def get_torch_tensor(self, api_config):
2645-
if self.dtype in ["float8_e5m2", "float8_e4m3fn"]:
2646-
print("Warning ", self.dtype, "not supported")
2647-
return
2679+
# if self.dtype in ["float8_e5m2"]:
2680+
# print("Warning ", self.dtype, "not supported")
2681+
# return
26482682

26492683
device = torch.device("cuda:0") if torch.cuda.is_available() else torch.device("cpu")
26502684
torch.set_default_device(device)
26512685
if self.torch_tensor is None:
2686+
if self.dtype in ["bfloat16", "float8_e4m3fn", "float8_e5m2"]:
2687+
dtype_to_use = torch.float32
2688+
else:
2689+
dtype_to_use = self.convert_dtype_to_torch_type(self.dtype)
2690+
2691+
#print(f"[DEBUG] Preparing Torch Tensor for {self.dtype}, using initial dtype {dtype_to_use}")
2692+
26522693
self.torch_tensor = torch.tensor(
26532694
self.get_numpy_tensor(api_config),
2654-
dtype=self.convert_dtype_to_torch_type(self.dtype)
2655-
if self.dtype != "bfloat16"
2656-
else torch.float32,
2695+
dtype=dtype_to_use,
26572696
requires_grad=self.dtype
26582697
in [
26592698
"float32",
@@ -2666,6 +2705,21 @@ def get_torch_tensor(self, api_config):
26662705
)
26672706
if self.dtype == "bfloat16":
26682707
self.torch_tensor = self.torch_tensor.to(dtype=torch.bfloat16)
2708+
elif self.dtype == "float8_e4m3fn":
2709+
if hasattr(torch, "float8_e4m3fn"):
2710+
#print(f"[DEBUG] Before Torch Cast (float8_e4m3fn): {self.torch_tensor.dtype} data={self.torch_tensor}", flush=True)
2711+
self.torch_tensor = self.torch_tensor.to(dtype=torch.float8_e4m3fn)
2712+
# print(f"[DEBUG] Forward Torch Input Tensor (float8_e4m3fn): {self.torch_tensor}\n[DEBUG] dtype check: {self.torch_tensor.dtype}", flush=True)
2713+
else:
2714+
print("[DEBUG] Warning: Current torch version does not support float8_e4m3fn, keep float32/float16.", flush=True)
2715+
elif self.dtype == "float8_e5m2":
2716+
if hasattr(torch, "float8_e5m2"):
2717+
#print(f"[DEBUG] Before Torch Cast (float8_e5m2): {self.torch_tensor.dtype} data={self.torch_tensor}", flush=True)
2718+
self.torch_tensor = self.torch_tensor.to(dtype=torch.float8_e5m2)
2719+
# print(f"[DEBUG] Forward Torch Input Tensor (float8_e5m2): {self.torch_tensor}\n[DEBUG] dtype check: {self.torch_tensor.dtype}", flush=True)
2720+
else:
2721+
print("[DEBUG] Warning: Current torch version does not support float8_e5m2, keep float32/float16.", flush=True)
2722+
26692723
return self.torch_tensor
26702724

26712725
def clear_tensor(self):

tester/base.py

Lines changed: 57 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -82,7 +82,7 @@ def need_skip(self, paddle_only=False):
8282
return True
8383
for i in range(len(self.api_config.args)):
8484
if isinstance(self.api_config.args[i], TensorConfig):
85-
if self.api_config.args[i].dtype in ["float8_e5m2", "float8_e4m3fn"]:
85+
if self.api_config.args[i].dtype in ["float8_e5m2"]:
8686
return True
8787
elif isinstance(self.api_config.args[i], list) or isinstance(
8888
self.api_config.args[i], tuple
@@ -91,31 +91,26 @@ def need_skip(self, paddle_only=False):
9191
if isinstance(self.api_config.args[i][j], TensorConfig):
9292
if self.api_config.args[i][j].dtype in [
9393
"float8_e5m2",
94-
"float8_e4m3fn",
9594
]:
9695
return True
9796
elif self.api_config.args[i] in [
98-
paddle.base.core.DataType.FLOAT8_E4M3FN,
9997
paddle.base.core.DataType.FLOAT8_E5M2,
10098
"float8_e5m2",
101-
"float8_e4m3fn",
10299
]:
103100
return True
104101

105102
for _key, arg_config in self.api_config.kwargs.items():
106103
if isinstance(arg_config, TensorConfig):
107-
if arg_config.dtype in ["float8_e5m2", "float8_e4m3fn"]:
104+
if arg_config.dtype in ["float8_e5m2"]:
108105
return True
109106
elif isinstance(arg_config, (list, tuple)):
110107
for i in range(len(arg_config)):
111108
if isinstance(arg_config[i], TensorConfig):
112-
if arg_config[i].dtype in ["float8_e5m2", "float8_e4m3fn"]:
109+
if arg_config[i].dtype in ["float8_e5m2"]:
113110
return True
114111
elif arg_config in [
115-
paddle.base.core.DataType.FLOAT8_E4M3FN,
116112
paddle.base.core.DataType.FLOAT8_E5M2,
117113
"float8_e5m2",
118-
"float8_e4m3fn",
119114
]:
120115
return True
121116

@@ -617,26 +612,39 @@ def gen_paddle_output_and_output_grad(self, outputs):
617612
for output in result_outputs:
618613
dtype = str(output.dtype).split(".")[-1]
619614
if USE_CACHED_NUMPY:
620-
dtype = "float32" if dtype == "bfloat16" else dtype
615+
dtype = (
616+
"float32"
617+
if dtype in ["bfloat16", "float8_e4m3fn", "float8_e5m2"]
618+
else dtype
619+
)
621620
numpy_tensor = self.get_cached_numpy(dtype, output.shape)
622621
else:
623622
if "int" in dtype:
624623
numpy_tensor = (
625624
numpy.random.randint(-65535, 65535, size=output.shape)
626625
).astype(dtype)
627626
else:
628-
dtype = "float32" if dtype == "bfloat16" else dtype
627+
dtype = (
628+
"float32"
629+
if dtype in ["bfloat16", "float8_e4m3fn", "float8_e5m2"]
630+
else dtype
631+
)
629632
numpy_tensor = (numpy.random.random(output.shape) - 0.5).astype(dtype)
630633
self.outputs_grad_numpy.append(numpy_tensor)
631634
for i, numpy_tensor in enumerate(self.outputs_grad_numpy):
632635
dtype = str(result_outputs[i].dtype).split(".")[-1]
633636
result_output_grad = paddle.to_tensor(
634637
numpy_tensor,
635-
dtype=dtype if dtype != "bfloat16" else "float32",
638+
dtype=dtype
639+
if dtype not in ["bfloat16", "float8_e4m3fn", "float8_e5m2"]
640+
else "float32",
636641
)
637642
result_output_grad.stop_gradient = False
638643
if dtype == "bfloat16":
639644
result_output_grad = paddle.cast(result_output_grad, dtype="bfloat16")
645+
elif dtype == "float8_e4m3fn":
646+
result_output_grad = paddle.cast(result_output_grad, dtype="float8_e4m3fn")
647+
#print(f"[DEBUG] Backward Paddle Grad Tensor (float8_e4m3fn): {result_output_grad}\n[DEBUG] dtype check: {result_output_grad.dtype}", flush=True)
640648
result_outputs_grads.append(result_output_grad)
641649
return result_outputs, result_outputs_grads
642650

@@ -662,27 +670,43 @@ def gen_torch_output_and_output_grad(self, outputs):
662670
for output in result_outputs:
663671
dtype = str(output.dtype).split(".")[-1]
664672
if USE_CACHED_NUMPY:
665-
dtype = "float32" if dtype == "bfloat16" else dtype
673+
dtype = (
674+
"float32"
675+
if dtype in ["bfloat16", "float8_e4m3fn"]
676+
else dtype
677+
)
666678
numpy_tensor = self.get_cached_numpy(dtype, output.shape)
667679
else:
668680
if "int" in dtype:
669681
numpy_tensor = (
670682
numpy.random.randint(-65535, 65535, size=output.shape)
671683
).astype(dtype)
672684
else:
673-
dtype = "float32" if dtype == "bfloat16" else dtype
685+
dtype = (
686+
"float32"
687+
if dtype in ["bfloat16", "float8_e4m3fn"]
688+
else dtype
689+
)
674690
numpy_tensor = (numpy.random.random(output.shape) - 0.5).astype(dtype)
675691
self.outputs_grad_numpy.append(numpy_tensor)
676692
for i, numpy_tensor in enumerate(self.outputs_grad_numpy):
677693
dtype = str(result_outputs[i].dtype).split(".")[1]
694+
dtype_to_use = (
695+
torch.float32
696+
if dtype in ["bfloat16", "float8_e4m3fn"]
697+
else self.convert_dtype_to_torch_type(dtype)
698+
)
678699
result_output_grad = torch.tensor(
679700
numpy_tensor,
680-
dtype=self.convert_dtype_to_torch_type(dtype)
681-
if dtype != "bfloat16"
682-
else torch.float32,
701+
dtype=dtype_to_use,
683702
)
684703
if dtype == "bfloat16":
685704
result_output_grad = result_output_grad.to(dtype=torch.bfloat16)
705+
elif dtype == "float8_e4m3fn":
706+
if hasattr(torch, "float8_e4m3fn"):
707+
result_output_grad = result_output_grad.to(dtype=torch.float8_e4m3fn)
708+
else:
709+
result_output_grad = result_output_grad.to(dtype=torch.float16)
686710
result_outputs_grads.append(result_output_grad)
687711
return result_outputs, result_outputs_grads
688712

@@ -780,8 +804,14 @@ def convert_dtype_to_torch_type(self, dtype):
780804
complex,
781805
]:
782806
return torch.complex128
783-
elif dtype is None:
784-
return None
807+
elif dtype in ["float8_e4m3fn"]:
808+
if hasattr(torch, "float8_e4m3fn"):
809+
return torch.float8_e4m3fn
810+
return torch.float32
811+
elif dtype in ["float8_e5m2"]:
812+
if hasattr(torch, "float8_e5m2"):
813+
return torch.float8_e5m2
814+
return torch.float32
785815
else:
786816
raise ValueError(f"Unsupport dtype: {dtype}")
787817

@@ -1033,6 +1063,15 @@ def error_msg(msg):
10331063
is_backward = getattr(self, "is_backward", False)
10341064
if test_tol:
10351065
atol, rtol = 0.0, 0.0
1066+
1067+
# [DEBUG] Print tensors before assertion
1068+
if str(torch_tensor.dtype).endswith("float8_e4m3fn"):
1069+
print(f"\n[DEBUG] Comparing Float8 Tensors:", flush=True)
1070+
print(f"[DEBUG] Converted Paddle Tensor: {converted_paddle_tensor}", flush=True)
1071+
print(f"[DEBUG] Paddle dtype: {converted_paddle_tensor.dtype}", flush=True)
1072+
print(f"[DEBUG] Benchmark Torch Tensor: {torch_tensor}", flush=True)
1073+
print(f"[DEBUG] Torch dtype: {torch_tensor.dtype}", flush=True)
1074+
10361075
try:
10371076
torch.testing.assert_close(
10381077
converted_paddle_tensor,

0 commit comments

Comments
 (0)