Skip to content

Commit e8cf668

Browse files
authored
Merge pull request #318 from alan-turing-institute/add-dataset-channel-idx-subset
Add channel index support
2 parents e570de0 + 71f5533 commit e8cf668

5 files changed

Lines changed: 116 additions & 34 deletions

File tree

src/autocast/data/datamodule.py

Lines changed: 6 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -168,9 +168,7 @@ def __init__(
168168
n_steps_input: int = 1,
169169
n_steps_output: int = 1,
170170
stride: int = 1,
171-
# TODO: support for passing data from dict
172-
input_channel_idxs: tuple[int, ...] | None = None,
173-
output_channel_idxs: tuple[int, ...] | None = None,
171+
channel_idxs: tuple[int, ...] | None = None,
174172
batch_size: int = 4,
175173
dtype: torch.dtype = torch.float32,
176174
ftype: str = "torch",
@@ -205,8 +203,7 @@ def __init__(
205203
n_steps_input=n_steps_input,
206204
n_steps_output=n_steps_output,
207205
stride=stride,
208-
input_channel_idxs=input_channel_idxs,
209-
output_channel_idxs=output_channel_idxs,
206+
channel_idxs=channel_idxs,
210207
autoencoder_mode=self.autoencoder_mode,
211208
full_trajectory_mode=full_trajectory_mode,
212209
dtype=dtype,
@@ -237,8 +234,7 @@ def __init__(
237234
n_steps_input=n_steps_input,
238235
n_steps_output=n_steps_output,
239236
stride=stride,
240-
input_channel_idxs=input_channel_idxs,
241-
output_channel_idxs=output_channel_idxs,
237+
channel_idxs=channel_idxs,
242238
autoencoder_mode=self.autoencoder_mode,
243239
full_trajectory_mode=full_trajectory_mode,
244240
dtype=dtype,
@@ -254,8 +250,7 @@ def __init__(
254250
n_steps_input=n_steps_input,
255251
n_steps_output=n_steps_output,
256252
stride=stride,
257-
input_channel_idxs=input_channel_idxs,
258-
output_channel_idxs=output_channel_idxs,
253+
channel_idxs=channel_idxs,
259254
autoencoder_mode=self.autoencoder_mode,
260255
full_trajectory_mode=full_trajectory_mode,
261256
dtype=dtype,
@@ -275,8 +270,7 @@ def __init__(
275270
n_steps_input=n_steps_input,
276271
n_steps_output=n_steps_output,
277272
stride=stride,
278-
input_channel_idxs=input_channel_idxs,
279-
output_channel_idxs=output_channel_idxs,
273+
channel_idxs=channel_idxs,
280274
full_trajectory_mode=True,
281275
dtype=dtype,
282276
verbose=self.verbose,
@@ -291,8 +285,7 @@ def __init__(
291285
n_steps_input=n_steps_input,
292286
n_steps_output=n_steps_output,
293287
stride=stride,
294-
input_channel_idxs=input_channel_idxs,
295-
output_channel_idxs=output_channel_idxs,
288+
channel_idxs=channel_idxs,
296289
full_trajectory_mode=True,
297290
dtype=dtype,
298291
verbose=self.verbose,

src/autocast/data/dataset.py

Lines changed: 15 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -31,16 +31,14 @@ def to_sample(data: dict) -> Sample:
3131
class SpatioTemporalDataset(Dataset, BatchMixin):
3232
"""A class for spatio-temporal datasets."""
3333

34-
def __init__(
34+
def __init__( # noqa: PLR0915
3535
self,
3636
data_path: str | None,
3737
data: dict | None = None,
3838
n_steps_input: int = 1,
3939
n_steps_output: int = 1,
4040
stride: int = 1,
41-
# TODO: support for passing data from dict
42-
input_channel_idxs: tuple[int, ...] | None = None,
43-
output_channel_idxs: tuple[int, ...] | None = None,
41+
channel_idxs: tuple[int, ...] | None = None,
4442
full_trajectory_mode: bool = False,
4543
autoencoder_mode: bool = False,
4644
dtype: torch.dtype = torch.float32,
@@ -67,10 +65,9 @@ def __init__(
6765
Stride for sampling the data.
6866
data: dict | None
6967
Preloaded data. Defaults to None.
70-
input_channel_idxs: tuple[int, ...] | None
71-
Indices of input channels to use. Defaults to None.
72-
output_channel_idxs: tuple[int, ...] | None
73-
Indices of output channels to use. Defaults to None.
68+
channel_idxs: tuple[int, ...] | None
69+
Indices of channels to select from the raw data (applied to both
70+
input and output). If None, all channels are used. Defaults to None.
7471
full_trajectory_mode: bool
7572
If True, use full trajectories without creating subtrajectories.
7673
autoencoder_mode: bool
@@ -104,8 +101,17 @@ def __init__(
104101
if data is not None:
105102
self.parse_data(data)
106103

104+
if channel_idxs is not None:
105+
self.data = self.data[..., list(channel_idxs)]
106+
107107
self.set_up_normalization()
108108

109+
if channel_idxs is not None and self.norm is not None:
110+
self.norm.core_field_names = [
111+
self.norm.core_field_names[i] for i in channel_idxs
112+
]
113+
self.norm._precompute_flattened_stats()
114+
109115
if autoencoder_mode and full_trajectory_mode:
110116
msg = "autoencoder_mode and full_trajectory_mode cannot both be True."
111117
raise ValueError(msg)
@@ -124,8 +130,7 @@ def __init__(
124130
self.n_steps_input = n_steps_input
125131
self.n_steps_output = n_steps_output
126132
self.stride = stride
127-
self.input_channel_idxs = input_channel_idxs
128-
self.output_channel_idxs = output_channel_idxs
133+
self.channel_idxs = channel_idxs
129134

130135
# Destructured here
131136
(

src/autocast/scripts/eval/encoder_processor_decoder.py

Lines changed: 7 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -280,11 +280,13 @@ def _resolve_rollout_channel_names(dataset: Any) -> list[str] | None:
280280

281281
norm = getattr(dataset, "norm", None)
282282
raw_names = getattr(norm, "core_field_names", None)
283+
names_already_subset = raw_names is not None
283284

284285
if not isinstance(raw_names, Sequence) or isinstance(raw_names, str):
285286
normalization_stats = getattr(dataset, "normalization_stats", None)
286287
if isinstance(normalization_stats, Mapping):
287288
raw_names = normalization_stats.get("core_field_names")
289+
names_already_subset = False
288290

289291
if not isinstance(raw_names, Sequence) or isinstance(raw_names, str):
290292
return None
@@ -293,14 +295,14 @@ def _resolve_rollout_channel_names(dataset: Any) -> list[str] | None:
293295
if not channel_names:
294296
return None
295297

296-
output_channel_idxs = getattr(dataset, "output_channel_idxs", None)
297-
if output_channel_idxs is not None:
298+
channel_idxs = getattr(dataset, "channel_idxs", None)
299+
if channel_idxs is not None and not names_already_subset:
298300
try:
299-
channel_names = [channel_names[idx] for idx in output_channel_idxs]
301+
channel_names = [channel_names[idx] for idx in channel_idxs]
300302
except (TypeError, IndexError):
301303
log.warning(
302-
"Could not apply output_channel_idxs=%s to channel names %s.",
303-
output_channel_idxs,
304+
"Could not apply channel_idxs=%s to channel names %s.",
305+
channel_idxs,
304306
channel_names,
305307
)
306308
return None

tests/data/test_dataset_normalization.py

Lines changed: 71 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -137,6 +137,77 @@ def test_normalized_data_is_transformed(deterministic_data, stats_dict):
137137
)
138138

139139

140+
def test_channel_idxs_slices_data_and_subsets_norm(deterministic_data, stats_dict):
141+
"""`channel_idxs` should slice data channels and align norm field names."""
142+
dataset = ReactionDiffusionDataset(
143+
data_path=None,
144+
data=deterministic_data,
145+
n_steps_input=2,
146+
n_steps_output=1,
147+
channel_idxs=(1,),
148+
use_normalization=True,
149+
normalization_type=ZScoreNormalization,
150+
normalization_stats=stats_dict,
151+
)
152+
153+
# Sliced data keeps only channel 1 (V).
154+
assert dataset.data.shape[-1] == 1
155+
assert dataset[0].input_fields.shape[-1] == 1
156+
157+
# Norm field names subset to match sliced channels.
158+
assert dataset.norm is not None
159+
assert dataset.norm.core_field_names == ["V"]
160+
161+
# Normalization uses V stats (mean=4.0, std=2.0) against the original V channel.
162+
expected = (deterministic_data["data"][0][:2, ..., 1] - 4.0) / 2.0
163+
assert torch.allclose(dataset[0].input_fields[..., 0], expected)
164+
165+
166+
def test_channel_idxs_none_is_noop(deterministic_data):
167+
"""`channel_idxs=None` should leave all channels intact."""
168+
dataset = ReactionDiffusionDataset(
169+
data_path=None,
170+
data=deterministic_data,
171+
n_steps_input=2,
172+
n_steps_output=1,
173+
channel_idxs=None,
174+
use_normalization=False,
175+
)
176+
assert dataset.data.shape[-1] == 2
177+
assert dataset[0].input_fields.shape[-1] == 2
178+
179+
180+
def test_datamodule_threads_channel_idxs(deterministic_data, stats_dict):
181+
"""DataModule should propagate `channel_idxs` to all sub-datasets."""
182+
dm = SpatioTemporalDataModule(
183+
data_path=None,
184+
data={
185+
"train": deterministic_data,
186+
"valid": deterministic_data,
187+
"test": deterministic_data,
188+
},
189+
dataset_cls=ReactionDiffusionDataset,
190+
n_steps_input=2,
191+
n_steps_output=1,
192+
batch_size=1,
193+
channel_idxs=(0,),
194+
use_normalization=True,
195+
normalization_type=ZScoreNormalization,
196+
normalization_stats=stats_dict,
197+
)
198+
199+
for ds in (
200+
dm.train_dataset,
201+
dm.val_dataset,
202+
dm.test_dataset,
203+
dm.rollout_val_dataset,
204+
dm.rollout_test_dataset,
205+
):
206+
assert ds.data.shape[-1] == 1
207+
assert ds.norm is not None
208+
assert ds.norm.core_field_names == ["U"]
209+
210+
140211
def test_datamodule_with_and_without_normalization(deterministic_data, stats_dict):
141212
"""Test DataModule can be configured with or without normalization."""
142213

tests/scripts/test_eval_encoder_processor_decoder.py

Lines changed: 17 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -281,10 +281,20 @@ def test_should_skip_metric_variogram_only():
281281
assert _should_skip_metric("ssr") is False
282282

283283

284-
def test_resolve_rollout_channel_names_from_norm_with_output_selection():
284+
def test_resolve_rollout_channel_names_from_norm_already_subset():
285285
dataset = SimpleNamespace(
286-
norm=SimpleNamespace(core_field_names=["u", "v", "p"]),
287-
output_channel_idxs=(2, 0),
286+
norm=SimpleNamespace(core_field_names=["p", "u"]),
287+
channel_idxs=(2, 0),
288+
)
289+
290+
assert _resolve_rollout_channel_names(dataset) == ["p", "u"]
291+
292+
293+
def test_resolve_rollout_channel_names_from_stats_applies_idxs():
294+
dataset = SimpleNamespace(
295+
norm=None,
296+
normalization_stats={"core_field_names": ["u", "v", "p"]},
297+
channel_idxs=(2, 0),
288298
)
289299

290300
assert _resolve_rollout_channel_names(dataset) == ["p", "u"]
@@ -294,16 +304,17 @@ def test_resolve_rollout_channel_names_returns_none_without_norm_names():
294304
dataset = SimpleNamespace(
295305
norm=None,
296306
metadata=SimpleNamespace(field_names={0: ["velocity_x", "velocity_y"]}),
297-
output_channel_idxs=None,
307+
channel_idxs=None,
298308
)
299309

300310
assert _resolve_rollout_channel_names(dataset) is None
301311

302312

303313
def test_resolve_rollout_channel_names_returns_none_on_invalid_output_indices():
304314
dataset = SimpleNamespace(
305-
norm=SimpleNamespace(core_field_names=["u", "v"]),
306-
output_channel_idxs=(0, 3),
315+
norm=None,
316+
normalization_stats={"core_field_names": ["u", "v"]},
317+
channel_idxs=(0, 3),
307318
)
308319

309320
assert _resolve_rollout_channel_names(dataset) is None

0 commit comments

Comments
 (0)