Skip to content

Commit 11e4c28

Browse files
committed
Refactor cluster status retrieval
1 parent 4932550 commit 11e4c28

1 file changed

Lines changed: 31 additions & 57 deletions

File tree

cluv/cli/status.py

Lines changed: 31 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -12,7 +12,6 @@
1212

1313
from cluv.cli.login import get_remote_without_2fa_prompt
1414
from cluv.config import get_config
15-
from cluv.remote import Remote
1615
from cluv.slurm import (
1716
StorageStats,
1817
parse_disk_quota,
@@ -49,6 +48,18 @@ class ClusterStatus:
4948
storage: StorageStats
5049

5150

51+
def get_default_cluster_status(cluster: str) -> ClusterStatus:
52+
return ClusterStatus(
53+
name=cluster,
54+
online=False,
55+
gpu_idle=0,
56+
gpu_total=0,
57+
gpu_model="?",
58+
jobs=JobStats(running=0, pending=0, my_running=0, my_pending=0),
59+
storage=StorageStats(home_used=0, home_quota=0, scratch_used=0, scratch_quota=0),
60+
)
61+
62+
5263
# ---------------------------------------------------------------------------
5364
# Real data layer
5465
# ---------------------------------------------------------------------------
@@ -90,14 +101,20 @@ class ClusterStatus:
90101
_MILA_CLUSTERS = {"mila"}
91102

92103

93-
async def get_real_cluster_status(remote: Remote) -> ClusterStatus:
104+
async def get_real_cluster_status(cluster: str) -> ClusterStatus:
94105
"""Fetch live Slurm data from a remote cluster and return a ClusterStatus.
95106
96107
Uses a single SSH round-trip. Falls back gracefully when commands are
97108
unavailable (e.g. partition-stats is DRAC-only).
98109
"""
99110

100-
cluster = remote.hostname
111+
# Use get_remote_without_2fa_prompt directly so we never filter out the
112+
# "current" cluster the way login() does. A working socket for mila is
113+
# perfectly usable even when /home/mila is mounted locally.
114+
remote = await get_remote_without_2fa_prompt(cluster)
115+
if remote is None:
116+
return get_default_cluster_status(cluster)
117+
101118
script = _REMOTE_SCRIPT_MILA if cluster in _MILA_CLUSTERS else _REMOTE_SCRIPT_DRAC
102119

103120
try:
@@ -109,15 +126,7 @@ async def get_real_cluster_status(remote: Remote) -> ClusterStatus:
109126
)
110127
except Exception as exc:
111128
logger.warning(f"[red]Could not reach {cluster}: {exc}[/red]")
112-
return ClusterStatus(
113-
name=cluster,
114-
online=False,
115-
gpu_idle=0,
116-
gpu_total=0,
117-
gpu_model="?",
118-
jobs=JobStats(running=0, pending=0, my_running=0, my_pending=0),
119-
storage=StorageStats(home_used=0, home_quota=0, scratch_used=0, scratch_quota=0),
120-
)
129+
return get_default_cluster_status(cluster)
121130

122131
parts = raw.split(_SEP)
123132
# Pad in case some sections are missing
@@ -197,33 +206,6 @@ async def get_real_cluster_status(remote: Remote) -> ClusterStatus:
197206
)
198207

199208

200-
async def get_all_cluster_statuses(
201-
remotes: list[Remote] | None = None,
202-
) -> tuple[list[ClusterStatus], bool]:
203-
"""Query clusters in parallel.
204-
205-
If *remotes* is provided, query exactly those connections.
206-
Otherwise, query all clusters that already have an active SSH connection
207-
(never blocks on 2FA).
208-
209-
Returns (statuses, any_live) where any_live is False when no cluster
210-
was reachable.
211-
"""
212-
if remotes is None:
213-
clusters = get_config().clusters
214-
remotes = [
215-
r
216-
for r in await asyncio.gather(*(get_remote_without_2fa_prompt(c) for c in clusters))
217-
if r is not None
218-
]
219-
220-
if not remotes:
221-
return [], False
222-
223-
statuses = list(await asyncio.gather(*(get_real_cluster_status(r) for r in remotes)))
224-
return statuses, True
225-
226-
227209
# ---------------------------------------------------------------------------
228210
# UI helpers
229211
# ---------------------------------------------------------------------------
@@ -285,9 +267,9 @@ def _build_cluster_table(data: list[ClusterStatus]) -> Table:
285267

286268
for c in data:
287269
if not c.online:
288-
status_cell = Text("⚠ offline", style="bold red")
270+
status_cell = Text("⚠ disconnected", style="bold red")
289271
else:
290-
status_cell = Text("● online", style="bold green")
272+
status_cell = Text("● connected", style="bold green")
291273

292274
my_jobs = Text(f"{c.jobs.my_running} / {c.jobs.my_pending}", style="cyan")
293275
all_jobs = Text(f"{c.jobs.running} / {c.jobs.pending}", style="white")
@@ -381,24 +363,16 @@ async def status(table: str) -> None:
381363
"""
382364
console = Console()
383365
clusters = get_config().clusters_names
384-
clusters = list(clusters or [])
385-
386-
if clusters:
387-
# Use get_remote_without_2fa_prompt directly so we never filter out the
388-
# "current" cluster the way login() does. A working socket for mila is
389-
# perfectly usable even when /home/mila is mounted locally.
390-
remotes = [
391-
r
392-
for r in await asyncio.gather(*(get_remote_without_2fa_prompt(c) for c in clusters))
393-
if r is not None
394-
]
395-
data, is_live = await get_all_cluster_statuses(remotes=remotes)
396-
else:
397-
data, is_live = await get_all_cluster_statuses()
398366

399-
if not is_live:
367+
# Query clusters in parallel
368+
data: list[ClusterStatus] = [
369+
d for d in await asyncio.gather(*(get_real_cluster_status(c) for c in clusters))
370+
]
371+
372+
# Show a tip message if all clusters are offline, which likely means the user hasn't logged in yet (no control sockets).
373+
if all(not c.online for c in data):
400374
console.print(
401-
"[yellow]No active cluster connections found. Run [bold]cluv login[/bold] first.[/yellow]"
375+
"[yellow]No active connections to any clusters found. Run [bold]cluv login[/bold] first.[/yellow]"
402376
)
403377

404378
console.print()

0 commit comments

Comments
 (0)