Skip to content

Commit e18f54c

Browse files
authored
Merge pull request #3 from clumsy/main
Add support to wandb and tensorboard loggers, add legacy pytorch_lightning.loggers support
2 parents fa96054 + 6307d51 commit e18f54c

File tree

17 files changed

+604
-93
lines changed

17 files changed

+604
-93
lines changed

.gitignore

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -147,7 +147,9 @@ venv.bak/
147147
runs/
148148

149149
# mlflow
150-
mlruns/
150+
examples/mlflow/
151+
examples/tensorboard/
152+
examples/wandb/
151153

152154
# Pyre type checker
153155
.pyre/

examples/conf/hf.yaml

Lines changed: 11 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -19,10 +19,17 @@ trainer:
1919
val_check_interval: 5
2020
limit_val_batches: 1
2121
logger:
22-
_target_: lightning.pytorch.loggers.MLFlowLogger
23-
experiment_name: ${model_name}-train
24-
tracking_uri: ./mlruns
25-
synchronous: false
22+
- _target_: lightning.pytorch.loggers.MLFlowLogger
23+
experiment_name: ${model_name}-train
24+
tracking_uri: ./mlflow
25+
synchronous: false
26+
- _target_: lightning.pytorch.loggers.TensorBoardLogger
27+
save_dir: ./tensorboard
28+
name: ${model_name}-train
29+
- _target_: lightning.pytorch.loggers.WandbLogger
30+
project: ${model_name}-train
31+
save_dir: ./wandb
32+
offline: true
2633
callbacks:
2734
- _target_: fkat.pytorch.callbacks.cuda.memory.MemoryObserver
2835
- _target_: fkat.pytorch.callbacks.monitoring.HardwareStats

examples/sync_wandb_offline.py

Lines changed: 105 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,105 @@
1+
#!/usr/bin/env python3
2+
"""Sync wandb offline runs by manually uploading metrics from summary files."""
3+
4+
import json
5+
import os
6+
import sys
7+
from pathlib import Path
8+
9+
import wandb
10+
11+
12+
def _extract_project_name(config_file: Path) -> str:
13+
"""Extract project name from wandb config file."""
14+
if not config_file.exists():
15+
return "unknown"
16+
17+
import yaml
18+
19+
with open(config_file) as f:
20+
config = yaml.safe_load(f)
21+
22+
# Try to extract model_name from various locations
23+
if "model_name" in config and isinstance(config["model_name"], dict):
24+
return f"{config['model_name']['value']}-train"
25+
26+
if "_wandb" in config and "value" in config["_wandb"]:
27+
wandb_val = config["_wandb"]["value"]
28+
e_dict = wandb_val.get("e", {})
29+
if e_dict:
30+
first_exec = list(e_dict.values())[0]
31+
args = first_exec.get("args", [])
32+
for arg in args:
33+
if "model_name=" in str(arg):
34+
model = str(arg).split("model_name=")[1]
35+
return f"{model}-train"
36+
37+
return "unknown"
38+
39+
40+
def _upload_files(run_path: Path, run_dir_path: str) -> None:
41+
"""Upload files from offline run to wandb."""
42+
import shutil
43+
44+
files_dir = run_path / "files"
45+
if not files_dir.exists():
46+
return
47+
48+
for file_path in files_dir.rglob("*"):
49+
if file_path.is_file() and not file_path.name.startswith("wandb-"):
50+
try:
51+
rel_path = file_path.relative_to(files_dir)
52+
dest = Path(run_dir_path) / rel_path
53+
dest.parent.mkdir(parents=True, exist_ok=True)
54+
shutil.copy2(file_path, dest)
55+
wandb.save(str(rel_path), base_path=run_dir_path, policy="now")
56+
except Exception:
57+
pass
58+
59+
60+
def sync_offline_run(run_dir: str, base_url: str, api_key: str) -> None:
61+
"""Sync an offline run by reading summary and uploading metrics."""
62+
run_path = Path(run_dir)
63+
summary_file = run_path / "files" / "wandb-summary.json"
64+
65+
if not summary_file.exists():
66+
print(f"No summary file found in {run_dir}")
67+
return
68+
69+
run_id = run_path.name.split("-")[-1]
70+
71+
with open(summary_file) as f:
72+
summary = json.load(f)
73+
74+
config_file = run_path / "files" / "config.yaml"
75+
project = _extract_project_name(config_file)
76+
77+
os.environ["WANDB_BASE_URL"] = base_url
78+
os.environ["WANDB_API_KEY"] = api_key
79+
80+
run = wandb.init(project=project, id=run_id, resume="allow", mode="online")
81+
82+
for key, value in summary.items():
83+
if not key.startswith("_"):
84+
run.log({key: value})
85+
86+
_upload_files(run_path, run.dir)
87+
88+
run.finish()
89+
print(f"✓ Synced {run_dir}")
90+
91+
92+
if __name__ == "__main__":
93+
if len(sys.argv) < 4:
94+
print("Usage: sync_wandb_offline.py <base_url> <api_key> <run_dir1> [run_dir2 ...]")
95+
sys.exit(1)
96+
97+
base_url = sys.argv[1]
98+
api_key = sys.argv[2]
99+
run_dirs = sys.argv[3:]
100+
101+
for run_dir in run_dirs:
102+
try:
103+
sync_offline_run(run_dir, base_url, api_key)
104+
except Exception as e:
105+
print(f"✗ Failed to sync {run_dir}: {e}")

