Skip to content

Commit 19a576e

Browse files
committed
Merge branch 'main' into diffusion_model
2 parents 2d5f2a1 + cd9f8c2 commit 19a576e

11 files changed

Lines changed: 668 additions & 38 deletions

File tree

.github/workflows/ci.yaml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -22,8 +22,8 @@ jobs:
2222
fail-fast: true
2323
matrix:
2424
os: ["ubuntu-latest", "macos-latest", "windows-latest"]
25-
# only run for MSV Python 3.10 on PRs, but all versions on pushes to main
26-
python-version: ${{ github.event_name == 'push' && fromJSON('["3.10", "3.11", "3.12"]') || fromJSON('["3.10"]') }}
25+
# only run for MSV Python 3.11 on PRs, but all versions on pushes to main
26+
python-version: ${{ github.event_name == 'push' && fromJSON('["3.11", "3.12"]') || fromJSON('["3.11"]') }}
2727
env:
2828
TURN_OFF_MPS_IF_RUNNING_CI: 1
2929
MPLBACKEND: Agg

.github/workflows/release.yaml

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -15,12 +15,12 @@ jobs:
1515
- name: Set up Python
1616
uses: actions/setup-python@v5
1717
with:
18-
python-version: '3.10'
18+
python-version: '3.11'
1919

2020
- name: Install uv
2121
uses: astral-sh/setup-uv@v7
2222
with:
23-
python-version: '3.10'
23+
python-version: '3.11'
2424
enable-cache: true
2525
cache-suffix: release
2626
activate-environment: true

src/auto_cast/data/datamodule.py

