Skip to content

Commit f9f75bb

Browse files
authored
Allow snapshot.pickle to be saved for all recipes with profiler support (#1384)
1 parent 9304eeb commit f9f75bb

File tree

3 files changed

+79
-0
lines changed

3 files changed

+79
-0
lines changed

recipes/full_finetune_distributed.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -311,6 +311,12 @@ def _setup_profiler(
311311
if self._is_rank_zero:
312312
log.info(f" Profiler config after instantiation: {profiler_cfg}")
313313

314+
self.profiler_profile_memory = profiler_cfg.get("profile_memory", False)
315+
if profiler_cfg["enabled"]:
316+
self.profiler_wait_steps = profiler_cfg["wait_steps"]
317+
self.profiler_warmup_steps = profiler_cfg["warmup_steps"]
318+
self.profiler_active_steps = profiler_cfg["active_steps"]
319+
314320
return profiler
315321

316322
def _setup_model(
@@ -569,6 +575,15 @@ def train(self) -> None:
569575
):
570576
break
571577

578+
# Start tracking CUDA memory for active steps for just the first epoch
579+
if (
580+
self._is_rank_zero
581+
and curr_epoch == 0
582+
and self.profiler_profile_memory
583+
and idx == self.profiler_wait_steps + self.profiler_warmup_steps
584+
):
585+
torch.cuda.memory._record_memory_history()
586+
572587
# Both are shape [b, s]
573588
tokens, labels = batch["tokens"], batch["labels"]
574589
# Get the attention mask and position ids from the dataset if they
@@ -635,6 +650,18 @@ def train(self) -> None:
635650
num_tokens = 0
636651
t0 = time.perf_counter()
637652

653+
# Stop tracking CUDA memory now that active steps are complete
654+
if (
655+
self._is_rank_zero
656+
and curr_epoch == 0
657+
and self.profiler_profile_memory
658+
and idx
659+
== self.profiler_wait_steps
660+
+ self.profiler_warmup_steps
661+
+ self.profiler_active_steps
662+
):
663+
torch.cuda.memory._record_memory_history(enabled=None)
664+
638665
# Step profiler
639666
# Note that this is called within gradient accumulation block, hence
640667
# will include multiple forward / backward passes if gradient accumulation > 1

recipes/full_finetune_single_device.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -312,6 +312,12 @@ def _setup_profiler(
312312

313313
log.info(f" Profiler config after instantiation: {profiler_cfg}")
314314

315+
self.profiler_profile_memory = profiler_cfg.get("profile_memory", False)
316+
if profiler_cfg["enabled"]:
317+
self.profiler_wait_steps = profiler_cfg["wait_steps"]
318+
self.profiler_warmup_steps = profiler_cfg["warmup_steps"]
319+
self.profiler_active_steps = profiler_cfg["active_steps"]
320+
315321
return profiler
316322

317323
def _setup_model(
@@ -519,6 +525,14 @@ def train(self) -> None:
519525
):
520526
break
521527

528+
# Start tracking CUDA memory for active steps for just the first epoch
529+
if (
530+
curr_epoch == 0
531+
and self.profiler_profile_memory
532+
and idx == self.profiler_wait_steps + self.profiler_warmup_steps
533+
):
534+
torch.cuda.memory._record_memory_history()
535+
522536
batch = {k: v.to(self._device) for k, v in batch.items()}
523537
num_tokens += batch["tokens"].numel()
524538

@@ -567,6 +581,17 @@ def train(self) -> None:
567581
num_tokens = 0
568582
t0 = time.perf_counter()
569583

584+
# Stop tracking CUDA memory now that active steps are complete
585+
if (
586+
curr_epoch == 0
587+
and self.profiler_profile_memory
588+
and idx
589+
== self.profiler_wait_steps
590+
+ self.profiler_warmup_steps
591+
+ self.profiler_active_steps
592+
):
593+
torch.cuda.memory._record_memory_history(enabled=None)
594+
570595
# Step the profiler
571596
# Note we are stepping each batch, which might not include optimizer step in the trace
572597
# if the schedule cycle doesn't align with gradient accumulation.

recipes/lora_finetune_distributed.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -340,6 +340,12 @@ def _setup_profiler(
340340
if self._is_rank_zero:
341341
log.info(f" Profiler config after instantiation: {profiler_cfg}")
342342

343+
self.profiler_profile_memory = profiler_cfg.get("profile_memory", False)
344+
if profiler_cfg["enabled"]:
345+
self.profiler_wait_steps = profiler_cfg["wait_steps"]
346+
self.profiler_warmup_steps = profiler_cfg["warmup_steps"]
347+
self.profiler_active_steps = profiler_cfg["active_steps"]
348+
343349
return profiler
344350

345351
def _setup_model(
@@ -654,6 +660,15 @@ def train(self) -> None:
654660
):
655661
break
656662

663+
# Start tracking CUDA memory for active steps for just the first epoch
664+
if (
665+
self._is_rank_zero
666+
and curr_epoch == 0
667+
and self.profiler_profile_memory
668+
and idx == self.profiler_wait_steps + self.profiler_warmup_steps
669+
):
670+
torch.cuda.memory._record_memory_history()
671+
657672
# Both are shape [b, s]
658673
tokens, labels = batch["tokens"], batch["labels"]
659674
# Get the attention mask and position ids from the dataset if they
@@ -721,6 +736,18 @@ def train(self) -> None:
721736
num_tokens = 0
722737
t0 = time.perf_counter()
723738

739+
# Stop tracking CUDA memory now that active steps are complete
740+
if (
741+
self._is_rank_zero
742+
and curr_epoch == 0
743+
and self.profiler_profile_memory
744+
and idx
745+
== self.profiler_wait_steps
746+
+ self.profiler_warmup_steps
747+
+ self.profiler_active_steps
748+
):
749+
torch.cuda.memory._record_memory_history(enabled=None)
750+
724751
# Step profiler
725752
# Note that this is called within gradient accumulation block, hence
726753
# will include multiple forward / backward passes if gradient accumulation > 1

0 commit comments

Comments
 (0)