Skip to content

Commit 2cdb490

Browse files
Update new_simpleNIPA.py
Checking crossvalpcr
1 parent 974ab94 commit 2cdb490

File tree

1 file changed

+179
-12
lines changed

1 file changed

+179
-12
lines changed

albatross/new_simpleNIPA.py

Lines changed: 179 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -166,17 +166,17 @@ def gridCheck(self, lim=5, ntim=2, debug=False):
166166

167167
return
168168

169-
def crossvalpcr(self, xval=True, explained_variance_threshold=0.95):
169+
"""def crossvalpcr(self, xval=True, explained_variance_threshold=0.95):
170170
import numpy as np
171171
from scipy.stats import pearsonr as corr
172172
from sklearn.linear_model import LinearRegression
173173
from sklearn.model_selection import KFold
174-
from albatross.utils import weightsst
174+
from albatross.utils import weight_glo_var
175175
176176
predictand = self.clim_data
177177
178178
# Check for insufficient SST data
179-
if self.corr_grid.mask.sum() >= len(self.sst.lat) * len(self.sst.lon) - 4:
179+
if self.corr_grid.mask.sum() >= len(self.glo_var.lat) * len(self.glo_var.lon) - 4:
180180
self.flags [ "noSST" ] = True
181181
self.hindcast = None
182182
self.pcs = None
@@ -186,13 +186,13 @@ def crossvalpcr(self, xval=True, explained_variance_threshold=0.95):
186186
return
187187
188188
self.flags [ "noSST" ] = False
189-
sstidx = ~self.corr_grid.mask
190-
raw_sst = weightsst(self.sst).data [ :, sstidx ]
189+
glo_var_idx = ~self.corr_grid.mask
190+
raw_glo_var = weight_glo_var(self.glo_var).data [ :, glo_var_idx ]
191191
n_samples = len(predictand)
192192
193193
if not xval:
194194
# Standard PCA regression (no CV)
195-
cov_matrix = np.cov(raw_sst.T)
195+
cov_matrix = np.cov(raw_glo_var.T)
196196
eigval, eigvec = np.linalg.eig(cov_matrix)
197197
eigval, eigvec = np.real(eigval), np.real(eigvec)
198198
@@ -204,7 +204,7 @@ def crossvalpcr(self, xval=True, explained_variance_threshold=0.95):
204204
n_pc = np.searchsorted(cumulative_var, explained_variance_threshold) + 1
205205
206206
eofs = eigvec [ :, :n_pc ]
207-
pcs = raw_sst.dot(eofs)
207+
pcs = raw_glo_var.dot(eofs)
208208
209209
reg = LinearRegression().fit(pcs, predictand)
210210
yhat = reg.predict(pcs)
@@ -221,13 +221,13 @@ def crossvalpcr(self, xval=True, explained_variance_threshold=0.95):
221221
222222
# Cross-validation PCA regression
223223
yhat = np.zeros(n_samples)
224-
pcs_all = np.zeros((n_samples, raw_sst.shape [ 1 ]))
224+
pcs_all = np.zeros((n_samples, raw_glo_var.shape [ 1 ]))
225225
models = [ ]
226226
227227
kf = KFold(n_splits=5, shuffle=True, random_state=42)
228228
229-
for train_idx, test_idx in kf.split(raw_sst):
230-
X_train, X_test = raw_sst [ train_idx ], raw_sst [ test_idx ]
229+
for train_idx, test_idx in kf.split(raw_glo_var):
230+
X_train, X_test = raw_glo_var [ train_idx ], raw_glo_var [ test_idx ]
231231
y_train = predictand [ train_idx ]
232232
233233
cov_matrix = np.cov(X_train.T)
@@ -281,15 +281,182 @@ def crossvalpcr(self, xval=True, explained_variance_threshold=0.95):
281281
n_pc_best = best_model [ "n_pc" ]
282282
283283
# Refit on full dataset using best number of PCs
284-
cov_matrix = np.cov(raw_sst.T)
284+
cov_matrix = np.cov(raw_glo_var.T)
285+
eigval, eigvec = np.linalg.eig(cov_matrix)
286+
eigval, eigvec = np.real(eigval), np.real(eigvec)
287+
288+
sorted_idx = np.argsort(eigval) [ ::-1 ]
289+
eigvec = eigvec [ :, sorted_idx ]
290+
291+
eofs = eigvec [ :, :n_pc_best ]
292+
pcs_full = raw_glo_var.dot(eofs)
293+
reg_full = LinearRegression().fit(pcs_full, predictand)
294+
295+
self.pcs = pcs_full
296+
self.lin_model = {
297+
"eofs": eofs,
298+
"regression": reg_full,
299+
"n_pc": n_pc_best
300+
}"""
301+
302+
def crossvalpcr(self, xval=True, explained_variance_threshold=0.95):
303+
import numpy as np
304+
from scipy.stats import pearsonr as corr
305+
from sklearn.linear_model import LinearRegression
306+
from sklearn.model_selection import KFold
307+
from albatross.utils import weight_glo_var, vcorr, sig_test
308+
309+
predictand = self.clim_data
310+
n_samples = len(predictand)
311+
yhat = np.zeros(n_samples)
312+
313+
# NOTE: The full corr_grid and raw_glo_var are no longer calculated here.
314+
# They will be calculated *inside* each CV loop to prevent data leakage.
315+
316+
if not xval:
317+
# Standard PCA regression (no CV)
318+
# This part still needs the full corr_grid, but it's not for CV
319+
# so data leakage isn't an issue here.
320+
self.bootcorr(corrconf=0.95)
321+
self.gridCheck()
322+
if self.corr_grid.mask.sum() >= len(self.glo_var.lat) * len(self.glo_var.lon) - 4:
323+
self.flags [ "noSST" ] = True
324+
self.hindcast = None
325+
self.pcs = None
326+
self.lin_model = None
327+
self.correlation = None
328+
print("Insufficient SST data for PCA regression.")
329+
return
330+
331+
glo_var_idx = ~self.corr_grid.mask
332+
raw_glo_var = weight_glo_var(self.glo_var).data [ :, glo_var_idx ]
333+
334+
cov_matrix = np.cov(raw_glo_var.T)
335+
eigval, eigvec = np.linalg.eig(cov_matrix)
336+
eigval, eigvec = np.real(eigval), np.real(eigvec)
337+
338+
sorted_idx = np.argsort(eigval) [ ::-1 ]
339+
eigvec = eigvec [ :, sorted_idx ]
340+
341+
explained_ratio = eigval / eigval.sum()
342+
cumulative_var = np.cumsum(explained_ratio)
343+
n_pc = np.searchsorted(cumulative_var, explained_variance_threshold) + 1
344+
345+
eofs = eigvec [ :, :n_pc ]
346+
pcs = raw_glo_var.dot(eofs)
347+
348+
reg = LinearRegression().fit(pcs, predictand)
349+
yhat = reg.predict(pcs)
350+
351+
self.pcs = pcs
352+
self.hindcast = yhat
353+
self.correlation = corr(predictand, yhat) [ 0 ]
354+
self.lin_model = {
355+
"eofs": eofs,
356+
"regression": reg,
357+
"n_pc": n_pc
358+
}
359+
return
360+
361+
# Cross-validation PCA regression
362+
models = [ ]
363+
kf = KFold(n_splits=5, shuffle=True, random_state=42)
364+
p_value_threshold = 1 - 0.95 # This should match the value used in bootcorr
365+
366+
# Get full global data for slicing later
367+
full_glo_var_data = weight_glo_var(self.glo_var).data
368+
369+
for train_idx, test_idx in kf.split(predictand):
370+
X_train_full = full_glo_var_data [ train_idx ]
371+
X_test_full = full_glo_var_data [ test_idx ]
372+
y_train = predictand [ train_idx ]
373+
374+
# --- CORRECTION: Step 1 (Feature Selection within the loop) ---
375+
# Calculate correlation mask on *training data only*
376+
corr_grid_train = vcorr(X=X_train_full, y=y_train)
377+
n_yrs_train = len(y_train)
378+
p_value = sig_test(corr_grid_train, n_yrs_train)
379+
380+
# This is a simplified version of bootcorr for demonstration.
381+
# You may want to call a modified `bootcorr_fold` function.
382+
glo_var_idx_train = ~np.ma.masked_array(corr_grid_train, ~(p_value < p_value_threshold)).mask
383+
384+
# Check for insufficient data in the training set
385+
if glo_var_idx_train.sum()==0:
386+
print("Skipping fold: Insufficient SST data in training set.")
387+
continue
388+
389+
# --- CORRECTION: Step 2 (Data Filtering) ---
390+
# Apply the mask from the training set to both training and test data
391+
X_train = X_train_full [ :, glo_var_idx_train ]
392+
X_test = X_test_full [ :, glo_var_idx_train ]
393+
394+
# --- CORRECTION: Step 3 (PCA and Regression within the loop) ---
395+
# Perform PCA and regression on the training data
396+
cov_matrix = np.cov(X_train.T)
397+
eigval, eigvec = np.linalg.eig(cov_matrix)
398+
eigval, eigvec = np.real(eigval), np.real(eigvec)
399+
400+
sorted_idx = np.argsort(eigval) [ ::-1 ]
401+
eigvec = eigvec [ :, sorted_idx ]
402+
403+
explained_ratio = eigval / eigval.sum()
404+
cumulative_var = np.cumsum(explained_ratio)
405+
n_pc = np.searchsorted(cumulative_var, explained_variance_threshold) + 1
406+
407+
if n_pc==0 or np.isnan(eigval [ :n_pc ]).any():
408+
continue
409+
410+
eofs = eigvec [ :, :n_pc ]
411+
pcs_train = X_train.dot(eofs)
412+
pcs_test = X_test.dot(eofs)
413+
414+
if pcs_train.shape [ 0 ] < n_pc:
415+
continue
416+
417+
reg = LinearRegression().fit(pcs_train, y_train)
418+
preds = reg.predict(pcs_test)
419+
420+
yhat [ test_idx ] = preds
421+
422+
models.append({
423+
"eofs": eofs,
424+
"regression": reg,
425+
"n_pc": n_pc,
426+
"corr": corr(y_train, reg.predict(pcs_train)) [ 0 ]
427+
})
428+
429+
if not models:
430+
self.hindcast = None
431+
self.pcs = None
432+
self.lin_model = None
433+
self.correlation = None
434+
self.flags [ "noSST" ] = True
435+
return
436+
437+
# Store hindcast from CV
438+
self.hindcast = yhat
439+
self.correlation = corr(predictand, yhat) [ 0 ]
440+
441+
# Select best number of PCs and refit on all data
442+
best_model = max(models, key=lambda m: m [ "corr" ])
443+
n_pc_best = best_model [ "n_pc" ]
444+
445+
# --- Refit on full dataset (this can still use the old method) ---
446+
self.bootcorr(corrconf=0.95)
447+
self.gridCheck()
448+
glo_var_idx_full = ~self.corr_grid.mask
449+
raw_glo_var_full = weight_glo_var(self.glo_var).data [ :, glo_var_idx_full ]
450+
451+
cov_matrix = np.cov(raw_glo_var_full.T)
285452
eigval, eigvec = np.linalg.eig(cov_matrix)
286453
eigval, eigvec = np.real(eigval), np.real(eigvec)
287454

288455
sorted_idx = np.argsort(eigval) [ ::-1 ]
289456
eigvec = eigvec [ :, sorted_idx ]
290457

291458
eofs = eigvec [ :, :n_pc_best ]
292-
pcs_full = raw_sst.dot(eofs)
459+
pcs_full = raw_glo_var_full.dot(eofs)
293460
reg_full = LinearRegression().fit(pcs_full, predictand)
294461

295462
self.pcs = pcs_full

0 commit comments

Comments
 (0)