Skip to content

Commit 2a9a9f8

Browse files
committed
update
1 parent 154735f commit 2a9a9f8

3 files changed

Lines changed: 193 additions & 12 deletions

File tree

src/litlogger/experiment_support.py

Lines changed: 5 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616
import contextlib
1717
import re
1818
from datetime import datetime
19-
from typing import TYPE_CHECKING, Callable, cast
19+
from typing import TYPE_CHECKING, Callable
2020

2121
from lightning_sdk.lightning_cloud.openapi import V1MediaType
2222

@@ -143,12 +143,8 @@ def resolve_remote_model(exp: "Experiment", key: str) -> Model | Series | None:
143143

144144
@staticmethod
145145
def rebuild_state(exp: "Experiment") -> None:
146-
"""Rebuild state from remote metadata, steps, artifacts, and media.
147-
148-
TODO: Add backend-supported recovery for model bindings so resumed
149-
experiments can reconstruct ``Model`` values without storing them in
150-
frontend-visible metadata tags.
151-
"""
146+
"""Rebuild state from remote metadata, steps, artifacts, and media."""
147+
# TODO: add BE support for restoring model states as well
152148
exp._update_metrics_store()
153149
tags = getattr(exp._metrics_store, "tags", None) or []
154150
for tag in tags:
@@ -295,7 +291,7 @@ def wrap_media_file(exp: "Experiment", media_name: str, media_type: V1MediaType)
295291
text.path = media_name
296292
return text
297293

298-
if media_type == "MEDIA_TYPE_VIDEO": # TODO: use proper V1MediaType
294+
if media_type == V1MediaType.VIDEO:
299295
return Video(media_name)
300296
return File(media_name)
301297

@@ -321,7 +317,7 @@ def media_type_to_v1(exp: "Experiment", media_type: MediaType) -> V1MediaType:
321317
return V1MediaType.TEXT
322318

323319
if media_type == MediaType.VIDEO:
324-
return cast(V1MediaType, "MEDIA_TYPE_VIDEO") # TODO: Use proper V1MediaType
320+
return V1MediaType.VIDEO
325321
raise ValueError(f"Unsupported media type for file upload: {media_type}")
326322

327323
@staticmethod

tests/unittests/test_experiment_support.py

Lines changed: 28 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,7 @@
99
from lightning_sdk.lightning_cloud.openapi import V1MediaType
1010
from litlogger.experiment import Experiment
1111
from litlogger.experiment_support import ExperimentIOSupport, ExperimentStateSupport
12-
from litlogger.media import Model, Text
12+
from litlogger.media import Model, Text, Video
1313
from litlogger.series import Series
1414
from litlogger.types import MediaType
1515

@@ -69,6 +69,25 @@ def test_log_file_series_value_routes_text_with_exact_key_name(self):
6969
exp._upload_media_value.assert_called_once_with("logs", text, name="logs", step=7)
7070
exp._upload_model_value.assert_not_called()
7171

72+
def test_log_file_series_value_routes_video_with_exact_key_name(self):
73+
exp = MagicMock(spec=Experiment)
74+
exp._upload_media_value = MagicMock()
75+
exp._upload_model_value = MagicMock()
76+
exp._stats = MagicMock()
77+
exp._stats.media_logged = 0
78+
79+
video = Video("preview.mp4")
80+
81+
ExperimentIOSupport.log_file_series_value(exp, "clips", video, 2, step=7)
82+
83+
exp._upload_media_value.assert_called_once_with("clips", video, name="clips", step=7)
84+
exp._upload_model_value.assert_not_called()
85+
86+
def test_media_type_to_v1_maps_video(self):
87+
exp = MagicMock(spec=Experiment)
88+
89+
assert ExperimentIOSupport.media_type_to_v1(exp, MediaType.VIDEO) == V1MediaType.VIDEO
90+
7291
def test_log_file_series_value_auto_versions_models_from_v1(self):
7392
exp = MagicMock(spec=Experiment)
7493
exp._upload_model_value = MagicMock()
@@ -92,6 +111,14 @@ def test_wrap_media_file_returns_text_wrapper(self):
92111
assert isinstance(wrapped, Text)
93112
assert wrapped.path == "logs/0"
94113

