1313from matplotlib import colormaps , gridspec , patheffects , rcParams
1414from matplotlib import pyplot as plt
1515from matplotlib .colors import is_color_like
16+ from packaging .version import Version
1617from pandas .api .types import CategoricalDtype , is_numeric_dtype
1718
1819from .. import get
1920from .. import logging as logg
20- from .._compat import CSBase , old_positionals
21+ from .._compat import CSBase , old_positionals , pkg_version
2122from .._settings import settings
2223from .._utils import (
2324 _doc_params ,
5556 from seaborn .matrix import ClusterGrid
5657
5758 from .._utils import Empty
58- from ._utils import (
59- ColorLike ,
60- DensityNorm ,
61- _FontSize ,
62- _FontWeight ,
63- _LegendLoc ,
64- )
59+ from ._utils import ColorLike , DensityNorm , _FontSize , _FontWeight , _LegendLoc
6560
6661# TODO: is that all?
6762type _Basis = Literal ["pca" , "tsne" , "umap" , "diffmap" , "draw_graph_fr" ]
@@ -319,6 +314,9 @@ def _scatter_obs( # noqa: PLR0912, PLR0913, PLR0915
319314 if title is not None and isinstance (title , str ):
320315 title = [title ]
321316 highlights = adata .uns .get ("highlights" , [])
317+ is_anndata_13 = pkg_version ("anndata" ) >= Version ("0.13.0rc0" )
318+ if is_anndata_13 :
319+ from anndata .acc import A
322320 if basis is not None :
323321 try :
324322 # ignore the '0th' diffusion component
@@ -334,16 +332,28 @@ def _scatter_obs( # noqa: PLR0912, PLR0913, PLR0915
334332 elif x is not None and y is not None :
335333 if use_raw :
336334 if x in adata .obs .columns :
337- x_arr = adata .obs_vector (x )
335+ x_arr = adata [ A . obs [ x ]] if is_anndata_13 else adata .obs_vector (x )
338336 else :
339- x_arr = adata .raw .obs_vector (x )
337+ x_arr = (
338+ adata .raw [A .X [:, x ]] if is_anndata_13 else adata .raw .obs_vector (x )
339+ )
340340 if y in adata .obs .columns :
341- y_arr = adata .obs_vector (y )
341+ y_arr = adata [ A . obs [ y ]] if is_anndata_13 else adata .obs_vector (y )
342342 else :
343- y_arr = adata .raw .obs_vector (y )
343+ y_arr = (
344+ adata .raw [A .X [:, y ]] if is_anndata_13 else adata .raw .obs_vector (y )
345+ )
344346 else :
345- x_arr = adata .obs_vector (x , layer = layers [0 ])
346- y_arr = adata .obs_vector (y , layer = layers [1 ])
347+ x_arr = (
348+ adata [A .layers [layers [0 ]][:, x ]]
349+ if is_anndata_13
350+ else adata .obs_vector (x , layer = layers [0 ])
351+ )
352+ y_arr = (
353+ adata [A .layers [layers [1 ]][:, y ]]
354+ if is_anndata_13
355+ else adata .obs_vector (y , layer = layers [1 ])
356+ )
347357
348358 xy = np .c_ [x_arr , y_arr ]
349359 else :
@@ -400,9 +410,13 @@ def _scatter_obs( # noqa: PLR0912, PLR0913, PLR0915
400410 c = adata .obs [key ].to_numpy ()
401411 # coloring according to gene expression
402412 elif use_raw and adata .raw is not None and key in adata .raw .var_names :
403- c = adata .raw .obs_vector (key )
413+ c = adata [ A . obs [ key ]] if is_anndata_13 else adata .raw .obs_vector (key )
404414 elif key in adata .var_names :
405- c = adata .obs_vector (key , layer = layers [2 ])
415+ c = (
416+ adata [A .layers [layers [2 ][:, key ]]]
417+ if is_anndata_13
418+ else adata .obs_vector (key , layer = layers [2 ])
419+ )
406420 elif is_color_like (key ): # a flat color
407421 c = key
408422 colorbar = False
0 commit comments