Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -14,3 +14,5 @@ wheels/
.claude
.vscode
logs/
# created by `uv run mkdocs serve` and such.
site/
7 changes: 7 additions & 0 deletions cluv/__main__.py
Original file line number Diff line number Diff line change
Expand Up @@ -174,6 +174,13 @@ def add_sync_args(
"Use a comma to separate multiple clusters."
),
)
sync_parser.add_argument(
"--sync-datasets",
dest="sync_datasets",
action=argparse.BooleanOptionalAction,
default=True,
help="Push datasets from data_source to each cluster. Requires data_source in config.",
)
# TODO: Try to add a 'remainder' arg to pass extra args to `uv sync` on the remote cluster, but it seems to be a bit tricky.
# sync_parser.add_argument(
# "--",
Expand Down
228 changes: 163 additions & 65 deletions cluv/cli/sync.py

Large diffs are not rendered by default.

104 changes: 94 additions & 10 deletions cluv/config.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,46 +2,114 @@

from __future__ import annotations

import dataclasses
import functools
import logging
import tomllib
from dataclasses import field
from pathlib import Path

from pydantic import BaseModel
from pydantic.dataclasses import dataclass

from cluv.utils import current_cluster, resolve_env_vars

logger = logging.getLogger(__name__)


class ClusterConfig(BaseModel):
@dataclass(frozen=True)
class PartialClusterConfig:
"""Per-cluster configuration options."""

env: dict[str, str] = {}
env: dict[str, str] = field(default_factory=dict)
"""Environment variables to set when running Slurm commands on this cluster."""

results_path: str | None = None # TODO: Change to `Path` instead. Fix any pydantic errors.
"""Path to the results directory for a specific cluster."""

datasets_path: str | None = None # TODO: Change to `Path` instead. Fix any pydantic errors.
"""Different path where the datasets should be replicated on this cluster.

When `None`, this defaults to the top-level config's `datasets_path`.

This folder will be synced from the current cluster to all other clusters at their respective `dataset_path`.
"""


@dataclass(frozen=True)
class ClusterConfig:
"""Per-cluster configuration options."""

env: dict[str, str]
"""Environment variables to set when running Slurm commands on this cluster."""

results_path: Path
"""Path to the results directory for a specific cluster."""

datasets_path: Path | None
"""Different path where the datasets should be replicated on this cluster.

When `None`, this defaults to the top-level config's `datasets_path`.

This folder will be synced from the current cluster to all other clusters at their respective `dataset_path`.
"""

def resolve_env_vars_in_paths(self):
return ClusterConfig(
env=self.env,
results_path=resolve_env_vars(self.results_path),
datasets_path=resolve_env_vars(self.datasets_path) if self.datasets_path else None,
)


class CluvConfig(BaseModel):
"""Configuration options for Cluv, loaded from the pyproject.toml file."""

env: dict[str, str] = {}
"""Global environment variables set on all clusters when running Slurm commands."""

results_path: str
"""Path to the results directory, relative to the project root.
"""Default path to the results directory for all clusters (may contain env vars like $SCRATCH)."""

!!! info
On Slurm clusters, this will be a symlink to a folder in `$SCRATCH/<results_path>/<project_name>`.
results_symlink: str | None = None
"""Name of the symlink created in the project directory pointing to results_path.

When None, defaults to Path(results_path).name (the last component of results_path).
"""

env: dict[str, str] = {}
"""Global environment variables set on all clusters when running Slurm commands."""
data_source: str | None = None
"""`hostname:/path` of where to get the data from."""

datasets_path: str | None = None
"""Path to a dataset directory, for example, `'$SCRATCH/my_dataset'`

This folder will be synced from the current cluster to all other clusters at their respective `dataset_path`.
"""

clusters: dict[str, ClusterConfig] = {}
clusters: dict[str, PartialClusterConfig] = {}
"""Configuration options for each cluster.

The keys are cluster names; each value is a `ClusterConfig` whose `env` dict contains
environment variables to set when running Slurm commands on that cluster.
The keys are cluster names, and values are configs that override options for that cluster.
"""

@property
def clusters_names(self) -> list[str]:
return list(self.clusters.keys())

def get_cluster_config(self, cluster: str) -> ClusterConfig:
"""Returns the cluster config for a specific cluster.

