Skip to content

Commit 307d8f1

Browse files
[pre-commit.ci] auto fixes from pre-commit.com hooks
for more information, see https://pre-commit.ci
1 parent 650af68 commit 307d8f1

File tree

3 files changed

+103
-78
lines changed

3 files changed

+103
-78
lines changed

Diff for: colossalai/checkpoint_io/__init__.py

+3-2
Original file line numberDiff line numberDiff line change
@@ -1,7 +1,8 @@
11
from .checkpoint_io_base import CheckpointIO
22
from .general_checkpoint_io import GeneralCheckpointIO
33
from .hybrid_parallel_checkpoint_io import HybridParallelCheckpointIO
4-
from.distributed_checkpoint_io import DistributedCheckpointIO
4+
5+
from .distributed_checkpoint_io import DistributedCheckpointIO
56
from .index_file import CheckpointIndexFile
67
from .moe_checkpoint import MoECheckpointIO
78

@@ -11,5 +12,5 @@
1112
"GeneralCheckpointIO",
1213
"HybridParallelCheckpointIO",
1314
"MoECheckpointIO",
14-
"DistributedCheckpointIO"
15+
"DistributedCheckpointIO",
1516
]

Diff for: colossalai/checkpoint_io/distributed_checkpoint_io.py

+61-67
Original file line numberDiff line numberDiff line change
@@ -1,51 +1,30 @@
1-
import copy
1+
import json
22
import logging
33
import os
4-
from functools import reduce
54
from pathlib import Path
6-
from shutil import rmtree
75
from typing import Dict, Iterator, Optional, OrderedDict, Tuple
8-
import json
96

107
import torch
118
import torch.distributed as dist
129
import torch.nn as nn
1310
from torch.distributed import ProcessGroup
14-
from torch.optim.lr_scheduler import _LRScheduler as LRScheduler
15-
from torch.utils._pytree import tree_map
11+
from torch.distributed.distributed_c10d import _get_default_group
1612

1713
from colossalai.cluster import DistCoordinator
18-
from colossalai.interface import ModelWrapper, OptimizerWrapper
19-
from colossalai.tensor.padded_tensor import (
20-
init_as_padded_tensor,
21-
is_padded_tensor,
22-
to_padded_tensor,
23-
to_unpadded_tensor,
24-
)
25-
from colossalai.utils import get_current_device, get_non_persistent_buffers_set
26-
from torch.distributed.distributed_c10d import _get_default_group
14+
from colossalai.interface import ModelWrapper
15+
from colossalai.utils import get_non_persistent_buffers_set
2716

2817
from .general_checkpoint_io import GeneralCheckpointIO
2918
from .index_file import CheckpointIndexFile
3019
from .utils import (
3120
StateDictSharder,
3221
async_save_state_dict_shards,
3322
create_pinned_state_dict,
34-
gather_distributed_param,
3523
get_model_base_filenames,
36-
get_optimizer_base_filenames,
37-
is_safetensors_available,
38-
load_shard_state_dict,
3924
load_state_dict,
40-
load_state_dict_into_model,
41-
load_states_into_optimizer,
42-
save_config_file,
43-
save_param_groups,
4425
save_state_dict,
4526
save_state_dict_shards,
46-
search_padding_dim,
4727
search_tp_partition_dim,
48-
sharded_optimizer_loading_epilogue,
4928
)
5029

