Skip to content

Commit 50020f0

Browse files
committed
fix sydney model for batches > 128
1 parent a41f42c commit 50020f0

3 files changed

Lines changed: 42 additions & 21 deletions

File tree

src_py/magnethub/sydney.py

Lines changed: 9 additions & 15 deletions
Original file line numberDiff line numberDiff line change
@@ -159,18 +159,14 @@ def __call__(self, data_B, data_F, data_T):
159159
if data_B.ndim == 1:
160160
data_B = np.array(data_B).reshape(1, -1)
161161

162-
loader = get_dataloader(data_B, data_F, data_T, self.mdl.norm)
162+
_, ts_feats, scalar_feats = get_dataloader(data_B, data_F, data_T, self.mdl.norm)
163163

164164
# 2.Validate the models
165-
data_P = torch.Tensor([]).to(self.device) # Allocate memory to store loss density
166-
167-
with torch.no_grad():
165+
self.mdl.eval()
166+
with torch.inference_mode():
168167
# Start model evaluation explicitly
169-
self.mdl.eval()
170-
for inputs, vars in loader:
171-
Pv, h_series = self.mdl(inputs.to(self.device), vars.to(self.device))
168+
data_P, h_series = self.mdl(ts_feats.to(self.device), scalar_feats.to(self.device))
172169

173-
data_P = torch.cat((data_P, Pv.to(self.device)), dim=0)
174170
data_P, h_series = data_P.cpu().numpy(), h_series.cpu().numpy()
175171

176172
# 3.Return results
@@ -337,9 +333,7 @@ def forward(self, x, hidden=None):
337333

338334

339335
def get_dataloader(data_B, data_F, data_T, norm, n_init=32):
340-
"""
341-
Preprocess data into a data loader.
342-
336+
"""Preprocess data into a data loader.
343337
Get a test dataloader.
344338
345339
Parameters
@@ -394,14 +388,14 @@ def get_dataloader(data_B, data_F, data_T, norm, n_init=32):
394388

395389
s0 = get_operator_init(in_B[:, 0, 0] - in_dB[:, 0, 0], in_dB, max_B, min_B) # Operator inital state
396390

391+
ts_feats = torch.cat((in_B, in_dB, in_dB_dt), dim=2)
392+
scalar_feats = torch.cat((in_F, in_T, s0), dim=1)
397393
# 6. Create dataloader to speed up data processing
398-
test_dataset = torch.utils.data.TensorDataset(
399-
torch.cat((in_B, in_dB, in_dB_dt), dim=2), torch.cat((in_F, in_T, s0), dim=1)
400-
)
394+
test_dataset = torch.utils.data.TensorDataset(ts_feats, scalar_feats)
401395
kwargs = {"num_workers": 0, "batch_size": 128, "drop_last": False}
402396
test_loader = torch.utils.data.DataLoader(test_dataset, **kwargs)
403397

404-
return test_loader
398+
return test_loader, ts_feats, scalar_feats
405399

406400

407401
# %% Predict the operator state at t0

tests/debug.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,27 @@
1+
import numpy as np
2+
import pandas as pd
3+
from pathlib import Path
4+
from magnethub.loss import LossModel, MATERIALS
5+
6+
test_ds = pd.read_csv(
7+
Path.cwd() / "tests" / "test_files" / "all_data.csv.gzip", dtype={"material": str}
8+
)
9+
errs_d = {}
10+
for m_lbl in MATERIALS:
11+
mdl = LossModel(material=m_lbl, team="paderborn")
12+
test_mat_df = test_ds.query("material == @m_lbl")
13+
p, h = mdl(
14+
test_mat_df.loc[:, [c for c in test_mat_df if c.startswith("B_t_")]].to_numpy(),
15+
test_mat_df.loc[:, "freq"].to_numpy(),
16+
test_mat_df.loc[:, "temp"].to_numpy(),
17+
)
18+
rel_err = np.abs(test_mat_df.ploss - p) / test_mat_df.ploss
19+
errs_d[m_lbl] = {
20+
"avg": np.mean(rel_err),
21+
"95th": np.quantile(rel_err, 0.95),
22+
"99th": np.quantile(rel_err, 0.99),
23+
'samples': len(rel_err),
24+
}
25+
rel_df = pd.DataFrame(errs_d).T
26+
print(f"Rel. errors")
27+
print(rel_df)

tests/test_sydney.py

Lines changed: 6 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -52,14 +52,14 @@ def test_longer_sequence():
5252

5353
def test_batch_execution():
5454
mdl = LossModel(material="3C92", team=TEAM_NAME)
55-
56-
b_waves = np.random.randn(100, 1024) * 200e-3 # mT
57-
freqs = np.random.randint(100e3, 750e3, size=100)
58-
temps = np.random.randint(20, 80, size=100)
55+
seq_len = 1412
56+
b_waves = np.random.randn(seq_len, 1024) * 200e-3 # mT
57+
freqs = np.random.randint(100e3, 750e3, size=seq_len)
58+
temps = np.random.randint(20, 80, size=seq_len)
5959
p, h = mdl(b_waves, freqs, temps)
6060

61-
assert p.size == 100, f"{p.size=}"
62-
assert h.shape == (100, 1024), f"{h.shape=}"
61+
assert p.size == seq_len, f"{p.size=}"
62+
assert h.shape == (seq_len, 1024), f"{h.shape=}"
6363

6464

6565
def test_material_availability():

0 commit comments

Comments
 (0)