The environment variables as part of paths will *not* be resolved.
"""
cluv_config = load_cluv_config(find_pyproject())
cluster_config = cluv_config.clusters[cluster]
datasets_path = cluster_config.datasets_path or cluv_config.datasets_path
return ClusterConfig(
env=cluv_config.env | cluster_config.env,
results_path=Path(cluster_config.results_path or cluv_config.results_path),
datasets_path=Path(datasets_path) if datasets_path else None,
)


@functools.cache
def get_config() -> CluvConfig:
Expand Down Expand Up @@ -82,3 +150,19 @@ def load_cluv_config(pyproject_path: Path) -> CluvConfig:
def get_cluster_choices() -> list[str]:
"""Return configured clusters or the defaults when config is missing/invalid."""
return get_config().clusters_names


def current_cluster_config() -> ClusterConfig | None:
"""Returns the `ClusterConfig` of the current cluster, or None if not currently on a cluster."""
cluster = current_cluster()
if not cluster:
return None # not on a cluster.
cluv_config = load_cluv_config(find_pyproject())
data_source = cluv_config.data_source
config = cluv_config.get_cluster_config(cluster)
if data_source:
source_cluster, data_path = data_source.split(":", 1)
if cluster == source_cluster:
# use the dataset path from the data_source setting as the datasets_path.
config = dataclasses.replace(config, datasets_path=data_path)
return config.resolve_env_vars_in_paths()
129 changes: 129 additions & 0 deletions cluv/job.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,129 @@
"""A script that reads something, and produces some output.

This is a simplified job script, used to test the syncing of the 'dataset' across clusters.
"""

import functools
import os
import re
import subprocess
from dataclasses import dataclass
from pathlib import Path

import cluv
import cluv.config
from cluv.utils import current_cluster

SLURM_JOB_ID: int | None = (
int(os.environ["SLURM_JOB_ID"]) if "SLURM_JOB_ID" in os.environ else None
)
SCRATCH = Path(os.environ["SCRATCH"]) if "SCRATCH" in os.environ else None
SLURM_TMPDIR = Path(os.environ["SLURM_TMPDIR"]) if "SLURM_TMPDIR" in os.environ else None
SLURM_PROCID = int(os.environ["SLURM_PROCID"]) if "SLURM_PROCID" in os.environ else None


in_job_packing = "SLURM_NTASKS_PER_GPU" in os.environ
in_job_array = "SLURM_ARRAY_JOB_ID" in os.environ


@dataclass(frozen=True)
class JobInfo:
"""Information about a "job"/"run".

Note, there may be multiple "runs" inside a single "job", that's why there is a distinction.
"""

cluster: str

run_id: str
"""The unique 'run identifier' for this job/run, used for checkpointing and Weights & Biases.

This will usually just be {cluster}_{SLURM_JOB_ID}, but can also vary based on whether
the job is doing job packing (with --ntasks-per-gpu) or job chunking (with --array=...%1) or
both.

Use this as the run_id for `wandb.init` or whenever you need a unique run identifier.
"""

results_path: Path

@property
def datasets_path(self) -> Path | None:
"""The path where the datasets are located for this job (based on which cluster it runs on.)"""
cluster_info = cluv.config.current_cluster_config()
assert cluster_info
return cluster_info.datasets_path


def current_job_info() -> JobInfo | None:
"""Returns information about the current job, such as its unique run id and results path.

This is useful to determine where to save checkpoints or results for this job, and to have a unique
identifier for this job that can be used in Weights & Biases or elsewhere.

The 'run id' is determined based on the cluster name and SLURM job id, and also takes into account
whether the job is doing job packing (with --ntasks-per-gpu) or job chunking (with --array=...%1).
"""
if not SLURM_JOB_ID:
return None # not in a Slurm job.
cluster = current_cluster()
run_id = _get_run_id()
# IDEA: maybe load the cluv config and set the checkpoint_dir
# from cluv.config import load_cluv_config
assert cluster, "Example must be run on a cluster."
config = cluv.config.current_cluster_config()
assert config, "Example must be run on a cluster."
assert config.results_path
assert config.datasets_path
return JobInfo(
run_id=run_id,
cluster=cluster,
results_path=config.results_path / run_id,
)


@functools.cache
def _get_max_active_jobs() -> int | None:
"""When in a job array, returns the max number of active jobs at the same time.

