Skip to content

Commit c0c6a46

Browse files
committed
Update default stride to be derived from n_step_outputs but can be overridden
1 parent d8244a9 commit c0c6a46

4 files changed

Lines changed: 29 additions & 3 deletions

File tree

configs/processor.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@ output:
1414
training:
1515
n_steps_input: 1
1616
n_steps_output: 4
17-
stride: 4
17+
stride: null
1818
autoencoder_checkpoint: null
1919
freeze_autoencoder: false
2020

src/autocast/eval/processor.py

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -107,6 +107,12 @@ def parse_args() -> argparse.Namespace:
107107
default=None,
108108
help="Override training.n_steps_output (number of target time steps).",
109109
)
110+
parser.add_argument(
111+
"--stride",
112+
type=int,
113+
default=None,
114+
help="Override training stride used for rollouts (defaults to n_steps_output).",
115+
)
110116
parser.add_argument(
111117
"--checkpoint",
112118
type=Path,

src/autocast/train/configuration.py

Lines changed: 15 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@ class TrainingParams:
1818

1919
n_steps_input: int
2020
n_steps_output: int
21+
stride: int
2122
autoencoder_checkpoint: Path | None
2223
freeze_autoencoder: bool
2324

@@ -135,6 +136,7 @@ def resolve_training_params(cfg: DictConfig, args) -> TrainingParams:
135136
if training_cfg is not None
136137
else False
137138
)
139+
stride_cfg = training_cfg.get("stride") if training_cfg is not None else None
138140

139141
n_steps_input = args.n_steps_input or n_steps_input_cfg
140142
n_steps_output = args.n_steps_output or n_steps_output_cfg
@@ -147,13 +149,26 @@ def resolve_training_params(cfg: DictConfig, args) -> TrainingParams:
147149
args.freeze_autoencoder if args.freeze_autoencoder is not None else freeze_cfg
148150
)
149151

152+
if stride_cfg in (None, "auto"):
153+
stride_cfg = n_steps_output
154+
stride_override = getattr(args, "stride", None)
155+
stride = stride_override or stride_cfg or n_steps_output
156+
if stride < 1:
157+
msg = "stride must be >= 1."
158+
raise ValueError(msg)
159+
160+
if training_cfg is not None:
161+
with open_dict(training_cfg):
162+
training_cfg["stride"] = stride
163+
150164
if n_steps_output < 1:
151165
msg = "n_steps_output must be >= 1 for processor training."
152166
raise ValueError(msg)
153167

154168
return TrainingParams(
155169
n_steps_input=n_steps_input,
156170
n_steps_output=n_steps_output,
171+
stride=stride,
157172
autoencoder_checkpoint=checkpoint,
158173
freeze_autoencoder=freeze_autoencoder,
159174
)

src/autocast/train/processor.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -86,6 +86,12 @@ def parse_args() -> argparse.Namespace:
8686
default=None,
8787
help="Override training.n_steps_output (number of target time steps).",
8888
)
89+
parser.add_argument(
90+
"--stride",
91+
type=int,
92+
default=None,
93+
help="Override training stride (rollout interval between predictions).",
94+
)
8995
parser.add_argument(
9096
"--work-dir",
9197
type=Path,
@@ -254,8 +260,7 @@ def main() -> None: # noqa: PLR0915
254260
max_rollout_steps = epd_cfg.get("max_rollout_steps", 10)
255261
loss_cfg = epd_cfg.get("loss_func")
256262
loss_func = instantiate(loss_cfg) if loss_cfg is not None else nn.MSELoss()
257-
training_cfg = cfg.get("training") or {}
258-
stride = training_cfg.get("stride", 1)
263+
stride = training_params.stride
259264

260265
model = EncoderProcessorDecoder(
261266
encoder_decoder=encoder_decoder,

0 commit comments

Comments
 (0)