Lines changed: 175 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,175 @@
1+
from pathlib import Path
2+
3+
import torch
4+
from the_well.data.datamodule import WellDataModule
5+
from the_well.data.normalization import ZScoreNormalization
6+
from torch.utils.data import DataLoader
7+
8+
from auto_cast.data.dataset import SpatioTemporalDataset
9+
from auto_cast.types import collate_batches
10+
11+
12+
class SpatioTemporalDataModule(WellDataModule):
13+
"""A class for spatio-temporal data modules."""
14+
15+
def __init__(
16+
self,
17+
data_path: str | None,
18+
data: dict[str, dict] | None = None,
19+
dataset_cls: type[SpatioTemporalDataset] = SpatioTemporalDataset,
20+
n_steps_input: int = 1,
21+
n_steps_output: int = 1,
22+
stride: int = 1,
23+
# TODO: support for passing data from dict
24+
input_channel_idxs: tuple[int, ...] | None = None,
25+
output_channel_idxs: tuple[int, ...] | None = None,
26+
batch_size: int = 4,
27+
dtype: torch.dtype = torch.float32,
28+
ftype: str = "torch",
29+
verbose: bool = False,
30+
use_normalization: bool = False,
31+
):
32+
self.verbose = verbose
33+
self.use_normalization = use_normalization
34+
35+
base_path = Path(data_path) if data_path is not None else None
36+
suffix = ".pt" if ftype == "torch" else ".h5"
37+
fname = f"data{suffix}"
38+
train_path = base_path / "train" / fname if base_path is not None else None
39+
valid_path = base_path / "valid" / fname if base_path is not None else None
40+
test_path = base_path / "test" / fname if base_path is not None else None
41+
42+
# Create training dataset first (without normalization)
43+
self.train_dataset = dataset_cls(
44+
data_path=str(train_path) if train_path is not None else None,
45+
data=data["train"] if data is not None else None,
46+
n_steps_input=n_steps_input,
47+
n_steps_output=n_steps_output,
48+
stride=stride,
49+
input_channel_idxs=input_channel_idxs,
50+
output_channel_idxs=output_channel_idxs,
51+
dtype=dtype,
52+
verbose=self.verbose,
53+
use_normalization=False, # Temporarily disable to compute stats
54+
norm=None,
55+
)
56+
57+
# Compute normalization from training data if requested
58+
norm = None
59+
if self.use_normalization:
60+
if self.verbose:
61+
print("Computing normalization statistics from training data...")
62+
norm = ZScoreNormalization
63+
# if self.verbose:
64+
# print(f" Mean (per channel): {norm.mean}")
65+
# print(f" Std (per channel): {norm.std}")
66+
67+
# Now enable normalization for training dataset
68+
self.train_dataset.use_normalization = True
69+
self.train_dataset.norm = norm
70+
71+
self.val_dataset = dataset_cls(
72+
data_path=str(valid_path) if valid_path is not None else None,
73+
data=data["valid"] if data is not None else None,
74+
n_steps_input=n_steps_input,
75+
n_steps_output=n_steps_output,
76+
stride=stride,
77+
input_channel_idxs=input_channel_idxs,
78+
output_channel_idxs=output_channel_idxs,
79+
dtype=dtype,
80+
verbose=self.verbose,
81+
use_normalization=self.use_normalization,
82+
norm=norm,
83+
)
84+
self.test_dataset = dataset_cls(
85+
data_path=str(test_path) if test_path is not None else None,
86+
data=data["test"] if data is not None else None,
87+
n_steps_input=n_steps_input,
88+
n_steps_output=n_steps_output,
89+
stride=stride,
90+
input_channel_idxs=input_channel_idxs,
91+
output_channel_idxs=output_channel_idxs,
92+
dtype=dtype,
93+
verbose=self.verbose,
94+
use_normalization=self.use_normalization,
95+
norm=norm,
96+
)
97+
self.rollout_val_dataset = dataset_cls(
98+
data_path=str(train_path) if train_path is not None else None,
99+
data=data["train"] if data is not None else None,
100+
n_steps_input=n_steps_input,
101+
n_steps_output=n_steps_output,
102+
stride=stride,
103+
input_channel_idxs=input_channel_idxs,
104+
output_channel_idxs=output_channel_idxs,
105+
full_trajectory_mode=True,
106+
dtype=dtype,
107+
verbose=self.verbose,
108+
use_normalization=self.use_normalization,
109+
norm=norm,
110+
)
111+
self.rollout_test_dataset = dataset_cls(
112+
data_path=str(test_path) if test_path is not None else None,
113+
data=data["test"] if data is not None else None,
114+
n_steps_input=n_steps_input,
115+
n_steps_output=n_steps_output,
116+
stride=stride,
117+
input_channel_idxs=input_channel_idxs,
118+
output_channel_idxs=output_channel_idxs,
119+
full_trajectory_mode=True,
120+
dtype=dtype,
121+
verbose=self.verbose,
122+
use_normalization=self.use_normalization,
123+
norm=norm,
124+
)
125+
self.batch_size = batch_size
126+
127+
def train_dataloader(self) -> DataLoader:
128+
"""DataLoader for training."""
129+
return DataLoader(
130+
self.train_dataset,
131+
batch_size=self.batch_size,
132+
shuffle=True,
133+
num_workers=1,
134+
collate_fn=collate_batches,
135+
)
136+
137+
def val_dataloader(self) -> DataLoader:
138+
"""DataLoader for standard validation (not full trajectory rollouts)."""
139+
return DataLoader(
140+
self.val_dataset,
141+
batch_size=self.batch_size,
142+
shuffle=False,
143+
num_workers=1,
144+
collate_fn=collate_batches,
145+
)
146+
147+
def rollout_val_dataloader(self) -> DataLoader:
148+
"""DataLoader for full trajectory rollouts on validation data."""
149+
return DataLoader(
150+
self.rollout_val_dataset,
151+
batch_size=self.batch_size,
152+
shuffle=False,
153+
num_workers=1,
154+
collate_fn=collate_batches,
155+
)
156+
157+
def test_dataloader(self) -> DataLoader:
158+
"""DataLoader for testing."""
159+
return DataLoader(
160+
self.test_dataset,
161+
batch_size=self.batch_size,
162+
shuffle=False,
163+
num_workers=1,
164+
collate_fn=collate_batches,
165+
)
166+
167+
def rollout_test_dataloader(self) -> DataLoader:
168+
"""DataLoader for full trajectory rollouts on test data."""
169+
return DataLoader(
170+
self.rollout_test_dataset,
171+
batch_size=self.batch_size,
172+
shuffle=False,
173+
num_workers=1,
174+
collate_fn=collate_batches,
175+
)

0 commit comments

Comments
 (0)