Skip to content

Commit ddbfa7d

Browse files
committed
mypy+pre-commit
1 parent 2a9a9f8 commit ddbfa7d

3 files changed

Lines changed: 220 additions & 22 deletions

File tree

src/litlogger/media.py

Lines changed: 12 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -370,24 +370,24 @@ def _maybe_as_moviepy_clip(self, data: Any) -> None | Any:
370370
try:
371371
# MoviePy 2.x layout
372372
video_clip_mod = import_module("moviepy.video.VideoClip")
373-
VideoClip = getattr(video_clip_mod, "VideoClip")
374-
if isinstance(data, VideoClip):
373+
video_clip = video_clip_mod.VideoClip
374+
if isinstance(data, video_clip):
375375
return data
376376
except Exception:
377377
pass
378378

379379
try:
380380
# Older common import path
381381
editor_mod = import_module("moviepy.editor")
382-
VideoClip = getattr(editor_mod, "VideoClip")
383-
if isinstance(data, VideoClip):
382+
video_clip = editor_mod.VideoClip
383+
if isinstance(data, video_clip):
384384
return data
385385
except Exception:
386386
pass
387387

388388
return None
389389

390-
def _moviepy_clip_from_array(self, data, fps: float):
390+
def _moviepy_clip_from_array(self, data: Any, fps: float) -> Any:
391391
np = import_module("numpy")
392392

393393
if data.dtype != np.uint8:
@@ -402,8 +402,7 @@ def _moviepy_clip_from_array(self, data, fps: float):
402402

403403
if data.ndim not in (3, 4):
404404
raise ValueError(
405-
f"Unsupported array shape for video: {data.shape}. "
406-
"Expected (T,H,W), (T,H,W,C), or (T,C,H,W)."
405+
f"Unsupported array shape for video: {data.shape}. Expected (T,H,W), (T,H,W,C), or (T,C,H,W)."
407406
)
408407

409408
# (T, H, W) -> grayscale => expand to (T, H, W, 1)
@@ -420,34 +419,26 @@ def _moviepy_clip_from_array(self, data, fps: float):
420419
data = np.repeat(data, 3, axis=-1)
421420

422421
if data.shape[-1] not in (3, 4):
423-
raise ValueError(
424-
f"Unsupported channel count for video frames: {data.shape[-1]}"
425-
)
422+
raise ValueError(f"Unsupported channel count for video frames: {data.shape[-1]}")
426423

427424
# Drop alpha for now unless you explicitly want to preserve/use masks.
428425
if data.shape[-1] == 4:
429426
data = data[..., :3]
430427

431428
try:
432429
# MoviePy 2.x
433-
ImageSequenceClip = getattr(
434-
import_module("moviepy.video.io.ImageSequenceClip"),
435-
"ImageSequenceClip",
436-
)
430+
image_sequence_clip = import_module("moviepy.video.io.ImageSequenceClip").ImageSequenceClip
437431
except Exception:
438432
# Older common path
439-
ImageSequenceClip = getattr(
440-
import_module("moviepy.editor"),
441-
"ImageSequenceClip",
442-
)
433+
image_sequence_clip = import_module("moviepy.editor").ImageSequenceClip
443434

444435
# list(...) avoids some ndarray edge cases in callers and matches
445436
# common usage for frame sequences.
446-
return ImageSequenceClip(list(data), fps=fps)
437+
return image_sequence_clip(list(data), fps=fps)
447438

