Skip to content

Commit 31a6581

Browse files
Enable FSDP upcast (#2280)
1 parent 1862be8 commit 31a6581

File tree

2 files changed

+15
-257
lines changed

2 files changed

+15
-257
lines changed

optimum/habana/accelerate/accelerator.py

Lines changed: 7 additions & 253 deletions
Original file line numberDiff line numberDiff line change
@@ -15,10 +15,7 @@
1515

1616
from __future__ import annotations
1717

18-
import functools
1918
import os
20-
import re
21-
from types import MethodType
2219

2320
import accelerate
2421
import torch
@@ -27,15 +24,7 @@
2724
from accelerate.logging import get_logger
2825
from accelerate.utils import (
2926
DistributedType,
30-
DynamoBackend,
31-
apply_fp8_autowrap,
32-
convert_outputs_to_fp32,
33-
ensure_weights_retied,
34-
get_mixed_precision_context_manager,
35-
model_has_dtensor,
3627
)
37-
from accelerate.utils.dataclasses import FP8BackendType
38-
from accelerate.utils.other import compile_regions, is_compiled_module
3928

4029
from ..distributed import parallel_state
4130
from .utils.dataclasses import GaudiTERecipeKwargs
@@ -77,249 +66,14 @@ def __init__(
7766
if self.has_fp8_handler:
7867
self.fp8_recipe = get_fp8_recipe(self.te_recipe_handler or self.fp8_recipe_handler)
7968

80-
# NOTE: this is only kept here until FSDP upcast is fixed
81-
def prepare_model(
82-
self, model: torch.nn.Module, device_placement: bool | None = None, evaluation_mode: bool = False
83-
):
84-
if device_placement is None:
85-
device_placement = self.device_placement and self.distributed_type != DistributedType.FSDP
86-
87-
self._models.append(model)
88-
89-
# TODO: Look at enabling native TP training directly with a proper config
90-
if (
91-
self.verify_device_map(model)
92-
and self.distributed_type != DistributedType.NO
93-
and os.environ.get("ACCELERATE_BYPASS_DEVICE_MAP", "false") != "true"
94-
):
95-
raise ValueError(
96-
"You can't train a model that has been loaded with `device_map='auto'` in any distributed mode."
97-
" Please rerun your script specifying `--num_processes=1` or by launching with `python {{myscript.py}}`."
98-
)
99-
100-
if self.native_amp:
101-
model._original_forward = model.forward
102-
autocast_context = get_mixed_precision_context_manager(self.native_amp, self.autocast_handler)
103-
# NOTE: MS-AMP adds `__func__` already to `model.forward`, so we should always use `model.forward`
104-
if self.fp8_backend == FP8BackendType.MSAMP or not hasattr(model.forward, "__func__"):
105-
model_forward_func = model.forward
106-
model.forward = convert_outputs_to_fp32(autocast_context(model_forward_func))
107-
else:
108-
model_forward_func = model.forward.__func__
109-
new_forward = autocast_context(model_forward_func)
110-
model.forward = MethodType(new_forward, model)
111-
model.forward = MethodType(convert_outputs_to_fp32(model.forward.__func__), model)
112-
113-
# We prepare TE after, allowing for bf16 autocast to happen first
114-
if self.fp8_backend == FP8BackendType.TE and not self.delayed_fp8_autocast:
115-
model = apply_fp8_autowrap(model, self.te_recipe_handler or self.fp8_recipe_handler)
116-
117-
if device_placement and not self.verify_device_map(model):
118-
model = model.to(self.device)
119-
if not evaluation_mode and self.distribution_strategy != "fast_ddp":
120-
if self.multi_device and not (self.parallelism_config and self.parallelism_config.tp_enabled):
121-
if model_has_dtensor(model):
122-
raise ValueError(
123-
"Your model contains `DTensor` parameters, which is incompatible with DDP. Maybe you loaded your model with `device_map='auto'`? Specify `device_map='cuda'` or 'cpu' instead."
124-
)
125-
if any(p.requires_grad for p in model.parameters()):
126-
kwargs = self.ddp_handler.to_kwargs() if self.ddp_handler is not None else {}
127-
# TODO: Look at enabling native TP training directly with a proper config
128-
if os.environ.get("ACCELERATE_BYPASS_DEVICE_MAP", "false") != "true":
129-
if self.device.type == "hpu":
130-
device_ids, output_device = [self.device.index], self.device.index
131-
else:
132-
device_ids, output_device = [self.local_process_index], self.local_process_index
133-
else:
134-
device_ids, output_device = None, None
135-
model = torch.nn.parallel.DistributedDataParallel(
136-
model, device_ids=device_ids, output_device=output_device, **kwargs
137-
)
138-
if self.ddp_handler is not None:
139-
self.ddp_handler.register_comm_hook(model)
140-
elif self.parallelism_config and self.parallelism_config.tp_enabled:
141-
if not hasattr(model, "tp_size"):
142-
raise NotImplementedError(
143-
"Model should undergo tensor parallel before passing it to accelerate."
144-
"You can use .from_pretrained(..., tp_plan='auto') if the model supports"
145-
)
146-
if model.tp_size != self.parallelism_config.tp_size:
147-
raise ValueError(
148-
f"tp_size in the plugin {self.parallelism_config.tp_size} should be same as model's tp size {model.tp_size}"
149-
)
150-
elif self.is_fsdp2:
151-
raise ValueError(
152-
"FSDP2 preparation should be done via `accelerate.prepare()`, as it requires a model and an optimizer."
153-
)
154-
155-
elif self.distributed_type == DistributedType.FSDP:
156-
# We need to fix the optimizer *before* sharding the model
157-
from torch.distributed.fsdp.fully_sharded_data_parallel import FullyShardedDataParallel as FSDP
158-
159-
# Check if the model is already a FSDP model due to `Manual Wrapping` and if so,
160-
# don't wrap it again
161-
# In case the model is already compiled using PyTorch 2.0 and the wrapped model in it
162-
# is a FSDP model, don't wrap it again
163-
is_type_fsdp = isinstance(model, FSDP) or (
164-
is_compiled_module(model) and isinstance(model._orig_mod, FSDP)
165-
)
166-
167-
if not is_type_fsdp:
168-
self.state.fsdp_plugin.set_auto_wrap_policy(model)
169-
fsdp_plugin = self.state.fsdp_plugin
170-
171-
# need to ensure that params are re-tied after running
172-
# param_init_fn
173-
fsdp_plugin.param_init_fn = ensure_weights_retied(
174-
fsdp_plugin.param_init_fn,
175-
model,
176-
self.device,
177-
)
178-
179-
kwargs = {
180-
# We fallback to reshard_after_forward if sharding_strategy is not set.
181-
# We prerfer sharding_strategy to not break the behavior of the existing code.
182-
# Deprecation warning has already been issued in `utils.dataclasses.py`
183-
"sharding_strategy": fsdp_plugin.sharding_strategy or fsdp_plugin.reshard_after_forward,
184-
"cpu_offload": fsdp_plugin.cpu_offload,
185-
"auto_wrap_policy": fsdp_plugin.auto_wrap_policy,
186-
"mixed_precision": fsdp_plugin.mixed_precision_policy,
187-
"sync_module_states": fsdp_plugin.sync_module_states,
188-
"backward_prefetch": fsdp_plugin.backward_prefetch,
189-
"forward_prefetch": fsdp_plugin.forward_prefetch,
190-
"use_orig_params": fsdp_plugin.use_orig_params,
191-
"param_init_fn": fsdp_plugin.param_init_fn,
192-
"ignored_modules": fsdp_plugin.ignored_modules,
193-
"limit_all_gathers": fsdp_plugin.limit_all_gathers,
194-
"device_id": self.device,
195-
}
196-
197-
if isinstance(kwargs["ignored_modules"], str):
198-
reg = re.compile(kwargs["ignored_modules"])
199-
ignored = []
200-
for name, module in model.named_modules():
201-
if reg.fullmatch(name):
202-
# ensure that the device for these modules is still set correctly
203-
module.to(self.device)
204-
ignored.append(module)
205-
kwargs["ignored_modules"] = ignored
206-
207-
model = FSDP(model, **kwargs)
208-
if fsdp_plugin.activation_checkpointing:
209-
from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import (
210-
CheckpointImpl,
211-
apply_activation_checkpointing,
212-
checkpoint_wrapper,
213-
)
214-
215-
apply_activation_checkpointing(
216-
model,
217-
checkpoint_wrapper_fn=functools.partial(
218-
checkpoint_wrapper,
219-
checkpoint_impl=CheckpointImpl.NO_REENTRANT,
220-
),
221-
auto_wrap_policy=fsdp_plugin.auto_wrap_policy,
222-
)
223-
224-
# TODO: starting from transformers 4.43 and accelerate 0.33, upcasting was added for FSDP in mixed precision
225-
# https://github.com/huggingface/accelerate/pull/2674 making FSDP training more stable,
226-
# but was kept disabled in optimum-habana temporarily due to some failing tests.
227-
"""
228-
# In the event the model had been loaded in low precision, but
229-
# mixed precision had also been activated, then we follow DeepSpeed's
230-
# strategy to hold the parameters in full precision.
231-
# - assume that trainer.args.bf16 and trainer.args.fp16 are already checked against
232-
# fsdp_plugin.mixed_precision_policy.
233-
# - NOTE: we do not check the mixed_precision attribute on the FSDP root wrapper.
234-
# * this attribute will always set by init_utils.init_core_state so its always not None.
235-
# * mixed_precision.param_dtype only regards _fwd_bwd_param_dtype
236-
# * if model is loaded in 16bit, and even if mixed_precision.param_dtype is None,
237-
# we still want to upcast the flat_param.
238-
if self.mixed_precision != "no": # if mixed precision is set
239-
upcasted_log = []
240-
for module in FSDP.fsdp_modules(model):
241-
# Referencing DeepSpeed Zero3
242-
# - in Init, params are converted to 16bit while partitioning.
243-
# - in accelerator.prepare, deepspeed.initialize is called to:
244-
# * creates the DeepSpeedEngine.
245-
# * since zero_optimization() is True , calls engine._configure_zero_optimizer.
246-
#
247-
# Inside the DeepSpeed Zero3 optimizer configuration, which initializes
248-
# DeepSpeedZeroOptimizer_Stage3, during which:
249-
# * trainable_param_groups are obtained from the attached optimizer
250-
# (already partitioned in 16bit).
251-
# * then _setup_for_real_optimizer -> _create_fp32_partitions
252-
# which performs the fp32 upcasting.
253-
254-
# To mimic DeepSeepds's casting in FSDP, we look at the (single) FlatParameter held
255-
# within an FSDP wrapper. This FlatParameter will be seen by the optimizer.
256-
# - even though there is a torch.device('meta') guard below, we
257-
# expect _init_utils._init_param_handle_from_module to already
258-
# sync the parameter.
259-
260-
if not module._has_params:
261-
continue # skip if FSDP module not managing parameters
262-
param = module._flat_param
263-
if (
264-
param.dtype != torch.float32
265-
and param.device != torch.device("meta")
266-
and param.requires_grad
267-
):
268-
# keep log of names_params that was upcasted
269-
# NOTE: resorted to this because warnings.simplefilter("once") is somehow not working
270-
name_param_log = (module.module.__class__.__name__, ", ".join(module._flat_param._fqns))
271-
if name_param_log not in upcasted_log:
272-
upcasted_log.append(name_param_log)
273-
274-
# this works because of FSDP's _runtime_utils.lazy_init.
275-
# Have to be careful not to call anything before this that
276-
# triggers lazy_init (e.g., _is_fsdp_root).
277-
param.data = param.data.to(torch.float32) # upcasting
278-
module._handle._orig_param_dtype = torch.float32 # update
279-
280-
# report the warnings
281-
# some messages can be quite repetitive, especially when reporting about layers that have identical architecture.
282-
if self.is_main_process:
283-
for name_log, param_log in upcasted_log:
284-
warnings.warn(
285-
f"Upcasted low precision parameters in {name_log} because mixed precision turned on in FSDP. "
286-
f"Affects: {param_log}."
287-
)
288-
289-
if len(upcasted_log) > 0:
290-
warnings.warn(
291-
"FSDP upcast of low precision parameters may affect the precision of model checkpoints."
292-
)
293-
"""
294-
295-
# if the previous and current models are same, delete the previous one
296-
if len(self._models) > 1 and (self._models[-2] is self._models[-1]):
297-
del self._models[-2]
298-
self._models[-1] = model
299-
elif self.distributed_type == DistributedType.MULTI_CPU:
300-
kwargs = self.ddp_handler.to_kwargs() if self.ddp_handler else {}
301-
model = torch.nn.parallel.DistributedDataParallel(model, **kwargs)
302-
if self.ddp_handler is not None:
303-
self.ddp_handler.register_comm_hook(model)
304-
# Now we can apply the FP8 autocast
305-
if self.fp8_backend == FP8BackendType.TE and self.delayed_fp8_autocast:
306-
model = apply_fp8_autowrap(model, self.te_recipe_handler or self.fp8_recipe_handler)
307-
# torch.compile should be called last and only if the model isn't already compiled
308-
if self.state.dynamo_plugin.backend != DynamoBackend.NO and not is_compiled_module(model):
309-
if self.state.dynamo_plugin.use_regional_compilation:
310-
model = compile_regions(model, **self.state.dynamo_plugin.to_kwargs())
311-
else:
312-
model = torch.compile(model, **self.state.dynamo_plugin.to_kwargs())
313-
return model
314-
31569
# INFO: this adds support for fast_ddp by not applying DDP wrapper
316-
# def prepare_model(self, model: torch.nn.Module, device_placement: bool = None, evaluation_mode: bool = False):
317-
# if self.distribution_strategy == "fast_ddp":
318-
# # with fast_ddp, we just skip ddp and fsdp model preparation
319-
# model = super().prepare_model(model, device_placement=device_placement, evaluation_mode=True)
320-
# else:
321-
# model = super().prepare_model(model, device_placement=device_placement, evaluation_mode=evaluation_mode)
322-
# return model
70+
def prepare_model(self, model: torch.nn.Module, device_placement: bool = None, evaluation_mode: bool = False):
71+
if self.distribution_strategy == "fast_ddp":
72+
# with fast_ddp, we just skip ddp and fsdp model preparation
73+
model = super().prepare_model(model, device_placement=device_placement, evaluation_mode=True)
74+
else:
75+
model = super().prepare_model(model, device_placement=device_placement, evaluation_mode=evaluation_mode)
76+
return model
32377

32478
# INFO: this adds support for autograd compilation to the deepspeed engine
32579
def _prepare_deepspeed(self, *args):

optimum/habana/transformers/trainer.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1085,8 +1085,13 @@ def hpu_deepspeed_checkpointing(function, *checkpoint_args, use_reentrant: Optio
10851085
# If the condition is true, we need to compute grad_norm, deepspeed does its own clipping
10861086
if _should_compute_grad_norm:
10871087
# Gradient clipping
1088-
if self.FusedNorm is not None:
1089-
# TODO: to merge self.accelerator.clip_grad_norm_ when HMP is removed
1088+
if (
1089+
self.FusedNorm is not None
1090+
and self.accelerator.distributed_type != DistributedType.FSDP
1091+
):
1092+
# when weights are sharded, fsdp.clip_grad_norm_ should be used
1093+
# https://docs.pytorch.org/docs/main/fsdp.html#torch.distributed.fsdp.FullyShardedDataParallel.clip_grad_norm_
1094+
# TODO: check if the fused norm is more performant than the torch.nn.utils.clip_grad_norm_
10901095
grad_norm = self.FusedNorm.clip_norm(model.parameters())
10911096
else:
10921097
grad_norm_context = contextlib.nullcontext
@@ -1096,8 +1101,7 @@ def hpu_deepspeed_checkpointing(function, *checkpoint_args, use_reentrant: Optio
10961101
grad_norm_context = implicit_replication
10971102
with grad_norm_context():
10981103
grad_norm = self.accelerator.clip_grad_norm_(
1099-
model.parameters(),
1100-
args.max_grad_norm,
1104+
model.parameters(), args.max_grad_norm
11011105
)
11021106

11031107
self.control = self.callback_handler.on_pre_optimizer_step(args, self.state, self.control)

0 commit comments

Comments
 (0)