Skip to content

Custom modeling for training #801

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 74 commits into from
May 15, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
74 commits
Select commit Hold shift + click to select a range
b131882
[WIP] modeling
michaelbenayoun Mar 3, 2025
447cadd
[WIP] modeling
michaelbenayoun Mar 3, 2025
c4107ba
[WIP] modeling
michaelbenayoun Mar 4, 2025
93ce12b
[WIP] from_pretrained
michaelbenayoun Mar 10, 2025
7410117
[WIP] from_pretrained
michaelbenayoun Mar 12, 2025
eef5e0e
[WIP] from_pretrained
michaelbenayoun Mar 13, 2025
f67a31f
Incomplete styling
michaelbenayoun Mar 13, 2025
53f6900
Support flash_attention_v2
michaelbenayoun Mar 13, 2025
54819bc
Merge branch 'main' into custom_modeling_introduction
michaelbenayoun Mar 13, 2025
3cd352c
Support for GQA QKV
michaelbenayoun Mar 14, 2025
bdc65e0
[WIP] test
michaelbenayoun Mar 18, 2025
f5d0214
[WIP] test
michaelbenayoun Mar 21, 2025
c137748
[WIP] test
michaelbenayoun Mar 25, 2025
f4f0d8d
WIP
michaelbenayoun Mar 31, 2025
7fb574b
WIP
michaelbenayoun Apr 2, 2025
20245a8
Refactor
michaelbenayoun Apr 3, 2025
a8b247c
[WIP] save_pretrained
michaelbenayoun Apr 7, 2025
22974a0
[WIP] save_pretrained
michaelbenayoun Apr 7, 2025
7d80a8c
Merge branch 'main' into custom_modeling_introduction
michaelbenayoun Apr 10, 2025
a329249
[WIP]
michaelbenayoun Apr 14, 2025
6b486b4
Fix
michaelbenayoun Apr 16, 2025
88ae7ea
Fix
michaelbenayoun Apr 16, 2025
91119e7
Gradient checkpointing
michaelbenayoun Apr 16, 2025
39e5002
Merge branch 'main' into custom_modeling_introduction
michaelbenayoun Apr 16, 2025
8b55f4d
Styling
michaelbenayoun Apr 16, 2025
8d057ab
[WIP] consolidate
michaelbenayoun Apr 16, 2025
3557f3c
[WIP] consolidate
michaelbenayoun Apr 23, 2025
e8752c2
[WIP] consolidate
michaelbenayoun Apr 23, 2025
5e78a7a
styling
michaelbenayoun Apr 23, 2025
7bf94d6
Cleanup
michaelbenayoun Apr 24, 2025
bc31a51
Refactor
michaelbenayoun Apr 24, 2025
3ce97cc
Cleanup
michaelbenayoun Apr 24, 2025
8ac7420
Fix
michaelbenayoun Apr 24, 2025
3e0f2d9
Merge branch 'main' into custom_modeling_introduction
michaelbenayoun Apr 24, 2025
546bd14
Disable PP tests since it is broken
michaelbenayoun Apr 24, 2025
c38a87e
Fix import
michaelbenayoun Apr 24, 2025
73e3e0d
Fixes
michaelbenayoun Apr 24, 2025
52ea95e
Fixes
michaelbenayoun Apr 24, 2025
efc11c4
Fixes
michaelbenayoun Apr 24, 2025
b9e8dc6
Fixes
michaelbenayoun Apr 24, 2025
574afb2
Fixes
michaelbenayoun Apr 24, 2025
2936011
Fixes
michaelbenayoun Apr 24, 2025
226b4c7
Fixes
michaelbenayoun Apr 24, 2025
abd42e0
Fixes
michaelbenayoun Apr 24, 2025
a631c14
Add independant Llama implementation from Transformers
michaelbenayoun Apr 28, 2025
3dbc63d
Remove fake support for cache
michaelbenayoun Apr 28, 2025
b164f5d
Raising an error if intermediate size is not divisible by tp size
michaelbenayoun Apr 28, 2025
5a84a5b
Add the CustomModule class to explicitly mark which submodules need t…
michaelbenayoun Apr 28, 2025
2c05d46
Remove transformers code that we do not use
michaelbenayoun Apr 28, 2025
218aa46
[WIP] from_pretrained in details
michaelbenayoun Apr 28, 2025
549cec2
[WIP] from_pretrained in details
michaelbenayoun Apr 29, 2025
8f40973
[WIP] from_pretrained in details
michaelbenayoun Apr 29, 2025
e05c477
from_pretrained done
michaelbenayoun Apr 29, 2025
9f7eea3
Add comment explaining what transformation_utils.py is about
michaelbenayoun Apr 30, 2025
3fae083
Add comment explaining what transformation_utils.py is about
michaelbenayoun Apr 30, 2025
aad7cad
Change sharding
michaelbenayoun Apr 30, 2025
7b37f12
Remove sharding.py
michaelbenayoun Apr 30, 2025
50df135
Restore sharding.py, this can be removed in the Granite PR
michaelbenayoun Apr 30, 2025
f2a8023
Combime parallel linear tests
michaelbenayoun Apr 30, 2025
109571c
Test with bigger sequence length
michaelbenayoun Apr 30, 2025
a542978
Remove duplicate flash attention test
michaelbenayoun Apr 30, 2025
be5c2db
Add recompute causal mask option
michaelbenayoun Apr 30, 2025
a1fa3c0
Remove _tp_plan and _pp_plan
michaelbenayoun Apr 30, 2025
0fd9f7c
Remove commented code
michaelbenayoun Apr 30, 2025
10f82b4
Remove from_tf and from_flax artifacts
michaelbenayoun Apr 30, 2025
7c45001
Fix comparison
michaelbenayoun Apr 30, 2025
0d8abe6
Tiny changes
michaelbenayoun Apr 30, 2025
f567b75
Add overfitting test
michaelbenayoun Apr 30, 2025
edcd8f0
Fix tests using all NCs
michaelbenayoun May 13, 2025
78123e7
Styling
michaelbenayoun May 13, 2025
7f00b7d
Remove tests from the former approach
michaelbenayoun May 14, 2025
d92f62e
Styling
michaelbenayoun May 14, 2025
6ef46db
Remove tests from the former approach
michaelbenayoun May 14, 2025
1746f76
Remove tests from the former approach
michaelbenayoun May 14, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions docs/source/guides/distributed_training.mdx
Original file line number Diff line number Diff line change
Expand Up @@ -183,11 +183,11 @@ Just as for ZeRO-1, it is possible to wrap the optimizer class to make it lazy.
```python
from torch.optim import AdamW
from optimum.neuron import NeuronAccelerator
from optimum.neuron.accelerate.utils import ModelParallelismPlugin
from optimum.neuron.accelerate.utils import ModelParallelismConfig
from optimum.neuron.distributed import lazy_load_for_parallelism

tensor_parallel_size = 8
mp_plugin = ModelParallelismPlugin(
mp_config = ModelParallelismConfig(
tensor_parallel_size,
parallelize_embeddings=True,
sequence_parallel_enabled=True,
Expand All @@ -196,7 +196,7 @@ mp_plugin = ModelParallelismPlugin(

accelerator = NeuronAccelerator(
...
mp_plugin=mp_plugin,
mp_config=mp_config,
)

with lazy_load_for_parallelism(tensor_parallel_size=tensor_parallel_size):
Expand Down
2 changes: 1 addition & 1 deletion docs/source/training_tutorials/sft_lora_finetune_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -37,7 +37,7 @@ def training_function(script_args, training_args):
r=16,
lora_alpha=16,
lora_dropout=0.05,
target_modules=["q_proj", "gate_proj", "v_proj", "o_proj", "k_proj", "up_proj", "down_proj"],
target_modules=["q_proj", "v_proj", "o_proj", "k_proj", "up_proj", "down_proj"],
bias="none",
task_type="CAUSAL_LM",
)
Expand Down
4 changes: 2 additions & 2 deletions optimum/neuron/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -70,7 +70,7 @@
"NeuronAccelerator",
"NeuronAcceleratorState",
"NeuronPartialState",
"ModelParallelismPlugin",
"ModelParallelismConfig",
],
"pipelines": ["pipeline"],
"utils": ["NeuronSFTConfig", "NeuronORPOConfig", "get_peft_model"],
Expand All @@ -90,7 +90,7 @@
_import_structure["models.yolos"] = ["NeuronYolosForObjectDetection"]

if TYPE_CHECKING:
from .accelerate import ModelParallelismPlugin, NeuronAccelerator, NeuronAcceleratorState, NeuronPartialState
from .accelerate import ModelParallelismConfig, NeuronAccelerator, NeuronAcceleratorState, NeuronPartialState
from .hf_argparser import NeuronHfArgumentParser
from .modeling import (
NeuronModelForAudioClassification,
Expand Down
2 changes: 1 addition & 1 deletion optimum/neuron/accelerate/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,4 +15,4 @@

from .accelerator import NeuronAccelerator
from .state import NeuronAcceleratorState, NeuronPartialState
from .utils.dataclasses import ModelParallelismPlugin, NeuronDistributedType
from .utils.dataclasses import ModelParallelismConfig, NeuronDistributedType
35 changes: 27 additions & 8 deletions optimum/neuron/accelerate/accelerator.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

import collections
import contextlib
import inspect
import os
import re
import shutil
Expand Down Expand Up @@ -56,7 +57,7 @@
from .state import NeuronAcceleratorState
from .utils import (
AutocastBackend,
ModelParallelismPlugin,
ModelParallelismConfig,
NeuronDistributedType,
patch_accelerate_is_torch_xla_available,
)
Expand Down Expand Up @@ -99,7 +100,7 @@ class NeuronAccelerator(Accelerator):
def __init__(
self,
*args,
mp_plugin: Optional[ModelParallelismPlugin] = None,
mp_config: Optional[ModelParallelismConfig] = None,
zero_1: bool = False,
autocast_backend: Union[str, AutocastBackend] = "xla",
**kwargs,
Expand Down Expand Up @@ -147,7 +148,7 @@ def patched_is_torch_xla_available(check_is_tpu: bool = False, check_is_gpu: boo
accelerate.state.is_torch_xla_available = patched_is_torch_xla_available

patched_accelerator_state = partial(
NeuronAcceleratorState, mp_plugin=mp_plugin, autocast_backend=autocast_backend
NeuronAcceleratorState, mp_config=mp_config, autocast_backend=autocast_backend
)
with Patcher([("accelerate.accelerator.AcceleratorState", patched_accelerator_state)]):
super().__init__(**full_kwargs)
Expand Down Expand Up @@ -226,7 +227,7 @@ def prepare_data_loader(
data_loader, num_replicas=num_replicas, rank=rank, force_drop_last=force_drop_last
)
# No need to wrap the dataloader if we are using pipeline parallelism.
if use_mp_device_loader and self.state.mp_plugin.pipeline_parallel_size == 1:
if use_mp_device_loader and self.state.mp_config.pipeline_parallel_size == 1:
data_loader = MpDeviceLoader(data_loader, self.device)
return data_loader

Expand Down Expand Up @@ -302,6 +303,14 @@ def _prepare_optimizer_for_zero_1(self, optimizer: torch.optim.Optimizer, device

@patch_within_function(("accelerate.accelerator.AcceleratedOptimizer", NeuronAcceleratedOptimizer))
def prepare_optimizer(self, optimizer: torch.optim.Optimizer, device_placement: Optional[bool] = None):
# If we use custom modeling, we do not have to do anything for now.
# We will have to do some work when supporting ZeRO-1.
model = self._models[0] if len(self._models) == 1 else None
if model is not None and inspect.getmodule(model.__class__).__name__.startswith(
"optimum.neuron.models.training"
):
return super().prepare_optimizer(optimizer, device_placement=device_placement)
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Does this work with Granite?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It should since it is defined in optimum/neuron/models/training.
The reason we skip the rest is because with custom modeling we do not need lazy loading or anything, so the optimizer is created with the proper parameters already.


if self.distributed_type is NeuronDistributedType.MODEL_PARALLELISM:
optimizer = self._prepare_optimizer_for_mp(optimizer, device_placement=device_placement)
if self.zero_1:
Expand Down Expand Up @@ -385,7 +394,7 @@ def _prepare_model_for_mp(

tied_parameters_dict = get_tied_parameters_dict(model)
model_main_input_name = getattr(model, "main_input_name", None)
model = self.state.mp_plugin.parallelize_model(model, device=self.device)
model = self.state.mp_config.parallelize_model(model, device=self.device)

if model_main_input_name is not None:
setattr(model, "main_input_name", model_main_input_name)
Expand Down Expand Up @@ -444,6 +453,16 @@ def prepare_model(
# we get access to the model, we simply check if the flags are the best and notify the user otherwise.
check_neuron_cc_flags_for_model(model)

if inspect.getmodule(model.__class__).__name__.startswith("optimum.neuron.models.training"):
# We do not want to use the cache, or output unused tensors as it would imply more communication that we do not
# need.
model.config.use_cache = False
model.config.output_attentions = False
model.config.output_hidden_states = False
move_model_to_device(model, self.device)
model = super().prepare_model(model, device_placement=False, evaluation_mode=evaluation_mode)
return model

model = self.patch_model_for_neuron(model)

if self.state.mixed_precision == "bf16":
Expand Down Expand Up @@ -629,9 +648,9 @@ def save_optimizer_func(accelerator, optimizer, model, output_dir, i):
model,
output_dir,
optimizer=optimizer,
use_xser=self.state.mp_plugin.use_xser,
async_save=self.state.mp_plugin.async_save,
num_local_ranks_per_step=self.state.mp_plugin.num_local_ranks_per_step,
use_xser=self.state.mp_config.use_xser,
async_save=self.state.mp_config.async_save,
num_local_ranks_per_step=self.state.mp_config.num_local_ranks_per_step,
)
logger.info(f"Parallel model and optimizer saved to the directory {output_dir}")

Expand Down
16 changes: 8 additions & 8 deletions optimum/neuron/accelerate/state.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@
set_neuron_cc_flags_for_torch_amp,
)
from .utils import NeuronDistributedType
from .utils.dataclasses import AutocastBackend, ModelParallelismPlugin
from .utils.dataclasses import AutocastBackend, ModelParallelismConfig


if is_torch_xla_available():
Expand Down Expand Up @@ -128,7 +128,7 @@ def __init__(
deepspeed_plugin=None,
fsdp_plugin=None,
megatron_lm_plugin=None,
mp_plugin: Optional[ModelParallelismPlugin] = None,
mp_config: Optional[ModelParallelismConfig] = None,
autocast_backend: Optional[Union[str, AutocastBackend]] = None,
_from_accelerator: bool = False,
**kwargs,
Expand Down Expand Up @@ -184,18 +184,18 @@ def __init__(
if mixed_precision == "bf16":
os.environ["NEURON_RT_STOCHASTIC_ROUNDING_EN"] = "1"

if mp_plugin is None:
mp_plugin = ModelParallelismPlugin()
if mp_config is None:
mp_config = ModelParallelismConfig()

if mp_plugin.should_parallelize:
if mp_config.should_parallelize:
self.distributed_type = NeuronDistributedType.MODEL_PARALLELISM

self.mp_plugin = mp_plugin
self.mp_config = mp_config

if torch.distributed.is_initialized() and not parallel_state.model_parallel_is_initialized():
parallel_state.initialize_model_parallel(
tensor_model_parallel_size=self.mp_plugin.tensor_parallel_size,
pipeline_model_parallel_size=self.mp_plugin.pipeline_parallel_size,
tensor_model_parallel_size=self.mp_config.tensor_parallel_size,
pipeline_model_parallel_size=self.mp_config.pipeline_parallel_size,
)

if self.distributed_type is DistributedType.NO:
Expand Down
2 changes: 1 addition & 1 deletion optimum/neuron/accelerate/utils/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,7 @@

from .dataclasses import (
AutocastBackend,
ModelParallelismPlugin,
ModelParallelismConfig,
NeuronDistributedType,
)
from .misc import patch_accelerate_is_torch_xla_available
29 changes: 28 additions & 1 deletion optimum/neuron/accelerate/utils/dataclasses.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,12 @@
import torch

from ...distributed import ParallelizersManager
from ...utils import is_neuronx_distributed_available
from ...utils.torch_xla_and_neuronx_initialization import init_process_group


if is_neuronx_distributed_available():
from neuronx_distributed.parallel_layers import parallel_state


if TYPE_CHECKING:
Expand Down Expand Up @@ -49,7 +55,7 @@ class AutocastBackend(str, enum.Enum):


@dataclass
class ModelParallelismPlugin:
class ModelParallelismConfig:
tensor_parallel_size: int = 1
parallelize_embeddings: bool = True
sequence_parallel_enabled: bool = False
Expand All @@ -62,6 +68,9 @@ class ModelParallelismPlugin:
num_local_ranks_per_step: int = 8
use_xser: bool = True
async_save: bool = False
fuse_qkv: bool = False
use_flash_attention: bool = True
recompute_causal_mask: bool = True

def __post_init__(self):
if self.tensor_parallel_size < 1:
Expand All @@ -73,6 +82,24 @@ def __post_init__(self):
if isinstance(self.checkpoint_dir, str):
self.checkpoint_dir = Path(self.checkpoint_dir)

if not torch.distributed.is_initialized():
init_process_group()

if not parallel_state.model_parallel_is_initialized():
parallel_state.initialize_model_parallel(
tensor_model_parallel_size=self.tensor_parallel_size,
pipeline_model_parallel_size=self.pipeline_parallel_size,
)

def auto_kv_size_multiplier(self, num_key_value_heads: int) -> int:
kv_size_multiplier = max(1, self.tensor_parallel_size // num_key_value_heads)
if self.kv_size_multiplier is not None and self.kv_size_multiplier != kv_size_multiplier:
raise ValueError(
"A kv size multiplier was already specified and is different from the inferred one: "
f"{self.kv_size_multiplier}"
)
return kv_size_multiplier

@property
def should_parallelize(self):
return self.tensor_parallel_size > 1 or self.pipeline_parallel_size > 1
Expand Down
5 changes: 5 additions & 0 deletions optimum/neuron/distributed/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

import contextlib
import gc
import inspect
import math
from abc import ABC, abstractclassmethod
from collections import defaultdict
Expand Down Expand Up @@ -559,6 +560,10 @@ def parallelize(
orig_model, peft_prefix = get_base_model_and_peft_prefix(model)
model_class = orig_model.__class__

# We skip parallelization if the model is coming from a custom modeling since it is already parallelized.
if inspect.getmodule(orig_model.__class__).__name__.startswith("optimum.neuron.models.training"):
return orig_model

if peft_prefix:
# We update the weight_map to contain both the original parameter names, and the ones in the PeftModel.
# The reason we keep both is because depending on the context during parallelization one or the other name
Expand Down
61 changes: 55 additions & 6 deletions optimum/neuron/distributed/checkpointing.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

import json
import os
from functools import partial
from pathlib import Path
from typing import Any, Callable, Dict, List, Literal, Union

Expand Down Expand Up @@ -103,7 +104,7 @@ def create_gqa_query_or_output_projection_weight_from_full_weight(
return full_weight


def consolidate_tensor_parallel_checkpoints(
def old_consolidate_tensor_parallel_checkpoints(
sharded_checkpoints: List[Path],
load_function: Callable[[Union[str, Path]], Dict[str, Any]],
metadata: Dict[str, Any],
Expand Down Expand Up @@ -205,6 +206,43 @@ def consolidate_tensor_parallel_checkpoints(
return consolidated_state_dict


def consolidate_tensor_parallel_checkpoints(
sharded_checkpoints: List[Path],
load_function: Callable[[Union[str, Path]], Dict[str, Any]],
metadata: Dict[str, Any],
) -> Dict[str, "torch.Tensor"]:
from ..models.training import ModelWeightTransformationSpecs, to_original_weights

state_dicts = []
sharded_checkpoints = sorted(sharded_checkpoints)
for sharded_checkpoint in sharded_checkpoints:
if not sharded_checkpoint.is_file():
continue
state_dicts.append(load_function(sharded_checkpoint.as_posix()))

parameters_metadata = metadata["parameters"]
transformation_specs_metadata = metadata["model_weight_transformation_specs"]

# We recreate the transformation specs from the metadata.
transformations_specs = []
for specs_metadata in transformation_specs_metadata:
specs = ModelWeightTransformationSpecs.from_metadata(specs_metadata)
transformations_specs.append(specs)

# We transform the sharded state dicts as follows:
# [state_dict_tp_rank_0, state_dict_tp_rank_1, ...]
# -> {
# key: [state_dict_tp_rank_0[key], state_dict_tp_rank_1[key], ...],
# for key in state_dict_tp_rank_0.keys()
# }
paramater_names = state_dicts[0].keys()
sharded_state_dicts = {name: [state_dict[name] for state_dict in state_dicts] for name in paramater_names}

consolidated_state_dict = to_original_weights(transformations_specs, sharded_state_dicts, parameters_metadata)

return consolidated_state_dict


@requires_neuronx_distributed
def consolidate_model_parallel_checkpoints(checkpoint_dir: Path) -> Dict[str, "torch.Tensor"]:
model_checkpoint_dir = checkpoint_dir / "model"
Expand All @@ -221,26 +259,37 @@ def consolidate_model_parallel_checkpoints(checkpoint_dir: Path) -> Dict[str, "t
# Case 2: If no file was found, maybe the checkpoint was saved without xser.
if not sharded_checkpoints:
sharded_checkpoints = list(model_checkpoint_dir.glob("dp_rank_*.pt"))
load_function = torch.load
load_function = partial(torch.load, weights_only=True)

if not sharded_checkpoints:
raise ValueError(f"Could not find any sharded checkpoint in {model_checkpoint_dir.as_posix()}")

pp_size = max((int(checkpoint_path.stem[-2:]) for checkpoint_path in sharded_checkpoints)) + 1
checkpoints_grouped_by_pp_ranks = [[] for _ in range(pp_size)]
metadatas = []
is_old_metadata = False
for pp_rank in range(pp_size):
for checkpoint_path in sharded_checkpoints:
checkpoint_name = checkpoint_path.stem
if int(checkpoint_name[-2:]) == pp_rank:
checkpoints_grouped_by_pp_ranks[pp_rank].append(checkpoint_path)
metadatas.append(torch.load(checkpoint_dir / f"mp_metadata_pp_rank_{pp_rank}.pt"))
if (checkpoint_dir / f"mp_metadata_pp_rank_{pp_rank}.pt").is_file():
is_old_metadata = True
metadatas.append(torch.load(checkpoint_dir / f"mp_metadata_pp_rank_{pp_rank}.pt"))
else:
with open(checkpoint_dir / f"mp_metadata_pp_rank_{pp_rank}.json") as fp:
metadatas.append(json.load(fp))

consolidated_state_dict = {}
for pp_rank, checkpoint_group_for_pp_rank in enumerate(checkpoints_grouped_by_pp_ranks):
consolidated_for_pp_rank = consolidate_tensor_parallel_checkpoints(
checkpoint_group_for_pp_rank, load_function, metadatas[pp_rank]
)
if is_old_metadata:
consolidated_for_pp_rank = old_consolidate_tensor_parallel_checkpoints(
checkpoint_group_for_pp_rank, load_function, metadatas[pp_rank]
)
else:
consolidated_for_pp_rank = consolidate_tensor_parallel_checkpoints(
checkpoint_group_for_pp_rank, load_function, metadatas[pp_rank]
)
consolidated_state_dict.update(**consolidated_for_pp_rank)

for key, tensor in consolidated_state_dict.items():
Expand Down
Loading
Loading