1+ import pytest
2+ import numpy as np
3+ from unittest .mock import patch , MagicMock
4+ from sklearn .neighbors import KDTree
5+ import open3d as o3d
6+ from PIL import Image
7+ from detectionmetrics .utils .lidar import (
8+ Sampler ,
9+ recenter ,
10+ build_point_cloud ,
11+ view_point_cloud ,
12+ render_point_cloud ,
13+ REFERENCE_SIZE ,
14+ CAMERA_VIEWS
15+ )
16+
17+
18+ @pytest .fixture
19+ def sample_points ():
20+ """Fixture to generate reproducible sample points for testing."""
21+ np .random .seed (42 )
22+ return np .random .rand (100 , 3 )
23+
24+
25+ @pytest .fixture
26+ def sample_colors ():
27+ """Fixture to generate reproducible sample colors for testing."""
28+ np .random .seed (42 )
29+ return np .random .rand (100 , 3 )
30+
31+
32+ @pytest .fixture
33+ def sample_kdtree (sample_points ):
34+ """Create a KDTree from sample points."""
35+ return KDTree (sample_points )
36+
37+
38+ class TestSampler :
39+ """Tests for the Sampler class."""
40+
41+ def test_valid_samplers (self , sample_points , sample_kdtree ):
42+ """Test initialization with valid samplers."""
43+ # Test with random sampler
44+ random_sampler = Sampler (
45+ point_cloud_size = len (sample_points ),
46+ search_tree = sample_kdtree ,
47+ sampler_name = "random" ,
48+ num_classes = 10 ,
49+ seed = 42
50+ )
51+
52+ assert random_sampler .num_classes == 10
53+ assert random_sampler .test_probs .shape == (len (sample_points ), 10 )
54+ assert random_sampler .sample .__name__ == "random"
55+
56+ # Test with spatially_regular sampler
57+ spatial_sampler = Sampler (
58+ point_cloud_size = len (sample_points ),
59+ search_tree = sample_kdtree ,
60+ sampler_name = "spatially_regular" ,
61+ num_classes = 10 ,
62+ seed = 42
63+ )
64+
65+ assert spatial_sampler .sample .__name__ == "spatially_regular"
66+
67+ def test_invalid_sampler (self , sample_points , sample_kdtree ):
68+ """Test initialization with invalid sampler name."""
69+ # Handling the fact that the original code tries to access self.model_cfg['sampler']
70+ # We expect an AttributeError rather than NotImplementedError
71+ with pytest .raises (AttributeError ):
72+ Sampler (
73+ point_cloud_size = len (sample_points ),
74+ search_tree = sample_kdtree ,
75+ sampler_name = "invalid_sampler" ,
76+ num_classes = 10 ,
77+ seed = 42
78+ )
79+
80+ def test_get_indices_small_cloud (self , sample_points , sample_kdtree ):
81+ """Test _get_indices when point_cloud_size < num_points."""
82+ sampler = Sampler (
83+ point_cloud_size = len (sample_points ),
84+ search_tree = sample_kdtree ,
85+ sampler_name = "random" ,
86+ num_classes = 10 ,
87+ seed = 42
88+ )
89+
90+ point_cloud_size = 20
91+ num_points = 30
92+ center_point = np .array ([[0.5 , 0.5 , 0.5 ]])
93+
94+ indices = sampler ._get_indices (point_cloud_size , num_points , center_point )
95+
96+ assert len (indices ) == num_points
97+ assert np .max (indices ) < point_cloud_size # All indices should be within range
98+
99+ def test_get_indices_large_cloud (self , sample_points , sample_kdtree ):
100+ """Test _get_indices when point_cloud_size >= num_points."""
101+ sampler = Sampler (
102+ point_cloud_size = len (sample_points ),
103+ search_tree = sample_kdtree ,
104+ sampler_name = "random" ,
105+ num_classes = 10 ,
106+ seed = 42
107+ )
108+
109+ point_cloud_size = 100
110+ num_points = 10
111+ center_point = np .array ([[0.5 , 0.5 , 0.5 ]])
112+
113+ indices = sampler ._get_indices (point_cloud_size , num_points , center_point )
114+
115+ assert len (indices ) == num_points
116+ assert np .max (indices ) < point_cloud_size
117+
118+ def test_random_sampler_functionality (self , sample_points , sample_kdtree ):
119+ """Test the random sampler's sampling behavior."""
120+ sampler = Sampler (
121+ point_cloud_size = len (sample_points ),
122+ search_tree = sample_kdtree ,
123+ sampler_name = "random" ,
124+ num_classes = 10 ,
125+ seed = 42
126+ )
127+
128+ num_points = 20
129+ points , indices , center_point = sampler .random (sample_points , num_points )
130+
131+ assert points .shape == (num_points , 3 )
132+ assert len (indices ) == num_points
133+ assert center_point .shape == (1 , 3 )
134+ assert indices .max () < len (sample_points )
135+
136+ def test_spatially_regular_with_num_points (self , sample_points , sample_kdtree ):
137+ """Test spatially regular sampler with num_points parameter."""
138+ sampler = Sampler (
139+ point_cloud_size = len (sample_points ),
140+ search_tree = sample_kdtree ,
141+ sampler_name = "spatially_regular" ,
142+ num_classes = 10 ,
143+ seed = 42
144+ )
145+
146+ num_points = 20
147+ points , indices , center_point = sampler .spatially_regular (sample_points , num_points = num_points )
148+
149+ assert points .shape == (len (indices ), 3 )
150+ assert len (indices ) >= 2 # Should have at least 2 points
151+ assert center_point .shape == (1 , 3 )
152+ assert np .min (sampler .p ) >= sampler .min_p
153+
154+ def test_spatially_regular_with_radius (self , sample_points , sample_kdtree ):
155+ """Test spatially regular sampler with radius parameter."""
156+ sampler = Sampler (
157+ point_cloud_size = len (sample_points ),
158+ search_tree = sample_kdtree ,
159+ sampler_name = "spatially_regular" ,
160+ num_classes = 10 ,
161+ seed = 42
162+ )
163+
164+ radius = 0.3
165+ points , indices , center_point = sampler .spatially_regular (sample_points , radius = radius )
166+
167+ assert points .shape == (len (indices ), 3 )
168+ assert len (indices ) >= 2
169+ assert center_point .shape == (1 , 3 )
170+
171+ def test_spatially_regular_missing_params (self , sample_points , sample_kdtree ):
172+ """Test spatially_regular raises error when parameters are missing."""
173+ sampler = Sampler (
174+ point_cloud_size = len (sample_points ),
175+ search_tree = sample_kdtree ,
176+ sampler_name = "spatially_regular" ,
177+ num_classes = 10 ,
178+ seed = 42
179+ )
180+
181+ with pytest .raises (ValueError , match = "Either num_points or radius must be provided" ):
182+ sampler .spatially_regular (sample_points )
183+
184+
185+ class TestUtilityFunctions :
186+ """Tests for standalone utility functions."""
187+
188+ def test_recenter (self , sample_points ):
189+ """Test recenter function properly centers point cloud dimensions."""
190+ dims = [0 , 2 ]
191+ recentered_points = recenter (sample_points .copy (), dims )
192+
193+ # Check that mean along specified dimensions is close to zero
194+ assert np .abs (recentered_points [:, dims ].mean (0 )).max () < 1e-10
195+
196+ # Check that unspecified dimension is unchanged
197+ assert np .allclose (recentered_points [:, 1 ], sample_points [:, 1 ])
198+
199+ def test_build_point_cloud (self , sample_points , sample_colors ):
200+ """Test build_point_cloud creates proper Open3D point cloud."""
201+ point_cloud = build_point_cloud (sample_points , sample_colors )
202+
203+ assert isinstance (point_cloud , o3d .geometry .PointCloud )
204+ assert len (point_cloud .points ) == len (sample_points )
205+ assert len (point_cloud .colors ) == len (sample_colors )
206+ assert np .allclose (np .asarray (point_cloud .points ), sample_points )
207+ assert np .allclose (np .asarray (point_cloud .colors ), sample_colors )
208+
209+ @patch ('open3d.visualization.draw_geometries' )
210+ def test_view_point_cloud (self , mock_draw , sample_points , sample_colors ):
211+ """Test view_point_cloud correctly calls visualization function."""
212+ view_point_cloud (sample_points , sample_colors )
213+
214+ mock_draw .assert_called_once ()
215+ args = mock_draw .call_args [0 ][0 ]
216+ assert len (args ) == 1
217+ assert isinstance (args [0 ], o3d .geometry .PointCloud )
218+
219+ @patch ('open3d.visualization.rendering.OffscreenRenderer' )
220+ def test_render_point_cloud (self , mock_renderer_class , sample_points , sample_colors ):
221+ """Test render_point_cloud produces expected output."""
222+ # Setup mock
223+ mock_renderer = MagicMock ()
224+ mock_renderer_class .return_value = mock_renderer
225+ mock_image_array = np .zeros ((1080 , 1920 , 4 ), dtype = np .uint8 )
226+ mock_renderer .render_to_image .return_value = mock_image_array
227+
228+ # Call function with custom parameters
229+ result = render_point_cloud (
230+ sample_points ,
231+ sample_colors ,
232+ camera_view = "3rd_person" ,
233+ bg_color = [0.5 , 0.5 , 0.5 , 1.0 ],
234+ color_jitter = 0.1 ,
235+ point_size = 5.0 ,
236+ resolution = (800 , 600 )
237+ )
238+
239+ # Verify expectations
240+ mock_renderer_class .assert_called_once_with (800 , 600 )
241+ mock_renderer .scene .add_geometry .assert_called_once ()
242+ mock_renderer .scene .set_background .assert_called_once ()
243+ mock_renderer .setup_camera .assert_called_once ()
244+ mock_renderer .render_to_image .assert_called_once ()
245+ mock_renderer .scene .clear_geometry .assert_called_once ()
246+
247+ assert isinstance (result , Image .Image )
248+
249+ def test_render_point_cloud_invalid_camera_view (self , sample_points , sample_colors ):
250+ """Test render_point_cloud with invalid camera view."""
251+ with pytest .raises (AssertionError ):
252+ render_point_cloud (
253+ sample_points ,
254+ sample_colors ,
255+ camera_view = "invalid_view"
256+ )
257+
258+
259+ class TestConstants :
260+ """Tests for constants in the module."""
261+
262+ def test_camera_views_structure (self ):
263+ """Test the structure of CAMERA_VIEWS constant."""
264+ assert "3rd_person" in CAMERA_VIEWS
265+ view = CAMERA_VIEWS ["3rd_person" ]
266+
267+ required_keys = ["zoom" , "front" , "lookat" , "up" ]
268+ for key in required_keys :
269+ assert key in view
270+
271+ for vector_key in ["front" , "lookat" , "up" ]:
272+ assert isinstance (view [vector_key ], np .ndarray )
273+ assert view [vector_key ].shape == (3 ,)
0 commit comments