@@ -125,6 +125,14 @@ def convert_dtype_to_torch_type(self, dtype):
125125 return torch .complex64
126126 elif dtype in ["complex128" , numpy .complex128 ]:
127127 return torch .complex128
128+ elif dtype in ["float8_e4m3fn" ]:
129+ if hasattr (torch , "float8_e4m3fn" ):
130+ return torch .float8_e4m3fn
131+ return torch .float32 # fallback
132+ elif dtype in ["float8_e5m2" ]:
133+ if hasattr (torch , "float8_e5m2" ):
134+ return torch .float8_e5m2
135+ return torch .float32 # fallback
128136 else :
129137 raise ValueError (f"Unsupport dtype: { dtype } " )
130138
@@ -206,12 +214,15 @@ def get_numpy_tensor(self, api_config, index=None, key=None, **kwargs):
206214 if key is not None :
207215 self .key = key
208216
209- if self .dtype in ["float8_e5m2" , "float8_e4m3fn" ]:
210- print ("Warning " , self .dtype , "not supported" )
211- return
217+ # if self.dtype in ["float8_e5m2", "float8_e4m3fn"]:
218+ # print("Warning ", self.dtype, "not supported")
219+ # return
212220
213221 original_dtype = self .dtype
214- self .dtype = "float32" if self .dtype == "bfloat16" else self .dtype
222+ if self .dtype == "bfloat16" :
223+ self .dtype = "float32"
224+ elif self .dtype in ["float8_e5m2" , "float8_e4m3fn" ]:
225+ self .dtype = "float32"
215226
216227 if self .numpy_tensor is None :
217228 if api_config .api_name in not_zero_apis :
@@ -2621,39 +2632,67 @@ def get_exponent_max(value, dtype_max, default_max=5):
26212632 self .dtype
26222633 )
26232634
2635+ if original_dtype == "float8_e4m3fn" :
2636+ self .numpy_tensor = numpy .clip (self .numpy_tensor , - 448 , 448 )
2637+ elif original_dtype == "float8_e5m2" :
2638+ self .numpy_tensor = numpy .clip (self .numpy_tensor , - 57344 , 57344 )
2639+
26242640 self .dtype = original_dtype
26252641 return self .numpy_tensor
26262642
26272643 def get_paddle_tensor (self , api_config ):
2628- if self .dtype in ["float8_e5m2" , "float8_e4m3fn" ]:
2629- print ("Warning " , self .dtype , "not supported" )
2630- return
2644+ # if self.dtype in ["float8_e5m2", "float8_e4m3fn"]:
2645+ # print("Warning ", self.dtype, "not supported")
2646+ # return
26312647
26322648 if self .paddle_tensor is None :
2649+ np_tensor = self .get_numpy_tensor (api_config )
2650+ #print(f"[DEBUG] Numpy Tensor for {self.dtype}: {np_tensor} dtype={np_tensor.dtype}")
2651+
2652+ # Use float32 as intermediate for float8
2653+ intermediate_dtype = self .dtype
2654+ if self .dtype == "bfloat16" :
2655+ intermediate_dtype = "float32"
2656+ elif self .dtype in ["float8_e5m2" , "float8_e4m3fn" ]:
2657+ intermediate_dtype = "float32"
2658+
26332659 self .paddle_tensor = paddle .to_tensor (
2634- self . get_numpy_tensor ( api_config ) ,
2635- dtype = "float32" if self . dtype == "bfloat16" else self . dtype ,
2660+ np_tensor ,
2661+ dtype = intermediate_dtype ,
26362662 place = self .place ,
26372663 )
26382664
26392665 self .paddle_tensor .stop_gradient = False
26402666 if self .dtype == "bfloat16" :
26412667 self .paddle_tensor = paddle .cast (self .paddle_tensor , dtype = "bfloat16" )
2668+ elif self .dtype == "float8_e4m3fn" :
2669+ #print(f"[DEBUG] Before Paddle Cast (float8_e4m3fn): {self.paddle_tensor}\n[DEBUG] dtype check: {self.paddle_tensor.dtype}", flush=True)
2670+ self .paddle_tensor = paddle .cast (self .paddle_tensor , dtype = "float8_e4m3fn" )
2671+ #print(f"[DEBUG] Forward Paddle Input Tensor (float8_e4m3fn): {self.paddle_tensor}\n[DEBUG] dtype check: {self.paddle_tensor.dtype}", flush=True)
2672+ elif self .dtype == "float8_e5m2" :
2673+ #print(f"[DEBUG] Before Paddle Cast (float8_e5m2): {self.paddle_tensor}\n[DEBUG] dtype check: {self.paddle_tensor.dtype}", flush=True)
2674+ self .paddle_tensor = paddle .cast (self .paddle_tensor , dtype = "float8_e5m2" )
2675+ #print(f"[DEBUG] Forward Paddle Input Tensor (float8_e5m2): {self.paddle_tensor}\n[DEBUG] dtype check: {self.paddle_tensor.dtype}", flush=True)
26422676 return self .paddle_tensor
26432677
26442678 def get_torch_tensor (self , api_config ):
2645- if self .dtype in ["float8_e5m2" , "float8_e4m3fn " ]:
2646- print ("Warning " , self .dtype , "not supported" )
2647- return
2679+ # if self.dtype in ["float8_e5m2"]:
2680+ # print("Warning ", self.dtype, "not supported")
2681+ # return
26482682
26492683 device = torch .device ("cuda:0" ) if torch .cuda .is_available () else torch .device ("cpu" )
26502684 torch .set_default_device (device )
26512685 if self .torch_tensor is None :
2686+ if self .dtype in ["bfloat16" , "float8_e4m3fn" , "float8_e5m2" ]:
2687+ dtype_to_use = torch .float32
2688+ else :
2689+ dtype_to_use = self .convert_dtype_to_torch_type (self .dtype )
2690+
2691+ #print(f"[DEBUG] Preparing Torch Tensor for {self.dtype}, using initial dtype {dtype_to_use}")
2692+
26522693 self .torch_tensor = torch .tensor (
26532694 self .get_numpy_tensor (api_config ),
2654- dtype = self .convert_dtype_to_torch_type (self .dtype )
2655- if self .dtype != "bfloat16"
2656- else torch .float32 ,
2695+ dtype = dtype_to_use ,
26572696 requires_grad = self .dtype
26582697 in [
26592698 "float32" ,
@@ -2666,6 +2705,21 @@ def get_torch_tensor(self, api_config):
26662705 )
26672706 if self .dtype == "bfloat16" :
26682707 self .torch_tensor = self .torch_tensor .to (dtype = torch .bfloat16 )
2708+ elif self .dtype == "float8_e4m3fn" :
2709+ if hasattr (torch , "float8_e4m3fn" ):
2710+ #print(f"[DEBUG] Before Torch Cast (float8_e4m3fn): {self.torch_tensor.dtype} data={self.torch_tensor}", flush=True)
2711+ self .torch_tensor = self .torch_tensor .to (dtype = torch .float8_e4m3fn )
2712+ # print(f"[DEBUG] Forward Torch Input Tensor (float8_e4m3fn): {self.torch_tensor}\n[DEBUG] dtype check: {self.torch_tensor.dtype}", flush=True)
2713+ else :
2714+ print ("[DEBUG] Warning: Current torch version does not support float8_e4m3fn, keep float32/float16." , flush = True )
2715+ elif self .dtype == "float8_e5m2" :
2716+ if hasattr (torch , "float8_e5m2" ):
2717+ #print(f"[DEBUG] Before Torch Cast (float8_e5m2): {self.torch_tensor.dtype} data={self.torch_tensor}", flush=True)
2718+ self .torch_tensor = self .torch_tensor .to (dtype = torch .float8_e5m2 )
2719+ # print(f"[DEBUG] Forward Torch Input Tensor (float8_e5m2): {self.torch_tensor}\n[DEBUG] dtype check: {self.torch_tensor.dtype}", flush=True)
2720+ else :
2721+ print ("[DEBUG] Warning: Current torch version does not support float8_e5m2, keep float32/float16." , flush = True )
2722+
26692723 return self .torch_tensor
26702724
26712725 def clear_tensor (self ):
0 commit comments