Skip to content

Commit 08efaed

Browse files
authored
Refactor Recipe State Dict Code (#1964)
1 parent 550163b commit 08efaed

18 files changed

+226
-129
lines changed

docs/source/api_ref_modules.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,7 @@ PEFT Components
7575
peft.AdapterModule
7676
peft.get_adapter_params
7777
peft.set_trainable_params
78+
peft.get_adapter_state_dict
7879
peft.validate_missing_and_unexpected_for_lora
7980
peft.validate_state_dict_for_lora
8081
peft.disable_adapter

docs/source/api_ref_training.rst

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,7 @@ Utilities for enabling and working with distributed training.
5656
get_world_size_and_rank
5757
get_full_finetune_fsdp_wrap_policy
5858
lora_fsdp_wrap_policy
59+
gather_cpu_state_dict
5960

6061
.. _ac_label:
6162

recipes/full_finetune_distributed.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -645,8 +645,8 @@ def save_checkpoint(
645645

646646
# To prevent GPU memory from spiking during checkpoint save,
647647
# we consolidate the full model and optim state dicts on CPU for rank 0
648-
cpu_state_dict = training.get_full_model_state_dict(
649-
self._model,
648+
cpu_state_dict = training.gather_cpu_state_dict(
649+
self._model.state_dict(),
650650
self._is_rank_zero,
651651
device=self._device,
652652
)

recipes/knowledge_distillation_distributed.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
from torchtune.modules.peft import (
2626
DoRALinear,
2727
get_adapter_params,
28+
get_adapter_state_dict,
2829
get_lora_module_names,
2930
get_merged_lora_ckpt,
3031
load_dora_magnitudes,
@@ -707,8 +708,8 @@ def save_checkpoint(self, epoch: int) -> None:
707708
intermediate_checkpoint = epoch + 1 < self.total_epochs
708709
# To prevent GPU memory from spiking during checkpoint save,
709710
# we consolidate the full model and optim state dicts on CPU for rank 0
710-
cpu_state_dict = training.get_full_model_state_dict(
711-
self._model,
711+
cpu_state_dict = training.gather_cpu_state_dict(
712+
self._model.state_dict(),
712713
self._is_rank_zero,
713714
device=self._device,
714715
)
@@ -728,10 +729,7 @@ def save_checkpoint(self, epoch: int) -> None:
728729

729730
# Filter out the adapter keys and weights from the model state dict. These will
730731
# be saved separately
731-
adapter_key_filter = lambda x: x in self.adapter_params
732-
adapter_state_dict = {
733-
k: v for k, v in cpu_state_dict.items() if adapter_key_filter(k)
734-
}
732+
adapter_state_dict = get_adapter_state_dict(cpu_state_dict)
735733
checkpoint_dict.update({training.ADAPTER_KEY: adapter_state_dict})
736734

737735
# merge the adapter weights and base weights to create the model checkpoint

recipes/knowledge_distillation_single_device.py

Lines changed: 2 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
from torchtune.datasets import ConcatDataset
2424
from torchtune.modules.peft import (
2525
get_adapter_params,
26+
get_adapter_state_dict,
2627
get_lora_module_names,
2728
get_merged_lora_ckpt,
2829
load_dora_magnitudes,
@@ -586,10 +587,7 @@ def save_checkpoint(self, epoch: int) -> None:
586587
ckpt_dict.update({training.MODEL_KEY: merged_state_dict})
587588

588589
# Construct the adapter weights
589-
adapter_key_filter = lambda x: x in self.adapter_params
590-
adapter_state_dict = {
591-
k: v for k, v in self._model.state_dict().items() if adapter_key_filter(k)
592-
}
590+
adapter_state_dict = get_adapter_state_dict(self._model.state_dict())
593591
ckpt_dict.update({training.ADAPTER_KEY: adapter_state_dict})
594592
adapter_config = {
595593
"r": self._lora_rank,

recipes/lora_dpo_distributed.py

Lines changed: 16 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
disable_adapter,
2626
DoRALinear,
2727
get_adapter_params,
28+
get_adapter_state_dict,
2829
get_merged_lora_ckpt,
2930
load_dora_magnitudes,
3031
LoRALinear,
@@ -504,8 +505,12 @@ def save_checkpoint(
504505
intermediate_checkpoint = epoch + 1 < self.total_epochs
505506
# To prevent GPU memory from spiking during checkpoint save,
506507
# we consolidate the full model and optim state dicts on CPU for rank 0
507-
cpu_state_dict = training.get_full_model_state_dict(
508-
self._model,
508+
state_dict = self._model.state_dict()
509+
if self._save_adapter_weights_only:
510+
state_dict = get_adapter_state_dict(state_dict, device=None)
511+
512+
cpu_state_dict = training.gather_cpu_state_dict(
513+
state_dict,
509514
self._is_rank_zero,
510515
device=self._device,
511516
)
@@ -521,23 +526,21 @@ def save_checkpoint(
521526
# Now that we have the model and opt state dict, create the actual checkpoint dict
522527
# to be sent to the checkpointer and ultimately written to file
523528
if self._is_rank_zero:
524-
525-
# Filter out the adapter keys and weights from the model state dict. These will
526-
# be saved separately
527-
adapter_key_filter = lambda x: x in self.adapter_params
528-
adapter_state_dict = {
529-
k: v for k, v in cpu_state_dict.items() if adapter_key_filter(k)
530-
}
531-
checkpoint_dict.update({training.ADAPTER_KEY: adapter_state_dict})
532-
533-
# merge the adapter weights and base weights to create the model checkpoint
534-
if not self._save_adapter_weights_only:
529+
if self._save_adapter_weights_only:
530+
adapter_state_dict = cpu_state_dict
531+
else:
532+
# Filter out the adapter keys and weights from the model state dict. These will
533+
# be saved separately
534+
adapter_state_dict = get_adapter_state_dict(cpu_state_dict)
535+
536+
# merge the adapter weights and base weights to create the model checkpoint
535537
merged_state_dict = get_merged_lora_ckpt(
536538
cpu_state_dict,
537539
rank=self._lora_rank,
538540
alpha=self._lora_alpha,
539541
)
540542
checkpoint_dict.update({training.MODEL_KEY: merged_state_dict})
543+
checkpoint_dict.update({training.ADAPTER_KEY: adapter_state_dict})
541544

542545
# if training is in-progress, checkpoint the optimizer state and recipe state
543546
# as well.

recipes/lora_dpo_single_device.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
from torchtune.modules.peft import (
2424
disable_adapter,
2525
get_adapter_params,
26+
get_adapter_state_dict,
2627
get_merged_lora_ckpt,
2728
set_trainable_params,
2829
validate_missing_and_unexpected_for_lora,
@@ -407,7 +408,7 @@ def save_checkpoint(self, epoch: int) -> None:
407408
}
408409
)
409410

410-
adapter_state_dict = {k: v.cpu() for k, v in self.adapter_params.items()}
411+
adapter_state_dict = get_adapter_state_dict(self._model.state_dict())
411412
ckpt_dict.update({training.ADAPTER_KEY: adapter_state_dict})
412413
if not self._save_adapter_weights_only:
413414
# Construct the full state dict with LoRA weights merged into base LLM weights

recipes/lora_finetune_distributed.py

Lines changed: 17 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
from torchtune.modules.peft import (
2727
DoRALinear,
2828
get_adapter_params,
29+
get_adapter_state_dict,
2930
get_lora_module_names,
3031
get_merged_lora_ckpt,
3132
load_dora_magnitudes,
@@ -452,8 +453,7 @@ def _setup_model(
452453
with training.set_default_dtype(self._dtype), torch.device("meta"):
453454
model = config.instantiate(cfg_model)
454455

455-
self.adapter_params = get_adapter_params(model)
456-
set_trainable_params(model, self.adapter_params)
456+
set_trainable_params(model, get_adapter_params(model))
457457

458458
if self._compile:
459459
training.compile_model(model, verbose=self._is_rank_zero)
@@ -664,11 +664,14 @@ def save_checkpoint(
664664

665665
# To prevent GPU memory from spiking during checkpoint save,
666666
# we consolidate the full model and optim state dicts on CPU for rank 0
667-
cpu_state_dict = training.get_full_model_state_dict(
668-
self._model,
667+
state_dict = self._model.state_dict()
668+
if self._save_adapter_weights_only:
669+
state_dict = get_adapter_state_dict(state_dict, device=None)
670+
671+
cpu_state_dict = training.gather_cpu_state_dict(
672+
state_dict,
669673
self._is_rank_zero,
670674
device=self._device,
671-
trainable_only=self._save_adapter_weights_only,
672675
)
673676
if self._is_rank_zero:
674677
log.info(
@@ -694,22 +697,22 @@ def save_checkpoint(
694697
# to be sent to the checkpointer and ultimately written to file
695698
if self._is_rank_zero:
696699
start = time.perf_counter()
697-
# Filter out the adapter keys and weights from the model state dict. These will
698-
# be saved separately
699-
adapter_key_filter = lambda x: x in self.adapter_params
700-
adapter_state_dict = {
701-
k: v for k, v in cpu_state_dict.items() if adapter_key_filter(k)
702-
}
703-
checkpoint_dict.update({training.ADAPTER_KEY: adapter_state_dict})
704700

705-
# merge the adapter weights and base weights to create the model checkpoint
706-
if not self._save_adapter_weights_only:
701+
if self._save_adapter_weights_only:
702+
adapter_state_dict = cpu_state_dict
703+
else:
704+
# Filter out the adapter keys and weights from the model state dict. These will
705+
# be saved separately
706+
adapter_state_dict = get_adapter_state_dict(cpu_state_dict)
707+
708+
# merge the adapter weights and base weights to create the model checkpoint
707709
merged_state_dict = get_merged_lora_ckpt(
708710
cpu_state_dict,
709711
rank=self._lora_rank,
710712
alpha=self._lora_alpha,
711713
)
712714
checkpoint_dict.update({training.MODEL_KEY: merged_state_dict})
715+
checkpoint_dict.update({training.ADAPTER_KEY: adapter_state_dict})
713716

714717
# if training is in-progress, checkpoint the optimizer state and recipe state
715718
# as well.

recipes/lora_finetune_single_device.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
from torchtune.datasets import ConcatDataset
2525
from torchtune.modules.peft import (
2626
get_adapter_params,
27+
get_adapter_state_dict,
2728
get_lora_module_names,
2829
get_merged_lora_ckpt,
2930
load_dora_magnitudes,
@@ -592,7 +593,7 @@ def save_checkpoint(self, epoch: int) -> None:
592593
}
593594
)
594595

595-
adapter_state_dict = {k: v.cpu() for k, v in self.adapter_params.items()}
596+
adapter_state_dict = get_adapter_state_dict(self._model.state_dict())
596597
ckpt_dict.update({training.ADAPTER_KEY: adapter_state_dict})
597598

598599
if not self._save_adapter_weights_only:

recipes/qat_distributed.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -673,8 +673,8 @@ def save_checkpoint(
673673

674674
# To prevent GPU memory from spiking during checkpoint save,
675675
# we consolidate the full model and optim state dicts on CPU for rank 0
676-
cpu_state_dict = training.get_full_model_state_dict(
677-
self._model,
676+
cpu_state_dict = training.gather_cpu_state_dict(
677+
self._model.state_dict(),
678678
self._is_rank_zero,
679679
device=self._device,
680680
)

0 commit comments

Comments
 (0)