448-
def _write_moviepy_clip(self, clip, path: str, fps: float) -> None:
439+
def _write_moviepy_clip(self, clip: Any, path: str, fps: float) -> None:
449440
# For MP4, libx264 is the usual sensible default.
450-
kwargs = {
441+
kwargs: dict[str, Any] = {
451442
"fps": fps,
452443
"logger": None,
453444
}

tests/integrations/test_standalone_media.py

Lines changed: 60 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -377,3 +377,63 @@ def test_new_dict_api_file_download():
377377
project_id=project_id,
378378
body=LitLoggerServiceDeleteMetricsStreamBody(ids=[stream_id]),
379379
)
380+
381+
382+
@pytest.mark.cloud()
383+
def test_new_dict_api_video_download():
384+
"""Test the new dict-like API for uploading and downloading videos."""
385+
from litlogger import Video
386+
387+
experiment_name = f"standalone_dict_video_test-{uuid.uuid4().hex}"
388+
389+
with tempfile.TemporaryDirectory() as tmpdir:
390+
exp = litlogger.init(
391+
name=experiment_name,
392+
teamspace="oss-litlogger",
393+
root_dir=tmpdir,
394+
)
395+
396+
video_path = os.path.join(tmpdir, "preview.mp4")
397+
with open(video_path, "wb") as f:
398+
f.write(b"\x00\x00\x00\x18ftypmp42\x00\x00\x00\x00mp42isom")
399+
400+
exp["preview"] = Video(video_path)
401+
402+
assert isinstance(exp["preview"], Video)
403+
assert exp["preview"].name == "preview"
404+
405+
litlogger.finalize()
406+
407+
exp2 = litlogger.init(name=experiment_name, teamspace="oss-litlogger")
408+
409+
download_path = os.path.join(tmpdir, "downloaded_preview.mp4")
410+
video_downloaded = False
411+
last_exception = None
412+
for attempt in range(30):
413+
try:
414+
preview = exp2["preview"]
415+
if isinstance(preview, Video):
416+
preview.save(download_path)
417+
video_downloaded = True
418+
break
419+
except Exception as e:
420+
last_exception = e
421+
if attempt < 29:
422+
sleep(1)
423+
424+
litlogger.finalize()
425+
426+
if video_downloaded:
427+
assert os.path.exists(download_path)
428+
with open(video_path, "rb") as original, open(download_path, "rb") as downloaded:
429+
assert downloaded.read() == original.read()
430+
else:
431+
print(f"\nWarning: Could not download video. Last exception: {last_exception}")
432+
433+
project_id = exp._teamspace.id
434+
stream_id = exp._metrics_store.id
435+
client = LitRestClient()
436+
client.lit_logger_service_delete_metrics_stream(
437+
project_id=project_id,
438+
body=LitLoggerServiceDeleteMetricsStreamBody(ids=[stream_id]),
439+
)

tests/unittests/test_experiment_media.py

Lines changed: 148 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -11,7 +11,7 @@
1111
import pytest
1212
from lightning_sdk.lightning_cloud.openapi import V1MediaType
1313
from litlogger.experiment import Experiment
14-
from litlogger.media import File, Image, Model, Text
14+
from litlogger.media import File, Image, Model, Text, Video
1515
from litlogger.series import Series
1616

1717

@@ -92,6 +92,15 @@ def test_setitem_text(self):
9292
assert exp._static_files["notes"] is t
9393
exp._set_static_file.assert_called_once_with("notes", t)
9494

95+
def test_setitem_video(self):
96+
exp = _make_exp()
97+
video = Video("preview.mp4")
98+
exp["preview"] = video
99+
100+
assert exp._key_types["preview"] == "static_file"
101+
assert exp._static_files["preview"] is video
102+
exp._set_static_file.assert_called_once_with("preview", video)
103+
95104
def test_overwrite_same_type(self):
96105
"""Overwriting a static_file key with another File is allowed."""
97106
exp = _make_exp()
@@ -188,6 +197,48 @@ def test_non_file_media_uses_media_api(self):
188197
assert kwargs["media_type"] == V1MediaType.IMAGE
189198
assert exp._stats.media_logged == 1
190199

