@@ -33,7 +33,11 @@ def __init__(
3333 X : np .ndarray | CSMatrix | None = None ,
3434 var : pd .DataFrame | Mapping [str , Sequence ] | None = None ,
3535 varm : AxisArrays | Mapping [str , np .ndarray ] | None = None ,
36- ):
36+ ) -> None :
37+ if X is not None and X .shape [0 ] != adata .n_obs :
38+ msg = f"X has { X .shape [0 ]} rows, but n_obs is { adata .n_obs } "
39+ raise ValueError (msg )
40+
3741 self ._adata = adata
3842 self ._n_obs = adata .n_obs
3943 # construct manually
@@ -121,8 +125,19 @@ def var_names(self) -> pd.Index[str]:
121125 def obs_names (self ) -> pd .Index [str ]:
122126 return self ._adata .obs_names
123127
124- def __getitem__ (self , index : Index ) -> Raw :
125- oidx , vidx = self ._normalize_indices (index )
128+ def __getitem__ (self , index : Index | tuple [AnnData , Index ]) -> Raw :
129+ from .anndata import AnnData
130+
131+ if (
132+ isinstance (index , tuple )
133+ and len (index ) == 2
134+ and isinstance (index [0 ], AnnData )
135+ ):
136+ adata , index = index
137+ oidx , vidx = self ._normalize_indices (index )
138+ else :
139+ oidx , vidx = self ._normalize_indices (index )
140+ adata = self ._adata [oidx ]
126141
127142 # To preserve two dimensional shape
128143 if isinstance (vidx , int | np .integer ):
@@ -133,7 +148,7 @@ def __getitem__(self, index: Index) -> Raw:
133148 X = _subset (self .X , (oidx , vidx )) if not self ._adata .isbacked else None
134149
135150 var = self ._var .iloc [vidx ]
136- new = Raw (self . _adata , X = X , var = var )
151+ new = Raw (adata , X = X , var = var )
137152 if self .varm is not None :
138153 # Since there is no view of raws
139154 new .varm = self .varm ._view (_RawViewHack (self , vidx ), (vidx ,)).copy ()
0 commit comments