@@ -35,7 +35,11 @@ def __init__(
3535 X : np .ndarray | CSMatrix | None = None ,
3636 var : pd .DataFrame | Mapping [str , Sequence ] | None = None ,
3737 varm : AxisArrays | Mapping [str , np .ndarray ] | None = None ,
38- ):
38+ ) -> None :
39+ if X is not None and X .shape [0 ] != adata .n_obs :
40+ msg = f"X has { X .shape [0 ]} rows, but n_obs is { adata .n_obs } "
41+ raise ValueError (msg )
42+
3943 self ._adata = adata
4044 self ._n_obs = adata .n_obs
4145 # construct manually
@@ -126,14 +130,26 @@ def obs_names(self) -> pd.Index[str]:
126130 @overload
127131 def __getitem__ (self , index : AdRef ) -> InMemoryArray : ...
128132 @overload
129- def __getitem__ (self , index : Index ) -> Raw : ...
130- def __getitem__ (self , index : Index | AdRef ) -> Raw | InMemoryArray :
133+ def __getitem__ (self , index : Index | tuple [AnnData , Index ]) -> Raw : ...
134+ def __getitem__ (
135+ self , index : Index | tuple [AnnData , Index ] | AdRef
136+ ) -> Raw | InMemoryArray :
131137 from ..acc import AdRef
138+ from .anndata import AnnData
132139
133140 if isinstance (index , AdRef ):
134141 return index .acc .get (self , index .idx ) # type: ignore # no official Raw support here
135142
136- oidx , vidx = self ._normalize_indices (index )
143+ if (
144+ isinstance (index , tuple )
145+ and len (index ) == 2
146+ and isinstance (index [0 ], AnnData )
147+ ):
148+ adata , index = index
149+ oidx , vidx = self ._normalize_indices (index )
150+ else :
151+ oidx , vidx = self ._normalize_indices (index )
152+ adata = self ._adata [oidx ]
137153
138154 # To preserve two dimensional shape
139155 if isinstance (vidx , int | np .integer ):
@@ -144,7 +160,7 @@ def __getitem__(self, index: Index | AdRef) -> Raw | InMemoryArray:
144160 X = _subset (self .X , (oidx , vidx )) if not self ._adata .isbacked else None
145161
146162 var = self ._var .iloc [vidx ]
147- new = Raw (self . _adata , X = X , var = var )
163+ new = Raw (adata , X = X , var = var )
148164 if self .varm is not None :
149165 # Since there is no view of raws
150166 new .varm = self .varm ._view (_RawViewHack (self , vidx ), (vidx ,)).copy ()
0 commit comments