Skip to content

Commit a77d6cd

Browse files
authored
Add experimental log_media method (#16)
1 parent 60af930 commit a77d6cd

6 files changed

Lines changed: 233 additions & 2 deletions

File tree

src/litlogger/experiment.py

Lines changed: 67 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,7 @@
1515

1616
import atexit
1717
import contextlib
18+
import mimetypes
1819
import os
1920
import signal
2021
import sys
@@ -26,15 +27,18 @@
2627
from types import FrameType
2728
from typing import TYPE_CHECKING, Any, Dict, List, Union
2829

30+
from lightning_sdk.lightning_cloud.openapi import V1MediaType
31+
2932
from litlogger.api.artifacts_api import ArtifactsApi
3033
from litlogger.api.auth_api import AuthApi
34+
from litlogger.api.media_api import MediaApi
3135
from litlogger.api.metrics_api import MetricsApi
3236
from litlogger.api.utils import _resolve_teamspace, build_experiment_url, get_accessible_url, get_guest_url
3337
from litlogger.artifacts import Artifact, Model, ModelArtifact
3438
from litlogger.background import _BackgroundThread
3539
from litlogger.capture import rerun_and_record
3640
from litlogger.printer import Printer, RunStats
37-
from litlogger.types import Metrics, MetricValue
41+
from litlogger.types import MediaType, Metrics, MetricValue
3842

3943
if TYPE_CHECKING:
4044
from lightning_sdk import Teamspace
@@ -111,6 +115,7 @@ def __init__(
111115
teamspace = None
112116

113117
self._metrics_api = MetricsApi()
118+
self._media_api = MediaApi(client=self._metrics_api.client)
114119
self._artifacts_api = ArtifactsApi()
115120
self._teamspace = _resolve_teamspace(teamspace)
116121

@@ -561,6 +566,67 @@ def get_model(self, staging_dir: str | None = None, verbose: bool = False, versi
561566
self._printer.print_success("Retrieved model object")
562567
return result
563568

569+
def log_media(
570+
self,
571+
name: str,
572+
path: str,
573+
kind: MediaType | None = None,
574+
step: int | None = None,
575+
epoch: int | None = None,
576+
caption: str | None = None,
577+
verbose: bool = False,
578+
) -> None:
579+
"""Upload a media file (image, text, etc.) to the experiment.
580+
581+
Args:
582+
name: Name of the media.
583+
path: Local path to the media file.
584+
kind: Type of media (MediaType.IMAGE or MediaType.TEXT).
585+
If None, attempts to guess from file extension or mime type.
586+
step: Optional training step.
587+
epoch: Optional training epoch.
588+
caption: Optional caption for the media.
589+
verbose: Whether to print a confirmation message after upload.
590+
591+
Raises:
592+
ValueError: If the file type cannot be determined or is not supported.
593+
FileNotFoundError: If the file does not exist.
594+
"""
595+
if not os.path.exists(path):
596+
raise FileNotFoundError(f"Media file not found: {path}")
597+
598+
media_type = V1MediaType.UNSPECIFIED
599+
600+
if kind is not None:
601+
if kind == MediaType.IMAGE:
602+
media_type = V1MediaType.IMAGE
603+
elif kind == MediaType.TEXT:
604+
media_type = V1MediaType.TEXT
605+
else:
606+
mime_type, _ = mimetypes.guess_type(path)
607+
if mime_type:
608+
if mime_type.startswith("image/"):
609+
media_type = V1MediaType.IMAGE
610+
elif mime_type.startswith("text/"):
611+
media_type = V1MediaType.TEXT
612+
613+
if media_type == V1MediaType.UNSPECIFIED:
614+
raise ValueError(f"Unsupported media type for file: {path}")
615+
616+
self._media_api.upload_media(
617+
experiment_id=self._metrics_store.id,
618+
teamspace=self._teamspace,
619+
file_path=path,
620+
name=name,
621+
media_type=media_type,
622+
step=step,
623+
epoch=epoch,
624+
caption=caption,
625+
)
626+
self._stats.media_logged += 1
627+
if verbose:
628+
self._printer.media_logged(path, step)
629+
564630
def print_url(self) -> None:
565631
"""Print the experiment URL and initialization info with styled output."""
566632
self._printer.experiment_start(

src/litlogger/logger.py

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -28,6 +28,7 @@
2828

2929
from litlogger.experiment import Experiment
3030
from litlogger.generator import _create_name
31+
from litlogger.types import MediaType
3132

3233
_base_classes = []
3334

@@ -418,3 +419,29 @@ def log_file(self, path: str) -> None:
418419
self._is_ready = True
419420
self._store_step = True
420421
self.experiment.log_file(path)
422+
423+
def log_media(
424+
self,
425+
name: str,
426+
path: str,
427+
kind: MediaType | None = None,
428+
step: int | None = None,
429+
epoch: int | None = None,
430+
caption: str | None = None,
431+
verbose: bool = False,
432+
) -> None:
433+
"""Log a media file to the experiment.
434+
435+
Args:
436+
name: Name of the media.
437+
path: Local path to the media file.
438+
kind: Kind of media (MediaType.IMAGE or MediaType.TEXT).
439+
If None, attempts to guess from file extension or mime type.
440+
step: Optional training step.
441+
epoch: Optional training epoch.
442+
caption: Optional caption.
443+
verbose: Whether to print a confirmation message after upload.
444+
"""
445+
self._is_ready = True
446+
self._store_step = True
447+
self.experiment.log_media(name, path, kind, step, epoch, caption, verbose)

src/litlogger/printer.py

Lines changed: 16 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -121,6 +121,7 @@ class RunStats:
121121

122122
metrics_logged: int = 0
123123
artifacts_logged: int = 0
124+
media_logged: int = 0
124125
models_logged: int = 0
125126
# Store recent values for sparklines (metric_name -> list of values)
126127
metrics_history: Dict[str, List[float]] = field(default_factory=dict)
@@ -406,6 +407,8 @@ def _print_run_stats(self, stats: RunStats) -> None:
406407
rows.append(("Metrics logged", f"{stats.metrics_logged:,}"))
407408
if stats.artifacts_logged > 0:
408409
rows.append(("Artifacts", f"{stats.artifacts_logged:,}"))
410+
if stats.media_logged > 0:
411+
rows.append(("Media logged", f"{stats.media_logged:,}"))
409412
if stats.models_logged > 0:
410413
rows.append(("Models", f"{stats.models_logged:,}"))
411414

@@ -477,6 +480,19 @@ def artifact_logged(self, path: str, remote_path: str | None = None) -> None:
477480
display_path = remote_path or path
478481
self._echo(f"{self.success(check)} Logged {self.files(display_path)}")
479482

483+
def media_logged(self, path: str, step: int | None = None) -> None:
484+
"""Print media upload confirmation.
485+
486+
Example:
487+
litlogger: ✓ Logged output.png (step 100)
488+
"""
489+
if not self.verbose:
490+
return
491+
492+
check = self.emoji("check") or "✓"
493+
step_str = f" (step {step})" if step is not None else ""
494+
self._echo(f"{self.success(check)} Logged {self.files(path)}{step_str}")
495+
480496
def artifact_retrieved(self, path: str) -> None:
481497
"""Print artifact download confirmation.
482498

src/litlogger/types.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -33,6 +33,13 @@ class PhaseType(str, Enum):
3333
STOPPED = "stopped"
3434

3535

36+
class MediaType(str, Enum):
37+
"""Type of media to upload."""
38+
39+
IMAGE = "image"
40+
TEXT = "text"
41+
42+
3643
@dataclass
3744
class MetricValue:
3845
"""A single metric value with optional step and timestamp.

tests/unittests/test_experiment.py

Lines changed: 103 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -676,4 +676,106 @@ def test_atexit_handler_registered(self):
676676
def test_signal_handlers_exist(self):
677677
"""Test that signal handler method exists."""
678678
assert hasattr(experiment_module.Experiment, "_signal_handler")
679-
assert callable(experiment_module.Experiment._signal_handler)
679+
680+
681+
class TestExperimentLogMedia:
682+
"""Test log_media method."""
683+
684+
def test_log_media_basic(self):
685+
"""Test basic media upload with explicit type."""
686+
exp = MagicMock()
687+
exp.name = "test_exp"
688+
exp._manager = MagicMock()
689+
exp._media_api = MagicMock()
690+
exp._printer = MagicMock()
691+
exp._stats = MagicMock()
692+
exp._stats.media_logged = 0
693+
694+
# Import MediaType from types
695+
from unittest.mock import patch
696+
697+
from lightning_sdk.lightning_cloud.openapi import V1MediaType
698+
from litlogger.experiment import Experiment
699+
from litlogger.types import MediaType
700+
701+
with patch("os.path.exists", return_value=True):
702+
Experiment.log_media(exp, "image", "/path/to/image.png", kind=MediaType.IMAGE)
703+
704+
# Verify upload_media called with correct args
705+
exp._media_api.upload_media.assert_called_once()
706+
_, kwargs = exp._media_api.upload_media.call_args
707+
assert kwargs["file_path"] == "/path/to/image.png"
708+
assert kwargs["media_type"] == V1MediaType.IMAGE
709+
assert exp._stats.media_logged == 1
710+
711+
def test_log_media_guess_type_image(self):
712+
"""Test media upload with guessed image type."""
713+
exp = MagicMock()
714+
exp._media_api = MagicMock()
715+
exp._printer = MagicMock()
716+
exp._stats = MagicMock()
717+
exp._stats.media_logged = 0
718+
719+
from unittest.mock import patch
720+
721+
from lightning_sdk.lightning_cloud.openapi import V1MediaType
722+
from litlogger.experiment import Experiment
723+
724+
with patch("os.path.exists", return_value=True):
725+
Experiment.log_media(exp, "image", "/path/to/image.jpg")
726+
727+
_, kwargs = exp._media_api.upload_media.call_args
728+
assert kwargs["media_type"] == V1MediaType.IMAGE
729+
assert exp._stats.media_logged == 1
730+
731+
def test_log_media_guess_type_text(self):
732+
"""Test media upload with guessed text type."""
733+
exp = MagicMock()
734+
exp._media_api = MagicMock()
735+
exp._printer = MagicMock()
736+
exp._stats = MagicMock()
737+
exp._stats.media_logged = 0
738+
739+
from unittest.mock import patch
740+
741+
from lightning_sdk.lightning_cloud.openapi import V1MediaType
742+
from litlogger.experiment import Experiment
743+
744+
with patch("os.path.exists", return_value=True):
745+
Experiment.log_media(exp, "file", "/path/to/file.txt")
746+
747+
_, kwargs = exp._media_api.upload_media.call_args
748+
assert kwargs["media_type"] == V1MediaType.TEXT
749+
assert exp._stats.media_logged == 1
750+
751+
def test_log_media_unsupported_type(self):
752+
"""Test log_media raises ValueError for guessed unsupported media type."""
753+
exp = MagicMock()
754+
exp._media_api = MagicMock()
755+
exp._printer = MagicMock()
756+
exp._stats = MagicMock()
757+
exp._stats.media_logged = 0
758+
759+
from unittest.mock import patch
760+
761+
import pytest
762+
from litlogger.experiment import Experiment
763+
764+
with (
765+
patch("os.path.exists", return_value=True),
766+
patch("mimetypes.guess_type", return_value=("application/zip", None)),
767+
pytest.raises(ValueError, match="Unsupported media type for file: /path/to/file.txt"),
768+
):
769+
Experiment.log_media(exp, "file", "/path/to/file.txt")
770+
771+
def test_log_media_raises_file_not_found(self):
772+
"""Test log_media raises FileNotFoundError."""
773+
exp = MagicMock()
774+
775+
from unittest.mock import patch
776+
777+
import pytest
778+
from litlogger.experiment import Experiment
779+
780+
with patch("os.path.exists", return_value=False), pytest.raises(FileNotFoundError):
781+
Experiment.log_media(exp, "file", "/non/existent/file.png")

tests/unittests/test_logger.py

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -47,3 +47,16 @@ def test_log_file():
4747
# Call log_file and verify it delegates to the experiment
4848
logger.log_file("test.txt")
4949
logger._experiment.log_file.assert_called_once_with("test.txt")
50+
51+
52+
def test_log_media():
53+
"""Test that LightningLogger has a log_media method."""
54+
from litlogger import LightningLogger
55+
from litlogger.types import MediaType
56+
57+
logger = object.__new__(LightningLogger)
58+
logger._experiment = MagicMock()
59+
60+
logger.log_media("image", "test.png", kind=MediaType.IMAGE, step=10, caption="Test caption")
61+
62+
logger._experiment.log_media.assert_called_once_with("image", "test.png", MediaType.IMAGE, 10, None, "Test caption")

0 commit comments

Comments
 (0)