4
4
import warnings
5
5
import logging as logg
6
6
from enum import Enum
7
- from collections import Mapping , Sequence , Sized
7
+ from collections import Mapping , Sequence , Sized , ChainMap
8
+ from functools import reduce
9
+ from typing import Union
10
+
8
11
import numpy as np
9
12
from numpy import ma
10
13
import pandas as pd
@@ -844,7 +847,7 @@ def shape(self):
844
847
return self .n_obs , self .n_vars
845
848
846
849
@property
847
- def X (self ):
850
+ def X (self ) -> Union [ np . ndarray , sparse . spmatrix ] :
848
851
"""Data matrix of shape `n_obs` × `n_vars` (`np.ndarray`, `sp.sparse.spmatrix`)."""
849
852
if self .isbacked :
850
853
if not self .file .isopen : self .file .open ()
@@ -1297,15 +1300,17 @@ def copy(self, filename=None):
1297
1300
copyfile (self .filename , filename )
1298
1301
return AnnData (filename = filename )
1299
1302
1300
- def concatenate (self , adatas , batch_key = 'batch' , batch_categories = None , index_unique = '-' ):
1303
+ def concatenate (self , * adatas , join = 'inner' , batch_key = 'batch' , batch_categories = None , index_unique = '-' ):
1301
1304
"""Concatenate along the observations axis after intersecting the variables names.
1302
1305
1303
1306
The `.var`, `.varm`, and `.uns` attributes of the passed adatas are ignored.
1304
1307
1305
1308
Parameters
1306
1309
----------
1307
- adatas : :class:`~anndata.AnnData` or list of :class:`~anndata.AnnData`
1310
+ adatas : :class:`~anndata.AnnData`
1308
1311
AnnData matrices to concatenate with.
1312
+ join: `str` (default: 'inner')
1313
+ Use intersection (``'inner'``) or union (``'outer'``) of variables?
1309
1314
batch_key : `str` (default: 'batch')
1310
1315
Add the batch annotation to `.obs` using this key.
1311
1316
batch_categories : list, optional (default: `range(len(adatas)+1)`)
@@ -1332,7 +1337,7 @@ def concatenate(self, adatas, batch_key='batch', batch_categories=None, index_un
1332
1337
>>> {'anno2': ['d3', 'd4']},
1333
1338
>>> {'var_names': ['b', 'c', 'd']})
1334
1339
>>>
1335
- >>> adata = adata1.concatenate([ adata2, adata3] )
1340
+ >>> adata = adata1.concatenate(adata2, adata3)
1336
1341
>>> adata.X
1337
1342
[[ 2. 3.]
1338
1343
[ 5. 6.]
@@ -1351,41 +1356,56 @@ def concatenate(self, adatas, batch_key='batch', batch_categories=None, index_un
1351
1356
0-2 NaN d3 2
1352
1357
1-2 NaN d4 2
1353
1358
"""
1354
- if isinstance (adatas , AnnData ): adatas = [adatas ]
1355
- joint_variables = self .var_names
1356
- for adata2 in adatas :
1357
- joint_variables = np .intersect1d (
1358
- joint_variables , adata2 .var_names , assume_unique = True )
1359
+ if len (adatas ) == 0 :
1360
+ return self
1361
+ elif len (adatas ) == 1 and not isinstance (adatas [0 ], AnnData ):
1362
+ adatas = adatas [0 ] # backwards compatibility
1363
+ all_adatas = (self ,) + adatas
1364
+
1365
+ mergers = dict (inner = set .intersection , outer = set .union )
1366
+ var_names = pd .Index (reduce (mergers [join ], (set (ad .var_names ) for ad in all_adatas )))
1367
+
1359
1368
if batch_categories is None :
1360
- categories = [str (i ) for i in range ( len ( adatas ) + 1 )]
1361
- elif len (batch_categories ) == len (adatas ) + 1 :
1369
+ categories = [str (i ) for i , _ in enumerate ( all_adatas )]
1370
+ elif len (batch_categories ) == len (all_adatas ) :
1362
1371
categories = batch_categories
1363
1372
else :
1364
1373
raise ValueError ('Provide as many `batch_categories` as `adatas`.' )
1365
- adatas_to_concat = []
1366
- for i , ad in enumerate ([self ] + adatas ):
1367
- ad .obs .index .values
1368
- ad = ad [:, joint_variables ]
1369
- ad .obs [batch_key ] = pd .Categorical (
1370
- ad .n_obs * [categories [i ]], categories = categories )
1371
- ad .obs .index .values
1374
+
1375
+ out_shape = (sum (a .n_obs for a in all_adatas ), len (var_names ))
1376
+
1377
+ any_sparse = any (issparse (a .X ) for a in all_adatas )
1378
+ mat_cls = sparse .csc_matrix if any_sparse else np .ndarray
1379
+ X = mat_cls (out_shape , dtype = self .X .dtype )
1380
+ var = pd .DataFrame (index = var_names )
1381
+
1382
+ obs_i = 0 # start of next adata’s observations in X
1383
+ out_obss = []
1384
+ for i , ad in enumerate (all_adatas ):
1385
+ vars_ad_in_res = var_names .isin (ad .var_names )
1386
+ vars_res_in_ad = ad .var_names .isin (var_names )
1387
+
1388
+ # X
1389
+ X [obs_i :obs_i + ad .n_obs , vars_ad_in_res ] = ad .X [:, vars_res_in_ad ]
1390
+ obs_i += ad .n_obs
1391
+
1392
+ # obs
1393
+ obs = ad .obs .copy ()
1394
+ obs [batch_key ] = pd .Categorical (ad .n_obs * [categories [i ]], categories )
1372
1395
if index_unique is not None :
1373
1396
if not is_string_dtype (ad .obs .index ):
1374
- ad .obs .index = ad .obs .index .astype (str )
1375
- ad .obs .index = ad .obs .index .values + index_unique + categories [i ]
1376
- adatas_to_concat .append (ad )
1377
- Xs = [ad .X for ad in adatas_to_concat ]
1378
- if issparse (self .X ):
1379
- from scipy .sparse import vstack
1380
- X = vstack (Xs )
1381
- else :
1382
- X = np .concatenate (Xs )
1383
- obs = pd .concat ([ad .obs for ad in adatas_to_concat ])
1384
- obsm = np .concatenate ([ad .obsm for ad in adatas_to_concat ])
1385
- var = adatas_to_concat [0 ].var
1386
- varm = adatas_to_concat [0 ].varm
1387
- uns = adatas_to_concat [0 ].uns
1388
- return AnnData (X , obs , var , uns , obsm , varm , filename = self .filename )
1397
+ obs .index = obs .index .astype (str )
1398
+ obs .index = obs .index .values + index_unique + categories [i ]
1399
+ out_obss .append (obs )
1400
+
1401
+ # var
1402
+ var .loc [vars_ad_in_res , ad .var .columns ] = ad .var .loc [vars_res_in_ad , :]
1403
+
1404
+ obs = pd .concat (out_obss )
1405
+ uns = dict (ChainMap ({}, * [ad .obs for ad in all_adatas ]))
1406
+ obsm = np .concatenate ([ad .obsm for ad in all_adatas ])
1407
+ varm = self .varm # TODO
1408
+ return AnnData (X , obs , var , uns , obsm , None , filename = self .filename )
1389
1409
1390
1410
def var_names_make_unique (self , join = '-' ):
1391
1411
self .var .index = utils .make_index_unique (self .var .index , join )
@@ -1423,11 +1443,11 @@ def _check_dimensions(self, key=None):
1423
1443
if 'obsm' in key and len (self ._obsm ) != self ._n_obs :
1424
1444
raise ValueError ('Observations annot. `obsm` must have number of '
1425
1445
'rows of `X` ({}), but has {} rows.'
1426
- .format (self ._n_obs , self ._obs . shape [ 0 ] ))
1446
+ .format (self ._n_obs , len ( self ._obsm ) ))
1427
1447
if 'varm' in key and len (self ._varm ) != self ._n_vars :
1428
1448
raise ValueError ('Variables annot. `varm` must have number of '
1429
1449
'columns of `X` ({}), but has {} rows.'
1430
- .format (self ._n_vars , self ._var . shape [ 0 ] ))
1450
+ .format (self ._n_vars , len ( self ._varm ) ))
1431
1451
1432
1452
def write (self , filename = None , compression = 'gzip' , compression_opts = None ):
1433
1453
"""Write `.h5ad`-formatted hdf5 file and close a potential backing file.
0 commit comments