diff --git a/src/accelerate/utils/dataclasses.py b/src/accelerate/utils/dataclasses.py index d5055d3bc0b..dd3c340f8b4 100644 --- a/src/accelerate/utils/dataclasses.py +++ b/src/accelerate/utils/dataclasses.py @@ -1911,6 +1911,9 @@ def __post_init__(self): self.cpu_ram_efficient_loading = ( str_to_bool(os.environ.get(env_prefix + "CPU_RAM_EFFICIENT_LOADING", "False")) == 1 ) + else: + # We still need to set it for transformers + os.environ[env_prefix + "CPU_RAM_EFFICIENT_LOADING"] = str(self.cpu_ram_efficient_loading) # There's no need to specify sync_module_states in FSDP2 if self.fsdp_version == 1 and self.cpu_ram_efficient_loading and not self.sync_module_states: warnings.warn( @@ -1918,17 +1921,6 @@ def __post_init__(self): "Setting sync_module_states to True." ) self.sync_module_states = True - - if self.cpu_ram_efficient_loading != bool( - str_to_bool(os.environ.get(env_prefix + "CPU_RAM_EFFICIENT_LOADING", "False")) - ): - env_var = env_prefix + "CPU_RAM_EFFICIENT_LOADING" - warnings.warn( - f"The `cpu_ram_efficient_loading` flag for `FullyShardedDataParallelPlugin` does not match the environment variable {env_var}. " - "Setting environment variable to match `cpu_ram_efficient_loading`." - ) - os.environ[env_var] = str(self.cpu_ram_efficient_loading) - if isinstance(self.mixed_precision_policy, str): # override is True since self.mixed_precision_policy is not None # has to be overwritten with the correct mixed precision object