diff --git a/setup.py b/setup.py index 91df81c5e..748e60260 100644 --- a/setup.py +++ b/setup.py @@ -39,7 +39,10 @@ def get_tag(self): extras_require={ "fsdp": [ "torch>=2.0", - ] + ], + "mlflow": [ + "mlflow", + ], }, python_requires=">=3.10", classifiers=[ diff --git a/slime/utils/arguments.py b/slime/utils/arguments.py index 3c7a35496..26b36bd1e 100644 --- a/slime/utils/arguments.py +++ b/slime/utils/arguments.py @@ -1111,6 +1111,30 @@ def add_wandb_arguments(parser): parser.add_argument("--wandb-run-id", type=str, default=None) return parser + # mlflow + def add_mlflow_arguments(parser): + parser.add_argument("--use-mlflow", action="store_true", default=False) + parser.add_argument( + "--mlflow-tracking-uri", + type=str, + default=None, + help="MLflow tracking server URI. Defaults to MLFLOW_TRACKING_URI env var, or local mlruns/ directory.", + ) + parser.add_argument( + "--mlflow-experiment-name", + type=str, + default="slime", + help="MLflow experiment name.", + ) + parser.add_argument( + "--mlflow-run-name", + type=str, + default=None, + help="MLflow run name. Defaults to --wandb-group if not set.", + ) + parser.add_argument("--mlflow-run-id", type=str, default=None) + return parser + # tensorboard def add_tensorboard_arguments(parser): # tb_project_name, tb_experiment_name @@ -1404,6 +1428,7 @@ def add_ci_arguments(parser): parser = add_algo_arguments(parser) parser = add_on_policy_distillation_arguments(parser) parser = add_wandb_arguments(parser) + parser = add_mlflow_arguments(parser) parser = add_tensorboard_arguments(parser) parser = add_router_arguments(parser) parser = add_debug_arguments(parser) diff --git a/slime/utils/logging_utils.py b/slime/utils/logging_utils.py index 5fc0fad35..5bea4bdf8 100644 --- a/slime/utils/logging_utils.py +++ b/slime/utils/logging_utils.py @@ -1,11 +1,9 @@ import logging -import wandb - -from . import wandb_utils -from .tensorboard_utils import _TensorboardAdapter +from .tracking import TrackingManager _LOGGER_CONFIGURED = False +_manager = TrackingManager() # ref: SGLang @@ -25,17 +23,13 @@ def configure_logger(prefix: str = ""): def init_tracking(args, primary: bool = True, **kwargs): - if primary: - wandb_utils.init_wandb_primary(args, **kwargs) - else: - wandb_utils.init_wandb_secondary(args, **kwargs) + _manager.init(args, primary=primary, **kwargs) -# TODO further refactor, e.g. put TensorBoard init to the "init" part def log(args, metrics, step_key: str): - if args.use_wandb: - wandb.log(metrics) + step = metrics.get(step_key) + _manager.log(metrics, step=step) + - if args.use_tensorboard: - metrics_except_step = {k: v for k, v in metrics.items() if k != step_key} - _TensorboardAdapter(args).log(data=metrics_except_step, step=metrics[step_key]) +def finish_tracking(): + _manager.finish() diff --git a/slime/utils/tracking/__init__.py b/slime/utils/tracking/__init__.py new file mode 100644 index 000000000..3c9d2ed8f --- /dev/null +++ b/slime/utils/tracking/__init__.py @@ -0,0 +1,10 @@ +""" +Shared tracking interface for experiment logging backends. + +Exports :class:`TrackingManager` so existing ``from .tracking import TrackingManager`` +imports continue to work. +""" + +from .manager import TrackingBackend, TrackingManager # noqa: F401 +from .mlflow_utils import finish as mlflow_finish # noqa: F401 +from .mlflow_utils import init_mlflow, log_metrics # noqa: F401 diff --git a/slime/utils/tracking/manager.py b/slime/utils/tracking/manager.py new file mode 100644 index 000000000..ebf70ad2b --- /dev/null +++ b/slime/utils/tracking/manager.py @@ -0,0 +1,140 @@ +""" +Shared tracking interface for experiment logging backends. + +Each backend implements ``init / log / finish``, and :class:`TrackingManager` fans out +calls to every active backend. + +To add a new backend: +-------------------- +1. Subclass :class:`TrackingBackend`. +2. Register it in :data:`BACKEND_REGISTRY`. +3. Add a corresponding ``--use-`` CLI flag in ``arguments.py``. +""" + +from __future__ import annotations + +import logging +from abc import ABC, abstractmethod +from typing import Any + +logger = logging.getLogger(__name__) + + + +class TrackingBackend(ABC): + # Interface every logging backend must satisfy. + + @abstractmethod + def init(self, args, *, primary: bool = True, **kwargs) -> None: + ... + + @abstractmethod + def log(self, metrics: dict[str, Any], step: int | None = None) -> None: + ... + + @abstractmethod + def finish(self) -> None: + ... + + +# Thin adapters for backwards compatibility to keep wandb_utils and tensorboard_utils untouched. +class WandbBackend(TrackingBackend): + # Delegates to the existing ``wandb_utils`` helpers. + + def init(self, args, *, primary: bool = True, **kwargs) -> None: + from .. import wandb_utils + + if primary: + wandb_utils.init_wandb_primary(args, **kwargs) + else: + wandb_utils.init_wandb_secondary(args, **kwargs) + + def log(self, metrics: dict[str, Any], step: int | None = None) -> None: + import wandb + + wandb.log(metrics) + + def finish(self) -> None: + import wandb + + wandb.finish() + + +class TensorboardBackend(TrackingBackend): + # Delegates to the existing ``_TensorboardAdapter`` (part of the TODO). + + _adapter = None + + def init(self, args, *, primary: bool = True, **kwargs) -> None: + from ..tensorboard_utils import _TensorboardAdapter + + self._adapter = _TensorboardAdapter(args) + + def log(self, metrics: dict[str, Any], step: int | None = None) -> None: + if self._adapter is not None: + # Strip step-key entries (e.g. "train/step", "rollout/step") — + # tensorboard receives step as an explicit argument instead. + data = {k: v for k, v in metrics.items() if not k.endswith("/step")} + self._adapter.log(data=data, step=step) + + def finish(self) -> None: + if self._adapter is not None: + self._adapter.finish() + + +class MlflowBackend(TrackingBackend): + """Delegates to ``mlflow_utils``.""" + + def init(self, args, *, primary: bool = True, **kwargs) -> None: + from . import mlflow_utils + + mlflow_utils.init_mlflow(args, primary=primary, **kwargs) + + def log(self, metrics: dict[str, Any], step: int | None = None) -> None: + from . import mlflow_utils + + mlflow_utils.log_metrics(metrics, step=step) + + def finish(self) -> None: + from . import mlflow_utils + + mlflow_utils.finish() + + +# Registry that maps backend name → (class, args-flag attribute) + +BACKEND_REGISTRY: dict[str, tuple[type[TrackingBackend], str]] = { + "wandb": (WandbBackend, "use_wandb"), + "tensorboard": (TensorboardBackend, "use_tensorboard"), + "mlflow": (MlflowBackend, "use_mlflow"), +} + + +class TrackingManager: + #Initialises and logs to every enabled backend; used internally by ``logging_utils``. + + def __init__(self) -> None: + self._backends: list[TrackingBackend] = [] + + def init(self, args, *, primary: bool = True, **kwargs) -> None: + for name, (cls, flag) in BACKEND_REGISTRY.items(): + if getattr(args, flag, False): + logger.info("Initialising tracking backend: %s", name) + backend = cls() + backend.init(args, primary=primary, **kwargs) + self._backends.append(backend) + + def log(self, metrics: dict[str, Any], step: int | None = None) -> None: + for backend in self._backends: + backend.log(metrics, step=step) + + def finish(self) -> None: + for backend in self._backends: + try: + backend.finish() + except Exception: + logger.exception( + "Error finishing tracking backend %s", + type(backend).__name__, + ) + self._backends.clear() diff --git a/slime/utils/tracking/mlflow_utils.py b/slime/utils/tracking/mlflow_utils.py new file mode 100644 index 000000000..5ae7dffd2 --- /dev/null +++ b/slime/utils/tracking/mlflow_utils.py @@ -0,0 +1,135 @@ +""" +MLflow tracking backend for slime. + + +MLflow docs for future reference: + - Tracking overview : https://mlflow.org/docs/latest/ml/tracking/ + - Python API : https://mlflow.org/docs/latest/python_api/mlflow.html + - Remote tracking : https://mlflow.org/docs/latest/tracking/server.html +""" + +from __future__ import annotations + +import logging +import os +import re +from copy import deepcopy +from typing import Any + +logger = logging.getLogger(__name__) + + +# Helpers/utils +def _sanitize_key(key: str) -> str: + return re.sub(r"[^a-zA-Z0-9_\-./\s]", "_", key) + + +def _compute_config_for_logging(args) -> dict[str, str]: + # Build a flat param dict from *args*, mirroring ``wandb_utils._compute_config_for_logging``.""" + raw = deepcopy(args.__dict__) + + whitelist_env_vars = ["SLURM_JOB_ID"] + raw["env_vars"] = {k: v for k, v in os.environ.items() if k in whitelist_env_vars} + + return _flatten_dict(raw) + + +def _flatten_dict(d: dict, parent_key: str = "", sep: str = ".") -> dict[str, str]: + # Recursively flatten nested dicts into ``dotted.key`` → ``str(value)`` pairs. + items: list[tuple[str, str]] = [] + for k, v in d.items(): + new_key = f"{parent_key}{sep}{k}" if parent_key else k + if isinstance(v, dict): + items.extend(_flatten_dict(v, new_key, sep).items()) + else: + items.append((new_key, str(v))) + return dict(items) + + +def init_mlflow(args, *, primary: bool = True, **kwargs) -> None: + if not args.use_mlflow: + args.mlflow_run_id = None + return + + import mlflow + + tracking_uri = args.mlflow_tracking_uri or os.environ.get("MLFLOW_TRACKING_URI") + if tracking_uri: + mlflow.set_tracking_uri(tracking_uri) + logger.info("MLflow tracking URI: %s", tracking_uri) + + experiment_name = args.mlflow_experiment_name + mlflow.set_experiment(experiment_name) + + if primary: + _init_mlflow_primary(args, experiment_name) + else: + _init_mlflow_secondary(args) + + +def _init_mlflow_primary(args, experiment_name: str) -> None: + import mlflow + + run_name = args.mlflow_run_name or args.wandb_group + + tags = {} + slurm_job_id = os.environ.get("SLURM_JOB_ID") + if slurm_job_id: + tags["slurm_job_id"] = slurm_job_id + tags["rank"] = str(args.rank) + + run = mlflow.start_run(run_name=run_name, tags=tags) + mlflow.log_params(_compute_config_for_logging(args)) + + args.mlflow_run_id = run.info.run_id + logger.info("MLflow run started: %s (experiment=%s, name=%s)", run.info.run_id, experiment_name, run_name) + + +def _init_mlflow_secondary(args) -> None: + """Attach to an existing MLflow run created by the primary rank.""" + import mlflow + + run_id = args.mlflow_run_id or os.environ.get("MLFLOW_RUN_ID") + if run_id is None: + return + + mlflow.start_run(run_id=run_id) + logger.info("MLflow secondary attached to run: %s", run_id) + + +# --------------------------------------------------------------------------- +# Logging +# --------------------------------------------------------------------------- + +def log_metrics(metrics: dict[str, Any], step: int | None = None) -> None: + import mlflow + + if mlflow.active_run() is None: + return + + sanitized: dict[str, float] = {} + for k, v in metrics.items(): + if k.endswith("/step"): + continue + try: + sanitized[_sanitize_key(k)] = float(v) + except (TypeError, ValueError): + continue + + if sanitized: + mlflow.log_metrics(sanitized, step=int(step) if step is not None else None) + + +# --------------------------------------------------------------------------- +# Cleanup +# --------------------------------------------------------------------------- + +def finish() -> None: + import mlflow + + if mlflow.active_run() is None: + return + + run_id = mlflow.active_run().info.run_id + mlflow.end_run() + logger.info("MLflow run ended: %s", run_id)