Skip to content

Commit f4fcce8

Browse files
committed
[feat] Added MLFlow to Slime Logging Backend
1 parent d56d56b commit f4fcce8

File tree

6 files changed

+322
-15
lines changed

6 files changed

+322
-15
lines changed

setup.py

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -39,7 +39,10 @@ def get_tag(self):
3939
extras_require={
4040
"fsdp": [
4141
"torch>=2.0",
42-
]
42+
],
43+
"mlflow": [
44+
"mlflow",
45+
],
4346
},
4447
python_requires=">=3.10",
4548
classifiers=[

slime/utils/arguments.py

Lines changed: 25 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1113,6 +1113,30 @@ def add_wandb_arguments(parser):
11131113
parser.add_argument("--wandb-run-id", type=str, default=None)
11141114
return parser
11151115

1116+
# mlflow
1117+
def add_mlflow_arguments(parser):
1118+
parser.add_argument("--use-mlflow", action="store_true", default=False)
1119+
parser.add_argument(
1120+
"--mlflow-tracking-uri",
1121+
type=str,
1122+
default=None,
1123+
help="MLflow tracking server URI. Defaults to MLFLOW_TRACKING_URI env var, or local mlruns/ directory.",
1124+
)
1125+
parser.add_argument(
1126+
"--mlflow-experiment-name",
1127+
type=str,
1128+
default="slime",
1129+
help="MLflow experiment name.",
1130+
)
1131+
parser.add_argument(
1132+
"--mlflow-run-name",
1133+
type=str,
1134+
default=None,
1135+
help="MLflow run name. Defaults to --wandb-group if not set.",
1136+
)
1137+
parser.add_argument("--mlflow-run-id", type=str, default=None)
1138+
return parser
1139+
11161140
# tensorboard
11171141
def add_tensorboard_arguments(parser):
11181142
# tb_project_name, tb_experiment_name
@@ -1455,6 +1479,7 @@ def add_sglang_tp_size():
14551479
parser = add_algo_arguments(parser)
14561480
parser = add_on_policy_distillation_arguments(parser)
14571481
parser = add_wandb_arguments(parser)
1482+
parser = add_mlflow_arguments(parser)
14581483
parser = add_tensorboard_arguments(parser)
14591484
parser = add_router_arguments(parser)
14601485
parser = add_debug_arguments(parser)

slime/utils/logging_utils.py

Lines changed: 8 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,11 +1,9 @@
11
import logging
22

3-
import wandb
4-
5-
from . import wandb_utils
6-
from .tensorboard_utils import _TensorboardAdapter
3+
from .tracking import TrackingManager
74

85
_LOGGER_CONFIGURED = False
6+
_manager = TrackingManager()
97

108

119
# ref: SGLang
@@ -25,17 +23,13 @@ def configure_logger(prefix: str = ""):
2523

2624

2725
def init_tracking(args, primary: bool = True, **kwargs):
28-
if primary:
29-
wandb_utils.init_wandb_primary(args, **kwargs)
30-
else:
31-
wandb_utils.init_wandb_secondary(args, **kwargs)
26+
_manager.init(args, primary=primary, **kwargs)
3227

3328

34-
# TODO further refactor, e.g. put TensorBoard init to the "init" part
3529
def log(args, metrics, step_key: str):
36-
if args.use_wandb:
37-
wandb.log(metrics)
30+
step = metrics.get(step_key)
31+
_manager.log(metrics, step=step)
32+
3833

39-
if args.use_tensorboard:
40-
metrics_except_step = {k: v for k, v in metrics.items() if k != step_key}
41-
_TensorboardAdapter(args).log(data=metrics_except_step, step=metrics[step_key])
34+
def finish_tracking():
35+
_manager.finish()

slime/utils/tracking/__init__.py

Lines changed: 10 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,10 @@
1+
"""
2+
Shared tracking interface for experiment logging backends.
3+
4+
Exports :class:`TrackingManager` so existing ``from .tracking import TrackingManager``
5+
imports continue to work.
6+
"""
7+
8+
from .manager import TrackingBackend, TrackingManager # noqa: F401
9+
from .mlflow_utils import finish as mlflow_finish # noqa: F401
10+
from .mlflow_utils import init_mlflow, log_metrics # noqa: F401

slime/utils/tracking/manager.py

Lines changed: 140 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,140 @@
1+
"""
2+
Shared tracking interface for experiment logging backends.
3+
4+
Each backend implements ``init / log / finish``, and :class:`TrackingManager` fans out
5+
calls to every active backend.
6+
7+
To add a new backend:
8+
--------------------
9+
1. Subclass :class:`TrackingBackend`.
10+
2. Register it in :data:`BACKEND_REGISTRY`.
11+
3. Add a corresponding ``--use-<name>`` CLI flag in ``arguments.py``.
12+
"""
13+
14+
from __future__ import annotations
15+
16+
import logging
17+
from abc import ABC, abstractmethod
18+
from typing import Any
19+
20+
logger = logging.getLogger(__name__)
21+
22+
23+
24+
class TrackingBackend(ABC):
25+
# Interface every logging backend must satisfy.
26+
27+
@abstractmethod
28+
def init(self, args, *, primary: bool = True, **kwargs) -> None:
29+
...
30+
31+
@abstractmethod
32+
def log(self, metrics: dict[str, Any], step: int | None = None) -> None:
33+
...
34+
35+
@abstractmethod
36+
def finish(self) -> None:
37+
...
38+
39+
40+
# Thin adapters for backwards compatibility to keep wandb_utils and tensorboard_utils untouched.
41+
class WandbBackend(TrackingBackend):
42+
# Delegates to the existing ``wandb_utils`` helpers.
43+
44+
def init(self, args, *, primary: bool = True, **kwargs) -> None:
45+
from .. import wandb_utils
46+
47+
if primary:
48+
wandb_utils.init_wandb_primary(args, **kwargs)
49+
else:
50+
wandb_utils.init_wandb_secondary(args, **kwargs)
51+
52+
def log(self, metrics: dict[str, Any], step: int | None = None) -> None:
53+
import wandb
54+
55+
wandb.log(metrics)
56+
57+
def finish(self) -> None:
58+
import wandb
59+
60+
wandb.finish()
61+
62+
63+
class TensorboardBackend(TrackingBackend):
64+
# Delegates to the existing ``_TensorboardAdapter`` (part of the TODO).
65+
66+
_adapter = None
67+
68+
def init(self, args, *, primary: bool = True, **kwargs) -> None:
69+
from ..tensorboard_utils import _TensorboardAdapter
70+
71+
self._adapter = _TensorboardAdapter(args)
72+
73+
def log(self, metrics: dict[str, Any], step: int | None = None) -> None:
74+
if self._adapter is not None:
75+
# Strip step-key entries (e.g. "train/step", "rollout/step") —
76+
# tensorboard receives step as an explicit argument instead.
77+
data = {k: v for k, v in metrics.items() if not k.endswith("/step")}
78+
self._adapter.log(data=data, step=step)
79+
80+
def finish(self) -> None:
81+
if self._adapter is not None:
82+
self._adapter.finish()
83+
84+
85+
class MlflowBackend(TrackingBackend):
86+
"""Delegates to ``mlflow_utils``."""
87+
88+
def init(self, args, *, primary: bool = True, **kwargs) -> None:
89+
from . import mlflow_utils
90+
91+
mlflow_utils.init_mlflow(args, primary=primary, **kwargs)
92+
93+
def log(self, metrics: dict[str, Any], step: int | None = None) -> None:
94+
from . import mlflow_utils
95+
96+
mlflow_utils.log_metrics(metrics, step=step)
97+
98+
def finish(self) -> None:
99+
from . import mlflow_utils
100+
101+
mlflow_utils.finish()
102+
103+
104+
# Registry that maps backend name → (class, args-flag attribute)
105+
106+
BACKEND_REGISTRY: dict[str, tuple[type[TrackingBackend], str]] = {
107+
"wandb": (WandbBackend, "use_wandb"),
108+
"tensorboard": (TensorboardBackend, "use_tensorboard"),
109+
"mlflow": (MlflowBackend, "use_mlflow"),
110+
}
111+
112+
113+
class TrackingManager:
114+
#Initialises and logs to every enabled backend; used internally by ``logging_utils``.
115+
116+
def __init__(self) -> None:
117+
self._backends: list[TrackingBackend] = []
118+
119+
def init(self, args, *, primary: bool = True, **kwargs) -> None:
120+
for name, (cls, flag) in BACKEND_REGISTRY.items():
121+
if getattr(args, flag, False):
122+
logger.info("Initialising tracking backend: %s", name)
123+
backend = cls()
124+
backend.init(args, primary=primary, **kwargs)
125+
self._backends.append(backend)
126+
127+
def log(self, metrics: dict[str, Any], step: int | None = None) -> None:
128+
for backend in self._backends:
129+
backend.log(metrics, step=step)
130+
131+
def finish(self) -> None:
132+
for backend in self._backends:
133+
try:
134+
backend.finish()
135+
except Exception:
136+
logger.exception(
137+
"Error finishing tracking backend %s",
138+
type(backend).__name__,
139+
)
140+
self._backends.clear()
Lines changed: 135 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,135 @@
1+
"""
2+
MLflow tracking backend for slime.
3+
4+
5+
MLflow docs for future reference:
6+
- Tracking overview : https://mlflow.org/docs/latest/ml/tracking/
7+
- Python API : https://mlflow.org/docs/latest/python_api/mlflow.html
8+
- Remote tracking : https://mlflow.org/docs/latest/tracking/server.html
9+
"""
10+
11+
from __future__ import annotations
12+
13+
import logging
14+
import os
15+
import re
16+
from copy import deepcopy
17+
from typing import Any
18+
19+
logger = logging.getLogger(__name__)
20+
21+
22+
# Helpers/utils
23+
def _sanitize_key(key: str) -> str:
24+
return re.sub(r"[^a-zA-Z0-9_\-./\s]", "_", key)
25+
26+
27+
def _compute_config_for_logging(args) -> dict[str, str]:
28+
# Build a flat param dict from *args*, mirroring ``wandb_utils._compute_config_for_logging``."""
29+
raw = deepcopy(args.__dict__)
30+
31+
whitelist_env_vars = ["SLURM_JOB_ID"]
32+
raw["env_vars"] = {k: v for k, v in os.environ.items() if k in whitelist_env_vars}
33+
34+
return _flatten_dict(raw)
35+
36+
37+
def _flatten_dict(d: dict, parent_key: str = "", sep: str = ".") -> dict[str, str]:
38+
# Recursively flatten nested dicts into ``dotted.key`` → ``str(value)`` pairs.
39+
items: list[tuple[str, str]] = []
40+
for k, v in d.items():
41+
new_key = f"{parent_key}{sep}{k}" if parent_key else k
42+
if isinstance(v, dict):
43+
items.extend(_flatten_dict(v, new_key, sep).items())
44+
else:
45+
items.append((new_key, str(v)))
46+
return dict(items)
47+
48+
49+
def init_mlflow(args, *, primary: bool = True, **kwargs) -> None:
50+
if not args.use_mlflow:
51+
args.mlflow_run_id = None
52+
return
53+
54+
import mlflow
55+
56+
tracking_uri = args.mlflow_tracking_uri or os.environ.get("MLFLOW_TRACKING_URI")
57+
if tracking_uri:
58+
mlflow.set_tracking_uri(tracking_uri)
59+
logger.info("MLflow tracking URI: %s", tracking_uri)
60+
61+
experiment_name = args.mlflow_experiment_name
62+
mlflow.set_experiment(experiment_name)
63+
64+
if primary:
65+
_init_mlflow_primary(args, experiment_name)
66+
else:
67+
_init_mlflow_secondary(args)
68+
69+
70+
def _init_mlflow_primary(args, experiment_name: str) -> None:
71+
import mlflow
72+
73+
run_name = args.mlflow_run_name or args.wandb_group
74+
75+
tags = {}
76+
slurm_job_id = os.environ.get("SLURM_JOB_ID")
77+
if slurm_job_id:
78+
tags["slurm_job_id"] = slurm_job_id
79+
tags["rank"] = str(args.rank)
80+
81+
run = mlflow.start_run(run_name=run_name, tags=tags)
82+
mlflow.log_params(_compute_config_for_logging(args))
83+
84+
args.mlflow_run_id = run.info.run_id
85+
logger.info("MLflow run started: %s (experiment=%s, name=%s)", run.info.run_id, experiment_name, run_name)
86+
87+
88+
def _init_mlflow_secondary(args) -> None:
89+
"""Attach to an existing MLflow run created by the primary rank."""
90+
import mlflow
91+
92+
run_id = args.mlflow_run_id or os.environ.get("MLFLOW_RUN_ID")
93+
if run_id is None:
94+
return
95+
96+
mlflow.start_run(run_id=run_id)
97+
logger.info("MLflow secondary attached to run: %s", run_id)
98+
99+
100+
# ---------------------------------------------------------------------------
101+
# Logging
102+
# ---------------------------------------------------------------------------
103+
104+
def log_metrics(metrics: dict[str, Any], step: int | None = None) -> None:
105+
import mlflow
106+
107+
if mlflow.active_run() is None:
108+
return
109+
110+
sanitized: dict[str, float] = {}
111+
for k, v in metrics.items():
112+
if k.endswith("/step"):
113+
continue
114+
try:
115+
sanitized[_sanitize_key(k)] = float(v)
116+
except (TypeError, ValueError):
117+
continue
118+
119+
if sanitized:
120+
mlflow.log_metrics(sanitized, step=int(step) if step is not None else None)
121+
122+
123+
# ---------------------------------------------------------------------------
124+
# Cleanup
125+
# ---------------------------------------------------------------------------
126+
127+
def finish() -> None:
128+
import mlflow
129+
130+
if mlflow.active_run() is None:
131+
return
132+
133+
run_id = mlflow.active_run().info.run_id
134+
mlflow.end_run()
135+
logger.info("MLflow run ended: %s", run_id)

0 commit comments

Comments
 (0)