Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 8 additions & 2 deletions sleap_io/io/slp.py
Original file line number Diff line number Diff line change
Expand Up @@ -845,10 +845,14 @@ def read_suggestions(labels_path: str, videos: list[Video]) -> list[SuggestionFr
suggestions = [json.loads(x) for x in suggestions]
suggestions_objects = []
for suggestion in suggestions:
# Extract metadata (e.g., "group")
metadata = {"group": suggestion.get("group", 0)}

suggestions_objects.append(
SuggestionFrame(
video=videos[int(suggestion["video"])],
frame_idx=suggestion["frame_idx"],
metadata=metadata,
)
)
return suggestions_objects
Expand All @@ -864,13 +868,15 @@ def write_suggestions(
suggestions: A list of `SuggestionFrame` objects to store the metadata for.
videos: A list of `Video` objects.
"""
GROUP = 0 # TODO: Handle storing extraneous metadata.
suggestions_json = []
for suggestion in suggestions:
# Get group from metadata if available, otherwise use default
group = suggestion.metadata.get("group", 0) if suggestion.metadata else 0

suggestion_dict = {
"video": str(videos.index(suggestion.video)),
"frame_idx": suggestion.frame_idx,
"group": GROUP,
"group": group,
}
suggestion_json = np.bytes_(json.dumps(suggestion_dict, separators=(",", ":")))
suggestions_json.append(suggestion_json)
Expand Down
4 changes: 4 additions & 0 deletions sleap_io/model/suggestions.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,11 @@ class SuggestionFrame:
Attributes:
video: The video associated with the frame.
frame_idx: The index of the frame in the video.
metadata: Dictionary containing additional metadata that is not explicitly
represented in the data model. This is used to store arbitrary metadata
such as the "group" key when reading/writing SLP files.
"""

video: Video
frame_idx: int
metadata: dict[str, any] = attrs.field(factory=dict)
34 changes: 34 additions & 0 deletions tests/io/test_slp.py
Original file line number Diff line number Diff line change
Expand Up @@ -866,6 +866,40 @@ def test_suggestions(tmpdir):
assert len(loaded_suggestions) == 0


def test_suggestions_metadata(tmpdir):
"""Test that suggestion metadata (e.g., group) is preserved during read/write."""
labels = Labels()
labels.videos.append(Video.from_filename("fake.mp4"))

# Create suggestions with different group values in metadata
labels.suggestions.append(
SuggestionFrame(video=labels.video, frame_idx=0, metadata={"group": 0})
)
labels.suggestions.append(
SuggestionFrame(video=labels.video, frame_idx=1, metadata={"group": 1})
)
labels.suggestions.append(
SuggestionFrame(video=labels.video, frame_idx=2, metadata={"group": 2})
)

# Write and read suggestions
write_suggestions(tmpdir / "test.slp", labels.suggestions, labels.videos)
loaded_suggestions = read_suggestions(tmpdir / "test.slp", labels.videos)

# Verify metadata is preserved
assert len(loaded_suggestions) == 3
assert loaded_suggestions[0].metadata["group"] == 0
assert loaded_suggestions[1].metadata["group"] == 1
assert loaded_suggestions[2].metadata["group"] == 2

# Test backward compatibility: suggestions without metadata default to group 0
suggestion_no_metadata = SuggestionFrame(video=labels.video, frame_idx=3)
write_suggestions(tmpdir / "test2.slp", [suggestion_no_metadata], labels.videos)
loaded_suggestions = read_suggestions(tmpdir / "test2.slp", labels.videos)
assert len(loaded_suggestions) == 1
assert loaded_suggestions[0].metadata["group"] == 0


def test_pkg_roundtrip(tmpdir, slp_minimal_pkg):
labels = read_labels(slp_minimal_pkg)
assert type(labels.video.backend) is HDF5Video
Expand Down
52 changes: 52 additions & 0 deletions tests/io/test_video_reading.py
Original file line number Diff line number Diff line change
Expand Up @@ -752,6 +752,31 @@ def test_plugin_name_normalization():
normalize_plugin_name("invalid_plugin")


def test_image_plugin_name_normalization():
"""Test image plugin name normalization with various aliases."""
from sleap_io.io.video_reading import normalize_image_plugin_name

# Test opencv aliases
assert normalize_image_plugin_name("opencv") == "opencv"
assert normalize_image_plugin_name("OpenCV") == "opencv"
assert normalize_image_plugin_name("cv") == "opencv"
assert normalize_image_plugin_name("cv2") == "opencv"
assert normalize_image_plugin_name("CV2") == "opencv"
assert normalize_image_plugin_name("ocv") == "opencv"

# Test imageio aliases
assert normalize_image_plugin_name("imageio") == "imageio"
assert normalize_image_plugin_name("iio") == "imageio"

# Test invalid plugin
with pytest.raises(ValueError, match="Unknown image plugin"):
normalize_image_plugin_name("invalid_plugin")

# Test invalid plugin that's valid for video but not images
with pytest.raises(ValueError, match="Unknown image plugin"):
normalize_image_plugin_name("pyav")


def test_global_default_plugin():
"""Test global default plugin functionality."""
import sleap_io as sio
Expand Down Expand Up @@ -915,3 +940,30 @@ def test_image_video_plugin_with_grayscale(centered_pair_frame_paths):
np.testing.assert_array_equal(frame_opencv, frame_imageio)
assert frame_opencv.ndim == 3 # Always 3D (H, W, C)
assert frame_opencv.shape[-1] in (1, 3) # Grayscale or RGB


def test_image_video_default_plugin_without_opencv(
centered_pair_frame_paths, monkeypatch
):
"""Test ImageVideo defaults to imageio when opencv not available."""
import sys

# Mock sys.modules to simulate opencv not being available
if "cv2" in sys.modules:
monkeypatch.delitem(sys.modules, "cv2")

# Clear any global default
import sleap_io as sio

original_default = sio.get_default_image_plugin()
try:
sio.set_default_image_plugin(None)

# Create ImageVideo without specifying plugin
backend = ImageVideo(centered_pair_frame_paths)

# Should default to imageio since opencv is not available
assert backend.plugin == "imageio"
finally:
# Restore
sio.set_default_image_plugin(original_default)