Skip to content

Commit 4f54c63

Browse files
authored
fix: check .raw length on setting (#2351)
1 parent c67c66d commit 4f54c63

4 files changed

Lines changed: 36 additions & 14 deletions

File tree

docs/release-notes/2351.fix.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1 @@
1+
Check `Raw` length on creation and fix associated `.adata` in `Raw` slicing {user}`P Angerer`

src/anndata/_core/anndata.py

Lines changed: 9 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -340,8 +340,7 @@ def _init_as_view(
340340
# set raw, easy, as it’s immutable anyways...
341341
if adata_ref._raw is not None:
342342
# slicing along variables axis is ignored
343-
self._raw = adata_ref.raw[oidx]
344-
self._raw._adata = self
343+
self._raw = adata_ref.raw[self, oidx]
345344
else:
346345
self._raw = None
347346

@@ -699,19 +698,20 @@ def raw(self) -> Raw:
699698
return self._raw
700699

701700
@raw.setter
702-
def raw(self, value: AnnData):
701+
def raw(self, value: AnnData) -> None:
703702
if value is None:
704703
del self.raw
705-
elif not isinstance(value, AnnData):
704+
return
705+
if not isinstance(value, AnnData):
706706
msg = "Can only init raw attribute with an AnnData object."
707707
raise ValueError(msg)
708-
else:
709-
if self.is_view:
710-
self._init_as_actual(self.copy())
711-
self._raw = Raw(self, X=value.X, var=value.var, varm=value.varm)
708+
raw = Raw(self, X=value.X, var=value.var, varm=value.varm)
709+
if self.is_view:
710+
self._init_as_actual(self.copy())
711+
self._raw = raw
712712

713713
@raw.deleter
714-
def raw(self):
714+
def raw(self) -> None:
715715
if self.is_view:
716716
self._init_as_actual(self.copy())
717717
self._raw = None

src/anndata/_core/raw.py

Lines changed: 21 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,11 @@ def __init__(
3535
X: np.ndarray | CSMatrix | None = None,
3636
var: pd.DataFrame | Mapping[str, Sequence] | None = None,
3737
varm: AxisArrays | Mapping[str, np.ndarray] | None = None,
38-
):
38+
) -> None:
39+
if X is not None and X.shape[0] != adata.n_obs:
40+
msg = f"X has {X.shape[0]} rows, but n_obs is {adata.n_obs}"
41+
raise ValueError(msg)
42+
3943
self._adata = adata
4044
self._n_obs = adata.n_obs
4145
# construct manually
@@ -126,14 +130,26 @@ def obs_names(self) -> pd.Index[str]:
126130
@overload
127131
def __getitem__(self, index: AdRef) -> InMemoryArray: ...
128132
@overload
129-
def __getitem__(self, index: Index) -> Raw: ...
130-
def __getitem__(self, index: Index | AdRef) -> Raw | InMemoryArray:
133+
def __getitem__(self, index: Index | tuple[AnnData, Index]) -> Raw: ...
134+
def __getitem__(
135+
self, index: Index | tuple[AnnData, Index] | AdRef
136+
) -> Raw | InMemoryArray:
131137
from ..acc import AdRef
138+
from .anndata import AnnData
132139

133140
if isinstance(index, AdRef):
134141
return index.acc.get(self, index.idx) # type: ignore # no official Raw support here
135142

136-
oidx, vidx = self._normalize_indices(index)
143+
if (
144+
isinstance(index, tuple)
145+
and len(index) == 2
146+
and isinstance(index[0], AnnData)
147+
):
148+
adata, index = index
149+
oidx, vidx = self._normalize_indices(index)
150+
else:
151+
oidx, vidx = self._normalize_indices(index)
152+
adata = self._adata[oidx]
137153

138154
# To preserve two dimensional shape
139155
if isinstance(vidx, int | np.integer):
@@ -144,7 +160,7 @@ def __getitem__(self, index: Index | AdRef) -> Raw | InMemoryArray:
144160
X = _subset(self.X, (oidx, vidx)) if not self._adata.isbacked else None
145161

146162
var = self._var.iloc[vidx]
147-
new = Raw(self._adata, X=X, var=var)
163+
new = Raw(adata, X=X, var=var)
148164
if self.varm is not None:
149165
# Since there is no view of raws
150166
new.varm = self.varm._view(_RawViewHack(self, vidx), (vidx,)).copy()

tests/test_raw.py

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -71,6 +71,11 @@ def test_raw_set_as_none(adata_raw: ad.AnnData):
7171
assert_equal(a, b)
7272

7373

74+
def test_raw_set_error(adata_raw: ad.AnnData) -> None:
75+
with pytest.raises(ValueError, match=r"X has 2 rows, but n_obs is 3"):
76+
adata_raw.raw = adata_raw[:2].copy()
77+
78+
7479
def test_raw_of_view(adata_raw: ad.AnnData):
7580
adata_view = adata_raw[adata_raw.obs["oanno1"] == "cat2"]
7681
assert adata_view.raw.X.tolist() == [

0 commit comments

Comments
 (0)