Skip to content

Commit 9af9ed5

Browse files
authored
Merge pull request #324 from alan-turing-institute/calculate-cosine-epochs
Calculate epoch timings
2 parents 8bbd68e + 18b4a52 commit 9af9ed5

4 files changed

Lines changed: 575 additions & 18 deletions

File tree

docs/SCRIPTS_AND_CONFIGS.md

Lines changed: 134 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -141,6 +141,140 @@ For launching many prewritten runs from a manifest list:
141141
bash scripts/launch_from_manifest.sh run_manifests/example_runs.txt
142142
```
143143

144+
### Timing epochs and computing `max_epochs` for cosine schedules
145+
146+
When using the `adamw_half` optimizer (half-period cosine LR schedule), the
147+
learning rate decays from its initial value to zero over exactly
148+
`trainer.max_epochs` epochs. If training is cut short by `trainer.max_time`
149+
before all epochs complete, the schedule will not have reached zero.
150+
151+
The `time-epochs` subcommand solves this by running a short timing run (a few
152+
epochs), measuring per-epoch wall-clock duration, and computing the
153+
`max_epochs` that fits within a given budget:
154+
155+
```bash
156+
# Time 3 EPD epochs (default) and compute max_epochs for a 24h budget
157+
uv run autocast time-epochs datamodule=advection_diffusion_multichannel
158+
159+
# Time an autoencoder run
160+
uv run autocast time-epochs --kind ae datamodule=reaction_diffusion
161+
162+
# Time a processor run
163+
uv run autocast time-epochs --kind processor datamodule=reaction_diffusion
164+
165+
# Custom: 5 timing epochs, 12h budget, 2% safety margin
166+
uv run autocast time-epochs -n 5 -b 12 -m 0.02 \
167+
datamodule=shallow_water2d
168+
169+
# With experiment overrides
170+
uv run autocast time-epochs experiment=epd_crps_vit_large_ps4_64
171+
172+
# Dry-run to inspect the generated command
173+
uv run autocast time-epochs --dry-run datamodule=reaction_diffusion
174+
```
175+
176+
`--kind` selects the training type to time: `ae`, `epd` (default), or
177+
`processor`. Use the same kind you intend to train so that the per-epoch
178+
measurement reflects the actual model and data pipeline.
179+
180+
#### Batch timing via SLURM
181+
182+
With `--mode slurm` the timing run is submitted as a SLURM job and the CLI
183+
exits immediately, printing a follow-up command to retrieve results once the
184+
job completes:
185+
186+
```bash
187+
# Submit timing jobs for several configs at once
188+
uv run autocast time-epochs --mode slurm --kind ae \
189+
datamodule=reaction_diffusion --run-group timing
190+
uv run autocast time-epochs --mode slurm --kind epd \
191+
datamodule=shallow_water2d --run-group timing \
192+
experiment=epd_crps_vit_large_ps4_64
193+
194+
# Once the SLURM jobs finish, compute results from the checkpoints
195+
uv run autocast time-epochs --from-checkpoint outputs/timing/ae_.../timing.ckpt
196+
uv run autocast time-epochs --from-checkpoint outputs/timing/epd_.../timing.ckpt
197+
```
198+
199+
`--from-checkpoint` reads an existing checkpoint, extracts the per-epoch
200+
times, and prints the recommendation — no training is run. You can also
201+
use it to recompute with a different budget or margin:
202+
203+
```bash
204+
uv run autocast time-epochs --from-checkpoint outputs/timing/epd_.../timing.ckpt \
205+
-b 12 -m 0.05
206+
```
207+
208+
The output includes recommended Hydra overrides ready to copy-paste:
209+
210+
```
211+
============================================================
212+
Seconds/epoch: 150.0s
213+
Budget: 24.0h (margin: 2%)
214+
max_epochs: 564
215+
Expected time: 23.5h
216+
Headroom: 0.5h
217+
============================================================
218+
219+
Recommended overrides:
220+
trainer.max_epochs=564 trainer.max_time=01:00:00:00 optimizer=adamw_half
221+
```
222+
223+
The calculation is conservative:
224+
- A 2% safety margin (configurable with `-m`) is subtracted from the budget.
225+
- The result is rounded **down** to a whole epoch (`floor`), so the cosine
226+
schedule always completes its full half-period.
227+
- `trainer.max_time` is set to the full (un-margined) budget as a hard stop.
228+
229+
Per-epoch times are extracted from the `TrainingTimerCallback` saved in the
230+
checkpoint, which excludes model setup and data loading overhead.
231+
232+
#### How `max_epochs` and `max_time` interact at runtime
233+
234+
The recommended overrides set **two** stopping conditions:
235+
236+
| Condition | Controlled by | What happens |
237+
|---|---|---|
238+
| Epoch limit | `trainer.max_epochs` | Training stops cleanly after completing this many epochs. |
239+
| Wall-clock limit | `trainer.max_time` | Lightning hard-stops training when the clock runs out. |
240+
241+
Lightning stops at whichever fires first.
242+
243+
**Faster than expected** (each epoch takes less time than the timing run
244+
measured): `max_epochs` fires first. All epochs complete, and the cosine LR
245+
schedule reaches exactly zero. `max_time` is never triggered. This is the
246+
ideal outcome.
247+
248+
**Slower than expected** (each epoch takes more time): `max_time` fires first,
249+
cutting training short before all `max_epochs` have completed. The cosine
250+
schedule has *not* reached zero — the final LR is positive.
251+
252+
The 2% default margin tolerates up to ~2% slower epochs before `max_time`
253+
intervenes. The `floor()` rounding adds a small additional buffer (up to
254+
one epoch's worth). For workloads where epoch duration is stable
255+
(compute-bound, data in memory), 2% is sufficient. For I/O-bound workloads
256+
that stream from a shared parallel filesystem, consider `--margin 0.05` or
257+
higher.
258+
259+
**The cosine cannot overshoot and start increasing.**
260+
`cosine_lambda(t) = 0.5 * (1 + cos(pi * t / max_epochs))` is monotonically
261+
decreasing over `[0, max_epochs]`. Training terminates at `max_epochs`, so
262+
the second half of the cosine period is never entered. If `max_time`
263+
intervenes earlier, the LR is still on the decreasing branch — it simply
264+
hasn't reached zero yet.
265+
266+
#### Choosing a margin
267+
268+
| Scenario | Recommended `--margin` |
269+
|---|---|
270+
| Data in memory, single GPU (very stable epoch times) | 0.02 (default) |
271+
| Local NVMe data loading | 0.02 – 0.03 |
272+
| Streaming from Lustre / GPFS | 0.05 – 0.10 |
273+
274+
To empirically check variance, run `time-epochs` twice at different cluster
275+
load levels. If the two per-epoch estimates agree within 3%, 2% margin is
276+
safe. If they diverge more, match the margin to the observed variance.
277+
144278
## Lower-level script entry points (advanced)
145279

146280
AutoCast uses a set of Python scripts located in `src/autocast/scripts/` as entry points for training and evaluation. These scripts are exposed as CLI commands via `pyproject.toml`.

src/autocast/scripts/training.py

Lines changed: 23 additions & 17 deletions
Original file line numberDiff line numberDiff line change
@@ -269,9 +269,14 @@ def _attach_reset_timer_callback(
269269
class TrainingTimerCallback(Callback):
270270
"""Measures wall-clock training time and persists it to the checkpoint.
271271
272-
Records total training time and per-epoch durations. The values are
273-
stored via ``state_dict()`` so the eval script can read them directly
274-
from the checkpoint's ``callbacks`` block.
272+
Records total training time and per-epoch durations. Each epoch
273+
measurement spans the **full cycle** — training batches *and* the
274+
subsequent validation loop — so that the ``time-epochs`` command can
275+
accurately predict wall-clock budget consumption.
276+
277+
Epoch boundaries are measured from one ``on_train_epoch_start`` to the
278+
next; the final epoch is closed out in ``on_train_end`` (which fires
279+
after the last validation loop).
275280
276281
Note
277282
----
@@ -305,19 +310,20 @@ def on_train_epoch_start(
305310
self, trainer: L.Trainer, pl_module: L.LightningModule
306311
) -> None:
307312
del trainer, pl_module
308-
self._epoch_start = perf_counter()
309-
310-
def on_train_epoch_end(
311-
self, trainer: L.Trainer, pl_module: L.LightningModule
312-
) -> None:
313-
del trainer, pl_module
313+
now = perf_counter()
314+
# Close out the *previous* epoch (training + validation + overhead).
314315
if self._epoch_start is not None:
315-
self._epoch_times_s.append(perf_counter() - self._epoch_start)
316+
self._epoch_times_s.append(now - self._epoch_start)
317+
self._epoch_start = now
316318

317319
def on_train_end(self, trainer: L.Trainer, pl_module: L.LightningModule) -> None:
318320
del trainer, pl_module
321+
now = perf_counter()
322+
# Close out the final epoch (includes its validation loop).
323+
if self._epoch_start is not None:
324+
self._epoch_times_s.append(now - self._epoch_start)
319325
if self._train_start is not None:
320-
self.training_runtime_total_s = perf_counter() - self._train_start
326+
self.training_runtime_total_s = now - self._train_start
321327

322328
def state_dict(self) -> dict: # type: ignore[override]
323329
runtime_elapsed_s = self._current_elapsed_runtime_s()
@@ -570,6 +576,7 @@ def train_autoencoder(
570576
trainer = instantiate(
571577
trainer_cfg, logger=wandb_logger, default_root_dir=str(work_dir)
572578
)
579+
trainer.callbacks.append(TrainingTimerCallback())
573580
output_cfg = config.get("output", {})
574581
if output_cfg.get("save_config", False) and trainer.is_global_zero:
575582
save_resolved_config(
@@ -622,12 +629,11 @@ def train_autoencoder(
622629
log.info("Starting training from scratch (no resume checkpoint).")
623630
trainer.fit(model=model, datamodule=datamodule)
624631

625-
checkpoint_name = output_cfg.get("checkpoint_name", "autoencoder.ckpt")
626-
checkpoint_target = Path(checkpoint_name)
627-
checkpoint_path = (
628-
checkpoint_target
629-
if checkpoint_target.is_absolute()
630-
else (work_dir / checkpoint_target)
632+
checkpoint_path = _resolve_checkpoint_path(
633+
work_dir,
634+
output_cfg,
635+
output_cfg.get("checkpoint_path"),
636+
default_name="autoencoder.ckpt",
631637
)
632638
if trainer.is_global_zero:
633639
trainer.save_checkpoint(checkpoint_path)

src/autocast/scripts/workflow/cli.py

Lines changed: 72 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,7 @@
1212
eval_command,
1313
infer_dataset_from_workdir,
1414
infer_resume_checkpoint,
15+
time_epochs_command,
1516
train_command,
1617
train_eval_single_job_command,
1718
)
@@ -154,6 +155,53 @@ def build_parser() -> argparse.ArgumentParser:
154155
)
155156
_add_common_args(cache_parser)
156157

158+
# -- time-epochs -------------------------------------------------------
159+
time_parser = subparsers.add_parser(
160+
"time-epochs",
161+
description=(
162+
"Run a short training (ae, epd, or processor) to time per-epoch "
163+
"duration and compute the recommended trainer.max_epochs for a "
164+
"cosine half-period schedule within a given wall-clock budget."
165+
),
166+
)
167+
_add_train_args(time_parser)
168+
time_parser.add_argument(
169+
"--kind",
170+
choices=["ae", "epd", "processor"],
171+
default="epd",
172+
help="Training kind to time (default: epd).",
173+
)
174+
time_parser.add_argument(
175+
"-n",
176+
"--num-epochs",
177+
type=int,
178+
default=3,
179+
help="Number of epochs to run for timing (default: 3).",
180+
)
181+
time_parser.add_argument(
182+
"-b",
183+
"--budget",
184+
type=float,
185+
default=24.0,
186+
help="Wall-clock budget in hours (default: 24).",
187+
)
188+
time_parser.add_argument(
189+
"-m",
190+
"--margin",
191+
type=float,
192+
default=0.02,
193+
help="Safety margin fraction subtracted from budget (default: 0.02 = 2%%).",
194+
)
195+
time_parser.add_argument(
196+
"--from-checkpoint",
197+
metavar="CKPT",
198+
help=(
199+
"Path to an existing timing checkpoint. Skips training and "
200+
"computes the recommendation directly."
201+
),
202+
)
203+
_add_common_args(time_parser)
204+
157205
return parser
158206

159207

@@ -320,6 +368,30 @@ def main() -> None:
320368
)
321369
return
322370

371+
if args.command == "time-epochs":
372+
dataset = _resolve_dataset(
373+
work_dir=args.workdir,
374+
overrides=combined_overrides,
375+
)
376+
377+
time_epochs_command(
378+
kind=args.kind,
379+
mode=args.mode,
380+
dataset=dataset,
381+
output_base=args.output_base,
382+
overrides=combined_overrides,
383+
num_epochs=args.num_epochs,
384+
budget_hours=args.budget,
385+
margin=args.margin,
386+
run_group=args.run_group,
387+
run_id=args.run_id,
388+
work_dir=args.workdir,
389+
from_checkpoint=args.from_checkpoint,
390+
runtime_typechecking=args.runtime_typechecking,
391+
dry_run=args.dry_run,
392+
)
393+
return
394+
323395
raise ValueError(f"Unsupported command: {args.command}")
324396

325397

0 commit comments

Comments
 (0)