Skip to content

Commit a4c818c

Browse files
committed
Merge remote-tracking branch 'upstream/main' into insop/kld
2 parents 2ed4e91 + cce8ef6 commit a4c818c

18 files changed

+877
-196
lines changed

recipes/dev/early_exit_finetune_distributed.py

+3-3
Original file line numberDiff line numberDiff line change
@@ -556,7 +556,6 @@ def _setup_model(
556556
model,
557557
model_state_dict,
558558
self._device,
559-
self._is_rank_zero,
560559
strict=True,
561560
cpu_offload=fsdp_cpu_offload,
562561
)
@@ -757,7 +756,7 @@ def save_checkpoint(
757756
# To prevent GPU memory from spiking during checkpoint save,
758757
# we consolidate the full model and optim state dicts on CPU for rank 0
759758
cpu_state_dict = training.gather_cpu_state_dict(
760-
self._model.state_dict(),
759+
self._model,
761760
self._is_rank_zero,
762761
device=self._device,
763762
)
@@ -773,6 +772,7 @@ def save_checkpoint(
773772
log.info("Getting optimizer state dict...")
774773
if not self._optimizer_in_bwd:
775774
opt_state_dict = training.get_full_optimizer_state_dict(
775+
self._model,
776776
self._optimizer,
777777
self._is_rank_zero,
778778
device=self._device,
@@ -781,7 +781,7 @@ def save_checkpoint(
781781
opt_state_dict = {}
782782
for param, opt in self._optim_ckpt_wrapper.optim_map.items():
783783
opt_state_dict[param] = training.get_full_optimizer_state_dict(
784-
opt, self._is_rank_zero, device=self._device
784+
self._model, opt, self._is_rank_zero, device=self._device
785785
)
786786
if self._is_rank_zero:
787787
log.info(

recipes/full_finetune_distributed.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -547,7 +547,6 @@ def _setup_model(
547547
model,
548548
model_state_dict,
549549
self._device,
550-
self._is_rank_zero,
551550
strict=True,
552551
cpu_offload=fsdp_cpu_offload,
553552
)
@@ -602,6 +601,7 @@ def _setup_optimizer(
602601
for param in opt_state_dict.keys():
603602
try:
604603
training.load_from_full_optimizer_state_dict(
604+
self._model,
605605
self._optim_ckpt_wrapper.state_dict()[param],
606606
opt_state_dict[param],
607607
self._device,
@@ -617,6 +617,7 @@ def _setup_optimizer(
617617
optimizer = config.instantiate(cfg_optimizer, self._model.parameters())
618618
if opt_state_dict:
619619
training.load_from_full_optimizer_state_dict(
620+
self._model,
620621
optimizer,
621622
opt_state_dict,
622623
self._device,

recipes/knowledge_distillation_distributed.py

+3-4
Original file line numberDiff line numberDiff line change
@@ -461,7 +461,6 @@ def _setup_model(
461461
model,
462462
lora_weights_state_dict,
463463
self._device,
464-
self._is_rank_zero,
465464
cpu_offload=fsdp_cpu_offload,
466465
)
467466
else:
@@ -486,7 +485,6 @@ def _setup_model(
486485
model,
487486
base_model_state_dict,
488487
self._device,
489-
self._is_rank_zero,
490488
cpu_offload=fsdp_cpu_offload,
491489
)
492490
for m in model.modules():
@@ -574,7 +572,6 @@ def _setup_teacher_model(
574572
model,
575573
model_state_dict,
576574
self._device,
577-
self._is_rank_zero,
578575
strict=True,
579576
cpu_offload=fsdp_cpu_offload,
580577
)
@@ -611,6 +608,7 @@ def _setup_optimizer(
611608
optimizer = config.instantiate(cfg_optimizer, self._model.parameters())
612609
if opt_state_dict:
613610
training.load_from_full_optimizer_state_dict(
611+
self._model,
614612
optimizer,
615613
opt_state_dict,
616614
self._device,
@@ -705,13 +703,14 @@ def save_checkpoint(self, epoch: int) -> None:
705703
# To prevent GPU memory from spiking during checkpoint save,
706704
# we consolidate the full model and optim state dicts on CPU for rank 0
707705
cpu_state_dict = training.gather_cpu_state_dict(
708-
self._model.state_dict(),
706+
self._model,
709707
self._is_rank_zero,
710708
device=self._device,
711709
)
712710

713711
if intermediate_checkpoint:
714712
opt_state_dict = training.get_full_optimizer_state_dict(
713+
self._model,
715714
self._optimizer,
716715
self._is_rank_zero,
717716
device=self._device,

recipes/lora_dpo_distributed.py

+4-7
Original file line numberDiff line numberDiff line change
@@ -385,7 +385,6 @@ def _setup_model(
385385
model,
386386
lora_weights_state_dict,
387387
self._device,
388-
self._is_rank_zero,
389388
cpu_offload=fsdp_cpu_offload,
390389
)
391390
else:
@@ -410,7 +409,6 @@ def _setup_model(
410409
model,
411410
base_model_state_dict,
412411
self._device,
413-
self._is_rank_zero,
414412
cpu_offload=fsdp_cpu_offload,
415413
)
416414
is_dora = False
@@ -458,6 +456,7 @@ def _setup_optimizer(
458456
optimizer = config.instantiate(cfg_optimizer, self._model.parameters())
459457
if opt_state_dict:
460458
training.load_from_full_optimizer_state_dict(
459+
self._model,
461460
optimizer,
462461
opt_state_dict,
463462
self._device,
@@ -546,17 +545,15 @@ def save_checkpoint(
546545
intermediate_checkpoint = epoch + 1 < self.total_epochs
547546
# To prevent GPU memory from spiking during checkpoint save,
548547
# we consolidate the full model and optim state dicts on CPU for rank 0
549-
state_dict = self._model.state_dict()
550-
if self._save_adapter_weights_only:
551-
state_dict = get_adapter_state_dict(state_dict, device=None)
552-
553548
cpu_state_dict = training.gather_cpu_state_dict(
554-
state_dict,
549+
self._model,
555550
self._is_rank_zero,
556551
device=self._device,
552+
adapter_weights_only=self._save_adapter_weights_only,
557553
)
558554
if intermediate_checkpoint:
559555
opt_state_dict = training.get_full_optimizer_state_dict(
556+
self._model,
560557
self._optimizer,
561558
self._is_rank_zero,
562559
device=self._device,

recipes/lora_finetune_distributed.py

+4-7
Original file line numberDiff line numberDiff line change
@@ -480,7 +480,6 @@ def _setup_model(
480480
model,
481481
lora_weights_state_dict,
482482
self._device,
483-
self._is_rank_zero,
484483
cpu_offload=fsdp_cpu_offload,
485484
)
486485
else:
@@ -505,7 +504,6 @@ def _setup_model(
505504
model,
506505
base_model_state_dict,
507506
self._device,
508-
self._is_rank_zero,
509507
cpu_offload=fsdp_cpu_offload,
510508
)
511509
for m in model.modules():
@@ -549,6 +547,7 @@ def _setup_optimizer(
549547
optimizer = config.instantiate(cfg_optimizer, self._model.parameters())
550548
if opt_state_dict:
551549
training.load_from_full_optimizer_state_dict(
550+
self._model,
552551
optimizer,
553552
opt_state_dict,
554553
self._device,
@@ -656,14 +655,11 @@ def save_checkpoint(
656655

657656
# To prevent GPU memory from spiking during checkpoint save,
658657
# we consolidate the full model and optim state dicts on CPU for rank 0
659-
state_dict = self._model.state_dict()
660-
if self._save_adapter_weights_only:
661-
state_dict = get_adapter_state_dict(state_dict, device=None)
662-
663658
cpu_state_dict = training.gather_cpu_state_dict(
664-
state_dict,
659+
self._model,
665660
self._is_rank_zero,
666661
device=self._device,
662+
adapter_weights_only=self._save_adapter_weights_only,
667663
)
668664
utils.log_rank_zero(
669665
log,
@@ -673,6 +669,7 @@ def save_checkpoint(
673669
if intermediate_checkpoint:
674670
utils.log_rank_zero(log, "Retrieving optimizer state dict...")
675671
opt_state_dict = training.get_full_optimizer_state_dict(
672+
self._model,
676673
self._optimizer,
677674
self._is_rank_zero,
678675
device=self._device,

recipes/lora_finetune_distributed_multi_dataset.py

-2
Original file line numberDiff line numberDiff line change
@@ -473,7 +473,6 @@ def _setup_model(
473473
model,
474474
lora_weights_state_dict,
475475
self._device,
476-
self._is_rank_zero,
477476
cpu_offload=fsdp_cpu_offload,
478477
)
479478
else:
@@ -500,7 +499,6 @@ def _setup_model(
500499
model,
501500
base_model_state_dict,
502501
self._device,
503-
self._is_rank_zero,
504502
cpu_offload=fsdp_cpu_offload,
505503
)
506504
for m in model.modules():

recipes/qat_distributed.py

+5-3
Original file line numberDiff line numberDiff line change
@@ -508,7 +508,6 @@ def _setup_model(
508508
model,
509509
model_state_dict,
510510
self._device,
511-
self._is_rank_zero,
512511
strict=True,
513512
cpu_offload=fsdp_cpu_offload,
514513
)
@@ -562,6 +561,7 @@ def _setup_optimizer(
562561
for param in opt_state_dict.keys():
563562
try:
564563
training.load_from_full_optimizer_state_dict(
564+
self._model,
565565
self._optim_ckpt_wrapper.state_dict()[param],
566566
opt_state_dict[param],
567567
self._device,
@@ -577,6 +577,7 @@ def _setup_optimizer(
577577
optimizer = config.instantiate(cfg_optimizer, self._model.parameters())
578578
if opt_state_dict:
579579
training.load_from_full_optimizer_state_dict(
580+
self._model,
580581
optimizer,
581582
opt_state_dict,
582583
self._device,
@@ -667,7 +668,7 @@ def save_checkpoint(
667668
# To prevent GPU memory from spiking during checkpoint save,
668669
# we consolidate the full model and optim state dicts on CPU for rank 0
669670
cpu_state_dict = training.gather_cpu_state_dict(
670-
self._model.state_dict(),
671+
self._model,
671672
self._is_rank_zero,
672673
device=self._device,
673674
)
@@ -682,6 +683,7 @@ def save_checkpoint(
682683
utils.log_rank_zero(log, "Getting optimizer state dict...")
683684
if not self._optimizer_in_bwd:
684685
opt_state_dict = training.get_full_optimizer_state_dict(
686+
self._model,
685687
self._optimizer,
686688
self._is_rank_zero,
687689
device=self._device,
@@ -690,7 +692,7 @@ def save_checkpoint(
690692
opt_state_dict = {}
691693
for param, opt in self._optim_ckpt_wrapper.optim_map.items():
692694
opt_state_dict[param] = training.get_full_optimizer_state_dict(
693-
opt, self._is_rank_zero, device=self._device
695+
self._model, opt, self._is_rank_zero, device=self._device
694696
)
695697
utils.log_rank_zero(
696698
log,

recipes/qat_lora_finetune_distributed.py

+4-7
Original file line numberDiff line numberDiff line change
@@ -525,7 +525,6 @@ def _setup_model(
525525
model,
526526
lora_weights_state_dict,
527527
self._device,
528-
self._is_rank_zero,
529528
cpu_offload=fsdp_cpu_offload,
530529
)
531530
else:
@@ -550,7 +549,6 @@ def _setup_model(
550549
model,
551550
base_model_state_dict,
552551
self._device,
553-
self._is_rank_zero,
554552
cpu_offload=fsdp_cpu_offload,
555553
)
556554
validate_missing_and_unexpected_for_lora(
@@ -589,6 +587,7 @@ def _setup_optimizer(
589587
optimizer = config.instantiate(cfg_optimizer, self._model.parameters())
590588
if opt_state_dict:
591589
training.load_from_full_optimizer_state_dict(
590+
self._model,
592591
optimizer,
593592
opt_state_dict,
594593
self._device,
@@ -699,14 +698,11 @@ def save_checkpoint(
699698

700699
# To prevent GPU memory from spiking during checkpoint save,
701700
# we consolidate the full model and optim state dicts on CPU for rank 0
702-
state_dict = self._model.state_dict()
703-
if self._save_adapter_weights_only:
704-
state_dict = get_adapter_state_dict(state_dict, device=None)
705-
706701
cpu_state_dict = training.gather_cpu_state_dict(
707-
state_dict,
702+
self._model,
708703
self._is_rank_zero,
709704
device=self._device,
705+
adapter_weights_only=self._save_adapter_weights_only,
710706
)
711707
if self._is_rank_zero:
712708
log.info(
@@ -717,6 +713,7 @@ def save_checkpoint(
717713
if self._is_rank_zero:
718714
log.info("Retrieving optimizer state dict...")
719715
opt_state_dict = training.get_full_optimizer_state_dict(
716+
self._model,
720717
self._optimizer,
721718
self._is_rank_zero,
722719
device=self._device,
+5
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,5 @@
1+
# Copyright (c) Meta Platforms, Inc. and affiliates.
2+
# All rights reserved.
3+
#
4+
# This source code is licensed under the BSD-style license found in the
5+
# LICENSE file in the root directory of this source tree.

0 commit comments

Comments
 (0)