@@ -252,11 +252,15 @@ def __init__(self,
252
252
self .sub_group_size = sub_group_size
253
253
254
254
self .sub_group_to_group_id = {}
255
- see_memory_usage ("Before creating fp16 partitions" , force = False )
256
- self ._create_fp16_partitions_with_defragmentation ()
255
+
256
+ # Trainable parameters
257
+ self .trainable_param_groups = self ._get_trainable_parameter_groups ()
258
+
259
+ see_memory_usage ("Before creating fp16 partitions" , force = True )
260
+ self ._create_fp16_partitions_with_defragmentation (self .trainable_param_groups )
257
261
num_fp16_subgroups = len (self .fp16_partitioned_groups_flat )
258
262
see_memory_usage (f"After creating fp16 partitions: { num_fp16_subgroups } " ,
259
- force = False )
263
+ force = True )
260
264
261
265
# Optimizer tensor swapping
262
266
if self .swap_optimizer :
@@ -350,19 +354,28 @@ def __init__(self,
350
354
def destroy (self ):
351
355
self .parameter_offload .destroy ()
352
356
357
+ def _get_trainable_parameter_groups (self ):
358
+ param_groups = []
359
+ for param_group in self .optimizer .param_groups :
360
+ trainable_params = {
361
+ "params" : [p for p in param_group ["params" ] if p .requires_grad ]
362
+ }
363
+ param_groups .append (trainable_params )
364
+ return param_groups
365
+
353
366
def _setup_for_real_optimizer (self ):
354
- see_memory_usage ("Before creating fp32 partitions" , force = False )
367
+ see_memory_usage ("Before creating fp32 partitions" , force = True )
355
368
self ._create_fp32_partitions ()
356
- see_memory_usage ("After creating fp32 partitions" , force = False )
369
+ see_memory_usage ("After creating fp32 partitions" , force = True )
357
370
dist .barrier ()
358
371
359
372
# To support pipelined optimizer swapping
360
373
self ._create_next_swappable_fp32_groups ()
361
374
362
- see_memory_usage ("Before initializing optimizer states" , force = False )
375
+ see_memory_usage ("Before initializing optimizer states" , force = True )
363
376
364
377
self .initialize_optimizer_states ()
365
- see_memory_usage ("After initializing optimizer states" , force = False )
378
+ see_memory_usage ("After initializing optimizer states" , force = True )
366
379
dist .barrier ()
367
380
368
381
if dist .get_rank () == 0 :
@@ -523,7 +536,7 @@ def _create_param_groups_fp16_flat_cpu_memory(self):
523
536
524
537
aggregate_params_count = 0
525
538
526
- for j , param_group in enumerate (self .optimizer . param_groups ):
539
+ for j , param_group in enumerate (self .trainable_param_groups ):
527
540
params_in_group = sum ([p .partition_numel () for p in param_group ['params' ]])
528
541
529
542
flat_buffer_size = params_in_group
@@ -552,11 +565,12 @@ def _create_param_groups_fp16_flat_cpu_memory(self):
552
565
torch .empty (1 ,
553
566
dtype = self .dtype ))
554
567
555
- def _create_fp16_partitions_with_defragmentation (self ):
568
+ def _create_fp16_partitions_with_defragmentation (self , fp16_param_groups ):
556
569
dist .barrier ()
570
+
557
571
param_groups : List [List [Parameter ]] = tuple (
558
572
self ._create_fp16_sub_groups (param_group ["params" ])
559
- for param_group in self . optimizer . param_groups )
573
+ for param_group in fp16_param_groups )
560
574
561
575
# bookkeeping related to param groups
562
576
for param_group_idx , param_group in enumerate (param_groups ):
@@ -884,7 +898,6 @@ def initialize_optimizer_states(self):
884
898
dtype = gradient_dtype ,
885
899
device = self .device )
886
900
887
- timers = self .timers
888
901
timer_names = set ()
889
902
890
903
if self .swap_optimizer :
@@ -2122,6 +2135,7 @@ def _get_param_groups(self):
2122
2135
2123
2136
def _set_param_groups (self , value ):
2124
2137
self .optimizer .param_groups = value
2138
+ self .trainable_param_groups = self ._get_trainable_parameter_groups ()
2125
2139
2126
2140
param_groups = property (_get_param_groups , _set_param_groups )
2127
2141
0 commit comments