Skip to content

Commit 0f90093

Browse files
updating full recipe
1 parent 753e822 commit 0f90093

File tree

1 file changed

+18
-29
lines changed

1 file changed

+18
-29
lines changed

recipes/full_dpo_distributed.py

+18-29
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@
2222
from torchtune.recipe_interfaces import FTRecipeInterface
2323
from torchtune.training import DummyProfiler, PROFILER_KEY
2424
from torchtune.training.activations import apply_selective_activation_checkpointing
25+
from torchtune.training.checkpointing._checkpoint_client import CheckpointClient
2526
from torchtune.utils import get_world_size_and_rank
2627
from tqdm import tqdm
2728

@@ -112,7 +113,6 @@ class FullDPORecipeDistributed(FTRecipeInterface):
112113
"""
113114

114115
def __init__(self, cfg: DictConfig) -> None:
115-
116116
self._device = utils.get_device(device=cfg.device)
117117
self._dtype = training.get_dtype(cfg.dtype, device=self._device)
118118

@@ -121,15 +121,6 @@ def __init__(self, cfg: DictConfig) -> None:
121121
"full fp16 training is not supported with this recipe. Please use bf16 or fp32 instead."
122122
)
123123

124-
if (
125-
cfg.get("fsdp_cpu_offload", False)
126-
and cfg.optimizer.get("fused", False)
127-
and not utils.torch_version_ge("2.4.0")
128-
):
129-
raise RuntimeError(
130-
"Using fused optimizer on CPU is only supported in PyTorch nightly."
131-
)
132-
133124
# logging attributes
134125
self._output_dir = cfg.output_dir
135126
self._log_every_n_steps = cfg.get("log_every_n_steps", 1)
@@ -141,8 +132,6 @@ def __init__(self, cfg: DictConfig) -> None:
141132
)
142133
self._log_peak_memory_stats = False
143134

144-
# _is_rank_zero is used primarily for logging. In the future, the logger
145-
# should directly take care of this
146135
_, rank = get_world_size_and_rank()
147136
self._is_rank_zero = rank == 0
148137

@@ -151,6 +140,7 @@ def __init__(self, cfg: DictConfig) -> None:
151140
self._gradient_accumulation_steps = cfg.gradient_accumulation_steps
152141
self._optimizer_in_bwd = cfg.get("optimizer_in_bwd", False)
153142
self._clip_grad_norm = cfg.get("clip_grad_norm", None)
143+
self._checkpoint_client = CheckpointClient(cfg)
154144

155145
# Optimizer in backward is not compatible with gradient accumulation or gradient clipping
156146
if self._optimizer_in_bwd:
@@ -273,9 +263,9 @@ def setup(self, cfg: DictConfig) -> None:
273263

274264
# log config with parameter override
275265
self._metric_logger.log_config(cfg)
276-
log.info("_metric_logger is initialized.")
277266

278-
checkpoint_dict = self.load_checkpoint(cfg_checkpointer=cfg.checkpointer)
267+
# Load the base model
268+
checkpoint_dict = self._checkpoint_client.load_base_checkpoint()
279269

280270
self._compile = cfg.get("compile", False)
281271
self._model = self._setup_model(
@@ -443,22 +433,18 @@ def _setup_model(
443433
full state dicts are loaded with ``torch.load(mmap=True)``
444434
"""
445435

446-
if self._is_rank_zero:
447-
log.info(
448-
"FSDP is enabled. Instantiating model and loading checkpoint on Rank 0 ..."
449-
)
450-
init_start = time.perf_counter()
436+
utils.log_rank_zero(
437+
log,
438+
"FSDP is enabled. Instantiating model and loading checkpoint on Rank 0 ...",
439+
)
440+
init_start = time.perf_counter()
451441

452442
with training.set_default_dtype(self._dtype), torch.device("meta"):
453443
model = config.instantiate(cfg_model)
454444

455445
if self._compile:
456446
training.compile_model(model, verbose=self._is_rank_zero)
457447

458-
model.load_state_dict(model_state_dict, assign=True)
459-
if self._dtype == torch.bfloat16:
460-
model = model.to(torch.bfloat16)
461-
462448
# We currently have two versions of activation checkpointing in this recipe
463449
# for testing and BC purposes. ``enable_activation_checkpointing`` controls
464450
# the older version of AC and this behavior is unchanged
@@ -516,10 +502,12 @@ def _setup_model(
516502
# Ensure no params and buffers are on meta device
517503
training.validate_no_params_on_meta_device(model)
518504

505+
utils.log_rank_zero(
506+
log,
507+
f"Instantiating model and loading checkpoint took {time.perf_counter() - init_start:.2f} secs",
508+
)
509+
519510
if self._is_rank_zero:
520-
log.info(
521-
f"Instantiating model and loading checkpoint took {time.perf_counter() - init_start:.2f} secs"
522-
)
523511
memory_stats = training.get_memory_stats(device=self._device)
524512
training.log_memory_stats(memory_stats)
525513

@@ -536,7 +524,7 @@ def _setup_model(
536524
torch.distributed.barrier()
537525

538526
return model
539-
527+
540528
def _setup_reference_model(
541529
self,
542530
cfg_model: DictConfig,
@@ -551,15 +539,15 @@ def _setup_reference_model(
551539
the right dtype
552540
b. All ranks calls ``load_state_dict`` without peaking CPU RAMs since
553541
full state dicts are loaded with ``torch.load(mmap=True)``
554-
"""
542+
"""
555543
return self._setup_model(
556544
cfg_model,
557545
False,
558546
False,
559547
fsdp_cpu_offload,
560548
reshard_after_forward,
561549
model_state_dict,
562-
custom_sharded_layers
550+
custom_sharded_layers,
563551
)
564552

565553
def _setup_optimizer(
@@ -963,6 +951,7 @@ def recipe_main(cfg: DictConfig) -> None:
963951
"Distributed finetune recipe should be run via a distributed launcher."
964952
"If using tune CLI, please specify --nnodes 1 and --nproc_per_node [num_gpus]"
965953
)
954+
966955
init_process_group("cuda:nccl,cpu:gloo")
967956
if cfg.get("fsdp_cpu_offload", False):
968957
# Utilize all available CPU cores for intra-op parallelism. This provides ~2x

0 commit comments

Comments
 (0)