Skip to content

Commit 5a02063

Browse files
authored
Merge pull request #12 from theislab/concat-outer
Added outer join for concatenate function
2 parents 40a24fb + 47ad1a1 commit 5a02063

File tree

2 files changed

+73
-40
lines changed

2 files changed

+73
-40
lines changed

anndata/base.py

+56-36
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,10 @@
44
import warnings
55
import logging as logg
66
from enum import Enum
7-
from collections import Mapping, Sequence, Sized
7+
from collections import Mapping, Sequence, Sized, ChainMap
8+
from functools import reduce
9+
from typing import Union
10+
811
import numpy as np
912
from numpy import ma
1013
import pandas as pd
@@ -844,7 +847,7 @@ def shape(self):
844847
return self.n_obs, self.n_vars
845848

846849
@property
847-
def X(self):
850+
def X(self) -> Union[np.ndarray, sparse.spmatrix]:
848851
"""Data matrix of shape `n_obs` × `n_vars` (`np.ndarray`, `sp.sparse.spmatrix`)."""
849852
if self.isbacked:
850853
if not self.file.isopen: self.file.open()
@@ -1297,15 +1300,17 @@ def copy(self, filename=None):
12971300
copyfile(self.filename, filename)
12981301
return AnnData(filename=filename)
12991302

1300-
def concatenate(self, adatas, batch_key='batch', batch_categories=None, index_unique='-'):
1303+
def concatenate(self, *adatas, join='inner', batch_key='batch', batch_categories=None, index_unique='-'):
13011304
"""Concatenate along the observations axis after intersecting the variables names.
13021305
13031306
The `.var`, `.varm`, and `.uns` attributes of the passed adatas are ignored.
13041307
13051308
Parameters
13061309
----------
1307-
adatas : :class:`~anndata.AnnData` or list of :class:`~anndata.AnnData`
1310+
adatas : :class:`~anndata.AnnData`
13081311
AnnData matrices to concatenate with.
1312+
join: `str` (default: 'inner')
1313+
Use intersection (``'inner'``) or union (``'outer'``) of variables?
13091314
batch_key : `str` (default: 'batch')
13101315
Add the batch annotation to `.obs` using this key.
13111316
batch_categories : list, optional (default: `range(len(adatas)+1)`)
@@ -1332,7 +1337,7 @@ def concatenate(self, adatas, batch_key='batch', batch_categories=None, index_un
13321337
>>> {'anno2': ['d3', 'd4']},
13331338
>>> {'var_names': ['b', 'c', 'd']})
13341339
>>>
1335-
>>> adata = adata1.concatenate([adata2, adata3])
1340+
>>> adata = adata1.concatenate(adata2, adata3)
13361341
>>> adata.X
13371342
[[ 2. 3.]
13381343
[ 5. 6.]
@@ -1351,41 +1356,56 @@ def concatenate(self, adatas, batch_key='batch', batch_categories=None, index_un
13511356
0-2 NaN d3 2
13521357
1-2 NaN d4 2
13531358
"""
1354-
if isinstance(adatas, AnnData): adatas = [adatas]
1355-
joint_variables = self.var_names
1356-
for adata2 in adatas:
1357-
joint_variables = np.intersect1d(
1358-
joint_variables, adata2.var_names, assume_unique=True)
1359+
if len(adatas) == 0:
1360+
return self
1361+
elif len(adatas) == 1 and not isinstance(adatas[0], AnnData):
1362+
adatas = adatas[0] # backwards compatibility
1363+
all_adatas = (self,) + adatas
1364+
1365+
mergers = dict(inner=set.intersection, outer=set.union)
1366+
var_names = pd.Index(reduce(mergers[join], (set(ad.var_names) for ad in all_adatas)))
1367+
13591368
if batch_categories is None:
1360-
categories = [str(i) for i in range(len(adatas)+1)]
1361-
elif len(batch_categories) == len(adatas)+1:
1369+
categories = [str(i) for i, _ in enumerate(all_adatas)]
1370+
elif len(batch_categories) == len(all_adatas):
13621371
categories = batch_categories
13631372
else:
13641373
raise ValueError('Provide as many `batch_categories` as `adatas`.')
1365-
adatas_to_concat = []
1366-
for i, ad in enumerate([self] + adatas):
1367-
ad.obs.index.values
1368-
ad = ad[:, joint_variables]
1369-
ad.obs[batch_key] = pd.Categorical(
1370-
ad.n_obs*[categories[i]], categories=categories)
1371-
ad.obs.index.values
1374+
1375+
out_shape = (sum(a.n_obs for a in all_adatas), len(var_names))
1376+
1377+
any_sparse = any(issparse(a.X) for a in all_adatas)
1378+
mat_cls = sparse.csc_matrix if any_sparse else np.ndarray
1379+
X = mat_cls(out_shape, dtype=self.X.dtype)
1380+
var = pd.DataFrame(index=var_names)
1381+
1382+
obs_i = 0 # start of next adata’s observations in X
1383+
out_obss = []
1384+
for i, ad in enumerate(all_adatas):
1385+
vars_ad_in_res = var_names.isin(ad.var_names)
1386+
vars_res_in_ad = ad.var_names.isin(var_names)
1387+
1388+
# X
1389+
X[obs_i:obs_i+ad.n_obs, vars_ad_in_res] = ad.X[:, vars_res_in_ad]
1390+
obs_i += ad.n_obs
1391+
1392+
# obs
1393+
obs = ad.obs.copy()
1394+
obs[batch_key] = pd.Categorical(ad.n_obs * [categories[i]], categories)
13721395
if index_unique is not None:
13731396
if not is_string_dtype(ad.obs.index):
1374-
ad.obs.index = ad.obs.index.astype(str)
1375-
ad.obs.index = ad.obs.index.values + index_unique + categories[i]
1376-
adatas_to_concat.append(ad)
1377-
Xs = [ad.X for ad in adatas_to_concat]
1378-
if issparse(self.X):
1379-
from scipy.sparse import vstack
1380-
X = vstack(Xs)
1381-
else:
1382-
X = np.concatenate(Xs)
1383-
obs = pd.concat([ad.obs for ad in adatas_to_concat])
1384-
obsm = np.concatenate([ad.obsm for ad in adatas_to_concat])
1385-
var = adatas_to_concat[0].var
1386-
varm = adatas_to_concat[0].varm
1387-
uns = adatas_to_concat[0].uns
1388-
return AnnData(X, obs, var, uns, obsm, varm, filename=self.filename)
1397+
obs.index = obs.index.astype(str)
1398+
obs.index = obs.index.values + index_unique + categories[i]
1399+
out_obss.append(obs)
1400+
1401+
# var
1402+
var.loc[vars_ad_in_res, ad.var.columns] = ad.var.loc[vars_res_in_ad, :]
1403+
1404+
obs = pd.concat(out_obss)
1405+
uns = dict(ChainMap({}, *[ad.obs for ad in all_adatas]))
1406+
obsm = np.concatenate([ad.obsm for ad in all_adatas])
1407+
varm = self.varm # TODO
1408+
return AnnData(X, obs, var, uns, obsm, None, filename=self.filename)
13891409

