|
15 | 15 |
|
16 | 16 | from __future__ import annotations |
17 | 17 |
|
18 | | -import functools |
19 | 18 | import os |
20 | | -import re |
21 | | -from types import MethodType |
22 | 19 |
|
23 | 20 | import accelerate |
24 | 21 | import torch |
|
27 | 24 | from accelerate.logging import get_logger |
28 | 25 | from accelerate.utils import ( |
29 | 26 | DistributedType, |
30 | | - DynamoBackend, |
31 | | - apply_fp8_autowrap, |
32 | | - convert_outputs_to_fp32, |
33 | | - ensure_weights_retied, |
34 | | - get_mixed_precision_context_manager, |
35 | | - model_has_dtensor, |
36 | 27 | ) |
37 | | -from accelerate.utils.dataclasses import FP8BackendType |
38 | | -from accelerate.utils.other import compile_regions, is_compiled_module |
39 | 28 |
|
40 | 29 | from ..distributed import parallel_state |
41 | 30 | from .utils.dataclasses import GaudiTERecipeKwargs |
@@ -77,249 +66,14 @@ def __init__( |
77 | 66 | if self.has_fp8_handler: |
78 | 67 | self.fp8_recipe = get_fp8_recipe(self.te_recipe_handler or self.fp8_recipe_handler) |
79 | 68 |
|
80 | | - # NOTE: this is only kept here until FSDP upcast is fixed |
81 | | - def prepare_model( |
82 | | - self, model: torch.nn.Module, device_placement: bool | None = None, evaluation_mode: bool = False |
83 | | - ): |
84 | | - if device_placement is None: |
85 | | - device_placement = self.device_placement and self.distributed_type != DistributedType.FSDP |
86 | | - |
87 | | - self._models.append(model) |
88 | | - |
89 | | - # TODO: Look at enabling native TP training directly with a proper config |
90 | | - if ( |
91 | | - self.verify_device_map(model) |
92 | | - and self.distributed_type != DistributedType.NO |
93 | | - and os.environ.get("ACCELERATE_BYPASS_DEVICE_MAP", "false") != "true" |
94 | | - ): |
95 | | - raise ValueError( |
96 | | - "You can't train a model that has been loaded with `device_map='auto'` in any distributed mode." |
97 | | - " Please rerun your script specifying `--num_processes=1` or by launching with `python {{myscript.py}}`." |
98 | | - ) |
99 | | - |
100 | | - if self.native_amp: |
101 | | - model._original_forward = model.forward |
102 | | - autocast_context = get_mixed_precision_context_manager(self.native_amp, self.autocast_handler) |
103 | | - # NOTE: MS-AMP adds `__func__` already to `model.forward`, so we should always use `model.forward` |
104 | | - if self.fp8_backend == FP8BackendType.MSAMP or not hasattr(model.forward, "__func__"): |
105 | | - model_forward_func = model.forward |
106 | | - model.forward = convert_outputs_to_fp32(autocast_context(model_forward_func)) |
107 | | - else: |
108 | | - model_forward_func = model.forward.__func__ |
109 | | - new_forward = autocast_context(model_forward_func) |
110 | | - model.forward = MethodType(new_forward, model) |
111 | | - model.forward = MethodType(convert_outputs_to_fp32(model.forward.__func__), model) |
112 | | - |
113 | | - # We prepare TE after, allowing for bf16 autocast to happen first |
114 | | - if self.fp8_backend == FP8BackendType.TE and not self.delayed_fp8_autocast: |
115 | | - model = apply_fp8_autowrap(model, self.te_recipe_handler or self.fp8_recipe_handler) |
116 | | - |
117 | | - if device_placement and not self.verify_device_map(model): |
118 | | - model = model.to(self.device) |
119 | | - if not evaluation_mode and self.distribution_strategy != "fast_ddp": |
120 | | - if self.multi_device and not (self.parallelism_config and self.parallelism_config.tp_enabled): |
121 | | - if model_has_dtensor(model): |
122 | | - raise ValueError( |
123 | | - "Your model contains `DTensor` parameters, which is incompatible with DDP. Maybe you loaded your model with `device_map='auto'`? Specify `device_map='cuda'` or 'cpu' instead." |
124 | | - ) |
125 | | - if any(p.requires_grad for p in model.parameters()): |
126 | | - kwargs = self.ddp_handler.to_kwargs() if self.ddp_handler is not None else {} |
127 | | - # TODO: Look at enabling native TP training directly with a proper config |
128 | | - if os.environ.get("ACCELERATE_BYPASS_DEVICE_MAP", "false") != "true": |
129 | | - if self.device.type == "hpu": |
130 | | - device_ids, output_device = [self.device.index], self.device.index |
131 | | - else: |
132 | | - device_ids, output_device = [self.local_process_index], self.local_process_index |
133 | | - else: |
134 | | - device_ids, output_device = None, None |
135 | | - model = torch.nn.parallel.DistributedDataParallel( |
136 | | - model, device_ids=device_ids, output_device=output_device, **kwargs |
137 | | - ) |
138 | | - if self.ddp_handler is not None: |
139 | | - self.ddp_handler.register_comm_hook(model) |
140 | | - elif self.parallelism_config and self.parallelism_config.tp_enabled: |
141 | | - if not hasattr(model, "tp_size"): |
142 | | - raise NotImplementedError( |
143 | | - "Model should undergo tensor parallel before passing it to accelerate." |
144 | | - "You can use .from_pretrained(..., tp_plan='auto') if the model supports" |
145 | | - ) |
146 | | - if model.tp_size != self.parallelism_config.tp_size: |
147 | | - raise ValueError( |
148 | | - f"tp_size in the plugin {self.parallelism_config.tp_size} should be same as model's tp size {model.tp_size}" |
149 | | - ) |
150 | | - elif self.is_fsdp2: |
151 | | - raise ValueError( |
152 | | - "FSDP2 preparation should be done via `accelerate.prepare()`, as it requires a model and an optimizer." |
153 | | - ) |
154 | | - |
155 | | - elif self.distributed_type == DistributedType.FSDP: |
156 | | - # We need to fix the optimizer *before* sharding the model |
157 | | - from torch.distributed.fsdp.fully_sharded_data_parallel import FullyShardedDataParallel as FSDP |
158 | | - |
159 | | - # Check if the model is already a FSDP model due to `Manual Wrapping` and if so, |
160 | | - # don't wrap it again |
161 | | - # In case the model is already compiled using PyTorch 2.0 and the wrapped model in it |
162 | | - # is a FSDP model, don't wrap it again |
163 | | - is_type_fsdp = isinstance(model, FSDP) or ( |
164 | | - is_compiled_module(model) and isinstance(model._orig_mod, FSDP) |
165 | | - ) |
166 | | - |
167 | | - if not is_type_fsdp: |
168 | | - self.state.fsdp_plugin.set_auto_wrap_policy(model) |
169 | | - fsdp_plugin = self.state.fsdp_plugin |
170 | | - |
171 | | - # need to ensure that params are re-tied after running |
172 | | - # param_init_fn |
173 | | - fsdp_plugin.param_init_fn = ensure_weights_retied( |
174 | | - fsdp_plugin.param_init_fn, |
175 | | - model, |
176 | | - self.device, |
177 | | - ) |
178 | | - |
179 | | - kwargs = { |
180 | | - # We fallback to reshard_after_forward if sharding_strategy is not set. |
181 | | - # We prerfer sharding_strategy to not break the behavior of the existing code. |
182 | | - # Deprecation warning has already been issued in `utils.dataclasses.py` |
183 | | - "sharding_strategy": fsdp_plugin.sharding_strategy or fsdp_plugin.reshard_after_forward, |
184 | | - "cpu_offload": fsdp_plugin.cpu_offload, |
185 | | - "auto_wrap_policy": fsdp_plugin.auto_wrap_policy, |
186 | | - "mixed_precision": fsdp_plugin.mixed_precision_policy, |
187 | | - "sync_module_states": fsdp_plugin.sync_module_states, |
188 | | - "backward_prefetch": fsdp_plugin.backward_prefetch, |
189 | | - "forward_prefetch": fsdp_plugin.forward_prefetch, |
190 | | - "use_orig_params": fsdp_plugin.use_orig_params, |
191 | | - "param_init_fn": fsdp_plugin.param_init_fn, |
192 | | - "ignored_modules": fsdp_plugin.ignored_modules, |
193 | | - "limit_all_gathers": fsdp_plugin.limit_all_gathers, |
194 | | - "device_id": self.device, |
195 | | - } |
196 | | - |
197 | | - if isinstance(kwargs["ignored_modules"], str): |
198 | | - reg = re.compile(kwargs["ignored_modules"]) |
199 | | - ignored = [] |
200 | | - for name, module in model.named_modules(): |
201 | | - if reg.fullmatch(name): |
202 | | - # ensure that the device for these modules is still set correctly |
203 | | - module.to(self.device) |
204 | | - ignored.append(module) |
205 | | - kwargs["ignored_modules"] = ignored |
206 | | - |
207 | | - model = FSDP(model, **kwargs) |
208 | | - if fsdp_plugin.activation_checkpointing: |
209 | | - from torch.distributed.algorithms._checkpoint.checkpoint_wrapper import ( |
210 | | - CheckpointImpl, |
211 | | - apply_activation_checkpointing, |
212 | | - checkpoint_wrapper, |
213 | | - ) |
214 | | - |
215 | | - apply_activation_checkpointing( |
216 | | - model, |
217 | | - checkpoint_wrapper_fn=functools.partial( |
218 | | - checkpoint_wrapper, |
219 | | - checkpoint_impl=CheckpointImpl.NO_REENTRANT, |
220 | | - ), |
221 | | - auto_wrap_policy=fsdp_plugin.auto_wrap_policy, |
222 | | - ) |
223 | | - |
224 | | - # TODO: starting from transformers 4.43 and accelerate 0.33, upcasting was added for FSDP in mixed precision |
225 | | - # https://github.com/huggingface/accelerate/pull/2674 making FSDP training more stable, |
226 | | - # but was kept disabled in optimum-habana temporarily due to some failing tests. |
227 | | - """ |
228 | | - # In the event the model had been loaded in low precision, but |
229 | | - # mixed precision had also been activated, then we follow DeepSpeed's |
230 | | - # strategy to hold the parameters in full precision. |
231 | | - # - assume that trainer.args.bf16 and trainer.args.fp16 are already checked against |
232 | | - # fsdp_plugin.mixed_precision_policy. |
233 | | - # - NOTE: we do not check the mixed_precision attribute on the FSDP root wrapper. |
234 | | - # * this attribute will always set by init_utils.init_core_state so its always not None. |
235 | | - # * mixed_precision.param_dtype only regards _fwd_bwd_param_dtype |
236 | | - # * if model is loaded in 16bit, and even if mixed_precision.param_dtype is None, |
237 | | - # we still want to upcast the flat_param. |
238 | | - if self.mixed_precision != "no": # if mixed precision is set |
239 | | - upcasted_log = [] |
240 | | - for module in FSDP.fsdp_modules(model): |
241 | | - # Referencing DeepSpeed Zero3 |
242 | | - # - in Init, params are converted to 16bit while partitioning. |
243 | | - # - in accelerator.prepare, deepspeed.initialize is called to: |
244 | | - # * creates the DeepSpeedEngine. |
245 | | - # * since zero_optimization() is True , calls engine._configure_zero_optimizer. |
246 | | - # |
247 | | - # Inside the DeepSpeed Zero3 optimizer configuration, which initializes |
248 | | - # DeepSpeedZeroOptimizer_Stage3, during which: |
249 | | - # * trainable_param_groups are obtained from the attached optimizer |
250 | | - # (already partitioned in 16bit). |
251 | | - # * then _setup_for_real_optimizer -> _create_fp32_partitions |
252 | | - # which performs the fp32 upcasting. |
253 | | -
|
254 | | - # To mimic DeepSeepds's casting in FSDP, we look at the (single) FlatParameter held |
255 | | - # within an FSDP wrapper. This FlatParameter will be seen by the optimizer. |
256 | | - # - even though there is a torch.device('meta') guard below, we |
257 | | - # expect _init_utils._init_param_handle_from_module to already |
258 | | - # sync the parameter. |
259 | | -
|
260 | | - if not module._has_params: |
261 | | - continue # skip if FSDP module not managing parameters |
262 | | - param = module._flat_param |
263 | | - if ( |
264 | | - param.dtype != torch.float32 |
265 | | - and param.device != torch.device("meta") |
266 | | - and param.requires_grad |
267 | | - ): |
268 | | - # keep log of names_params that was upcasted |
269 | | - # NOTE: resorted to this because warnings.simplefilter("once") is somehow not working |
270 | | - name_param_log = (module.module.__class__.__name__, ", ".join(module._flat_param._fqns)) |
271 | | - if name_param_log not in upcasted_log: |
272 | | - upcasted_log.append(name_param_log) |
273 | | -
|
274 | | - # this works because of FSDP's _runtime_utils.lazy_init. |
275 | | - # Have to be careful not to call anything before this that |
276 | | - # triggers lazy_init (e.g., _is_fsdp_root). |
277 | | - param.data = param.data.to(torch.float32) # upcasting |
278 | | - module._handle._orig_param_dtype = torch.float32 # update |
279 | | -
|
280 | | - # report the warnings |
281 | | - # some messages can be quite repetitive, especially when reporting about layers that have identical architecture. |
282 | | - if self.is_main_process: |
283 | | - for name_log, param_log in upcasted_log: |
284 | | - warnings.warn( |
285 | | - f"Upcasted low precision parameters in {name_log} because mixed precision turned on in FSDP. " |
286 | | - f"Affects: {param_log}." |
287 | | - ) |
288 | | -
|
289 | | - if len(upcasted_log) > 0: |
290 | | - warnings.warn( |
291 | | - "FSDP upcast of low precision parameters may affect the precision of model checkpoints." |
292 | | - ) |
293 | | - """ |
294 | | - |
295 | | - # if the previous and current models are same, delete the previous one |
296 | | - if len(self._models) > 1 and (self._models[-2] is self._models[-1]): |
297 | | - del self._models[-2] |
298 | | - self._models[-1] = model |
299 | | - elif self.distributed_type == DistributedType.MULTI_CPU: |
300 | | - kwargs = self.ddp_handler.to_kwargs() if self.ddp_handler else {} |
301 | | - model = torch.nn.parallel.DistributedDataParallel(model, **kwargs) |
302 | | - if self.ddp_handler is not None: |
303 | | - self.ddp_handler.register_comm_hook(model) |
304 | | - # Now we can apply the FP8 autocast |
305 | | - if self.fp8_backend == FP8BackendType.TE and self.delayed_fp8_autocast: |
306 | | - model = apply_fp8_autowrap(model, self.te_recipe_handler or self.fp8_recipe_handler) |
307 | | - # torch.compile should be called last and only if the model isn't already compiled |
308 | | - if self.state.dynamo_plugin.backend != DynamoBackend.NO and not is_compiled_module(model): |
309 | | - if self.state.dynamo_plugin.use_regional_compilation: |
310 | | - model = compile_regions(model, **self.state.dynamo_plugin.to_kwargs()) |
311 | | - else: |
312 | | - model = torch.compile(model, **self.state.dynamo_plugin.to_kwargs()) |
313 | | - return model |
314 | | - |
315 | 69 | # INFO: this adds support for fast_ddp by not applying DDP wrapper |
316 | | - # def prepare_model(self, model: torch.nn.Module, device_placement: bool = None, evaluation_mode: bool = False): |
317 | | - # if self.distribution_strategy == "fast_ddp": |
318 | | - # # with fast_ddp, we just skip ddp and fsdp model preparation |
319 | | - # model = super().prepare_model(model, device_placement=device_placement, evaluation_mode=True) |
320 | | - # else: |
321 | | - # model = super().prepare_model(model, device_placement=device_placement, evaluation_mode=evaluation_mode) |
322 | | - # return model |
| 70 | + def prepare_model(self, model: torch.nn.Module, device_placement: bool = None, evaluation_mode: bool = False): |
| 71 | + if self.distribution_strategy == "fast_ddp": |
| 72 | + # with fast_ddp, we just skip ddp and fsdp model preparation |
| 73 | + model = super().prepare_model(model, device_placement=device_placement, evaluation_mode=True) |
| 74 | + else: |
| 75 | + model = super().prepare_model(model, device_placement=device_placement, evaluation_mode=evaluation_mode) |
| 76 | + return model |
323 | 77 |
|
324 | 78 | # INFO: this adds support for autograd compilation to the deepspeed engine |
325 | 79 | def _prepare_deepspeed(self, *args): |
|
0 commit comments