66from collections import Counter , defaultdict
77from collections .abc import Mapping
88from functools import partial , singledispatch , wraps
9+ from importlib .util import find_spec
910from string import ascii_letters
1011from typing import TYPE_CHECKING
1112
@@ -311,7 +312,6 @@ def gen_adata( # noqa: PLR0913
311312 (csr, csc)
312313 """
313314 import dask .array as da
314- import xarray as xr
315315
316316 if random_state is None :
317317 random_state = np .random .default_rng ()
@@ -325,10 +325,11 @@ def gen_adata( # noqa: PLR0913
325325 obs .rename (columns = dict (cat = "obs_cat" ), inplace = True )
326326 var .rename (columns = dict (cat = "var_cat" ), inplace = True )
327327
328- if obs_xdataset :
329- obs = XDataset .from_dataframe (obs )
330- if var_xdataset :
331- var = XDataset .from_dataframe (var )
328+ if has_xr := find_spec ("xarray" ):
329+ if obs_xdataset :
330+ obs = XDataset .from_dataframe (obs )
331+ if var_xdataset :
332+ var = XDataset .from_dataframe (var )
332333
333334 if X_type is None :
334335 X = None
@@ -341,27 +342,28 @@ def gen_adata( # noqa: PLR0913
341342 df = gen_typed_df (M , obs_names , dtypes = obs_dtypes ),
342343 awk_2d_ragged = gen_awkward ((M , None )),
343344 da = da .random .random ((M , 50 )),
344- xdataset = xr .Dataset .from_dataframe (
345- gen_typed_df (M , obs_names , dtypes = obs_dtypes )
346- ),
347- )
348- obsm = {k : v for k , v in obsm .items () if type (v ) in obsm_types }
349- obsm = maybe_add_sparse_array (
350- mapping = obsm ,
351- types = obsm_types ,
352- format = sparse_fmt ,
353- random_state = random_state ,
354- shape = (M , 100 ),
355345 )
356346 varm = dict (
357347 array = np .random .random ((N , 50 )),
358348 sparse = sparse .random (N , 100 , format = sparse_fmt , random_state = random_state ),
359349 df = gen_typed_df (N , var_names , dtypes = var_dtypes ),
360350 awk_2d_ragged = gen_awkward ((N , None )),
361351 da = da .random .random ((N , 50 )),
362- xdataset = xr .Dataset .from_dataframe (
352+ )
353+ if has_xr :
354+ obsm ["xdataset" ] = XDataset .from_dataframe (
355+ gen_typed_df (M , obs_names , dtypes = obs_dtypes )
356+ )
357+ varm ["xdataset" ] = XDataset .from_dataframe (
363358 gen_typed_df (N , var_names , dtypes = var_dtypes )
364- ),
359+ )
360+ obsm = {k : v for k , v in obsm .items () if type (v ) in obsm_types }
361+ obsm = maybe_add_sparse_array (
362+ mapping = obsm ,
363+ types = obsm_types ,
364+ format = sparse_fmt ,
365+ random_state = random_state ,
366+ shape = (M , 100 ),
365367 )
366368 varm = {k : v for k , v in varm .items () if type (v ) in varm_types }
367369 varm = maybe_add_sparse_array (
0 commit comments