@@ -508,7 +508,6 @@ def _setup_model(
508
508
model ,
509
509
model_state_dict ,
510
510
self ._device ,
511
- self ._is_rank_zero ,
512
511
strict = True ,
513
512
cpu_offload = fsdp_cpu_offload ,
514
513
)
@@ -562,6 +561,7 @@ def _setup_optimizer(
562
561
for param in opt_state_dict .keys ():
563
562
try :
564
563
training .load_from_full_optimizer_state_dict (
564
+ self ._model ,
565
565
self ._optim_ckpt_wrapper .state_dict ()[param ],
566
566
opt_state_dict [param ],
567
567
self ._device ,
@@ -577,6 +577,7 @@ def _setup_optimizer(
577
577
optimizer = config .instantiate (cfg_optimizer , self ._model .parameters ())
578
578
if opt_state_dict :
579
579
training .load_from_full_optimizer_state_dict (
580
+ self ._model ,
580
581
optimizer ,
581
582
opt_state_dict ,
582
583
self ._device ,
@@ -667,7 +668,7 @@ def save_checkpoint(
667
668
# To prevent GPU memory from spiking during checkpoint save,
668
669
# we consolidate the full model and optim state dicts on CPU for rank 0
669
670
cpu_state_dict = training .gather_cpu_state_dict (
670
- self ._model . state_dict () ,
671
+ self ._model ,
671
672
self ._is_rank_zero ,
672
673
device = self ._device ,
673
674
)
@@ -682,6 +683,7 @@ def save_checkpoint(
682
683
utils .log_rank_zero (log , "Getting optimizer state dict..." )
683
684
if not self ._optimizer_in_bwd :
684
685
opt_state_dict = training .get_full_optimizer_state_dict (
686
+ self ._model ,
685
687
self ._optimizer ,
686
688
self ._is_rank_zero ,
687
689
device = self ._device ,
@@ -690,7 +692,7 @@ def save_checkpoint(
690
692
opt_state_dict = {}
691
693
for param , opt in self ._optim_ckpt_wrapper .optim_map .items ():
692
694
opt_state_dict [param ] = training .get_full_optimizer_state_dict (
693
- opt , self ._is_rank_zero , device = self ._device
695
+ self . _model , opt , self ._is_rank_zero , device = self ._device
694
696
)
695
697
utils .log_rank_zero (
696
698
log ,
0 commit comments