Skip to content

Commit 08efaed

Browse files
authored
Refactor Recipe State Dict Code (#1964)
1 parent 550163b commit 08efaed

18 files changed

+226
-129
lines changed

docs/source/api_ref_modules.rst

+1
Original file line numberDiff line numberDiff line change
@@ -75,6 +75,7 @@ PEFT Components
7575
peft.AdapterModule
7676
peft.get_adapter_params
7777
peft.set_trainable_params
78+
peft.get_adapter_state_dict
7879
peft.validate_missing_and_unexpected_for_lora
7980
peft.validate_state_dict_for_lora
8081
peft.disable_adapter

docs/source/api_ref_training.rst

+1
Original file line numberDiff line numberDiff line change
@@ -56,6 +56,7 @@ Utilities for enabling and working with distributed training.
5656
get_world_size_and_rank
5757
get_full_finetune_fsdp_wrap_policy
5858
lora_fsdp_wrap_policy
59+
gather_cpu_state_dict
5960

6061
.. _ac_label:
6162

recipes/full_finetune_distributed.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -645,8 +645,8 @@ def save_checkpoint(
645645

646646
# To prevent GPU memory from spiking during checkpoint save,
647647
# we consolidate the full model and optim state dicts on CPU for rank 0
648-
cpu_state_dict = training.get_full_model_state_dict(
649-
self._model,
648+
cpu_state_dict = training.gather_cpu_state_dict(
649+
self._model.state_dict(),
650650
self._is_rank_zero,
651651
device=self._device,
652652
)

recipes/knowledge_distillation_distributed.py

+4-6
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
from torchtune.modules.peft import (
2626
DoRALinear,
2727
get_adapter_params,
28+
get_adapter_state_dict,
2829
get_lora_module_names,
2930
get_merged_lora_ckpt,
3031
load_dora_magnitudes,
@@ -707,8 +708,8 @@ def save_checkpoint(self, epoch: int) -> None:
707708
intermediate_checkpoint = epoch + 1 < self.total_epochs
708709
# To prevent GPU memory from spiking during checkpoint save,
709710
# we consolidate the full model and optim state dicts on CPU for rank 0
710-
cpu_state_dict = training.get_full_model_state_dict(
711-
self._model,
711+
cpu_state_dict = training.gather_cpu_state_dict(
712+
self._model.state_dict(),
712713
self._is_rank_zero,
713714
device=self._device,
714715
)
@@ -728,10 +729,7 @@ def save_checkpoint(self, epoch: int) -> None:
728729

729730
# Filter out the adapter keys and weights from the model state dict. These will
730731
# be saved separately
731-
adapter_key_filter = lambda x: x in self.adapter_params
732-
adapter_state_dict = {
733-
k: v for k, v in cpu_state_dict.items() if adapter_key_filter(k)
734-
}
732+
adapter_state_dict = get_adapter_state_dict(cpu_state_dict)
735733
checkpoint_dict.update({training.ADAPTER_KEY: adapter_state_dict})
736734

737735
# merge the adapter weights and base weights to create the model checkpoint

recipes/knowledge_distillation_single_device.py

+2-4
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
from torchtune.datasets import ConcatDataset
2424
from torchtune.modules.peft import (
2525
get_adapter_params,
26+
get_adapter_state_dict,
2627
get_lora_module_names,
2728
get_merged_lora_ckpt,
2829
load_dora_magnitudes,
@@ -586,10 +587,7 @@ def save_checkpoint(self, epoch: int) -> None:
586587
ckpt_dict.update({training.MODEL_KEY: merged_state_dict})
587588

588589
# Construct the adapter weights
589-
adapter_key_filter = lambda x: x in self.adapter_params
590-
adapter_state_dict = {
591-
k: v for k, v in self._model.state_dict().items() if adapter_key_filter(k)
592-
}
590+
adapter_state_dict = get_adapter_state_dict(self._model.state_dict())
593591
ckpt_dict.update({training.ADAPTER_KEY: adapter_state_dict})
594592
adapter_config = {
595593
"r": self._lora_rank,

recipes/lora_dpo_distributed.py

+16-13
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
disable_adapter,
2626
DoRALinear,
2727
get_adapter_params,
28+
get_adapter_state_dict,
2829
get_merged_lora_ckpt,
2930
load_dora_magnitudes,
3031
LoRALinear,
@@ -504,8 +505,12 @@ def save_checkpoint(
504505
intermediate_checkpoint = epoch + 1 < self.total_epochs
505506
# To prevent GPU memory from spiking during checkpoint save,
506507
# we consolidate the full model and optim state dicts on CPU for rank 0
507-
cpu_state_dict = training.get_full_model_state_dict(
508-
self._model,
508+
state_dict = self._model.state_dict()
509+
if self._save_adapter_weights_only:
510+
state_dict = get_adapter_state_dict(state_dict, device=None)
511+
512+
cpu_state_dict = training.gather_cpu_state_dict(
513+
state_dict,
509514
self._is_rank_zero,
510515
device=self._device,
511516
)
@@ -521,23 +526,21 @@ def save_checkpoint(
521526
# Now that we have the model and opt state dict, create the actual checkpoint dict
522527
# to be sent to the checkpointer and ultimately written to file
523528
if self._is_rank_zero:
524-
525-
# Filter out the adapter keys and weights from the model state dict. These will
526-
# be saved separately
527-
adapter_key_filter = lambda x: x in self.adapter_params
528-
adapter_state_dict = {
529-
k: v for k, v in cpu_state_dict.items() if adapter_key_filter(k)
530-
}
531-
checkpoint_dict.update({training.ADAPTER_KEY: adapter_state_dict})
532-
533-
# merge the adapter weights and base weights to create the model checkpoint
534-
if not self._save_adapter_weights_only:
529+
if self._save_adapter_weights_only:
530+
adapter_state_dict = cpu_state_dict
531+
else:
532+
# Filter out the adapter keys and weights from the model state dict. These will
533+
# be saved separately
534+
adapter_state_dict = get_adapter_state_dict(cpu_state_dict)
535+
536+
# merge the adapter weights and base weights to create the model checkpoint
535537
merged_state_dict = get_merged_lora_ckpt(
536538
cpu_state_dict,
537539
rank=self._lora_rank,
538540
alpha=self._lora_alpha,
539541
)
540542
checkpoint_dict.update({training.MODEL_KEY: merged_state_dict})
543+
checkpoint_dict.update({training.ADAPTER_KEY: adapter_state_dict})
541544

542545
# if training is in-progress, checkpoint the optimizer state and recipe state
543546
# as well.

recipes/lora_dpo_single_device.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
from torchtune.modules.peft import (
2424
disable_adapter,
2525
get_adapter_params,
26+
get_adapter_state_dict,
2627
get_merged_lora_ckpt,
2728
set_trainable_params,
2829
validate_missing_and_unexpected_for_lora,
@@ -407,7 +408,7 @@ def save_checkpoint(self, epoch: int) -> None:
407408
}
408409
)
409410

410-
adapter_state_dict = {k: v.cpu() for k, v in self.adapter_params.items()}
411+
adapter_state_dict = get_adapter_state_dict(self._model.state_dict())
411412
ckpt_dict.update({training.ADAPTER_KEY: adapter_state_dict})
412413
if not self._save_adapter_weights_only:
413414
# Construct the full state dict with LoRA weights merged into base LLM weights

recipes/lora_finetune_distributed.py

+17-14
Original file line numberDiff line numberDiff line change
@@ -26,6 +26,7 @@
2626
from torchtune.modules.peft import (
2727
DoRALinear,
2828
get_adapter_params,
29+
get_adapter_state_dict,
2930
get_lora_module_names,
3031
get_merged_lora_ckpt,
3132
load_dora_magnitudes,
@@ -452,8 +453,7 @@ def _setup_model(
452453
with training.set_default_dtype(self._dtype), torch.device("meta"):
453454
model = config.instantiate(cfg_model)
454455

455-
self.adapter_params = get_adapter_params(model)
456-
set_trainable_params(model, self.adapter_params)
456+
set_trainable_params(model, get_adapter_params(model))
457457

458458
if self._compile:
459459
training.compile_model(model, verbose=self._is_rank_zero)
@@ -664,11 +664,14 @@ def save_checkpoint(
664664

665665
# To prevent GPU memory from spiking during checkpoint save,
666666
# we consolidate the full model and optim state dicts on CPU for rank 0
667-
cpu_state_dict = training.get_full_model_state_dict(
668-
self._model,
667+
state_dict = self._model.state_dict()
668+
if self._save_adapter_weights_only:
669+
state_dict = get_adapter_state_dict(state_dict, device=None)
670+
671+
cpu_state_dict = training.gather_cpu_state_dict(
672+
state_dict,
669673
self._is_rank_zero,
670674
device=self._device,
671-
trainable_only=self._save_adapter_weights_only,
672675
)
673676
if self._is_rank_zero:
674677
log.info(
@@ -694,22 +697,22 @@ def save_checkpoint(
694697
# to be sent to the checkpointer and ultimately written to file
695698
if self._is_rank_zero:
696699
start = time.perf_counter()
697-
# Filter out the adapter keys and weights from the model state dict. These will
698-
# be saved separately
699-
adapter_key_filter = lambda x: x in self.adapter_params
700-
adapter_state_dict = {
701-
k: v for k, v in cpu_state_dict.items() if adapter_key_filter(k)
702-
}
703-
checkpoint_dict.update({training.ADAPTER_KEY: adapter_state_dict})
704700

705-
# merge the adapter weights and base weights to create the model checkpoint
706-
if not self._save_adapter_weights_only:
701+
if self._save_adapter_weights_only:
702+
adapter_state_dict = cpu_state_dict
703+
else:
704+
# Filter out the adapter keys and weights from the model state dict. These will
705+
# be saved separately
706+
adapter_state_dict = get_adapter_state_dict(cpu_state_dict)
707+
708+
# merge the adapter weights and base weights to create the model checkpoint
707709
merged_state_dict = get_merged_lora_ckpt(
708710
cpu_state_dict,
709711
rank=self._lora_rank,
710712
alpha=self._lora_alpha,
711713
)
712714
checkpoint_dict.update({training.MODEL_KEY: merged_state_dict})
715+
checkpoint_dict.update({training.ADAPTER_KEY: adapter_state_dict})
713716

714717
# if training is in-progress, checkpoint the optimizer state and recipe state
715718
# as well.

recipes/lora_finetune_single_device.py

+2-1
Original file line numberDiff line numberDiff line change
@@ -24,6 +24,7 @@
2424
from torchtune.datasets import ConcatDataset
2525
from torchtune.modules.peft import (
2626
get_adapter_params,
27+
get_adapter_state_dict,
2728
get_lora_module_names,
2829
get_merged_lora_ckpt,
2930
load_dora_magnitudes,
@@ -592,7 +593,7 @@ def save_checkpoint(self, epoch: int) -> None:
592593
}
593594
)
594595

595-
adapter_state_dict = {k: v.cpu() for k, v in self.adapter_params.items()}
596+
adapter_state_dict = get_adapter_state_dict(self._model.state_dict())
596597
ckpt_dict.update({training.ADAPTER_KEY: adapter_state_dict})
597598

598599
if not self._save_adapter_weights_only:

recipes/qat_distributed.py

+2-2
Original file line numberDiff line numberDiff line change
@@ -673,8 +673,8 @@ def save_checkpoint(
673673

674674
# To prevent GPU memory from spiking during checkpoint save,
675675
# we consolidate the full model and optim state dicts on CPU for rank 0
676-
cpu_state_dict = training.get_full_model_state_dict(
677-
self._model,
676+
cpu_state_dict = training.gather_cpu_state_dict(
677+
self._model.state_dict(),
678678
self._is_rank_zero,
679679
device=self._device,
680680
)

tests/recipes/test_full_finetune_distributed.py

+87-1
Original file line numberDiff line numberDiff line change
@@ -4,8 +4,8 @@
44
# This source code is licensed under the BSD-style license found in the
55
# LICENSE file in the root directory of this source tree.
66

7+
import os
78
import runpy
8-
99
import sys
1010
from pathlib import Path
1111

@@ -113,3 +113,89 @@ def test_loss(
113113
torch.testing.assert_close(
114114
loss_values, expected_loss_values, rtol=1e-4, atol=1e-4
115115
)
116+
117+
@pytest.mark.integration_test
118+
@pytest.mark.parametrize(
119+
"config, model_type, ckpt_type, micro_batch_size, gradient_accumulation_steps, optim_in_bwd",
120+
[
121+
("llama3/8B_full", "llama3", "tune", 1, 4, False),
122+
],
123+
)
124+
@gpu_test(gpu_count=2)
125+
def test_training_state_on_resume(
126+
self,
127+
micro_batch_size,
128+
gradient_accumulation_steps,
129+
config,
130+
model_type,
131+
ckpt_type,
132+
optim_in_bwd,
133+
tmpdir,
134+
monkeypatch,
135+
):
136+
ckpt_component = CKPT_COMPONENT_MAP[ckpt_type]
137+
ckpt = model_type + "_" + ckpt_type
138+
ckpt_path = Path(CKPT_MODEL_PATHS[ckpt])
139+
tokenizer_path = Path(TOKENIZER_PATHS[model_type])
140+
ckpt_dir = ckpt_path.parent
141+
log_file = gen_log_file_name(tmpdir)
142+
143+
# Config file needed for model conversion.
144+
# Create a second copy for training resume
145+
write_hf_ckpt_config(ckpt_dir)
146+
write_hf_ckpt_config(tmpdir)
147+
148+
# Train for two epochs
149+
cmd_1 = f"""
150+
tune run --nnodes 1 --nproc_per_node 2 full_finetune_distributed \
151+
--config {config} \
152+
batch_size={micro_batch_size} \
153+
gradient_accumulation_steps={gradient_accumulation_steps} \
154+
output_dir={tmpdir} \
155+
checkpointer._component_={ckpt_component} \
156+
checkpointer.checkpoint_dir='{ckpt_dir}' \
157+
checkpointer.checkpoint_files=[{ckpt_path}]\
158+
checkpointer.output_dir={tmpdir} \
159+
checkpointer.model_type={model_type.upper()} \
160+
tokenizer.path='{tokenizer_path}' \
161+
tokenizer.prompt_template=null \
162+
clip_grad_norm=100 \
163+
""".split()
164+
165+
model_config = MODEL_TEST_CONFIGS[model_type]
166+
cmd_1 = cmd_1 + self._get_test_config_overrides() + model_config
167+
168+
monkeypatch.setattr(sys, "argv", cmd_1)
169+
runpy.run_path(TUNE_PATH, run_name="__main__")
170+
171+
# Resume training
172+
cmd_2 = f"""
173+
tune run --nnodes 1 --nproc_per_node 2 full_finetune_distributed \
174+
--config {config} \
175+
batch_size={micro_batch_size} \
176+
gradient_accumulation_steps={gradient_accumulation_steps} \
177+
output_dir={tmpdir} \
178+
checkpointer._component_={ckpt_component} \
179+
checkpointer.checkpoint_dir='{tmpdir}' \
180+
checkpointer.checkpoint_files=[{os.path.join(tmpdir, "torchtune_model_0.pt")}]\
181+
checkpointer.recipe_checkpoint={os.path.join(tmpdir, "recipe_state.pt")}\
182+
checkpointer.output_dir={tmpdir} \
183+
checkpointer.model_type={model_type.upper()} \
184+
tokenizer.path='{tokenizer_path}' \
185+
tokenizer.prompt_template=null \
186+
resume_from_checkpoint=True \
187+
metric_logger.filename={log_file} \
188+
clip_grad_norm=100 \
189+
""".split()
190+
191+
cmd_2 = cmd_2 + self._get_test_config_overrides() + model_config
192+
193+
monkeypatch.setattr(sys, "argv", cmd_2)
194+
runpy.run_path(TUNE_PATH, run_name="__main__")
195+
196+
expected_loss_values = self._fetch_expected_loss_values(model_type)[2:]
197+
198+
loss_values = get_loss_values_from_metric_logger(log_file)
199+
torch.testing.assert_close(
200+
loss_values, expected_loss_values, rtol=1e-4, atol=1e-4
201+
)

tests/recipes/test_full_finetune_single_device.py

+1-1
Original file line numberDiff line numberDiff line change
@@ -181,7 +181,7 @@ def test_training_state_on_resume(self, tmpdir, monkeypatch):
181181
checkpointer._component_=torchtune.training.FullModelHFCheckpointer \
182182
checkpointer.checkpoint_dir={tmpdir} \
183183
checkpointer.checkpoint_files=[{os.path.join(tmpdir, "hf_model_0001_0.pt")}]\
184-
checkpointer.recipe_checkpoint={os.path.join(tmpdir, "recipe_state.pt")}
184+
checkpointer.recipe_checkpoint={os.path.join(tmpdir, "recipe_state.pt")}\
185185
checkpointer.output_dir={tmpdir} \
186186
checkpointer.model_type=LLAMA2 \
187187
tokenizer.path=/tmp/test-artifacts/tokenizer.model \

0 commit comments

Comments
 (0)