Skip to content

Commit 77b900d

Browse files
committed
Update saliency_analysis_tcn.py
1 parent f30068d commit 77b900d

1 file changed

Lines changed: 120 additions & 112 deletions

File tree

src/results_finalisation/saliency_analysis_tcn.py

Lines changed: 120 additions & 112 deletions
Original file line numberDiff line numberDiff line change
@@ -18,210 +18,218 @@
1818
import json
1919
from pathlib import Path
2020
import numpy as np
21+
22+
# Core deep learning library (model loading, tensor ops)
2123
import torch
2224
import matplotlib.pyplot as plt
2325
import pandas as pd
2426

25-
27+
# For progress bars when looping through patients or batches
2628
from tqdm import tqdm
2729

28-
30+
# Import the TCN model architecture definition
2931
from ml_models_tcn.tcn_model import TCNModel
3032

3133
# -----------------------
3234
# Path Directories
3335
# -----------------------
3436
SCRIPT_DIR = Path(__file__).resolve().parent
3537

36-
# === Input directories ===
38+
# === TCN Model ===
3739
TRAINED_MODEL_PATH = SCRIPT_DIR.parent.parent / "src" / "prediction_diagnostics" / "trained_models_refined" / "tcn_best_refined.pt"
3840
CONFIG_PATH = SCRIPT_DIR.parent.parent / "src" / "prediction_diagnostics" / "trained_models_refined" / "config_refined.json"
39-
TEST_DATA_DIR = SCRIPT_DIR.parent.parent / "src" / "ml_models_tcn" / "prepared_datasets" # test tensors
41+
42+
# === TCN data and preprocessing directories ===
43+
TEST_DATA_DIR = SCRIPT_DIR.parent.parent / "src" / "ml_models_tcn" / "prepared_datasets"
4044
TCN_DIR = SCRIPT_DIR.parent.parent / "src" / "ml_models_tcn" / "deployment_models" / "preprocessing"
45+
46+
# === Preprocessing artifacts ===
4147
SPLITS_PATH = TCN_DIR / "patient_splits.json"
4248
PADDING_PATH = TCN_DIR / "padding_config.json"
4349
SCALER_PATH = TCN_DIR / "standard_scaler.pkl"
4450

51+
# === Model tensors ===
52+
TEST_TENSOR = TEST_DATA_DIR / "test.pt"
53+
MASK_TENSOR = TEST_DATA_DIR / "test_mask.pt"
54+
4555
# === Output directory ===
4656
RESULTS_DIR = SCRIPT_DIR / "interpretability_tcn"
4757
RESULTS_DIR.mkdir(parents=True, exist_ok=True)
4858

4959
# -----------------------------
5060
# Sanity Checks
5161
# -----------------------------
52-
assert TRAINED_MODEL_PATH.exists(), f"Missing model: {TRAINED_MODEL_PATH}"
53-
assert CONFIG_PATH.exists(), f"Missing config: {CONFIG_PATH}"
54-
assert (TEST_DATA_DIR / "test.pt").exists(), f"Missing test tensor: {TEST_DATA_DIR / 'test.pt'}"
55-
assert (TEST_DATA_DIR / "test_mask.pt").exists(), f"Missing test mask: {TEST_DATA_DIR / 'test_mask.pt'}"
56-
assert SPLITS_PATH.exists(), f"Missing splits: {SPLITS_PATH}"
62+
assert TRAINED_MODEL_PATH.exists(), f"[ERROR] Missing trained model file: {TRAINED_MODEL_PATH}"
63+
assert CONFIG_PATH.exists(), f"[ERROR] Missing model configuration: {CONFIG_PATH}"
64+
assert (TEST_DATA_DIR / "test.pt").exists(), f"[ERROR] Missing test tensor: {TEST_DATA_DIR / 'test.pt'}"
65+
assert (TEST_DATA_DIR / "test_mask.pt").exists(), f"[ERROR] Missing test mask tensor: {TEST_DATA_DIR / 'test_mask.pt'}"
66+
assert SPLITS_PATH.exists(), f"[ERROR] Missing patient splits file: {SPLITS_PATH}"
67+
assert PADDING_PATH.exists(), f"[ERROR] Missing padding config file: {PADDING_PATH}"
68+
assert SCALER_PATH.exists(), f"[ERROR] Missing standard scaler file: {SCALER_PATH}"
69+
70+
print("[INFO] All required input files found. Ready to proceed.")
5771

5872
# -----------------------------
59-
# Device
73+
# 1. Load Model, Config, and Test Data
6074
# -----------------------------
75+
# --- Load device (cpu or gpu) ---
6176
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
6277
print(f"[INFO] Using device: {device}")
6378

