Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
14 changes: 3 additions & 11 deletions src/accelerate/utils/dataclasses.py
Original file line number Diff line number Diff line change
Expand Up @@ -1911,24 +1911,16 @@ def __post_init__(self):
self.cpu_ram_efficient_loading = (
str_to_bool(os.environ.get(env_prefix + "CPU_RAM_EFFICIENT_LOADING", "False")) == 1
)
else:
# We still need to set it for transformers
os.environ[env_prefix + "CPU_RAM_EFFICIENT_LOADING"] = str(self.cpu_ram_efficient_loading)
# There's no need to specify sync_module_states in FSDP2
if self.fsdp_version == 1 and self.cpu_ram_efficient_loading and not self.sync_module_states:
warnings.warn(
"sync_module_states cannot be False since efficient cpu ram loading enabled. "
"Setting sync_module_states to True."
)
self.sync_module_states = True

if self.cpu_ram_efficient_loading != bool(
str_to_bool(os.environ.get(env_prefix + "CPU_RAM_EFFICIENT_LOADING", "False"))
):
env_var = env_prefix + "CPU_RAM_EFFICIENT_LOADING"
warnings.warn(
f"The `cpu_ram_efficient_loading` flag for `FullyShardedDataParallelPlugin` does not match the environment variable {env_var}. "
"Setting environment variable to match `cpu_ram_efficient_loading`."
)
os.environ[env_var] = str(self.cpu_ram_efficient_loading)

if isinstance(self.mixed_precision_policy, str):
# override is True since self.mixed_precision_policy is not None
# has to be overwritten with the correct mixed precision object
Expand Down
Loading