1
- import copy
1
+ import json
2
2
import logging
3
3
import os
4
- from functools import reduce
5
4
from pathlib import Path
6
- from shutil import rmtree
7
5
from typing import Dict , Iterator , Optional , OrderedDict , Tuple
8
- import json
9
6
10
7
import torch
11
8
import torch .distributed as dist
12
9
import torch .nn as nn
13
10
from torch .distributed import ProcessGroup
14
- from torch .optim .lr_scheduler import _LRScheduler as LRScheduler
15
- from torch .utils ._pytree import tree_map
11
+ from torch .distributed .distributed_c10d import _get_default_group
16
12
17
13
from colossalai .cluster import DistCoordinator
18
- from colossalai .interface import ModelWrapper , OptimizerWrapper
19
- from colossalai .tensor .padded_tensor import (
20
- init_as_padded_tensor ,
21
- is_padded_tensor ,
22
- to_padded_tensor ,
23
- to_unpadded_tensor ,
24
- )
25
- from colossalai .utils import get_current_device , get_non_persistent_buffers_set
26
- from torch .distributed .distributed_c10d import _get_default_group
14
+ from colossalai .interface import ModelWrapper
15
+ from colossalai .utils import get_non_persistent_buffers_set
27
16
28
17
from .general_checkpoint_io import GeneralCheckpointIO
29
18
from .index_file import CheckpointIndexFile
30
19
from .utils import (
31
20
StateDictSharder ,
32
21
async_save_state_dict_shards ,
33
22
create_pinned_state_dict ,
34
- gather_distributed_param ,
35
23
get_model_base_filenames ,
36
- get_optimizer_base_filenames ,
37
- is_safetensors_available ,
38
- load_shard_state_dict ,
39
24
load_state_dict ,
40
- load_state_dict_into_model ,
41
- load_states_into_optimizer ,
42
- save_config_file ,
43
- save_param_groups ,
44
25
save_state_dict ,
45
26
save_state_dict_shards ,
46
- search_padding_dim ,
47
27
search_tp_partition_dim ,
48
- sharded_optimizer_loading_epilogue ,
49
28
)
50
29
51
30
try :
@@ -97,7 +76,6 @@ def __init__(
97
76
self .model_metadata = None
98
77
self .optimizer_metadata = None
99
78
self .global_rank = dist .get_rank (_get_default_group ())
100
-
101
79
102
80
@staticmethod
103
81
def model_state_dict (model : nn .Module , prefix : str = "" , keep_vars : bool = False ):
@@ -106,13 +84,13 @@ def model_state_dict(model: nn.Module, prefix: str = "", keep_vars: bool = False
106
84
for name , param in model .named_parameters ():
107
85
if param is None :
108
86
continue
109
- destination [prefix + name ] = param
87
+ destination [prefix + name ] = param
110
88
# Save buffers.
111
89
non_persist_buffers_set = get_non_persistent_buffers_set (model )
112
90
for name , buf in model .named_buffers ():
113
91
if buf is not None and name not in non_persist_buffers_set :
114
92
buffer = buf if keep_vars else buf .detach ()
115
- destination [prefix + name ] = buffer
93
+ destination [prefix + name ] = buffer
116
94
117
95
# Save extra states.
118
96
extra_state_key = prefix + _EXTRA_STATE_KEY_SUFFIX
@@ -123,22 +101,24 @@ def model_state_dict(model: nn.Module, prefix: str = "", keep_vars: bool = False
123
101
extra_state = model .get_extra_state ()
124
102
destination [extra_state_key ] = extra_state
125
103
return destination
126
-
104
+
127
105
@staticmethod
128
- def load_state_dict (model : nn .Module , state_dict : Dict , prefix : str = "" , keep_vars : bool = False , strict : bool = False ):
106
+ def load_state_dict (
107
+ model : nn .Module , state_dict : Dict , prefix : str = "" , keep_vars : bool = False , strict : bool = False
108
+ ):
129
109
destination = dict ()
130
110
# Save parameters.
131
111
for name , param in model .named_parameters ():
132
112
if param is None :
133
113
continue
134
114
with torch .no_grad ():
135
- param .copy_ (state_dict [prefix + name ])
115
+ param .copy_ (state_dict [prefix + name ])
136
116
# Save buffers.
137
117
non_persist_buffers_set = get_non_persistent_buffers_set (model )
138
118
for name , buf in model .named_buffers ():
139
119
if buf is not None and name not in non_persist_buffers_set :
140
120
with torch .no_grad ():
141
- buf .copy_ (state_dict [prefix + name ])
121
+ buf .copy_ (state_dict [prefix + name ])
142
122
143
123
# Save extra states.
144
124
extra_state_key = prefix + _EXTRA_STATE_KEY_SUFFIX
@@ -151,26 +131,33 @@ def load_state_dict(model: nn.Module, state_dict: Dict, prefix: str = "", keep_v
151
131
extra_state .copy_ (state_dict [extra_state_key ])
152
132
return destination
153
133
154
- def create_model_metadata (self , model : nn .Module , prefix : str = "" ,):
134
+ def create_model_metadata (
135
+ self ,
136
+ model : nn .Module ,
137
+ prefix : str = "" ,
138
+ ):
155
139
param_origin_shape = model .param_origin_shape
156
140
model = model .unwrap ()
157
141
self .model_metadata = {}
158
142
for name , param in model .named_parameters ():
159
143
if param is None :
160
144
continue
161
- self .model_metadata [prefix + name ] = {}
145
+ self .model_metadata [prefix + name ] = {}
162
146
original_shape = param_origin_shape [name ]
163
- tp_partition_dim = search_tp_partition_dim (current_shape = param .shape , original_shape = original_shape , tp_size = self .tp_size )
164
- self .model_metadata [prefix + name ]["offsets" ] = torch .zeros (len (original_shape ), dtype = torch .int )
165
- self .model_metadata [prefix + name ]["lengths" ] = list (param .shape )
166
- self .model_metadata [prefix + name ]["global_shape" ] = list (original_shape )
147
+ tp_partition_dim = search_tp_partition_dim (
148
+ current_shape = param .shape , original_shape = original_shape , tp_size = self .tp_size
149
+ )
150
+ self .model_metadata [prefix + name ]["offsets" ] = torch .zeros (len (original_shape ), dtype = torch .int )
151
+ self .model_metadata [prefix + name ]["lengths" ] = list (param .shape )
152
+ self .model_metadata [prefix + name ]["global_shape" ] = list (original_shape )
167
153
if tp_partition_dim is not None :
168
154
partition_size = param .shape [tp_partition_dim ]
169
- self .model_metadata [prefix + name ]["offsets" ][tp_partition_dim ] = partition_size * self .tp_rank
155
+ self .model_metadata [prefix + name ]["offsets" ][tp_partition_dim ] = partition_size * self .tp_rank
170
156
if self .tp_rank == self .tp_size - 1 :
171
- self .model_metadata [prefix + name ]["lengths" ][tp_partition_dim ] = original_shape [tp_partition_dim ] - (partition_size * (self .tp_size - 1 ))
157
+ self .model_metadata [prefix + name ]["lengths" ][tp_partition_dim ] = original_shape [
158
+ tp_partition_dim
159
+ ] - (partition_size * (self .tp_size - 1 ))
172
160
173
-
174
161
def save_metadata (self , metadata_file , checkpoint_file = None , total_size = None ):
175
162
metadata_dicts = {
176
163
"checkpoint_version" : "1.0" ,
@@ -188,7 +175,7 @@ def save_metadata(self, metadata_file, checkpoint_file=None, total_size=None):
188
175
metadata_dicts ["metadata" ][name ]["rank" ] = self .global_rank
189
176
with open (metadata_file , "w" ) as json_file :
190
177
json .dump (metadata_dicts , json_file , indent = 4 )
191
-
178
+
192
179
def save_unsharded_model (
193
180
self , model : ModelWrapper , checkpoint : str , gather_dtensor : bool , use_safetensors : bool , use_async : bool = False
194
181
):
@@ -249,13 +236,13 @@ def load_metadata(self, checkpoint: str):
249
236
try :
250
237
with open (file_path , "r" ) as f :
251
238
metadata_json = json .load (f )
252
- for name , item in metadata_json [' metadata' ].items ():
239
+ for name , item in metadata_json [" metadata" ].items ():
253
240
if name not in metadata_dict :
254
241
metadata_dict [name ] = {}
255
- metadata_dict [name ]["global_shape" ] = item [' global_shape' ]
242
+ metadata_dict [name ]["global_shape" ] = item [" global_shape" ]
256
243
metadata_dict [name ]["shards" ] = {}
257
244
else :
258
- assert metadata_dict [name ]["global_shape" ] == item [' global_shape' ]
245
+ assert metadata_dict [name ]["global_shape" ] == item [" global_shape" ]
259
246
shard = {}
260
247
shard [item ["rank" ]] = {}
261
248
shard [item ["rank" ]]["file" ] = item ["file" ]
@@ -304,7 +291,7 @@ def find_covering_shards(self, shards, target_offsets, target_lengths):
304
291
305
292
assert total_lengths == global_shape
306
293
return covering_shards
307
-
294
+
308
295
def extract_weight_from_shard_partial (self , shard , target_offsets , target_lengths ):
309
296
"""
310
297
Extract the target range of weights from shard data, supporting partial overlap.
@@ -314,14 +301,16 @@ def extract_weight_from_shard_partial(self, shard, target_offsets, target_length
314
301
param target_lengths: A 1D array indicating the length of the target tensor in each dimension.
315
302
return: The extracted sub-tensor of the target weights and its position within the target range.
316
303
"""
317
- shard_offsets = shard [' offsets' ]
318
- shard_lengths = shard [' lengths' ]
319
- weight = shard [' weight' ]
304
+ shard_offsets = shard [" offsets" ]
305
+ shard_lengths = shard [" lengths" ]
306
+ weight = shard [" weight" ]
320
307
321
308
slices = []
322
309
target_slices = []
323
310
324
- for dim , (t_offset , t_length , s_offset , s_length ) in enumerate (zip (target_offsets , target_lengths , shard_offsets , shard_lengths )):
311
+ for dim , (t_offset , t_length , s_offset , s_length ) in enumerate (
312
+ zip (target_offsets , target_lengths , shard_offsets , shard_lengths )
313
+ ):
325
314
intersection_start = max (t_offset , s_offset )
326
315
intersection_end = min (t_offset + t_length , s_offset + s_length )
327
316
@@ -339,7 +328,6 @@ def extract_weight_from_shard_partial(self, shard, target_offsets, target_length
339
328
target_weight = weight [tuple (slices )]
340
329
return target_weight , target_slices
341
330
342
-
343
331
def assemble_tensor_from_shards_partial (self , shards , target_offsets , target_lengths , dtype ):
344
332
target_tensor = torch .zeros (target_lengths , dtype = dtype )
345
333
@@ -351,15 +339,14 @@ def assemble_tensor_from_shards_partial(self, shards, target_offsets, target_len
351
339
352
340
return target_tensor
353
341
354
-
355
- def load_unsharded_model (
342
+ def load_unsharded_model (
356
343
self ,
357
344
model : ModelWrapper ,
358
345
checkpoint : str ,
359
346
strict : bool = False ,
360
347
low_cpu_mem_mode : bool = True ,
361
348
num_threads : int = 1 ,
362
- ):
349
+ ):
363
350
"""
364
351
Load model from a single file with the given path of checkpoint.
365
352
@@ -390,30 +377,34 @@ def load_unsharded_model(
390
377
for key , item in self .model_metadata .items ():
391
378
offsets = item ["offsets" ]
392
379
lengths = item ["lengths" ]
393
- assert item ["global_shape" ] == metadata_loaded [key ]["global_shape" ], f"{ item ['global_shape' ]} , { metadata_loaded [key ]['global_shape' ]} "
380
+ assert (
381
+ item ["global_shape" ] == metadata_loaded [key ]["global_shape" ]
382
+ ), f"{ item ['global_shape' ]} , { metadata_loaded [key ]['global_shape' ]} "
394
383
shards = metadata_loaded [key ]["shards" ]
395
384
covering_shards = self .find_covering_shards (shards = shards , target_offsets = offsets , target_lengths = lengths )
396
385
covered_shards [key ] = covering_shards
397
386
# load_files.update({rank: shard['file'] for rank, shard in covering_shards.items()})
398
387
for rank , shard in covering_shards .items ():
399
388
if rank not in load_files :
400
389
load_files [rank ] = set ()
401
- load_files [rank ].add (shard [' file' ])
390
+ load_files [rank ].add (shard [" file" ])
402
391
403
392
dtype = None
404
393
for rank , files in load_files .items ():
405
394
for file in files :
406
395
file_path = os .path .join (checkpoint , file )
407
396
state_dict_shard = load_state_dict (file_path )
408
- for key , weight in state_dict_shard .items ():
397
+ for key , weight in state_dict_shard .items ():
409
398
if key not in covered_shards :
410
399
continue
411
400
if dtype == None :
412
401
dtype = weight .dtype
413
402
covered_shards [key ][rank ]["weight" ] = weight
414
403
state_dict = {}
415
404
for key , shards in covered_shards .items ():
416
- state = self .assemble_tensor_from_shards_partial (shards , self .model_metadata [key ]["offsets" ], self .model_metadata [key ]["lengths" ], dtype = dtype )
405
+ state = self .assemble_tensor_from_shards_partial (
406
+ shards , self .model_metadata [key ]["offsets" ], self .model_metadata [key ]["lengths" ], dtype = dtype
407
+ )
417
408
state_dict [key ] = state
418
409
419
410
if not low_cpu_mem_mode :
@@ -424,7 +415,6 @@ def load_unsharded_model(
424
415
# Update master params if mixed-precision training is enabled.
425
416
model_before_wrapping .update_master_params ()
426
417
427
-
428
418
@staticmethod
429
419
def _model_sharder (
430
420
model : nn .Module ,
@@ -571,7 +561,7 @@ def save_sharded_model(
571
561
)
572
562
for k , _ in self .model_metadata .items ():
573
563
self .model_metadata [k ]["file" ] = index_file .get_checkpoint_file (k )
574
-
564
+
575
565
self .save_metadata (metadata_file , total_size = total_size )
576
566
577
567
def load_sharded_model (
@@ -606,30 +596,34 @@ def load_sharded_model(
606
596
for key , item in self .model_metadata .items ():
607
597
offsets = item ["offsets" ]
608
598
lengths = item ["lengths" ]
609
- assert item ["global_shape" ] == metadata_loaded [key ]["global_shape" ], f"{ item ['global_shape' ]} , { metadata_loaded [key ]['global_shape' ]} "
599
+ assert (
600
+ item ["global_shape" ] == metadata_loaded [key ]["global_shape" ]
601
+ ), f"{ item ['global_shape' ]} , { metadata_loaded [key ]['global_shape' ]} "
610
602
shards = metadata_loaded [key ]["shards" ]
611
603
covering_shards = self .find_covering_shards (shards = shards , target_offsets = offsets , target_lengths = lengths )
612
604
covered_shards [key ] = covering_shards
613
605
for rank , shard in covering_shards .items ():
614
606
if rank not in load_files :
615
607
load_files [rank ] = set ()
616
- load_files [rank ].add (shard [' file' ])
617
-
608
+ load_files [rank ].add (shard [" file" ])
609
+
618
610
dtype = None
619
611
for rank , files in load_files .items ():
620
612
for file in files :
621
613
file_path = os .path .join (checkpoint , file )
622
614
state_dict_shard = load_state_dict (file_path )
623
- for key , weight in state_dict_shard .items ():
615
+ for key , weight in state_dict_shard .items ():
624
616
if key not in covered_shards :
625
617
continue
626
618
if dtype == None :
627
619
dtype = weight .dtype
628
620
covered_shards [key ][rank ]["weight" ] = weight
629
-
621
+
630
622
state_dict = {}
631
623
for key , shards in covered_shards .items ():
632
- state = self .assemble_tensor_from_shards_partial (shards , self .model_metadata [key ]["offsets" ], self .model_metadata [key ]["lengths" ], dtype = dtype )
624
+ state = self .assemble_tensor_from_shards_partial (
625
+ shards , self .model_metadata [key ]["offsets" ], self .model_metadata [key ]["lengths" ], dtype = dtype
626
+ )
633
627
state_dict [key ] = state
634
628
635
629
if not low_cpu_mem_mode :
@@ -638,4 +632,4 @@ def load_sharded_model(
638
632
DistributedCheckpointIO .load_state_dict (model = model , state_dict = state_dict )
639
633
640
634
# Update master params if mixed-precision training is enabled.
641
- model_before_wrapping .update_master_params ()
635
+ model_before_wrapping .update_master_params ()
0 commit comments