Skip to content

Commit ebed89c

Browse files
updating recipe
1 parent 0f90093 commit ebed89c

File tree

1 file changed

+96
-29
lines changed

1 file changed

+96
-29
lines changed

recipes/full_dpo_distributed.py

+96-29
Original file line numberDiff line numberDiff line change
@@ -22,7 +22,6 @@
2222
from torchtune.recipe_interfaces import FTRecipeInterface
2323
from torchtune.training import DummyProfiler, PROFILER_KEY
2424
from torchtune.training.activations import apply_selective_activation_checkpointing
25-
from torchtune.training.checkpointing._checkpoint_client import CheckpointClient
2625
from torchtune.utils import get_world_size_and_rank
2726
from tqdm import tqdm
2827

@@ -140,7 +139,6 @@ def __init__(self, cfg: DictConfig) -> None:
140139
self._gradient_accumulation_steps = cfg.gradient_accumulation_steps
141140
self._optimizer_in_bwd = cfg.get("optimizer_in_bwd", False)
142141
self._clip_grad_norm = cfg.get("clip_grad_norm", None)
143-
self._checkpoint_client = CheckpointClient(cfg)
144142

145143
# Optimizer in backward is not compatible with gradient accumulation or gradient clipping
146144
if self._optimizer_in_bwd:
@@ -189,27 +187,28 @@ def __init__(self, cfg: DictConfig) -> None:
189187
self.max_steps_per_epoch = cfg.max_steps_per_epoch
190188
self.global_step = 0
191189

192-
def load_checkpoint(self, cfg_checkpointer: DictConfig) -> Dict[str, Any]:
190+
def _load_checkpoint(self, cfg_checkpointer: DictConfig) -> Dict[str, Any]:
193191
"""
194192
Extract the checkpoint state from file and validate. If resume_from_checkpoint
195193
is True, this also includes the recipe state.
196194
"""
197195
self._checkpointer = config.instantiate(
198196
cfg_checkpointer,
199-
resume_from_checkpoint=self._resume_from_checkpoint,
197+
should_load_recipe_state=self._resume_from_checkpoint,
200198
)
201199
checkpoint_dict = self._checkpointer.load_checkpoint()
202200

203201
if self._resume_from_checkpoint:
204202
self._update_recipe_state(checkpoint_dict)
205203
return checkpoint_dict
206204

207-
def load_ref_states(self, cfg_ref_checkpointer: DictConfig) -> Dict[str, Any]:
205+
def _load_ref_checkpoint(self, cfg_ref_checkpointer: DictConfig) -> Dict[str, Any]:
208206
"""
209-
Extract the checkpoint state from file and validate. If resume_from_checkpoint
210-
is True, this also includes the recipe state.
207+
Extract the reference model checkpoint state from file.
211208
"""
212-
_ref_checkpointer = config.instantiate(cfg_ref_checkpointer)
209+
_ref_checkpointer = config.instantiate(
210+
cfg_ref_checkpointer, should_load_recipe_state=False
211+
)
213212
checkpoint_dict = _ref_checkpointer.load_checkpoint()
214213
return checkpoint_dict[training.MODEL_KEY]
215214

@@ -265,7 +264,8 @@ def setup(self, cfg: DictConfig) -> None:
265264
self._metric_logger.log_config(cfg)
266265

267266
# Load the base model
268-
checkpoint_dict = self._checkpoint_client.load_base_checkpoint()
267+
checkpoint_dict = self._load_checkpoint(cfg.checkpointer)
268+
ref_checkoint_dict = self._load_ref_checkpoint(cfg.ref_checkpointer)
269269

270270
self._compile = cfg.get("compile", False)
271271
self._model = self._setup_model(
@@ -279,16 +279,15 @@ def setup(self, cfg: DictConfig) -> None:
279279
ac_mode=cfg.get("ac_mode", None),
280280
ac_option=cfg.get("ac_option", None),
281281
)
282+
283+
# TODO (@SalmanMohammadi) investigate TP for ref model
282284
self._ref_model = self._setup_reference_model(
283285
cfg_model=cfg.model,
284-
custom_sharded_layers=cfg.get("custom_sharded_layers", None),
285286
fsdp_cpu_offload=cfg.get("fsdp_cpu_offload", False),
286287
reshard_after_forward=cfg.get("fsdp_reshard_after_forward", True),
287-
model_state_dict=self.load_ref_states(cfg.ref_checkpointer),
288+
model_state_dict=ref_checkoint_dict,
289+
custom_sharded_layers=cfg.get("custom_sharded_layers", None),
288290
)
289-
self._ref_model.eval()
290-
for p in self._ref_model.parameters():
291-
p.requires_grad = False
292291

