-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathhelper.py
More file actions
executable file
·289 lines (236 loc) · 9.66 KB
/
helper.py
File metadata and controls
executable file
·289 lines (236 loc) · 9.66 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
import torch
import torch.nn as nn
import torch.nn.functional as F
import torchcde
from torchdiffeq import odeint
import matplotlib.pyplot as plt
import numpy as np
import pandas as pd
from scipy.signal import savgol_filter
import copy
import os
import json
import random
from sklearn.metrics import mean_squared_error, mean_absolute_error
# Device
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
def compute_metrics(I_pred, I_true, t_split, seed, results_root="results"):
os.makedirs(results_root, exist_ok=True)
I_pred_split = I_pred[t_split:]
I_true_split = I_true[t_split:]
assert len(I_pred_split) == len(I_true_split), "Mismatch in forecast length"
mse = mean_squared_error(I_true_split, I_pred_split)
rmse = np.sqrt(mse)
mae = mean_absolute_error(I_true_split, I_pred_split)
mape = np.mean(np.abs((I_true_split - I_pred_split) / (I_true_split + 1e-8))) * 100
P = I_true_split / (np.sum(I_true_split) + 1e-8)
Q = I_pred_split / (np.sum(I_pred_split) + 1e-8)
kl_div = np.sum(P * np.log((P + 1e-8) / (Q + 1e-8)))
metrics = {
"rmse": float(rmse),
"mae": float(mae),
"mape": float(mape),
"kl_divergence": float(kl_div),
"forecast_len": len(I_true_split),
"seed": seed
}
print(f"Metrics for seed {seed}: RMSE = {rmse:.4f}, MAE = {mae:.4f}, MAPE = {mape:.4f}, KL Divergence = {kl_div:.4f}")
out_path = os.path.join(results_root, "metrics.json")
with open(out_path, "w") as f:
json.dump(metrics, f, indent=2)
print(f" Metrics saved to {out_path}")
return metrics
def save_training_artifacts(save_dir, model_dict, optimizer,
loss_history,
pred,
params=None, z_t=None, metadata=None, seed=0):
final_dir = os.path.join(save_dir)
os.makedirs(final_dir, exist_ok=True)
# Save model
torch.save(model_dict, os.path.join(final_dir, "model.pt"))
# Save optimizer
torch.save(optimizer, os.path.join(final_dir, "optimizer.pt"))
# Save losses
np.save(os.path.join(final_dir, "train_loss_history.npy"), np.array(loss_history))
# Save prediction
np.save(os.path.join(final_dir, "pred.npy"), pred.cpu().numpy())
# Save parameters if available
if params is not None:
beta, gamma, delta = params
np.save(os.path.join(final_dir, "params.npy"), np.stack([beta, gamma, delta], axis=-1))
# Save z_t if available
if z_t is not None:
np.save(os.path.join(final_dir, "z_t.npy"), z_t.detach().cpu().numpy())
# Save metadata
if metadata:
with open(os.path.join(final_dir, "metadata.json"), "w") as f:
json.dump(metadata, f, indent=4)
def save_plots(save_dir, seed, t, t_split=None,
S_true=None, I_true=None, R_true=None, beta_true=None, gamma_true=None, delta_true=None,
pred_np=None, beta=None, gamma=None, delta=None,
train_losses=None
):
plot_dir = os.path.join(save_dir)
os.makedirs(plot_dir, exist_ok=True)
plt.figure(figsize=(6, 4))
if train_losses is not None and len(train_losses) > 0:
plt.plot(train_losses, label="Train Loss")
plt.xlabel("Epoch")
plt.ylabel("Loss")
plt.title("Training Curve")
plt.legend()
plt.grid(True)
else:
plt.axis("off")
plt.text(0.5, 0.5, "No training loss for this model",
ha="center", va="center", fontsize=12)
plt.tight_layout()
plt.savefig(os.path.join(plot_dir, "loss_curve.png"))
# plt.show()
plt.close()
plt.figure(figsize=(6, 4))
if I_true is not None:
plt.plot(t, I_true, label="True I", color="black")
if pred_np is not None:
plt.plot(t, pred_np[:, 1], label="Forecast I", linestyle="--", color="red")
plt.axvline(x=t[t_split], color='gray', linestyle='--', label='Train/Forecast Split')
plt.xlabel("Time")
plt.ylabel("I(t)")
plt.title("Forecast vs Ground Truth")
plt.legend()
plt.grid(True)
plt.tight_layout()
plt.savefig(os.path.join(plot_dir, "forecast_I.png"))
# plt.show()
plt.close()
plt.figure(figsize=(6, 4))
if S_true is not None:
plt.plot(t, S_true, label="True S", color='blue')
plt.plot(t, I_true, label="True I", color='black')
if R_true is not None:
plt.plot(t, R_true, label="True R", color='green')
plt.plot(t, pred_np[:, 0], label="Forecast S", linestyle="--", color="blue")
plt.plot(t, pred_np[:, 1], label="Forecast I", linestyle="--", color="red")
plt.plot(t, pred_np[:, 2], label="Forecast R", linestyle="--", color="green")
plt.axvline(x=t[t_split], color='gray', linestyle='--', label='Train/Forecast Split')
plt.title("True vs Forecast S, I, R")
plt.xlabel("Time")
plt.ylabel("Value")
plt.legend()
plt.grid(True)
plt.tight_layout()
plt.savefig(os.path.join(plot_dir, "true_vs_forecast.png"))
# plt.show()
plt.close()
# Plot beta/gamma/delta trajectories (only if provided)
plt.figure(figsize=(6, 4))
if beta_true is not None: plt.plot(t[:len(beta_true)], beta_true, label="True β", color='brown')
if gamma_true is not None: plt.plot(t[:len(gamma_true)], gamma_true, label="True γ", color='orange')
if delta_true is not None: plt.plot(t[:len(delta_true)], delta_true, label="True δ", color='purple')
if beta is not None: plt.plot(t[:len(beta)], beta, label="Estimated β", color='brown', linestyle="--")
if gamma is not None: plt.plot(t[:len(gamma)], gamma, label="Estimated γ", color='orange', linestyle="--")
if delta is not None: plt.plot(t[:len(delta)], delta, label="Estimated δ", color='purple', linestyle="--")
plt.title("Parameters Over Time")
plt.xlabel("Time")
plt.ylabel("Value")
plt.legend()
plt.grid(True)
plt.tight_layout()
plt.savefig(os.path.join(plot_dir, "parameters_over_time.png"))
# plt.show()
plt.close()
# Windowed Accuracy
def rmse_windows(I_pred, I_true, t_split, window_sizes=[4, 5, 6, 8, 10, 20, 30]):
"""
Compute RMSE only for the window starting at t_split and ending at t_split + window_size.
Returns: {window_size: rmse_value}
"""
results = {}
I_pred = np.asarray(I_pred)
I_true = np.asarray(I_true)
for w in window_sizes:
end = t_split + w
if end > len(I_pred):
# cannot compute full window
results[w] = np.nan
continue
rmse = np.sqrt(np.mean((I_pred[t_split:end] - I_true[t_split:end])**2))
results[w] = rmse
return results
# Forecast Horizon Quality
def forecast_horizon(I_pred, I_true, thresholds=[5e-2, 1e-2, 5e-4, 1e-4]):
"""
Determine up to what time index forecast is accurate
for each RMSE threshold.
Returns dict: {threshold: last_good_idx}
"""
I_pred = np.asarray(I_pred)
I_true = np.asarray(I_true)
T = len(I_pred)
cumulative_rmse = np.sqrt(np.cumsum((I_pred - I_true)**2) / np.arange(1, T+1))
results = {}
for th in thresholds:
good_idxs = np.where(cumulative_rmse < th)[0]
results[th] = int(good_idxs[-1]) if len(good_idxs) > 0 else None
return results
def peak_metrics(I_pred, I_true, dataset=None, t_split=None,
peak_after_split_datasets=("flu-multi", "hhs")):
"""
Peak metrics with an option to enforce that BOTH true and predicted peaks
are computed only on the test segment [t_split:].
For dataset in peak_after_split_datasets:
requires t_split and uses window start = t_split
(so peaks are guaranteed to be after the split).
Otherwise:
uses full series [0:].
Returns indices in original coordinates.
"""
I_pred = np.asarray(I_pred).reshape(-1)
I_true = np.asarray(I_true).reshape(-1)
if len(I_pred) != len(I_true):
raise ValueError(f"I_pred and I_true length mismatch: {len(I_pred)} vs {len(I_true)}")
ds = str(dataset).lower() if dataset is not None else ""
enforce = (ds in {d.lower() for d in peak_after_split_datasets})
# decide the peak-search window
if enforce:
if t_split is None:
raise ValueError(f"t_split must be provided when dataset is {dataset}")
start = int(t_split)
else:
start = 0
start = max(0, min(start, len(I_true) - 1))
# search peaks ONLY within [start:]
true_seg = I_true[start:]
pred_seg = I_pred[start:]
if len(true_seg) == 0 or len(pred_seg) == 0:
# extremely unlikely, but be safe
start = 0
true_seg = I_true
pred_seg = I_pred
true_peak_idx = start + int(np.argmax(true_seg))
pred_peak_idx = start + int(np.argmax(pred_seg))
true_peak_val = float(I_true[true_peak_idx])
pred_peak_val = float(I_pred[pred_peak_idx])
return {
"peak_window_start": int(start),
"true_peak_idx": int(true_peak_idx),
"pred_peak_idx": int(pred_peak_idx),
"peak_time_error": int(abs(pred_peak_idx - true_peak_idx)),
"true_peak_val": true_peak_val,
"pred_peak_val": pred_peak_val,
"peak_value_error": float(abs(pred_peak_val - true_peak_val)),
"signed_peak_time_error": int(pred_peak_idx - true_peak_idx),
"signed_peak_value_error": float(pred_peak_val - true_peak_val),
}
def set_seed(seed):
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
# For CUDA
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
# Make CUDNN deterministic (for CNNs, not relevant for your model but good practice)
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
# For reproducibility of dataloader workers
os.environ['PYTHONHASHSEED'] = str(seed)