114+
def test_wrap_media_file_returns_video_wrapper(self):
115+
exp = MagicMock(spec=Experiment)
116+
117+
wrapped = ExperimentStateSupport.wrap_media_file(exp, "clips/0", V1MediaType.VIDEO)
118+
119+
assert isinstance(wrapped, Video)
120+
assert wrapped.path == "clips/0"
121+
95122
def test_rebuild_state_reconstructs_sorted_text_series(self):
96123
media1 = MagicMock()
97124
media1.name = "logs/1"

tests/unittests/test_media.py

Lines changed: 160 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -2,14 +2,15 @@
22
# Licensed under the Apache License, Version 2.0 (the "License");
33
# http://www.apache.org/licenses/LICENSE-2.0
44
#
5-
"""Tests for File, Image, Text, and Model media types."""
5+
"""Tests for File, Image, Text, Video, and Model media types."""
66

77
import os
88
import tempfile
9+
from types import SimpleNamespace
910
from unittest.mock import MagicMock, patch
1011

1112
import pytest
12-
from litlogger.media import File, Image, Model, Text, _sanitize_version_for_model_name
13+
from litlogger.media import File, Image, Model, Text, Video, _sanitize_version_for_model_name
1314
from litlogger.types import MediaType
1415

1516

@@ -406,6 +407,163 @@ def test_object_repr_before_render(self):
406407
assert repr(Image(pil)) == "Image('')"
407408

408409