For example, with --array=0-20%4, this returns 4.
Returns `None` when not in a job array.
Result is cached since this calls scontrol in a subprocess.
"""
if "SLURM_ARRAY_JOB_ID" not in os.environ:
return None
output = subprocess.check_output(
["scontrol", "--oneliner", "show", "job", os.environ["SLURM_ARRAY_JOB_ID"]],
text=True,
)
match = re.search(r"ArrayTaskId=\S+%(\d+)", output)
return int(match.group(1)) if match else None


def _in_job_chunking() -> bool:
return in_job_array and _get_max_active_jobs() == 1


def _get_run_id():
cluster = current_cluster()
doing_job_packing = "SLURM_NTASKS_PER_GPU" in os.environ
doing_job_chunking = _in_job_chunking()
task_index = int(os.environ["SLURM_PROCID"])
if doing_job_chunking:
# IF we have --array=...%1, use the id of the first job.
first_job_id = int(os.environ["SLURM_ARRAY_JOB_ID"])
if doing_job_packing:
# Running with --array=0-5%1 for chunking and --ntasks-per-gpu for packing! Awesome!!
return f"{cluster}_{first_job_id}_task{task_index}"
# IDEA: If we support doing an arrays of 'chunked' jobs, then we could use this:
# IF we have --array=0-20%4, this means there are 4 jobs with 5 chunks each (weird).
# max_active_jobs = get_max_active_jobs()
# assert max_active_jobs is not None and max_active_jobs > 1
# index_in_array = int(os.environ["SLURM_ARRAY_TASK_ID"])
# return str(first_job_id + (index_in_array % max_active_jobs))
# Keeping it simple for now, only support chunking with --array=...%1, so we always use
# the id of the first job in the array.
return f"{cluster}_{first_job_id}"
if doing_job_packing:
return f"{cluster}_{SLURM_JOB_ID}_task{SLURM_PROCID}"
return f"{cluster}_{SLURM_JOB_ID}"
60 changes: 42 additions & 18 deletions cluv/remote.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,16 @@
from __future__ import annotations

import asyncio
import contextlib
import dataclasses
import functools
import shlex
import subprocess
import sys
from logging import getLogger as get_logger
from typing import Callable, Literal, Self, TypeVar

from cluv.utils import console
from cluv.utils import console, console_lock

logger = get_logger(__name__)

Expand Down Expand Up @@ -70,17 +72,24 @@ async def run(
command,
)

_display = False
if display:
console.log(
(
f"({self.hostname}) $ {command}"
if input is None
else f"({self.hostname}) $ {command=}\n{input=}"
),
style="green",
_stack_offset=2, # to show a link to the code calling this, instead of here.
# Pass what to display to `run`, which uses a lock to keep the command and its output
# together in the console, instead of interleaving with other commands' outputs.
# Commands start running (and may error out) before being shown in the terminal though.
_display = (
f"({self.hostname}) $ {command}"
if input is None
else f"({self.hostname}) $ {command=}\n{input=}"
)
return await run(ssh_command, input=input, warn=warn, hide=hide, _stacklevel=3)
return await run(
ssh_command,
input=input,
warn=warn,
hide=hide,
_display=_display,
_stacklevel=3,
)

async def get_output(
self,
Expand All @@ -100,6 +109,7 @@ async def run(
warn: bool = False,
hide: Hide = False,
_stacklevel: int = 2,
_display: bool | str = False,
) -> subprocess.CompletedProcess[str]:
"""Runs the command *asynchronously* in a subprocess and returns the result.

Expand Down Expand Up @@ -167,14 +177,28 @@ async def run(
stdout=stdout.decode(),
stderr=stderr.decode(),
)
if result.stdout:
if hide not in [True, "out", "stdout"]:
print(result.stdout)
logger.debug(result.stdout)
if result.stderr:
if hide not in [True, "err", "stderr"]:
print(result.stderr, file=sys.stderr)
logger.debug(result.stderr)
async with console_lock.get() or contextlib.nullcontext():
if _display:
console.log(
_display
if isinstance(_display, str)
else (
f"$ {shlex.join(program_and_args)}"
if input is None
else f"$ {program_and_args=}\n{input=}"
),
style="green",
_stack_offset=_stacklevel
- 1, # to show a link to the code calling this, instead of here.
)
if result.stdout:
if hide not in [True, "out", "stdout"]:
print(result.stdout)
logger.debug(result.stdout)
if result.stderr:
if hide not in [True, "err", "stderr"]:
print(result.stderr, file=sys.stderr)
logger.debug(result.stderr)
return result


Expand Down
Loading
Loading