64-
# -----------------------------
65-
# Load Model + Config + Test Tensors
66-
# -----------------------------
67-
# Load model configuration (refined)
79+
# --- Load config (architecture & parameters) ---
6880
with open(CONFIG_PATH) as f:
6981
config = json.load(f)
7082
arch = config["model_architecture"]
7183

72-
with open(SCALER_DIR / "padding_config.json", 'r') as f:
84+
# --- Load padding/feature configuration ---
85+
with open(PADDING_PATH) as f:
7386
pad_cfg = json.load(f)
74-
7587
feature_cols = pad_cfg['feature_cols']
7688
MAX_SEQ_LEN = pad_cfg['max_seq_len']
7789
target_cols = pad_cfg["target_cols"]
7890

79-
# Load tensors (x_test: [n_patients, seq_len, n_features], mask_test similar)
80-
x_test = torch.load(TEST_DATA_DIR / "test.pt", map_location=device)
81-
mask_test = torch.load(TEST_DATA_DIR / "test_mask.pt", map_location=device)
91+
# --- Load test tensors ---
92+
x_test = torch.load(TEST_TENSOR, map_location=device)
93+
mask_test = torch.load(MASK_TENSOR, map_location=device)
8294

95+
# --- Validate test tensor shapes (shape should match padding config) ---
8396
n_test, seq_len, n_features = x_test.shape
84-
print(f"[INFO] Loaded x_test: {x_test.shape}, mask: {mask_test.shape}")
8597
assert seq_len == MAX_SEQ_LEN, f"Expected seq_len {MAX_SEQ_LEN}, got {seq_len}"
8698
assert n_features == len(feature_cols), "Feature dimension mismatch with padding_config"
99+
print(f"[INFO] Loaded test data: {x_test.shape}, mask: {mask_test.shape}")
87100

88-
# --- Reload model architecture and weights ---
89-
arch = config['model_architecture']
101+
# --- Rebuild model architecture (from config) ---
90102
model = TCNModel(
91103
num_features=n_features,
92104
num_channels=arch['num_channels'],
93105
kernel_size=arch['kernel_size'],
94106
dropout=arch['dropout'],
95107
head_hidden=arch['head_hidden']
96-
)
108+
).to(device)
97109

110+
# --- Load trained model weights (tcn_best_refined.pt) ---
98111
state_dict = torch.load(TRAINED_MODEL_PATH, map_location=device)
99112
model.load_state_dict(state_dict)
100-
model.to(device)
113+
114+
# --- Set model to eval mode ---
101115
model.eval()
102-
print('[INFO] Loaded TCN model')
116+
print("[INFO] Loaded TCN model and moved to device.")
117+
118+
# -----------------------------
119+
# 2. Define Targets & Saliency Function
120+
# -----------------------------
103121

104-
# Targets we will explain and corresponding model output keys
122+
# --- Target heads to explain ---
105123
TARGETS = [
106-
("max", "logit_max"), # binary classification -> use logit (pre-sigmoid)
107-
("median", "logit_median"),
108-
("pct_time_high", "regression") # regression head
124+
("max_risk", "logit_max"), # classification
125+
("median_risk", "logit_median"), # classification
126+
("pct_time_high", "regression") # regression
109127
]
110128

