33import re
44import warnings
55from contextlib import contextmanager
6+ from functools import partial
67from importlib .util import find_spec
78from pathlib import Path
89from string import ascii_letters
2526
2627if TYPE_CHECKING :
2728 from os import PathLike
29+ from typing import Literal
2830
2931HERE = Path (__file__ ).parent
3032
@@ -658,30 +660,13 @@ def random_cats(n):
658660 assert_equal (orig , curr )
659661
660662
661- def test_write_string_types (tmp_path , diskfmt ):
662- # https://github.com/scverse/anndata/issues/456
663- adata_pth = tmp_path / f"adata.{ diskfmt } "
664-
665- adata = ad .AnnData (
666- obs = pd .DataFrame (
667- np .ones ((3 , 2 )),
668- columns = ["a" , np .str_ ("b" )],
669- index = ["a" , "b" , "c" ],
670- ),
671- )
672-
673- write = getattr (adata , f"write_{ diskfmt } " )
674- read = getattr (ad , f"read_{ diskfmt } " )
675-
676- write (adata_pth )
677- from_disk = read (adata_pth )
678-
679- assert_equal (adata , from_disk )
680-
663+ def test_write_string_type_error (tmp_path , diskfmt ):
664+ adata = ad .AnnData (obs = dict (obs_names = list ("abc" )))
681665 adata .obs [b"c" ] = np .zeros (3 )
666+
682667 # This should error, and tell you which key is at fault
683668 with pytest .raises (TypeError , match = r"writing key 'obs'" ) as exc_info :
684- write ( adata_pth )
669+ getattr ( adata , f"write_ { diskfmt } " )( tmp_path / f"adata. { diskfmt } " )
685670
686671 assert "b'c'" in str (exc_info .value )
687672
@@ -722,15 +707,39 @@ def test_zarr_chunk_X(tmp_path):
722707# Round-tripping scanpy datasets
723708################################
724709
725- diskfmt2 = diskfmt
710+
711+ def _do_roundtrip (
712+ adata : ad .AnnData , pth : Path , diskfmt : Literal ["h5ad" , "zarr" ]
713+ ) -> ad .AnnData :
714+ getattr (adata , f"write_{ diskfmt } " )(pth )
715+ return getattr (ad , f"read_{ diskfmt } " )(pth )
716+
717+
718+ @pytest .fixture
719+ def roundtrip (diskfmt ):
720+ return partial (_do_roundtrip , diskfmt = diskfmt )
721+
722+
723+ def test_write_string_types (tmp_path , diskfmt , roundtrip ):
724+ # https://github.com/scverse/anndata/issues/456
725+ adata_pth = tmp_path / f"adata.{ diskfmt } "
726+
727+ adata = ad .AnnData (
728+ obs = pd .DataFrame (
729+ np .ones ((3 , 2 )),
730+ columns = ["a" , np .str_ ("b" )],
731+ index = ["a" , "b" , "c" ],
732+ ),
733+ )
734+
735+ from_disk = roundtrip (adata , adata_pth )
736+
737+ assert_equal (adata , from_disk )
726738
727739
728740@pytest .mark .skipif (not find_spec ("scanpy" ), reason = "Scanpy is not installed" )
729- def test_scanpy_pbmc68k (tmp_path , diskfmt , diskfmt2 ):
730- read1 = lambda pth : getattr (ad , f"read_{ diskfmt } " )(pth )
731- write1 = lambda adata , pth : getattr (adata , f"write_{ diskfmt } " )(pth )
732- read2 = lambda pth : getattr (ad , f"read_{ diskfmt2 } " )(pth )
733- write2 = lambda adata , pth : getattr (adata , f"write_{ diskfmt2 } " )(pth )
741+ def test_scanpy_pbmc68k (tmp_path , diskfmt , roundtrip , diskfmt2 ):
742+ roundtrip2 = partial (_do_roundtrip , diskfmt = diskfmt2 )
734743
735744 filepth1 = tmp_path / f"test1.{ diskfmt } "
736745 filepth2 = tmp_path / f"test2.{ diskfmt2 } "
@@ -745,17 +754,15 @@ def test_scanpy_pbmc68k(tmp_path, diskfmt, diskfmt2):
745754 warnings .simplefilter ("ignore" , ad .OldFormatWarning )
746755 pbmc = sc .datasets .pbmc68k_reduced ()
747756
748- write1 (pbmc , filepth1 )
749- from_disk1 = read1 (filepth1 ) # Do we read okay
750- write2 (from_disk1 , filepth2 ) # Can we round trip
751- from_disk2 = read2 (filepth2 )
757+ from_disk1 = roundtrip (pbmc , filepth1 ) # Do we read okay
758+ from_disk2 = roundtrip2 (from_disk1 , filepth2 ) # Can we round trip
752759
753760 assert_equal (pbmc , from_disk1 ) # Not expected to be exact due to `nan`s
754761 assert_equal (pbmc , from_disk2 )
755762
756763
757764@pytest .mark .skipif (not find_spec ("scanpy" ), reason = "Scanpy is not installed" )
758- def test_scanpy_krumsiek11 (tmp_path , diskfmt ):
765+ def test_scanpy_krumsiek11 (tmp_path , diskfmt , roundtrip ):
759766 filepth = tmp_path / f"test.{ diskfmt } "
760767 with warnings .catch_warnings ():
761768 warnings .filterwarnings (
@@ -769,11 +776,10 @@ def test_scanpy_krumsiek11(tmp_path, diskfmt):
769776 del orig .uns ["highlights" ] # Can’t write int keys
770777 # Can’t write "string" dtype: https://github.com/scverse/anndata/issues/679
771778 orig .obs ["cell_type" ] = orig .obs ["cell_type" ].astype (str )
772- getattr (orig , f"write_{ diskfmt } " )(filepth )
773779 with pytest .warns (UserWarning , match = r"Observation names are not unique" ):
774- read = getattr ( ad , f"read_ { diskfmt } " )( filepth )
780+ curr = roundtrip ( orig , filepth )
775781
776- assert_equal (orig , read , exact = True )
782+ assert_equal (orig , curr , exact = True )
777783
778784
779785# Checking if we can read legacy zarr files
@@ -808,11 +814,8 @@ def test_backwards_compat_zarr():
808814 assert_equal (pbmc_zarr , pbmc_orig )
809815
810816
811- # TODO: use diskfmt fixture once zarr backend implemented
812- def test_adata_in_uns (tmp_path , diskfmt ):
817+ def test_adata_in_uns (tmp_path , diskfmt , roundtrip ):
813818 pth = tmp_path / f"adatas_in_uns.{ diskfmt } "
814- read = lambda pth : getattr (ad , f"read_{ diskfmt } " )(pth )
815- write = lambda adata , pth : getattr (adata , f"write_{ diskfmt } " )(pth )
816819
817820 orig = gen_adata ((4 , 5 ))
818821 orig .uns ["adatas" ] = {
@@ -823,20 +826,16 @@ def test_adata_in_uns(tmp_path, diskfmt):
823826 another_one .raw = gen_adata ((2 , 7 ))
824827 orig .uns ["adatas" ]["b" ].uns ["another_one" ] = another_one
825828
826- write (orig , pth )
827- curr = read (pth )
829+ curr = roundtrip (orig , pth )
828830
829831 assert_equal (orig , curr )
830832
831833
832- def test_io_dtype (tmp_path , diskfmt , dtype ):
834+ def test_io_dtype (tmp_path , diskfmt , dtype , roundtrip ):
833835 pth = tmp_path / f"adata_dtype.{ diskfmt } "
834- read = lambda pth : getattr (ad , f"read_{ diskfmt } " )(pth )
835- write = lambda adata , pth : getattr (adata , f"write_{ diskfmt } " )(pth )
836836
837837 orig = ad .AnnData (np .ones ((5 , 8 ), dtype = dtype ))
838- write (orig , pth )
839- curr = read (pth )
838+ curr = roundtrip (orig , pth )
840839
841840 assert curr .X .dtype == dtype
842841
0 commit comments