5130
try:
@@ -97,7 +76,6 @@ def __init__(
9776
self.model_metadata = None
9877
self.optimizer_metadata = None
9978
self.global_rank = dist.get_rank(_get_default_group())
100-
10179

10280
@staticmethod
10381
def model_state_dict(model: nn.Module, prefix: str = "", keep_vars: bool = False):
@@ -106,13 +84,13 @@ def model_state_dict(model: nn.Module, prefix: str = "", keep_vars: bool = False
10684
for name, param in model.named_parameters():
10785
if param is None:
10886
continue
109-
destination[prefix+name] = param
87+
destination[prefix + name] = param
11088
# Save buffers.
11189
non_persist_buffers_set = get_non_persistent_buffers_set(model)
11290
for name, buf in model.named_buffers():
11391
if buf is not None and name not in non_persist_buffers_set:
11492
buffer = buf if keep_vars else buf.detach()
115-
destination[prefix+name] = buffer
93+
destination[prefix + name] = buffer
11694

11795
# Save extra states.
11896
extra_state_key = prefix + _EXTRA_STATE_KEY_SUFFIX
@@ -123,22 +101,24 @@ def model_state_dict(model: nn.Module, prefix: str = "", keep_vars: bool = False
123101
extra_state = model.get_extra_state()
124102
destination[extra_state_key] = extra_state
125103
return destination
126-
104+
127105
@staticmethod
128-
def load_state_dict(model: nn.Module, state_dict: Dict, prefix: str = "", keep_vars: bool = False, strict: bool = False):
106+
def load_state_dict(
107+
model: nn.Module, state_dict: Dict, prefix: str = "", keep_vars: bool = False, strict: bool = False
108+
):
129109
destination = dict()
130110
# Save parameters.
131111
for name, param in model.named_parameters():
132112
if param is None:
133113
continue
134114
with torch.no_grad():
135-
param.copy_(state_dict[prefix+name])
115+
param.copy_(state_dict[prefix + name])
136116
# Save buffers.
137117
non_persist_buffers_set = get_non_persistent_buffers_set(model)
138118
for name, buf in model.named_buffers():
139119
if buf is not None and name not in non_persist_buffers_set:
140120
with torch.no_grad():
141-
buf.copy_(state_dict[prefix+name])
121+
buf.copy_(state_dict[prefix + name])
142122

143123
# Save extra states.
144124
extra_state_key = prefix + _EXTRA_STATE_KEY_SUFFIX
@@ -151,26 +131,33 @@ def load_state_dict(model: nn.Module, state_dict: Dict, prefix: str = "", keep_v
151131
extra_state.copy_(state_dict[extra_state_key])
152132
return destination
153133

154-
def create_model_metadata(self, model: nn.Module, prefix: str = "",):
134+
def create_model_metadata(
135+
self,
136+
model: nn.Module,
137+
prefix: str = "",
138+
):
155139
param_origin_shape = model.param_origin_shape
156140
model = model.unwrap()
157141
self.model_metadata = {}
158142
for name, param in model.named_parameters():
159143
if param is None:
160144
continue
161-
self.model_metadata[prefix+name] = {}
145+
self.model_metadata[prefix + name] = {}
162146
original_shape = param_origin_shape[name]
163-
tp_partition_dim = search_tp_partition_dim(current_shape=param.shape, original_shape=original_shape, tp_size=self.tp_size)
164-
self.model_metadata[prefix+name]["offsets"] = torch.zeros(len(original_shape), dtype=torch.int)
165-
self.model_metadata[prefix+name]["lengths"] = list(param.shape)
166-
self.model_metadata[prefix+name]["global_shape"] = list(original_shape)
147+
tp_partition_dim = search_tp_partition_dim(
148+
current_shape=param.shape, original_shape=original_shape, tp_size=self.tp_size
149+
)
150+
self.model_metadata[prefix + name]["offsets"] = torch.zeros(len(original_shape), dtype=torch.int)
151+
self.model_metadata[prefix + name]["lengths"] = list(param.shape)
152+
self.model_metadata[prefix + name]["global_shape"] = list(original_shape)
167153
if tp_partition_dim is not None:
168154
partition_size = param.shape[tp_partition_dim]
169-
self.model_metadata[prefix+name]["offsets"][tp_partition_dim] = partition_size * self.tp_rank
155+
self.model_metadata[prefix + name]["offsets"][tp_partition_dim] = partition_size * self.tp_rank
170156
if self.tp_rank == self.tp_size - 1:
171-
self.model_metadata[prefix+name]["lengths"][tp_partition_dim] = original_shape[tp_partition_dim] - (partition_size * (self.tp_size -1))
157+
self.model_metadata[prefix + name]["lengths"][tp_partition_dim] = original_shape[
158+
tp_partition_dim
159+
] - (partition_size * (self.tp_size - 1))
172160

173-
174161
def save_metadata(self, metadata_file, checkpoint_file=None, total_size=None):
175162
metadata_dicts = {
176163
"checkpoint_version": "1.0",
@@ -188,7 +175,7 @@ def save_metadata(self, metadata_file, checkpoint_file=None, total_size=None):
188175
metadata_dicts["metadata"][name]["rank"] = self.global_rank
189176
with open(metadata_file, "w") as json_file:
190177
json.dump(metadata_dicts, json_file, indent=4)
191-
178+
192179
def save_unsharded_model(
193180
self, model: ModelWrapper, checkpoint: str, gather_dtensor: bool, use_safetensors: bool, use_async: bool = False
194181
):
@@ -249,13 +236,13 @@ def load_metadata(self, checkpoint: str):
249236
try:
250237
with open(file_path, "r") as f:
251238
metadata_json = json.load(f)
252-
for name, item in metadata_json['metadata'].items():
239+
for name, item in metadata_json["metadata"].items():
253240
if name not in metadata_dict:
254241
metadata_dict[name] = {}
255-
metadata_dict[name]["global_shape"] = item['global_shape']
242+
metadata_dict[name]["global_shape"] = item["global_shape"]
256243
metadata_dict[name]["shards"] = {}
257244
else:
258-
assert metadata_dict[name]["global_shape"] == item['global_shape']
245+
assert metadata_dict[name]["global_shape"] == item["global_shape"]
259246
shard = {}
260247
shard[item["rank"]] = {}
261248
shard[item["rank"]]["file"] = item["file"]
@@ -304,7 +291,7 @@ def find_covering_shards(self, shards, target_offsets, target_lengths):
304291

305292
assert total_lengths == global_shape
306293
return covering_shards
307-
294+
308295
def extract_weight_from_shard_partial(self, shard, target_offsets, target_lengths):
309296
"""
310297
Extract the target range of weights from shard data, supporting partial overlap.
@@ -314,14 +301,16 @@ def extract_weight_from_shard_partial(self, shard, target_offsets, target_length
314301
param target_lengths: A 1D array indicating the length of the target tensor in each dimension.
315302
return: The extracted sub-tensor of the target weights and its position within the target range.
316303
"""
317-
shard_offsets = shard['offsets']
318-
shard_lengths = shard['lengths']
319-
weight = shard['weight']
304+
shard_offsets = shard["offsets"]
305+
shard_lengths = shard["lengths"]
306+
weight = shard["weight"]
320307

321308
slices = []
322309
target_slices = []
323310

324-
for dim, (t_offset, t_length, s_offset, s_length) in enumerate(zip(target_offsets, target_lengths, shard_offsets, shard_lengths)):
311+
for dim, (t_offset, t_length, s_offset, s_length) in enumerate(
312+
zip(target_offsets, target_lengths, shard_offsets, shard_lengths)
313+
):
325314
intersection_start = max(t_offset, s_offset)
326315
intersection_end = min(t_offset + t_length, s_offset + s_length)
327316

@@ -339,7 +328,6 @@ def extract_weight_from_shard_partial(self, shard, target_offsets, target_length
339328
target_weight = weight[tuple(slices)]
340329
return target_weight, target_slices
341330

342-
343331
def assemble_tensor_from_shards_partial(self, shards, target_offsets, target_lengths, dtype):
344332
target_tensor = torch.zeros(target_lengths, dtype=dtype)
345333

@@ -351,15 +339,14 @@ def assemble_tensor_from_shards_partial(self, shards, target_offsets, target_len
351339

352340
return target_tensor
353341

354-
355-
def load_unsharded_model(
342+
def load_unsharded_model(
356343
self,
357344
model: ModelWrapper,
358345
checkpoint: str,
359346
strict: bool = False,
360347
low_cpu_mem_mode: bool = True,
361348
num_threads: int = 1,
362-
):
349+
):
363350
"""
364351
Load model from a single file with the given path of checkpoint.
365352
@@ -390,30 +377,34 @@ def load_unsharded_model(
390377
for key, item in self.model_metadata.items():
391378
offsets = item["offsets"]
392379
lengths = item["lengths"]
393-
assert item["global_shape"] == metadata_loaded[key]["global_shape"], f"{item['global_shape']}, {metadata_loaded[key]['global_shape']}"
380+
assert (
381+
item["global_shape"] == metadata_loaded[key]["global_shape"]
382+
), f"{item['global_shape']}, {metadata_loaded[key]['global_shape']}"
394383
shards = metadata_loaded[key]["shards"]
395384
covering_shards = self.find_covering_shards(shards=shards, target_offsets=offsets, target_lengths=lengths)
396385
covered_shards[key] = covering_shards
397386
# load_files.update({rank: shard['file'] for rank, shard in covering_shards.items()})
398387
for rank, shard in covering_shards.items():
399388
if rank not in load_files:
400389
load_files[rank] = set()
401-
load_files[rank].add(shard['file'])
390+
load_files[rank].add(shard["file"])
402391

403392
dtype = None
404393
for rank, files in load_files.items():
405394
for file in files:
406395
file_path = os.path.join(checkpoint, file)
407396
state_dict_shard = load_state_dict(file_path)
408-
for key, weight in state_dict_shard.items():
397+
for key, weight in state_dict_shard.items():
409398
if key not in covered_shards:
410399
continue
411400
if dtype == None:
412401
dtype = weight.dtype
413402
covered_shards[key][rank]["weight"] = weight
414403
state_dict = {}
415404
for key, shards in covered_shards.items():
416-
state = self.assemble_tensor_from_shards_partial(shards, self.model_metadata[key]["offsets"], self.model_metadata[key]["lengths"], dtype=dtype)
405+
state = self.assemble_tensor_from_shards_partial(
406+
shards, self.model_metadata[key]["offsets"], self.model_metadata[key]["lengths"], dtype=dtype
407+
)
417408
state_dict[key] = state
418409

419410
if not low_cpu_mem_mode:
@@ -424,7 +415,6 @@ def load_unsharded_model(
424415
# Update master params if mixed-precision training is enabled.
425416
model_before_wrapping.update_master_params()
426417

427-
428418
@staticmethod
429419
def _model_sharder(
430420
model: nn.Module,
@@ -571,7 +561,7 @@ def save_sharded_model(
571561
)
572562
for k, _ in self.model_metadata.items():
573563
self.model_metadata[k]["file"] = index_file.get_checkpoint_file(k)
574-
564+
575565
self.save_metadata(metadata_file, total_size=total_size)
576566

577567
def load_sharded_model(
@@ -606,30 +596,34 @@ def load_sharded_model(
606596
for key, item in self.model_metadata.items():
607597
offsets = item["offsets"]
608598
lengths = item["lengths"]
609-
assert item["global_shape"] == metadata_loaded[key]["global_shape"], f"{item['global_shape']}, {metadata_loaded[key]['global_shape']}"
599+
assert (
600+
item["global_shape"] == metadata_loaded[key]["global_shape"]
601+
), f"{item['global_shape']}, {metadata_loaded[key]['global_shape']}"
610602
shards = metadata_loaded[key]["shards"]
611603
covering_shards = self.find_covering_shards(shards=shards, target_offsets=offsets, target_lengths=lengths)
612604
covered_shards[key] = covering_shards
613605
for rank, shard in covering_shards.items():
614606
if rank not in load_files:
615607
load_files[rank] = set()
616-
load_files[rank].add(shard['file'])
617-
608+
load_files[rank].add(shard["file"])
609+
618610
dtype = None
619611
for rank, files in load_files.items():
620612
for file in files:
621613
file_path = os.path.join(checkpoint, file)
622614
state_dict_shard = load_state_dict(file_path)
623-
for key, weight in state_dict_shard.items():
615+
for key, weight in state_dict_shard.items():
624616
if key not in covered_shards:
625617
continue
626618
if dtype == None:
627619
dtype = weight.dtype
628620
covered_shards[key][rank]["weight"] = weight
629-
621+
630622
state_dict = {}
631623
for key, shards in covered_shards.items():
632-
state = self.assemble_tensor_from_shards_partial(shards, self.model_metadata[key]["offsets"], self.model_metadata[key]["lengths"], dtype=dtype)
624+
state = self.assemble_tensor_from_shards_partial(
625+
shards, self.model_metadata[key]["offsets"], self.model_metadata[key]["lengths"], dtype=dtype
626+
)
633627
state_dict[key] = state
634628

635629
if not low_cpu_mem_mode:
@@ -638,4 +632,4 @@ def load_sharded_model(
638632
DistributedCheckpointIO.load_state_dict(model=model, state_dict=state_dict)
639633

640634
# Update master params if mixed-precision training is enabled.
641-
model_before_wrapping.update_master_params()
635+
model_before_wrapping.update_master_params()

0 commit comments

Comments
 (0)