@@ -1301,9 +1301,12 @@ def copy(self, filename=None):
1301
1301
return AnnData (filename = filename )
1302
1302
1303
1303
def concatenate (self , * adatas , join = 'inner' , batch_key = 'batch' , batch_categories = None , index_unique = None ):
1304
- """Concatenate along the observations axis after intersecting the variables names .
1304
+ """Concatenate along the observations axis.
1305
1305
1306
- The `.var`, `.varm`, and `.uns` attributes of the passed adatas are ignored.
1306
+ The `.uns` and `.varm` attributes of the passed `adatas` are ignored.
1307
+
1308
+ If you use `join='outer'`, then note that this fills 0s for data that is
1309
+ non-present. Use this with care.
1307
1310
1308
1311
Parameters
1309
1312
----------
@@ -1337,7 +1340,7 @@ def concatenate(self, *adatas, join='inner', batch_key='batch', batch_categories
1337
1340
>>> {'anno2': ['d3', 'd4']},
1338
1341
>>> {'var_names': ['b', 'c', 'd']})
1339
1342
>>>
1340
- >>> adata = adata1.concatenate(adata2, adata3)
1343
+ >>> adata = adata1.concatenate(adata2, adata3, index_unique='-' )
1341
1344
>>> adata.X
1342
1345
[[ 2. 3.]
1343
1346
[ 5. 6.]
@@ -1372,9 +1375,17 @@ def concatenate(self, *adatas, join='inner', batch_key='batch', batch_categories
1372
1375
'Making variable names unique for controlled concatenation.' )
1373
1376
printed_info = True
1374
1377
1378
+ # define variable names of joint AnnData
1375
1379
mergers = dict (inner = set .intersection , outer = set .union )
1376
- var_names = pd .Index (reduce (mergers [join ], (set (ad .var_names ) for ad in all_adatas )))
1377
-
1380
+ var_names_reduce = reduce (mergers [join ], (set (ad .var_names ) for ad in all_adatas ))
1381
+ # restore order of initial var_names, append non-sortable names at the end
1382
+ var_names = []
1383
+ for v in all_adatas [0 ].var_names :
1384
+ if v in var_names_reduce :
1385
+ var_names .append (v )
1386
+ var_names_reduce .remove (v ) # update the set
1387
+ var_names = pd .Index (var_names + list (var_names_reduce ))
1388
+
1378
1389
if batch_categories is None :
1379
1390
categories = [str (i ) for i , _ in enumerate (all_adatas )]
1380
1391
elif len (batch_categories ) == len (all_adatas ):
@@ -1392,11 +1403,11 @@ def concatenate(self, *adatas, join='inner', batch_key='batch', batch_categories
1392
1403
obs_i = 0 # start of next adata’s observations in X
1393
1404
out_obss = []
1394
1405
for i , ad in enumerate (all_adatas ):
1395
- vars_ad_in_res = var_names .isin (ad .var_names )
1396
- vars_res_in_ad = ad .var_names .isin (var_names )
1406
+ vars_intersect = [v for v in var_names if v in ad .var_names ]
1397
1407
1398
1408
# X
1399
- X [obs_i :obs_i + ad .n_obs , vars_ad_in_res ] = ad .X [:, vars_res_in_ad ]
1409
+ X [obs_i :obs_i + ad .n_obs ,
1410
+ var_names .isin (vars_intersect )] = ad [:, vars_intersect ].X
1400
1411
obs_i += ad .n_obs
1401
1412
1402
1413
# obs
@@ -1412,13 +1423,14 @@ def concatenate(self, *adatas, join='inner', batch_key='batch', batch_categories
1412
1423
out_obss .append (obs )
1413
1424
1414
1425
# var
1415
- var .loc [vars_ad_in_res , ad .var .columns ] = ad .var .loc [vars_res_in_ad , :]
1426
+ # potential add additional columns
1427
+ var .loc [vars_intersect , ad .var .columns ] = ad .var .loc [vars_intersect , :]
1416
1428
1417
1429
obs = pd .concat (out_obss )
1418
1430
uns = all_adatas [0 ].uns
1419
1431
obsm = np .concatenate ([ad .obsm for ad in all_adatas ])
1420
1432
varm = self .varm # TODO
1421
-
1433
+
1422
1434
new_adata = AnnData (X , obs , var , uns , obsm , None , filename = self .filename )
1423
1435
if not obs .index .is_unique :
1424
1436
logg .info (
0 commit comments