111-
# --- Helper: compute gradient*input saliency for a batch of patients ---
129+
# --- Saliency Computation Helper ---
112130
def compute_saliency_for_batch(model, x_batch, mask_batch, head_key):
113-
"""Compute gradient * input saliency for a batch.
114-
x_batch: torch tensor (B, T, F) requires_grad=False
115-
mask_batch: torch tensor (B, T) — not used directly in gradient calc but kept for clarity
116-
head_key: string key returned by model(x, mask) -> selects which scalar to compute gradient of
117-
118-
Returns: saliency_abs: numpy array shape (B, T, F) of absolute(grad * input)
119131
"""
120-
# Ensure we run on device
132+
Compute |grad * input| saliency for a mini-batch.
133+
Args:
134+
model: trained TCN model
135+
x_batch: tensor (B, T, F)
136+
mask_batch: tensor (B, T)
137+
head_key: output head to explain ('logit_max', 'logit_median', 'regression')
138+
Returns:
139+
np.ndarray of shape (B, T, F) = |gradient * input|
140+
"""
121141
x = x_batch.clone().detach().to(device)
122142
x.requires_grad = True
123-
124-
mask = mask_batch.to(device) if mask_batch is not None else None
143+
mask = mask_batch.to(device)
125144

126145
outputs = model(x, mask)
127-
out = outputs[head_key] # shape (B, 1) or (B,)
128-
out = out.squeeze()
146+
out = outputs[head_key].squeeze()
129147

130-
# For multi-output batch: compute gradient of each scalar with respect to inputs
131-
# We'll compute a vector-Jacobian product that results in gradients of shape (B, T, F)
132148
grads = []
133149
for i in range(out.shape[0]):
134-
# Zero existing grads
135150
if x.grad is not None:
136151
x.grad.zero_()
137152
scalar = out[i]
138-
# Backprop scalar
139153
scalar.backward(retain_graph=True)
140-
g = x.grad[i].detach().cpu().numpy().copy() # (T, F)
141-
grads.append(g)
142-
grads = np.stack(grads, axis=0) # (B, T, F)
154+
grads.append(x.grad[i].detach().cpu().numpy())
155+
grads = np.stack(grads, axis=0)
156+
157+
saliency = grads * x.detach().cpu().numpy()
158+
return np.abs(saliency)
143159

144-
x_np = x.detach().cpu().numpy() # (B, T, F)
145160

146-
saliency = grads * x_np # elementwise grad * input
147-
saliency_abs = np.abs(saliency)
148-
return saliency_abs
149161

150-
# --- Main loop: compute per-patient saliency for each target ---
162+
# =============================================================
163+
# 3. Generate Per-Patient & Global Saliency Outputs
164+
# =============================================================
165+
batch_size = 4
166+
151167
for target_name, head_key in TARGETS:
152-
print(f"[INFO] Computing saliency for target: {target_name} (head: {head_key})")
153-
154-
# We'll compute per-patient saliency in batches to be memory-friendly
155-
batch_size = 4
156-
per_patient_saliency = [] # list of (T, F) arrays
157-
158-
with torch.no_grad():
159-
# NOTE: we need gradients; temporarily enable gradient context by switching to requires_grad path inside helper
160-
# We'll compute in small batches and re-enable grad per-batch using compute_saliency_for_batch
161-
for i in range(0, n_test, batch_size):
162-
xb = x_test[i:i+batch_size].to(device)
163-
mb = mask_test[i:i+batch_size].to(device)
164-
# compute with gradient tracking ON inside helper
165-
sal_b = compute_saliency_for_batch(model, xb, mb, head_key)
166-
per_patient_saliency.append(sal_b)
167-
168-
per_patient_saliency = np.concatenate(per_patient_saliency, axis=0) # shape (n_test, T, F)
169-
print(f"[INFO] Saliency array shape for {target_name}: {per_patient_saliency.shape}")
170-
171-
# Save per-patient saliency arrays into npz (keyed by patient index)
168+
print(f"\n[INFO] ===== Saliency for target: {target_name} ({head_key}) =====")
169+
170+
# --- Compute per-patient saliency ---
171+
per_patient_saliency = []
172+
for i in tqdm(range(0, n_test, batch_size), desc=f"Processing {target_name}"):
173+
xb = x_test[i:i+batch_size].to(device)
174+
mb = mask_test[i:i+batch_size].to(device)
175+
sal_b = compute_saliency_for_batch(model, xb, mb, head_key)
176+
per_patient_saliency.append(sal_b)
177+
178+
per_patient_saliency = np.concatenate(per_patient_saliency, axis=0)
179+
print(f"[INFO] Saliency shape: {per_patient_saliency.shape}")
180+
181+
# --- Save per-patient arrays ---
172182
save_npz = RESULTS_DIR / f"patient_saliency_{target_name}.npz"
173-
npz_dict = {f"patient_{i}": per_patient_saliency[i] for i in range(per_patient_saliency.shape[0])}
174-
np.savez_compressed(save_npz, **npz_dict)
175-
print(f"[INFO] Saved per-patient saliency arrays → {save_npz}")
176-
177-
# --- Generate patient-level heatmap PNGs (optional: small number, here save all test patients) ---
178-
for i in range(per_patient_saliency.shape[0]):
179-
arr = per_patient_saliency[i] # (T, F)
180-
# Optionally zero-out padded timesteps using mask
183+
np.savez_compressed(save_npz, **{f"patient_{i}": per_patient_saliency[i] for i in range(n_test)})
184+
print(f"[INFO] Saved saliency arrays → {save_npz}")
185+
186+
# =========================================================
187+
# 3A. Patient-Level Heatmaps
188+
# =========================================================
189+
for i in range(n_test):
190+
arr = per_patient_saliency[i]
181191
mask_np = mask_test[i].cpu().numpy().astype(bool)
182-
# Masked heatmap: set padded rows to NaN for plotting transparency
183-
plot_arr = arr.copy()
184-
if mask_np.shape[0] == plot_arr.shape[0]:
185-
plot_arr[~mask_np, :] = np.nan
192+
arr[~mask_np, :] = np.nan
186193

187194
plt.figure(figsize=(14, 6))
188-
plt.imshow(plot_arr.T, aspect='auto', interpolation='nearest')
195+
plt.imshow(arr.T, aspect='auto', interpolation='nearest')
189196
plt.colorbar(label='|grad * input|')
190-
plt.ylabel('Feature (index)')
191-
plt.yticks(ticks=np.arange(len(feature_cols)), labels=feature_cols, fontsize=6)
197+
plt.ylabel('Feature')
198+
plt.yticks(np.arange(len(feature_cols)), feature_cols, fontsize=6)
192199
plt.xlabel('Time (timestep)')
193-
plt.title(f"Saliency heatmap{target_name}patient {i}")
200+
plt.title(f"Saliency Heatmap{target_name}Patient {i}")
194201
plt.tight_layout()
195-
out_png = RESULTS_DIR / f"{target_name}_patient_{i:02d}_heatmap.png"
196-
plt.savefig(out_png, dpi=200)
202+
plt.savefig(RESULTS_DIR / f"{target_name}_patient_{i:02d}_heatmap.png", dpi=200)
197203
plt.close()
198204

199-
print(f"[INFO] Saved per-patient heatmaps for {target_name} (n={per_patient_saliency.shape[0]})")
205+
print(f"[INFO] Saved {n_test} patient-level heatmaps for {target_name}")
200206

201-
# --- Global mean heatmap (mean over patients) ---
202-
mean_over_patients = np.nanmean(per_patient_saliency, axis=0) # (T, F)
207+
# =========================================================
208+
# 3B. Global Mean Heatmap
209+
# =========================================================
210+
mean_saliency = np.nanmean(per_patient_saliency, axis=0)
203211
plt.figure(figsize=(14, 6))
204-
plt.imshow(mean_over_patients.T, aspect='auto', interpolation='nearest')
212+
plt.imshow(mean_saliency.T, aspect='auto', interpolation='nearest')
205213
plt.colorbar(label='mean |grad * input|')
206-
plt.ylabel('Feature (index)')
207-
plt.yticks(ticks=np.arange(len(feature_cols)), labels=feature_cols, fontsize=6)
214+
plt.ylabel('Feature')
215+
plt.yticks(np.arange(len(feature_cols)), feature_cols, fontsize=6)
208216
plt.xlabel('Time (timestep)')
209-
plt.title(f"Mean Saliency heatmap {target_name}mean across test patients")
217+
plt.title(f"Mean Saliency — {target_name}Averaged Across Patients")
210218
plt.tight_layout()
211-
out_mean_png = RESULTS_DIR / f"{target_name}_mean_heatmap.png"
212-
plt.savefig(out_mean_png, dpi=200)
219+
plt.savefig(RESULTS_DIR / f"{target_name}_mean_heatmap.png", dpi=200)
213220
plt.close()
214-
print(f"[INFO] Saved mean heatmap → {out_mean_png}")
215-
216-
# --- Top-10 features by mean absolute saliency (averaged over patients & time) ---
217-
# Compute mean abs saliency per feature: first mean over patients and time
218-
feature_mean_sal = per_patient_saliency.mean(axis=(0, 1)) # (F,)
219-
df_top = pd.DataFrame({
220-
'feature': feature_cols,
221-
'mean_abs_saliency': feature_mean_sal
222-
}).sort_values('mean_abs_saliency', ascending=False)
223-
221+
print(f"[INFO] Saved global mean heatmap for {target_name}")
222+
223+
# =========================================================
224+
# 3C. Top-10 Feature Ranking
225+
# =========================================================
226+
feature_mean_sal = per_patient_saliency.mean(axis=(0, 1))
227+
df_top = (
228+
pd.DataFrame({"feature": feature_cols, "mean_abs_saliency": feature_mean_sal})
229+
.sort_values("mean_abs_saliency", ascending=False)
230+
.head(10)
231+
)
224232
df_top.to_csv(RESULTS_DIR / f"{target_name}_top10_saliency.csv", index=False)
225-
print(f"[INFO] Saved top-10 feature saliency CSV{RESULTS_DIR / f'{target_name}_top10_saliency.csv'}")
233+
print(f"[INFO] Saved Top-10 Features{RESULTS_DIR / f'{target_name}_top10_saliency.csv'}")
226234

227-
print('[INFO] TCN saliency computation complete')
235+
print("\n[INFO] TCN saliency computation complete.")

0 commit comments

Comments
 (0)