Skip to content

Commit

Permalink
Allow snapshot.pickle to be saved for all recipes with profiler suppo…
Browse files Browse the repository at this point in the history
…rt (#1384)
  • Loading branch information
janeyx99 authored Aug 22, 2024
1 parent 9304eeb commit f9f75bb
Show file tree
Hide file tree
Showing 3 changed files with 79 additions and 0 deletions.
27 changes: 27 additions & 0 deletions recipes/full_finetune_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -311,6 +311,12 @@ def _setup_profiler(
if self._is_rank_zero:
log.info(f" Profiler config after instantiation: {profiler_cfg}")

self.profiler_profile_memory = profiler_cfg.get("profile_memory", False)
if profiler_cfg["enabled"]:
self.profiler_wait_steps = profiler_cfg["wait_steps"]
self.profiler_warmup_steps = profiler_cfg["warmup_steps"]
self.profiler_active_steps = profiler_cfg["active_steps"]

return profiler

def _setup_model(
Expand Down Expand Up @@ -569,6 +575,15 @@ def train(self) -> None:
):
break

# Start tracking CUDA memory for active steps for just the first epoch
if (
self._is_rank_zero
and curr_epoch == 0
and self.profiler_profile_memory
and idx == self.profiler_wait_steps + self.profiler_warmup_steps
):
torch.cuda.memory._record_memory_history()

# Both are shape [b, s]
tokens, labels = batch["tokens"], batch["labels"]
# Get the attention mask and position ids from the dataset if they
Expand Down Expand Up @@ -635,6 +650,18 @@ def train(self) -> None:
num_tokens = 0
t0 = time.perf_counter()

# Stop tracking CUDA memory now that active steps are complete
if (
self._is_rank_zero
and curr_epoch == 0
and self.profiler_profile_memory
and idx
== self.profiler_wait_steps
+ self.profiler_warmup_steps
+ self.profiler_active_steps
):
torch.cuda.memory._record_memory_history(enabled=None)

# Step profiler
# Note that this is called within gradient accumulation block, hence
# will include multiple forward / backward passes if gradient accumulation > 1
Expand Down
25 changes: 25 additions & 0 deletions recipes/full_finetune_single_device.py
Original file line number Diff line number Diff line change
Expand Up @@ -312,6 +312,12 @@ def _setup_profiler(

log.info(f" Profiler config after instantiation: {profiler_cfg}")

self.profiler_profile_memory = profiler_cfg.get("profile_memory", False)
if profiler_cfg["enabled"]:
self.profiler_wait_steps = profiler_cfg["wait_steps"]
self.profiler_warmup_steps = profiler_cfg["warmup_steps"]
self.profiler_active_steps = profiler_cfg["active_steps"]

return profiler

def _setup_model(
Expand Down Expand Up @@ -519,6 +525,14 @@ def train(self) -> None:
):
break

# Start tracking CUDA memory for active steps for just the first epoch
if (
curr_epoch == 0
and self.profiler_profile_memory
and idx == self.profiler_wait_steps + self.profiler_warmup_steps
):
torch.cuda.memory._record_memory_history()

batch = {k: v.to(self._device) for k, v in batch.items()}
num_tokens += batch["tokens"].numel()

Expand Down Expand Up @@ -567,6 +581,17 @@ def train(self) -> None:
num_tokens = 0
t0 = time.perf_counter()

# Stop tracking CUDA memory now that active steps are complete
if (
curr_epoch == 0
and self.profiler_profile_memory
and idx
== self.profiler_wait_steps
+ self.profiler_warmup_steps
+ self.profiler_active_steps
):
torch.cuda.memory._record_memory_history(enabled=None)

# Step the profiler
# Note we are stepping each batch, which might not include optimizer step in the trace
# if the schedule cycle doesn't align with gradient accumulation.
Expand Down
27 changes: 27 additions & 0 deletions recipes/lora_finetune_distributed.py
Original file line number Diff line number Diff line change
Expand Up @@ -340,6 +340,12 @@ def _setup_profiler(
if self._is_rank_zero:
log.info(f" Profiler config after instantiation: {profiler_cfg}")

self.profiler_profile_memory = profiler_cfg.get("profile_memory", False)
if profiler_cfg["enabled"]:
self.profiler_wait_steps = profiler_cfg["wait_steps"]
self.profiler_warmup_steps = profiler_cfg["warmup_steps"]
self.profiler_active_steps = profiler_cfg["active_steps"]

return profiler

def _setup_model(
Expand Down Expand Up @@ -654,6 +660,15 @@ def train(self) -> None:
):
break

# Start tracking CUDA memory for active steps for just the first epoch
if (
self._is_rank_zero
and curr_epoch == 0
and self.profiler_profile_memory
and idx == self.profiler_wait_steps + self.profiler_warmup_steps
):
torch.cuda.memory._record_memory_history()

# Both are shape [b, s]
tokens, labels = batch["tokens"], batch["labels"]
# Get the attention mask and position ids from the dataset if they
Expand Down Expand Up @@ -721,6 +736,18 @@ def train(self) -> None:
num_tokens = 0
t0 = time.perf_counter()

# Stop tracking CUDA memory now that active steps are complete
if (
self._is_rank_zero
and curr_epoch == 0
and self.profiler_profile_memory
and idx
== self.profiler_wait_steps
+ self.profiler_warmup_steps
+ self.profiler_active_steps
):
torch.cuda.memory._record_memory_history(enabled=None)

# Step profiler
# Note that this is called within gradient accumulation block, hence
# will include multiple forward / backward passes if gradient accumulation > 1
Expand Down

0 comments on commit f9f75bb

Please sign in to comment.