File tree Expand file tree Collapse file tree 3 files changed +9
-10
lines changed
Expand file tree Collapse file tree 3 files changed +9
-10
lines changed Original file line number Diff line number Diff line change @@ -2039,11 +2039,7 @@ def _prepare_ao(self, *args):
20392039
20402040 # Invariant: with FSDP2, optimizer is always passed to `prepare()` together with model
20412041 # We only precompute scales if float8 all gather is enabled, possibly can add a flag for this later
2042- if (
2043- self .is_fsdp2
2044- and len (optimizers ) > 0
2045- and self .ao_recipe_handler .config .enable_fsdp_float8_all_gather
2046- ):
2042+ if self .is_fsdp2 and len (optimizers ) > 0 and self .ao_recipe_handler .config .enable_fsdp_float8_all_gather :
20472043 from torchao .float8 import precompute_float8_dynamic_scale_for_fsdp
20482044
20492045 optimizers [0 ].register_step_post_hook (
Original file line number Diff line number Diff line change 2828 is_musa_available ,
2929 is_npu_available ,
3030 is_sdaa_available ,
31+ is_torchao_available ,
3132 is_transformer_engine_available ,
3233 is_transformers_available ,
3334 is_xpu_available ,
@@ -794,7 +795,9 @@ def get_cluster_input():
794795 )
795796 if mixed_precision == "fp8" :
796797 if not is_fp8_available ():
797- raise ValueError ("FP8 (either torchao, Transformer Engine or MSAMP) is not installed on this machine." )
798+ raise ValueError (
799+ "FP8 (either torchao, Transformer Engine or MSAMP) is not installed on this machine."
800+ )
798801 fp8_config = {}
799802 fp8_config ["backend" ] = _ask_options (
800803 "Which FP8 backend do you want to use?" ,
@@ -870,9 +873,9 @@ def get_cluster_input():
870873 lambda x : "O1" if x == 0 else "O2" ,
871874 default = 1 ,
872875 )
873-
876+
874877 elif fp8_config ["backend" ] == "AO" :
875- if not is_torch_ao_available ():
878+ if not is_torchao_available ():
876879 raise ValueError ("torchao was selected, but it is not installed on this machine." )
877880 fp8_config ["enable_fsdp_float8_all_gather" ] = _ask_field (
878881 "Do you want to enable FSDP2 float8 all gather? This is recommended for better performance if using FSDP2. [YES/no]: " ,
Original file line number Diff line number Diff line change @@ -321,7 +321,7 @@ class AORecipeKwargs(KwargsHandler):
321321 operations to prevent runtime errors.
322322 - `enable_fsdp_float8_all_gather=True`: Enables FP8 all-gather for FSDP2. This provides memory bandwidth
323323 savings by casting parameters before the all-gather operation, saving 50% bandwidth compared to BF16.
324-
324+
325325 You can override these defaults by providing your own `Float8LinearConfig` instance.
326326 module_filter_func (`Callable`, *optional*, default to `None`):
327327 Optional function that must take in a module and layer name, and returns a boolean indicating whether the
@@ -338,7 +338,7 @@ def __post_init__(self):
338338 env_prefix = "ACCELERATE_FP8_"
339339 if not is_torchao_available ():
340340 raise ImportError ("TorchAO is not available. Please install it or use a different backend." )
341-
341+
342342 if self .config is None :
343343 from torchao .float8 import Float8LinearConfig
344344
You can’t perform that action at this time.
0 commit comments