Skip to content

Commit 18b4a52

Browse files
committed
Fix time-epochs max_time formatting
Correct max_time output to DD:HH:MM:SS and add validation for budget, margin, and num_epochs. Update docs example to match.
1 parent 9a87a02 commit 18b4a52

2 files changed

Lines changed: 163 additions & 92 deletions

File tree

docs/SCRIPTS_AND_CONFIGS.md

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -217,7 +217,7 @@ The output includes recommended Hydra overrides ready to copy-paste:
217217
============================================================
218218
219219
Recommended overrides:
220-
trainer.max_epochs=564 trainer.max_time=24:00:00:00 optimizer=adamw_half
220+
trainer.max_epochs=564 trainer.max_time=01:00:00:00 optimizer=adamw_half
221221
```
222222

223223
The calculation is conservative:

src/autocast/scripts/workflow/commands.py

Lines changed: 162 additions & 91 deletions
Original file line numberDiff line numberDiff line change
@@ -935,6 +935,16 @@ def _compute_max_epochs(
935935
2. A cosine half-period schedule (``cosine_epochs = max_epochs``)
936936
reaches exactly zero and never starts increasing again.
937937
"""
938+
if seconds_per_epoch <= 0:
939+
msg = "seconds_per_epoch must be positive"
940+
raise ValueError(msg)
941+
if budget_hours <= 0:
942+
msg = "budget_hours must be positive"
943+
raise ValueError(msg)
944+
if not (0.0 <= margin < 1.0):
945+
msg = "margin must be in [0, 1)"
946+
raise ValueError(msg)
947+
938948
budget_seconds = budget_hours * 3600
939949
usable_seconds = budget_seconds * (1.0 - margin)
940950
max_epochs = math.floor(usable_seconds / seconds_per_epoch)
@@ -952,25 +962,42 @@ def _compute_max_epochs(
952962

953963
def _format_max_time(budget_hours: float) -> str:
954964
"""Format *budget_hours* as a ``DD:HH:MM:SS`` string for Lightning."""
955-
if budget_hours != int(budget_hours):
956-
return f"{int(budget_hours):02d}:{int(budget_hours % 1 * 60):02d}:00:00"
957-
return f"{int(budget_hours):02d}:00:00:00"
965+
if budget_hours <= 0:
966+
msg = "budget_hours must be positive"
967+
raise ValueError(msg)
968+
969+
total_seconds = round(budget_hours * 3600)
970+
days, rem = divmod(total_seconds, 24 * 3600)
971+
hours, rem = divmod(rem, 3600)
972+
minutes, seconds = divmod(rem, 60)
973+
return f"{days:02d}:{hours:02d}:{minutes:02d}:{seconds:02d}"
958974

959975

960976
def _print_timing_results(
961977
epoch_times: list[float],
962978
budget_hours: float,
963979
margin: float,
964-
) -> dict:
980+
) -> dict | None:
965981
"""Compute and print the ``max_epochs`` recommendation from epoch timings."""
966982
seconds_per_epoch = sum(epoch_times) / len(epoch_times)
967983
print(
968984
"\nPer-epoch times (from TrainingTimerCallback): "
969985
+ ", ".join(f"{t:.1f}s" for t in epoch_times)
970986
)
971987

972-
result = _compute_max_epochs(seconds_per_epoch, budget_hours, margin)
973-
max_time_str = _format_max_time(budget_hours)
988+
try:
989+
result = _compute_max_epochs(seconds_per_epoch, budget_hours, margin)
990+
max_time_str = _format_max_time(budget_hours)
991+
except ValueError as exc:
992+
print(f"\nERROR: {exc}")
993+
return None
994+
995+
if result["max_epochs"] < 1:
996+
print(
997+
"\nERROR: Computed max_epochs < 1. Increase the budget, reduce the "
998+
"margin, or re-check the epoch timing estimate."
999+
)
1000+
return None
9741001

9751002
print(f"\n{'=' * 60}")
9761003
print(f" Seconds/epoch: {result['seconds_per_epoch']:.1f}s")
@@ -988,83 +1015,39 @@ def _print_timing_results(
9881015
return result
9891016

9901017

991-
def time_epochs_command(
1018+
def _validate_time_epochs_args(
1019+
*, num_epochs: int, budget_hours: float, margin: float
1020+
) -> None:
1021+
if num_epochs < 1:
1022+
msg = "--num-epochs must be >= 1"
1023+
raise ValueError(msg)
1024+
if budget_hours <= 0:
1025+
msg = "--budget must be > 0"
1026+
raise ValueError(msg)
1027+
if not (0.0 <= margin < 1.0):
1028+
msg = "--margin must be in [0, 1)"
1029+
raise ValueError(msg)
1030+
1031+
1032+
def _run_time_epochs_training(
9921033
*,
993-
kind: str = "epd",
1034+
kind: str,
9941035
mode: str,
9951036
dataset: str | None,
9961037
output_base: str,
9971038
overrides: list[str],
998-
num_epochs: int = 3,
999-
budget_hours: float = 24.0,
1000-
margin: float = 0.02,
1001-
run_group: str | None = None,
1002-
run_id: str | None = None,
1003-
work_dir: str | None = None,
1004-
from_checkpoint: str | None = None,
1005-
runtime_typechecking: bool = False,
1006-
dry_run: bool = False,
1007-
) -> dict | None:
1008-
"""Run a short training to time per-epoch duration and recommend ``max_epochs``.
1009-
1010-
Executes *num_epochs* epochs of training (ae, epd, or processor) with
1011-
W&B logging and testing disabled, saves a checkpoint so that per-epoch
1012-
wall-clock times can be extracted from ``TrainingTimerCallback``, and
1013-
prints the recommended ``trainer.max_epochs`` for a cosine half-period
1014-
schedule (``optimizer=adamw_half``) that completes within *budget_hours*.
1015-
1016-
The calculation is conservative: a *margin* fraction is subtracted
1017-
from the budget **and** the result is rounded down to a whole epoch,
1018-
so the schedule will always reach zero before the wall-clock limit.
1019-
``trainer.max_time`` is emitted as a hard safety stop equal to the
1020-
full (un-margined) budget.
1021-
1022-
With ``--mode slurm`` the timing run is submitted via sbatch and the
1023-
command exits immediately, printing a ``--from-checkpoint`` command to
1024-
retrieve results once the job completes.
1025-
1026-
Parameters
1027-
----------
1028-
kind:
1029-
Training kind: ``"ae"``, ``"epd"``, or ``"processor"``.
1030-
dataset:
1031-
Hydra datamodule group name (e.g. ``"advection_diffusion_multichannel"``).
1032-
output_base:
1033-
Root output directory (forwarded to ``build_train_overrides``).
1034-
overrides:
1035-
Additional Hydra overrides forwarded to the timing run.
1036-
num_epochs:
1037-
How many epochs to run for the timing measurement.
1038-
budget_hours:
1039-
Target wall-clock budget in hours.
1040-
margin:
1041-
Fraction of *budget_hours* held back as safety headroom (default 2 %).
1042-
from_checkpoint:
1043-
Path to an existing checkpoint; skips training and computes the
1044-
recommendation directly.
1045-
"""
1046-
# ------------------------------------------------------------------
1047-
# Fast path: compute from an existing checkpoint (no training needed)
1048-
# ------------------------------------------------------------------
1049-
if from_checkpoint is not None:
1050-
ckpt = Path(from_checkpoint)
1051-
epoch_times = _extract_epoch_times_from_checkpoint(ckpt)
1052-
if not epoch_times:
1053-
print(
1054-
f"ERROR: Could not extract per-epoch times from {ckpt}. "
1055-
"Check that the checkpoint was produced by a timing run with "
1056-
"TrainingTimerCallback."
1057-
)
1058-
return None
1059-
return _print_timing_results(epoch_times, budget_hours, margin)
1060-
1061-
# ------------------------------------------------------------------
1062-
# Training path: run a short timing job (local or SLURM)
1063-
# ------------------------------------------------------------------
1039+
num_epochs: int,
1040+
budget_hours: float,
1041+
margin: float,
1042+
run_group: str | None,
1043+
run_id: str | None,
1044+
work_dir: str | None,
1045+
runtime_typechecking: bool,
1046+
dry_run: bool,
1047+
) -> tuple[list[float] | None, bool]:
1048+
"""Run timing training job and return (epoch_times, exit_early)."""
10641049
timing_run_id = run_id or "timing"
10651050

1066-
# Local without explicit workdir: use a tempdir (cleaned up after).
1067-
# SLURM or explicit workdir: use a persistent path so results survive.
10681051
use_tempdir = mode == "local" and work_dir is None
10691052
tmpdir_ctx = (
10701053
tempfile.TemporaryDirectory(prefix="autocast_timing_") if use_tempdir else None
@@ -1074,9 +1057,6 @@ def time_epochs_command(
10741057
try:
10751058
effective_work_dir = tmpdir if use_tempdir else work_dir
10761059

1077-
# Build overrides: short run, no wandb, no test, checkpoint for
1078-
# timer extraction. Use a relative checkpoint name so the
1079-
# training script resolves it against its own work_dir.
10801060
timing_overrides = [
10811061
f"++trainer.max_epochs={num_epochs}",
10821062
"++trainer.max_time=null",
@@ -1093,9 +1073,9 @@ def time_epochs_command(
10931073
output_base=output_base,
10941074
run_group=run_group,
10951075
run_id=timing_run_id,
1096-
work_dir=(
1097-
str(effective_work_dir) if effective_work_dir is not None else None
1098-
),
1076+
work_dir=str(effective_work_dir)
1077+
if effective_work_dir is not None
1078+
else None,
10991079
resume_from=None,
11001080
overrides=[*timing_overrides, *overrides],
11011081
)
@@ -1107,7 +1087,7 @@ def time_epochs_command(
11071087
print(f"DRY-RUN: {format_command(cmd)}")
11081088
print(f"\nWould time {num_epochs} epochs, then compute max_epochs")
11091089
print(f"for a {budget_hours}h budget with {margin:.0%} margin.")
1110-
return None
1090+
return None, True
11111091

11121092
if mode == "slurm":
11131093
run_module(
@@ -1122,8 +1102,6 @@ def time_epochs_command(
11221102
f"--from-checkpoint {ckpt_path} "
11231103
f"-b {budget_hours} -m {margin}"
11241104
)
1125-
# Write retrieval command to workdir so batch results are easy
1126-
# to collect: for f in outputs/timing/*/retrieve.sh; do bash "$f"; done
11271105
final_work_dir.mkdir(parents=True, exist_ok=True)
11281106
(final_work_dir / "retrieve.sh").write_text(
11291107
f"#!/usr/bin/env bash\n{retrieve_cmd}\n"
@@ -1133,12 +1111,10 @@ def time_epochs_command(
11331111
print(f" {retrieve_cmd}")
11341112
print(
11351113
"\nOr collect all timing results at once:\n"
1136-
" for f in outputs/timing/*/retrieve.sh; "
1137-
'do bash "$f"; done'
1114+
' for f in outputs/timing/*/retrieve.sh; do bash "$f"; done'
11381115
)
1139-
return None
1116+
return None, True
11401117

1141-
# Local execution
11421118
print(f"Timing {num_epochs} epoch(s) to estimate per-epoch duration...")
11431119
run_module(
11441120
TRAIN_MODULES[kind],
@@ -1147,12 +1123,107 @@ def time_epochs_command(
11471123
mode="local",
11481124
runtime_typechecking=runtime_typechecking,
11491125
)
1150-
1151-
epoch_times = _extract_epoch_times_from_checkpoint(ckpt_path)
1126+
return _extract_epoch_times_from_checkpoint(ckpt_path), False
11521127
finally:
11531128
if tmpdir_ctx is not None:
11541129
tmpdir_ctx.__exit__(None, None, None)
11551130

1131+
1132+
def time_epochs_command(
1133+
*,
1134+
kind: str = "epd",
1135+
mode: str,
1136+
dataset: str | None,
1137+
output_base: str,
1138+
overrides: list[str],
1139+
num_epochs: int = 3,
1140+
budget_hours: float = 24.0,
1141+
margin: float = 0.02,
1142+
run_group: str | None = None,
1143+
run_id: str | None = None,
1144+
work_dir: str | None = None,
1145+
from_checkpoint: str | None = None,
1146+
runtime_typechecking: bool = False,
1147+
dry_run: bool = False,
1148+
) -> dict | None:
1149+
"""Run a short training to time per-epoch duration and recommend ``max_epochs``.
1150+
1151+
Executes *num_epochs* epochs of training (ae, epd, or processor) with
1152+
W&B logging and testing disabled, saves a checkpoint so that per-epoch
1153+
wall-clock times can be extracted from ``TrainingTimerCallback``, and
1154+
prints the recommended ``trainer.max_epochs`` for a cosine half-period
1155+
schedule (``optimizer=adamw_half``) that completes within *budget_hours*.
1156+
1157+
The calculation is conservative: a *margin* fraction is subtracted
1158+
from the budget **and** the result is rounded down to a whole epoch,
1159+
so the schedule will always reach zero before the wall-clock limit.
1160+
``trainer.max_time`` is emitted as a hard safety stop equal to the
1161+
full (un-margined) budget.
1162+
1163+
With ``--mode slurm`` the timing run is submitted via sbatch and the
1164+
command exits immediately, printing a ``--from-checkpoint`` command to
1165+
retrieve results once the job completes.
1166+
1167+
Parameters
1168+
----------
1169+
kind:
1170+
Training kind: ``"ae"``, ``"epd"``, or ``"processor"``.
1171+
dataset:
1172+
Hydra datamodule group name (e.g. ``"advection_diffusion_multichannel"``).
1173+
output_base:
1174+
Root output directory (forwarded to ``build_train_overrides``).
1175+
overrides:
1176+
Additional Hydra overrides forwarded to the timing run.
1177+
num_epochs:
1178+
How many epochs to run for the timing measurement.
1179+
budget_hours:
1180+
Target wall-clock budget in hours.
1181+
margin:
1182+
Fraction of *budget_hours* held back as safety headroom (default 2 %).
1183+
from_checkpoint:
1184+
Path to an existing checkpoint; skips training and computes the
1185+
recommendation directly.
1186+
"""
1187+
try:
1188+
_validate_time_epochs_args(
1189+
num_epochs=num_epochs,
1190+
budget_hours=budget_hours,
1191+
margin=margin,
1192+
)
1193+
except ValueError as exc:
1194+
print(f"ERROR: {exc}")
1195+
return None
1196+
1197+
if from_checkpoint is not None:
1198+
ckpt = Path(from_checkpoint)
1199+
epoch_times = _extract_epoch_times_from_checkpoint(ckpt)
1200+
if not epoch_times:
1201+
print(
1202+
f"ERROR: Could not extract per-epoch times from {ckpt}. "
1203+
"Check that the checkpoint was produced by a timing run with "
1204+
"TrainingTimerCallback."
1205+
)
1206+
return None
1207+
return _print_timing_results(epoch_times, budget_hours, margin)
1208+
1209+
epoch_times, exit_early = _run_time_epochs_training(
1210+
kind=kind,
1211+
mode=mode,
1212+
dataset=dataset,
1213+
output_base=output_base,
1214+
overrides=overrides,
1215+
num_epochs=num_epochs,
1216+
budget_hours=budget_hours,
1217+
margin=margin,
1218+
run_group=run_group,
1219+
run_id=run_id,
1220+
work_dir=work_dir,
1221+
runtime_typechecking=runtime_typechecking,
1222+
dry_run=dry_run,
1223+
)
1224+
if exit_early:
1225+
return None
1226+
11561227
if not epoch_times:
11571228
print(
11581229
"\nWARNING: Could not extract per-epoch times from checkpoint. "

0 commit comments

Comments
 (0)