Skip to content

Commit 3a2f301

Browse files
committed
Fixed make style
1 parent 9bb447a commit 3a2f301

File tree

3 files changed

+9
-10
lines changed

3 files changed

+9
-10
lines changed

src/accelerate/accelerator.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -2039,11 +2039,7 @@ def _prepare_ao(self, *args):
20392039

20402040
# Invariant: with FSDP2, optimizer is always passed to `prepare()` together with model
20412041
# We only precompute scales if float8 all gather is enabled, possibly can add a flag for this later
2042-
if (
2043-
self.is_fsdp2
2044-
and len(optimizers) > 0
2045-
and self.ao_recipe_handler.config.enable_fsdp_float8_all_gather
2046-
):
2042+
if self.is_fsdp2 and len(optimizers) > 0 and self.ao_recipe_handler.config.enable_fsdp_float8_all_gather:
20472043
from torchao.float8 import precompute_float8_dynamic_scale_for_fsdp
20482044

20492045
optimizers[0].register_step_post_hook(

src/accelerate/commands/config/cluster.py

Lines changed: 6 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828
is_musa_available,
2929
is_npu_available,
3030
is_sdaa_available,
31+
is_torchao_available,
3132
is_transformer_engine_available,
3233
is_transformers_available,
3334
is_xpu_available,
@@ -794,7 +795,9 @@ def get_cluster_input():
794795
)
795796
if mixed_precision == "fp8":
796797
if not is_fp8_available():
797-
raise ValueError("FP8 (either torchao, Transformer Engine or MSAMP) is not installed on this machine.")
798+
raise ValueError(
799+
"FP8 (either torchao, Transformer Engine or MSAMP) is not installed on this machine."
800+
)
798801
fp8_config = {}
799802
fp8_config["backend"] = _ask_options(
800803
"Which FP8 backend do you want to use?",
@@ -870,9 +873,9 @@ def get_cluster_input():
870873
lambda x: "O1" if x == 0 else "O2",
871874
default=1,
872875
)
873-
876+
874877
elif fp8_config["backend"] == "AO":
875-
if not is_torch_ao_available():
878+
if not is_torchao_available():
876879
raise ValueError("torchao was selected, but it is not installed on this machine.")
877880
fp8_config["enable_fsdp_float8_all_gather"] = _ask_field(
878881
"Do you want to enable FSDP2 float8 all gather? This is recommended for better performance if using FSDP2. [YES/no]: ",

src/accelerate/utils/dataclasses.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -321,7 +321,7 @@ class AORecipeKwargs(KwargsHandler):
321321
operations to prevent runtime errors.
322322
- `enable_fsdp_float8_all_gather=True`: Enables FP8 all-gather for FSDP2. This provides memory bandwidth
323323
savings by casting parameters before the all-gather operation, saving 50% bandwidth compared to BF16.
324-
324+
325325
You can override these defaults by providing your own `Float8LinearConfig` instance.
326326
module_filter_func (`Callable`, *optional*, default to `None`):
327327
Optional function that must take in a module and layer name, and returns a boolean indicating whether the
@@ -338,7 +338,7 @@ def __post_init__(self):
338338
env_prefix = "ACCELERATE_FP8_"
339339
if not is_torchao_available():
340340
raise ImportError("TorchAO is not available. Please install it or use a different backend.")
341-
341+
342342
if self.config is None:
343343
from torchao.float8 import Float8LinearConfig
344344

0 commit comments

Comments
 (0)