Skip to content

Commit 8e9eb88

Browse files
authored
Simplify roundtrip io tests (#1702)
1 parent e939e95 commit 8e9eb88

1 file changed

Lines changed: 45 additions & 46 deletions

File tree

tests/test_readwrite.py

Lines changed: 45 additions & 46 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import re
44
import warnings
55
from contextlib import contextmanager
6+
from functools import partial
67
from importlib.util import find_spec
78
from pathlib import Path
89
from string import ascii_letters
@@ -25,6 +26,7 @@
2526

2627
if TYPE_CHECKING:
2728
from os import PathLike
29+
from typing import Literal
2830

2931
HERE = Path(__file__).parent
3032

@@ -658,30 +660,13 @@ def random_cats(n):
658660
assert_equal(orig, curr)
659661

660662

661-
def test_write_string_types(tmp_path, diskfmt):
662-
# https://github.com/scverse/anndata/issues/456
663-
adata_pth = tmp_path / f"adata.{diskfmt}"
664-
665-
adata = ad.AnnData(
666-
obs=pd.DataFrame(
667-
np.ones((3, 2)),
668-
columns=["a", np.str_("b")],
669-
index=["a", "b", "c"],
670-
),
671-
)
672-
673-
write = getattr(adata, f"write_{diskfmt}")
674-
read = getattr(ad, f"read_{diskfmt}")
675-
676-
write(adata_pth)
677-
from_disk = read(adata_pth)
678-
679-
assert_equal(adata, from_disk)
680-
663+
def test_write_string_type_error(tmp_path, diskfmt):
664+
adata = ad.AnnData(obs=dict(obs_names=list("abc")))
681665
adata.obs[b"c"] = np.zeros(3)
666+
682667
# This should error, and tell you which key is at fault
683668
with pytest.raises(TypeError, match=r"writing key 'obs'") as exc_info:
684-
write(adata_pth)
669+
getattr(adata, f"write_{diskfmt}")(tmp_path / f"adata.{diskfmt}")
685670

686671
assert "b'c'" in str(exc_info.value)
687672

@@ -722,15 +707,39 @@ def test_zarr_chunk_X(tmp_path):
722707
# Round-tripping scanpy datasets
723708
################################
724709

725-
diskfmt2 = diskfmt
710+
711+
def _do_roundtrip(
712+
adata: ad.AnnData, pth: Path, diskfmt: Literal["h5ad", "zarr"]
713+
) -> ad.AnnData:
714+
getattr(adata, f"write_{diskfmt}")(pth)
715+
return getattr(ad, f"read_{diskfmt}")(pth)
716+
717+
718+
@pytest.fixture
719+
def roundtrip(diskfmt):
720+
return partial(_do_roundtrip, diskfmt=diskfmt)
721+
722+
723+
def test_write_string_types(tmp_path, diskfmt, roundtrip):
724+
# https://github.com/scverse/anndata/issues/456
725+
adata_pth = tmp_path / f"adata.{diskfmt}"
726+
727+
adata = ad.AnnData(
728+
obs=pd.DataFrame(
729+
np.ones((3, 2)),
730+
columns=["a", np.str_("b")],
731+
index=["a", "b", "c"],
732+
),
733+
)
734+
735+
from_disk = roundtrip(adata, adata_pth)
736+
737+
assert_equal(adata, from_disk)
726738

727739

728740
@pytest.mark.skipif(not find_spec("scanpy"), reason="Scanpy is not installed")
729-
def test_scanpy_pbmc68k(tmp_path, diskfmt, diskfmt2):
730-
read1 = lambda pth: getattr(ad, f"read_{diskfmt}")(pth)
731-
write1 = lambda adata, pth: getattr(adata, f"write_{diskfmt}")(pth)
732-
read2 = lambda pth: getattr(ad, f"read_{diskfmt2}")(pth)
733-
write2 = lambda adata, pth: getattr(adata, f"write_{diskfmt2}")(pth)
741+
def test_scanpy_pbmc68k(tmp_path, diskfmt, roundtrip, diskfmt2):
742+
roundtrip2 = partial(_do_roundtrip, diskfmt=diskfmt2)
734743

735744
filepth1 = tmp_path / f"test1.{diskfmt}"
736745
filepth2 = tmp_path / f"test2.{diskfmt2}"
@@ -745,17 +754,15 @@ def test_scanpy_pbmc68k(tmp_path, diskfmt, diskfmt2):
745754
warnings.simplefilter("ignore", ad.OldFormatWarning)
746755
pbmc = sc.datasets.pbmc68k_reduced()
747756

748-
write1(pbmc, filepth1)
749-
from_disk1 = read1(filepth1) # Do we read okay
750-
write2(from_disk1, filepth2) # Can we round trip
751-
from_disk2 = read2(filepth2)
757+
from_disk1 = roundtrip(pbmc, filepth1) # Do we read okay
758+
from_disk2 = roundtrip2(from_disk1, filepth2) # Can we round trip
752759

753760
assert_equal(pbmc, from_disk1) # Not expected to be exact due to `nan`s
754761
assert_equal(pbmc, from_disk2)
755762

756763

757764
@pytest.mark.skipif(not find_spec("scanpy"), reason="Scanpy is not installed")
758-
def test_scanpy_krumsiek11(tmp_path, diskfmt):
765+
def test_scanpy_krumsiek11(tmp_path, diskfmt, roundtrip):
759766
filepth = tmp_path / f"test.{diskfmt}"
760767
with warnings.catch_warnings():
761768
warnings.filterwarnings(
@@ -769,11 +776,10 @@ def test_scanpy_krumsiek11(tmp_path, diskfmt):
769776
del orig.uns["highlights"] # Can’t write int keys
770777
# Can’t write "string" dtype: https://github.com/scverse/anndata/issues/679
771778
orig.obs["cell_type"] = orig.obs["cell_type"].astype(str)
772-
getattr(orig, f"write_{diskfmt}")(filepth)
773779
with pytest.warns(UserWarning, match=r"Observation names are not unique"):
774-
read = getattr(ad, f"read_{diskfmt}")(filepth)
780+
curr = roundtrip(orig, filepth)
775781

776-
assert_equal(orig, read, exact=True)
782+
assert_equal(orig, curr, exact=True)
777783

778784

779785
# Checking if we can read legacy zarr files
@@ -808,11 +814,8 @@ def test_backwards_compat_zarr():
808814
assert_equal(pbmc_zarr, pbmc_orig)
809815

810816

811-
# TODO: use diskfmt fixture once zarr backend implemented
812-
def test_adata_in_uns(tmp_path, diskfmt):
817+
def test_adata_in_uns(tmp_path, diskfmt, roundtrip):
813818
pth = tmp_path / f"adatas_in_uns.{diskfmt}"
814-
read = lambda pth: getattr(ad, f"read_{diskfmt}")(pth)
815-
write = lambda adata, pth: getattr(adata, f"write_{diskfmt}")(pth)
816819

817820
orig = gen_adata((4, 5))
818821
orig.uns["adatas"] = {
@@ -823,20 +826,16 @@ def test_adata_in_uns(tmp_path, diskfmt):
823826
another_one.raw = gen_adata((2, 7))
824827
orig.uns["adatas"]["b"].uns["another_one"] = another_one
825828

826-
write(orig, pth)
827-
curr = read(pth)
829+
curr = roundtrip(orig, pth)
828830

829831
assert_equal(orig, curr)
830832

831833

832-
def test_io_dtype(tmp_path, diskfmt, dtype):
834+
def test_io_dtype(tmp_path, diskfmt, dtype, roundtrip):
833835
pth = tmp_path / f"adata_dtype.{diskfmt}"
834-
read = lambda pth: getattr(ad, f"read_{diskfmt}")(pth)
835-
write = lambda adata, pth: getattr(adata, f"write_{diskfmt}")(pth)
836836

837837
orig = ad.AnnData(np.ones((5, 8), dtype=dtype))
838-
write(orig, pth)
839-
curr = read(pth)
838+
curr = roundtrip(orig, pth)
840839

841840
assert curr.X.dtype == dtype
842841

0 commit comments

Comments
 (0)