@@ -45,6 +45,14 @@ def resid_yscores(x_scores, y_scores, copy=True):
4545 return y_scores
4646
4747
48+ def get_mask (X , Y ):
49+ """ Returns mask removing rows where either X/Y contain all NaN values
50+ """
51+
52+ return np .logical_not (np .logical_or (np .all (np .isnan (X ), axis = 1 ),
53+ np .all (np .isnan (Y ), axis = 1 )))
54+
55+
4856def simpls (X , Y , n_components = None , seed = 1234 ):
4957 """
5058 Performs partial least squares regression with the SIMPLS algorithm
@@ -226,10 +234,11 @@ def __init__(self, X, Y, *, n_components=None,
226234 'or one of {}' .format (sorted (aggfuncs )))
227235 self .aggfunc = aggfuncs .get (aggfunc , aggfunc )
228236
237+ # these need to be zero -- they're not implemented for PLSRegression
238+ kwargs .update (n_split = 0 , test_split = 0 )
229239 super ().__init__ (X = np .asarray (X ), Y = np .asarray (Y ),
230240 n_components = n_components , n_perm = n_perm ,
231- n_boot = n_boot , n_split = 0 , test_split = 0 ,
232- rotate = rotate , ci = ci , aggfunc = aggfunc ,
241+ n_boot = n_boot , rotate = rotate , ci = ci , aggfunc = aggfunc ,
233242 permsamples = permsamples , bootsamples = bootsamples ,
234243 seed = seed , verbose = verbose , n_proc = n_proc , ** kwargs )
235244
@@ -258,7 +267,9 @@ def svd(self, X, Y, seed=None):
258267 Variance explained by PLS-derived components; diagonal array
259268 """
260269
261- out = simpls (X , Y , self .n_components , seed = seed )
270+ # find nan rows and mask for the decomposition
271+ mask = get_mask (X , Y )
272+ out = simpls (X [mask ], Y [mask ], self .n_components , seed = seed )
262273
263274 # need to return len-3 for compatibility purposes
264275 # use the variance explained in Y in lieu of the singular values since
@@ -301,18 +312,19 @@ def _single_boot(self, X, Y, inds, groups=None, original=None, seed=None):
301312 else :
302313 Xi , Yi = X [inds ], Y [inds ]
303314
304- out = simpls (Xi , Yi , self . n_components , seed = seed )
315+ x_weights = self . svd (Xi , Yi , seed = seed )[ 0 ]
305316
306317 if original is not None :
307318 # flip signs of weights based on correlations with `original`
308- flip = np .sign (compute .efficient_corr (out [ ' x_weights' ] , original ))
309- out [ ' x_weights' ] = out [ 'x_weights' ] * flip
319+ flip = np .sign (compute .efficient_corr (x_weights , original ))
320+ x_weights *= flip
310321 # NOTE: should we be doing a procrustes here?
311322
312- # recompute y_loadings based on new x_weight signs
313- out ['y_loadings' ] = Yi .T @ (Xi @ out ['x_weights' ])
323+ # compute y_loadings
324+ mask = get_mask (Xi , Yi )
325+ y_loadings = Yi [mask ].T @ (Xi @ x_weights )[mask ]
314326
315- return out [ ' y_loadings' ], out [ ' x_weights' ]
327+ return y_loadings , x_weights
316328
317329 def _single_perm (self , X , Y , inds , groups = None , original = None , seed = None ):
318330 """
@@ -340,20 +352,22 @@ def _single_perm(self, X, Y, inds, groups=None, original=None, seed=None):
340352 Variance explained by PLS decomposition of permuted data
341353 """
342354
355+ # should permute Y (but not X) by default
343356 Xp , Yp = self .make_permutation (X , Y , inds )
344- out = simpls (Xp , Yp , self . n_components , seed = seed )
357+ x_weights , varexp , _ = self . svd (Xp , Yp , seed = seed )
345358
346359 if self .inputs .rotate and original is not None :
347360 # flip signs of weights based on correlations with `original`
348- flip = np .sign (compute .efficient_corr (out [ ' x_weights' ] , original ))
349- out [ ' x_weights' ] = out [ 'x_weights' ] * flip
361+ flip = np .sign (compute .efficient_corr (x_weights , original ))
362+ x_weights *= flip
350363 # NOTE: should we be doing a procrustes here?
351364
352365 # recompute pctvar based on new x_weight signs
353- y_loadings = Y [inds ].T @ (X [inds ] @ out ['x_weights' ])
354- varexp = np .sum (y_loadings ** 2 , axis = 0 ) / np .sum (Yp ** 2 )
366+ mask = get_mask (Xp , Yp )
367+ y_loadings = Yp [mask ].T @ (Xp @ x_weights )[mask ]
368+ varexp = np .sum (y_loadings ** 2 , axis = 0 ) / np .sum (Yp [mask ] ** 2 )
355369 else :
356- varexp = out [ 'pctvar' ][ 1 ]
370+ varexp = np . diag ( varexp )
357371
358372 # need to return len-3 for compatibility purposes
359373 return varexp , None , None
@@ -378,13 +392,15 @@ def run_pls(self, X, Y):
378392 'the specified axis.' )
379393
380394 # mean-center here so that our outputs are generated accordingly
381- X -= X .mean (axis = 0 , keepdims = True )
382- Y_agg -= Y_agg .mean (axis = 0 , keepdims = True )
395+ X -= np .nanmean (X , axis = 0 , keepdims = True )
396+ Y_agg -= np .nanmean (Y_agg , axis = 0 , keepdims = True )
397+ mask = get_mask (X , Y_agg )
383398
384399 res = super ().run_pls (X , Y_agg )
385- res ['y_loadings' ] = Y_agg .T @ res ['x_scores' ]
386- res ['y_scores' ] = resid_yscores (res ['x_scores' ],
387- Y_agg @ res ['y_loadings' ])
400+ res ['y_loadings' ] = Y_agg [mask ].T @ res ['x_scores' ][mask ]
401+ res ['y_scores' ] = np .full ((len (Y_agg ), self .n_components ), np .nan )
402+ res ['y_scores' ][mask ] = resid_yscores (res ['x_scores' ][mask ],
403+ Y_agg [mask ] @ res ['y_loadings' ])
388404
389405 if self .inputs .n_boot > 0 :
390406 # compute bootstraps
@@ -400,7 +416,7 @@ def run_pls(self, X, Y):
400416 corrci = np .stack (compute .boot_ci (distrib , ci = self .inputs .ci ), - 1 )
401417 res ['bootres' ].update (dict (x_weights_normed = bsrs ,
402418 x_weights_stderr = uboot_se ,
403- y_loadings = res ['y_loadings' ]. copy () ,
419+ y_loadings = res ['y_loadings' ],
404420 y_loadings_boot = distrib ,
405421 y_loadings_ci = corrci ,
406422 bootsamples = self .bootsamp ,))
0 commit comments