Skip to content

Commit b6919f1

Browse files
authored
create metrax/logging package (#116)
* create metrax/logging * format files * only change src/metrax/logging * only modify src/metrax/logging * add tests * fix dependency * fix with statements lint
1 parent cf96547 commit b6919f1

File tree

7 files changed

+326
-0
lines changed

7 files changed

+326
-0
lines changed

pyproject.toml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -21,6 +21,8 @@ dependencies = [
2121
"flax==0.11.1",
2222
"jax==0.6.2",
2323
"numpy==2.1.3",
24+
"tensorboardX==2.6.4",
25+
"wandb==0.22.3",
2426
]
2527

2628
[tool.hatch.build]

src/metrax/logging/__init__.py

Lines changed: 21 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,21 @@
1+
# Copyright 2025 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from metrax.logging import protocol
16+
from metrax.logging import tensorboard_backend
17+
from metrax.logging import wandb_backend
18+
19+
LoggingBackend = protocol.LoggingBackend
20+
TensorboardBackend = tensorboard_backend.TensorboardBackend
21+
WandbBackend = wandb_backend.WandbBackend

src/metrax/logging/protocol.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,30 @@
1+
# Copyright 2025 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""Metrax LoggingBackend protocol."""
16+
17+
from typing import Protocol
18+
import numpy as np
19+
20+
21+
class LoggingBackend(Protocol):
22+
"""Defines the interface for a pluggable logging backend."""
23+
24+
def log_scalar(self, event: str, value: float | np.ndarray, **kwargs):
25+
"""Logs a scalar value. Must match jax.monitoring listener signature."""
26+
...
27+
28+
def close(self):
29+
"""Closes the logger and flushes any pending data."""
30+
...
Lines changed: 56 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,56 @@
1+
# Copyright 2025 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""Metrax LoggingBackend implemenation for Tensorboard."""
16+
17+
import jax
18+
import numpy as np
19+
from tensorboardX import writer
20+
21+
_DEFAULT_STEP = 0
22+
23+
24+
def _get_step(kwargs: dict[str, str | int]) -> int:
25+
"""Returns the step from the kwargs, or 0 if not provided."""
26+
step = kwargs.get("step")
27+
return _DEFAULT_STEP if step is None else int(step)
28+
29+
30+
def _preprocess_event_name(event_name: str) -> str:
31+
"""Preprocesses the event name before logging."""
32+
return event_name.lstrip("/") # Remove leading slashes
33+
34+
35+
class TensorboardBackend:
36+
"""A logging backend for Tensorboard that conforms to the LoggingBackend protocol."""
37+
38+
def __init__(self, log_dir: str, flush_every_n_steps: int = 100):
39+
self._flush_every_n_steps = flush_every_n_steps
40+
if jax.process_index() == 0:
41+
self._writer = writer.SummaryWriter(logdir=log_dir)
42+
else:
43+
self._writer = None
44+
45+
def log_scalar(self, event: str, value: float | np.ndarray, **kwargs):
46+
if self._writer is None:
47+
return
48+
current_step = _get_step(kwargs)
49+
event_name = _preprocess_event_name(event)
50+
self._writer.add_scalar(event_name, value, current_step)
51+
if current_step % self._flush_every_n_steps == 0:
52+
self._writer.flush()
53+
54+
def close(self):
55+
if self._writer:
56+
self._writer.close()
Lines changed: 64 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,64 @@
1+
# Copyright 2025 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# http://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
from unittest import mock
16+
17+
from absl.testing import absltest
18+
from metrax import logging as metrax_logging
19+
20+
TensorboardBackend = metrax_logging.TensorboardBackend
21+
22+
23+
class TensorboardBackendTest(absltest.TestCase):
24+
25+
@mock.patch("metrax.logging.tensorboard_backend.writer.SummaryWriter")
26+
def test_init_and_log_success_main_process(self, mock_summary_writer):
27+
"""Tests successful init, logging, and closing on the main process."""
28+
mock_writer_instance = mock_summary_writer.return_value
29+
30+
with mock.patch("jax.process_index", return_value=0):
31+
backend = TensorboardBackend(log_dir="/fake/logs", flush_every_n_steps=2)
32+
33+
mock_summary_writer.assert_called_once_with(logdir="/fake/logs")
34+
35+
backend.log_scalar("/event1", 1.0, step=1)
36+
mock_writer_instance.add_scalar.assert_called_with("event1", 1.0, 1)
37+
mock_writer_instance.flush.assert_not_called()
38+
39+
backend.log_scalar("event2", 2.0, step=2)
40+
mock_writer_instance.add_scalar.assert_called_with("event2", 2.0, 2)
41+
mock_writer_instance.flush.assert_called_once()
42+
43+
backend.close()
44+
mock_writer_instance.close.assert_called_once()
45+
46+
@mock.patch("metrax.logging.tensorboard_backend.writer.SummaryWriter")
47+
def test_init_non_main_process_is_noop(self, mock_summary_writer):
48+
"""Tests that the backend does nothing on non-main processes."""
49+
mock_writer_instance = mock_summary_writer.return_value
50+
51+
with mock.patch("jax.process_index", return_value=1):
52+
backend = TensorboardBackend(log_dir="/fake/logs")
53+
54+
mock_summary_writer.assert_not_called()
55+
56+
backend.log_scalar("myevent", 1.0, step=1)
57+
mock_writer_instance.add_scalar.assert_not_called()
58+
59+
backend.close()
60+
mock_writer_instance.close.assert_not_called()
61+
62+
63+
if __name__ == "__main__":
64+
absltest.main()
Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
# Copyright 2025 Google LLC
2+
#
3+
# Licensed under the Apache License, Version 2.0 (the "License");
4+
# you may not use this file except in compliance with the License.
5+
# You may obtain a copy of the License at
6+
#
7+
# https://www.apache.org/licenses/LICENSE-2.0
8+
#
9+
# Unless required by applicable law or agreed to in writing, software
10+
# distributed under the License is distributed on an "AS IS" BASIS,
11+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
# See the License for the specific language governing permissions and
13+
# limitations under the License.
14+
15+
"""Metrax LoggingBackend implemenation for Weight & Bias."""
16+
17+
import datetime
18+
import logging
19+
20+
import jax
21+
import numpy as np
22+
23+
_DEFAULT_STEP = 0
24+
25+
26+
def _get_step(kwargs: dict[str, str | int]) -> int:
27+
"""Returns the step from the kwargs, or 0 if not provided."""
28+
step = kwargs.get("step")
29+
return _DEFAULT_STEP if step is None else int(step)
30+
31+
32+
def _preprocess_event_name(event_name: str) -> str:
33+
"""Preprocesses the event name before logging."""
34+
return event_name.lstrip("/") # Remove leading slashes
35+
36+
37+
class WandbBackend:
38+
"""A logging backend for W&B that conforms to the LoggingBackend protocol."""
39+
40+
def __init__(self, project: str, name: str | None = None, **kwargs):
41+
if jax.process_index() != 0:
42+
self._is_active = False
43+
self.wandb = None # Ensure the attribute exists
44+
return
45+
46+
try:
47+
# pylint: disable=g-import-not-at-top
48+
# pytype: disable=import-error
49+
import wandb
50+
except ImportError as e:
51+
raise ImportError(
52+
"The 'wandb' library is not installed. Please install it with "
53+
"'pip install wandb' to use the WandbBackend."
54+
) from e
55+
self.wandb = wandb
56+
57+
run_name = name or datetime.datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
58+
wandb.init(project=project, name=run_name, anonymous="allow", **kwargs)
59+
if wandb.run:
60+
logging.info("W&B run URL: %s", wandb.run.url)
61+
self._is_active = True
62+
else:
63+
self._is_active = False
64+
65+
def log_scalar(self, event: str, value: float | np.ndarray, **kwargs):
66+
if self.wandb is None or not self._is_active:
67+
return
68+
current_step = _get_step(kwargs)
69+
event_name = _preprocess_event_name(event)
70+
self.wandb.log({event_name: value}, step=current_step)
71+
72+
def close(self):
73+
if self.wandb is None or not self._is_active:
74+
return
75+
if hasattr(self.wandb, "run") and self.wandb.run:
76+
self.wandb.finish()
Lines changed: 77 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,77 @@
1+
"""Tests for the W&B backend."""
2+
3+
import builtins
4+
from unittest import mock
5+
6+
from absl.testing import absltest
7+
from metrax import logging as metrax_logging
8+
9+
WandbBackend = metrax_logging.WandbBackend
10+
11+
12+
_real_import = builtins.__import__
13+
14+
15+
class WandbBackendTest(absltest.TestCase):
16+
17+
def setUp(self):
18+
super().setUp()
19+
self.mock_wandb = mock.Mock()
20+
self.mock_wandb.run = mock.Mock()
21+
self.mock_datetime = mock.Mock()
22+
self.mock_datetime.datetime.now.return_value.strftime.return_value = (
23+
"fixed-run-name"
24+
)
25+
26+
def _mock_successful_import(self, name, *args, **kwargs):
27+
"""Mock __import__ to return our mock_wandb for 'wandb'."""
28+
if name == "wandb":
29+
return self.mock_wandb
30+
return _real_import(name, *args, **kwargs)
31+
32+
def test_init_and_log_success_main_process(self):
33+
"""Tests successful init, logging, and closing on the main process."""
34+
with mock.patch("jax.process_index", return_value=0), mock.patch(
35+
"metrax.logging.wandb_backend.datetime", self.mock_datetime
36+
), mock.patch(
37+
"builtins.__import__", side_effect=self._mock_successful_import
38+
):
39+
40+
backend = WandbBackend(project="test-project")
41+
self.mock_wandb.init.assert_called_once_with(
42+
project="test-project", name="fixed-run-name", anonymous="allow"
43+
)
44+
self.assertTrue(backend._is_active)
45+
46+
backend.log_scalar("/myevent", 123.45, step=50)
47+
self.mock_wandb.log.assert_called_once_with({"myevent": 123.45}, step=50)
48+
49+
backend.close()
50+
self.mock_wandb.finish.assert_called_once()
51+
52+
def test_init_non_main_process_is_noop(self):
53+
"""Tests that the backend does nothing on non-main processes."""
54+
with mock.patch("jax.process_index", return_value=1):
55+
backend = WandbBackend(project="test-project")
56+
self.assertFalse(backend._is_active)
57+
self.assertIsNone(backend.wandb)
58+
59+
def test_init_fails_if_wandb_not_installed(self):
60+
"""Tests that __init__ raises an ImportError if wandb is missing."""
61+
62+
def failing_import(name, *args, **kwargs):
63+
"""Mock __import__ to raise an error for 'wandb'."""
64+
if name == "wandb":
65+
raise ImportError("Mocked import failure")
66+
return _real_import(name, *args, **kwargs)
67+
68+
with mock.patch(
69+
"builtins.__import__", side_effect=failing_import
70+
), mock.patch("jax.process_index", return_value=0):
71+
with self.assertRaises(ImportError) as cm:
72+
WandbBackend(project="test-project")
73+
self.assertIn("pip install wandb", str(cm.exception))
74+
75+
76+
if __name__ == "__main__":
77+
absltest.main()

0 commit comments

Comments
 (0)