Skip to content

[Shardformer] support dualpipe schedule #6229

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 14 commits into
base: feature/dualpipe
Choose a base branch
from
Open
Original file line number Diff line number Diff line change
Expand Up @@ -257,7 +257,7 @@ def is_master():
)

torch.set_default_dtype(torch.float)
booster.load_model(model, args.pretrained)
booster.load_model(model, args.pretrained, low_cpu_mem_mode=False, num_threads=8)

coordinator.print_on_master(
f"Booster init max device memory: {accelerator.max_memory_allocated() / 1024 ** 2:.2f} MB"
Expand Down
14 changes: 7 additions & 7 deletions colossalai/booster/plugin/gemini_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,11 +85,11 @@ def save_unsharded_model(
if use_async:
from colossalai.utils.safetensors import save

if id(model) not in self.pinned_state_dicts:
self.pinned_state_dicts[id(model)] = create_pinned_state_dict(state_dict)
if hash(model) not in self.pinned_state_dicts:
self.pinned_state_dicts[hash(model)] = create_pinned_state_dict(state_dict)
for k, v in state_dict.items():
self.pinned_state_dicts[id(model)][k].copy_(v)
state_dict[k] = self.pinned_state_dicts[id(model)][k]
self.pinned_state_dicts[hash(model)][k].copy_(v)
state_dict[k] = self.pinned_state_dicts[hash(model)][k]
writer = save(checkpoint, state_dict)
self.async_writers.append(writer)
else:
Expand Down Expand Up @@ -172,9 +172,9 @@ def save_sharded_model(
Path(checkpoint_path).mkdir(parents=True, exist_ok=True)

if use_async and self.coordinator.is_master():
if id(model) not in self.pinned_state_dicts:
self.pinned_state_dicts[id(model)] = {}
pinned_state_dicts = self.pinned_state_dicts[id(model)]
if hash(model) not in self.pinned_state_dicts:
self.pinned_state_dicts[hash(model)] = {}
pinned_state_dicts = self.pinned_state_dicts[hash(model)]
else:
pinned_state_dicts = None
state_dict_shard = model.state_dict_shard(
Expand Down
3 changes: 2 additions & 1 deletion colossalai/booster/plugin/hybrid_parallel_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from colossalai.checkpoint_io import CheckpointIO, HybridParallelCheckpointIO
from colossalai.cluster import ProcessGroupMesh
from colossalai.interface import AMPModelMixin, ModelWrapper, OptimizerWrapper
from colossalai.interface.model import PeftUnwrapMixin
from colossalai.interface.optimizer import DistributedOptim
from colossalai.logging import get_dist_logger
from colossalai.nn.optimizer import DistGaloreAwamW, cast_to_distributed
Expand Down Expand Up @@ -225,7 +226,7 @@ def unwrap(self, unwrap_peft: bool = True):
if isinstance(model, DDP):
model = model.module
if unwrap_peft and isinstance(model, PeftModel):
model = model.get_base_model()
model = PeftUnwrapMixin(model)
return model

def _force_wait_all_gather(self):
Expand Down
3 changes: 2 additions & 1 deletion colossalai/booster/plugin/torch_ddp_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from colossalai.checkpoint_io import CheckpointIO, GeneralCheckpointIO
from colossalai.cluster import DistCoordinator
from colossalai.interface import ModelWrapper, OptimizerWrapper
from colossalai.interface.model import PeftUnwrapMixin
from colossalai.logging import get_dist_logger
from colossalai.quantization import BnbQuantizationConfig, quantize_model
from colossalai.utils import get_current_device
Expand Down Expand Up @@ -201,7 +202,7 @@ def __init__(self, module: nn.Module, *args, **kwargs) -> None:
def unwrap(self, unwrap_peft: bool = True) -> nn.Module:
model = self.module.module
if unwrap_peft and isinstance(model, PeftModel):
model = model.get_base_model()
model = PeftUnwrapMixin(model)
return model


Expand Down
14 changes: 7 additions & 7 deletions colossalai/booster/plugin/torch_fsdp_plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,11 +103,11 @@ def save_unsharded_model(
if use_async:
from colossalai.utils.safetensors import save

if id(model) not in self.pinned_state_dicts:
self.pinned_state_dicts[id(model)] = create_pinned_state_dict(full_model_state)
if hash(model) not in self.pinned_state_dicts:
self.pinned_state_dicts[hash(model)] = create_pinned_state_dict(full_model_state)
for k, v in full_model_state.items():
self.pinned_state_dicts[id(model)][k].copy_(v)
full_model_state[k] = self.pinned_state_dicts[id(model)][k]
self.pinned_state_dicts[hash(model)][k].copy_(v)
full_model_state[k] = self.pinned_state_dicts[hash(model)][k]
writer = save(checkpoint, full_model_state)
self.async_writers.append(writer)
else:
Expand Down Expand Up @@ -186,9 +186,9 @@ def save_sharded_model(
state_dict = model.unwrap().state_dict()

if use_async and self.coordinator.is_master():
if id(model) not in self.pinned_state_dicts:
self.pinned_state_dicts[id(model)] = {}
pinned_state_dicts = self.pinned_state_dicts[id(model)]
if hash(model) not in self.pinned_state_dicts:
self.pinned_state_dicts[hash(model)] = {}
pinned_state_dicts = self.pinned_state_dicts[hash(model)]
else:
pinned_state_dicts = None
state_dict_shard = utils.shard_model_checkpoint(
Expand Down
10 changes: 5 additions & 5 deletions colossalai/checkpoint_io/general_checkpoint_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,9 +60,9 @@ def save_unsharded_model(
if use_async:
from colossalai.utils.safetensors import move_and_save

if id(model) not in self.pinned_state_dicts:
self.pinned_state_dicts[id(model)] = create_pinned_state_dict(state_dict)
writer = move_and_save(checkpoint, state_dict, self.pinned_state_dicts[id(model)])
if hash(model) not in self.pinned_state_dicts:
self.pinned_state_dicts[hash(model)] = create_pinned_state_dict(state_dict)
writer = move_and_save(checkpoint, state_dict, self.pinned_state_dicts[hash(model)])
self.async_writers.append(writer)
else:
# save the checkpoint
Expand Down Expand Up @@ -234,7 +234,7 @@ def save_sharded_model(
index_file = CheckpointIndexFile(checkpoint_path)

if use_async:
pinned_state_dict = self.pinned_state_dicts.get(id(model), None)
pinned_state_dict = self.pinned_state_dicts.get(hash(model), None)
total_size, new_pinned_state_dict, writers = async_move_save_state_dict_shards(
sharded_state_dict=state_dict_shard,
checkpoint=checkpoint_path,
Expand All @@ -243,7 +243,7 @@ def save_sharded_model(
is_master=True,
pinned_state_dict=pinned_state_dict,
)
self.pinned_state_dicts[id(model)] = new_pinned_state_dict
self.pinned_state_dicts[hash(model)] = new_pinned_state_dict
self.async_writers.extend(writers)
else:
# Save shards of optimizer states.
Expand Down
22 changes: 11 additions & 11 deletions colossalai/checkpoint_io/hybrid_parallel_checkpoint_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -249,9 +249,9 @@ def save_sharded_model(
# Only devices with tp_rank == 0 are responsible for model saving.
control_saving = self.tp_rank == 0 and self.sp_rank == 0
if control_saving and use_async:
if id(model) not in self.pinned_state_dicts:
self.pinned_state_dicts[id(model)] = {}
pinned_state_dicts = self.pinned_state_dicts[id(model)]
if hash(model) not in self.pinned_state_dicts:
self.pinned_state_dicts[hash(model)] = {}
pinned_state_dicts = self.pinned_state_dicts[hash(model)]
else:
pinned_state_dicts = None
state_dict_shard = HybridParallelCheckpointIO._model_sharder(
Expand Down Expand Up @@ -789,11 +789,11 @@ def save_unsharded_model(
if use_async:
from colossalai.utils.safetensors import save

if id(model) not in self.pinned_state_dicts:
self.pinned_state_dicts[id(model)] = create_pinned_state_dict(state_dict)
if hash(model) not in self.pinned_state_dicts:
self.pinned_state_dicts[hash(model)] = create_pinned_state_dict(state_dict)
for name, param in state_dict.items():
self.pinned_state_dicts[id(model)][name].copy_(param)
state_dict[name] = self.pinned_state_dicts[id(model)][name]
self.pinned_state_dicts[hash(model)][name].copy_(param)
state_dict[name] = self.pinned_state_dicts[hash(model)][name]
writer = save(path=checkpoint, state_dict=state_dict)
self.async_writers.append(writer)
else:
Expand All @@ -811,11 +811,11 @@ def save_unsharded_model(
if use_async:
from colossalai.utils.safetensors import save

if id(model) not in self.pinned_state_dicts:
self.pinned_state_dicts[id(model)] = create_pinned_state_dict(complete_state_dict)
if hash(model) not in self.pinned_state_dicts:
self.pinned_state_dicts[hash(model)] = create_pinned_state_dict(complete_state_dict)
for name, param in complete_state_dict.items():
self.pinned_state_dicts[id(model)][name].copy_(param)
complete_state_dict[name] = self.pinned_state_dicts[id(model)][name]
self.pinned_state_dicts[hash(model)][name].copy_(param)
complete_state_dict[name] = self.pinned_state_dicts[hash(model)][name]
writer = save(path=checkpoint, state_dict=complete_state_dict)
self.async_writers.append(writer)
else:
Expand Down
9 changes: 6 additions & 3 deletions colossalai/checkpoint_io/moe_checkpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -701,15 +701,18 @@ def pre_save_model(self, model: nn.Module) -> dict:
all_param = None
# gather param from every ep rank
# dist.all_gather(all_param, param, group=ep_group)
dist.gather(param, all_param, group=ep_group)
dist.gather(param, all_param, dst=dist.get_global_rank(ep_group, 0), group=ep_group)
if ep_rank == 0:
all_param = torch.cat(all_param, dim=0)
state_dict[name] = all_param.cpu()

if self.pp_size > 1:
if self.dp_rank == 0:
out = [None for _ in range(self.pp_size)]
dist.gather_object(state_dict, out, group=self.pp_group)
if self.pp_rank == 0:
out = [None for _ in range(self.pp_size)]
else:
out = None
dist.gather_object(state_dict, out, dst=dist.get_global_rank(self.pp_group, 0), group=self.pp_group)
if self.pp_rank == 0:
new_state_dict = {}
for o in out:
Expand Down
6 changes: 6 additions & 0 deletions colossalai/checkpoint_io/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
from torch.utils._pytree import tree_flatten, tree_map, tree_unflatten

from colossalai.accelerator import get_accelerator
from colossalai.interface.model import PeftUnwrapMixin
from colossalai.tensor.d_tensor import (
is_customized_distributed_tensor,
is_distributed_tensor,
Expand Down Expand Up @@ -554,6 +555,8 @@ def save_config_file(model: nn.Module, checkpoint_path: str, is_master: bool = T
from transformers.modeling_utils import unwrap_model as unwrap_huggingface_model
except ImportError:
return
if isinstance(model, PeftUnwrapMixin):
model = model.base_model
if not isinstance(model, PreTrainedModel):
return

Expand Down Expand Up @@ -692,6 +695,9 @@ def load_state_dict_into_model(
state_dict (dict): a dict containing parameters and
persistent buffers.
"""
if isinstance(model, PeftUnwrapMixin):
state_dict = model.patch_state_dict(state_dict)
model = model.base_model
if not isinstance(state_dict, Mapping):
raise TypeError("Expected state_dict to be dict-like, got {}.".format(type(state_dict)))

Expand Down
101 changes: 99 additions & 2 deletions colossalai/interface/model.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,102 @@
import re
from typing import Dict, Set

import torch
import torch.nn as nn
from peft import PeftModel
from peft import PeftModel, PeftType


def extract_lora_layers(model: PeftModel, names: Set[str], adapter_name: str = "default"):
config = model.peft_config[adapter_name]
if config.peft_type != PeftType.LORA:
raise ValueError(f"Adapter {adapter_name} is not a LORA adapter.")
# to_return = lora_state_dict(model, bias=model.peft_config.bias)
# adapted from `https://github.com/microsoft/LoRA/blob/main/loralib/utils.py`
# to be used directly with the state dict which is necessary when using DeepSpeed or FSDP
bias = config.bias
if bias == "none":
to_return = {k for k in names if "lora_" in k}
elif bias == "all":
to_return = {k for k in names if "lora_" in k or "bias" in k}
elif bias == "lora_only":
to_return = set()
for k in names:
if "lora_" in k:
to_return.add(k)
bias_name = k.split("lora_")[0] + "bias"
if bias_name in names:
to_return.add(bias_name)
else:
raise NotImplementedError
to_return = {k for k in to_return if (("lora_" in k and adapter_name in k) or ("bias" in k))}
if config.use_dora:
# Here we take care of a refactor of DoRA which changed lora_magnitude_vector from a ParameterDict to a
# ModuleDict with a DoraLayer instance. The old parameter is now the "weight" attribute of that layer. Since
# we want the state_dict format not to change, we remove the "weight" part.
new_dora_suffix = f"lora_magnitude_vector.{adapter_name}.weight"

def renamed_dora_weights(k):
if k.endswith(new_dora_suffix):
k = k[:-7] # remove ".weight"
return k

to_return = {renamed_dora_weights(k) for k in to_return}

to_return = {re.sub(f"lora_\S\.{adapter_name}\.(weight|bias)", "base_layer", k) for k in to_return}
return to_return


class PeftUnwrapMixin:
def __init__(self, peft_model: PeftModel):
self.base_model = peft_model.get_base_model()
# peft does not affect buffers
self.lora_layers = extract_lora_layers(peft_model, set(n for n, p in self.base_model.named_parameters()))
potential_lora_weights = set()
for n in self.lora_layers:
potential_lora_weights.add(f"{n}.weight")
potential_lora_weights.add(f"{n}.bias")
self.lora_param_to_origin_param = {n: n.replace("base_layer.", "") for n in potential_lora_weights}
self.origin_param_to_lora_param = {v: k for k, v in self.lora_param_to_origin_param.items()}

def named_parameters(self):
for n, p in self.base_model.named_parameters():
if n in self.lora_param_to_origin_param:
n = self.lora_param_to_origin_param[n]
yield n, p

def named_buffers(self):
return self.base_model.named_buffers()

@property
def _modules(self):
return self.base_model._modules

@property
def _non_persistent_buffers_set(self):
return self.base_model._non_persistent_buffers_set

def patch_state_dict(self, state_dict: Dict[str, torch.Tensor]):
new_state_dict = {}
for k, v in state_dict.items():
if k in self.origin_param_to_lora_param:
k = self.origin_param_to_lora_param[k]
new_state_dict[k] = v
return new_state_dict

def state_dict(self):
state_dict = {}
for k, v in self.base_model.state_dict().items():
if k in self.lora_param_to_origin_param:
k = self.lora_param_to_origin_param[k]
state_dict[k] = v
return state_dict

def load_state_dict(self, state_dict, strict: bool = True, assign: bool = False):
state_dict = self.patch_state_dict(state_dict)
self.base_model.load_state_dict(state_dict, strict=strict, assign=assign)

def __hash__(self):
return hash(self.base_model)


class ModelWrapper(nn.Module):
Expand All @@ -23,7 +120,7 @@ def unwrap(self, unwrap_peft: bool = True):
else:
model = self.module
if unwrap_peft and isinstance(model, PeftModel):
model = model.get_base_model()
model = PeftUnwrapMixin(model)
return model

def forward(self, *args, **kwargs):
Expand Down
Loading