@@ -1814,22 +1814,60 @@ def test_concat_on_var_outer_join(array_type):
18141814 _ = concat ([a , b ], join = "outer" , axis = 1 )
18151815
18161816
1817- def test_concat_dask_sparse_matches_memory (join_type , merge_strategy ):
1817+ @pytest .mark .parametrize ("format" , ["csr" , "csc" ])
1818+ @pytest .mark .parametrize (
1819+ "unchunked_minor_axis" , [True , False ], ids = ["unchunked_minor" , "chunked_minor" ]
1820+ )
1821+ @pytest .mark .parametrize ("fill_value" , [0 , - 1 ])
1822+ def test_concat_dask_sparse_matches_memory (
1823+ join_type ,
1824+ merge_strategy ,
1825+ format : Literal ["csr" , "csc" ],
1826+ axis_name : Literal ["obs" , "var" ],
1827+ fill_value : Literal [- 1 , 0 ],
1828+ * ,
1829+ unchunked_minor_axis : bool ,
1830+ ):
18181831 import dask .array as da
18191832
1820- X = sparse .random (50 , 20 , density = 0.5 , format = "csr" )
1821- X_dask = da .from_array (X , chunks = (5 , 20 ))
1822- var_names_1 = [f"gene_{ i } " for i in range (20 )]
1823- var_names_2 = [f"gene_{ i } { '_foo' if (i % 2 ) else '' } " for i in range (20 )]
1833+ X = sparse .random (50 , 20 , density = 0.5 , format = format )
1834+ X_dask = da .from_array (
1835+ X ,
1836+ chunks = (
1837+ X .shape [0 ] if format == "csc" else 10 ,
1838+ X .shape [1 ] if format == "csr" else 5 ,
1839+ )
1840+ if unchunked_minor_axis
1841+ else (5 , 10 ),
1842+ )
1843+ off_axis_idx = int (axis_name == "obs" )
1844+ concat_axis_idx = int (axis_name == "var" )
1845+ off_axis = "var" if axis_name == "obs" else "obs"
1846+ axis_names_1 = [f"off_axis_{ i } " for i in range (X .shape [off_axis_idx ])]
1847+ axis_names_2 = [
1848+ f"off_axis_{ i } { '_foo' if (i % 2 ) else '' } " for i in range (X .shape [off_axis_idx ])
1849+ ]
18241850
1825- ad1 = AnnData (X = X , var = pd .DataFrame (index = var_names_1 ) )
1826- ad2 = AnnData (X = X , var = pd .DataFrame (index = var_names_2 ) )
1851+ ad1 = AnnData (X = X , ** { off_axis : pd .DataFrame (index = axis_names_1 )} )
1852+ ad2 = AnnData (X = X , ** { off_axis : pd .DataFrame (index = axis_names_2 )} )
18271853
1828- ad1_dask = AnnData (X = X_dask , var = pd .DataFrame (index = var_names_1 ) )
1829- ad2_dask = AnnData (X = X_dask , var = pd .DataFrame (index = var_names_2 ) )
1854+ ad1_dask = AnnData (X = X_dask , ** { off_axis : pd .DataFrame (index = axis_names_1 )} )
1855+ ad2_dask = AnnData (X = X_dask , ** { off_axis : pd .DataFrame (index = axis_names_2 )} )
18301856
1831- res_in_memory = concat ([ad1 , ad2 ], join = join_type , merge = merge_strategy )
1832- res_dask = concat ([ad1_dask , ad2_dask ], join = join_type , merge = merge_strategy )
1857+ res_in_memory = concat (
1858+ [ad1 , ad2 ],
1859+ join = join_type ,
1860+ merge = merge_strategy ,
1861+ axis = concat_axis_idx ,
1862+ fill_value = fill_value ,
1863+ )
1864+ res_dask = concat (
1865+ [ad1_dask , ad2_dask ],
1866+ join = join_type ,
1867+ merge = merge_strategy ,
1868+ axis = concat_axis_idx ,
1869+ fill_value = fill_value ,
1870+ )
18331871 assert_equal (res_in_memory , res_dask )
18341872
18351873
0 commit comments