Skip to content

Commit 66d1977

Browse files
Custom modeling for training (#801)
# What does this PR do? ### Custom modeling code for training #### Features This PR adds support for custom modeling code for training. Each custom modeling code can be added under `optimum/neuron/models/training`. Having a custom modeling allows us to implement Neuron specificities in a cleaner way than using dynamic patching. It becomes easy to: - Fuse linear layers together for efficiency - Use custom linear layers such as `GQAQKVColumnParallelLinear`, useful with high TP sizes. - Use custom kernels, such as the flash attention kernel In this PR we provide a first full custom implementation with Llama. #### Model weight transformations Because having a custom modeling code enables to change the vanilla Transformers implementation, we need a way to make sure that we can load checkpoints from Transformers, and that we can save checkpoints in the original format as well. To do that we provide an API with the `ModelWeightTransformationSpec` classes. These classes represent the transformation compared to the vanilla Transformers implementation and are directly added in the modules containing these transformations. For now two exist: - `FusedLinearsSpec`: represents a transformation when multiple linear layers are fused into a single linear layer (possibly a parallel linear) - `GQAQKVColumnParallelLinearSpec`: represents the transformation of separate query, key, and value projections into a single GQAQKVColumnParalleLinear projection. Then during loading, saving and consolidation, we use these specs to make sure every weight matches with Transformers weights. #### Known issues - There seems to be an issue when saving a checkpoint for DP > 1 during training. After initial investigation, it seems to be a compiler bug, but it will require more work. I suggest to work on it on a another PR. ### Training example #### Specs - Model: `meta-llama/Llama-3.2-3B-Instruct` - Dataset: `databricks/databricks-dolly-15k` - Trainer: `NeuronSFTTrainer` - DP=4, TP=8 - Gradient accumulation steps = 16 => Effective batch size = 4 x 16 = 64 - Sequence length = 2048 with packing = True - 3 epochs - Learning rate = 5e-4, warmup ration = 0.3, lr scheduler type = "cosine" #### Loss curve ![W B Chart 24_04_2025 16_12_30](https://github.com/user-attachments/assets/cb4467cd-3e18-4560-86e0-d71d55bb08b7) ### To be done in later PRs: - Support for PP - Support for LoRA - Refactor `save_pretrained` as it was done for `from_pretrained` in this PR. - Add test that tests overfitting
1 parent 6edcf86 commit 66d1977

26 files changed

+3985
-741
lines changed

docs/source/guides/distributed_training.mdx

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -183,11 +183,11 @@ Just as for ZeRO-1, it is possible to wrap the optimizer class to make it lazy.
183183
```python
184184
from torch.optim import AdamW
185185
from optimum.neuron import NeuronAccelerator
186-
from optimum.neuron.accelerate.utils import ModelParallelismPlugin
186+
from optimum.neuron.accelerate.utils import ModelParallelismConfig
187187
from optimum.neuron.distributed import lazy_load_for_parallelism
188188

189189
tensor_parallel_size = 8
190-
mp_plugin = ModelParallelismPlugin(
190+
mp_config = ModelParallelismConfig(
191191
tensor_parallel_size,
192192
parallelize_embeddings=True,
193193
sequence_parallel_enabled=True,
@@ -196,7 +196,7 @@ mp_plugin = ModelParallelismPlugin(
196196

197197
accelerator = NeuronAccelerator(
198198
...
199-
mp_plugin=mp_plugin,
199+
mp_config=mp_config,
200200
)
201201

202202
with lazy_load_for_parallelism(tensor_parallel_size=tensor_parallel_size):

docs/source/training_tutorials/sft_lora_finetune_llm.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -37,7 +37,7 @@ def training_function(script_args, training_args):
3737
r=16,
3838
lora_alpha=16,
3939
lora_dropout=0.05,
40-
target_modules=["q_proj", "gate_proj", "v_proj", "o_proj", "k_proj", "up_proj", "down_proj"],
40+
target_modules=["q_proj", "v_proj", "o_proj", "k_proj", "up_proj", "down_proj"],
4141
bias="none",
4242
task_type="CAUSAL_LM",
4343
)

optimum/neuron/__init__.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -70,7 +70,7 @@
7070
"NeuronAccelerator",
7171
"NeuronAcceleratorState",
7272
"NeuronPartialState",
73-
"ModelParallelismPlugin",
73+
"ModelParallelismConfig",
7474
],
7575
"pipelines": ["pipeline"],
7676
"utils": ["NeuronSFTConfig", "NeuronORPOConfig", "get_peft_model"],
@@ -90,7 +90,7 @@
9090
_import_structure["models.yolos"] = ["NeuronYolosForObjectDetection"]
9191

9292
if TYPE_CHECKING:
93-
from .accelerate import ModelParallelismPlugin, NeuronAccelerator, NeuronAcceleratorState, NeuronPartialState
93+
from .accelerate import ModelParallelismConfig, NeuronAccelerator, NeuronAcceleratorState, NeuronPartialState
9494
from .hf_argparser import NeuronHfArgumentParser
9595
from .modeling import (
9696
NeuronModelForAudioClassification,

optimum/neuron/accelerate/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,4 +15,4 @@
1515

1616
from .accelerator import NeuronAccelerator
1717
from .state import NeuronAcceleratorState, NeuronPartialState
18-
from .utils.dataclasses import ModelParallelismPlugin, NeuronDistributedType
18+
from .utils.dataclasses import ModelParallelismConfig, NeuronDistributedType

optimum/neuron/accelerate/accelerator.py

Lines changed: 27 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616

1717
import collections
1818
import contextlib
19+
import inspect
1920
import os
2021
import re
2122
import shutil
@@ -56,7 +57,7 @@
5657
from .state import NeuronAcceleratorState
5758
from .utils import (
5859
AutocastBackend,
59-
ModelParallelismPlugin,
60+
ModelParallelismConfig,
6061
NeuronDistributedType,
6162
patch_accelerate_is_torch_xla_available,
6263
)
@@ -99,7 +100,7 @@ class NeuronAccelerator(Accelerator):
99100
def __init__(
100101
self,
101102
*args,
102-
mp_plugin: Optional[ModelParallelismPlugin] = None,
103+
mp_config: Optional[ModelParallelismConfig] = None,
103104
zero_1: bool = False,
104105
autocast_backend: Union[str, AutocastBackend] = "xla",
105106
**kwargs,
@@ -147,7 +148,7 @@ def patched_is_torch_xla_available(check_is_tpu: bool = False, check_is_gpu: boo
147148
accelerate.state.is_torch_xla_available = patched_is_torch_xla_available
148149

149150
patched_accelerator_state = partial(
150-
NeuronAcceleratorState, mp_plugin=mp_plugin, autocast_backend=autocast_backend
151+
NeuronAcceleratorState, mp_config=mp_config, autocast_backend=autocast_backend
151152
)
152153
with Patcher([("accelerate.accelerator.AcceleratorState", patched_accelerator_state)]):
153154
super().__init__(**full_kwargs)
@@ -226,7 +227,7 @@ def prepare_data_loader(
226227
data_loader, num_replicas=num_replicas, rank=rank, force_drop_last=force_drop_last
227228
)
228229
# No need to wrap the dataloader if we are using pipeline parallelism.
229-
if use_mp_device_loader and self.state.mp_plugin.pipeline_parallel_size == 1:
230+
if use_mp_device_loader and self.state.mp_config.pipeline_parallel_size == 1:
230231
data_loader = MpDeviceLoader(data_loader, self.device)
231232
return data_loader
232233

@@ -302,6 +303,14 @@ def _prepare_optimizer_for_zero_1(self, optimizer: torch.optim.Optimizer, device
302303

303304
@patch_within_function(("accelerate.accelerator.AcceleratedOptimizer", NeuronAcceleratedOptimizer))
304305
def prepare_optimizer(self, optimizer: torch.optim.Optimizer, device_placement: Optional[bool] = None):
306+
# If we use custom modeling, we do not have to do anything for now.
307+
# We will have to do some work when supporting ZeRO-1.
308+
model = self._models[0] if len(self._models) == 1 else None
309+
if model is not None and inspect.getmodule(model.__class__).__name__.startswith(
310+
"optimum.neuron.models.training"
311+
):
312+
return super().prepare_optimizer(optimizer, device_placement=device_placement)
313+
305314
if self.distributed_type is NeuronDistributedType.MODEL_PARALLELISM:
306315
optimizer = self._prepare_optimizer_for_mp(optimizer, device_placement=device_placement)
307316
if self.zero_1:
@@ -385,7 +394,7 @@ def _prepare_model_for_mp(
385394

386395
tied_parameters_dict = get_tied_parameters_dict(model)
387396
model_main_input_name = getattr(model, "main_input_name", None)
388-
model = self.state.mp_plugin.parallelize_model(model, device=self.device)
397+
model = self.state.mp_config.parallelize_model(model, device=self.device)
389398

390399
if model_main_input_name is not None:
391400
setattr(model, "main_input_name", model_main_input_name)
@@ -444,6 +453,16 @@ def prepare_model(
444453
# we get access to the model, we simply check if the flags are the best and notify the user otherwise.
445454
check_neuron_cc_flags_for_model(model)
446455

456+
if inspect.getmodule(model.__class__).__name__.startswith("optimum.neuron.models.training"):
457+
# We do not want to use the cache, or output unused tensors as it would imply more communication that we do not
458+
# need.
459+
model.config.use_cache = False
460+
model.config.output_attentions = False
461+
model.config.output_hidden_states = False
462+
move_model_to_device(model, self.device)
463+
model = super().prepare_model(model, device_placement=False, evaluation_mode=evaluation_mode)
464+
return model
465+
447466
model = self.patch_model_for_neuron(model)
448467

449468
if self.state.mixed_precision == "bf16":
@@ -629,9 +648,9 @@ def save_optimizer_func(accelerator, optimizer, model, output_dir, i):
629648
model,
630649
output_dir,
631650
optimizer=optimizer,
632-
use_xser=self.state.mp_plugin.use_xser,
633-
async_save=self.state.mp_plugin.async_save,
634-
num_local_ranks_per_step=self.state.mp_plugin.num_local_ranks_per_step,
651+
use_xser=self.state.mp_config.use_xser,
652+
async_save=self.state.mp_config.async_save,
653+
num_local_ranks_per_step=self.state.mp_config.num_local_ranks_per_step,
635654
)
636655
logger.info(f"Parallel model and optimizer saved to the directory {output_dir}")
637656

optimum/neuron/accelerate/state.py

Lines changed: 8 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -38,7 +38,7 @@
3838
set_neuron_cc_flags_for_torch_amp,
3939
)
4040
from .utils import NeuronDistributedType
41-
from .utils.dataclasses import AutocastBackend, ModelParallelismPlugin
41+
from .utils.dataclasses import AutocastBackend, ModelParallelismConfig
4242

4343

4444
if is_torch_xla_available():
@@ -128,7 +128,7 @@ def __init__(
128128
deepspeed_plugin=None,
129129
fsdp_plugin=None,
130130
megatron_lm_plugin=None,
131-
mp_plugin: Optional[ModelParallelismPlugin] = None,
131+
mp_config: Optional[ModelParallelismConfig] = None,
132132
autocast_backend: Optional[Union[str, AutocastBackend]] = None,
133133
_from_accelerator: bool = False,
134134
**kwargs,
@@ -184,18 +184,18 @@ def __init__(
184184
if mixed_precision == "bf16":
185185
os.environ["NEURON_RT_STOCHASTIC_ROUNDING_EN"] = "1"
186186

187-
if mp_plugin is None:
188-
mp_plugin = ModelParallelismPlugin()
187+
if mp_config is None:
188+
mp_config = ModelParallelismConfig()
189189

190-
if mp_plugin.should_parallelize:
190+
if mp_config.should_parallelize:
191191
self.distributed_type = NeuronDistributedType.MODEL_PARALLELISM
192192

193-
self.mp_plugin = mp_plugin
193+
self.mp_config = mp_config
194194

195195
if torch.distributed.is_initialized() and not parallel_state.model_parallel_is_initialized():
196196
parallel_state.initialize_model_parallel(
197-
tensor_model_parallel_size=self.mp_plugin.tensor_parallel_size,
198-
pipeline_model_parallel_size=self.mp_plugin.pipeline_parallel_size,
197+
tensor_model_parallel_size=self.mp_config.tensor_parallel_size,
198+
pipeline_model_parallel_size=self.mp_config.pipeline_parallel_size,
199199
)
200200

201201
if self.distributed_type is DistributedType.NO:

optimum/neuron/accelerate/utils/__init__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515

1616
from .dataclasses import (
1717
AutocastBackend,
18-
ModelParallelismPlugin,
18+
ModelParallelismConfig,
1919
NeuronDistributedType,
2020
)
2121
from .misc import patch_accelerate_is_torch_xla_available

optimum/neuron/accelerate/utils/dataclasses.py

Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,12 @@
2222
import torch
2323

2424
from ...distributed import ParallelizersManager
25+
from ...utils import is_neuronx_distributed_available
26+
from ...utils.torch_xla_and_neuronx_initialization import init_process_group
27+
28+
29+
if is_neuronx_distributed_available():
30+
from neuronx_distributed.parallel_layers import parallel_state
2531

2632

2733
if TYPE_CHECKING:
@@ -49,7 +55,7 @@ class AutocastBackend(str, enum.Enum):
4955

5056

5157
@dataclass
52-
class ModelParallelismPlugin:
58+
class ModelParallelismConfig:
5359
tensor_parallel_size: int = 1
5460
parallelize_embeddings: bool = True
5561
sequence_parallel_enabled: bool = False
@@ -62,6 +68,9 @@ class ModelParallelismPlugin:
6268
num_local_ranks_per_step: int = 8
6369
use_xser: bool = True
6470
async_save: bool = False
71+
fuse_qkv: bool = False
72+
use_flash_attention: bool = True
73+
recompute_causal_mask: bool = True
6574

6675
def __post_init__(self):
6776
if self.tensor_parallel_size < 1:
@@ -73,6 +82,24 @@ def __post_init__(self):
7382
if isinstance(self.checkpoint_dir, str):
7483
self.checkpoint_dir = Path(self.checkpoint_dir)
7584

85+
if not torch.distributed.is_initialized():
86+
init_process_group()
87+
88+
if not parallel_state.model_parallel_is_initialized():
89+
parallel_state.initialize_model_parallel(
90+
tensor_model_parallel_size=self.tensor_parallel_size,
91+
pipeline_model_parallel_size=self.pipeline_parallel_size,
92+
)
93+
94+
def auto_kv_size_multiplier(self, num_key_value_heads: int) -> int:
95+
kv_size_multiplier = max(1, self.tensor_parallel_size // num_key_value_heads)
96+
if self.kv_size_multiplier is not None and self.kv_size_multiplier != kv_size_multiplier:
97+
raise ValueError(
98+
"A kv size multiplier was already specified and is different from the inferred one: "
99+
f"{self.kv_size_multiplier}"
100+
)
101+
return kv_size_multiplier
102+
76103
@property
77104
def should_parallelize(self):
78105
return self.tensor_parallel_size > 1 or self.pipeline_parallel_size > 1

optimum/neuron/distributed/base.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616

1717
import contextlib
1818
import gc
19+
import inspect
1920
import math
2021
from abc import ABC, abstractclassmethod
2122
from collections import defaultdict
@@ -559,6 +560,10 @@ def parallelize(
559560
orig_model, peft_prefix = get_base_model_and_peft_prefix(model)
560561
model_class = orig_model.__class__
561562

563+
# We skip parallelization if the model is coming from a custom modeling since it is already parallelized.
564+
if inspect.getmodule(orig_model.__class__).__name__.startswith("optimum.neuron.models.training"):
565+
return orig_model
566+
562567
if peft_prefix:
563568
# We update the weight_map to contain both the original parameter names, and the ones in the PeftModel.
564569
# The reason we keep both is because depending on the context during parallelization one or the other name

optimum/neuron/distributed/checkpointing.py

Lines changed: 55 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,7 @@
1616

1717
import json
1818
import os
19+
from functools import partial
1920
from pathlib import Path
2021
from typing import Any, Callable, Dict, List, Literal, Union
2122

@@ -103,7 +104,7 @@ def create_gqa_query_or_output_projection_weight_from_full_weight(
103104
return full_weight
104105

105106

106-
def consolidate_tensor_parallel_checkpoints(
107+
def old_consolidate_tensor_parallel_checkpoints(
107108
sharded_checkpoints: List[Path],
108109
load_function: Callable[[Union[str, Path]], Dict[str, Any]],
109110
metadata: Dict[str, Any],
@@ -205,6 +206,43 @@ def consolidate_tensor_parallel_checkpoints(
205206
return consolidated_state_dict
206207

207208

209+
def consolidate_tensor_parallel_checkpoints(
210+
sharded_checkpoints: List[Path],
211+
load_function: Callable[[Union[str, Path]], Dict[str, Any]],
212+
metadata: Dict[str, Any],
213+
) -> Dict[str, "torch.Tensor"]:
214+
from ..models.training import ModelWeightTransformationSpecs, to_original_weights
215+
216+
state_dicts = []
217+
sharded_checkpoints = sorted(sharded_checkpoints)
218+
for sharded_checkpoint in sharded_checkpoints:
219+
if not sharded_checkpoint.is_file():
220+
continue
221+
state_dicts.append(load_function(sharded_checkpoint.as_posix()))
222+
223+
parameters_metadata = metadata["parameters"]
224+
transformation_specs_metadata = metadata["model_weight_transformation_specs"]
225+
226+
# We recreate the transformation specs from the metadata.
227+
transformations_specs = []
228+
for specs_metadata in transformation_specs_metadata:
229+
specs = ModelWeightTransformationSpecs.from_metadata(specs_metadata)
230+
transformations_specs.append(specs)
231+
232+
# We transform the sharded state dicts as follows:
233+
# [state_dict_tp_rank_0, state_dict_tp_rank_1, ...]
234+
# -> {
235+
# key: [state_dict_tp_rank_0[key], state_dict_tp_rank_1[key], ...],
236+
# for key in state_dict_tp_rank_0.keys()
237+
# }
238+
paramater_names = state_dicts[0].keys()
239+
sharded_state_dicts = {name: [state_dict[name] for state_dict in state_dicts] for name in paramater_names}
240+
241+
consolidated_state_dict = to_original_weights(transformations_specs, sharded_state_dicts, parameters_metadata)
242+
243+
return consolidated_state_dict
244+
245+
208246
@requires_neuronx_distributed
209247
def consolidate_model_parallel_checkpoints(checkpoint_dir: Path) -> Dict[str, "torch.Tensor"]:
210248
model_checkpoint_dir = checkpoint_dir / "model"
@@ -221,26 +259,37 @@ def consolidate_model_parallel_checkpoints(checkpoint_dir: Path) -> Dict[str, "t
221259
# Case 2: If no file was found, maybe the checkpoint was saved without xser.
222260
if not sharded_checkpoints:
223261
sharded_checkpoints = list(model_checkpoint_dir.glob("dp_rank_*.pt"))
224-
load_function = torch.load
262+
load_function = partial(torch.load, weights_only=True)
225263

226264
if not sharded_checkpoints:
227265
raise ValueError(f"Could not find any sharded checkpoint in {model_checkpoint_dir.as_posix()}")
228266

229267
pp_size = max((int(checkpoint_path.stem[-2:]) for checkpoint_path in sharded_checkpoints)) + 1
230268
checkpoints_grouped_by_pp_ranks = [[] for _ in range(pp_size)]
231269
metadatas = []
270+
is_old_metadata = False
232271
for pp_rank in range(pp_size):
233272
for checkpoint_path in sharded_checkpoints:
234273
checkpoint_name = checkpoint_path.stem
235274
if int(checkpoint_name[-2:]) == pp_rank:
236275
checkpoints_grouped_by_pp_ranks[pp_rank].append(checkpoint_path)
237-
metadatas.append(torch.load(checkpoint_dir / f"mp_metadata_pp_rank_{pp_rank}.pt"))
276+
if (checkpoint_dir / f"mp_metadata_pp_rank_{pp_rank}.pt").is_file():
277+
is_old_metadata = True
278+
metadatas.append(torch.load(checkpoint_dir / f"mp_metadata_pp_rank_{pp_rank}.pt"))
279+
else:
280+
with open(checkpoint_dir / f"mp_metadata_pp_rank_{pp_rank}.json") as fp:
281+
metadatas.append(json.load(fp))
238282

239283
consolidated_state_dict = {}
240284
for pp_rank, checkpoint_group_for_pp_rank in enumerate(checkpoints_grouped_by_pp_ranks):
241-
consolidated_for_pp_rank = consolidate_tensor_parallel_checkpoints(
242-
checkpoint_group_for_pp_rank, load_function, metadatas[pp_rank]
243-
)
285+
if is_old_metadata:
286+
consolidated_for_pp_rank = old_consolidate_tensor_parallel_checkpoints(
287+
checkpoint_group_for_pp_rank, load_function, metadatas[pp_rank]
288+
)
289+
else:
290+
consolidated_for_pp_rank = consolidate_tensor_parallel_checkpoints(
291+
checkpoint_group_for_pp_rank, load_function, metadatas[pp_rank]
292+
)
244293
consolidated_state_dict.update(**consolidated_for_pp_rank)
245294

246295
for key, tensor in consolidated_state_dict.items():

0 commit comments

Comments
 (0)