11import pathlib
22
33import numpy as np
4+ import pytest
45from numpy .typing import NDArray
56
67import imgviz
78
89
10+ @pytest .mark .parametrize (
11+ "shape" , [(15 , 20 , 3 ), (15 , 20 , 4 ), (15 , 20 )], ids = ["rgb" , "rgba" , "grayscale" ]
12+ )
13+ def test_imsave_imread_roundtrip (
14+ tmp_path : pathlib .Path , shape : tuple [int , ...]
15+ ) -> None :
16+ image = np .random .RandomState (0 ).uniform (0 , 255 , shape ).astype (np .uint8 )
17+
18+ imgviz .io .imsave (tmp_path / "image.png" , image )
19+ read = imgviz .io .imread (tmp_path / "image.png" )
20+
21+ assert read .dtype == np .uint8
22+ np .testing .assert_array_equal (read , image )
23+
24+
25+ def test_imsave_creates_parent_dirs (tmp_path : pathlib .Path ) -> None :
26+ rgb = np .zeros ((4 , 4 , 3 ), dtype = np .uint8 )
27+
28+ imgviz .io .imsave (tmp_path / "sub" / "deep" / "rgb.png" , rgb )
29+
30+ assert (tmp_path / "sub" / "deep" / "rgb.png" ).exists ()
31+
32+
33+ def test_lblsave_rejects_non_png_extension (tmp_path : pathlib .Path ) -> None :
34+ lbl = np .zeros ((4 , 4 ), dtype = np .uint8 )
35+
36+ with pytest .raises (ValueError , match = r"filename must end with '\.png'" ):
37+ imgviz .io .lblsave (tmp_path / "label.jpg" , lbl )
38+
39+
40+ def test_lblsave_rejects_non_uint8 (tmp_path : pathlib .Path ) -> None :
41+ lbl = np .zeros ((4 , 4 ), dtype = np .int32 )
42+
43+ with pytest .raises (ValueError , match = r"lbl\.dtype must be np\.uint8" ):
44+ imgviz .io .lblsave (tmp_path / "label.png" , lbl )
45+
46+
947def test_lblsave (tmp_path : pathlib .Path ) -> None :
1048 data = imgviz .data .arc2017 ()
1149
@@ -20,4 +58,4 @@ def test_lblsave(tmp_path: pathlib.Path) -> None:
2058 imgviz .io .lblsave (png_file , label_cls )
2159 label_cls_read = imgviz .io .imread (png_file )
2260
23- np .testing .assert_allclose (label_cls , label_cls_read )
61+ np .testing .assert_array_equal (label_cls , label_cls_read )
0 commit comments