diff --git a/cluv/__main__.py b/cluv/__main__.py index 1c71424..8223b14 100644 --- a/cluv/__main__.py +++ b/cluv/__main__.py @@ -139,18 +139,17 @@ def add_submit_args( def add_status_args(subparsers: Subparsers) -> argparse.ArgumentParser: status_parser = subparsers.add_parser( "status", - help="Get the status of available clusters.", + help="Get the status of clusters and jobs.", formatter_class=rich_argparse.RichHelpFormatter, ) status_parser.add_argument( - "clusters", - nargs="*", - default=None, - metavar="", - help=("Cluster(s) to query. Leave empty to query all clusters with an active connection."), + "table", + nargs="?", + choices=["clusters", "jobs", "all"], + default="all", + metavar="", + help="Which table to display: cluster overview, jobs overview, or both (default: all).", ) - # TODO: Add sub-commands to query the status with respect to different things, GPUs, storage, jobs, etc? - # Or just display everything? status_parser.set_defaults(func=status) return status_parser diff --git a/cluv/cli/status.py b/cluv/cli/status.py index 2c9ef21..a1b6794 100644 --- a/cluv/cli/status.py +++ b/cluv/cli/status.py @@ -12,7 +12,14 @@ from cluv.cli.login import get_remote_without_2fa_prompt from cluv.config import get_config -from cluv.remote import Remote +from cluv.slurm import ( + StorageStats, + parse_disk_quota, + parse_diskusage_report, + parse_partition_stats, + parse_savail, + parse_sinfo_nodes, +) logger = logging.getLogger(__name__) __all__ = ["status"] @@ -30,15 +37,6 @@ class JobStats: my_completed: int | None = None # recently completed jobs for the current user -@dataclass -class StorageStats: - """Disk usage as (used_gib, quota_gib) for $HOME and $SCRATCH.""" - home_used: float - home_quota: float - scratch_used: float - scratch_quota: float - - @dataclass class ClusterStatus: name: str @@ -50,6 +48,18 @@ class ClusterStatus: storage: StorageStats +def get_default_cluster_status(cluster: str) -> ClusterStatus: + return ClusterStatus( + name=cluster, + online=False, + gpu_idle=0, + gpu_total=0, + gpu_model="?", + jobs=JobStats(running=0, pending=0, my_running=0, my_pending=0), + storage=StorageStats(home_used=0, home_quota=0, scratch_used=0, scratch_quota=0), + ) + + # --------------------------------------------------------------------------- # Real data layer # --------------------------------------------------------------------------- @@ -91,21 +101,20 @@ class ClusterStatus: _MILA_CLUSTERS = {"mila"} -async def get_real_cluster_status(remote: Remote) -> ClusterStatus: +async def get_real_cluster_status(cluster: str) -> ClusterStatus: """Fetch live Slurm data from a remote cluster and return a ClusterStatus. Uses a single SSH round-trip. Falls back gracefully when commands are unavailable (e.g. partition-stats is DRAC-only). """ - from cluv.slurm import ( - parse_disk_quota, - parse_diskusage_report, - parse_partition_stats, - parse_savail, - parse_sinfo_nodes, - ) - cluster = remote.hostname + # Use get_remote_without_2fa_prompt directly so we never filter out the + # "current" cluster the way login() does. A working socket for mila is + # perfectly usable even when /home/mila is mounted locally. + remote = await get_remote_without_2fa_prompt(cluster) + if remote is None: + return get_default_cluster_status(cluster) + script = _REMOTE_SCRIPT_MILA if cluster in _MILA_CLUSTERS else _REMOTE_SCRIPT_DRAC try: @@ -117,15 +126,7 @@ async def get_real_cluster_status(remote: Remote) -> ClusterStatus: ) except Exception as exc: logger.warning(f"[red]Could not reach {cluster}: {exc}[/red]") - return ClusterStatus( - name=cluster, - online=False, - gpu_idle=0, - gpu_total=0, - gpu_model="?", - jobs=JobStats(running=0, pending=0, my_running=0, my_pending=0), - storage=StorageStats(home_used=0, home_quota=0, scratch_used=0, scratch_quota=0), - ) + return get_default_cluster_status(cluster) parts = raw.split(_SEP) # Pad in case some sections are missing @@ -205,33 +206,6 @@ async def get_real_cluster_status(remote: Remote) -> ClusterStatus: ) -async def get_all_cluster_statuses( - remotes: list[Remote] | None = None, -) -> tuple[list[ClusterStatus], bool]: - """Query clusters in parallel. - - If *remotes* is provided, query exactly those connections. - Otherwise, query all clusters that already have an active SSH connection - (never blocks on 2FA). - - Returns (statuses, any_live) where any_live is False when no cluster - was reachable. - """ - if remotes is None: - clusters = get_config().clusters - remotes = [ - r - for r in await asyncio.gather(*(get_remote_without_2fa_prompt(c) for c in clusters)) - if r is not None - ] - - if not remotes: - return [], False - - statuses = list(await asyncio.gather(*(get_real_cluster_status(r) for r in remotes))) - return statuses, True - - # --------------------------------------------------------------------------- # UI helpers # --------------------------------------------------------------------------- @@ -274,29 +248,24 @@ def _gpu_bar(idle: int, total: int, width: int = 10) -> Text: def _build_cluster_table(data: list[ClusterStatus]) -> Table: table = Table( - title="[bold cyan]Cluster Overview[/bold cyan]", + title="Cluster Overview", box=box.ROUNDED, show_lines=True, header_style="bold white on #1a1a2e", - title_style="bold", + title_style="bold cyan", expand=True, ) table.add_column("Cluster", style="bold", ratio=1) - table.add_column("Status", justify="center", ratio=1) - table.add_column("GPU model", justify="center", ratio=1) - table.add_column("Free GPUs", justify="left", ratio=2) + table.add_column("GPU model", justify="center", ratio=2) + table.add_column("Free GPUs", justify="left", ratio=1) table.add_column("My jobs\nrun/pend", justify="center", ratio=1) table.add_column("All jobs\nrun/pend", justify="center", ratio=1) table.add_column("$HOME", justify="left", ratio=2) table.add_column("$SCRATCH", justify="left", ratio=2) for c in data: - if not c.online: - status_cell = Text("⚠ offline", style="bold red") - else: - status_cell = Text("● online", style="bold green") - + status_cell = Text("● ", style="bold green") if c.online else Text("⚠ ", style="bold red") my_jobs = Text(f"{c.jobs.my_running} / {c.jobs.my_pending}", style="cyan") all_jobs = Text(f"{c.jobs.running} / {c.jobs.pending}", style="white") @@ -307,8 +276,7 @@ def _build_cluster_table(data: list[ClusterStatus]) -> Table: row_style = "dim" if not c.online else "" table.add_row( - Text(c.name, style="bold magenta" if c.online else "dim"), - status_cell, + status_cell + Text(c.name, style="bold magenta" if c.online else "bold bright_black"), Text(c.gpu_model, style="bright_blue"), _gpu_bar(c.gpu_idle, c.gpu_total), my_jobs, @@ -323,9 +291,10 @@ def _build_cluster_table(data: list[ClusterStatus]) -> Table: def _build_my_jobs_table(data: list[ClusterStatus]) -> Table: table = Table( - title="[bold cyan]Your Jobs Summary[/bold cyan]", + title="Jobs Overview", box=box.SIMPLE_HEAVY, header_style="bold white on #1a1a2e", + title_style="bold cyan", expand=True, ) table.add_column("Cluster", style="bold magenta") @@ -372,48 +341,43 @@ def _build_my_jobs_table(data: list[ClusterStatus]) -> Table: def _build_legend() -> Panel: legend = ( + "[green]●[/green] online " + "[red]⚠[/red] offline " "[green]▰[/green] free GPU " "[red]▱[/red] busy GPU " - "[green]█[/green]/[yellow]█[/yellow]/[red]█[/red] disk usage (low/med/high) " - "[green]●[/green] online " - "[red]⚠[/red] offline" + "[green]█[/green]/[yellow]█[/yellow]/[red]█[/red] disk usage (low/med/high)" ) return Panel(legend, title="Legend", border_style="dim", padding=(0, 1)) -async def status(clusters: list[str] | None = None): +async def status(table: str) -> None: """Gets the status of available clusters. - Gives you an overview of the state of each cluster, and displays an overview of the state of your jobs across the clusters. - Displays the number of idle nodes, or the number of idle GPUs, or something similar, for each cluster """ console = Console() - clusters = list(clusters or []) - - if clusters: - # Use get_remote_without_2fa_prompt directly so we never filter out the - # "current" cluster the way login() does. A working socket for mila is - # perfectly usable even when /home/mila is mounted locally. - remotes = [ - r - for r in await asyncio.gather(*(get_remote_without_2fa_prompt(c) for c in clusters)) - if r is not None + clusters = get_config().clusters_names + + # Query clusters in parallel + with console.status("Fetching clusters status..."): + data: list[ClusterStatus] = [ + d for d in await asyncio.gather(*(get_real_cluster_status(c) for c in clusters)) ] - data, is_live = await get_all_cluster_statuses(remotes=remotes) - else: - data, is_live = await get_all_cluster_statuses() - if not is_live: + # Show a tip message if all clusters are offline, which likely means the user hasn't logged in yet (no control sockets). + if all(not c.online for c in data): console.print( - "[yellow]No active cluster connections found. Run [bold]cluv login[/bold] first.[/yellow]" + "[yellow]No active connections to any clusters found. Run [bold]cluv login[/bold] first.[/yellow]" ) console.print() console.rule("[bold cyan]cluv status[/bold cyan]") console.print() - console.print(_build_cluster_table(data)) - console.print() - console.print(_build_my_jobs_table(data)) - console.print() - console.print(_build_legend()) - console.print() + if table in ("clusters", "all"): + console.print(_build_cluster_table(data)) + console.print(_build_legend()) + console.print() + if table in ("jobs", "all"): + console.print(_build_my_jobs_table(data)) + console.print() diff --git a/cluv/slurm.py b/cluv/slurm.py index 7988d58..934c994 100644 --- a/cluv/slurm.py +++ b/cluv/slurm.py @@ -6,8 +6,16 @@ from __future__ import annotations import re +from dataclasses import dataclass -from cluv.cli.status import StorageStats + +@dataclass +class StorageStats: + """Disk usage as (used_gib, quota_gib) for $HOME and $SCRATCH.""" + home_used: float + home_quota: float + scratch_used: float + scratch_quota: float # ---------------------------------------------------------------------------