Skip to content

Commit 256ae5e

Browse files
authored
Merge pull request #291 from jayzalani/tests
Build: complete test for remaining utils files name: io.py and lidar.py (#239)
2 parents 206e289 + b444612 commit 256ae5e

File tree

2 files changed

+343
-0
lines changed

2 files changed

+343
-0
lines changed

tests/test_io.py

Lines changed: 70 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,70 @@
1+
import json
2+
from PIL import Image
3+
from unittest.mock import mock_open, patch
4+
from detectionmetrics.utils.io import (
5+
read_txt,
6+
read_yaml,
7+
read_json,
8+
write_json,
9+
get_image_mode,
10+
extract_wildcard_matches,
11+
)
12+
13+
14+
# Test read_txt
15+
def test_read_txt():
16+
mock_content = "line1\nline2\nline3\n"
17+
with patch("builtins.open", mock_open(read_data=mock_content)):
18+
result = read_txt("dummy.txt")
19+
assert result == ["line1", "line2", "line3"]
20+
21+
22+
# Test read_yaml
23+
def test_read_yaml():
24+
mock_yaml = "key1: value1\nkey2: value2"
25+
with patch("builtins.open", mock_open(read_data=mock_yaml)):
26+
result = read_yaml("dummy.yaml")
27+
assert result == {"key1": "value1", "key2": "value2"}
28+
29+
30+
# Test read_json
31+
def test_read_json():
32+
mock_json = json.dumps({"key": "value"})
33+
with patch("builtins.open", mock_open(read_data=mock_json)):
34+
result = read_json("dummy.json")
35+
assert result == {"key": "value"}
36+
37+
38+
# Test write_json
39+
def test_write_json():
40+
data = {"name": "pytest"}
41+
mock_file = mock_open()
42+
43+
with patch("builtins.open", mock_file):
44+
write_json("dummy.json", data)
45+
46+
# Retrieve all calls to write()
47+
written_data = "".join(call.args[0] for call in mock_file().write.call_args_list)
48+
49+
assert json.loads(written_data) == data
50+
51+
52+
# Test get_image_mode
53+
def test_get_image_mode(tmp_path):
54+
img_path = tmp_path / "test_image.png"
55+
img = Image.new("RGB", (10, 10), color="red")
56+
img.save(img_path)
57+
58+
assert get_image_mode(str(img_path)) == "RGB"
59+
60+
61+
# Test extract_wildcard_matches
62+
def test_extract_wildcard_matches(tmp_path):
63+
(tmp_path / "file1.txt").touch()
64+
(tmp_path / "file2.txt").touch()
65+
with patch(
66+
"detectionmetrics.utils.io.glob",
67+
return_value=[str(tmp_path / "file1.txt"), str(tmp_path / "file2.txt")],
68+
):
69+
matches = extract_wildcard_matches(str(tmp_path / "*.txt"))
70+
assert len(matches) == 2

tests/test_lidar.py

Lines changed: 273 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,273 @@
1+
import pytest
2+
import numpy as np
3+
from unittest.mock import patch, MagicMock
4+
from sklearn.neighbors import KDTree
5+
import open3d as o3d
6+
from PIL import Image
7+
from detectionmetrics.utils.lidar import (
8+
Sampler,
9+
recenter,
10+
build_point_cloud,
11+
view_point_cloud,
12+
render_point_cloud,
13+
REFERENCE_SIZE,
14+
CAMERA_VIEWS
15+
)
16+
17+
18+
@pytest.fixture
19+
def sample_points():
20+
"""Fixture to generate reproducible sample points for testing."""
21+
np.random.seed(42)
22+
return np.random.rand(100, 3)
23+
24+
25+
@pytest.fixture
26+
def sample_colors():
27+
"""Fixture to generate reproducible sample colors for testing."""
28+
np.random.seed(42)
29+
return np.random.rand(100, 3)
30+
31+
32+
@pytest.fixture
33+
def sample_kdtree(sample_points):
34+
"""Create a KDTree from sample points."""
35+
return KDTree(sample_points)
36+
37+
38+
class TestSampler:
39+
"""Tests for the Sampler class."""
40+
41+
def test_valid_samplers(self, sample_points, sample_kdtree):
42+
"""Test initialization with valid samplers."""
43+
# Test with random sampler
44+
random_sampler = Sampler(
45+
point_cloud_size=len(sample_points),
46+
search_tree=sample_kdtree,
47+
sampler_name="random",
48+
num_classes=10,
49+
seed=42
50+
)
51+
52+
assert random_sampler.num_classes == 10
53+
assert random_sampler.test_probs.shape == (len(sample_points), 10)
54+
assert random_sampler.sample.__name__ == "random"
55+
56+
# Test with spatially_regular sampler
57+
spatial_sampler = Sampler(
58+
point_cloud_size=len(sample_points),
59+
search_tree=sample_kdtree,
60+
sampler_name="spatially_regular",
61+
num_classes=10,
62+
seed=42
63+
)
64+
65+
assert spatial_sampler.sample.__name__ == "spatially_regular"
66+
67+
def test_invalid_sampler(self, sample_points, sample_kdtree):
68+
"""Test initialization with invalid sampler name."""
69+
# Handling the fact that the original code tries to access self.model_cfg['sampler']
70+
# We expect an AttributeError rather than NotImplementedError
71+
with pytest.raises(AttributeError):
72+
Sampler(
73+
point_cloud_size=len(sample_points),
74+
search_tree=sample_kdtree,
75+
sampler_name="invalid_sampler",
76+
num_classes=10,
77+
seed=42
78+
)
79+
80+
def test_get_indices_small_cloud(self, sample_points, sample_kdtree):
81+
"""Test _get_indices when point_cloud_size < num_points."""
82+
sampler = Sampler(
83+
point_cloud_size=len(sample_points),
84+
search_tree=sample_kdtree,
85+
sampler_name="random",
86+
num_classes=10,
87+
seed=42
88+
)
89+
90+
point_cloud_size = 20
91+
num_points = 30
92+
center_point = np.array([[0.5, 0.5, 0.5]])
93+
94+
indices = sampler._get_indices(point_cloud_size, num_points, center_point)
95+
96+
assert len(indices) == num_points
97+
assert np.max(indices) < point_cloud_size # All indices should be within range
98+
99+
def test_get_indices_large_cloud(self, sample_points, sample_kdtree):
100+
"""Test _get_indices when point_cloud_size >= num_points."""
101+
sampler = Sampler(
102+
point_cloud_size=len(sample_points),
103+
search_tree=sample_kdtree,
104+
sampler_name="random",
105+
num_classes=10,
106+
seed=42
107+
)
108+
109+
point_cloud_size = 100
110+
num_points = 10
111+
center_point = np.array([[0.5, 0.5, 0.5]])
112+
113+
indices = sampler._get_indices(point_cloud_size, num_points, center_point)
114+
115+
assert len(indices) == num_points
116+
assert np.max(indices) < point_cloud_size
117+
118+
def test_random_sampler_functionality(self, sample_points, sample_kdtree):
119+
"""Test the random sampler's sampling behavior."""
120+
sampler = Sampler(
121+
point_cloud_size=len(sample_points),
122+
search_tree=sample_kdtree,
123+
sampler_name="random",
124+
num_classes=10,
125+
seed=42
126+
)
127+
128+
num_points = 20
129+
points, indices, center_point = sampler.random(sample_points, num_points)
130+
131+
assert points.shape == (num_points, 3)
132+
assert len(indices) == num_points
133+
assert center_point.shape == (1, 3)
134+
assert indices.max() < len(sample_points)
135+
136+
def test_spatially_regular_with_num_points(self, sample_points, sample_kdtree):
137+
"""Test spatially regular sampler with num_points parameter."""
138+
sampler = Sampler(
139+
point_cloud_size=len(sample_points),
140+
search_tree=sample_kdtree,
141+
sampler_name="spatially_regular",
142+
num_classes=10,
143+
seed=42
144+
)
145+
146+
num_points = 20
147+
points, indices, center_point = sampler.spatially_regular(sample_points, num_points=num_points)
148+
149+
assert points.shape == (len(indices), 3)
150+
assert len(indices) >= 2 # Should have at least 2 points
151+
assert center_point.shape == (1, 3)
152+
assert np.min(sampler.p) >= sampler.min_p
153+
154+
def test_spatially_regular_with_radius(self, sample_points, sample_kdtree):
155+
"""Test spatially regular sampler with radius parameter."""
156+
sampler = Sampler(
157+
point_cloud_size=len(sample_points),
158+
search_tree=sample_kdtree,
159+
sampler_name="spatially_regular",
160+
num_classes=10,
161+
seed=42
162+
)
163+
164+
radius = 0.3
165+
points, indices, center_point = sampler.spatially_regular(sample_points, radius=radius)
166+
167+
assert points.shape == (len(indices), 3)
168+
assert len(indices) >= 2
169+
assert center_point.shape == (1, 3)
170+
171+
def test_spatially_regular_missing_params(self, sample_points, sample_kdtree):
172+
"""Test spatially_regular raises error when parameters are missing."""
173+
sampler = Sampler(
174+
point_cloud_size=len(sample_points),
175+
search_tree=sample_kdtree,
176+
sampler_name="spatially_regular",
177+
num_classes=10,
178+
seed=42
179+
)
180+
181+
with pytest.raises(ValueError, match="Either num_points or radius must be provided"):
182+
sampler.spatially_regular(sample_points)
183+
184+
185+
class TestUtilityFunctions:
186+
"""Tests for standalone utility functions."""
187+
188+
def test_recenter(self, sample_points):
189+
"""Test recenter function properly centers point cloud dimensions."""
190+
dims = [0, 2]
191+
recentered_points = recenter(sample_points.copy(), dims)
192+
193+
# Check that mean along specified dimensions is close to zero
194+
assert np.abs(recentered_points[:, dims].mean(0)).max() < 1e-10
195+
196+
# Check that unspecified dimension is unchanged
197+
assert np.allclose(recentered_points[:, 1], sample_points[:, 1])
198+
199+
def test_build_point_cloud(self, sample_points, sample_colors):
200+
"""Test build_point_cloud creates proper Open3D point cloud."""
201+
point_cloud = build_point_cloud(sample_points, sample_colors)
202+
203+
assert isinstance(point_cloud, o3d.geometry.PointCloud)
204+
assert len(point_cloud.points) == len(sample_points)
205+
assert len(point_cloud.colors) == len(sample_colors)
206+
assert np.allclose(np.asarray(point_cloud.points), sample_points)
207+
assert np.allclose(np.asarray(point_cloud.colors), sample_colors)
208+
209+
@patch('open3d.visualization.draw_geometries')
210+
def test_view_point_cloud(self, mock_draw, sample_points, sample_colors):
211+
"""Test view_point_cloud correctly calls visualization function."""
212+
view_point_cloud(sample_points, sample_colors)
213+
214+
mock_draw.assert_called_once()
215+
args = mock_draw.call_args[0][0]
216+
assert len(args) == 1
217+
assert isinstance(args[0], o3d.geometry.PointCloud)
218+
219+
@patch('open3d.visualization.rendering.OffscreenRenderer')
220+
def test_render_point_cloud(self, mock_renderer_class, sample_points, sample_colors):
221+
"""Test render_point_cloud produces expected output."""
222+
# Setup mock
223+
mock_renderer = MagicMock()
224+
mock_renderer_class.return_value = mock_renderer
225+
mock_image_array = np.zeros((1080, 1920, 4), dtype=np.uint8)
226+
mock_renderer.render_to_image.return_value = mock_image_array
227+
228+
# Call function with custom parameters
229+
result = render_point_cloud(
230+
sample_points,
231+
sample_colors,
232+
camera_view="3rd_person",
233+
bg_color=[0.5, 0.5, 0.5, 1.0],
234+
color_jitter=0.1,
235+
point_size=5.0,
236+
resolution=(800, 600)
237+
)
238+
239+
# Verify expectations
240+
mock_renderer_class.assert_called_once_with(800, 600)
241+
mock_renderer.scene.add_geometry.assert_called_once()
242+
mock_renderer.scene.set_background.assert_called_once()
243+
mock_renderer.setup_camera.assert_called_once()
244+
mock_renderer.render_to_image.assert_called_once()
245+
mock_renderer.scene.clear_geometry.assert_called_once()
246+
247+
assert isinstance(result, Image.Image)
248+
249+
def test_render_point_cloud_invalid_camera_view(self, sample_points, sample_colors):
250+
"""Test render_point_cloud with invalid camera view."""
251+
with pytest.raises(AssertionError):
252+
render_point_cloud(
253+
sample_points,
254+
sample_colors,
255+
camera_view="invalid_view"
256+
)
257+
258+
259+
class TestConstants:
260+
"""Tests for constants in the module."""
261+
262+
def test_camera_views_structure(self):
263+
"""Test the structure of CAMERA_VIEWS constant."""
264+
assert "3rd_person" in CAMERA_VIEWS
265+
view = CAMERA_VIEWS["3rd_person"]
266+
267+
required_keys = ["zoom", "front", "lookat", "up"]
268+
for key in required_keys:
269+
assert key in view
270+
271+
for vector_key in ["front", "lookat", "up"]:
272+
assert isinstance(view[vector_key], np.ndarray)
273+
assert view[vector_key].shape == (3,)

0 commit comments

Comments
 (0)