Skip to content

Commit d8a19d5

Browse files
authored
Merge pull request #55 from rmarkello/missingno
[REF] Handle missing data rows in pls_regression
2 parents ae162f7 + 4fbd403 commit d8a19d5

File tree

2 files changed

+48
-21
lines changed

2 files changed

+48
-21
lines changed

pyls/tests/types/test_regression.py

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -80,6 +80,17 @@ def test_regression_3dbootstrap(aggfunc):
8080
bootsamples=bootsamples, n_boot=10)
8181

8282

83+
def test_regression_missingdata():
84+
X = rs.rand(subj, Xf)
85+
X[10] = np.nan
86+
PLSRegressionTests(X=X, n_components=2)
87+
X[20] = np.nan
88+
PLSRegressionTests(X=X, n_components=2)
89+
Y = rs.rand(subj, Yf)
90+
Y[11] = np.nan
91+
PLSRegressionTests(X=X, Y=Y, n_components=2)
92+
93+
8394
def test_errors():
8495
with pytest.raises(ValueError):
8596
PLSRegressionTests(n_components=1000)

pyls/types/regression.py

Lines changed: 37 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -45,6 +45,14 @@ def resid_yscores(x_scores, y_scores, copy=True):
4545
return y_scores
4646

4747

48+
def get_mask(X, Y):
49+
""" Returns mask removing rows where either X/Y contain all NaN values
50+
"""
51+
52+
return np.logical_not(np.logical_or(np.all(np.isnan(X), axis=1),
53+
np.all(np.isnan(Y), axis=1)))
54+
55+
4856
def simpls(X, Y, n_components=None, seed=1234):
4957
"""
5058
Performs partial least squares regression with the SIMPLS algorithm
@@ -226,10 +234,11 @@ def __init__(self, X, Y, *, n_components=None,
226234
'or one of {}'.format(sorted(aggfuncs)))
227235
self.aggfunc = aggfuncs.get(aggfunc, aggfunc)
228236

237+
# these need to be zero -- they're not implemented for PLSRegression
238+
kwargs.update(n_split=0, test_split=0)
229239
super().__init__(X=np.asarray(X), Y=np.asarray(Y),
230240
n_components=n_components, n_perm=n_perm,
231-
n_boot=n_boot, n_split=0, test_split=0,
232-
rotate=rotate, ci=ci, aggfunc=aggfunc,
241+
n_boot=n_boot, rotate=rotate, ci=ci, aggfunc=aggfunc,
233242
permsamples=permsamples, bootsamples=bootsamples,
234243
seed=seed, verbose=verbose, n_proc=n_proc, **kwargs)
235244

@@ -258,7 +267,9 @@ def svd(self, X, Y, seed=None):
258267
Variance explained by PLS-derived components; diagonal array
259268
"""
260269

261-
out = simpls(X, Y, self.n_components, seed=seed)
270+
# find nan rows and mask for the decomposition
271+
mask = get_mask(X, Y)
272+
out = simpls(X[mask], Y[mask], self.n_components, seed=seed)
262273

263274
# need to return len-3 for compatibility purposes
264275
# use the variance explained in Y in lieu of the singular values since
@@ -301,18 +312,19 @@ def _single_boot(self, X, Y, inds, groups=None, original=None, seed=None):
301312
else:
302313
Xi, Yi = X[inds], Y[inds]
303314

304-
out = simpls(Xi, Yi, self.n_components, seed=seed)
315+
x_weights = self.svd(Xi, Yi, seed=seed)[0]
305316

306317
if original is not None:
307318
# flip signs of weights based on correlations with `original`
308-
flip = np.sign(compute.efficient_corr(out['x_weights'], original))
309-
out['x_weights'] = out['x_weights'] * flip
319+
flip = np.sign(compute.efficient_corr(x_weights, original))
320+
x_weights *= flip
310321
# NOTE: should we be doing a procrustes here?
311322