13901410
def var_names_make_unique(self, join='-'):
13911411
self.var.index = utils.make_index_unique(self.var.index, join)
@@ -1423,11 +1443,11 @@ def _check_dimensions(self, key=None):
14231443
if 'obsm' in key and len(self._obsm) != self._n_obs:
14241444
raise ValueError('Observations annot. `obsm` must have number of '
14251445
'rows of `X` ({}), but has {} rows.'
1426-
.format(self._n_obs, self._obs.shape[0]))
1446+
.format(self._n_obs, len(self._obsm)))
14271447
if 'varm' in key and len(self._varm) != self._n_vars:
14281448
raise ValueError('Variables annot. `varm` must have number of '
14291449
'columns of `X` ({}), but has {} rows.'
1430-
.format(self._n_vars, self._var.shape[0]))
1450+
.format(self._n_vars, len(self._varm)))
14311451

14321452
def write(self, filename=None, compression='gzip', compression_opts=None):
14331453
"""Write `.h5ad`-formatted hdf5 file and close a potential backing file.

anndata/tests/base.py

+17-4
Original file line numberDiff line numberDiff line change
@@ -194,12 +194,12 @@ def test_concatenate():
194194
{'obs_names': ['s5', 's6'],
195195
'anno2': ['d3', 'd4']},
196196
{'var_names': ['b', 'c', 'd']})
197-
adata = adata1.concatenate([adata2, adata3])
197+
adata = adata1.concatenate(adata2, adata3)
198198
assert adata.n_vars == 2
199199
assert adata.obs_keys() == ['anno1', 'anno2', 'batch']
200-
adata = adata1.concatenate([adata2, adata3], batch_key='batch1')
200+
adata = adata1.concatenate(adata2, adata3, batch_key='batch1')
201201
assert adata.obs_keys() == ['anno1', 'anno2', 'batch1']
202-
adata = adata1.concatenate([adata2, adata3], batch_categories=['a1', 'a2', 'a3'])
202+
adata = adata1.concatenate(adata2, adata3, batch_categories=['a1', 'a2', 'a3'])
203203
assert adata.obs['batch'].cat.categories.tolist() == ['a1', 'a2', 'a3']
204204

205205

@@ -217,10 +217,23 @@ def test_concatenate_sparse():
217217
{'obs_names': ['s5', 's6'],
218218
'anno2': ['d3', 'd4']},
219219
{'var_names': ['b', 'c', 'd']})
220-
adata = adata1.concatenate([adata2, adata3])
220+
adata = adata1.concatenate(adata2, adata3)
221221
assert adata.n_vars == 2
222222

223223

224+
def test_concatenate_outer():
225+
adata1 = AnnData(np.array([[1, 2, 3], [4, 5, 6]]),
226+
{'obs_names': ['s1', 's2'],
227+
'anno1': ['c1', 'c2']},
228+
{'var_names': ['a', 'b', 'c']})
229+
adata2 = AnnData(np.array([[1, 2, 3], [4, 5, 6], [7,8,9]]),
230+
{'obs_names': ['s3', 's4', 's5'],
231+
'anno2': ['c3', 'c4', 'c5']},
232+
{'var_names': ['b', 'c', 'd']})
233+
adata = adata1.concatenate(adata2, join='outer')
234+
assert adata.n_vars == 4
235+
assert adata.obs_keys() == ['anno1', 'anno2', 'batch']
236+
224237
# TODO: remove logging and actually test values
225238
# from scanpy import logging as logg
226239

0 commit comments

Comments
 (0)