Skip to content

Commit e795dd3

Browse files
authored
Merge pull request #178 from wkentaro/test/io-coverage
test: cover imgviz.io round-trips and lblsave error paths
2 parents 5dc8ec3 + adae1be commit e795dd3

1 file changed

Lines changed: 39 additions & 1 deletion

File tree

tests/unit/io_test.py

Lines changed: 39 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,49 @@
11
import pathlib
22

33
import numpy as np
4+
import pytest
45
from numpy.typing import NDArray
56

67
import imgviz
78

89

10+
@pytest.mark.parametrize(
11+
"shape", [(15, 20, 3), (15, 20, 4), (15, 20)], ids=["rgb", "rgba", "grayscale"]
12+
)
13+
def test_imsave_imread_roundtrip(
14+
tmp_path: pathlib.Path, shape: tuple[int, ...]
15+
) -> None:
16+
image = np.random.RandomState(0).uniform(0, 255, shape).astype(np.uint8)
17+
18+
imgviz.io.imsave(tmp_path / "image.png", image)
19+
read = imgviz.io.imread(tmp_path / "image.png")
20+
21+
assert read.dtype == np.uint8
22+
np.testing.assert_array_equal(read, image)
23+
24+
25+
def test_imsave_creates_parent_dirs(tmp_path: pathlib.Path) -> None:
26+
rgb = np.zeros((4, 4, 3), dtype=np.uint8)
27+
28+
imgviz.io.imsave(tmp_path / "sub" / "deep" / "rgb.png", rgb)
29+
30+
assert (tmp_path / "sub" / "deep" / "rgb.png").exists()
31+
32+
33+
def test_lblsave_rejects_non_png_extension(tmp_path: pathlib.Path) -> None:
34+
lbl = np.zeros((4, 4), dtype=np.uint8)
35+
36+
with pytest.raises(ValueError, match=r"filename must end with '\.png'"):
37+
imgviz.io.lblsave(tmp_path / "label.jpg", lbl)
38+
39+
40+
def test_lblsave_rejects_non_uint8(tmp_path: pathlib.Path) -> None:
41+
lbl = np.zeros((4, 4), dtype=np.int32)
42+
43+
with pytest.raises(ValueError, match=r"lbl\.dtype must be np\.uint8"):
44+
imgviz.io.lblsave(tmp_path / "label.png", lbl)
45+
46+
947
def test_lblsave(tmp_path: pathlib.Path) -> None:
1048
data = imgviz.data.arc2017()
1149

@@ -20,4 +58,4 @@ def test_lblsave(tmp_path: pathlib.Path) -> None:
2058
imgviz.io.lblsave(png_file, label_cls)
2159
label_cls_read = imgviz.io.imread(png_file)
2260

23-
np.testing.assert_allclose(label_cls, label_cls_read)
61+
np.testing.assert_array_equal(label_cls, label_cls_read)

0 commit comments

Comments
 (0)