293292
self._tokenizer = config.instantiate(cfg.tokenizer)
294293

@@ -534,22 +533,89 @@ def _setup_reference_model(
534533
custom_sharded_layers: Optional[List[str]] = None,
535534
) -> nn.Module:
536535
"""
537-
Model initialization has some important considerations:
536+
Similar to `self._setup_model`:
538537
a. To minimize GPU peak memory, we initialize the model on meta device with
539538
the right dtype
540539
b. All ranks calls ``load_state_dict`` without peaking CPU RAMs since
541540
full state dicts are loaded with ``torch.load(mmap=True)``
541+
542+
Additionally, since the reference model is inference-only, we omit some training-specific
543+
optimizations.
542544
"""
543-
return self._setup_model(
544-
cfg_model,
545-
False,
546-
False,
547-
fsdp_cpu_offload,
548-
reshard_after_forward,
545+
546+
utils.log_rank_zero(
547+
log,
548+
"FSDP is enabled. Instantiating reference model and loading checkpoint on Rank 0 ...",
549+
)
550+
init_start = time.perf_counter()
551+
552+
with training.set_default_dtype(self._dtype), torch.device("meta"):
553+
model = config.instantiate(cfg_model)
554+
555+
if self._compile:
556+
training.compile_model(model, verbose=self._is_rank_zero)
557+
558+
# For FSDP sharding
559+
fsdp_shard_conditions = [
560+
partial(
561+
training.get_shard_conditions,
562+
names_to_match=custom_sharded_layers,
563+
)
564+
]
565+
training.shard_model(
566+
model=model,
567+
shard_conditions=fsdp_shard_conditions,
568+
cpu_offload=fsdp_cpu_offload,
569+
reshard_after_forward=reshard_after_forward,
570+
)
571+
572+
with training.set_default_dtype(self._dtype), self._device:
573+
for m in model.modules():
574+
# RoPE is not covered in state dict
575+
if hasattr(m, "rope_init"):
576+
m.rope_init()
577+
578+
# This method will convert the full model state dict into a sharded state
579+
# dict and load into the model
580+
training.load_from_full_model_state_dict(
581+
model,
549582
model_state_dict,
550-
custom_sharded_layers,
583+
self._device,
584+
strict=True,
585+
cpu_offload=fsdp_cpu_offload,
586+
)
587+
588+
# Ensure no params and buffers are on meta device
589+
training.validate_no_params_on_meta_device(model)
590+
591+
utils.log_rank_zero(
592+
log,
593+
f"Instantiating reference model and loading checkpoint took {time.perf_counter() - init_start:.2f} secs",
551594
)
552595

596+
if self._is_rank_zero:
597+
memory_stats = training.get_memory_stats(device=self._device)
598+
training.log_memory_stats(memory_stats)
599+
600+
# disabling dropout if found - non-determinism leads to issues in e.g. comparing logprobs
601+
# between ref policy and current policy
602+
for module in model.modules():
603+
if isinstance(module, torch.nn.Dropout):
604+
warn(
605+
f"Dropout found in {module}. This is likely to cause issues during training. Disabling."
606+
)
607+
module.p = 0
608+
609+
for p in self._ref_model.parameters():
610+
p.requires_grad = False
611+
612+
model.eval()
613+
614+
# synchronize before training begins
615+
torch.distributed.barrier()
616+
617+
return model
618+
553619
def _setup_optimizer(
554620
self,
555621
cfg_optimizer: DictConfig,
@@ -831,13 +897,14 @@ def train(self) -> None:
831897
break
832898

833899
# batch is input_ids, labels
834-
num_tokens += batch[0].numel()
835-
(
836-
policy_chosen_log_probs,
837-
policy_rejected_log_probs,
838-
policy_chosen_logits,
839-
policy_rejected_logits,
840-
) = self.concatenated_forward(self._model, batch)
900+
with self.activations_handling_ctx:
901+
num_tokens += batch[0].numel()
902+
(
903+
policy_chosen_log_probs,
904+
policy_rejected_log_probs,
905+
policy_chosen_logits,
906+
policy_rejected_logits,
907+
) = self.concatenated_forward(self._model, batch)
841908

842909
policy_chosen_logits_mean = policy_chosen_logits.detach().mean()
843910
policy_rejected_logits_mean = policy_rejected_logits.detach().mean()

0 commit comments

Comments
 (0)