312-
# recompute y_loadings based on new x_weight signs
313-
out['y_loadings'] = Yi.T @ (Xi @ out['x_weights'])
323+
# compute y_loadings
324+
mask = get_mask(Xi, Yi)
325+
y_loadings = Yi[mask].T @ (Xi @ x_weights)[mask]
314326

315-
return out['y_loadings'], out['x_weights']
327+
return y_loadings, x_weights
316328

317329
def _single_perm(self, X, Y, inds, groups=None, original=None, seed=None):
318330
"""
@@ -340,20 +352,22 @@ def _single_perm(self, X, Y, inds, groups=None, original=None, seed=None):
340352
Variance explained by PLS decomposition of permuted data
341353
"""
342354

355+
# should permute Y (but not X) by default
343356
Xp, Yp = self.make_permutation(X, Y, inds)
344-
out = simpls(Xp, Yp, self.n_components, seed=seed)
357+
x_weights, varexp, _ = self.svd(Xp, Yp, seed=seed)
345358

346359
if self.inputs.rotate and original is not None:
347360
# flip signs of weights based on correlations with `original`
348-
flip = np.sign(compute.efficient_corr(out['x_weights'], original))
349-
out['x_weights'] = out['x_weights'] * flip
361+
flip = np.sign(compute.efficient_corr(x_weights, original))
362+
x_weights *= flip
350363
# NOTE: should we be doing a procrustes here?
351364

352365
# recompute pctvar based on new x_weight signs
353-
y_loadings = Y[inds].T @ (X[inds] @ out['x_weights'])
354-
varexp = np.sum(y_loadings ** 2, axis=0) / np.sum(Yp ** 2)
366+
mask = get_mask(Xp, Yp)
367+
y_loadings = Yp[mask].T @ (Xp @ x_weights)[mask]
368+
varexp = np.sum(y_loadings ** 2, axis=0) / np.sum(Yp[mask] ** 2)
355369
else:
356-
varexp = out['pctvar'][1]
370+
varexp = np.diag(varexp)
357371

358372
# need to return len-3 for compatibility purposes
359373
return varexp, None, None
@@ -378,13 +392,15 @@ def run_pls(self, X, Y):
378392
'the specified axis.')
379393

380394
# mean-center here so that our outputs are generated accordingly
381-
X -= X.mean(axis=0, keepdims=True)
382-
Y_agg -= Y_agg.mean(axis=0, keepdims=True)
395+
X -= np.nanmean(X, axis=0, keepdims=True)
396+
Y_agg -= np.nanmean(Y_agg, axis=0, keepdims=True)
397+
mask = get_mask(X, Y_agg)
383398

384399
res = super().run_pls(X, Y_agg)
385-
res['y_loadings'] = Y_agg.T @ res['x_scores']
386-
res['y_scores'] = resid_yscores(res['x_scores'],
387-
Y_agg @ res['y_loadings'])
400+
res['y_loadings'] = Y_agg[mask].T @ res['x_scores'][mask]
401+
res['y_scores'] = np.full((len(Y_agg), self.n_components), np.nan)
402+
res['y_scores'][mask] = resid_yscores(res['x_scores'][mask],
403+
Y_agg[mask] @ res['y_loadings'])
388404

389405
if self.inputs.n_boot > 0:
390406
# compute bootstraps
@@ -400,7 +416,7 @@ def run_pls(self, X, Y):
400416
corrci = np.stack(compute.boot_ci(distrib, ci=self.inputs.ci), -1)
401417
res['bootres'].update(dict(x_weights_normed=bsrs,
402418
x_weights_stderr=uboot_se,
403-
y_loadings=res['y_loadings'].copy(),
419+
y_loadings=res['y_loadings'],
404420
y_loadings_boot=distrib,
405421
y_loadings_ci=corrci,
406422
bootsamples=self.bootsamp,))

0 commit comments

Comments
 (0)