410+
# ---------------------------------------------------------------------------
411+
# Video
412+
# ---------------------------------------------------------------------------
413+
414+
415+
class TestVideoInit:
416+
"""Test Video construction."""
417+
418+
def test_from_path(self):
419+
video = Video("clip.mp4")
420+
assert video.path == "clip.mp4"
421+
assert video._data == "clip.mp4"
422+
assert video._format == "mp4"
423+
assert video._fps is None
424+
425+
def test_from_path_with_description(self):
426+
video = Video("clip.mp4", description="preview clip")
427+
assert video.description == "preview clip"
428+
429+
def test_custom_format_and_fps(self):
430+
video = Video("clip.mov", format="mov", fps=12)
431+
assert video._format == "mov"
432+
assert video._fps == 12
433+
434+
def test_media_type(self):
435+
assert Video("x.mp4")._media_type == MediaType.VIDEO
436+
437+
438+
class TestVideoUploadPath:
439+
"""Test Video._get_upload_path for different data types."""
440+
441+
def test_string_path_nonexistent_returns_path(self):
442+
video = Video("nonexistent.mp4")
443+
assert video._get_upload_path() == "nonexistent.mp4"
444+
445+
@patch.object(Video, "_write_moviepy_clip")
446+
@patch.object(Video, "_moviepy_clip_from_array")
447+
def test_numpy_frames_render_to_temp(self, mock_clip_from_array, mock_write_moviepy_clip):
448+
try:
449+
import numpy as np
450+
except ImportError:
451+
pytest.skip("numpy not available")
452+
453+
clip = object()
454+
mock_clip_from_array.return_value = clip
455+
456+
frames = np.zeros((2, 8, 8, 3), dtype=np.uint8)
457+
video = Video(frames, fps=12)
458+
path = video._get_upload_path()
459+
460+
assert path.endswith(".mp4")
461+
assert os.path.exists(path)
462+
mock_clip_from_array.assert_called_once()
463+
mock_write_moviepy_clip.assert_called_once_with(clip, path, 12)
464+
video._cleanup()
465+
466+
def test_moviepy_clip_from_array_transposes_tchw_and_scales_floats(self):
467+
try:
468+
import numpy as np
469+
except ImportError:
470+
pytest.skip("numpy not available")
471+
472+
captured: dict[str, object] = {}
473+
474+
class FakeImageSequenceClip:
475+
def __init__(self, frames, fps):
476+
captured["frames"] = frames
477+
captured["fps"] = fps
478+
479+
def fake_import_module(name: str):
480+
if name == "numpy":
481+
return np
482+
if name == "moviepy.video.io.ImageSequenceClip":
483+
return SimpleNamespace(ImageSequenceClip=FakeImageSequenceClip)
484+
raise ImportError(name)
485+
486+
frames = np.full((2, 3, 4, 5), 0.5, dtype=np.float32)
487+
video = Video(frames)
488+
with patch("litlogger.media.import_module", side_effect=fake_import_module):
489+
clip = video._moviepy_clip_from_array(frames, fps=7)
490+
491+
assert isinstance(clip, FakeImageSequenceClip)
492+
assert captured["fps"] == 7
493+
rendered_frames = captured["frames"]
494+
assert isinstance(rendered_frames, list)
495+
assert len(rendered_frames) == 2
496+
assert rendered_frames[0].shape == (4, 5, 3)
497+
assert rendered_frames[0].dtype == np.uint8
498+
assert rendered_frames[0][0, 0, 0] in (127, 128)
499+
500+
def test_moviepy_clip_from_array_promotes_grayscale_to_rgb(self):
501+
try:
502+
import numpy as np
503+
except ImportError:
504+
pytest.skip("numpy not available")
505+
506+
captured: dict[str, object] = {}
507+
508+
class FakeImageSequenceClip:
509+
def __init__(self, frames, fps):
510+
captured["frames"] = frames
511+
captured["fps"] = fps
512+
513+
def fake_import_module(name: str):
514+
if name == "numpy":
515+
return np
516+
if name == "moviepy.video.io.ImageSequenceClip":
517+
return SimpleNamespace(ImageSequenceClip=FakeImageSequenceClip)
518+
raise ImportError(name)
519+
520+
frames = np.zeros((2, 6, 7), dtype=np.uint8)
521+
video = Video(frames)
522+
with patch("litlogger.media.import_module", side_effect=fake_import_module):
523+
video._moviepy_clip_from_array(frames, fps=5)
524+
525+
rendered_frames = captured["frames"]
526+
assert isinstance(rendered_frames, list)
527+
assert rendered_frames[0].shape == (6, 7, 3)
528+
assert captured["fps"] == 5
529+
530+
def test_numpy_unsupported_shape_raises(self):
531+
try:
532+
import numpy as np
533+
except ImportError:
534+
pytest.skip("numpy not available")
535+
536+
video = Video(np.zeros((2, 3, 4, 5, 6), dtype=np.uint8))
537+
with pytest.raises(ValueError, match="Unsupported array shape"):
538+
video._moviepy_clip_from_array(video._data, fps=Video.DEFAULT_FPS)
539+
540+
def test_unsupported_type_raises(self):
541+
video = Video({"not": "a video"})
542+
with pytest.raises(TypeError, match="Unsupported video type"):
543+
video._get_upload_path()
544+
545+
546+
class TestVideoCleanup:
547+
"""Test Video._cleanup removes rendered temp files."""
548+
549+
@patch.object(Video, "_write_moviepy_clip")
550+
@patch.object(Video, "_moviepy_clip_from_array")
551+
def test_cleanup_removes_rendered_temp(self, mock_clip_from_array, mock_write_moviepy_clip):
552+
try:
553+
import numpy as np
554+
except ImportError:
555+
pytest.skip("numpy not available")
556+
557+
mock_clip_from_array.return_value = object()
558+
video = Video(np.zeros((1, 4, 4, 3), dtype=np.uint8))
559+
path = video._get_upload_path()
560+
assert os.path.exists(path)
561+
562+
video._cleanup()
563+
assert not os.path.exists(path)
564+
assert video._temp_path is None
565+
566+
409567
# ---------------------------------------------------------------------------
410568
# Text
411569
# ---------------------------------------------------------------------------

0 commit comments

Comments
 (0)