From 2e67930bbb985f59573f9b9fb7ea84607e2d126f Mon Sep 17 00:00:00 2001 From: Fabrice Normandin Date: Tue, 26 May 2026 09:39:52 -0600 Subject: [PATCH 01/22] Start adding a test for syncing datasets Signed-off-by: Fabrice Normandin --- pyproject.toml | 1 + tests/data/.gitkeep | 1 + tests/data/dataset.txt | 1 + tests/example.py | 50 ++++++++++++++++++++++++++++++++++++++++++ tests/test_sync.py | 5 +++++ 5 files changed, 58 insertions(+) create mode 100644 tests/data/.gitkeep create mode 100644 tests/data/dataset.txt create mode 100644 tests/example.py create mode 100644 tests/test_sync.py diff --git a/pyproject.toml b/pyproject.toml index ba53530..8173389 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -82,6 +82,7 @@ show_missing = true [tool.cluv] results_path = "logs" +datasets_path = "tests/data" [tool.cluv.env] # Environment variables applied when using Slurm commands on all clusters. diff --git a/tests/data/.gitkeep b/tests/data/.gitkeep new file mode 100644 index 0000000..172660a --- /dev/null +++ b/tests/data/.gitkeep @@ -0,0 +1 @@ +dataset.csv diff --git a/tests/data/dataset.txt b/tests/data/dataset.txt new file mode 100644 index 0000000..644aa8f --- /dev/null +++ b/tests/data/dataset.txt @@ -0,0 +1 @@ +This is a dummy "dataset". diff --git a/tests/example.py b/tests/example.py new file mode 100644 index 0000000..3eb873c --- /dev/null +++ b/tests/example.py @@ -0,0 +1,50 @@ +"""A script that reads something, and produces some output. + +This is a simplified job script, used to test the syncing of the 'dataset' across clusters. +""" + +import os +import time +from dataclasses import dataclass +from pathlib import Path + +import simple_parsing + +SLURM_JOB_ID = int(os.environ["SLURM_JOB_ID"]) +SCRATCH = Path(os.environ["SCRATCH"]) +SLURM_TMPDIR = Path(os.environ["SLURM_TMPDIR"]) + +# IDEA: maybe load the cluv config and set the checkpoint_dir +# from cluv.config import load_cluv_config + + +@dataclass(frozen=True) +class Args: + # NOTE: This should be the same as the `results_path` in the Cluv config. + results_path: Path = SCRATCH / "logs" / "cluv" / str(SLURM_JOB_ID) + + # NOTE: This should be the same as the `datasets_path` in the Cluv config. + datasets_path: Path = Path("tests/data/dataset.csv") + + # Time to wait before producing the result. + # Can be useful to test and simulate preemption or cancelling jobs. + wait_duration_seconds: int = 0 + + +def main(args: Args | None = None): + args = args or simple_parsing.parse(Args, description=__doc__) + print(f"Job {SLURM_JOB_ID} starts.") + + dataset = args.datasets_path.read_text() + assert dataset.strip() == 'This is a dummy "dataset".' + + time.sleep(args.wait_duration_seconds) + + print(f"Job {SLURM_JOB_ID} is about to end.") + results_file = args.results_path / "results.txt" + with results_file.open("a") as f: + f.write(f"This is the result of job {SLURM_JOB_ID}\n") + + +if __name__ == "__main__": + main() diff --git a/tests/test_sync.py b/tests/test_sync.py new file mode 100644 index 0000000..aa15a7c --- /dev/null +++ b/tests/test_sync.py @@ -0,0 +1,5 @@ +"""Tests for `cluv sync`""" + + +def test_cluv_sync_with_data_path(): + """TODO: Test for `cluv sync` with a project that has a 'data_path'.""" From c1ae3b8a51b8afaa01740a4133a6ef8d4a4dbe3f Mon Sep 17 00:00:00 2001 From: Fabrice Normandin Date: Tue, 26 May 2026 11:31:33 -0600 Subject: [PATCH 02/22] Add some ideas in pyproject.toml Signed-off-by: Fabrice Normandin --- cluv/config.py | 15 ++++++++ pyproject.toml | 22 +++++++++++- tests/example.py | 9 +++-- tests/test_sync.py | 6 +++- uv.lock | 88 ++++++++++++++++++++++++++++++++++++++++++++++ 5 files changed, 135 insertions(+), 5 deletions(-) diff --git a/cluv/config.py b/cluv/config.py index de5fbf0..8e5053c 100644 --- a/cluv/config.py +++ b/cluv/config.py @@ -6,6 +6,7 @@ import logging import tomllib from pathlib import Path + from pydantic import BaseModel logger = logging.getLogger(__name__) @@ -17,6 +18,14 @@ class ClusterConfig(BaseModel): env: dict[str, str] = {} """Environment variables to set when running Slurm commands on this cluster.""" + datasets_path: str | None + """Different path where the datasets should be replicated on this cluster. + + When `None`, this defaults to the top-level config's `datasets_path`. + + This folder will be synced from the current cluster to all other clusters at their respective `dataset_path`. + """ + class CluvConfig(BaseModel): """Configuration options for Cluv, loaded from the pyproject.toml file.""" @@ -28,6 +37,12 @@ class CluvConfig(BaseModel): On Slurm clusters, this will be a symlink to a folder in `$SCRATCH//`. """ + datasets_path: str + """Path to a dataset directory, for example, `'$SCRATCH/my_dataset'` + + This folder will be synced from the current cluster to all other clusters at their respective `dataset_path`. + """ + env: dict[str, str] = {} """Global environment variables set on all clusters when running Slurm commands.""" diff --git a/pyproject.toml b/pyproject.toml index 8173389..abb66ac 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -29,6 +29,7 @@ dev = [ "pytest-timeout>=2.3.1", "ruff", "pytest-skip-slow>=0.0.5", + "torchvision>=0.25.0", ] [tool.pytest.ini_options] @@ -81,8 +82,17 @@ source = ["cluv"] show_missing = true [tool.cluv] +# Where to store results on all clusters. TODO: Can be overridden in each cluster's config. results_path = "logs" -datasets_path = "tests/data" + +# Which cluster to get the data from. When unset, assumes the current cluster is the source. +data_source_cluster = "mila" +# Where to read the data from on the source cluster. +data_source_path = "/network/datasets/cifar10" + +# Where the dataset should be replicated on all clusters. +# On the source cluster (ex Mila), the folder could contain symlinks, to avoid duplicating the data. +datasets_path = "$SCRATCH/data/cifar10" [tool.cluv.env] # Environment variables applied when using Slurm commands on all clusters. @@ -91,34 +101,44 @@ SBATCH_REQUEUE = "1" [tool.cluv.clusters.mila] env = {UV_OFFLINE="0", WANDB_MODE="online"} +datasets_path = "$SCRATCH/data/cifar10" # PAICE clusters. [tool.cluv.clusters.tamia] env = {UV_OFFLINE="1", WANDB_MODE="offline"} +datasets_path = "$SCRATCH/data/cifar10" [tool.cluv.clusters.killarney] env = {UV_OFFLINE="1", WANDB_MODE="offline"} +datasets_path = "$SCRATCH/data/cifar10" [tool.cluv.clusters.vulcan] env = {UV_OFFLINE="1", WANDB_MODE="offline"} +datasets_path = "$SCRATCH/data/cifar10" # DRAC clusters. [tool.cluv.clusters.rorqual] env = {UV_OFFLINE="1", WANDB_MODE="offline", SBATCH_ACCOUNT="rrg-bengioy-ad"} +datasets_path = "$SCRATCH/data/cifar10" [tool.cluv.clusters.fir] env = {UV_OFFLINE="0", WANDB_MODE="online", SBATCH_ACCOUNT="rrg-bengioy-ad"} +datasets_path = "$SCRATCH/data/cifar10" [tool.cluv.clusters.nibi] env = {UV_OFFLINE="0", WANDB_MODE="online", SBATCH_ACCOUNT="rrg-bengioy-ad"} +datasets_path = "$SCRATCH/data/cifar10" [tool.cluv.clusters.trillium] env = {UV_OFFLINE="1", WANDB_MODE="offline", SBATCH_ACCOUNT="rrg-bengioy-ad"} +datasets_path = "$SCRATCH/data/cifar10" [tool.cluv.clusters.trillium-gpu] env = {UV_OFFLINE="1", WANDB_MODE="offline", SBATCH_ACCOUNT="rrg-bengioy-ad"} +datasets_path = "$SCRATCH/data/cifar10" [tool.cluv.clusters.narval] # Mila doesn't have an allocation on Narval anymore. # You can also use "def-yourusername" (the default partitions). env = {UV_OFFLINE="1", WANDB_MODE="offline", SBATCH_ACCOUNT="def-bengioy"} +datasets_path = "$SCRATCH/data/cifar10" diff --git a/tests/example.py b/tests/example.py index 3eb873c..9f6a4c3 100644 --- a/tests/example.py +++ b/tests/example.py @@ -9,6 +9,7 @@ from pathlib import Path import simple_parsing +from torchvision.datasets import CIFAR10 SLURM_JOB_ID = int(os.environ["SLURM_JOB_ID"]) SCRATCH = Path(os.environ["SCRATCH"]) @@ -20,11 +21,13 @@ @dataclass(frozen=True) class Args: + """Command-line arguments for this example.""" + # NOTE: This should be the same as the `results_path` in the Cluv config. results_path: Path = SCRATCH / "logs" / "cluv" / str(SLURM_JOB_ID) # NOTE: This should be the same as the `datasets_path` in the Cluv config. - datasets_path: Path = Path("tests/data/dataset.csv") + datasets_path: Path = SCRATCH / "data" / "cifar10" # Time to wait before producing the result. # Can be useful to test and simulate preemption or cancelling jobs. @@ -35,8 +38,8 @@ def main(args: Args | None = None): args = args or simple_parsing.parse(Args, description=__doc__) print(f"Job {SLURM_JOB_ID} starts.") - dataset = args.datasets_path.read_text() - assert dataset.strip() == 'This is a dummy "dataset".' + dataset = CIFAR10(args.datasets_path) + print(dataset) time.sleep(args.wait_duration_seconds) diff --git a/tests/test_sync.py b/tests/test_sync.py index aa15a7c..f07ab08 100644 --- a/tests/test_sync.py +++ b/tests/test_sync.py @@ -2,4 +2,8 @@ def test_cluv_sync_with_data_path(): - """TODO: Test for `cluv sync` with a project that has a 'data_path'.""" + """TODO: Test for `cluv sync` with a project that has a 'data_path'. + + + Need to check that rsync happens from `datasets_path` (the source) to the `datasets_path` (the dest) on all the clusters. + """ diff --git a/uv.lock b/uv.lock index bf2da2b..9e6d60c 100644 --- a/uv.lock +++ b/uv.lock @@ -272,6 +272,7 @@ dev = [ { name = "pytest-skip-slow" }, { name = "pytest-timeout" }, { name = "ruff" }, + { name = "torchvision" }, { name = "uv-dynamic-versioning" }, ] @@ -295,6 +296,7 @@ dev = [ { name = "pytest-skip-slow", specifier = ">=0.0.5" }, { name = "pytest-timeout", specifier = ">=2.3.1" }, { name = "ruff" }, + { name = "torchvision", specifier = ">=0.25.0" }, { name = "uv-dynamic-versioning", specifier = ">=0.2.0" }, ] @@ -1085,6 +1087,64 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/f1/d9/7fb5aa316bc299258e68c73ba3bddbc499654a07f151cba08f6153988714/pathspec-1.1.1-py3-none-any.whl", hash = "sha256:a00ce642f577bf7f473932318056212bc4f8bfdf53128c78bbd5af0b9b20b189", size = 57328, upload-time = "2026-04-27T01:46:07.06Z" }, ] +[[package]] +name = "pillow" +version = "12.2.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/8c/21/c2bcdd5906101a30244eaffc1b6e6ce71a31bd0742a01eb89e660ebfac2d/pillow-12.2.0.tar.gz", hash = "sha256:a830b1a40919539d07806aa58e1b114df53ddd43213d9c8b75847eee6c0182b5", size = 46987819, upload-time = "2026-04-01T14:46:17.687Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/4a/01/53d10cf0dbad820a8db274d259a37ba50b88b24768ddccec07355382d5ad/pillow-12.2.0-cp313-cp313-ios_13_0_arm64_iphoneos.whl", hash = "sha256:8297651f5b5679c19968abefd6bb84d95fe30ef712eb1b2d9b2d31ca61267f4c", size = 4100837, upload-time = "2026-04-01T14:43:41.506Z" }, + { url = "https://files.pythonhosted.org/packages/0f/98/f3a6657ecb698c937f6c76ee564882945f29b79bad496abcba0e84659ec5/pillow-12.2.0-cp313-cp313-ios_13_0_arm64_iphonesimulator.whl", hash = "sha256:50d8520da2a6ce0af445fa6d648c4273c3eeefbc32d7ce049f22e8b5c3daecc2", size = 4176528, upload-time = "2026-04-01T14:43:43.773Z" }, + { url = "https://files.pythonhosted.org/packages/69/bc/8986948f05e3ea490b8442ea1c1d4d990b24a7e43d8a51b2c7d8b1dced36/pillow-12.2.0-cp313-cp313-ios_13_0_x86_64_iphonesimulator.whl", hash = "sha256:766cef22385fa1091258ad7e6216792b156dc16d8d3fa607e7545b2b72061f1c", size = 3640401, upload-time = "2026-04-01T14:43:45.87Z" }, + { url = "https://files.pythonhosted.org/packages/34/46/6c717baadcd62bc8ed51d238d521ab651eaa74838291bda1f86fe1f864c9/pillow-12.2.0-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:5d2fd0fa6b5d9d1de415060363433f28da8b1526c1c129020435e186794b3795", size = 5308094, upload-time = "2026-04-01T14:43:48.438Z" }, + { url = "https://files.pythonhosted.org/packages/71/43/905a14a8b17fdb1ccb58d282454490662d2cb89a6bfec26af6d3520da5ec/pillow-12.2.0-cp313-cp313-macosx_11_0_arm64.whl", hash = "sha256:56b25336f502b6ed02e889f4ece894a72612fe885889a6e8c4c80239ff6e5f5f", size = 4695402, upload-time = "2026-04-01T14:43:51.292Z" }, + { url = "https://files.pythonhosted.org/packages/73/dd/42107efcb777b16fa0393317eac58f5b5cf30e8392e266e76e51cff28c3d/pillow-12.2.0-cp313-cp313-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:f1c943e96e85df3d3478f7b691f229887e143f81fedab9b20205349ab04d73ed", size = 6280005, upload-time = "2026-04-01T14:43:54.242Z" }, + { url = "https://files.pythonhosted.org/packages/a8/68/b93e09e5e8549019e61acf49f65b1a8530765a7f812c77a7461bca7e4494/pillow-12.2.0-cp313-cp313-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:03f6fab9219220f041c74aeaa2939ff0062bd5c364ba9ce037197f4c6d498cd9", size = 8090669, upload-time = "2026-04-01T14:43:57.335Z" }, + { url = "https://files.pythonhosted.org/packages/4b/6e/3ccb54ce8ec4ddd1accd2d89004308b7b0b21c4ac3d20fa70af4760a4330/pillow-12.2.0-cp313-cp313-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:5cdfebd752ec52bf5bb4e35d9c64b40826bc5b40a13df7c3cda20a2c03a0f5ed", size = 6395194, upload-time = "2026-04-01T14:43:59.864Z" }, + { url = "https://files.pythonhosted.org/packages/67/ee/21d4e8536afd1a328f01b359b4d3997b291ffd35a237c877b331c1c3b71c/pillow-12.2.0-cp313-cp313-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:eedf4b74eda2b5a4b2b2fb4c006d6295df3bf29e459e198c90ea48e130dc75c3", size = 7082423, upload-time = "2026-04-01T14:44:02.74Z" }, + { url = "https://files.pythonhosted.org/packages/78/5f/e9f86ab0146464e8c133fe85df987ed9e77e08b29d8d35f9f9f4d6f917ba/pillow-12.2.0-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:00a2865911330191c0b818c59103b58a5e697cae67042366970a6b6f1b20b7f9", size = 6505667, upload-time = "2026-04-01T14:44:05.381Z" }, + { url = "https://files.pythonhosted.org/packages/ed/1e/409007f56a2fdce61584fd3acbc2bbc259857d555196cedcadc68c015c82/pillow-12.2.0-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:1e1757442ed87f4912397c6d35a0db6a7b52592156014706f17658ff58bbf795", size = 7208580, upload-time = "2026-04-01T14:44:08.39Z" }, + { url = "https://files.pythonhosted.org/packages/23/c4/7349421080b12fb35414607b8871e9534546c128a11965fd4a7002ccfbee/pillow-12.2.0-cp313-cp313-win32.whl", hash = "sha256:144748b3af2d1b358d41286056d0003f47cb339b8c43a9ea42f5fea4d8c66b6e", size = 6375896, upload-time = "2026-04-01T14:44:11.197Z" }, + { url = "https://files.pythonhosted.org/packages/3f/82/8a3739a5e470b3c6cbb1d21d315800d8e16bff503d1f16b03a4ec3212786/pillow-12.2.0-cp313-cp313-win_amd64.whl", hash = "sha256:390ede346628ccc626e5730107cde16c42d3836b89662a115a921f28440e6a3b", size = 7081266, upload-time = "2026-04-01T14:44:13.947Z" }, + { url = "https://files.pythonhosted.org/packages/c3/25/f968f618a062574294592f668218f8af564830ccebdd1fa6200f598e65c5/pillow-12.2.0-cp313-cp313-win_arm64.whl", hash = "sha256:8023abc91fba39036dbce14a7d6535632f99c0b857807cbbbf21ecc9f4717f06", size = 2463508, upload-time = "2026-04-01T14:44:16.312Z" }, + { url = "https://files.pythonhosted.org/packages/4d/a4/b342930964e3cb4dce5038ae34b0eab4653334995336cd486c5a8c25a00c/pillow-12.2.0-cp313-cp313t-macosx_10_13_x86_64.whl", hash = "sha256:042db20a421b9bafecc4b84a8b6e444686bd9d836c7fd24542db3e7df7baad9b", size = 5309927, upload-time = "2026-04-01T14:44:18.89Z" }, + { url = "https://files.pythonhosted.org/packages/9f/de/23198e0a65a9cf06123f5435a5d95cea62a635697f8f03d134d3f3a96151/pillow-12.2.0-cp313-cp313t-macosx_11_0_arm64.whl", hash = "sha256:dd025009355c926a84a612fecf58bb315a3f6814b17ead51a8e48d3823d9087f", size = 4698624, upload-time = "2026-04-01T14:44:21.115Z" }, + { url = "https://files.pythonhosted.org/packages/01/a6/1265e977f17d93ea37aa28aa81bad4fa597933879fac2520d24e021c8da3/pillow-12.2.0-cp313-cp313t-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:88ddbc66737e277852913bd1e07c150cc7bb124539f94c4e2df5344494e0a612", size = 6321252, upload-time = "2026-04-01T14:44:23.663Z" }, + { url = "https://files.pythonhosted.org/packages/3c/83/5982eb4a285967baa70340320be9f88e57665a387e3a53a7f0db8231a0cd/pillow-12.2.0-cp313-cp313t-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:d362d1878f00c142b7e1a16e6e5e780f02be8195123f164edf7eddd911eefe7c", size = 8126550, upload-time = "2026-04-01T14:44:26.772Z" }, + { url = "https://files.pythonhosted.org/packages/4e/48/6ffc514adce69f6050d0753b1a18fd920fce8cac87620d5a31231b04bfc5/pillow-12.2.0-cp313-cp313t-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:2c727a6d53cb0018aadd8018c2b938376af27914a68a492f59dfcaca650d5eea", size = 6433114, upload-time = "2026-04-01T14:44:29.615Z" }, + { url = "https://files.pythonhosted.org/packages/36/a3/f9a77144231fb8d40ee27107b4463e205fa4677e2ca2548e14da5cf18dce/pillow-12.2.0-cp313-cp313t-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:efd8c21c98c5cc60653bcb311bef2ce0401642b7ce9d09e03a7da87c878289d4", size = 7115667, upload-time = "2026-04-01T14:44:32.773Z" }, + { url = "https://files.pythonhosted.org/packages/c1/fc/ac4ee3041e7d5a565e1c4fd72a113f03b6394cc72ab7089d27608f8aaccb/pillow-12.2.0-cp313-cp313t-musllinux_1_2_aarch64.whl", hash = "sha256:9f08483a632889536b8139663db60f6724bfcb443c96f1b18855860d7d5c0fd4", size = 6538966, upload-time = "2026-04-01T14:44:35.252Z" }, + { url = "https://files.pythonhosted.org/packages/c0/a8/27fb307055087f3668f6d0a8ccb636e7431d56ed0750e07a60547b1e083e/pillow-12.2.0-cp313-cp313t-musllinux_1_2_x86_64.whl", hash = "sha256:dac8d77255a37e81a2efcbd1fc05f1c15ee82200e6c240d7e127e25e365c39ea", size = 7238241, upload-time = "2026-04-01T14:44:37.875Z" }, + { url = "https://files.pythonhosted.org/packages/ad/4b/926ab182c07fccae9fcb120043464e1ff1564775ec8864f21a0ebce6ac25/pillow-12.2.0-cp313-cp313t-win32.whl", hash = "sha256:ee3120ae9dff32f121610bb08e4313be87e03efeadfc6c0d18f89127e24d0c24", size = 6379592, upload-time = "2026-04-01T14:44:40.336Z" }, + { url = "https://files.pythonhosted.org/packages/c2/c4/f9e476451a098181b30050cc4c9a3556b64c02cf6497ea421ac047e89e4b/pillow-12.2.0-cp313-cp313t-win_amd64.whl", hash = "sha256:325ca0528c6788d2a6c3d40e3568639398137346c3d6e66bb61db96b96511c98", size = 7085542, upload-time = "2026-04-01T14:44:43.251Z" }, + { url = "https://files.pythonhosted.org/packages/00/a4/285f12aeacbe2d6dc36c407dfbbe9e96d4a80b0fb710a337f6d2ad978c75/pillow-12.2.0-cp313-cp313t-win_arm64.whl", hash = "sha256:2e5a76d03a6c6dcef67edabda7a52494afa4035021a79c8558e14af25313d453", size = 2465765, upload-time = "2026-04-01T14:44:45.996Z" }, + { url = "https://files.pythonhosted.org/packages/bf/98/4595daa2365416a86cb0d495248a393dfc84e96d62ad080c8546256cb9c0/pillow-12.2.0-cp314-cp314-ios_13_0_arm64_iphoneos.whl", hash = "sha256:3adc9215e8be0448ed6e814966ecf3d9952f0ea40eb14e89a102b87f450660d8", size = 4100848, upload-time = "2026-04-01T14:44:48.48Z" }, + { url = "https://files.pythonhosted.org/packages/0b/79/40184d464cf89f6663e18dfcf7ca21aae2491fff1a16127681bf1fa9b8cf/pillow-12.2.0-cp314-cp314-ios_13_0_arm64_iphonesimulator.whl", hash = "sha256:6a9adfc6d24b10f89588096364cc726174118c62130c817c2837c60cf08a392b", size = 4176515, upload-time = "2026-04-01T14:44:51.353Z" }, + { url = "https://files.pythonhosted.org/packages/b0/63/703f86fd4c422a9cf722833670f4f71418fb116b2853ff7da722ea43f184/pillow-12.2.0-cp314-cp314-ios_13_0_x86_64_iphonesimulator.whl", hash = "sha256:6a6e67ea2e6feda684ed370f9a1c52e7a243631c025ba42149a2cc5934dec295", size = 3640159, upload-time = "2026-04-01T14:44:53.588Z" }, + { url = "https://files.pythonhosted.org/packages/71/e0/fb22f797187d0be2270f83500aab851536101b254bfa1eae10795709d283/pillow-12.2.0-cp314-cp314-macosx_10_15_x86_64.whl", hash = "sha256:2bb4a8d594eacdfc59d9e5ad972aa8afdd48d584ffd5f13a937a664c3e7db0ed", size = 5312185, upload-time = "2026-04-01T14:44:56.039Z" }, + { url = "https://files.pythonhosted.org/packages/ba/8c/1a9e46228571de18f8e28f16fabdfc20212a5d019f3e3303452b3f0a580d/pillow-12.2.0-cp314-cp314-macosx_11_0_arm64.whl", hash = "sha256:80b2da48193b2f33ed0c32c38140f9d3186583ce7d516526d462645fd98660ae", size = 4695386, upload-time = "2026-04-01T14:44:58.663Z" }, + { url = "https://files.pythonhosted.org/packages/70/62/98f6b7f0c88b9addd0e87c217ded307b36be024d4ff8869a812b241d1345/pillow-12.2.0-cp314-cp314-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:22db17c68434de69d8ecfc2fe821569195c0c373b25cccb9cbdacf2c6e53c601", size = 6280384, upload-time = "2026-04-01T14:45:01.5Z" }, + { url = "https://files.pythonhosted.org/packages/5e/03/688747d2e91cfbe0e64f316cd2e8005698f76ada3130d0194664174fa5de/pillow-12.2.0-cp314-cp314-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:7b14cc0106cd9aecda615dd6903840a058b4700fcb817687d0ee4fc8b6e389be", size = 8091599, upload-time = "2026-04-01T14:45:04.5Z" }, + { url = "https://files.pythonhosted.org/packages/f6/35/577e22b936fcdd66537329b33af0b4ccfefaeabd8aec04b266528cddb33c/pillow-12.2.0-cp314-cp314-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:8cbeb542b2ebc6fcdacabf8aca8c1a97c9b3ad3927d46b8723f9d4f033288a0f", size = 6396021, upload-time = "2026-04-01T14:45:07.117Z" }, + { url = "https://files.pythonhosted.org/packages/11/8d/d2532ad2a603ca2b93ad9f5135732124e57811d0168155852f37fbce2458/pillow-12.2.0-cp314-cp314-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:4bfd07bc812fbd20395212969e41931001fd59eb55a60658b0e5710872e95286", size = 7083360, upload-time = "2026-04-01T14:45:09.763Z" }, + { url = "https://files.pythonhosted.org/packages/5e/26/d325f9f56c7e039034897e7380e9cc202b1e368bfd04d4cbe6a441f02885/pillow-12.2.0-cp314-cp314-musllinux_1_2_aarch64.whl", hash = "sha256:9aba9a17b623ef750a4d11b742cbafffeb48a869821252b30ee21b5e91392c50", size = 6507628, upload-time = "2026-04-01T14:45:12.378Z" }, + { url = "https://files.pythonhosted.org/packages/5f/f7/769d5632ffb0988f1c5e7660b3e731e30f7f8ec4318e94d0a5d674eb65a4/pillow-12.2.0-cp314-cp314-musllinux_1_2_x86_64.whl", hash = "sha256:deede7c263feb25dba4e82ea23058a235dcc2fe1f6021025dc71f2b618e26104", size = 7209321, upload-time = "2026-04-01T14:45:15.122Z" }, + { url = "https://files.pythonhosted.org/packages/6a/7a/c253e3c645cd47f1aceea6a8bacdba9991bf45bb7dfe927f7c893e89c93c/pillow-12.2.0-cp314-cp314-win32.whl", hash = "sha256:632ff19b2778e43162304d50da0181ce24ac5bb8180122cbe1bf4673428328c7", size = 6479723, upload-time = "2026-04-01T14:45:17.797Z" }, + { url = "https://files.pythonhosted.org/packages/cd/8b/601e6566b957ca50e28725cb6c355c59c2c8609751efbecd980db44e0349/pillow-12.2.0-cp314-cp314-win_amd64.whl", hash = "sha256:4e6c62e9d237e9b65fac06857d511e90d8461a32adcc1b9065ea0c0fa3a28150", size = 7217400, upload-time = "2026-04-01T14:45:20.529Z" }, + { url = "https://files.pythonhosted.org/packages/d6/94/220e46c73065c3e2951bb91c11a1fb636c8c9ad427ac3ce7d7f3359b9b2f/pillow-12.2.0-cp314-cp314-win_arm64.whl", hash = "sha256:b1c1fbd8a5a1af3412a0810d060a78b5136ec0836c8a4ef9aa11807f2a22f4e1", size = 2554835, upload-time = "2026-04-01T14:45:23.162Z" }, + { url = "https://files.pythonhosted.org/packages/b6/ab/1b426a3974cb0e7da5c29ccff4807871d48110933a57207b5a676cccc155/pillow-12.2.0-cp314-cp314t-macosx_10_15_x86_64.whl", hash = "sha256:57850958fe9c751670e49b2cecf6294acc99e562531f4bd317fa5ddee2068463", size = 5314225, upload-time = "2026-04-01T14:45:25.637Z" }, + { url = "https://files.pythonhosted.org/packages/19/1e/dce46f371be2438eecfee2a1960ee2a243bbe5e961890146d2dee1ff0f12/pillow-12.2.0-cp314-cp314t-macosx_11_0_arm64.whl", hash = "sha256:d5d38f1411c0ed9f97bcb49b7bd59b6b7c314e0e27420e34d99d844b9ce3b6f3", size = 4698541, upload-time = "2026-04-01T14:45:28.355Z" }, + { url = "https://files.pythonhosted.org/packages/55/c3/7fbecf70adb3a0c33b77a300dc52e424dc22ad8cdc06557a2e49523b703d/pillow-12.2.0-cp314-cp314t-manylinux2014_aarch64.manylinux_2_17_aarch64.whl", hash = "sha256:5c0a9f29ca8e79f09de89293f82fc9b0270bb4af1d58bc98f540cc4aedf03166", size = 6322251, upload-time = "2026-04-01T14:45:30.924Z" }, + { url = "https://files.pythonhosted.org/packages/1c/3c/7fbc17cfb7e4fe0ef1642e0abc17fc6c94c9f7a16be41498e12e2ba60408/pillow-12.2.0-cp314-cp314t-manylinux2014_x86_64.manylinux_2_17_x86_64.whl", hash = "sha256:1610dd6c61621ae1cf811bef44d77e149ce3f7b95afe66a4512f8c59f25d9ebe", size = 8127807, upload-time = "2026-04-01T14:45:33.908Z" }, + { url = "https://files.pythonhosted.org/packages/ff/c3/a8ae14d6defd2e448493ff512fae903b1e9bd40b72efb6ec55ce0048c8ce/pillow-12.2.0-cp314-cp314t-manylinux_2_27_aarch64.manylinux_2_28_aarch64.whl", hash = "sha256:0a34329707af4f73cf1782a36cd2289c0368880654a2c11f027bcee9052d35dd", size = 6433935, upload-time = "2026-04-01T14:45:36.623Z" }, + { url = "https://files.pythonhosted.org/packages/6e/32/2880fb3a074847ac159d8f902cb43278a61e85f681661e7419e6596803ed/pillow-12.2.0-cp314-cp314t-manylinux_2_27_x86_64.manylinux_2_28_x86_64.whl", hash = "sha256:8e9c4f5b3c546fa3458a29ab22646c1c6c787ea8f5ef51300e5a60300736905e", size = 7116720, upload-time = "2026-04-01T14:45:39.258Z" }, + { url = "https://files.pythonhosted.org/packages/46/87/495cc9c30e0129501643f24d320076f4cc54f718341df18cc70ec94c44e1/pillow-12.2.0-cp314-cp314t-musllinux_1_2_aarch64.whl", hash = "sha256:fb043ee2f06b41473269765c2feae53fc2e2fbf96e5e22ca94fb5ad677856f06", size = 6540498, upload-time = "2026-04-01T14:45:41.879Z" }, + { url = "https://files.pythonhosted.org/packages/18/53/773f5edca692009d883a72211b60fdaf8871cbef075eaa9d577f0a2f989e/pillow-12.2.0-cp314-cp314t-musllinux_1_2_x86_64.whl", hash = "sha256:f278f034eb75b4e8a13a54a876cc4a5ab39173d2cdd93a638e1b467fc545ac43", size = 7239413, upload-time = "2026-04-01T14:45:44.705Z" }, + { url = "https://files.pythonhosted.org/packages/c9/e4/4b64a97d71b2a83158134abbb2f5bd3f8a2ea691361282f010998f339ec7/pillow-12.2.0-cp314-cp314t-win32.whl", hash = "sha256:6bb77b2dcb06b20f9f4b4a8454caa581cd4dd0643a08bacf821216a16d9c8354", size = 6482084, upload-time = "2026-04-01T14:45:47.568Z" }, + { url = "https://files.pythonhosted.org/packages/ba/13/306d275efd3a3453f72114b7431c877d10b1154014c1ebbedd067770d629/pillow-12.2.0-cp314-cp314t-win_amd64.whl", hash = "sha256:6562ace0d3fb5f20ed7290f1f929cae41b25ae29528f2af1722966a0a02e2aa1", size = 7225152, upload-time = "2026-04-01T14:45:50.032Z" }, + { url = "https://files.pythonhosted.org/packages/ff/6e/cf826fae916b8658848d7b9f38d88da6396895c676e8086fc0988073aaf8/pillow-12.2.0-cp314-cp314t-win_arm64.whl", hash = "sha256:aa88ccfe4e32d362816319ed727a004423aab09c5cea43c01a4b435643fa34eb", size = 2556579, upload-time = "2026-04-01T14:45:52.529Z" }, +] + [[package]] name = "platformdirs" version = "4.9.6" @@ -1585,6 +1645,34 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/66/4d/35352043ee0eaffdeff154fad67cd4a31dbed7ff8e3be1cc4549717d6d51/torch-2.10.0-cp314-cp314t-win_amd64.whl", hash = "sha256:71283a373f0ee2c89e0f0d5f446039bdabe8dbc3c9ccf35f0f784908b0acd185", size = 113995816, upload-time = "2026-01-21T16:22:05.312Z" }, ] +[[package]] +name = "torchvision" +version = "0.25.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "numpy" }, + { name = "pillow" }, + { name = "torch" }, +] +wheels = [ + { url = "https://files.pythonhosted.org/packages/f5/5b/1562a04a6a5a4cf8cf40016a0cdeda91ede75d6962cff7f809a85ae966a5/torchvision-0.25.0-cp313-cp313-macosx_12_0_arm64.whl", hash = "sha256:24e11199e4d84ba9c5ee7825ebdf1cd37ce8deec225117f10243cae984ced3ec", size = 1874918, upload-time = "2026-01-21T16:27:39.02Z" }, + { url = "https://files.pythonhosted.org/packages/36/b1/3d6c42f62c272ce34fcce609bb8939bdf873dab5f1b798fd4e880255f129/torchvision-0.25.0-cp313-cp313-manylinux_2_28_aarch64.whl", hash = "sha256:5f271136d2d2c0b7a24c5671795c6e4fd8da4e0ea98aeb1041f62bc04c4370ef", size = 2309106, upload-time = "2026-01-21T16:27:30.624Z" }, + { url = "https://files.pythonhosted.org/packages/c7/60/59bb9c8b67cce356daeed4cb96a717caa4f69c9822f72e223a0eae7a9bd9/torchvision-0.25.0-cp313-cp313-manylinux_2_28_x86_64.whl", hash = "sha256:855c0dc6d37f462482da7531c6788518baedca1e0847f3df42a911713acdfe52", size = 8071522, upload-time = "2026-01-21T16:27:29.392Z" }, + { url = "https://files.pythonhosted.org/packages/32/a5/9a9b1de0720f884ea50dbf9acb22cbe5312e51d7b8c4ac6ba9b51efd9bba/torchvision-0.25.0-cp313-cp313-win_amd64.whl", hash = "sha256:cef0196be31be421f6f462d1e9da1101be7332d91984caa6f8022e6c78a5877f", size = 4321911, upload-time = "2026-01-21T16:27:35.195Z" }, + { url = "https://files.pythonhosted.org/packages/52/99/dca81ed21ebaeff2b67cc9f815a20fdaa418b69f5f9ea4c6ed71721470db/torchvision-0.25.0-cp313-cp313t-macosx_11_0_arm64.whl", hash = "sha256:a8f8061284395ce31bcd460f2169013382ccf411148ceb2ee38e718e9860f5a7", size = 1896209, upload-time = "2026-01-21T16:27:32.159Z" }, + { url = "https://files.pythonhosted.org/packages/28/cc/2103149761fdb4eaed58a53e8437b2d716d48f05174fab1d9fcf1e2a2244/torchvision-0.25.0-cp313-cp313t-manylinux_2_28_aarch64.whl", hash = "sha256:146d02c9876858420adf41f3189fe90e3d6a409cbfa65454c09f25fb33bf7266", size = 2310735, upload-time = "2026-01-21T16:27:22.327Z" }, + { url = "https://files.pythonhosted.org/packages/76/ad/f4c985ad52ddd3b22711c588501be1b330adaeaf6850317f66751711b78c/torchvision-0.25.0-cp313-cp313t-manylinux_2_28_x86_64.whl", hash = "sha256:c4d395cb2c4a2712f6eb93a34476cdf7aae74bb6ea2ea1917f858e96344b00aa", size = 8089557, upload-time = "2026-01-21T16:27:27.666Z" }, + { url = "https://files.pythonhosted.org/packages/63/cc/0ea68b5802e5e3c31f44b307e74947bad5a38cc655231d845534ed50ddb8/torchvision-0.25.0-cp313-cp313t-win_amd64.whl", hash = "sha256:5e6b449e9fa7d642142c0e27c41e5a43b508d57ed8e79b7c0a0c28652da8678c", size = 4344260, upload-time = "2026-01-21T16:27:17.018Z" }, + { url = "https://files.pythonhosted.org/packages/9e/1f/fa839532660e2602b7e704d65010787c5bb296258b44fa8b9c1cd6175e7d/torchvision-0.25.0-cp314-cp314-macosx_11_0_arm64.whl", hash = "sha256:620a236288d594dcec7634c754484542dc0a5c1b0e0b83a34bda5e91e9b7c3a1", size = 1896193, upload-time = "2026-01-21T16:27:24.785Z" }, + { url = "https://files.pythonhosted.org/packages/80/ed/d51889da7ceaf5ff7a0574fb28f9b6b223df19667265395891f81b364ab3/torchvision-0.25.0-cp314-cp314-manylinux_2_28_aarch64.whl", hash = "sha256:0b5e7f50002a8145a98c5694a018e738c50e2972608310c7e88e1bd4c058f6ce", size = 2309331, upload-time = "2026-01-21T16:27:19.97Z" }, + { url = "https://files.pythonhosted.org/packages/90/a5/f93fcffaddd8f12f9e812256830ec9c9ca65abbf1bc369379f9c364d1ff4/torchvision-0.25.0-cp314-cp314-manylinux_2_28_x86_64.whl", hash = "sha256:632db02300e83793812eee4f61ae6a2686dab10b4cfd628b620dc47747aa9d03", size = 8088713, upload-time = "2026-01-21T16:27:15.281Z" }, + { url = "https://files.pythonhosted.org/packages/1f/eb/d0096eed5690d962853213f2ee00d91478dfcb586b62dbbb449fb8abc3a6/torchvision-0.25.0-cp314-cp314-win_amd64.whl", hash = "sha256:d1abd5ed030c708f5dbf4812ad5f6fbe9384b63c40d6bd79f8df41a4a759a917", size = 4325058, upload-time = "2026-01-21T16:27:26.165Z" }, + { url = "https://files.pythonhosted.org/packages/97/36/96374a4c7ab50dea9787ce987815614ccfe988a42e10ac1a2e3e5b60319a/torchvision-0.25.0-cp314-cp314t-macosx_11_0_arm64.whl", hash = "sha256:ad9a8a5877782944d99186e4502a614770fe906626d76e9cd32446a0ac3075f2", size = 1896207, upload-time = "2026-01-21T16:27:23.383Z" }, + { url = "https://files.pythonhosted.org/packages/b5/e2/7abb10a867db79b226b41da419b63b69c0bd5b82438c4a4ed50e084c552f/torchvision-0.25.0-cp314-cp314t-manylinux_2_28_aarch64.whl", hash = "sha256:40a122c3cf4d14b651f095e0f672b688dde78632783fc5cd3d4d5e4f6a828563", size = 2310741, upload-time = "2026-01-21T16:27:18.712Z" }, + { url = "https://files.pythonhosted.org/packages/08/e6/0927784e6ffc340b6676befde1c60260bd51641c9c574b9298d791a9cda4/torchvision-0.25.0-cp314-cp314t-manylinux_2_28_x86_64.whl", hash = "sha256:846890161b825b38aa85fc37fb3ba5eea74e7091ff28bab378287111483b6443", size = 8089772, upload-time = "2026-01-21T16:27:14.048Z" }, + { url = "https://files.pythonhosted.org/packages/b6/37/e7ca4ec820d434c0f23f824eb29f0676a0c3e7a118f1514f5b949c3356da/torchvision-0.25.0-cp314-cp314t-win_amd64.whl", hash = "sha256:f07f01d27375ad89d72aa2b3f2180f07da95dd9d2e4c758e015c0acb2da72977", size = 4425879, upload-time = "2026-01-21T16:27:12.579Z" }, +] + [[package]] name = "tqdm" version = "4.67.3" From 58a96e87100a3b6f378e4eee3723b60db9744ba5 Mon Sep 17 00:00:00 2001 From: Fabrice Normandin Date: Tue, 26 May 2026 11:55:06 -0600 Subject: [PATCH 03/22] Some ideas Signed-off-by: Fabrice Normandin --- cluv/config.py | 29 ++++++++++++++++++++++++----- cluv/utils.py | 10 ++++++++++ pyproject.toml | 21 ++++++++++----------- 3 files changed, 44 insertions(+), 16 deletions(-) diff --git a/cluv/config.py b/cluv/config.py index 8e5053c..614c3fe 100644 --- a/cluv/config.py +++ b/cluv/config.py @@ -9,6 +9,8 @@ from pydantic import BaseModel +from cluv.utils import current_cluster, resolve_env_vars + logger = logging.getLogger(__name__) @@ -18,7 +20,10 @@ class ClusterConfig(BaseModel): env: dict[str, str] = {} """Environment variables to set when running Slurm commands on this cluster.""" - datasets_path: str | None + results_path: str | None # TODO: Change to `Path` instead. Fix any pydantic errors. + """Path to the results directory for a specific cluster.""" + + datasets_path: str | None # TODO: Change to `Path` instead. Fix any pydantic errors. """Different path where the datasets should be replicated on this cluster. When `None`, this defaults to the top-level config's `datasets_path`. @@ -30,8 +35,11 @@ class ClusterConfig(BaseModel): class CluvConfig(BaseModel): """Configuration options for Cluv, loaded from the pyproject.toml file.""" + env: dict[str, str] = {} + """Global environment variables set on all clusters when running Slurm commands.""" + results_path: str - """Path to the results directory, relative to the project root. + """Default path to the results directory for all clusters. !!! info On Slurm clusters, this will be a symlink to a folder in `$SCRATCH//`. @@ -43,9 +51,6 @@ class CluvConfig(BaseModel): This folder will be synced from the current cluster to all other clusters at their respective `dataset_path`. """ - env: dict[str, str] = {} - """Global environment variables set on all clusters when running Slurm commands.""" - clusters: dict[str, ClusterConfig] = {} """Configuration options for each cluster. @@ -97,3 +102,17 @@ def load_cluv_config(pyproject_path: Path) -> CluvConfig: def get_cluster_choices() -> list[str]: """Return configured clusters or the defaults when config is missing/invalid.""" return get_config().clusters_names + + +def current_cluster_config() -> ClusterConfig | None: + """Returns the `ClusterConfig` of the current cluster, or None if not currently on a cluster.""" + cluster = current_cluster() + if not cluster: + return None # not on a cluster. + cluv_config = load_cluv_config(find_pyproject()) + cluster_config = cluv_config.clusters[cluster] + return ClusterConfig( + env=cluv_config.env | cluster_config.env, + results_path=resolve_env_vars(cluster_config.results_path or cluv_config.results_path), + datasets_path=resolve_env_vars(cluster_config.datasets_path or cluv_config.datasets_path), + ) diff --git a/cluv/utils.py b/cluv/utils.py index ca5ee94..8198574 100644 --- a/cluv/utils.py +++ b/cluv/utils.py @@ -1,6 +1,7 @@ import os import socket import sys +from pathlib import Path import rich.console @@ -14,3 +15,12 @@ def current_cluster() -> str | None: if "CC_CLUSTER" in os.environ: return os.environ["CC_CLUSTER"] return None + + +def resolve_env_vars(string_or_path: str | Path): + path = Path(string_or_path) + parts = path.parts + new_parts = [ + os.environ.get(part.removeprefix("$")) if part.startswith("$") else part for part in parts + ] + return os.path.join(*new_parts) diff --git a/pyproject.toml b/pyproject.toml index abb66ac..6c07429 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -82,16 +82,16 @@ source = ["cluv"] show_missing = true [tool.cluv] -# Where to store results on all clusters. TODO: Can be overridden in each cluster's config. -results_path = "logs" +# Where to store results of jobs all clusters. +# TODO: Can be overridden in each cluster's config. +# If this is a simple path (like 'logs'), then on Slurm clusters a symlink is created in $SCRATCH/logs/. +results_path = "$SCRATCH/logs/cluv/$SLURM_JOB_ID" -# Which cluster to get the data from. When unset, assumes the current cluster is the source. -data_source_cluster = "mila" -# Where to read the data from on the source cluster. -data_source_path = "/network/datasets/cifar10" +# Where to read the data from, when synchronizing data to all clusters. +data_source = "mila:/network/datasets/cifar10" # Where the dataset should be replicated on all clusters. -# On the source cluster (ex Mila), the folder could contain symlinks, to avoid duplicating the data. +# On the source cluster (ex Mila), the folder will only contain symlinks, to avoid duplicating the data. datasets_path = "$SCRATCH/data/cifar10" [tool.cluv.env] @@ -100,21 +100,20 @@ SBATCH_TIME = "3:00:00" SBATCH_REQUEUE = "1" [tool.cluv.clusters.mila] +# Overrides specific to the Mila cluster. env = {UV_OFFLINE="0", WANDB_MODE="online"} -datasets_path = "$SCRATCH/data/cifar10" # PAICE clusters. [tool.cluv.clusters.tamia] env = {UV_OFFLINE="1", WANDB_MODE="offline"} -datasets_path = "$SCRATCH/data/cifar10" [tool.cluv.clusters.killarney] env = {UV_OFFLINE="1", WANDB_MODE="offline"} -datasets_path = "$SCRATCH/data/cifar10" +# For example, no $SCRATCH by default on Killarney, so we overwrite this here. +datasets_path = "$HOME/data/cifar10" [tool.cluv.clusters.vulcan] env = {UV_OFFLINE="1", WANDB_MODE="offline"} -datasets_path = "$SCRATCH/data/cifar10" # DRAC clusters. [tool.cluv.clusters.rorqual] From f30d14a217ced1bfacb27707b236bf7407cd9d7e Mon Sep 17 00:00:00 2001 From: Fabrice Normandin Date: Tue, 26 May 2026 13:51:52 -0600 Subject: [PATCH 04/22] Some ideas Signed-off-by: Fabrice Normandin --- cluv/config.py | 58 ++++++++++++++++++++++++++++++++++++++++-------- cluv/utils.py | 4 ++-- pyproject.toml | 9 +------- tests/example.py | 11 +++++++-- 4 files changed, 61 insertions(+), 21 deletions(-) diff --git a/cluv/config.py b/cluv/config.py index 614c3fe..06228fc 100644 --- a/cluv/config.py +++ b/cluv/config.py @@ -8,16 +8,18 @@ from pathlib import Path from pydantic import BaseModel +from pydantic.dataclasses import dataclass from cluv.utils import current_cluster, resolve_env_vars logger = logging.getLogger(__name__) -class ClusterConfig(BaseModel): +@dataclass(frozen=True) +class PartialClusterConfig: """Per-cluster configuration options.""" - env: dict[str, str] = {} + env: dict[str, str] """Environment variables to set when running Slurm commands on this cluster.""" results_path: str | None # TODO: Change to `Path` instead. Fix any pydantic errors. @@ -32,6 +34,32 @@ class ClusterConfig(BaseModel): """ +@dataclass(frozen=True) +class ClusterConfig: + """Per-cluster configuration options.""" + + env: dict[str, str] + """Environment variables to set when running Slurm commands on this cluster.""" + + results_path: Path + """Path to the results directory for a specific cluster.""" + + datasets_path: Path + """Different path where the datasets should be replicated on this cluster. + + When `None`, this defaults to the top-level config's `datasets_path`. + + This folder will be synced from the current cluster to all other clusters at their respective `dataset_path`. + """ + + def resolve_env_vars_in_paths(self): + return ClusterConfig( + env=self.env, + results_path=resolve_env_vars(self.results_path), + datasets_path=resolve_env_vars(self.datasets_path), + ) + + class CluvConfig(BaseModel): """Configuration options for Cluv, loaded from the pyproject.toml file.""" @@ -51,7 +79,10 @@ class CluvConfig(BaseModel): This folder will be synced from the current cluster to all other clusters at their respective `dataset_path`. """ - clusters: dict[str, ClusterConfig] = {} + data_source: str | None + """`hostname:/path` of where to get the data from.""" + + clusters: dict[str, PartialClusterConfig] = {} """Configuration options for each cluster. The keys are cluster names; each value is a `ClusterConfig` whose `env` dict contains @@ -62,6 +93,19 @@ class CluvConfig(BaseModel): def clusters_names(self) -> list[str]: return list(self.clusters.keys()) + def get_cluster_config(self, cluster: str) -> ClusterConfig: + """Returns the cluster config for a specific cluster. + + The environment variables as part of paths will *not* be resolved. + """ + cluv_config = load_cluv_config(find_pyproject()) + cluster_config = cluv_config.clusters[cluster] + return ClusterConfig( + env=cluv_config.env | cluster_config.env, + results_path=Path(cluster_config.results_path or cluv_config.results_path), + datasets_path=Path(cluster_config.datasets_path or cluv_config.datasets_path), + ) + @functools.cache def get_config() -> CluvConfig: @@ -110,9 +154,5 @@ def current_cluster_config() -> ClusterConfig | None: if not cluster: return None # not on a cluster. cluv_config = load_cluv_config(find_pyproject()) - cluster_config = cluv_config.clusters[cluster] - return ClusterConfig( - env=cluv_config.env | cluster_config.env, - results_path=resolve_env_vars(cluster_config.results_path or cluv_config.results_path), - datasets_path=resolve_env_vars(cluster_config.datasets_path or cluv_config.datasets_path), - ) + config = cluv_config.get_cluster_config(cluster) + return config.resolve_env_vars_in_paths() diff --git a/cluv/utils.py b/cluv/utils.py index 8198574..fd1a9e8 100644 --- a/cluv/utils.py +++ b/cluv/utils.py @@ -21,6 +21,6 @@ def resolve_env_vars(string_or_path: str | Path): path = Path(string_or_path) parts = path.parts new_parts = [ - os.environ.get(part.removeprefix("$")) if part.startswith("$") else part for part in parts + os.environ[part.removeprefix("$")] if part.startswith("$") else part for part in parts ] - return os.path.join(*new_parts) + return Path(*new_parts) diff --git a/pyproject.toml b/pyproject.toml index 6c07429..83774f2 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -84,8 +84,7 @@ show_missing = true [tool.cluv] # Where to store results of jobs all clusters. # TODO: Can be overridden in each cluster's config. -# If this is a simple path (like 'logs'), then on Slurm clusters a symlink is created in $SCRATCH/logs/. -results_path = "$SCRATCH/logs/cluv/$SLURM_JOB_ID" +results_path = "$SCRATCH/logs/cluv" # Where to read the data from, when synchronizing data to all clusters. data_source = "mila:/network/datasets/cifar10" @@ -118,26 +117,20 @@ env = {UV_OFFLINE="1", WANDB_MODE="offline"} # DRAC clusters. [tool.cluv.clusters.rorqual] env = {UV_OFFLINE="1", WANDB_MODE="offline", SBATCH_ACCOUNT="rrg-bengioy-ad"} -datasets_path = "$SCRATCH/data/cifar10" [tool.cluv.clusters.fir] env = {UV_OFFLINE="0", WANDB_MODE="online", SBATCH_ACCOUNT="rrg-bengioy-ad"} -datasets_path = "$SCRATCH/data/cifar10" [tool.cluv.clusters.nibi] env = {UV_OFFLINE="0", WANDB_MODE="online", SBATCH_ACCOUNT="rrg-bengioy-ad"} -datasets_path = "$SCRATCH/data/cifar10" [tool.cluv.clusters.trillium] env = {UV_OFFLINE="1", WANDB_MODE="offline", SBATCH_ACCOUNT="rrg-bengioy-ad"} -datasets_path = "$SCRATCH/data/cifar10" [tool.cluv.clusters.trillium-gpu] env = {UV_OFFLINE="1", WANDB_MODE="offline", SBATCH_ACCOUNT="rrg-bengioy-ad"} -datasets_path = "$SCRATCH/data/cifar10" [tool.cluv.clusters.narval] # Mila doesn't have an allocation on Narval anymore. # You can also use "def-yourusername" (the default partitions). env = {UV_OFFLINE="1", WANDB_MODE="offline", SBATCH_ACCOUNT="def-bengioy"} -datasets_path = "$SCRATCH/data/cifar10" diff --git a/tests/example.py b/tests/example.py index 9f6a4c3..ab7f5f5 100644 --- a/tests/example.py +++ b/tests/example.py @@ -11,12 +11,19 @@ import simple_parsing from torchvision.datasets import CIFAR10 +import cluv +import cluv.config + SLURM_JOB_ID = int(os.environ["SLURM_JOB_ID"]) SCRATCH = Path(os.environ["SCRATCH"]) SLURM_TMPDIR = Path(os.environ["SLURM_TMPDIR"]) # IDEA: maybe load the cluv config and set the checkpoint_dir # from cluv.config import load_cluv_config +config = cluv.config.current_cluster_config() +assert config, "Example must be run on a cluster." +assert config.results_path +assert config.datasets_path @dataclass(frozen=True) @@ -24,10 +31,10 @@ class Args: """Command-line arguments for this example.""" # NOTE: This should be the same as the `results_path` in the Cluv config. - results_path: Path = SCRATCH / "logs" / "cluv" / str(SLURM_JOB_ID) + results_path: Path = config.results_path # NOTE: This should be the same as the `datasets_path` in the Cluv config. - datasets_path: Path = SCRATCH / "data" / "cifar10" + datasets_path: Path = config.datasets_path # Time to wait before producing the result. # Can be useful to test and simulate preemption or cancelling jobs. From 129026f7d105602b2d9f617abe0c48d4702949ff Mon Sep 17 00:00:00 2001 From: Fabrice Normandin Date: Tue, 26 May 2026 14:06:05 -0600 Subject: [PATCH 05/22] Add an idea of a test Signed-off-by: Fabrice Normandin --- tests/test_sync.py | 26 +++++++++++++++++++++++++- 1 file changed, 25 insertions(+), 1 deletion(-) diff --git a/tests/test_sync.py b/tests/test_sync.py index f07ab08..9aef64d 100644 --- a/tests/test_sync.py +++ b/tests/test_sync.py @@ -1,9 +1,33 @@ """Tests for `cluv sync`""" +import subprocess -def test_cluv_sync_with_data_path(): +import pytest + +from cluv.cli.sync import sync +from cluv.remote import Remote +from cluv.utils import current_cluster + + +@pytest.mark.asyncio +async def test_cluv_sync_with_data_path(): """TODO: Test for `cluv sync` with a project that has a 'data_path'. Need to check that rsync happens from `datasets_path` (the source) to the `datasets_path` (the dest) on all the clusters. """ + assert current_cluster() == "mila" + other_cluster = "tamia" + other_cluster_remote = await Remote.connect(other_cluster) + + # Dataset isn't synced + this_cluster_files = subprocess.getoutput("ls $SCRATCH/data/cifar10") + other_cluster_files = await other_cluster_remote.get_output("ls $SCRATCH/data/cifar10") + assert this_cluster_files != other_cluster_files + + await sync([other_cluster], uv_sync_args=None) + + # Dataset is synced + this_cluster_files = subprocess.getoutput("ls $SCRATCH/data/cifar10") + other_cluster_files = await other_cluster_remote.get_output("ls $SCRATCH/data/cifar10") + assert this_cluster_files == other_cluster_files From 1a8b14886d19e0068a589a7d3e025c136387cd8d Mon Sep 17 00:00:00 2001 From: Fabrice Normandin Date: Wed, 27 May 2026 09:22:12 -0400 Subject: [PATCH 06/22] Fix issue with config Signed-off-by: Fabrice Normandin --- cluv/cli/sync.py | 8 ++++++-- cluv/config.py | 9 +++++---- pyproject.toml | 2 +- 3 files changed, 12 insertions(+), 7 deletions(-) diff --git a/cluv/cli/sync.py b/cluv/cli/sync.py index 350a3db..10c5048 100644 --- a/cluv/cli/sync.py +++ b/cluv/cli/sync.py @@ -194,7 +194,9 @@ async def clone_project(remote: Remote): git_remote_name = "origin" if not git_remote_name: git_remote_name = "origin" - github_repo_url = subprocess.getoutput(f"git config --get remote.{git_remote_name}.url").strip() + github_repo_url = subprocess.getoutput( + f"git config --get remote.{git_remote_name}.url" + ).strip() if not github_repo_url: raise RuntimeError( f"Could not determine Git remote URL from remote '{git_remote_name}'. " @@ -205,7 +207,9 @@ async def clone_project(remote: Remote): # Or configure the config credential-helper to store first? # Get the path to the root of the git repository - git_root_path = PurePosixPath(subprocess.getoutput("git rev-parse --show-toplevel").strip()).relative_to(Path.home()) + git_root_path = PurePosixPath( + subprocess.getoutput("git rev-parse --show-toplevel").strip() + ).relative_to(Path.home()) # If the project isn't cloned yet, clone it. _is_cloned_on_cluster = ( diff --git a/cluv/config.py b/cluv/config.py index 06228fc..767029f 100644 --- a/cluv/config.py +++ b/cluv/config.py @@ -5,6 +5,7 @@ import functools import logging import tomllib +from dataclasses import field from pathlib import Path from pydantic import BaseModel @@ -19,13 +20,13 @@ class PartialClusterConfig: """Per-cluster configuration options.""" - env: dict[str, str] + env: dict[str, str] = field(default_factory=dict) """Environment variables to set when running Slurm commands on this cluster.""" - results_path: str | None # TODO: Change to `Path` instead. Fix any pydantic errors. + results_path: str | None = None # TODO: Change to `Path` instead. Fix any pydantic errors. """Path to the results directory for a specific cluster.""" - datasets_path: str | None # TODO: Change to `Path` instead. Fix any pydantic errors. + datasets_path: str | None = None # TODO: Change to `Path` instead. Fix any pydantic errors. """Different path where the datasets should be replicated on this cluster. When `None`, this defaults to the top-level config's `datasets_path`. @@ -73,7 +74,7 @@ class CluvConfig(BaseModel): On Slurm clusters, this will be a symlink to a folder in `$SCRATCH//`. """ - datasets_path: str + datasets_path: str | None """Path to a dataset directory, for example, `'$SCRATCH/my_dataset'` This folder will be synced from the current cluster to all other clusters at their respective `dataset_path`. diff --git a/pyproject.toml b/pyproject.toml index 83774f2..85a279a 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -86,7 +86,7 @@ show_missing = true # TODO: Can be overridden in each cluster's config. results_path = "$SCRATCH/logs/cluv" -# Where to read the data from, when synchronizing data to all clusters. +# Where to read the data from when synchronizing data to all clusters. data_source = "mila:/network/datasets/cifar10" # Where the dataset should be replicated on all clusters. From 84843b6081ea32602681537e91ada3e8f0345402 Mon Sep 17 00:00:00 2001 From: Fabrice Normandin Date: Thu, 28 May 2026 15:17:25 -0400 Subject: [PATCH 07/22] wip Signed-off-by: Fabrice Normandin --- cluv/cli/sync.py | 16 ++++++++-------- tests/test_sync.py | 9 ++++++--- 2 files changed, 14 insertions(+), 11 deletions(-) diff --git a/cluv/cli/sync.py b/cluv/cli/sync.py index 10c5048..2784bb3 100644 --- a/cluv/cli/sync.py +++ b/cluv/cli/sync.py @@ -65,22 +65,22 @@ async def sync( - Gathers results from all other clusters to the Mila cluster using rsync. """ # TODO: Figure out which Slurm cluster we're currently on. Assuming mila for now. - this_cluster = current_cluster() + here = current_cluster() + if clusters and here in clusters: + clusters.remove(here) + # When no cluster is passed, sync with clusters for which we have an active SSH connection. if not clusters: clusters = get_config().clusters_names - if this_cluster and this_cluster in clusters: - clusters.remove(this_cluster) connections = await asyncio.gather( *(get_remote_without_2fa_prompt(cluster) for cluster in clusters) ) - remotes = [conn for conn in connections if conn] - if not remotes: - console.log( + if not any(connections): + raise RuntimeError( "[red]Not currently connected to any Slurm cluster.[/red] " "Use `cluv login` to login and create reusable connections." ) - return [] + remotes = [conn for conn in connections if conn] # keep the active connections. clusters = [remote.hostname for remote in remotes] else: remotes = await login(clusters) @@ -97,7 +97,7 @@ async def sync( task_descriptions: list[str] = [] for remote in remotes: tasks.append(functools.partial(sync_task_function, remote=remote)) - task_descriptions.append(f"{this_cluster or 'local'} -> {remote.hostname}") + task_descriptions.append(f"{here or 'local'} -> {remote.hostname}") await run_async_tasks_with_progress_bar( async_task_fns=tasks, diff --git a/tests/test_sync.py b/tests/test_sync.py index 9aef64d..4432c2d 100644 --- a/tests/test_sync.py +++ b/tests/test_sync.py @@ -16,7 +16,8 @@ async def test_cluv_sync_with_data_path(): Need to check that rsync happens from `datasets_path` (the source) to the `datasets_path` (the dest) on all the clusters. """ - assert current_cluster() == "mila" + # assert current_cluster() == "mila" + assert current_cluster() is None other_cluster = "tamia" other_cluster_remote = await Remote.connect(other_cluster) @@ -28,6 +29,8 @@ async def test_cluv_sync_with_data_path(): await sync([other_cluster], uv_sync_args=None) # Dataset is synced - this_cluster_files = subprocess.getoutput("ls $SCRATCH/data/cifar10") - other_cluster_files = await other_cluster_remote.get_output("ls $SCRATCH/data/cifar10") + this_cluster_files = subprocess.getoutput("ls $SCRATCH/data/cifar10").strip().splitlines() + other_cluster_files = ( + (await other_cluster_remote.get_output("ls $SCRATCH/data/cifar10")).strip().splitlines() + ) assert this_cluster_files == other_cluster_files From 9fb8c257660aa4c9756ef7b38b6826eb1760f345 Mon Sep 17 00:00:00 2001 From: Fabrice Normandin Date: Fri, 29 May 2026 11:39:55 -0400 Subject: [PATCH 08/22] Improve the design of the config Signed-off-by: Fabrice Normandin --- pyproject.toml | 40 ++++++++++++++++++++++------------------ tests/example.py | 2 ++ 2 files changed, 24 insertions(+), 18 deletions(-) diff --git a/pyproject.toml b/pyproject.toml index 85a279a..5e92067 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -81,42 +81,47 @@ source = ["cluv"] [tool.coverage.report] show_missing = true +### -------------- CLUV CONFIG -------------- ### + [tool.cluv] -# Where to store results of jobs all clusters. -# TODO: Can be overridden in each cluster's config. +# Where to store job results by default. Can be overridden per cluster. results_path = "$SCRATCH/logs/cluv" - +# On clusters, Cluv creates a symlink (a shortcut) in your project folder to the results_path dir. +# This makes it easier to keep your project in $HOME and to see the results which are on $SCRATCH. +results_symlink = "logs" # Where to read the data from when synchronizing data to all clusters. data_source = "mila:/network/datasets/cifar10" - # Where the dataset should be replicated on all clusters. -# On the source cluster (ex Mila), the folder will only contain symlinks, to avoid duplicating the data. -datasets_path = "$SCRATCH/data/cifar10" +# TODO?: On the source cluster (ex Mila), the folder will only contain symlinks, to avoid +# duplicating the data. +datasets_path = "$SCRATCH/datasets/cifar10" [tool.cluv.env] # Environment variables applied when using Slurm commands on all clusters. SBATCH_TIME = "3:00:00" SBATCH_REQUEUE = "1" +# Assume that compute nodes don't have internet access by default. Override below when they do. +UV_OFFLINE="1" +WANDB_MODE="offline" + + +### -------------- Clusters Config -------------- ### [tool.cluv.clusters.mila] # Overrides specific to the Mila cluster. env = {UV_OFFLINE="0", WANDB_MODE="online"} -# PAICE clusters. [tool.cluv.clusters.tamia] -env = {UV_OFFLINE="1", WANDB_MODE="offline"} [tool.cluv.clusters.killarney] -env = {UV_OFFLINE="1", WANDB_MODE="offline"} -# For example, no $SCRATCH by default on Killarney, so we overwrite this here. -datasets_path = "$HOME/data/cifar10" +# For example, you might not have a $SCRATCH on Killarney. This can be overwritten here. +results_path = "$HOME/logs/cluv" +datasets_path = "$HOME/datasets/cifar10" [tool.cluv.clusters.vulcan] -env = {UV_OFFLINE="1", WANDB_MODE="offline"} -# DRAC clusters. [tool.cluv.clusters.rorqual] -env = {UV_OFFLINE="1", WANDB_MODE="offline", SBATCH_ACCOUNT="rrg-bengioy-ad"} +env = {SBATCH_ACCOUNT="rrg-bengioy-ad"} [tool.cluv.clusters.fir] env = {UV_OFFLINE="0", WANDB_MODE="online", SBATCH_ACCOUNT="rrg-bengioy-ad"} @@ -125,12 +130,11 @@ env = {UV_OFFLINE="0", WANDB_MODE="online", SBATCH_ACCOUNT="rrg-bengioy-ad"} env = {UV_OFFLINE="0", WANDB_MODE="online", SBATCH_ACCOUNT="rrg-bengioy-ad"} [tool.cluv.clusters.trillium] -env = {UV_OFFLINE="1", WANDB_MODE="offline", SBATCH_ACCOUNT="rrg-bengioy-ad"} +env = {SBATCH_ACCOUNT="rrg-bengioy-ad"} [tool.cluv.clusters.trillium-gpu] -env = {UV_OFFLINE="1", WANDB_MODE="offline", SBATCH_ACCOUNT="rrg-bengioy-ad"} +env = {SBATCH_ACCOUNT="rrg-bengioy-ad"} [tool.cluv.clusters.narval] # Mila doesn't have an allocation on Narval anymore. -# You can also use "def-yourusername" (the default partitions). -env = {UV_OFFLINE="1", WANDB_MODE="offline", SBATCH_ACCOUNT="def-bengioy"} +env = {SBATCH_ACCOUNT="def-bengioy"} diff --git a/tests/example.py b/tests/example.py index ab7f5f5..f21d720 100644 --- a/tests/example.py +++ b/tests/example.py @@ -25,6 +25,8 @@ assert config.results_path assert config.datasets_path +# TODO: datasets_path should be data_source on the source cluster, and cluster.datasets_path on others. + @dataclass(frozen=True) class Args: From 98a2fc1e4f7a2e708cc23ef8deb3743360c03e1c Mon Sep 17 00:00:00 2001 From: Fabrice Normandin Date: Fri, 29 May 2026 15:07:20 -0400 Subject: [PATCH 09/22] Improving example for testing the sync Signed-off-by: Fabrice Normandin --- cluv/config.py | 7 +++ cluv/job.py | 129 +++++++++++++++++++++++++++++++++++++++++++++++ pyproject.toml | 1 + tests/example.py | 62 +++++++++++++---------- uv.lock | 92 +++++++++++++++++++++++++++++++++ 5 files changed, 263 insertions(+), 28 deletions(-) create mode 100644 cluv/job.py diff --git a/cluv/config.py b/cluv/config.py index 767029f..907371b 100644 --- a/cluv/config.py +++ b/cluv/config.py @@ -2,6 +2,7 @@ from __future__ import annotations +import dataclasses import functools import logging import tomllib @@ -155,5 +156,11 @@ def current_cluster_config() -> ClusterConfig | None: if not cluster: return None # not on a cluster. cluv_config = load_cluv_config(find_pyproject()) + data_source = cluv_config.data_source config = cluv_config.get_cluster_config(cluster) + if data_source: + source_cluster, data_path = data_source.split(":", 1) + if cluster == source_cluster: + # use the dataset path from the data_source setting as the datasets_path. + config = dataclasses.replace(config, datasets_path=data_path) return config.resolve_env_vars_in_paths() diff --git a/cluv/job.py b/cluv/job.py new file mode 100644 index 0000000..073e436 --- /dev/null +++ b/cluv/job.py @@ -0,0 +1,129 @@ +"""A script that reads something, and produces some output. + +This is a simplified job script, used to test the syncing of the 'dataset' across clusters. +""" + +import functools +import os +import re +import subprocess +from dataclasses import dataclass +from pathlib import Path + +import cluv +import cluv.config +from cluv.utils import current_cluster + +SLURM_JOB_ID: int | None = ( + int(os.environ["SLURM_JOB_ID"]) if "SLURM_JOB_ID" in os.environ else None +) +SCRATCH = Path(os.environ["SCRATCH"]) if "SCRATCH" in os.environ else None +SLURM_TMPDIR = Path(os.environ["SLURM_TMPDIR"]) if "SLURM_TMPDIR" in os.environ else None +SLURM_PROCID = int(os.environ["SLURM_PROCID"]) if "SLURM_PROCID" in os.environ else None + + +in_job_packing = "SLURM_NTASKS_PER_GPU" in os.environ +in_job_array = "SLURM_ARRAY_JOB_ID" in os.environ + + +@dataclass(frozen=True) +class JobInfo: + """Information about a "job"/"run". + + Note, there may be multiple "runs" inside a single "job", that's why there is a distinction. + """ + + cluster: str + + run_id: str + """The unique 'run identifier' for this job/run, used for checkpointing and Weights & Biases. + + This will usually just be {cluster}_{SLURM_JOB_ID}, but can also vary based on whether + the job is doing job packing (with --ntasks-per-gpu) or job chunking (with --array=...%1) or + both. + + Use this as the run_id for `wandb.init` or whenever you need a unique run identifier. + """ + + results_path: Path + + @property + def datasets_path(self) -> Path | None: + """The path where the datasets are located for this job (based on which cluster it runs on.)""" + cluster_info = cluv.config.current_cluster_config() + assert cluster_info + return cluster_info.datasets_path + + +def current_job_info() -> JobInfo | None: + """Returns information about the current job, such as its unique run id and results path. + + This is useful to determine where to save checkpoints or results for this job, and to have a unique + identifier for this job that can be used in Weights & Biases or elsewhere. + + The 'run id' is determined based on the cluster name and SLURM job id, and also takes into account + whether the job is doing job packing (with --ntasks-per-gpu) or job chunking (with --array=...%1). + """ + if not SLURM_JOB_ID: + return None # not in a Slurm job. + cluster = current_cluster() + run_id = _get_run_id() + # IDEA: maybe load the cluv config and set the checkpoint_dir + # from cluv.config import load_cluv_config + assert cluster, "Example must be run on a cluster." + config = cluv.config.current_cluster_config() + assert config, "Example must be run on a cluster." + assert config.results_path + assert config.datasets_path + return JobInfo( + run_id=run_id, + cluster=cluster, + results_path=config.results_path / run_id, + ) + + +@functools.cache +def _get_max_active_jobs() -> int | None: + """When in a job array, returns the max number of active jobs at the same time. + + For example, with --array=0-20%4, this returns 4. + Returns `None` when not in a job array. + Result is cached since this calls scontrol in a subprocess. + """ + if "SLURM_ARRAY_JOB_ID" not in os.environ: + return None + output = subprocess.check_output( + ["scontrol", "--oneliner", "show", "job", os.environ["SLURM_ARRAY_JOB_ID"]], + text=True, + ) + match = re.search(r"ArrayTaskId=\S+%(\d+)", output) + return int(match.group(1)) if match else None + + +def _in_job_chunking() -> bool: + return in_job_array and _get_max_active_jobs() == 1 + + +def _get_run_id(): + cluster = current_cluster() + doing_job_packing = "SLURM_NTASKS_PER_GPU" in os.environ + doing_job_chunking = _in_job_chunking() + task_index = int(os.environ["SLURM_PROCID"]) + if doing_job_chunking: + # IF we have --array=...%1, use the id of the first job. + first_job_id = int(os.environ["SLURM_ARRAY_JOB_ID"]) + if doing_job_packing: + # Running with --array=0-5%1 for chunking and --ntasks-per-gpu for packing! Awesome!! + return f"{cluster}_{first_job_id}_task{task_index}" + # IDEA: If we support doing an arrays of 'chunked' jobs, then we could use this: + # IF we have --array=0-20%4, this means there are 4 jobs with 5 chunks each (weird). + # max_active_jobs = get_max_active_jobs() + # assert max_active_jobs is not None and max_active_jobs > 1 + # index_in_array = int(os.environ["SLURM_ARRAY_TASK_ID"]) + # return str(first_job_id + (index_in_array % max_active_jobs)) + # Keeping it simple for now, only support chunking with --array=...%1, so we always use + # the id of the first job in the array. + return f"{cluster}_{first_job_id}" + if doing_job_packing: + return f"{cluster}_{SLURM_JOB_ID}_{SLURM_PROCID}" + return f"{cluster}_{SLURM_JOB_ID}" diff --git a/pyproject.toml b/pyproject.toml index 5e92067..f108d7c 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -30,6 +30,7 @@ dev = [ "ruff", "pytest-skip-slow>=0.0.5", "torchvision>=0.25.0", + "wandb>=0.27.0", ] [tool.pytest.ini_options] diff --git a/tests/example.py b/tests/example.py index f21d720..22506e5 100644 --- a/tests/example.py +++ b/tests/example.py @@ -3,59 +3,65 @@ This is a simplified job script, used to test the syncing of the 'dataset' across clusters. """ +import dataclasses import os +import random import time from dataclasses import dataclass -from pathlib import Path import simple_parsing +import torch +import wandb from torchvision.datasets import CIFAR10 -import cluv -import cluv.config - -SLURM_JOB_ID = int(os.environ["SLURM_JOB_ID"]) -SCRATCH = Path(os.environ["SCRATCH"]) -SLURM_TMPDIR = Path(os.environ["SLURM_TMPDIR"]) - -# IDEA: maybe load the cluv config and set the checkpoint_dir -# from cluv.config import load_cluv_config -config = cluv.config.current_cluster_config() -assert config, "Example must be run on a cluster." -assert config.results_path -assert config.datasets_path - -# TODO: datasets_path should be data_source on the source cluster, and cluster.datasets_path on others. +from cluv.config import current_cluster_config +from cluv.job import current_job_info @dataclass(frozen=True) class Args: """Command-line arguments for this example.""" - # NOTE: This should be the same as the `results_path` in the Cluv config. - results_path: Path = config.results_path - - # NOTE: This should be the same as the `datasets_path` in the Cluv config. - datasets_path: Path = config.datasets_path - # Time to wait before producing the result. # Can be useful to test and simulate preemption or cancelling jobs. wait_duration_seconds: int = 0 + seed: int = int(os.environ.get("SLURM_PROCID", "0")) + def main(args: Args | None = None): args = args or simple_parsing.parse(Args, description=__doc__) - print(f"Job {SLURM_JOB_ID} starts.") - dataset = CIFAR10(args.datasets_path) + job_info = current_job_info() + cluster_info = current_cluster_config() + assert job_info and cluster_info, "example should be run in a slurm job." + + print(f"Job {job_info.run_id} starts.") + wandb.init( + project="cluv-example", + name=job_info.run_id, + config=vars(args) + | {"job": dataclasses.asdict(job_info)} + | {"env": {k: v for k, v in os.environ.items() if k.startswith("SLURM")}}, + resume="allow", + ) + random.seed(args.seed) + torch.manual_seed(args.seed) + + # Test that we can load a dataset from the dataset_path (that was synced by Cluv) + dataset = CIFAR10(cluster_info.datasets_path, download=False) print(dataset) - time.sleep(args.wait_duration_seconds) + for i in range(args.wait_duration_seconds): + wandb.log({"step": i, "fake_loss": random.random()}) + time.sleep(1) + + print(f"Job {job_info.run_id} is about to end.") - print(f"Job {SLURM_JOB_ID} is about to end.") - results_file = args.results_path / "results.txt" + job_info.results_path.mkdir(parents=True, exist_ok=True) + results_file = job_info.results_path / "results.txt" with results_file.open("a") as f: - f.write(f"This is the result of job {SLURM_JOB_ID}\n") + f.write(f"This is the result of job {job_info.run_id}\n") if __name__ == "__main__": diff --git a/uv.lock b/uv.lock index 9e6d60c..7fc5e60 100644 --- a/uv.lock +++ b/uv.lock @@ -274,6 +274,7 @@ dev = [ { name = "ruff" }, { name = "torchvision" }, { name = "uv-dynamic-versioning" }, + { name = "wandb" }, ] [package.metadata] @@ -298,6 +299,7 @@ dev = [ { name = "ruff" }, { name = "torchvision", specifier = ">=0.25.0" }, { name = "uv-dynamic-versioning", specifier = ">=0.2.0" }, + { name = "wandb", specifier = ">=0.27.0" }, ] [[package]] @@ -540,6 +542,30 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/f7/ec/67fbef5d497f86283db54c22eec6f6140243aae73265799baaaa19cd17fb/ghp_import-2.1.0-py3-none-any.whl", hash = "sha256:8337dd7b50877f163d4c0289bc1f1c7f127550241988d568c1db512c4324a619", size = 11034, upload-time = "2022-05-02T15:47:14.552Z" }, ] +[[package]] +name = "gitdb" +version = "4.0.12" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "smmap" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/72/94/63b0fc47eb32792c7ba1fe1b694daec9a63620db1e313033d18140c2320a/gitdb-4.0.12.tar.gz", hash = "sha256:5ef71f855d191a3326fcfbc0d5da835f26b13fbcba60c32c21091c349ffdb571", size = 394684, upload-time = "2025-01-02T07:20:46.413Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/a0/61/5c78b91c3143ed5c14207f463aecfc8f9dbb5092fb2869baf37c273b2705/gitdb-4.0.12-py3-none-any.whl", hash = "sha256:67073e15955400952c6565cc3e707c554a4eea2e428946f7a4c162fab9bd9bcf", size = 62794, upload-time = "2025-01-02T07:20:43.624Z" }, +] + +[[package]] +name = "gitpython" +version = "3.1.50" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "gitdb" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/33/f6/354ae6491228b5eb40e10d89c4d13c651fe1cf7556e35ebdded50cff57ce/gitpython-3.1.50.tar.gz", hash = "sha256:80da2d12504d52e1f998772dc5baf6e553f8d2fcfe1fcc226c9d9a2ee3372dcc", size = 219798, upload-time = "2026-05-06T04:01:26.571Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/20/7a/1c6e3562dfd8950adbb11ffbc65d21e7c89d01a6e4f137fa981056de25c5/gitpython-3.1.50-py3-none-any.whl", hash = "sha256:d352abe2908d07355014abdd21ddf798c2a961469239afec4962e9da884858f9", size = 212507, upload-time = "2026-05-06T04:01:23.799Z" }, +] + [[package]] name = "griffelib" version = "2.0.2" @@ -1175,6 +1201,21 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/84/03/0d3ce49e2505ae70cf43bc5bb3033955d2fc9f932163e84dc0779cc47f48/prompt_toolkit-3.0.52-py3-none-any.whl", hash = "sha256:9aac639a3bbd33284347de5ad8d68ecc044b91a762dc39b7c21095fcd6a19955", size = 391431, upload-time = "2025-08-27T15:23:59.498Z" }, ] +[[package]] +name = "protobuf" +version = "7.35.0" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/60/fd/5b1491d9e4b586d621c54f4c36b888714164b6875f8d6afa3f9072906a51/protobuf-7.35.0.tar.gz", hash = "sha256:a2efd84605f41e559f1881b0912b44099d0a2ac9bf46b3474823f10fb393b0e6", size = 458677, upload-time = "2026-05-19T23:02:29.197Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/83/ee/93d06e358a4aa32280b00e722d3ea0a1f25fc3cc5778d80581c9cca2c10e/protobuf-7.35.0-cp310-abi3-macosx_10_9_universal2.whl", hash = "sha256:66be6c513931c794fa92c080ffee41671390da3d79da219cf9c0c0907f035dda", size = 433225, upload-time = "2026-05-19T23:02:19.884Z" }, + { url = "https://files.pythonhosted.org/packages/8b/39/1c76c2da93f3c507e958e0aecee2391cc44d4625de6c728bbc555195b5a8/protobuf-7.35.0-cp310-abi3-manylinux2014_aarch64.whl", hash = "sha256:fcbe42a4ac09d3ec9c987ddfcd956afd0b15f1ff613bd8371bde9405ffd5c8e5", size = 328847, upload-time = "2026-05-19T23:02:22.3Z" }, + { url = "https://files.pythonhosted.org/packages/91/1a/39f7ce90a238c1a987a4d81ec26379e02ca0aff367de68e4a1fa474215b9/protobuf-7.35.0-cp310-abi3-manylinux2014_s390x.whl", hash = "sha256:4cbf5cc286130e06a6c9bbefac442431173906dfcc979712183d4adcc01b37ee", size = 344030, upload-time = "2026-05-19T23:02:23.591Z" }, + { url = "https://files.pythonhosted.org/packages/70/5b/6baf9008817964454055ff3fe65f1de0b5f1e26c80c82f7fb108b7cd4ea3/protobuf-7.35.0-cp310-abi3-manylinux2014_x86_64.whl", hash = "sha256:6c0f98f10c8a05ea30f8993dfef2de093d27b490fdae78bb60c8343795d55011", size = 327130, upload-time = "2026-05-19T23:02:24.637Z" }, + { url = "https://files.pythonhosted.org/packages/8e/e5/e46adb0badc388bfb84877a5f9f026aff63f60e611016cf64dbe77e05446/protobuf-7.35.0-cp310-abi3-win32.whl", hash = "sha256:4c4617b83ade0e279d1d2bfe04025a1adb87f9ed657de038620dc0ff959357f6", size = 428946, upload-time = "2026-05-19T23:02:25.741Z" }, + { url = "https://files.pythonhosted.org/packages/a7/ab/547fbd9e16d879dd13c167478f8ae0a83a428008ca07a5e06acdc23ad473/protobuf-7.35.0-cp310-abi3-win_amd64.whl", hash = "sha256:f05bcadf9a2a6b8dda047007075135fb7d08c73d9177aabc067e1be46881a201", size = 439996, upload-time = "2026-05-19T23:02:26.808Z" }, + { url = "https://files.pythonhosted.org/packages/b8/ef/50433d346c56657a70d27f156c7b349ac59a068b01de4eb796e747eecc43/protobuf-7.35.0-py3-none-any.whl", hash = "sha256:c13f325cf242bad135c350629eeb5d54b24228eb472fb3e2e9ebbd4c5dc20ca0", size = 171659, upload-time = "2026-05-19T23:02:27.842Z" }, +] + [[package]] name = "pycparser" version = "3.0" @@ -1530,6 +1571,19 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/c0/98/6beb4b351e472e5f4c4613f7c35a5290b8be2497e183825310c4c3a3984b/ruff-0.15.12-py3-none-win_arm64.whl", hash = "sha256:a538f7a82d061cee7be55542aca1d86d1393d55d81d4fcc314370f4340930d4f", size = 11120821, upload-time = "2026-04-24T18:16:57.979Z" }, ] +[[package]] +name = "sentry-sdk" +version = "2.61.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "certifi" }, + { name = "urllib3" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/52/4d/3c66e6045bd2071256b6b6fdcb0cc02b86ce54b2acc2ceac79af8e0efbb5/sentry_sdk-2.61.0.tar.gz", hash = "sha256:1ca9b4bb777eb5be67004edab7eb894f21c6301f1d05ed64966719ad5d1764ce", size = 458510, upload-time = "2026-05-28T09:40:28.917Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/21/5a/9794736d5802689c1a48862e6afe6b7f3e86cc37c15d4a84bc0143877dc1/sentry_sdk-2.61.0-py3-none-any.whl", hash = "sha256:ec4d30273909cb1d198e03208b16ee70e2bc5d90a16fd9f1fb2fc6a72e1f03dc", size = 483111, upload-time = "2026-05-28T09:40:27.027Z" }, +] + [[package]] name = "setuptools" version = "82.0.1" @@ -1561,6 +1615,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/b7/ce/149a00dd41f10bc29e5921b496af8b574d8413afcd5e30dfa0ed46c2cc5e/six-1.17.0-py2.py3-none-any.whl", hash = "sha256:4721f391ed90541fddacab5acf947aa0d3dc7d27b2e1e8eda2be8970586c3274", size = 11050, upload-time = "2024-12-04T17:35:26.475Z" }, ] +[[package]] +name = "smmap" +version = "5.0.3" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/1f/ea/49c993d6dfdd7338c9b1000a0f36817ed7ec84577ae2e52f890d1a4ff909/smmap-5.0.3.tar.gz", hash = "sha256:4d9debb8b99007ae47165abc08670bd74cb74b5227dda7f643eccc4e9eb5642c", size = 22506, upload-time = "2026-03-09T03:43:26.1Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/c1/d4/59e74daffcb57a07668852eeeb6035af9f32cbfd7a1d2511f17d2fe6a738/smmap-5.0.3-py3-none-any.whl", hash = "sha256:c106e05d5a61449cf6ba9a1e650227ecfb141590d2a98412103ff35d89fc7b2f", size = 24390, upload-time = "2026-03-09T03:43:24.361Z" }, +] + [[package]] name = "sshconf" version = "0.2.7" @@ -1750,6 +1813,35 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/89/39/35773a629ac27d8803ff5ed86bde89d06f77041d7afa0a06cdc584ee8c6f/uv_dynamic_versioning-0.14.0-py3-none-any.whl", hash = "sha256:e087c346a786e98d41292ac2315180fb700cedfb30565fc973d64ce11a112387", size = 12172, upload-time = "2026-03-22T04:53:35.063Z" }, ] +[[package]] +name = "wandb" +version = "0.27.0" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "click" }, + { name = "gitpython" }, + { name = "packaging" }, + { name = "platformdirs" }, + { name = "protobuf" }, + { name = "pydantic" }, + { name = "pyyaml" }, + { name = "requests" }, + { name = "sentry-sdk" }, + { name = "typing-extensions" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/8e/31/fe53d06b75ef0a7f2f0ee5931a89f7aedc27d233840b1839616860fed256/wandb-0.27.0.tar.gz", hash = "sha256:579e75300173059f9334e1f513a79ef15f6d9ea5c74e20d695633648cdd02031", size = 41090732, upload-time = "2026-05-14T03:44:08.894Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/ea/5e/2c199e70e636ecfd217cde0bc7469f4511e1d03d0685eb92bfdfce391430/wandb-0.27.0-py3-none-macosx_12_0_arm64.whl", hash = "sha256:c156be4851485f3c4160cb6eb2e8991b4cdeffbccefc5636d33cf5e254847365", size = 24886476, upload-time = "2026-05-14T03:43:27.569Z" }, + { url = "https://files.pythonhosted.org/packages/0b/cd/a617c871cd304a9804e56a7ec2ec2c65685bf0091a2b9f91910175a149e2/wandb-0.27.0-py3-none-macosx_12_0_x86_64.whl", hash = "sha256:20179f38afb0158859a4141d29ac650d3fdbd0cf801a74ce25565c934f03776c", size = 26045779, upload-time = "2026-05-14T03:43:31.999Z" }, + { url = "https://files.pythonhosted.org/packages/10/0a/d3f159a201530b84b72ca5f98c68d1f351c2d9a1864558ed76c811407fae/wandb-0.27.0-py3-none-manylinux_2_28_aarch64.whl", hash = "sha256:626497d7975fa898d0a4a239da7a510483495ca3514510dbe75004a25963af4d", size = 25480764, upload-time = "2026-05-14T03:43:35.922Z" }, + { url = "https://files.pythonhosted.org/packages/5f/6a/8721fcdf71d42639191040a77a585d2982402b1754700cb2ecfc2ca1470a/wandb-0.27.0-py3-none-manylinux_2_28_x86_64.whl", hash = "sha256:f772da7005cc26a2a32b729a16982a583dc68b3d493df6a09d0aa5c5ca5a2060", size = 27256204, upload-time = "2026-05-14T03:43:39.765Z" }, + { url = "https://files.pythonhosted.org/packages/00/5e/279d167ba79fb7a8a43401c9f25efd0f6663ee9bd1eaf5a8578530198888/wandb-0.27.0-py3-none-musllinux_1_2_aarch64.whl", hash = "sha256:63acfc5b994e4a90e4a2fbdee6d45e664da3dd865bb1419942c8995c06c41cf1", size = 25647469, upload-time = "2026-05-14T03:43:44.817Z" }, + { url = "https://files.pythonhosted.org/packages/94/51/a69ac59300e3c813939d0764348959ed2a21e14c668cb1cebcb04010da6a/wandb-0.27.0-py3-none-musllinux_1_2_x86_64.whl", hash = "sha256:17aae6e4a88cd05c00ea8f546220918e3ebb6f8c1c36b70ef04a5ac75f0d7160", size = 27599005, upload-time = "2026-05-14T03:43:50.926Z" }, + { url = "https://files.pythonhosted.org/packages/5f/40/bf510c8758727df020f83b717ebc1fcc1739ed7f6ae1796ebef60bf6f592/wandb-0.27.0-py3-none-win32.whl", hash = "sha256:0bd5659417e386bf6538b5e2ffe6885774c6197f0e4853bfed517d5b0db457f1", size = 25036164, upload-time = "2026-05-14T03:43:54.839Z" }, + { url = "https://files.pythonhosted.org/packages/54/ff/69f88e7d90c22b79bcb911143c13e59742ee192080b21015ff83a5a1f60a/wandb-0.27.0-py3-none-win_amd64.whl", hash = "sha256:89d584b73166eecee96fb446f18d0e45b1aa45aba6a3696296f3f06d7454516b", size = 25036170, upload-time = "2026-05-14T03:43:59.227Z" }, + { url = "https://files.pythonhosted.org/packages/f6/38/f7efd7a87297a55c7e9a331a1dbb5b19e54aeacc11fe6f43f8636a73987c/wandb-0.27.0-py3-none-win_arm64.whl", hash = "sha256:a6c129c311edf210a2b4f2f4acc557eff522628125f5f28ed27df19c16c07079", size = 22972710, upload-time = "2026-05-14T03:44:03.275Z" }, +] + [[package]] name = "watchdog" version = "6.0.0" From 194bf3be7e06fdd7d2560a77cd61f1e9db969b29 Mon Sep 17 00:00:00 2001 From: Fabrice Normandin Date: Fri, 29 May 2026 15:20:00 -0400 Subject: [PATCH 10/22] Small tweaks Signed-off-by: Fabrice Normandin --- cluv/config.py | 16 ++++++++-------- 1 file changed, 8 insertions(+), 8 deletions(-) diff --git a/cluv/config.py b/cluv/config.py index 907371b..0d1c70d 100644 --- a/cluv/config.py +++ b/cluv/config.py @@ -46,7 +46,7 @@ class ClusterConfig: results_path: Path """Path to the results directory for a specific cluster.""" - datasets_path: Path + datasets_path: Path | None """Different path where the datasets should be replicated on this cluster. When `None`, this defaults to the top-level config's `datasets_path`. @@ -58,7 +58,7 @@ def resolve_env_vars_in_paths(self): return ClusterConfig( env=self.env, results_path=resolve_env_vars(self.results_path), - datasets_path=resolve_env_vars(self.datasets_path), + datasets_path=resolve_env_vars(self.datasets_path) if self.datasets_path else None, ) @@ -75,20 +75,19 @@ class CluvConfig(BaseModel): On Slurm clusters, this will be a symlink to a folder in `$SCRATCH//`. """ + data_source: str | None + """`hostname:/path` of where to get the data from.""" + datasets_path: str | None """Path to a dataset directory, for example, `'$SCRATCH/my_dataset'` This folder will be synced from the current cluster to all other clusters at their respective `dataset_path`. """ - data_source: str | None - """`hostname:/path` of where to get the data from.""" - clusters: dict[str, PartialClusterConfig] = {} """Configuration options for each cluster. - The keys are cluster names; each value is a `ClusterConfig` whose `env` dict contains - environment variables to set when running Slurm commands on that cluster. + The keys are cluster names, and values are configs that override options for that cluster. """ @property @@ -102,10 +101,11 @@ def get_cluster_config(self, cluster: str) -> ClusterConfig: """ cluv_config = load_cluv_config(find_pyproject()) cluster_config = cluv_config.clusters[cluster] + datasets_path = cluster_config.datasets_path or cluv_config.datasets_path return ClusterConfig( env=cluv_config.env | cluster_config.env, results_path=Path(cluster_config.results_path or cluv_config.results_path), - datasets_path=Path(cluster_config.datasets_path or cluv_config.datasets_path), + datasets_path=Path(datasets_path) if datasets_path else None, ) From fb70cb20a0597de13aad6b4626a9836738d404c0 Mon Sep 17 00:00:00 2001 From: Fabrice Normandin Date: Fri, 29 May 2026 15:20:52 -0400 Subject: [PATCH 11/22] Typing fix Signed-off-by: Fabrice Normandin --- tests/example.py | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/tests/example.py b/tests/example.py index 22506e5..6d6a1f2 100644 --- a/tests/example.py +++ b/tests/example.py @@ -34,7 +34,7 @@ def main(args: Args | None = None): job_info = current_job_info() cluster_info = current_cluster_config() - assert job_info and cluster_info, "example should be run in a slurm job." + assert job_info and cluster_info, "This example should be run in a slurm job." print(f"Job {job_info.run_id} starts.") wandb.init( @@ -49,6 +49,7 @@ def main(args: Args | None = None): torch.manual_seed(args.seed) # Test that we can load a dataset from the dataset_path (that was synced by Cluv) + assert cluster_info.datasets_path, "This example requires a datasets_path to be set." dataset = CIFAR10(cluster_info.datasets_path, download=False) print(dataset) From ba8f5d16f3178c8d1a81e690332926d4b6db876c Mon Sep 17 00:00:00 2001 From: Fabrice Normandin Date: Fri, 29 May 2026 15:24:35 -0400 Subject: [PATCH 12/22] Make the fake loss nicer Signed-off-by: Fabrice Normandin --- tests/example.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/tests/example.py b/tests/example.py index 6d6a1f2..f6f1df9 100644 --- a/tests/example.py +++ b/tests/example.py @@ -4,13 +4,16 @@ """ import dataclasses +import math import os import random +import sys import time from dataclasses import dataclass import simple_parsing import torch +import tqdm import wandb from torchvision.datasets import CIFAR10 @@ -53,9 +56,12 @@ def main(args: Args | None = None): dataset = CIFAR10(cluster_info.datasets_path, download=False) print(dataset) - for i in range(args.wait_duration_seconds): - wandb.log({"step": i, "fake_loss": random.random()}) + for i in tqdm.tqdm(range(args.wait_duration_seconds), disable=(not sys.stdout.isatty())): + # Some fake, loss that varies a bit between seeds and decreases over time. + fake_loss = math.exp(-i / 10) + random.random() * 0.1 time.sleep(1) + wandb.log({"step": i, "loss": fake_loss}) + print(f"Step {i}: loss={fake_loss}") print(f"Job {job_info.run_id} is about to end.") From 1b2c3b156c33f66e25e4451e9bdab7c5ff62cdef Mon Sep 17 00:00:00 2001 From: Fabrice Normandin Date: Fri, 29 May 2026 15:56:47 -0400 Subject: [PATCH 13/22] Claude-assisted implementation of sync datasets Signed-off-by: Fabrice Normandin --- cluv/__main__.py | 7 +++ cluv/cli/sync.py | 147 +++++++++++++++++++++++++++++++---------------- cluv/config.py | 12 ++-- 3 files changed, 113 insertions(+), 53 deletions(-) diff --git a/cluv/__main__.py b/cluv/__main__.py index 1c71424..4bae04f 100644 --- a/cluv/__main__.py +++ b/cluv/__main__.py @@ -174,6 +174,13 @@ def add_sync_args( "Use a comma to separate multiple clusters." ), ) + sync_parser.add_argument( + "--sync-datasets", + dest="sync_datasets", + action=argparse.BooleanOptionalAction, + default=True, + help="Push datasets from data_source to each cluster. Requires data_source in config.", + ) # TODO: Try to add a 'remainder' arg to pass extra args to `uv sync` on the remote cluster, but it seems to be a bit tricky. # sync_parser.add_argument( # "--", diff --git a/cluv/cli/sync.py b/cluv/cli/sync.py index 2784bb3..77d214f 100644 --- a/cluv/cli/sync.py +++ b/cluv/cli/sync.py @@ -32,7 +32,7 @@ milatools.utils.parallel_progress.console = console logger = logging.getLogger(__name__) -__all__ = ["sync", "install_uv", "clone_project", "fetch_results"] +__all__ = ["sync", "install_uv", "clone_project", "fetch_results", "sync_datasets_to_remote"] # TODO: Control the 'hide' and 'display' / etc using the --verbose flag value, in addition to the loglevel. @@ -41,7 +41,9 @@ async def sync( - clusters: list[str] | None = None, uv_sync_args: list[str] | None = None + clusters: list[str] | None = None, + uv_sync_args: list[str] | None = None, + sync_datasets: bool = True, ) -> list[Remote]: """Synchronizes the current project across clusters. @@ -96,7 +98,9 @@ async def sync( tasks: list[AsyncTaskFn] = [] task_descriptions: list[str] = [] for remote in remotes: - tasks.append(functools.partial(sync_task_function, remote=remote)) + tasks.append( + functools.partial(sync_task_function, remote=remote, sync_datasets=sync_datasets) + ) task_descriptions.append(f"{here or 'local'} -> {remote.hostname}") await run_async_tasks_with_progress_bar( @@ -110,6 +114,7 @@ async def sync( async def sync_task_function( report_progress: ReportProgressFn, remote: Remote, + sync_datasets: bool = True, ): """Syncs a single cluster, and reports progress using the provided `report_progress` function.""" project_path = PurePosixPath(find_pyproject().parent.relative_to(Path.home())) @@ -119,7 +124,16 @@ def _update_progress(progress: int, status: str, total: int): info = textwrap.shorten(status, 50, placeholder="...") report_progress(progress=progress, total=total, info=info) - num_tasks = 4 + source_host = source_path = None + should_sync_datasets = ( + sync_datasets and config.data_source is not None and config.datasets_path is not None + ) + if should_sync_datasets: + assert config.data_source + source_host, source_path = config.data_source.split(":", 1) + should_sync_datasets = remote.hostname != source_host + + num_tasks = 5 if should_sync_datasets else 4 _update_progress(0, "Checking/Installing UV", num_tasks) await install_uv(remote) @@ -128,10 +142,27 @@ def _update_progress(progress: int, status: str, total: int): await clone_project(remote) _update_progress(2, "Running 'uv sync'", num_tasks) - await remote.run(f"bash -l -c 'uv --directory={project_path} sync --quiet'") + await remote.run(f"bash --login -c 'uv --directory={project_path} sync --quiet'") + + step = 3 + if should_sync_datasets: + _update_progress(step, "Syncing datasets", num_tasks) + datasets_path_template = str(config.get_cluster_config(remote.hostname).datasets_path) + resolved_path = ( + await remote.get_output( + f"bash --login -c 'echo {datasets_path_template}'", + hide=True, + display=False, + ) + ).strip() + await remote.run(f"mkdir -p {resolved_path}", hide=True) + assert source_host and source_path + await sync_datasets_to_remote(source_host, source_path, remote, resolved_path) + step += 1 - _update_progress(3, "Fetching results", num_tasks) - await fetch_results(remote, config.results_path) + results_symlink = config.results_symlink or Path(config.results_path).name + _update_progress(step, "Fetching results", num_tasks) + await fetch_results(remote, results_symlink, config.results_path) _update_progress(num_tasks, "Done", num_tasks) @@ -253,70 +284,90 @@ async def clone_project(remote: Remote): await remote.run(f"git -C {git_root_path} pull", hide=False) -async def fetch_results(remote: Remote, results_path: Path | str): - """Fetches results from all remote clusters to the current (mila for now) cluster using rsync.""" - results_path = Path(results_path) - assert not results_path.is_absolute() - project_dir = find_pyproject().parent - - results_path_relative_to_home = (project_dir / results_path).relative_to(Path.home()) +async def sync_datasets_to_remote( + source_host: str, + source_path: str, + target_remote: Remote, + target_datasets_path: str, +): + """Push dataset from source_host:source_path to target_remote using source-push rsync. - # TODO: to simplify, for now we assume that the results are stored in a directory directly under the project directory. - # A directory with the same name (e.g. logs) is created in $SCRATCH. - # This could cause some confusion if there are multiple projects with a `logs` directory, since we'd see the logs - # from different projects in the same place. To fix this, for now we use `$SCRATCH/logs/{project_name}` as the `logs` dir. + If source_host is the current cluster, rsync runs locally. Otherwise it SSHes into the + source cluster and runs rsync from there, avoiding the local machine as an intermediary. + target_datasets_path must already have env vars resolved (no $SCRATCH etc.). + """ + rsync_args = ( + "rsync", + "--archive", + "--compress", + "--verbose", + "--progress", + "--copy-links", # follow symlinks (git-annex stores data behind symlinks) + "--exclude=.git", # skip .git — annex internals have restricted permissions + f"{source_path}/", + f"{target_remote.hostname}:{target_datasets_path}/", + ) + here = current_cluster() + if source_host == here: + await run(rsync_args, hide=False) + else: + # Note: This might go though 2fa! + source_remote = await Remote.connect(source_host) + await source_remote.run(shlex.join(rsync_args), hide=False) - # Create the results directory if it doesn't exist. - # TODO: Create that result directory as a symlink to a dir in $SCRATCH? - results_path.mkdir(parents=True, exist_ok=True) +async def fetch_results(remote: Remote, results_symlink: str, results_path: str): + """Fetches results from a remote cluster to local using rsync via the results symlink.""" + project_dir = find_pyproject().parent + symlink_relative_to_home = project_dir.relative_to(Path.home()) / results_symlink + local_results_dir = project_dir / results_symlink + local_results_dir.mkdir(parents=True, exist_ok=True) - await create_results_dir_with_symlink_to_scratch(remote, results_path) + await create_results_dir_with_symlink_to_scratch(remote, results_symlink, results_path) await run( - # Using --full-form flags (not -avz) for better readability. ( "rsync", "--archive", "--verbose", "--compress", "--copy-links", - f"{remote.hostname}:{results_path_relative_to_home}", - str((Path.home() / results_path_relative_to_home).parent), - # shlex.split( - # f"rsync --archive --verbose --compress --copy-links " - # f"{remote.hostname}:{results_path_relative_to_home} {(Path.home() / results_path_relative_to_home).parent}" - # ) + f"{remote.hostname}:{symlink_relative_to_home}", + str(local_results_dir.parent), ), warn=True, hide=False, ) -async def create_results_dir_with_symlink_to_scratch(remote: Remote, results_path: Path): - """On the remote, symlink ~// -> $SCRATCH//. +async def create_results_dir_with_symlink_to_scratch( + remote: Remote, results_symlink: str, results_path: str +): + """On the remote, create results_path and symlink project/ -> results_path. - This keeps large outputs out of $HOME and in $SCRATCH where storage limits are more generous. + results_path may contain env vars (e.g. $SCRATCH); they are resolved via the remote login shell. """ project_dir = find_pyproject().parent project_dir_relative_to_home = project_dir.relative_to(Path.home()) - symlink_path = project_dir_relative_to_home / results_path + symlink_path = project_dir_relative_to_home / results_symlink - # On some clusters (e.g. Vulcan), $SCRATCH is only defined in login shells. - scratch = ( - await remote.get_output("bash -l -c 'echo $SCRATCH'", hide=True, warn=True, display=False) + # Resolve env vars (e.g. $SCRATCH) in results_path using the remote login shell. + resolved_path = ( + await remote.get_output( + f"bash -l -c 'echo {results_path}'", hide=True, warn=True, display=False + ) ).strip() - if not scratch: - logger.warning(f"Remote {remote.hostname} does not have $SCRATCH defined.") + if not resolved_path: + logger.warning( + f"Could not resolve results_path '{results_path}' on {remote.hostname}. Skipping symlink." + ) return - scratch_dir = f"{scratch}/{results_path}/{project_dir.name}" - - # Create the target directory in $SCRATCH if it doesn't already exist. - if not await remote_test("-d", scratch_dir, remote): - result = await remote.run(f"mkdir -p {scratch_dir}", warn=True, hide=True) + # Create the target directory if it doesn't already exist. + if not await remote_test("-d", resolved_path, remote): + result = await remote.run(f"mkdir -p {resolved_path}", warn=True, hide=True) if result.returncode != 0: logger.warning( - f"Failed to create {scratch_dir} on {remote.hostname}. " + f"Failed to create {resolved_path} on {remote.hostname}. " f"Results will be stored in {symlink_path}, which may fill up $HOME." ) await remote.run(f"mkdir -p {symlink_path}", warn=True, hide=True) @@ -329,20 +380,20 @@ async def create_results_dir_with_symlink_to_scratch(remote: Remote, results_pat # If a real file/directory exists there, warn — the user may be filling up $HOME. if await remote_test("-e", symlink_path, remote): logger.warning( - f"{symlink_path} on {remote.hostname} is a real directory, not a symlink to $SCRATCH. " - f"You may end up filling up $HOME. Consider replacing it with a symlink to {scratch_dir}." + f"{symlink_path} on {remote.hostname} is a real directory, not a symlink. " + f"You may end up filling up $HOME. Consider replacing it with a symlink to {resolved_path}." ) return # Nothing at the path yet — create the symlink. result = await remote.run( - f"ln -s -T {scratch_dir} {symlink_path}", + f"ln -s -T {resolved_path} {symlink_path}", warn=True, hide=True, ) if result.returncode != 0: logger.warning( - f"Failed to create symlink {symlink_path} -> {scratch_dir} on {remote.hostname}." + f"Failed to create symlink {symlink_path} -> {resolved_path} on {remote.hostname}." ) diff --git a/cluv/config.py b/cluv/config.py index 0d1c70d..90912ac 100644 --- a/cluv/config.py +++ b/cluv/config.py @@ -69,16 +69,18 @@ class CluvConfig(BaseModel): """Global environment variables set on all clusters when running Slurm commands.""" results_path: str - """Default path to the results directory for all clusters. + """Default path to the results directory for all clusters (may contain env vars like $SCRATCH).""" - !!! info - On Slurm clusters, this will be a symlink to a folder in `$SCRATCH//`. + results_symlink: str | None = None + """Name of the symlink created in the project directory pointing to results_path. + + When None, defaults to Path(results_path).name (the last component of results_path). """ - data_source: str | None + data_source: str | None = None """`hostname:/path` of where to get the data from.""" - datasets_path: str | None + datasets_path: str | None = None """Path to a dataset directory, for example, `'$SCRATCH/my_dataset'` This folder will be synced from the current cluster to all other clusters at their respective `dataset_path`. From a37d581e692f4576e084668d8e6e2c5c5dbce213 Mon Sep 17 00:00:00 2001 From: Fabrice Normandin Date: Fri, 29 May 2026 16:06:28 -0400 Subject: [PATCH 14/22] Small tweaks for the run id Signed-off-by: Fabrice Normandin --- cluv/job.py | 2 +- tests/example.py | 3 ++- 2 files changed, 3 insertions(+), 2 deletions(-) diff --git a/cluv/job.py b/cluv/job.py index 073e436..4ed8951 100644 --- a/cluv/job.py +++ b/cluv/job.py @@ -125,5 +125,5 @@ def _get_run_id(): # the id of the first job in the array. return f"{cluster}_{first_job_id}" if doing_job_packing: - return f"{cluster}_{SLURM_JOB_ID}_{SLURM_PROCID}" + return f"{cluster}_{SLURM_JOB_ID}_task{SLURM_PROCID}" return f"{cluster}_{SLURM_JOB_ID}" diff --git a/tests/example.py b/tests/example.py index f6f1df9..5cfe46f 100644 --- a/tests/example.py +++ b/tests/example.py @@ -27,7 +27,7 @@ class Args: # Time to wait before producing the result. # Can be useful to test and simulate preemption or cancelling jobs. - wait_duration_seconds: int = 0 + wait_duration_seconds: int = 60 seed: int = int(os.environ.get("SLURM_PROCID", "0")) @@ -43,6 +43,7 @@ def main(args: Args | None = None): wandb.init( project="cluv-example", name=job_info.run_id, + id=job_info.run_id, config=vars(args) | {"job": dataclasses.asdict(job_info)} | {"env": {k: v for k, v in os.environ.items() if k.startswith("SLURM")}}, From a242cb5cf878e36fba9b23317fcca0f8d6366890 Mon Sep 17 00:00:00 2001 From: Fabrice Normandin Date: Fri, 29 May 2026 16:39:36 -0400 Subject: [PATCH 15/22] Pull to current machine then push to remotes Signed-off-by: Fabrice Normandin --- cluv/cli/sync.py | 160 ++++++++++++++++++++++++++++------------------- cluv/remote.py | 14 ++++- 2 files changed, 110 insertions(+), 64 deletions(-) diff --git a/cluv/cli/sync.py b/cluv/cli/sync.py index 77d214f..a33c71f 100644 --- a/cluv/cli/sync.py +++ b/cluv/cli/sync.py @@ -24,15 +24,15 @@ ) from cluv.cli.login import get_remote_without_2fa_prompt, login -from cluv.config import find_pyproject, get_config +from cluv.config import CluvConfig, current_cluster_config, find_pyproject, get_config from cluv.remote import Remote, get_ssh_options_for_host, run -from cluv.utils import console, current_cluster +from cluv.utils import console, current_cluster, resolve_env_vars milatools.cli.console = console milatools.utils.parallel_progress.console = console logger = logging.getLogger(__name__) -__all__ = ["sync", "install_uv", "clone_project", "fetch_results", "sync_datasets_to_remote"] +__all__ = ["sync", "install_uv", "clone_project", "fetch_results"] # TODO: Control the 'hide' and 'display' / etc using the --verbose flag value, in addition to the loglevel. @@ -98,9 +98,7 @@ async def sync( tasks: list[AsyncTaskFn] = [] task_descriptions: list[str] = [] for remote in remotes: - tasks.append( - functools.partial(sync_task_function, remote=remote, sync_datasets=sync_datasets) - ) + tasks.append(functools.partial(sync_task_function, remote=remote)) task_descriptions.append(f"{here or 'local'} -> {remote.hostname}") await run_async_tasks_with_progress_bar( @@ -108,13 +106,17 @@ async def sync( task_descriptions=task_descriptions, overall_progress_task_description="[green]Syncing project", ) + + config = get_config() + if sync_datasets and config.data_source and config.datasets_path: + await _sync_datasets(remotes, config) + return remotes async def sync_task_function( report_progress: ReportProgressFn, remote: Remote, - sync_datasets: bool = True, ): """Syncs a single cluster, and reports progress using the provided `report_progress` function.""" project_path = PurePosixPath(find_pyproject().parent.relative_to(Path.home())) @@ -124,16 +126,7 @@ def _update_progress(progress: int, status: str, total: int): info = textwrap.shorten(status, 50, placeholder="...") report_progress(progress=progress, total=total, info=info) - source_host = source_path = None - should_sync_datasets = ( - sync_datasets and config.data_source is not None and config.datasets_path is not None - ) - if should_sync_datasets: - assert config.data_source - source_host, source_path = config.data_source.split(":", 1) - should_sync_datasets = remote.hostname != source_host - - num_tasks = 5 if should_sync_datasets else 4 + num_tasks = 4 _update_progress(0, "Checking/Installing UV", num_tasks) await install_uv(remote) @@ -144,24 +137,8 @@ def _update_progress(progress: int, status: str, total: int): _update_progress(2, "Running 'uv sync'", num_tasks) await remote.run(f"bash --login -c 'uv --directory={project_path} sync --quiet'") - step = 3 - if should_sync_datasets: - _update_progress(step, "Syncing datasets", num_tasks) - datasets_path_template = str(config.get_cluster_config(remote.hostname).datasets_path) - resolved_path = ( - await remote.get_output( - f"bash --login -c 'echo {datasets_path_template}'", - hide=True, - display=False, - ) - ).strip() - await remote.run(f"mkdir -p {resolved_path}", hide=True) - assert source_host and source_path - await sync_datasets_to_remote(source_host, source_path, remote, resolved_path) - step += 1 - results_symlink = config.results_symlink or Path(config.results_path).name - _update_progress(step, "Fetching results", num_tasks) + _update_progress(3, "Fetching results", num_tasks) await fetch_results(remote, results_symlink, config.results_path) _update_progress(num_tasks, "Done", num_tasks) @@ -284,36 +261,93 @@ async def clone_project(remote: Remote): await remote.run(f"git -C {git_root_path} pull", hide=False) -async def sync_datasets_to_remote( - source_host: str, - source_path: str, - target_remote: Remote, - target_datasets_path: str, -): - """Push dataset from source_host:source_path to target_remote using source-push rsync. +async def _sync_datasets(remotes: list[Remote], config: CluvConfig): + """Pull dataset from data_source once, then push to all target remotes in parallel.""" - If source_host is the current cluster, rsync runs locally. Otherwise it SSHes into the - source cluster and runs rsync from there, avoiding the local machine as an intermediary. - target_datasets_path must already have env vars resolved (no $SCRATCH etc.). - """ - rsync_args = ( - "rsync", - "--archive", - "--compress", - "--verbose", - "--progress", - "--copy-links", # follow symlinks (git-annex stores data behind symlinks) - "--exclude=.git", # skip .git — annex internals have restricted permissions - f"{source_path}/", - f"{target_remote.hostname}:{target_datasets_path}/", - ) - here = current_cluster() - if source_host == here: - await run(rsync_args, hide=False) + if not config.data_source: + logger.debug("No data_source specified in config, skipping dataset sync.") + return + + this_cluster = current_cluster() + source_host, source_path = config.data_source.split(":", 1) + + target_remotes = [r for r in remotes if r.hostname != source_host] + if not target_remotes: + logger.debug("No target remotes to sync datasets to, skipping dataset sync.") + return # no remotes to sync to. + + # First, pull the data from the data source to this machine if we are not on the source cluster. + if this_cluster == source_host: + # If we are on the source cluster, the local 'datasets_path' is the path from 'data_source'. + datasets_path = Path(source_path) else: - # Note: This might go though 2fa! - source_remote = await Remote.connect(source_host) - await source_remote.run(shlex.join(rsync_args), hide=False) + # Pull from source to the locally-resolved datasets_path, then reuse for all pushes. + datasets_path = (current_cluster_config() or config).datasets_path + if not datasets_path: + raise RuntimeError( + f"To sync datasets from {source_host}, you must set a datasets_path in the config for this cluster ({this_cluster or 'local machine'})." + ) + try: + datasets_path = resolve_env_vars(datasets_path) + except KeyError as e: + raise RuntimeError( + f"Cannot resolve datasets_path '{config.datasets_path}' on this machine: " + f"the {e} environment variable is not set.\n" + f"To avoid copying the datasets from {source_host} to this machine, run " + f"`cluv sync` from the source cluster ({source_host}), or use the " + f"`--no-sync-datasets` flag when running `uv sync` from this machine." + ) from e + + datasets_path.mkdir(parents=True, exist_ok=True) + console.log( + f"[green]Pulling datasets:[/green] {source_host}:{source_path} -> {datasets_path}" + ) + await run( + ( + "rsync", + "--archive", + "--verbose", + "--compress", + "--copy-links", + "--exclude=.git", + "--exclude=.datalad", + f"{source_host}:{source_path}/", + f"{datasets_path}/", + ), + warn=True, + hide=False, + ) + + console.log(f"[green]Pushing datasets to:[/green] {[r.hostname for r in target_remotes]}") + await asyncio.gather( + *(_push_datasets_to_remote(datasets_path, r, config) for r in target_remotes) + ) + + +async def _push_datasets_to_remote(local_source: Path, remote: Remote, config: CluvConfig): + """Push dataset from a local path to the remote cluster's datasets_path.""" + datasets_path_template = str(config.get_cluster_config(remote.hostname).datasets_path) + resolved_path = ( + await remote.get_output( + f"bash -l -c 'echo {datasets_path_template}'", hide=True, display=False + ) + ).strip() + await remote.run(f"mkdir -p {resolved_path}", hide=True) + await run( + ( + "rsync", + "--archive", + "--verbose", + "--compress", + "--copy-links", + "--exclude=.git", + f"{local_source}/", + f"{remote.hostname}:{resolved_path}/", + ), + # warn=True, + hide=False, + _display=True, + ) async def fetch_results(remote: Remote, results_symlink: str, results_path: str): @@ -353,7 +387,7 @@ async def create_results_dir_with_symlink_to_scratch( # Resolve env vars (e.g. $SCRATCH) in results_path using the remote login shell. resolved_path = ( await remote.get_output( - f"bash -l -c 'echo {results_path}'", hide=True, warn=True, display=False + f"bash --login -c 'echo {results_path}'", hide=True, warn=True, display=False ) ).strip() if not resolved_path: diff --git a/cluv/remote.py b/cluv/remote.py index 1916c3f..3b128a2 100644 --- a/cluv/remote.py +++ b/cluv/remote.py @@ -3,6 +3,7 @@ import asyncio import dataclasses import functools +import shlex import subprocess import sys from logging import getLogger as get_logger @@ -100,6 +101,7 @@ async def run( warn: bool = False, hide: Hide = False, _stacklevel: int = 2, + _display: bool = False, ) -> subprocess.CompletedProcess[str]: """Runs the command *asynchronously* in a subprocess and returns the result. @@ -120,7 +122,17 @@ async def run( subprocess.CalledProcessError If an error occurs when running the command and `warn` is `False`. """ - + if _display: + console.log( + ( + f"$ {shlex.join(program_and_args)}" + if input is None + else f"$ {program_and_args=}\n{input=}" + ), + style="green", + _stack_offset=_stacklevel + - 1, # to show a link to the code calling this, instead of here. + ) logger.debug(f"Calling `asyncio.create_subprocess_exec` with {program_and_args=}") proc = await asyncio.create_subprocess_exec( *program_and_args, From fb2c21419afa83d8b882a6c7f2978229bf26e088 Mon Sep 17 00:00:00 2001 From: Fabrice Normandin Date: Fri, 29 May 2026 16:45:02 -0400 Subject: [PATCH 16/22] Display commands Signed-off-by: Fabrice Normandin --- cluv/cli/sync.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/cluv/cli/sync.py b/cluv/cli/sync.py index a33c71f..e64fae8 100644 --- a/cluv/cli/sync.py +++ b/cluv/cli/sync.py @@ -314,8 +314,7 @@ async def _sync_datasets(remotes: list[Remote], config: CluvConfig): f"{source_host}:{source_path}/", f"{datasets_path}/", ), - warn=True, - hide=False, + _display=True, ) console.log(f"[green]Pushing datasets to:[/green] {[r.hostname for r in target_remotes]}") @@ -341,11 +340,10 @@ async def _push_datasets_to_remote(local_source: Path, remote: Remote, config: C "--compress", "--copy-links", "--exclude=.git", + "--exclude=.datalad", f"{local_source}/", f"{remote.hostname}:{resolved_path}/", ), - # warn=True, - hide=False, _display=True, ) From 7b6e1163e4f366221e55cd9043fb3a06dc91ffe0 Mon Sep 17 00:00:00 2001 From: Fabrice Normandin Date: Fri, 29 May 2026 17:07:22 -0400 Subject: [PATCH 17/22] Rename example Signed-off-by: Fabrice Normandin --- .../README.md | 0 examples/pytorch-example/pyproject.toml | 80 +++++++++++++++++++ .../scripts/job.sh | 0 .../src/pytorch_example}/__init__.py | 0 .../src/pytorch_example/main.py | 25 ++++++ examples/pytorch-setup/pyproject.toml | 56 ------------- .../pytorch-setup/src/pytorch_setup/main.py | 34 -------- pyproject.toml | 18 ++--- uv.lock | 16 ++-- 9 files changed, 122 insertions(+), 107 deletions(-) rename examples/{pytorch-setup => pytorch-example}/README.md (100%) create mode 100644 examples/pytorch-example/pyproject.toml rename examples/{pytorch-setup => pytorch-example}/scripts/job.sh (100%) rename examples/{pytorch-setup/src/pytorch_setup => pytorch-example/src/pytorch_example}/__init__.py (100%) rename tests/example.py => examples/pytorch-example/src/pytorch_example/main.py (68%) delete mode 100644 examples/pytorch-setup/pyproject.toml delete mode 100644 examples/pytorch-setup/src/pytorch_setup/main.py diff --git a/examples/pytorch-setup/README.md b/examples/pytorch-example/README.md similarity index 100% rename from examples/pytorch-setup/README.md rename to examples/pytorch-example/README.md diff --git a/examples/pytorch-example/pyproject.toml b/examples/pytorch-example/pyproject.toml new file mode 100644 index 0000000..3a4f51c --- /dev/null +++ b/examples/pytorch-example/pyproject.toml @@ -0,0 +1,80 @@ +[project] +name = "pytorch-example" +version = "0.1.0" +description = "Add your description here" +readme = "README.md" +requires-python = ">=3.13" +dependencies = [ + "cluv", + "numpy>=2.4.4", + "torch>=2.7.0,<2.11.0", + "torchvision>=0.25.0", + "wandb>=0.27.0", +] + +[build-system] +requires = ["hatchling"] +build-backend = "hatchling.build" + +### -------------- CLUV CONFIG -------------- ### + +[tool.cluv] +# Where to store job results by default. Can be overridden per cluster. +results_path = "$SCRATCH/logs/cluv" +# On clusters, Cluv creates a symlink (a shortcut) in your project folder to the results_path dir. +# This makes it easier to keep your project in $HOME and to see the results which are on $SCRATCH. +results_symlink = "logs" +# Where to read the data from when synchronizing data to all clusters. +data_source = "mila:/network/datasets/cifar10" +# Where the dataset should be replicated on all clusters. +# TODO?: On the source cluster (ex Mila), the folder will only contain symlinks, to avoid +# duplicating the data. +datasets_path = "datasets/cifar10" + +[tool.cluv.env] +# Environment variables applied when using Slurm commands on all clusters. +SBATCH_TIME = "3:00:00" +SBATCH_REQUEUE = "1" +# Assume that compute nodes don't have internet access by default. Override below when they do. +UV_OFFLINE="1" +WANDB_MODE="offline" + + +### -------------- Clusters Config -------------- ### + +[tool.cluv.clusters.mila] +# Overrides specific to the Mila cluster. +env = {UV_OFFLINE="0", WANDB_MODE="online"} +results_path = "$SCRATCH/logs/cluv" + +[tool.cluv.clusters.tamia] + +[tool.cluv.clusters.killarney] +# For example, you might not have a $SCRATCH on Killarney. This can be overwritten here. +results_path = "$HOME/logs/cluv" +datasets_path = "$HOME/datasets/cifar10" + +[tool.cluv.clusters.vulcan] + +[tool.cluv.clusters.rorqual] +env = {SBATCH_ACCOUNT="rrg-bengioy-ad"} + +[tool.cluv.clusters.fir] +env = {UV_OFFLINE="0", WANDB_MODE="online", SBATCH_ACCOUNT="rrg-bengioy-ad"} + +[tool.cluv.clusters.nibi] +env = {UV_OFFLINE="0", WANDB_MODE="online", SBATCH_ACCOUNT="rrg-bengioy-ad"} + +[tool.cluv.clusters.trillium] +env = {SBATCH_ACCOUNT="rrg-bengioy-ad"} + +[tool.cluv.clusters.trillium-gpu] +env = {SBATCH_ACCOUNT="rrg-bengioy-ad"} + +[tool.cluv.clusters.narval] +# Mila doesn't have an allocation on Narval anymore. +env = {SBATCH_ACCOUNT="def-bengioy"} + + +[tool.uv.sources] +cluv = { workspace = true } diff --git a/examples/pytorch-setup/scripts/job.sh b/examples/pytorch-example/scripts/job.sh similarity index 100% rename from examples/pytorch-setup/scripts/job.sh rename to examples/pytorch-example/scripts/job.sh diff --git a/examples/pytorch-setup/src/pytorch_setup/__init__.py b/examples/pytorch-example/src/pytorch_example/__init__.py similarity index 100% rename from examples/pytorch-setup/src/pytorch_setup/__init__.py rename to examples/pytorch-example/src/pytorch_example/__init__.py diff --git a/tests/example.py b/examples/pytorch-example/src/pytorch_example/main.py similarity index 68% rename from tests/example.py rename to examples/pytorch-example/src/pytorch_example/main.py index 5cfe46f..f312f91 100644 --- a/tests/example.py +++ b/examples/pytorch-example/src/pytorch_example/main.py @@ -13,12 +13,14 @@ import simple_parsing import torch +import torch.backends import tqdm import wandb from torchvision.datasets import CIFAR10 from cluv.config import current_cluster_config from cluv.job import current_job_info +from cluv.utils import current_cluster @dataclass(frozen=True) @@ -33,6 +35,21 @@ class Args: def main(args: Args | None = None): + cluster = current_cluster() + cuda_built = torch.backends.cuda.is_built() + cuda_avail = torch.cuda.is_available() + device_count = torch.cuda.device_count() + + print(f"Run on cluster: {cluster}") + print(f"PyTorch built with CUDA: {cuda_built}") + print(f"PyTorch detects CUDA available: {cuda_avail}") + print(f"PyTorch-detected #GPUs: {device_count}") + if device_count == 0: + print(" No GPU detected.") + else: + for i in range(device_count): + print(f" GPU {i}: {torch.cuda.get_device_name(i)}") + args = args or simple_parsing.parse(Args, description=__doc__) job_info = current_job_info() @@ -57,6 +74,14 @@ def main(args: Args | None = None): dataset = CIFAR10(cluster_info.datasets_path, download=False) print(dataset) + # model = torchvision.models.resnet18(num_classes=10) + # optimizer = torch.optim.SGD(model.parameters(), lr=0.1) + # TODO: Make this a distributed example, so that it can also run on Tamia and others with + # full-node job allocations. + # from torch.distributed.fsdp import FullyShardedDataParallel as FSDP + # from torch.nn.parallel import DistributedDataParallel + # model = DistributedDataParallel(model) + for i in tqdm.tqdm(range(args.wait_duration_seconds), disable=(not sys.stdout.isatty())): # Some fake, loss that varies a bit between seeds and decreases over time. fake_loss = math.exp(-i / 10) + random.random() * 0.1 diff --git a/examples/pytorch-setup/pyproject.toml b/examples/pytorch-setup/pyproject.toml deleted file mode 100644 index 3c92d8a..0000000 --- a/examples/pytorch-setup/pyproject.toml +++ /dev/null @@ -1,56 +0,0 @@ -[project] -name = "pytorch-setup" -version = "0.1.0" -description = "Add your description here" -readme = "README.md" -requires-python = ">=3.13" -dependencies = [ - "numpy>=2.4.4", - "torch>=2.7.0,<2.11.0", -] - -[build-system] -requires = ["hatchling"] -build-backend = "hatchling.build" - -[tool.cluv] -results_path = "logs" - -[tool.cluv.env] -# Environment variables applied when using Slurm commands on all clusters. -SBATCH_TIME = "3:00:00" -SBATCH_REQUEUE = "1" - -[tool.cluv.clusters.mila] -env = {UV_OFFLINE="0", WANDB_MODE="online"} - -# PAICE clusters. -[tool.cluv.clusters.tamia] -env = {UV_OFFLINE="1", WANDB_MODE="offline"} - -[tool.cluv.clusters.killarney] -env = {UV_OFFLINE="1", WANDB_MODE="offline"} - -[tool.cluv.clusters.vulcan] -env = {UV_OFFLINE="1", WANDB_MODE="offline"} - -# DRAC clusters. -[tool.cluv.clusters.rorqual] -env = {UV_OFFLINE="1", WANDB_MODE="offline", SBATCH_ACCOUNT="rrg-bengioy-ad"} - -[tool.cluv.clusters.fir] -env = {UV_OFFLINE="0", WANDB_MODE="online", SBATCH_ACCOUNT="rrg-bengioy-ad"} - -[tool.cluv.clusters.nibi] -env = {UV_OFFLINE="0", WANDB_MODE="online", SBATCH_ACCOUNT="rrg-bengioy-ad"} - -[tool.cluv.clusters.trillium] -env = {UV_OFFLINE="1", WANDB_MODE="offline", SBATCH_ACCOUNT="rrg-bengioy-ad"} - -[tool.cluv.clusters.trillium-gpu] -env = {UV_OFFLINE="1", WANDB_MODE="offline", SBATCH_ACCOUNT="rrg-bengioy-ad"} - -[tool.cluv.clusters.narval] -# Mila doesn't have an allocation on Narval anymore. -# You can also use "def-yourusername" (the default partitions). -env = {UV_OFFLINE="1", WANDB_MODE="offline", SBATCH_ACCOUNT="def-bengioy"} diff --git a/examples/pytorch-setup/src/pytorch_setup/main.py b/examples/pytorch-setup/src/pytorch_setup/main.py deleted file mode 100644 index 5b4f1c8..0000000 --- a/examples/pytorch-setup/src/pytorch_setup/main.py +++ /dev/null @@ -1,34 +0,0 @@ -# Example from : https://github.com/mila-iqia/mila-docs/tree/master/docs/examples/frameworks/pytorch_setup - -import os -import socket - -import torch -import torch.backends - - -def current_cluster() -> str | None: - if socket.gethostname().endswith(".server.mila.quebec"): - return "mila" - if "CC_CLUSTER" in os.environ: - return os.environ["CC_CLUSTER"] - return None - -def main(): - cluster = current_cluster() - cuda_built = torch.backends.cuda.is_built() - cuda_avail = torch.cuda.is_available() - device_count = torch.cuda.device_count() - - print(f"Pytorch called on cluster: {cluster}") - print(f"PyTorch built with CUDA: {cuda_built}") - print(f"PyTorch detects CUDA available: {cuda_avail}") - print(f"PyTorch-detected #GPUs: {device_count}") - if device_count == 0: - print(" No GPU detected, not printing devices' names.") - else: - for i in range(device_count): - print(f" GPU {i}: {torch.cuda.get_device_name(i)}") - -if __name__ == "__main__": - main() diff --git a/pyproject.toml b/pyproject.toml index f108d7c..8b59be7 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -29,8 +29,6 @@ dev = [ "pytest-timeout>=2.3.1", "ruff", "pytest-skip-slow>=0.0.5", - "torchvision>=0.25.0", - "wandb>=0.27.0", ] [tool.pytest.ini_options] @@ -47,7 +45,7 @@ managed = true [tool.uv.workspace] members = [ - "examples/pytorch-setup", + "examples/pytorch-example", ] @@ -90,12 +88,12 @@ results_path = "$SCRATCH/logs/cluv" # On clusters, Cluv creates a symlink (a shortcut) in your project folder to the results_path dir. # This makes it easier to keep your project in $HOME and to see the results which are on $SCRATCH. results_symlink = "logs" -# Where to read the data from when synchronizing data to all clusters. -data_source = "mila:/network/datasets/cifar10" -# Where the dataset should be replicated on all clusters. -# TODO?: On the source cluster (ex Mila), the folder will only contain symlinks, to avoid -# duplicating the data. -datasets_path = "$SCRATCH/datasets/cifar10" +## Where to read the data from when synchronizing data to all clusters. +# data_source = "mila:/network/datasets/cifar10" +## Where the dataset should be replicated on all clusters. +## TODO?: On the source cluster (ex Mila), the folder will only contain symlinks, to avoid +## duplicating the data. +# datasets_path = "datasets/cifar10" [tool.cluv.env] # Environment variables applied when using Slurm commands on all clusters. @@ -117,7 +115,7 @@ env = {UV_OFFLINE="0", WANDB_MODE="online"} [tool.cluv.clusters.killarney] # For example, you might not have a $SCRATCH on Killarney. This can be overwritten here. results_path = "$HOME/logs/cluv" -datasets_path = "$HOME/datasets/cifar10" +# datasets_path = "$HOME/datasets/cifar10" [tool.cluv.clusters.vulcan] diff --git a/uv.lock b/uv.lock index 7fc5e60..9b09114 100644 --- a/uv.lock +++ b/uv.lock @@ -5,7 +5,7 @@ requires-python = ">=3.13" [manifest] members = [ "cluv", - "pytorch-setup", + "pytorch-example", ] [[package]] @@ -272,9 +272,7 @@ dev = [ { name = "pytest-skip-slow" }, { name = "pytest-timeout" }, { name = "ruff" }, - { name = "torchvision" }, { name = "uv-dynamic-versioning" }, - { name = "wandb" }, ] [package.metadata] @@ -297,9 +295,7 @@ dev = [ { name = "pytest-skip-slow", specifier = ">=0.0.5" }, { name = "pytest-timeout", specifier = ">=2.3.1" }, { name = "ruff" }, - { name = "torchvision", specifier = ">=0.25.0" }, { name = "uv-dynamic-versioning", specifier = ">=0.2.0" }, - { name = "wandb", specifier = ">=0.27.0" }, ] [[package]] @@ -1432,18 +1428,24 @@ wheels = [ ] [[package]] -name = "pytorch-setup" +name = "pytorch-example" version = "0.1.0" -source = { editable = "examples/pytorch-setup" } +source = { editable = "examples/pytorch-example" } dependencies = [ + { name = "cluv" }, { name = "numpy" }, { name = "torch" }, + { name = "torchvision" }, + { name = "wandb" }, ] [package.metadata] requires-dist = [ + { name = "cluv", editable = "." }, { name = "numpy", specifier = ">=2.4.4" }, { name = "torch", specifier = ">=2.7.0,<2.11.0" }, + { name = "torchvision", specifier = ">=0.25.0" }, + { name = "wandb", specifier = ">=0.27.0" }, ] [[package]] From eeec0f0def81ae4aa63c870d66306811d1fd3945 Mon Sep 17 00:00:00 2001 From: Fabrice Normandin Date: Fri, 29 May 2026 17:11:47 -0400 Subject: [PATCH 18/22] Fix config of pytorch example Signed-off-by: Fabrice Normandin --- examples/pytorch-example/pyproject.toml | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/examples/pytorch-example/pyproject.toml b/examples/pytorch-example/pyproject.toml index 3a4f51c..70270ad 100644 --- a/examples/pytorch-example/pyproject.toml +++ b/examples/pytorch-example/pyproject.toml @@ -20,7 +20,7 @@ build-backend = "hatchling.build" [tool.cluv] # Where to store job results by default. Can be overridden per cluster. -results_path = "$SCRATCH/logs/cluv" +results_path = "$SCRATCH/logs/pytorch_example" # On clusters, Cluv creates a symlink (a shortcut) in your project folder to the results_path dir. # This makes it easier to keep your project in $HOME and to see the results which are on $SCRATCH. results_symlink = "logs" @@ -29,7 +29,7 @@ data_source = "mila:/network/datasets/cifar10" # Where the dataset should be replicated on all clusters. # TODO?: On the source cluster (ex Mila), the folder will only contain symlinks, to avoid # duplicating the data. -datasets_path = "datasets/cifar10" +datasets_path = "$SCRATCH/datasets/cifar10" [tool.cluv.env] # Environment variables applied when using Slurm commands on all clusters. From 88e0ee78d010ddc73b82b137a685d28946f83f8e Mon Sep 17 00:00:00 2001 From: Fabrice Normandin Date: Fri, 29 May 2026 17:42:01 -0400 Subject: [PATCH 19/22] Pull before sync, push in task fns Signed-off-by: Fabrice Normandin --- cluv/cli/sync.py | 42 ++++++++++++++++++++-------- cluv/remote.py | 73 +++++++++++++++++++++++++++++------------------- 2 files changed, 75 insertions(+), 40 deletions(-) diff --git a/cluv/cli/sync.py b/cluv/cli/sync.py index e64fae8..f5e4b09 100644 --- a/cluv/cli/sync.py +++ b/cluv/cli/sync.py @@ -9,6 +9,7 @@ import shutil import subprocess import textwrap +from contextvars import ContextVar from pathlib import Path, PurePosixPath from typing import Literal @@ -95,38 +96,45 @@ async def sync( # TODO: Add an --ignore flag to ignore some clusters? console.log(f"[green]Synchronizing with the following clusters:[/green] {clusters}") + console_lock = asyncio.Lock() + tasks: list[AsyncTaskFn] = [] task_descriptions: list[str] = [] for remote in remotes: - tasks.append(functools.partial(sync_task_function, remote=remote)) + tasks.append( + functools.partial(sync_task_function, remote=remote, console_lock=console_lock) + ) task_descriptions.append(f"{here or 'local'} -> {remote.hostname}") + config = get_config() + if sync_datasets and config.data_source and config.datasets_path: + await _pull_datasets(remotes, config, _console_lock=console_lock) + await run_async_tasks_with_progress_bar( async_task_fns=tasks, task_descriptions=task_descriptions, overall_progress_task_description="[green]Syncing project", ) - config = get_config() - if sync_datasets and config.data_source and config.datasets_path: - await _sync_datasets(remotes, config) - return remotes async def sync_task_function( report_progress: ReportProgressFn, remote: Remote, + console_lock: asyncio.Lock | None = None, ): """Syncs a single cluster, and reports progress using the provided `report_progress` function.""" project_path = PurePosixPath(find_pyproject().parent.relative_to(Path.home())) config = get_config() + # for use in the sub-functions without having to pass it around everywhere? + ContextVar("console_lock").set(console_lock) def _update_progress(progress: int, status: str, total: int): info = textwrap.shorten(status, 50, placeholder="...") report_progress(progress=progress, total=total, info=info) - num_tasks = 4 + num_tasks = 5 if config.data_source else 4 _update_progress(0, "Checking/Installing UV", num_tasks) await install_uv(remote) @@ -141,6 +149,15 @@ def _update_progress(progress: int, status: str, total: int): _update_progress(3, "Fetching results", num_tasks) await fetch_results(remote, results_symlink, config.results_path) + if config.data_source: + _update_progress(4, "Syncing datasets", num_tasks) + here = current_cluster() + local_dataset_path = (config.get_cluster_config(here) if here else config).datasets_path + if not local_dataset_path: + raise RuntimeError("data_source is set, so dataset_path should also be set!") + local_dataset_path = resolve_env_vars(local_dataset_path) + await _push_datasets_to_remote(local_dataset_path, remote, config) + _update_progress(num_tasks, "Done", num_tasks) @@ -261,7 +278,9 @@ async def clone_project(remote: Remote): await remote.run(f"git -C {git_root_path} pull", hide=False) -async def _sync_datasets(remotes: list[Remote], config: CluvConfig): +async def _pull_datasets( + remotes: list[Remote], config: CluvConfig, _console_lock: asyncio.Lock | None = None +): """Pull dataset from data_source once, then push to all target remotes in parallel.""" if not config.data_source: @@ -315,12 +334,13 @@ async def _sync_datasets(remotes: list[Remote], config: CluvConfig): f"{datasets_path}/", ), _display=True, + _console_lock=_console_lock, ) - console.log(f"[green]Pushing datasets to:[/green] {[r.hostname for r in target_remotes]}") - await asyncio.gather( - *(_push_datasets_to_remote(datasets_path, r, config) for r in target_remotes) - ) + # console.log(f"[green]Pushing datasets to:[/green] {[r.hostname for r in target_remotes]}") + # await asyncio.gather( + # *(_push_datasets_to_remote(datasets_path, r, config) for r in target_remotes) + # ) async def _push_datasets_to_remote(local_source: Path, remote: Remote, config: CluvConfig): diff --git a/cluv/remote.py b/cluv/remote.py index 3b128a2..fb4ccc7 100644 --- a/cluv/remote.py +++ b/cluv/remote.py @@ -1,6 +1,7 @@ from __future__ import annotations import asyncio +import contextlib import dataclasses import functools import shlex @@ -71,17 +72,24 @@ async def run( command, ) + _display = False if display: - console.log( - ( - f"({self.hostname}) $ {command}" - if input is None - else f"({self.hostname}) $ {command=}\n{input=}" - ), - style="green", - _stack_offset=2, # to show a link to the code calling this, instead of here. + # Pass what to display to `run`, which uses a lock to keep the command and its output + # together in the console, instead of interleaving with other commands' outputs. + # Commands start running (and may error out) before being shown in the terminal though. + _display = ( + f"({self.hostname}) $ {command}" + if input is None + else f"({self.hostname}) $ {command=}\n{input=}" ) - return await run(ssh_command, input=input, warn=warn, hide=hide, _stacklevel=3) + return await run( + ssh_command, + input=input, + warn=warn, + hide=hide, + _display=_display, + _stacklevel=3, + ) async def get_output( self, @@ -101,7 +109,8 @@ async def run( warn: bool = False, hide: Hide = False, _stacklevel: int = 2, - _display: bool = False, + _display: bool | str = False, + _console_lock: asyncio.Lock | None = None, ) -> subprocess.CompletedProcess[str]: """Runs the command *asynchronously* in a subprocess and returns the result. @@ -122,17 +131,7 @@ async def run( subprocess.CalledProcessError If an error occurs when running the command and `warn` is `False`. """ - if _display: - console.log( - ( - f"$ {shlex.join(program_and_args)}" - if input is None - else f"$ {program_and_args=}\n{input=}" - ), - style="green", - _stack_offset=_stacklevel - - 1, # to show a link to the code calling this, instead of here. - ) + logger.debug(f"Calling `asyncio.create_subprocess_exec` with {program_and_args=}") proc = await asyncio.create_subprocess_exec( *program_and_args, @@ -179,14 +178,30 @@ async def run( stdout=stdout.decode(), stderr=stderr.decode(), ) - if result.stdout: - if hide not in [True, "out", "stdout"]: - print(result.stdout) - logger.debug(result.stdout) - if result.stderr: - if hide not in [True, "err", "stderr"]: - print(result.stderr, file=sys.stderr) - logger.debug(result.stderr) + # IDEA: Maybe we could grab an asyncio lock for the standard output, to keep the output of + # commands directly below the commands themselves? + async with _console_lock or contextlib.nullcontext(): + if _display: + console.log( + _display + if isinstance(_display, str) + else ( + f"$ {shlex.join(program_and_args)}" + if input is None + else f"$ {program_and_args=}\n{input=}" + ), + style="green", + _stack_offset=_stacklevel + - 1, # to show a link to the code calling this, instead of here. + ) + if result.stdout: + if hide not in [True, "out", "stdout"]: + print(result.stdout) + logger.debug(result.stdout) + if result.stderr: + if hide not in [True, "err", "stderr"]: + print(result.stderr, file=sys.stderr) + logger.debug(result.stderr) return result From 85afbbe2057f6bd911b5dd3383cc88d338ee882b Mon Sep 17 00:00:00 2001 From: Fabrice Normandin Date: Fri, 29 May 2026 17:46:41 -0400 Subject: [PATCH 20/22] Display commands and outputs (not interleaved) Signed-off-by: Fabrice Normandin --- cluv/cli/sync.py | 41 ++++++++++++++++------------------------- cluv/remote.py | 7 ++----- cluv/utils.py | 6 ++++++ 3 files changed, 24 insertions(+), 30 deletions(-) diff --git a/cluv/cli/sync.py b/cluv/cli/sync.py index f5e4b09..9ce4ad6 100644 --- a/cluv/cli/sync.py +++ b/cluv/cli/sync.py @@ -9,7 +9,6 @@ import shutil import subprocess import textwrap -from contextvars import ContextVar from pathlib import Path, PurePosixPath from typing import Literal @@ -27,7 +26,7 @@ from cluv.cli.login import get_remote_without_2fa_prompt, login from cluv.config import CluvConfig, current_cluster_config, find_pyproject, get_config from cluv.remote import Remote, get_ssh_options_for_host, run -from cluv.utils import console, current_cluster, resolve_env_vars +from cluv.utils import console, console_lock, current_cluster, resolve_env_vars milatools.cli.console = console milatools.utils.parallel_progress.console = console @@ -96,39 +95,34 @@ async def sync( # TODO: Add an --ignore flag to ignore some clusters? console.log(f"[green]Synchronizing with the following clusters:[/green] {clusters}") - console_lock = asyncio.Lock() - tasks: list[AsyncTaskFn] = [] task_descriptions: list[str] = [] for remote in remotes: - tasks.append( - functools.partial(sync_task_function, remote=remote, console_lock=console_lock) - ) + tasks.append(functools.partial(sync_task_function, remote=remote)) task_descriptions.append(f"{here or 'local'} -> {remote.hostname}") config = get_config() - if sync_datasets and config.data_source and config.datasets_path: - await _pull_datasets(remotes, config, _console_lock=console_lock) - await run_async_tasks_with_progress_bar( - async_task_fns=tasks, - task_descriptions=task_descriptions, - overall_progress_task_description="[green]Syncing project", - ) + token = console_lock.set(asyncio.Lock()) + try: + if sync_datasets and config.data_source and config.datasets_path: + await _pull_datasets(remotes, config) + + await run_async_tasks_with_progress_bar( + async_task_fns=tasks, + task_descriptions=task_descriptions, + overall_progress_task_description="[green]Syncing project", + ) + finally: + console_lock.reset(token) return remotes -async def sync_task_function( - report_progress: ReportProgressFn, - remote: Remote, - console_lock: asyncio.Lock | None = None, -): +async def sync_task_function(report_progress: ReportProgressFn, remote: Remote): """Syncs a single cluster, and reports progress using the provided `report_progress` function.""" project_path = PurePosixPath(find_pyproject().parent.relative_to(Path.home())) config = get_config() - # for use in the sub-functions without having to pass it around everywhere? - ContextVar("console_lock").set(console_lock) def _update_progress(progress: int, status: str, total: int): info = textwrap.shorten(status, 50, placeholder="...") @@ -278,9 +272,7 @@ async def clone_project(remote: Remote): await remote.run(f"git -C {git_root_path} pull", hide=False) -async def _pull_datasets( - remotes: list[Remote], config: CluvConfig, _console_lock: asyncio.Lock | None = None -): +async def _pull_datasets(remotes: list[Remote], config: CluvConfig): """Pull dataset from data_source once, then push to all target remotes in parallel.""" if not config.data_source: @@ -334,7 +326,6 @@ async def _pull_datasets( f"{datasets_path}/", ), _display=True, - _console_lock=_console_lock, ) # console.log(f"[green]Pushing datasets to:[/green] {[r.hostname for r in target_remotes]}") diff --git a/cluv/remote.py b/cluv/remote.py index fb4ccc7..852ad2f 100644 --- a/cluv/remote.py +++ b/cluv/remote.py @@ -10,7 +10,7 @@ from logging import getLogger as get_logger from typing import Callable, Literal, Self, TypeVar -from cluv.utils import console +from cluv.utils import console, console_lock logger = get_logger(__name__) @@ -110,7 +110,6 @@ async def run( hide: Hide = False, _stacklevel: int = 2, _display: bool | str = False, - _console_lock: asyncio.Lock | None = None, ) -> subprocess.CompletedProcess[str]: """Runs the command *asynchronously* in a subprocess and returns the result. @@ -178,9 +177,7 @@ async def run( stdout=stdout.decode(), stderr=stderr.decode(), ) - # IDEA: Maybe we could grab an asyncio lock for the standard output, to keep the output of - # commands directly below the commands themselves? - async with _console_lock or contextlib.nullcontext(): + async with console_lock.get() or contextlib.nullcontext(): if _display: console.log( _display diff --git a/cluv/utils.py b/cluv/utils.py index fd1a9e8..777971f 100644 --- a/cluv/utils.py +++ b/cluv/utils.py @@ -1,3 +1,5 @@ +import asyncio +import contextvars import os import socket import sys @@ -8,6 +10,10 @@ # todo: seeing some weird behaviour with stderr, the progress bars repeating themselves, etc. console = rich.console.Console(record=True, file=sys.stdout) +console_lock: contextvars.ContextVar[asyncio.Lock | None] = contextvars.ContextVar( + "console_lock", default=None +) + def current_cluster() -> str | None: if socket.gethostname().endswith(".server.mila.quebec"): From 32e58c4e444b615b7f20de955780467cec21df67 Mon Sep 17 00:00:00 2001 From: Fabrice Normandin Date: Fri, 29 May 2026 17:50:58 -0400 Subject: [PATCH 21/22] Add 'datasets' in .gitignore Signed-off-by: Fabrice Normandin --- examples/pytorch-example/.gitignore | 1 + 1 file changed, 1 insertion(+) create mode 100644 examples/pytorch-example/.gitignore diff --git a/examples/pytorch-example/.gitignore b/examples/pytorch-example/.gitignore new file mode 100644 index 0000000..aee11b2 --- /dev/null +++ b/examples/pytorch-example/.gitignore @@ -0,0 +1 @@ +datasets From 3a73689587d75cf30f7f304f0b5e3cbbaec65965 Mon Sep 17 00:00:00 2001 From: Fabrice Normandin Date: Fri, 29 May 2026 17:51:48 -0400 Subject: [PATCH 22/22] Add the 'site' dir to .gitignore Signed-off-by: Fabrice Normandin --- .gitignore | 2 ++ 1 file changed, 2 insertions(+) diff --git a/.gitignore b/.gitignore index ac5edbf..b95d9f2 100644 --- a/.gitignore +++ b/.gitignore @@ -14,3 +14,5 @@ wheels/ .claude .vscode logs/ +# created by `uv run mkdocs serve` and such. +site/