22
22
from torchtune .recipe_interfaces import FTRecipeInterface
23
23
from torchtune .training import DummyProfiler , PROFILER_KEY
24
24
from torchtune .training .activations import apply_selective_activation_checkpointing
25
- from torchtune .training .checkpointing ._checkpoint_client import CheckpointClient
26
25
from torchtune .utils import get_world_size_and_rank
27
26
from tqdm import tqdm
28
27
@@ -140,7 +139,6 @@ def __init__(self, cfg: DictConfig) -> None:
140
139
self ._gradient_accumulation_steps = cfg .gradient_accumulation_steps
141
140
self ._optimizer_in_bwd = cfg .get ("optimizer_in_bwd" , False )
142
141
self ._clip_grad_norm = cfg .get ("clip_grad_norm" , None )
143
- self ._checkpoint_client = CheckpointClient (cfg )
144
142
145
143
# Optimizer in backward is not compatible with gradient accumulation or gradient clipping
146
144
if self ._optimizer_in_bwd :
@@ -189,27 +187,28 @@ def __init__(self, cfg: DictConfig) -> None:
189
187
self .max_steps_per_epoch = cfg .max_steps_per_epoch
190
188
self .global_step = 0
191
189
192
- def load_checkpoint (self , cfg_checkpointer : DictConfig ) -> Dict [str , Any ]:
190
+ def _load_checkpoint (self , cfg_checkpointer : DictConfig ) -> Dict [str , Any ]:
193
191
"""
194
192
Extract the checkpoint state from file and validate. If resume_from_checkpoint
195
193
is True, this also includes the recipe state.
196
194
"""
197
195
self ._checkpointer = config .instantiate (
198
196
cfg_checkpointer ,
199
- resume_from_checkpoint = self ._resume_from_checkpoint ,
197
+ should_load_recipe_state = self ._resume_from_checkpoint ,
200
198
)
201
199
checkpoint_dict = self ._checkpointer .load_checkpoint ()
202
200
203
201
if self ._resume_from_checkpoint :
204
202
self ._update_recipe_state (checkpoint_dict )
205
203
return checkpoint_dict
206
204
207
- def load_ref_states (self , cfg_ref_checkpointer : DictConfig ) -> Dict [str , Any ]:
205
+ def _load_ref_checkpoint (self , cfg_ref_checkpointer : DictConfig ) -> Dict [str , Any ]:
208
206
"""
209
- Extract the checkpoint state from file and validate. If resume_from_checkpoint
210
- is True, this also includes the recipe state.
207
+ Extract the reference model checkpoint state from file.
211
208
"""
212
- _ref_checkpointer = config .instantiate (cfg_ref_checkpointer )
209
+ _ref_checkpointer = config .instantiate (
210
+ cfg_ref_checkpointer , should_load_recipe_state = False
211
+ )
213
212
checkpoint_dict = _ref_checkpointer .load_checkpoint ()
214
213
return checkpoint_dict [training .MODEL_KEY ]
215
214
@@ -265,7 +264,8 @@ def setup(self, cfg: DictConfig) -> None:
265
264
self ._metric_logger .log_config (cfg )
266
265
267
266
# Load the base model
268
- checkpoint_dict = self ._checkpoint_client .load_base_checkpoint ()
267
+ checkpoint_dict = self ._load_checkpoint (cfg .checkpointer )
268
+ ref_checkoint_dict = self ._load_ref_checkpoint (cfg .ref_checkpointer )
269
269
270
270
self ._compile = cfg .get ("compile" , False )
271
271
self ._model = self ._setup_model (
@@ -279,16 +279,15 @@ def setup(self, cfg: DictConfig) -> None:
279
279
ac_mode = cfg .get ("ac_mode" , None ),
280
280
ac_option = cfg .get ("ac_option" , None ),
281
281
)
282
+
283
+ # TODO (@SalmanMohammadi) investigate TP for ref model
282
284
self ._ref_model = self ._setup_reference_model (
283
285
cfg_model = cfg .model ,
284
- custom_sharded_layers = cfg .get ("custom_sharded_layers" , None ),
285
286
fsdp_cpu_offload = cfg .get ("fsdp_cpu_offload" , False ),
286
287
reshard_after_forward = cfg .get ("fsdp_reshard_after_forward" , True ),
287
- model_state_dict = self .load_ref_states (cfg .ref_checkpointer ),
288
+ model_state_dict = ref_checkoint_dict ,
289
+ custom_sharded_layers = cfg .get ("custom_sharded_layers" , None ),
288
290
)
289
- self ._ref_model .eval ()
290
- for p in self ._ref_model .parameters ():
291
- p .requires_grad = False
292
291
293
292
self ._tokenizer = config .instantiate (cfg .tokenizer )
294
293
@@ -534,22 +533,89 @@ def _setup_reference_model(
534
533
custom_sharded_layers : Optional [List [str ]] = None ,
535
534
) -> nn .Module :
536
535
"""
537
- Model initialization has some important considerations :
536
+ Similar to `self._setup_model` :
538
537
a. To minimize GPU peak memory, we initialize the model on meta device with
539
538
the right dtype
540
539
b. All ranks calls ``load_state_dict`` without peaking CPU RAMs since
541
540
full state dicts are loaded with ``torch.load(mmap=True)``
541
+
542
+ Additionally, since the reference model is inference-only, we omit some training-specific
543
+ optimizations.
542
544
"""
543
- return self ._setup_model (
544
- cfg_model ,
545
- False ,
546
- False ,
547
- fsdp_cpu_offload ,
548
- reshard_after_forward ,
545
+
546
+ utils .log_rank_zero (
547
+ log ,
548
+ "FSDP is enabled. Instantiating reference model and loading checkpoint on Rank 0 ..." ,
549
+ )
550
+ init_start = time .perf_counter ()
551
+
552
+ with training .set_default_dtype (self ._dtype ), torch .device ("meta" ):
553
+ model = config .instantiate (cfg_model )
554
+
555
+ if self ._compile :
556
+ training .compile_model (model , verbose = self ._is_rank_zero )
557
+
558
+ # For FSDP sharding
559
+ fsdp_shard_conditions = [
560
+ partial (
561
+ training .get_shard_conditions ,
562
+ names_to_match = custom_sharded_layers ,
563
+ )
564
+ ]
565
+ training .shard_model (
566
+ model = model ,
567
+ shard_conditions = fsdp_shard_conditions ,
568
+ cpu_offload = fsdp_cpu_offload ,
569
+ reshard_after_forward = reshard_after_forward ,
570
+ )
571
+
572
+ with training .set_default_dtype (self ._dtype ), self ._device :
573
+ for m in model .modules ():
574
+ # RoPE is not covered in state dict
575
+ if hasattr (m , "rope_init" ):
576
+ m .rope_init ()
577
+
578
+ # This method will convert the full model state dict into a sharded state
579
+ # dict and load into the model
580
+ training .load_from_full_model_state_dict (
581
+ model ,
549
582
model_state_dict ,
550
- custom_sharded_layers ,
583
+ self ._device ,
584
+ strict = True ,
585
+ cpu_offload = fsdp_cpu_offload ,
586
+ )
587
+
588
+ # Ensure no params and buffers are on meta device
589
+ training .validate_no_params_on_meta_device (model )
590
+
591
+ utils .log_rank_zero (
592
+ log ,
593
+ f"Instantiating reference model and loading checkpoint took { time .perf_counter () - init_start :.2f} secs" ,
551
594
)
552
595
596
+ if self ._is_rank_zero :
597
+ memory_stats = training .get_memory_stats (device = self ._device )
598
+ training .log_memory_stats (memory_stats )
599
+
600
+ # disabling dropout if found - non-determinism leads to issues in e.g. comparing logprobs
601
+ # between ref policy and current policy
602
+ for module in model .modules ():
603
+ if isinstance (module , torch .nn .Dropout ):
604
+ warn (
605
+ f"Dropout found in { module } . This is likely to cause issues during training. Disabling."
606
+ )
607
+ module .p = 0
608
+
609
+ for p in self ._ref_model .parameters ():
610
+ p .requires_grad = False
611
+
612
+ model .eval ()
613
+
614
+ # synchronize before training begins
615
+ torch .distributed .barrier ()
616
+
617
+ return model
618
+
553
619
def _setup_optimizer (
554
620
self ,
555
621
cfg_optimizer : DictConfig ,
@@ -831,13 +897,14 @@ def train(self) -> None:
831
897
break
832
898
833
899
# batch is input_ids, labels
834
- num_tokens += batch [0 ].numel ()
835
- (
836
- policy_chosen_log_probs ,
837
- policy_rejected_log_probs ,
838
- policy_chosen_logits ,
839
- policy_rejected_logits ,
840
- ) = self .concatenated_forward (self ._model , batch )
900
+ with self .activations_handling_ctx :
901
+ num_tokens += batch [0 ].numel ()
902
+ (
903
+ policy_chosen_log_probs ,
904
+ policy_rejected_log_probs ,
905
+ policy_chosen_logits ,
906
+ policy_rejected_logits ,
907
+ ) = self .concatenated_forward (self ._model , batch )
841
908
842
909
policy_chosen_logits_mean = policy_chosen_logits .detach ().mean ()
843
910
policy_rejected_logits_mean = policy_rejected_logits .detach ().mean ()
0 commit comments