22
22
from torchtune .recipe_interfaces import FTRecipeInterface
23
23
from torchtune .training import DummyProfiler , PROFILER_KEY
24
24
from torchtune .training .activations import apply_selective_activation_checkpointing
25
+ from torchtune .training .checkpointing ._checkpoint_client import CheckpointClient
25
26
from torchtune .utils import get_world_size_and_rank
26
27
from tqdm import tqdm
27
28
@@ -112,7 +113,6 @@ class FullDPORecipeDistributed(FTRecipeInterface):
112
113
"""
113
114
114
115
def __init__ (self , cfg : DictConfig ) -> None :
115
-
116
116
self ._device = utils .get_device (device = cfg .device )
117
117
self ._dtype = training .get_dtype (cfg .dtype , device = self ._device )
118
118
@@ -121,15 +121,6 @@ def __init__(self, cfg: DictConfig) -> None:
121
121
"full fp16 training is not supported with this recipe. Please use bf16 or fp32 instead."
122
122
)
123
123
124
- if (
125
- cfg .get ("fsdp_cpu_offload" , False )
126
- and cfg .optimizer .get ("fused" , False )
127
- and not utils .torch_version_ge ("2.4.0" )
128
- ):
129
- raise RuntimeError (
130
- "Using fused optimizer on CPU is only supported in PyTorch nightly."
131
- )
132
-
133
124
# logging attributes
134
125
self ._output_dir = cfg .output_dir
135
126
self ._log_every_n_steps = cfg .get ("log_every_n_steps" , 1 )
@@ -141,8 +132,6 @@ def __init__(self, cfg: DictConfig) -> None:
141
132
)
142
133
self ._log_peak_memory_stats = False
143
134
144
- # _is_rank_zero is used primarily for logging. In the future, the logger
145
- # should directly take care of this
146
135
_ , rank = get_world_size_and_rank ()
147
136
self ._is_rank_zero = rank == 0
148
137
@@ -151,6 +140,7 @@ def __init__(self, cfg: DictConfig) -> None:
151
140
self ._gradient_accumulation_steps = cfg .gradient_accumulation_steps
152
141
self ._optimizer_in_bwd = cfg .get ("optimizer_in_bwd" , False )
153
142
self ._clip_grad_norm = cfg .get ("clip_grad_norm" , None )
143
+ self ._checkpoint_client = CheckpointClient (cfg )
154
144
155
145
# Optimizer in backward is not compatible with gradient accumulation or gradient clipping
156
146
if self ._optimizer_in_bwd :
@@ -273,9 +263,9 @@ def setup(self, cfg: DictConfig) -> None:
273
263
274
264
# log config with parameter override
275
265
self ._metric_logger .log_config (cfg )
276
- log .info ("_metric_logger is initialized." )
277
266
278
- checkpoint_dict = self .load_checkpoint (cfg_checkpointer = cfg .checkpointer )
267
+ # Load the base model
268
+ checkpoint_dict = self ._checkpoint_client .load_base_checkpoint ()
279
269
280
270
self ._compile = cfg .get ("compile" , False )
281
271
self ._model = self ._setup_model (
@@ -443,22 +433,18 @@ def _setup_model(
443
433
full state dicts are loaded with ``torch.load(mmap=True)``
444
434
"""
445
435
446
- if self . _is_rank_zero :
447
- log . info (
448
- "FSDP is enabled. Instantiating model and loading checkpoint on Rank 0 ..."
449
- )
450
- init_start = time .perf_counter ()
436
+ utils . log_rank_zero (
437
+ log ,
438
+ "FSDP is enabled. Instantiating model and loading checkpoint on Rank 0 ..." ,
439
+ )
440
+ init_start = time .perf_counter ()
451
441
452
442
with training .set_default_dtype (self ._dtype ), torch .device ("meta" ):
453
443
model = config .instantiate (cfg_model )
454
444
455
445
if self ._compile :
456
446
training .compile_model (model , verbose = self ._is_rank_zero )
457
447
458
- model .load_state_dict (model_state_dict , assign = True )
459
- if self ._dtype == torch .bfloat16 :
460
- model = model .to (torch .bfloat16 )
461
-
462
448
# We currently have two versions of activation checkpointing in this recipe
463
449
# for testing and BC purposes. ``enable_activation_checkpointing`` controls
464
450
# the older version of AC and this behavior is unchanged
@@ -516,10 +502,12 @@ def _setup_model(
516
502
# Ensure no params and buffers are on meta device
517
503
training .validate_no_params_on_meta_device (model )
518
504
505
+ utils .log_rank_zero (
506
+ log ,
507
+ f"Instantiating model and loading checkpoint took { time .perf_counter () - init_start :.2f} secs" ,
508
+ )
509
+
519
510
if self ._is_rank_zero :
520
- log .info (
521
- f"Instantiating model and loading checkpoint took { time .perf_counter () - init_start :.2f} secs"
522
- )
523
511
memory_stats = training .get_memory_stats (device = self ._device )
524
512
training .log_memory_stats (memory_stats )
525
513
@@ -536,7 +524,7 @@ def _setup_model(
536
524
torch .distributed .barrier ()
537
525
538
526
return model
539
-
527
+
540
528
def _setup_reference_model (
541
529
self ,
542
530
cfg_model : DictConfig ,
@@ -551,15 +539,15 @@ def _setup_reference_model(
551
539
the right dtype
552
540
b. All ranks calls ``load_state_dict`` without peaking CPU RAMs since
553
541
full state dicts are loaded with ``torch.load(mmap=True)``
554
- """
542
+ """
555
543
return self ._setup_model (
556
544
cfg_model ,
557
545
False ,
558
546
False ,
559
547
fsdp_cpu_offload ,
560
548
reshard_after_forward ,
561
549
model_state_dict ,
562
- custom_sharded_layers
550
+ custom_sharded_layers ,
563
551
)
564
552
565
553
def _setup_optimizer (
@@ -963,6 +951,7 @@ def recipe_main(cfg: DictConfig) -> None:
963
951
"Distributed finetune recipe should be run via a distributed launcher."
964
952
"If using tune CLI, please specify --nnodes 1 and --nproc_per_node [num_gpus]"
965
953
)
954
+
966
955
init_process_group ("cuda:nccl,cpu:gloo" )
967
956
if cfg .get ("fsdp_cpu_offload" , False ):
968
957
# Utilize all available CPU cores for intra-op parallelism. This provides ~2x
0 commit comments