fkat/__version__.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,3 +1,3 @@
11
# Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
22
# SPDX-License-Identifier: Apache-2.0
3-
__version__ = "0.1.1"
3+
__version__ = "0.1.2"

fkat/pytorch/callbacks/cuda/memory.py

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,14 +18,15 @@
1818
)
1919
from fkat.pytorch.loggers import LightningLogger
2020
from fkat.pytorch.callbacks.loggers import CallbackLogger
21+
from fkat.utils import safe_timestamp
2122

2223
logger: logging.Logger = logging.getLogger(__name__)
2324

2425

2526
def _artifact_path(root_dir: str, rank: int, file_type: str, ext: str) -> tuple[str, str]:
2627
base_dir = os.path.join(root_dir, "torch.cuda.memory")
27-
now = datetime.now(timezone.utc).isoformat()
28-
file_path = os.path.join(base_dir, f"rank{rank}/{file_type}/rank{rank}_{now}.{ext}")
28+
timestamp = safe_timestamp()
29+
file_path = os.path.join(base_dir, f"rank{rank}/{file_type}/rank{rank}_{timestamp}.{ext}")
2930
os.makedirs(os.path.dirname(file_path), exist_ok=True)
3031
return base_dir, file_path
3132

fkat/pytorch/callbacks/cuda/test/memory_test.py

Lines changed: 10 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -14,10 +14,13 @@
1414