200+
def test_video_media_uses_media_api(self):
201+
exp = MagicMock(spec=Experiment)
202+
exp._media_api = MagicMock()
203+
exp._metrics_store = MagicMock()
204+
exp._metrics_store.id = "store-1"
205+
exp._teamspace = MagicMock()
206+
exp._stats = MagicMock()
207+
exp._stats.media_logged = 0
208+
exp._media_type_to_v1 = lambda media_type: Experiment._media_type_to_v1(exp, media_type)
209+
exp._upload_media = (
210+
lambda name, file_path, media_type, step=None, epoch=None, caption=None: Experiment._upload_media(
211+
exp,
212+
name,
213+
file_path,
214+
media_type,
215+
step=step,
216+
epoch=epoch,
217+
caption=caption,
218+
)
219+
)
220+
exp._upload_media_value = (
221+
lambda key, value, name=None, step=None, epoch=None, caption=None: Experiment._upload_media_value(
222+
exp,
223+
key,
224+
value,
225+
name=name,
226+
step=step,
227+
epoch=epoch,
228+
caption=caption,
229+
)
230+
)
231+
232+
video = Video("preview.mp4")
233+
Experiment._set_static_file(exp, "preview", video)
234+
235+
exp._media_api.upload_media.assert_called_once()
236+
_, kwargs = exp._media_api.upload_media.call_args
237+
assert kwargs["name"] == "preview"
238+
assert kwargs["file_path"] == "preview.mp4"
239+
assert kwargs["media_type"] == V1MediaType.VIDEO
240+
assert exp._stats.media_logged == 1
241+
191242
@patch.object(Model, "_log_model", return_value="owner/team/exp-model:latest")
192243
def test_model_artifact_uses_litmodels(self, mock_log_model):
193244
exp = Experiment.__new__(Experiment)
@@ -208,6 +259,7 @@ def test_model_artifact_uses_litmodels(self, mock_log_model):
208259
mock_log_model.assert_called_once_with(
209260
experiment_name="exp",
210261
teamspace=exp._teamspace,
262+
key="checkpoint",
211263
experiment=exp,
212264
cloud_account="acc-1",
213265
)
@@ -235,6 +287,7 @@ def test_model_object_uses_litmodels(self, mock_log_model):
235287
mock_log_model.assert_called_once_with(
236288
experiment_name="exp",
237289
teamspace=exp._teamspace,
290+
key="model-object",
238291
experiment=exp,
239292
cloud_account="acc-1",
240293
)
@@ -300,6 +353,14 @@ def test_append_text_to_series(self):
300353
assert len(exp["logs"]) == 1
301354
assert exp._key_types["logs"] == "file_series"
302355

356+
def test_append_video_to_series(self):
357+
exp = _make_exp()
358+
video = Video("preview.mp4")
359+
exp["clips"].append(video)
360+
361+
assert len(exp["clips"]) == 1
362+
assert exp._key_types["clips"] == "file_series"
363+
303364

304365
class TestFileSeriesBindings:
305366
"""Test that _log_file_series_value binds name and _download_fn."""
@@ -400,6 +461,48 @@ def test_non_file_series_uses_media_api(self):
400461
assert kwargs["media_type"] == V1MediaType.TEXT
401462
assert exp._stats.media_logged == 1
402463

