|
36 | 36 | from ._utils import ( |
37 | 37 | _deprecated_scale, |
38 | 38 | _dk, |
| 39 | + _obs_vector_compat, |
39 | 40 | check_colornorm, |
40 | 41 | scatter_base, |
41 | 42 | scatter_group, |
|
55 | 56 | from seaborn.matrix import ClusterGrid |
56 | 57 |
|
57 | 58 | from .._utils import Empty |
58 | | - from ._utils import ( |
59 | | - ColorLike, |
60 | | - DensityNorm, |
61 | | - _FontSize, |
62 | | - _FontWeight, |
63 | | - _LegendLoc, |
64 | | - ) |
| 59 | + from ._utils import ColorLike, DensityNorm, _FontSize, _FontWeight, _LegendLoc |
65 | 60 |
|
66 | 61 | # TODO: is that all? |
67 | 62 | type _Basis = Literal["pca", "tsne", "umap", "diffmap", "draw_graph_fr"] |
@@ -324,27 +319,18 @@ def _scatter_obs( # noqa: PLR0912, PLR0913, PLR0915 |
324 | 319 | # ignore the '0th' diffusion component |
325 | 320 | if basis == "diffmap": |
326 | 321 | components += 1 |
327 | | - xy = adata.obsm["X_" + basis][:, components] |
| 322 | + xy = adata.obsm[f"X_{basis}"][:, components] |
328 | 323 | # correct the component vector for use in labeling etc. |
329 | 324 | if basis == "diffmap": |
330 | 325 | components -= 1 |
331 | 326 | except KeyError: |
332 | 327 | msg = f"compute coordinates using visualization tool {basis} first" |
333 | 328 | raise KeyError(msg) from None |
334 | 329 | elif x is not None and y is not None: |
335 | | - if use_raw: |
336 | | - if x in adata.obs.columns: |
337 | | - x_arr = adata.obs_vector(x) |
338 | | - else: |
339 | | - x_arr = adata.raw.obs_vector(x) |
340 | | - if y in adata.obs.columns: |
341 | | - y_arr = adata.obs_vector(y) |
342 | | - else: |
343 | | - y_arr = adata.raw.obs_vector(y) |
344 | | - else: |
345 | | - x_arr = adata.obs_vector(x, layer=layers[0]) |
346 | | - y_arr = adata.obs_vector(y, layer=layers[1]) |
347 | | - |
| 330 | + x_arr, y_arr = ( |
| 331 | + _obs_vector_compat(adata, k, use_raw=use_raw, layer=layer) |
| 332 | + for layer, k in zip(layers, [x, y], strict=False) |
| 333 | + ) |
348 | 334 | xy = np.c_[x_arr, y_arr] |
349 | 335 | else: |
350 | 336 | msg = "Either provide a `basis` or `x` and `y`." |
@@ -399,10 +385,10 @@ def _scatter_obs( # noqa: PLR0912, PLR0913, PLR0915 |
399 | 385 | else: |
400 | 386 | c = adata.obs[key].to_numpy() |
401 | 387 | # coloring according to gene expression |
402 | | - elif use_raw and adata.raw is not None and key in adata.raw.var_names: |
403 | | - c = adata.raw.obs_vector(key) |
404 | | - elif key in adata.var_names: |
405 | | - c = adata.obs_vector(key, layer=layers[2]) |
| 388 | + elif (use_raw and adata.raw is not None and key in adata.raw.var_names) or ( |
| 389 | + key in adata.var_names |
| 390 | + ): |
| 391 | + c = _obs_vector_compat(adata, key, use_raw=use_raw, layer=layers[2]) |
406 | 392 | elif is_color_like(key): # a flat color |
407 | 393 | c = key |
408 | 394 | colorbar = False |
|
0 commit comments