1515
class TestArtifactPath(unittest.TestCase):
1616
@patch("os.makedirs")
17-
@patch(f"{memory.__name__}.datetime")
17+
@patch("fkat.utils.datetime")
1818
def test_artifact_path_creates_correct_structure(self, mock_datetime, mock_makedirs):
1919
# Arrange
20-
mock_datetime.now.return_value.isoformat.return_value = "2025-06-18T20:00:00"
20+
mock_now = MagicMock()
21+
mock_now.strftime.return_value = "2025-06-18_20-00-00-"
22+
mock_now.microsecond = 0
23+
mock_datetime.now.return_value = mock_now
2124
root_dir = "/tmp/test"
2225
rank = 1
2326
file_type = "snapshot"
@@ -28,17 +31,17 @@ def test_artifact_path_creates_correct_structure(self, mock_datetime, mock_maked
2831

2932
# Assert
3033
expected_base_dir = "/tmp/test/torch.cuda.memory"
31-
expected_file_path = "/tmp/test/torch.cuda.memory/rank1/snapshot/rank1_2025-06-18T20:00:00.pickle"
34+
expected_file_path = "/tmp/test/torch.cuda.memory/rank1/snapshot/rank1_2025-06-18_20-00-00-000.pickle"
3235

3336
assert base_dir == expected_base_dir
3437
assert file_path == expected_file_path
3538
mock_makedirs.assert_called_once_with("/tmp/test/torch.cuda.memory/rank1/snapshot", exist_ok=True)
3639

3740
@patch("os.makedirs")
38-
@patch(f"{memory.__name__}.datetime")
39-
def test_artifact_path_different_parameters(self, mock_datetime, mock_makedirs):
41+
@patch(f"{memory.__name__}.safe_timestamp")
42+
def test_artifact_path_different_parameters(self, mock_safe_timestamp, mock_makedirs):
4043
# Arrange
41-
mock_datetime.now.return_value.isoformat.return_value = "2025-06-18T15:30:45"
44+
mock_safe_timestamp.return_value = "2025-06-18_15-30-45-123"
4245
root_dir = "/var/logs"
4346
rank = 0
4447
file_type = "flamegraph"
@@ -49,7 +52,7 @@ def test_artifact_path_different_parameters(self, mock_datetime, mock_makedirs):
4952

5053
# Assert
5154
expected_base_dir = "/var/logs/torch.cuda.memory"
52-
expected_file_path = "/var/logs/torch.cuda.memory/rank0/flamegraph/rank0_2025-06-18T15:30:45.svg"
55+
expected_file_path = "/var/logs/torch.cuda.memory/rank0/flamegraph/rank0_2025-06-18_15-30-45-123.svg"
5356

5457
assert base_dir == expected_base_dir
5558
assert file_path == expected_file_path

fkat/pytorch/callbacks/loggers.py

Lines changed: 73 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -4,22 +4,22 @@
44
from typing_extensions import override
55

66
import lightning as L
7-
from lightning.pytorch.loggers import MLFlowLogger
7+
from lightning.pytorch.loggers import TensorBoardLogger, WandbLogger
88
from mlflow.entities import Metric, RunTag, Param
99
from mlflow.tracking import MlflowClient # type: ignore[possibly-unbound-import]
1010

1111
if TYPE_CHECKING:
12-
from lightning.pytorch.loggers import MLFlowLogger
12+
pass
1313

14-
from fkat.pytorch.loggers import LightningLogger
14+
from fkat.pytorch.loggers import LightningLogger, _is_logger_type
1515
from fkat.utils import assert_not_none
1616
from fkat.utils.logging import rank0_logger
1717
from fkat.utils.mlflow import broadcast_mlflow_run_id, mlflow_logger
1818

1919
log = rank0_logger(__name__)
2020

2121

22-
class MLFlowCallbackLogger:
22+
class MLFlowCallbackLogger(LightningLogger):
2323
"""
2424
Mlflow logger class that supports distributed logging of tags, metrics and artifacts.
2525
@@ -86,6 +86,69 @@ def log_artifact(self, local_path: str, artifact_path: str | None = None) -> Non
8686
)
8787

8888

89+
class TensorBoardCallbackLogger(LightningLogger):
90+
"""TensorBoard logger for distributed logging."""
91+
92+
def __init__(self, logger: TensorBoardLogger) -> None:
93+
self._logger = logger
94+
95+
def log_tag(self, key: str, value: str) -> None:
96+
self._logger.experiment.add_text(key, value)
97+
98+
def tags(self) -> dict[str, Any]:
99+
return {}
100+
101+
def log_batch(
102+
self,
103+
metrics: dict[str, float] | None = None,
104+
params: dict[str, Any] | None = None,
105+
tags: dict[str, str] | None = None,
106+
timestamp: int | None = None,
107+
step: int | None = None,
108+
) -> None:
109+
if metrics:
110+
for k, v in metrics.items():
111+
self._logger.experiment.add_scalar(k, v, step)
112+
if tags:
113+
for k, v in tags.items():
114+
self._logger.experiment.add_text(k, v, step)
115+
116+
def log_artifact(self, local_path: str, artifact_path: str | None = None) -> None:
117+
pass
118+
119+
120+
class WandbCallbackLogger(LightningLogger):
121+
"""WandB logger for distributed logging."""
122+
123+
def __init__(self, logger: WandbLogger) -> None:
124+
self._logger = logger
125+
126+
def log_tag(self, key: str, value: str) -> None:
127+
self._logger.experiment.config.update({key: value})
128+
129+
def tags(self) -> dict[str, Any]:
130+
return dict(self._logger.experiment.config)
131+
132+
def log_batch(
133+
self,
134+
metrics: dict[str, float] | None = None,
135+
params: dict[str, Any] | None = None,
136+
tags: dict[str, str] | None = None,
137+
timestamp: int | None = None,
138+
step: int | None = None,
139+
) -> None:
140+
log_dict = {}
141+
if metrics:
142+
log_dict.update(metrics)
143+
if tags:
144+
log_dict.update(tags)
145+
if log_dict:
146+
self._logger.experiment.log(log_dict, step=step)
147+
148+
def log_artifact(self, local_path: str, artifact_path: str | None = None) -> None:
149+
self._logger.experiment.save(local_path)
150+
151+
89152
class CallbackLogger(LightningLogger):
90153
"""
91154
A wrapper on top of the collection of Logger instances,
@@ -104,9 +167,13 @@ class CallbackLogger(LightningLogger):
104167
def __init__(self, trainer: "L.Trainer | None", loggers: list[LightningLogger] | None = None) -> None:
105168
if trainer:
106169
self.loggers = []
107-
for logger in trainer.logger if isinstance(trainer.logger, list) else [trainer.logger]:
108-
if isinstance(logger, MLFlowLogger):
170+
for logger in trainer.loggers:
171+
if _is_logger_type(logger, "MLFlowLogger"):
109172
self.loggers.append(MLFlowCallbackLogger(trainer=trainer))
173+
elif _is_logger_type(logger, "TensorBoardLogger"):
174+
self.loggers.append(TensorBoardCallbackLogger(logger=logger)) # type: ignore[arg-type]
175+
elif _is_logger_type(logger, "WandbLogger"):
176+
self.loggers.append(WandbCallbackLogger(logger=logger)) # type: ignore[arg-type]
110177
else:
111178
assert loggers
112179
self.loggers = loggers

0 commit comments

Comments
 (0)