@@ -935,6 +935,16 @@ def _compute_max_epochs(
935935 2. A cosine half-period schedule (``cosine_epochs = max_epochs``)
936936 reaches exactly zero and never starts increasing again.
937937 """
938+ if seconds_per_epoch <= 0 :
939+ msg = "seconds_per_epoch must be positive"
940+ raise ValueError (msg )
941+ if budget_hours <= 0 :
942+ msg = "budget_hours must be positive"
943+ raise ValueError (msg )
944+ if not (0.0 <= margin < 1.0 ):
945+ msg = "margin must be in [0, 1)"
946+ raise ValueError (msg )
947+
938948 budget_seconds = budget_hours * 3600
939949 usable_seconds = budget_seconds * (1.0 - margin )
940950 max_epochs = math .floor (usable_seconds / seconds_per_epoch )
@@ -952,25 +962,42 @@ def _compute_max_epochs(
952962
953963def _format_max_time (budget_hours : float ) -> str :
954964 """Format *budget_hours* as a ``DD:HH:MM:SS`` string for Lightning."""
955- if budget_hours != int (budget_hours ):
956- return f"{ int (budget_hours ):02d} :{ int (budget_hours % 1 * 60 ):02d} :00:00"
957- return f"{ int (budget_hours ):02d} :00:00:00"
965+ if budget_hours <= 0 :
966+ msg = "budget_hours must be positive"
967+ raise ValueError (msg )
968+
969+ total_seconds = round (budget_hours * 3600 )
970+ days , rem = divmod (total_seconds , 24 * 3600 )
971+ hours , rem = divmod (rem , 3600 )
972+ minutes , seconds = divmod (rem , 60 )
973+ return f"{ days :02d} :{ hours :02d} :{ minutes :02d} :{ seconds :02d} "
958974
959975
960976def _print_timing_results (
961977 epoch_times : list [float ],
962978 budget_hours : float ,
963979 margin : float ,
964- ) -> dict :
980+ ) -> dict | None :
965981 """Compute and print the ``max_epochs`` recommendation from epoch timings."""
966982 seconds_per_epoch = sum (epoch_times ) / len (epoch_times )
967983 print (
968984 "\n Per-epoch times (from TrainingTimerCallback): "
969985 + ", " .join (f"{ t :.1f} s" for t in epoch_times )
970986 )
971987
972- result = _compute_max_epochs (seconds_per_epoch , budget_hours , margin )
973- max_time_str = _format_max_time (budget_hours )
988+ try :
989+ result = _compute_max_epochs (seconds_per_epoch , budget_hours , margin )
990+ max_time_str = _format_max_time (budget_hours )
991+ except ValueError as exc :
992+ print (f"\n ERROR: { exc } " )
993+ return None
994+
995+ if result ["max_epochs" ] < 1 :
996+ print (
997+ "\n ERROR: Computed max_epochs < 1. Increase the budget, reduce the "
998+ "margin, or re-check the epoch timing estimate."
999+ )
1000+ return None
9741001
9751002 print (f"\n { '=' * 60 } " )
9761003 print (f" Seconds/epoch: { result ['seconds_per_epoch' ]:.1f} s" )
@@ -988,83 +1015,39 @@ def _print_timing_results(
9881015 return result
9891016
9901017
991- def time_epochs_command (
1018+ def _validate_time_epochs_args (
1019+ * , num_epochs : int , budget_hours : float , margin : float
1020+ ) -> None :
1021+ if num_epochs < 1 :
1022+ msg = "--num-epochs must be >= 1"
1023+ raise ValueError (msg )
1024+ if budget_hours <= 0 :
1025+ msg = "--budget must be > 0"
1026+ raise ValueError (msg )
1027+ if not (0.0 <= margin < 1.0 ):
1028+ msg = "--margin must be in [0, 1)"
1029+ raise ValueError (msg )
1030+
1031+
1032+ def _run_time_epochs_training (
9921033 * ,
993- kind : str = "epd" ,
1034+ kind : str ,
9941035 mode : str ,
9951036 dataset : str | None ,
9961037 output_base : str ,
9971038 overrides : list [str ],
998- num_epochs : int = 3 ,
999- budget_hours : float = 24.0 ,
1000- margin : float = 0.02 ,
1001- run_group : str | None = None ,
1002- run_id : str | None = None ,
1003- work_dir : str | None = None ,
1004- from_checkpoint : str | None = None ,
1005- runtime_typechecking : bool = False ,
1006- dry_run : bool = False ,
1007- ) -> dict | None :
1008- """Run a short training to time per-epoch duration and recommend ``max_epochs``.
1009-
1010- Executes *num_epochs* epochs of training (ae, epd, or processor) with
1011- W&B logging and testing disabled, saves a checkpoint so that per-epoch
1012- wall-clock times can be extracted from ``TrainingTimerCallback``, and
1013- prints the recommended ``trainer.max_epochs`` for a cosine half-period
1014- schedule (``optimizer=adamw_half``) that completes within *budget_hours*.
1015-
1016- The calculation is conservative: a *margin* fraction is subtracted
1017- from the budget **and** the result is rounded down to a whole epoch,
1018- so the schedule will always reach zero before the wall-clock limit.
1019- ``trainer.max_time`` is emitted as a hard safety stop equal to the
1020- full (un-margined) budget.
1021-
1022- With ``--mode slurm`` the timing run is submitted via sbatch and the
1023- command exits immediately, printing a ``--from-checkpoint`` command to
1024- retrieve results once the job completes.
1025-
1026- Parameters
1027- ----------
1028- kind:
1029- Training kind: ``"ae"``, ``"epd"``, or ``"processor"``.
1030- dataset:
1031- Hydra datamodule group name (e.g. ``"advection_diffusion_multichannel"``).
1032- output_base:
1033- Root output directory (forwarded to ``build_train_overrides``).
1034- overrides:
1035- Additional Hydra overrides forwarded to the timing run.
1036- num_epochs:
1037- How many epochs to run for the timing measurement.
1038- budget_hours:
1039- Target wall-clock budget in hours.
1040- margin:
1041- Fraction of *budget_hours* held back as safety headroom (default 2 %).
1042- from_checkpoint:
1043- Path to an existing checkpoint; skips training and computes the
1044- recommendation directly.
1045- """
1046- # ------------------------------------------------------------------
1047- # Fast path: compute from an existing checkpoint (no training needed)
1048- # ------------------------------------------------------------------
1049- if from_checkpoint is not None :
1050- ckpt = Path (from_checkpoint )
1051- epoch_times = _extract_epoch_times_from_checkpoint (ckpt )
1052- if not epoch_times :
1053- print (
1054- f"ERROR: Could not extract per-epoch times from { ckpt } . "
1055- "Check that the checkpoint was produced by a timing run with "
1056- "TrainingTimerCallback."
1057- )
1058- return None
1059- return _print_timing_results (epoch_times , budget_hours , margin )
1060-
1061- # ------------------------------------------------------------------
1062- # Training path: run a short timing job (local or SLURM)
1063- # ------------------------------------------------------------------
1039+ num_epochs : int ,
1040+ budget_hours : float ,
1041+ margin : float ,
1042+ run_group : str | None ,
1043+ run_id : str | None ,
1044+ work_dir : str | None ,
1045+ runtime_typechecking : bool ,
1046+ dry_run : bool ,
1047+ ) -> tuple [list [float ] | None , bool ]:
1048+ """Run timing training job and return (epoch_times, exit_early)."""
10641049 timing_run_id = run_id or "timing"
10651050
1066- # Local without explicit workdir: use a tempdir (cleaned up after).
1067- # SLURM or explicit workdir: use a persistent path so results survive.
10681051 use_tempdir = mode == "local" and work_dir is None
10691052 tmpdir_ctx = (
10701053 tempfile .TemporaryDirectory (prefix = "autocast_timing_" ) if use_tempdir else None
@@ -1074,9 +1057,6 @@ def time_epochs_command(
10741057 try :
10751058 effective_work_dir = tmpdir if use_tempdir else work_dir
10761059
1077- # Build overrides: short run, no wandb, no test, checkpoint for
1078- # timer extraction. Use a relative checkpoint name so the
1079- # training script resolves it against its own work_dir.
10801060 timing_overrides = [
10811061 f"++trainer.max_epochs={ num_epochs } " ,
10821062 "++trainer.max_time=null" ,
@@ -1093,9 +1073,9 @@ def time_epochs_command(
10931073 output_base = output_base ,
10941074 run_group = run_group ,
10951075 run_id = timing_run_id ,
1096- work_dir = (
1097- str ( effective_work_dir ) if effective_work_dir is not None else None
1098- ) ,
1076+ work_dir = str ( effective_work_dir )
1077+ if effective_work_dir is not None
1078+ else None ,
10991079 resume_from = None ,
11001080 overrides = [* timing_overrides , * overrides ],
11011081 )
@@ -1107,7 +1087,7 @@ def time_epochs_command(
11071087 print (f"DRY-RUN: { format_command (cmd )} " )
11081088 print (f"\n Would time { num_epochs } epochs, then compute max_epochs" )
11091089 print (f"for a { budget_hours } h budget with { margin :.0%} margin." )
1110- return None
1090+ return None , True
11111091
11121092 if mode == "slurm" :
11131093 run_module (
@@ -1122,8 +1102,6 @@ def time_epochs_command(
11221102 f"--from-checkpoint { ckpt_path } "
11231103 f"-b { budget_hours } -m { margin } "
11241104 )
1125- # Write retrieval command to workdir so batch results are easy
1126- # to collect: for f in outputs/timing/*/retrieve.sh; do bash "$f"; done
11271105 final_work_dir .mkdir (parents = True , exist_ok = True )
11281106 (final_work_dir / "retrieve.sh" ).write_text (
11291107 f"#!/usr/bin/env bash\n { retrieve_cmd } \n "
@@ -1133,12 +1111,10 @@ def time_epochs_command(
11331111 print (f" { retrieve_cmd } " )
11341112 print (
11351113 "\n Or collect all timing results at once:\n "
1136- " for f in outputs/timing/*/retrieve.sh; "
1137- 'do bash "$f"; done'
1114+ ' for f in outputs/timing/*/retrieve.sh; do bash "$f"; done'
11381115 )
1139- return None
1116+ return None , True
11401117
1141- # Local execution
11421118 print (f"Timing { num_epochs } epoch(s) to estimate per-epoch duration..." )
11431119 run_module (
11441120 TRAIN_MODULES [kind ],
@@ -1147,12 +1123,107 @@ def time_epochs_command(
11471123 mode = "local" ,
11481124 runtime_typechecking = runtime_typechecking ,
11491125 )
1150-
1151- epoch_times = _extract_epoch_times_from_checkpoint (ckpt_path )
1126+ return _extract_epoch_times_from_checkpoint (ckpt_path ), False
11521127 finally :
11531128 if tmpdir_ctx is not None :
11541129 tmpdir_ctx .__exit__ (None , None , None )
11551130
1131+
1132+ def time_epochs_command (
1133+ * ,
1134+ kind : str = "epd" ,
1135+ mode : str ,
1136+ dataset : str | None ,
1137+ output_base : str ,
1138+ overrides : list [str ],
1139+ num_epochs : int = 3 ,
1140+ budget_hours : float = 24.0 ,
1141+ margin : float = 0.02 ,
1142+ run_group : str | None = None ,
1143+ run_id : str | None = None ,
1144+ work_dir : str | None = None ,
1145+ from_checkpoint : str | None = None ,
1146+ runtime_typechecking : bool = False ,
1147+ dry_run : bool = False ,
1148+ ) -> dict | None :
1149+ """Run a short training to time per-epoch duration and recommend ``max_epochs``.
1150+
1151+ Executes *num_epochs* epochs of training (ae, epd, or processor) with
1152+ W&B logging and testing disabled, saves a checkpoint so that per-epoch
1153+ wall-clock times can be extracted from ``TrainingTimerCallback``, and
1154+ prints the recommended ``trainer.max_epochs`` for a cosine half-period
1155+ schedule (``optimizer=adamw_half``) that completes within *budget_hours*.
1156+
1157+ The calculation is conservative: a *margin* fraction is subtracted
1158+ from the budget **and** the result is rounded down to a whole epoch,
1159+ so the schedule will always reach zero before the wall-clock limit.
1160+ ``trainer.max_time`` is emitted as a hard safety stop equal to the
1161+ full (un-margined) budget.
1162+
1163+ With ``--mode slurm`` the timing run is submitted via sbatch and the
1164+ command exits immediately, printing a ``--from-checkpoint`` command to
1165+ retrieve results once the job completes.
1166+
1167+ Parameters
1168+ ----------
1169+ kind:
1170+ Training kind: ``"ae"``, ``"epd"``, or ``"processor"``.
1171+ dataset:
1172+ Hydra datamodule group name (e.g. ``"advection_diffusion_multichannel"``).
1173+ output_base:
1174+ Root output directory (forwarded to ``build_train_overrides``).
1175+ overrides:
1176+ Additional Hydra overrides forwarded to the timing run.
1177+ num_epochs:
1178+ How many epochs to run for the timing measurement.
1179+ budget_hours:
1180+ Target wall-clock budget in hours.
1181+ margin:
1182+ Fraction of *budget_hours* held back as safety headroom (default 2 %).
1183+ from_checkpoint:
1184+ Path to an existing checkpoint; skips training and computes the
1185+ recommendation directly.
1186+ """
1187+ try :
1188+ _validate_time_epochs_args (
1189+ num_epochs = num_epochs ,
1190+ budget_hours = budget_hours ,
1191+ margin = margin ,
1192+ )
1193+ except ValueError as exc :
1194+ print (f"ERROR: { exc } " )
1195+ return None
1196+
1197+ if from_checkpoint is not None :
1198+ ckpt = Path (from_checkpoint )
1199+ epoch_times = _extract_epoch_times_from_checkpoint (ckpt )
1200+ if not epoch_times :
1201+ print (
1202+ f"ERROR: Could not extract per-epoch times from { ckpt } . "
1203+ "Check that the checkpoint was produced by a timing run with "
1204+ "TrainingTimerCallback."
1205+ )
1206+ return None
1207+ return _print_timing_results (epoch_times , budget_hours , margin )
1208+
1209+ epoch_times , exit_early = _run_time_epochs_training (
1210+ kind = kind ,
1211+ mode = mode ,
1212+ dataset = dataset ,
1213+ output_base = output_base ,
1214+ overrides = overrides ,
1215+ num_epochs = num_epochs ,
1216+ budget_hours = budget_hours ,
1217+ margin = margin ,
1218+ run_group = run_group ,
1219+ run_id = run_id ,
1220+ work_dir = work_dir ,
1221+ runtime_typechecking = runtime_typechecking ,
1222+ dry_run = dry_run ,
1223+ )
1224+ if exit_early :
1225+ return None
1226+
11561227 if not epoch_times :
11571228 print (
11581229 "\n WARNING: Could not extract per-epoch times from checkpoint. "
0 commit comments