1818import json
1919from pathlib import Path
2020import numpy as np
21+
22+ # Core deep learning library (model loading, tensor ops)
2123import torch
2224import matplotlib .pyplot as plt
2325import pandas as pd
2426
25-
27+ # For progress bars when looping through patients or batches
2628from tqdm import tqdm
2729
28-
30+ # Import the TCN model architecture definition
2931from ml_models_tcn .tcn_model import TCNModel
3032
3133# -----------------------
3234# Path Directories
3335# -----------------------
3436SCRIPT_DIR = Path (__file__ ).resolve ().parent
3537
36- # === Input directories ===
38+ # === TCN Model ===
3739TRAINED_MODEL_PATH = SCRIPT_DIR .parent .parent / "src" / "prediction_diagnostics" / "trained_models_refined" / "tcn_best_refined.pt"
3840CONFIG_PATH = SCRIPT_DIR .parent .parent / "src" / "prediction_diagnostics" / "trained_models_refined" / "config_refined.json"
39- TEST_DATA_DIR = SCRIPT_DIR .parent .parent / "src" / "ml_models_tcn" / "prepared_datasets" # test tensors
41+
42+ # === TCN data and preprocessing directories ===
43+ TEST_DATA_DIR = SCRIPT_DIR .parent .parent / "src" / "ml_models_tcn" / "prepared_datasets"
4044TCN_DIR = SCRIPT_DIR .parent .parent / "src" / "ml_models_tcn" / "deployment_models" / "preprocessing"
45+
46+ # === Preprocessing artifacts ===
4147SPLITS_PATH = TCN_DIR / "patient_splits.json"
4248PADDING_PATH = TCN_DIR / "padding_config.json"
4349SCALER_PATH = TCN_DIR / "standard_scaler.pkl"
4450
51+ # === Model tensors ===
52+ TEST_TENSOR = TEST_DATA_DIR / "test.pt"
53+ MASK_TENSOR = TEST_DATA_DIR / "test_mask.pt"
54+
4555# === Output directory ===
4656RESULTS_DIR = SCRIPT_DIR / "interpretability_tcn"
4757RESULTS_DIR .mkdir (parents = True , exist_ok = True )
4858
4959# -----------------------------
5060# Sanity Checks
5161# -----------------------------
52- assert TRAINED_MODEL_PATH .exists (), f"Missing model: { TRAINED_MODEL_PATH } "
53- assert CONFIG_PATH .exists (), f"Missing config: { CONFIG_PATH } "
54- assert (TEST_DATA_DIR / "test.pt" ).exists (), f"Missing test tensor: { TEST_DATA_DIR / 'test.pt' } "
55- assert (TEST_DATA_DIR / "test_mask.pt" ).exists (), f"Missing test mask: { TEST_DATA_DIR / 'test_mask.pt' } "
56- assert SPLITS_PATH .exists (), f"Missing splits: { SPLITS_PATH } "
62+ assert TRAINED_MODEL_PATH .exists (), f"[ERROR] Missing trained model file: { TRAINED_MODEL_PATH } "
63+ assert CONFIG_PATH .exists (), f"[ERROR] Missing model configuration: { CONFIG_PATH } "
64+ assert (TEST_DATA_DIR / "test.pt" ).exists (), f"[ERROR] Missing test tensor: { TEST_DATA_DIR / 'test.pt' } "
65+ assert (TEST_DATA_DIR / "test_mask.pt" ).exists (), f"[ERROR] Missing test mask tensor: { TEST_DATA_DIR / 'test_mask.pt' } "
66+ assert SPLITS_PATH .exists (), f"[ERROR] Missing patient splits file: { SPLITS_PATH } "
67+ assert PADDING_PATH .exists (), f"[ERROR] Missing padding config file: { PADDING_PATH } "
68+ assert SCALER_PATH .exists (), f"[ERROR] Missing standard scaler file: { SCALER_PATH } "
69+
70+ print ("[INFO] All required input files found. Ready to proceed." )
5771
5872# -----------------------------
59- # Device
73+ # 1. Load Model, Config, and Test Data
6074# -----------------------------
75+ # --- Load device (cpu or gpu) ---
6176device = torch .device ("cuda" if torch .cuda .is_available () else "cpu" )
6277print (f"[INFO] Using device: { device } " )
6378
64- # -----------------------------
65- # Load Model + Config + Test Tensors
66- # -----------------------------
67- # Load model configuration (refined)
79+ # --- Load config (architecture & parameters) ---
6880with open (CONFIG_PATH ) as f :
6981 config = json .load (f )
7082arch = config ["model_architecture" ]
7183
72- with open (SCALER_DIR / "padding_config.json" , 'r' ) as f :
84+ # --- Load padding/feature configuration ---
85+ with open (PADDING_PATH ) as f :
7386 pad_cfg = json .load (f )
74-
7587feature_cols = pad_cfg ['feature_cols' ]
7688MAX_SEQ_LEN = pad_cfg ['max_seq_len' ]
7789target_cols = pad_cfg ["target_cols" ]
7890
79- # Load tensors (x_test: [n_patients, seq_len, n_features], mask_test similar)
80- x_test = torch .load (TEST_DATA_DIR / "test.pt" , map_location = device )
81- mask_test = torch .load (TEST_DATA_DIR / "test_mask.pt" , map_location = device )
91+ # --- Load test tensors ---
92+ x_test = torch .load (TEST_TENSOR , map_location = device )
93+ mask_test = torch .load (MASK_TENSOR , map_location = device )
8294
95+ # --- Validate test tensor shapes (shape should match padding config) ---
8396n_test , seq_len , n_features = x_test .shape
84- print (f"[INFO] Loaded x_test: { x_test .shape } , mask: { mask_test .shape } " )
8597assert seq_len == MAX_SEQ_LEN , f"Expected seq_len { MAX_SEQ_LEN } , got { seq_len } "
8698assert n_features == len (feature_cols ), "Feature dimension mismatch with padding_config"
99+ print (f"[INFO] Loaded test data: { x_test .shape } , mask: { mask_test .shape } " )
87100
88- # --- Reload model architecture and weights ---
89- arch = config ['model_architecture' ]
101+ # --- Rebuild model architecture (from config) ---
90102model = TCNModel (
91103 num_features = n_features ,
92104 num_channels = arch ['num_channels' ],
93105 kernel_size = arch ['kernel_size' ],
94106 dropout = arch ['dropout' ],
95107 head_hidden = arch ['head_hidden' ]
96- )
108+ ). to ( device )
97109
110+ # --- Load trained model weights (tcn_best_refined.pt) ---
98111state_dict = torch .load (TRAINED_MODEL_PATH , map_location = device )
99112model .load_state_dict (state_dict )
100- model .to (device )
113+
114+ # --- Set model to eval mode ---
101115model .eval ()
102- print ('[INFO] Loaded TCN model' )
116+ print ("[INFO] Loaded TCN model and moved to device." )
117+
118+ # -----------------------------
119+ # 2. Define Targets & Saliency Function
120+ # -----------------------------
103121
104- # Targets we will explain and corresponding model output keys
122+ # --- Target heads to explain ---
105123TARGETS = [
106- ("max " , "logit_max" ), # binary classification -> use logit (pre-sigmoid)
107- ("median " , "logit_median" ),
108- ("pct_time_high" , "regression" ) # regression head
124+ ("max_risk " , "logit_max" ), # classification
125+ ("median_risk " , "logit_median" ), # classification
126+ ("pct_time_high" , "regression" ) # regression
109127]
110128
111- # --- Helper: compute gradient*input saliency for a batch of patients ---
129+ # --- Saliency Computation Helper ---
112130def compute_saliency_for_batch (model , x_batch , mask_batch , head_key ):
113- """Compute gradient * input saliency for a batch.
114- x_batch: torch tensor (B, T, F) requires_grad=False
115- mask_batch: torch tensor (B, T) — not used directly in gradient calc but kept for clarity
116- head_key: string key returned by model(x, mask) -> selects which scalar to compute gradient of
117-
118- Returns: saliency_abs: numpy array shape (B, T, F) of absolute(grad * input)
119131 """
120- # Ensure we run on device
132+ Compute |grad * input| saliency for a mini-batch.
133+ Args:
134+ model: trained TCN model
135+ x_batch: tensor (B, T, F)
136+ mask_batch: tensor (B, T)
137+ head_key: output head to explain ('logit_max', 'logit_median', 'regression')
138+ Returns:
139+ np.ndarray of shape (B, T, F) = |gradient * input|
140+ """
121141 x = x_batch .clone ().detach ().to (device )
122142 x .requires_grad = True
123-
124- mask = mask_batch .to (device ) if mask_batch is not None else None
143+ mask = mask_batch .to (device )
125144
126145 outputs = model (x , mask )
127- out = outputs [head_key ] # shape (B, 1) or (B,)
128- out = out .squeeze ()
146+ out = outputs [head_key ].squeeze ()
129147
130- # For multi-output batch: compute gradient of each scalar with respect to inputs
131- # We'll compute a vector-Jacobian product that results in gradients of shape (B, T, F)
132148 grads = []
133149 for i in range (out .shape [0 ]):
134- # Zero existing grads
135150 if x .grad is not None :
136151 x .grad .zero_ ()
137152 scalar = out [i ]
138- # Backprop scalar
139153 scalar .backward (retain_graph = True )
140- g = x .grad [i ].detach ().cpu ().numpy ().copy () # (T, F)
141- grads .append (g )
142- grads = np .stack (grads , axis = 0 ) # (B, T, F)
154+ grads .append (x .grad [i ].detach ().cpu ().numpy ())
155+ grads = np .stack (grads , axis = 0 )
156+
157+ saliency = grads * x .detach ().cpu ().numpy ()
158+ return np .abs (saliency )
143159
144- x_np = x .detach ().cpu ().numpy () # (B, T, F)
145160
146- saliency = grads * x_np # elementwise grad * input
147- saliency_abs = np .abs (saliency )
148- return saliency_abs
149161
150- # --- Main loop: compute per-patient saliency for each target ---
162+ # =============================================================
163+ # 3. Generate Per-Patient & Global Saliency Outputs
164+ # =============================================================
165+ batch_size = 4
166+
151167for target_name , head_key in TARGETS :
152- print (f"[INFO] Computing saliency for target: { target_name } (head: { head_key } )" )
153-
154- # We'll compute per-patient saliency in batches to be memory-friendly
155- batch_size = 4
156- per_patient_saliency = [] # list of (T, F) arrays
157-
158- with torch .no_grad ():
159- # NOTE: we need gradients; temporarily enable gradient context by switching to requires_grad path inside helper
160- # We'll compute in small batches and re-enable grad per-batch using compute_saliency_for_batch
161- for i in range (0 , n_test , batch_size ):
162- xb = x_test [i :i + batch_size ].to (device )
163- mb = mask_test [i :i + batch_size ].to (device )
164- # compute with gradient tracking ON inside helper
165- sal_b = compute_saliency_for_batch (model , xb , mb , head_key )
166- per_patient_saliency .append (sal_b )
167-
168- per_patient_saliency = np .concatenate (per_patient_saliency , axis = 0 ) # shape (n_test, T, F)
169- print (f"[INFO] Saliency array shape for { target_name } : { per_patient_saliency .shape } " )
170-
171- # Save per-patient saliency arrays into npz (keyed by patient index)
168+ print (f"\n [INFO] ===== Saliency for target: { target_name } ({ head_key } ) =====" )
169+
170+ # --- Compute per-patient saliency ---
171+ per_patient_saliency = []
172+ for i in tqdm (range (0 , n_test , batch_size ), desc = f"Processing { target_name } " ):
173+ xb = x_test [i :i + batch_size ].to (device )
174+ mb = mask_test [i :i + batch_size ].to (device )
175+ sal_b = compute_saliency_for_batch (model , xb , mb , head_key )
176+ per_patient_saliency .append (sal_b )
177+
178+ per_patient_saliency = np .concatenate (per_patient_saliency , axis = 0 )
179+ print (f"[INFO] Saliency shape: { per_patient_saliency .shape } " )
180+
181+ # --- Save per-patient arrays ---
172182 save_npz = RESULTS_DIR / f"patient_saliency_{ target_name } .npz"
173- npz_dict = {f"patient_{ i } " : per_patient_saliency [i ] for i in range (per_patient_saliency . shape [ 0 ])}
174- np . savez_compressed ( save_npz , ** npz_dict )
175- print ( f"[INFO] Saved per-patient saliency arrays → { save_npz } " )
176-
177- # --- Generate patient-level heatmap PNGs (optional: small number, here save all test patients) ---
178- for i in range ( per_patient_saliency . shape [ 0 ]):
179- arr = per_patient_saliency [ i ] # (T, F)
180- # Optionally zero-out padded timesteps using mask
183+ np . savez_compressed ( save_npz , ** {f"patient_{ i } " : per_patient_saliency [i ] for i in range (n_test )})
184+ print ( f"[INFO] Saved saliency arrays → { save_npz } " )
185+
186+ # =========================================================
187+ # 3A. Patient-Level Heatmaps
188+ # =========================================================
189+ for i in range ( n_test ):
190+ arr = per_patient_saliency [ i ]
181191 mask_np = mask_test [i ].cpu ().numpy ().astype (bool )
182- # Masked heatmap: set padded rows to NaN for plotting transparency
183- plot_arr = arr .copy ()
184- if mask_np .shape [0 ] == plot_arr .shape [0 ]:
185- plot_arr [~ mask_np , :] = np .nan
192+ arr [~ mask_np , :] = np .nan
186193
187194 plt .figure (figsize = (14 , 6 ))
188- plt .imshow (plot_arr .T , aspect = 'auto' , interpolation = 'nearest' )
195+ plt .imshow (arr .T , aspect = 'auto' , interpolation = 'nearest' )
189196 plt .colorbar (label = '|grad * input|' )
190- plt .ylabel ('Feature (index) ' )
191- plt .yticks (ticks = np .arange (len (feature_cols )), labels = feature_cols , fontsize = 6 )
197+ plt .ylabel ('Feature' )
198+ plt .yticks (np .arange (len (feature_cols )), feature_cols , fontsize = 6 )
192199 plt .xlabel ('Time (timestep)' )
193- plt .title (f"Saliency heatmap — { target_name } — patient { i } " )
200+ plt .title (f"Saliency Heatmap — { target_name } — Patient { i } " )
194201 plt .tight_layout ()
195- out_png = RESULTS_DIR / f"{ target_name } _patient_{ i :02d} _heatmap.png"
196- plt .savefig (out_png , dpi = 200 )
202+ plt .savefig (RESULTS_DIR / f"{ target_name } _patient_{ i :02d} _heatmap.png" , dpi = 200 )
197203 plt .close ()
198204
199- print (f"[INFO] Saved per- patient heatmaps for { target_name } (n= { per_patient_saliency . shape [ 0 ] } ) " )
205+ print (f"[INFO] Saved { n_test } patient-level heatmaps for { target_name } " )
200206
201- # --- Global mean heatmap (mean over patients) ---
202- mean_over_patients = np .nanmean (per_patient_saliency , axis = 0 ) # (T, F)
207+ # =========================================================
208+ # 3B. Global Mean Heatmap
209+ # =========================================================
210+ mean_saliency = np .nanmean (per_patient_saliency , axis = 0 )
203211 plt .figure (figsize = (14 , 6 ))
204- plt .imshow (mean_over_patients .T , aspect = 'auto' , interpolation = 'nearest' )
212+ plt .imshow (mean_saliency .T , aspect = 'auto' , interpolation = 'nearest' )
205213 plt .colorbar (label = 'mean |grad * input|' )
206- plt .ylabel ('Feature (index) ' )
207- plt .yticks (ticks = np .arange (len (feature_cols )), labels = feature_cols , fontsize = 6 )
214+ plt .ylabel ('Feature' )
215+ plt .yticks (np .arange (len (feature_cols )), feature_cols , fontsize = 6 )
208216 plt .xlabel ('Time (timestep)' )
209- plt .title (f"Mean Saliency heatmap — { target_name } — mean across test patients " )
217+ plt .title (f"Mean Saliency — { target_name } — Averaged Across Patients " )
210218 plt .tight_layout ()
211- out_mean_png = RESULTS_DIR / f"{ target_name } _mean_heatmap.png"
212- plt .savefig (out_mean_png , dpi = 200 )
219+ plt .savefig (RESULTS_DIR / f"{ target_name } _mean_heatmap.png" , dpi = 200 )
213220 plt .close ()
214- print (f"[INFO] Saved mean heatmap → { out_mean_png } " )
215-
216- # --- Top-10 features by mean absolute saliency (averaged over patients & time) ---
217- # Compute mean abs saliency per feature: first mean over patients and time
218- feature_mean_sal = per_patient_saliency .mean (axis = (0 , 1 )) # (F,)
219- df_top = pd .DataFrame ({
220- 'feature' : feature_cols ,
221- 'mean_abs_saliency' : feature_mean_sal
222- }).sort_values ('mean_abs_saliency' , ascending = False )
223-
221+ print (f"[INFO] Saved global mean heatmap for { target_name } " )
222+
223+ # =========================================================
224+ # 3C. Top-10 Feature Ranking
225+ # =========================================================
226+ feature_mean_sal = per_patient_saliency .mean (axis = (0 , 1 ))
227+ df_top = (
228+ pd .DataFrame ({"feature" : feature_cols , "mean_abs_saliency" : feature_mean_sal })
229+ .sort_values ("mean_abs_saliency" , ascending = False )
230+ .head (10 )
231+ )
224232 df_top .to_csv (RESULTS_DIR / f"{ target_name } _top10_saliency.csv" , index = False )
225- print (f"[INFO] Saved top -10 feature saliency CSV → { RESULTS_DIR / f'{ target_name } _top10_saliency.csv' } " )
233+ print (f"[INFO] Saved Top -10 Features → { RESULTS_DIR / f'{ target_name } _top10_saliency.csv' } " )
226234
227- print (' [INFO] TCN saliency computation complete' )
235+ print (" \n [INFO] ✅ TCN saliency computation complete." )
0 commit comments