464+
def test_video_series_uses_media_api(self):
465+
exp = MagicMock(spec=Experiment)
466+
exp._media_api = MagicMock()
467+
exp._metrics_store = MagicMock()
468+
exp._metrics_store.id = "store-1"
469+
exp._teamspace = MagicMock()
470+
exp._stats = MagicMock()
471+
exp._stats.media_logged = 0
472+
exp._media_type_to_v1 = lambda media_type: Experiment._media_type_to_v1(exp, media_type)
473+
exp._upload_media = (
474+
lambda name, file_path, media_type, step=None, epoch=None, caption=None: Experiment._upload_media(
475+
exp,
476+
name,
477+
file_path,
478+
media_type,
479+
step=step,
480+
epoch=epoch,
481+
caption=caption,
482+
)
483+
)
484+
exp._upload_media_value = (
485+
lambda key, value, name=None, step=None, epoch=None, caption=None: Experiment._upload_media_value(
486+
exp,
487+
key,
488+
value,
489+
name=name,
490+
step=step,
491+
epoch=epoch,
492+
caption=caption,
493+
)
494+
)
495+
496+
video = Video("preview.mp4")
497+
Experiment._log_file_series_value(exp, "clips", video, 2, step=7)
498+
499+
exp._media_api.upload_media.assert_called_once()
500+
_, kwargs = exp._media_api.upload_media.call_args
501+
assert kwargs["name"] == "clips"
502+
assert kwargs["step"] == 7
503+
assert kwargs["media_type"] == V1MediaType.VIDEO
504+
assert exp._stats.media_logged == 1
505+
403506
@patch.object(Model, "_log_model", return_value="owner/team/exp-model-series:latest")
404507
def test_model_series_uses_series_key_for_remote_binding(self, mock_log_model):
405508
exp = Experiment.__new__(Experiment)
@@ -420,6 +523,7 @@ def test_model_series_uses_series_key_for_remote_binding(self, mock_log_model):
420523
mock_log_model.assert_called_once_with(
421524
experiment_name="exp",
422525
teamspace=exp._teamspace,
526+
key="models",
423527
experiment=exp,
424528
cloud_account="acc-1",
425529
)
@@ -460,6 +564,13 @@ def test_getitem_returns_text(self):
460564

461565
assert exp["notes"] is t
462566

567+
def test_getitem_returns_video(self):
568+
exp = _make_exp()
569+
video = Video("preview.mp4")
570+
exp["preview"] = video
571+
572+
assert exp["preview"] is video
573+
463574
def test_save_without_upload_raises(self):
464575
"""File.save() fails if not yet uploaded (no _download_fn)."""
465576
f = File("local.txt")
@@ -749,6 +860,42 @@ def test_rebuilds_static_media_with_wrapper(self):
749860
assert wrapped.name == "preview"
750861
assert wrapped._download_fn is not None
751862

863+
def test_rebuilds_static_video_with_wrapper(self):
864+
media = MagicMock()
865+
media.name = "preview"
866+
media.storage_path = "media/preview.mp4"
867+
media.cluster_id = "cloud-1"
868+
media.media_type = V1MediaType.VIDEO
869+
media.id = "media-1"
870+
871+
exp = MagicMock(spec=Experiment)
872+
exp._key_types = {}
873+
exp._metadata_values = {}
874+
exp._static_files = {}
875+
exp._series = {}
876+
exp._metrics_store = MagicMock()
877+
exp._metrics_store.id = "store-1"
878+
exp._update_metrics_store = MagicMock()
879+
exp._metrics_store.tags = []
880+
exp._metrics_store.artifacts = []
881+
exp._metrics_api = MagicMock()
882+
exp._teamspace = MagicMock()
883+
exp._teamspace.id = "ts-1"
884+
exp._media_api = MagicMock()
885+
exp._media_api.client.lit_logger_service_list_lit_logger_media.return_value.media = [media]
886+
exp._wrap_media_file = lambda media_name, media_type: Experiment._wrap_media_file(exp, media_name, media_type)
887+
exp._create_media_download_fn = lambda storage_path, cloud_account=None: Experiment._create_media_download_fn(
888+
exp, storage_path, cloud_account
889+
)
890+
exp._resumed_steps = {}
891+
892+
Experiment._rebuild_state(exp)
893+
894+
wrapped = exp._static_files["preview"]
895+
assert isinstance(wrapped, Video)
896+
assert wrapped.name == "preview"
897+
assert wrapped._download_fn is not None
898+
752899
def test_rebuilds_media_series_with_wrapper(self):
753900
media0 = MagicMock()
754901
media0.name = "logs/0"

0 commit comments

Comments
 (0)