Skip to content

Commit fc8551a

Browse files
test: add unit tests for COCODataset
16 tests covering: - Dataset properties and filtering (mock DataFrame) - Sample/batch creation with multi-instance data - Mask decoding: polygon, RLE, error handling - SEMANTIC vs INSTANCE annotation modes
1 parent 611d82a commit fc8551a

File tree

1 file changed

+208
-0
lines changed

1 file changed

+208
-0
lines changed
Lines changed: 208 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,208 @@
1+
# Copyright (C) 2025-2026 Intel Corporation
2+
# SPDX-License-Identifier: Apache-2.0
3+
4+
"""Unit tests for COCO dataset functionality."""
5+
6+
from typing import Any
7+
from unittest.mock import MagicMock, patch
8+
9+
import numpy as np
10+
import polars as pl
11+
import pytest
12+
import torch
13+
from pycocotools import mask as mask_utils
14+
15+
from instantlearn.data.base import Batch, Dataset, Sample
16+
from instantlearn.data.coco import COCODataset
17+
from instantlearn.data.lvis import LVISAnnotationMode
18+
19+
20+
class TestCOCODatasetMock:
21+
"""Test COCODataset via a mock that bypasses filesystem / COCO API."""
22+
23+
@pytest.fixture
24+
def mock_coco_dataframe(self) -> pl.DataFrame:
25+
"""Create a mock DataFrame mimicking COCODataset._load_dataframe output (SEMANTIC mode)."""
26+
return pl.DataFrame({
27+
"image_id": [1, 1, 2, 3],
28+
"image_path": ["/dummy/img_001.jpg", "/dummy/img_001.jpg", "/dummy/img_002.jpg", "/dummy/img_003.jpg"],
29+
"categories": [["cat"], ["dog"], ["cat"], ["dog"]],
30+
"category_ids": [[1], [2], [1], [2]],
31+
"segmentations": [
32+
[[[10, 10, 50, 10, 50, 50, 10, 50]]],
33+
[[[60, 60, 90, 60, 90, 90, 60, 90]]],
34+
[[[20, 20, 80, 20, 80, 80, 20, 80]]],
35+
[[[5, 5, 40, 5, 40, 40, 5, 40]]],
36+
],
37+
"bboxes": [None, None, None, None],
38+
"is_reference": [[True], [True], [False], [False]],
39+
"n_shot": [[0], [0], [-1], [-1]],
40+
"img_dim": [(100, 100), (100, 100), (100, 100), (100, 100)],
41+
})
42+
43+
@pytest.fixture
44+
def mock_coco_dataset(self, mock_coco_dataframe: pl.DataFrame) -> Dataset:
45+
"""Create a mock COCO dataset that skips real file I/O."""
46+
47+
class MockCOCODataset(Dataset):
48+
def _load_dataframe(self) -> pl.DataFrame:
49+
return mock_coco_dataframe
50+
51+
def _load_masks(self, raw_sample: dict[str, Any]) -> np.ndarray:
52+
segmentations = raw_sample.get("segmentations", [])
53+
if not segmentations:
54+
return np.zeros((0, 100, 100), dtype=np.uint8)
55+
num_masks = len(segmentations)
56+
return np.random.default_rng(42).integers(0, 2, (num_masks, 100, 100), dtype=np.uint8)
57+
58+
dataset = MockCOCODataset()
59+
dataset.df = mock_coco_dataframe
60+
return dataset
61+
62+
def test_dataset_length(self, mock_coco_dataset: Dataset) -> None:
63+
"""Dataset length matches number of rows."""
64+
assert len(mock_coco_dataset) == 4
65+
66+
def test_categories_property(self, mock_coco_dataset: Dataset) -> None:
67+
"""Unique categories are returned."""
68+
cats = mock_coco_dataset.categories
69+
assert set(cats) == {"cat", "dog"}
70+
71+
def test_num_categories(self, mock_coco_dataset: Dataset) -> None:
72+
"""Number of unique categories is correct."""
73+
assert mock_coco_dataset.num_categories == 2
74+
75+
def test_reference_filtering(self, mock_coco_dataset: Dataset) -> None:
76+
"""Reference rows are correctly filtered."""
77+
ref_df = mock_coco_dataset.get_reference_samples_df()
78+
assert len(ref_df) == 2
79+
80+
def test_target_filtering(self, mock_coco_dataset: Dataset) -> None:
81+
"""Target rows are correctly filtered."""
82+
target_df = mock_coco_dataset.get_target_samples_df()
83+
assert len(target_df) == 2
84+
85+
def test_category_reference_filtering(self, mock_coco_dataset: Dataset) -> None:
86+
"""Reference rows can be filtered by category."""
87+
cat_ref = mock_coco_dataset.get_reference_samples_df(category="cat")
88+
assert len(cat_ref) == 1
89+
dog_ref = mock_coco_dataset.get_reference_samples_df(category="dog")
90+
assert len(dog_ref) == 1
91+
92+
@patch("instantlearn.data.base.base.read_image")
93+
def test_sample_creation(self, mock_read_image: MagicMock, mock_coco_dataset: Dataset) -> None:
94+
"""Samples are created with correct structure."""
95+
mock_read_image.return_value = np.zeros((100, 100, 3), dtype=np.uint8)
96+
97+
sample = mock_coco_dataset[0]
98+
assert isinstance(sample, Sample)
99+
assert len(sample.categories) == 1
100+
assert sample.categories == ["cat"]
101+
assert sample.masks is not None
102+
assert sample.masks.shape[0] == 1
103+
104+
@patch("instantlearn.data.base.base.read_image")
105+
def test_sample_metadata(self, mock_read_image: MagicMock, mock_coco_dataset: Dataset) -> None:
106+
"""Sample metadata fields are correct."""
107+
mock_read_image.return_value = np.zeros((100, 100, 3), dtype=np.uint8)
108+
109+
sample = mock_coco_dataset[0]
110+
assert sample.is_reference == [True]
111+
assert sample.n_shot == [0]
112+
113+
@patch("instantlearn.data.base.base.read_image")
114+
def test_batch_creation(self, mock_read_image: MagicMock, mock_coco_dataset: Dataset) -> None:
115+
"""Batch collation preserves multi-sample structure."""
116+
mock_read_image.return_value = np.zeros((100, 100, 3), dtype=np.uint8)
117+
118+
samples = [mock_coco_dataset[i] for i in range(len(mock_coco_dataset))]
119+
batch = Batch.collate(samples)
120+
121+
assert isinstance(batch, Batch)
122+
assert len(batch) == 4
123+
assert len(batch.categories) == 4
124+
assert len(batch.images) == 4
125+
126+
@patch("instantlearn.data.base.base.read_image")
127+
def test_data_consistency(self, mock_read_image: MagicMock, mock_coco_dataset: Dataset) -> None:
128+
"""All samples have consistent metadata lengths."""
129+
mock_read_image.return_value = np.zeros((100, 100, 3), dtype=np.uint8)
130+
131+
for i in range(len(mock_coco_dataset)):
132+
sample = mock_coco_dataset[i]
133+
assert len(sample.categories) == len(sample.category_ids)
134+
assert len(sample.categories) == len(sample.is_reference)
135+
assert len(sample.categories) == len(sample.n_shot)
136+
137+
138+
class TestCOCODatasetMaskDecoding:
139+
"""Test COCODataset mask decoding methods directly."""
140+
141+
def test_decode_single_polygon(self) -> None:
142+
"""Polygon segmentation decodes to a binary mask."""
143+
polygon = [[10, 10, 90, 10, 90, 90, 10, 90]]
144+
mask = COCODataset._decode_single(polygon, h=100, w=100) # noqa: SLF001
145+
assert isinstance(mask, torch.Tensor)
146+
assert mask.dtype == torch.bool
147+
assert mask.shape == (100, 100)
148+
assert mask.any()
149+
150+
def test_decode_single_rle(self) -> None:
151+
"""RLE segmentation decodes to a binary mask."""
152+
# Create a valid RLE from a polygon
153+
rle = mask_utils.frPyObjects([[10, 10, 90, 10, 90, 90, 10, 90]], 100, 100)
154+
merged_rle = mask_utils.merge(rle)
155+
mask = COCODataset._decode_single(merged_rle, h=100, w=100) # noqa: SLF001
156+
assert isinstance(mask, torch.Tensor)
157+
assert mask.dtype == torch.bool
158+
assert mask.shape == (100, 100)
159+
assert mask.any()
160+
161+
def test_decode_single_invalid_type(self) -> None:
162+
"""Invalid segmentation type raises TypeError."""
163+
with pytest.raises(TypeError, match="Unknown segmentation format"):
164+
COCODataset._decode_single(12345, h=100, w=100) # noqa: SLF001
165+
166+
def test_load_masks_semantic_mode(self) -> None:
167+
"""SEMANTIC mode merges multiple polygons into one mask."""
168+
dataset = COCODataset.__new__(COCODataset)
169+
dataset.annotation_mode = LVISAnnotationMode.SEMANTIC
170+
171+
raw_sample = {
172+
"segmentations": [
173+
[[10, 10, 40, 10, 40, 40, 10, 40]],
174+
[[60, 60, 90, 60, 90, 90, 60, 90]],
175+
],
176+
"img_dim": (100, 100),
177+
}
178+
masks = dataset._load_masks(raw_sample) # noqa: SLF001
179+
assert masks is not None
180+
assert masks.shape == (1, 100, 100)
181+
assert masks.any()
182+
183+
def test_load_masks_instance_mode(self) -> None:
184+
"""INSTANCE mode keeps separate masks per instance."""
185+
dataset = COCODataset.__new__(COCODataset)
186+
dataset.annotation_mode = LVISAnnotationMode.INSTANCE
187+
188+
raw_sample = {
189+
"segmentations": [
190+
[[10, 10, 40, 10, 40, 40, 10, 40]],
191+
[[60, 60, 90, 60, 90, 90, 60, 90]],
192+
],
193+
"img_dim": (100, 100),
194+
}
195+
masks = dataset._load_masks(raw_sample) # noqa: SLF001
196+
assert masks is not None
197+
assert masks.shape == (2, 100, 100)
198+
199+
def test_load_masks_empty(self) -> None:
200+
"""Empty segmentations return None."""
201+
dataset = COCODataset.__new__(COCODataset)
202+
dataset.annotation_mode = LVISAnnotationMode.SEMANTIC
203+
204+
raw_sample = {"segmentations": [], "img_dim": (100, 100)}
205+
assert dataset._load_masks(raw_sample) is None # noqa: SLF001
206+
207+
raw_sample_no_key = {"img_dim": (100, 100)}
208+
assert dataset._load_masks(raw_sample_no_key) is None # noqa: SLF001

0 commit comments

Comments
 (0)