Skip to content

Commit 4300d82

Browse files
Adds pool command group to CLI for node pool management (#58)
* Adds pool command group to CLI for node pool management * address reviews * address reviews * fix console outputs * update agent docs
1 parent 786604d commit 4300d82

File tree

15 files changed

+1071
-246
lines changed

15 files changed

+1071
-246
lines changed

.gemini/styleguide.md

Lines changed: 35 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -120,22 +120,55 @@ This prevents confusing situations where a user sets an env var that works in on
120120

121121
---
122122

123+
## CLI commands must be idempotent and follow the reconciliation pattern.
124+
125+
Every mutating CLI command (`up`, `pool add`, `pool remove`, etc.) must follow the refresh-read-merge-apply pattern:
126+
127+
1. `stack.refresh()` — sync local state with cloud reality
128+
2. `get_current_node_pools()` — read current pools from stack exports
129+
3. Build `InfraConfig` — merge existing state with desired changes
130+
4. `stack.up()` — apply only the diff
131+
132+
This ensures:
133+
134+
- Re-running after partial failure is always safe
135+
- Existing resources are never accidentally recreated (Pulumi tracks by URN)
136+
- External drift is detected and corrected
137+
138+
When adding a new CLI command that modifies infrastructure, follow this pattern rather than directly creating or deleting resources.
139+
140+
---
141+
142+
## Prefer graceful degradation over hard failures in CLI operations.
143+
144+
Partial failures in multi-step CLI operations should not abort the entire flow:
145+
146+
- If `stack.refresh()` fails, log a warning and continue with stale state
147+
- If `stack.up()` fails, set a failure flag but still run post-deploy steps
148+
- If a post-deploy step fails (kubectl, LWS, GPU drivers), log a warning and continue with remaining steps
149+
150+
The user can always re-run the same command to recover, since all operations are idempotent.
151+
152+
---
153+
123154
## Don't neglect error messages, docstrings, and documentation.
124155

125156
- **Catch user errors early.** Validate GCP project existence and quota before starting a long build.
126157
- **Provide detailed feedback.**
127-
- Bad: `Error: 403 Forbidden`
128-
- Good: `Permission denied. Please ensure your account 'user@example.com' has the 'Storage Object Admin' role on bucket 'gs://my-bucket'.`
158+
- Bad: `Error: 403 Forbidden`
159+
- Good: `Permission denied. Please ensure your account 'user@example.com' has the 'Storage Object Admin' role on bucket 'gs://my-bucket'.`
129160
- **Show, don't tell.** Documentation should show code examples of running functions, not just list arguments.
130161

131162
### Error messages: a case study
132163

133164
Bad:
165+
134166
```
135167
RuntimeError: Job failed.
136168
```
137169

138170
Good:
171+
139172
```
140173
RuntimeError: The remote job failed with exit code 1.
141174
Logs from the worker:

AGENTS.md

Lines changed: 37 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -14,8 +14,8 @@ keras_remote/
1414
├── runner/ # Remote worker entrypoint (runs inside container)
1515
├── utils/ # Serialization (packager) and Cloud Storage helpers
1616
├── cli/ # CLI for infrastructure provisioning (Pulumi-based)
17-
│ ├── commands/ # up, down, status, config
18-
│ └── infra/ # Pulumi programs and stack management
17+
│ ├── commands/ # up, down, status, config, pool (add/remove/list)
18+
│ └── infra/ # Pulumi programs, stack management, post-deploy steps
1919
├── credentials.py # Credential verification & auto-setup (shared by core & CLI)
2020
└── constants.py # Zone/region utilities
2121
```
@@ -49,14 +49,17 @@ keras_remote/
4949
| `utils/packager.py` | `save_payload()` (cloudpickle), `zip_working_dir()` |
5050
| `utils/storage.py` | GCS upload/download/cleanup for job artifacts |
5151
| `runner/remote_runner.py` | Runs inside container: deserialize, execute, upload result |
52+
| `cli/commands/pool.py` | Node pool add/remove/list commands |
53+
| `cli/infra/post_deploy.py` | kubectl, LWS CRD, GPU driver setup after stack.up() |
54+
| `cli/constants.py` | CLI defaults, paths, API list |
5255
| `cli/main.py` | CLI entry point (`keras-remote` command) |
5356

5457
## Key Abstractions
5558

5659
- **`JobContext`** (`backend/execution.py`): Mutable dataclass carrying all job state through the pipeline — inputs, generated IDs, artifact paths, image URI.
5760
- **`BaseK8sBackend`** (`backend/execution.py`): Base class with `submit_job`, `wait_for_job`, `cleanup_job`. Subclassed by `GKEBackend` and `PathwaysBackend`.
5861
- **`GpuConfig` / `TpuConfig`** (`core/accelerators.py`): Frozen dataclasses for accelerator metadata. Single source of truth used by runtime, container builder, and CLI.
59-
- **`InfraConfig`** (`cli/config.py`): CLI provisioning configuration (project, zone, cluster, accelerator).
62+
- **`InfraConfig` / `NodePoolConfig`** (`cli/config.py`): CLI provisioning configuration. `InfraConfig` holds project, zone, cluster name, and a list of `NodePoolConfig` entries. `NodePoolConfig` pairs a unique pool name (e.g., `gpu-l4-a3f2`) with a `GpuConfig` or `TpuConfig`.
6063

6164
## Conventions
6265

@@ -74,15 +77,40 @@ Every customizable resource name must follow the same resolution model across al
7477
- **CLI commands**: `--flag` (with `envvar=`) → env var → interactive prompt or default
7578
- **`config show`**: displays current value and source for every configurable name
7679

77-
| Env Var | `@run()` param | CLI flag | `config show` | Default |
78-
| --- | --- | --- | --- | --- |
79-
| `KERAS_REMOTE_PROJECT` | `project=` | `--project` | Yes | *(required)* |
80-
| `KERAS_REMOTE_ZONE` | `zone=` | `--zone` | Yes | `us-central1-a` |
81-
| `KERAS_REMOTE_CLUSTER` | `cluster=` | `--cluster` | Yes | `keras-remote-cluster` |
82-
| `KERAS_REMOTE_GKE_NAMESPACE` | `namespace=` | *(runtime only)* | Yes | `default` |
80+
| Env Var | `@run()` param | CLI flag | `config show` | Default |
81+
| ---------------------------- | -------------- | ---------------- | ------------- | ---------------------- |
82+
| `KERAS_REMOTE_PROJECT` | `project=` | `--project` | Yes | _(required)_ |
83+
| `KERAS_REMOTE_ZONE` | `zone=` | `--zone` | Yes | `us-central1-a` |
84+
| `KERAS_REMOTE_CLUSTER` | `cluster=` | `--cluster` | Yes | `keras-remote-cluster` |
85+
| `KERAS_REMOTE_GKE_NAMESPACE` | `namespace=` | _(runtime only)_ | Yes | `default` |
8386

8487
When adding a new configurable resource name, ensure it is wired into **all three paths** (decorator, CLI flags on every relevant command, and `config show`). The `GOOGLE_CLOUD_PROJECT` env var is also accepted as a fallback for project ID (after `KERAS_REMOTE_PROJECT`).
8588

89+
Additional CLI-only env vars:
90+
91+
| Env Var | Default | Description |
92+
| ------------------------ | ------------------------ | ---------------------------- |
93+
| `KERAS_REMOTE_STATE_DIR` | `~/.keras-remote/pulumi` | Pulumi local state directory |
94+
95+
### CLI State Management
96+
97+
The CLI manages three layers of state: in-memory config (`InfraConfig`), Pulumi local state files (`~/.keras-remote/pulumi/`), and GCP cloud resources. Each GCP project gets its own Pulumi stack (stack name = project ID).
98+
99+
Every mutating command (`up`, `pool add`, `pool remove`, etc.) follows this reconciliation pattern:
100+
101+
1. `stack.refresh()` — pull cloud reality into local state
102+
2. `get_current_node_pools()` — read current pools from stack exports
103+
3. Build new `InfraConfig` — merge existing pools with desired changes
104+
4. `create_program(config)` — generate Pulumi program from desired state
105+
5. `stack.up()` — diff desired vs current, apply only changes
106+
107+
Key behaviors:
108+
109+
- **`up` re-runs** preserve existing pools and ignore `--accelerator` (defer to `pool add/remove`)
110+
- **All commands are idempotent** — safe to re-run after partial failure
111+
- **Graceful degradation** — partial failures (refresh, post-deploy steps) log warnings but don't abort the operation
112+
- **Pool state round-trips** through Pulumi stack exports (`accelerators` key) as a list of dicts, reconstructed via `_export_to_node_pool()`
113+
86114
### Testing
87115

88116
- **Framework**: `absl.testing` (not pytest)

keras_remote/cli/commands/pool.py

Lines changed: 245 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,245 @@
1+
"""keras-remote pool commands — add, remove, and list accelerator node pools."""
2+
3+
import click
4+
import pulumi.automation as auto
5+
6+
from keras_remote.cli.config import InfraConfig, NodePoolConfig
7+
from keras_remote.cli.constants import DEFAULT_CLUSTER_NAME, DEFAULT_ZONE
8+
from keras_remote.cli.infra.program import create_program
9+
from keras_remote.cli.infra.stack_manager import (
10+
get_current_node_pools,
11+
get_stack,
12+
)
13+
from keras_remote.cli.output import (
14+
banner,
15+
console,
16+
infrastructure_state,
17+
success,
18+
warning,
19+
)
20+
from keras_remote.cli.prerequisites_check import check_all
21+
from keras_remote.cli.prompts import resolve_project
22+
from keras_remote.core import accelerators
23+
from keras_remote.core.accelerators import generate_pool_name
24+
25+
26+
def _common_options(f):
27+
"""Shared options for pool subcommands."""
28+
f = click.option(
29+
"--project",
30+
envvar="KERAS_REMOTE_PROJECT",
31+
default=None,
32+
help="GCP project ID [env: KERAS_REMOTE_PROJECT]",
33+
)(f)
34+
f = click.option(
35+
"--zone",
36+
envvar="KERAS_REMOTE_ZONE",
37+
default=None,
38+
help=f"GCP zone [env: KERAS_REMOTE_ZONE, default: {DEFAULT_ZONE}]",
39+
)(f)
40+
f = click.option(
41+
"--cluster",
42+
"cluster_name",
43+
envvar="KERAS_REMOTE_CLUSTER",
44+
default=None,
45+
help="GKE cluster name [default: keras-remote-cluster]",
46+
)(f)
47+
return f
48+
49+
50+
def _resolve_common(project, zone, cluster_name):
51+
"""Resolve common options to concrete values."""
52+
return (
53+
project or resolve_project(),
54+
zone or DEFAULT_ZONE,
55+
cluster_name or DEFAULT_CLUSTER_NAME,
56+
)
57+
58+
59+
@click.group()
60+
def pool():
61+
"""Manage accelerator node pools."""
62+
63+
64+
def _load_pools(project, zone, cluster_name):
65+
"""Check prerequisites, refresh stack state, and return existing pools."""
66+
check_all()
67+
project, zone, cluster_name = _resolve_common(project, zone, cluster_name)
68+
69+
base_config = InfraConfig(
70+
project=project, zone=zone, cluster_name=cluster_name
71+
)
72+
try:
73+
program = create_program(base_config)
74+
stack = get_stack(program, base_config)
75+
except auto.errors.CommandError as e:
76+
raise click.ClickException(
77+
f"No Pulumi stack found for project '{project}': {e}\n"
78+
"Run 'keras-remote up' to provision infrastructure first."
79+
) from e
80+
81+
console.print("\nRefreshing state...\n")
82+
try:
83+
stack.refresh(on_output=print)
84+
except auto.errors.CommandError as e:
85+
warning(f"Failed to refresh stack state: {e}")
86+
87+
existing_pools = get_current_node_pools(stack)
88+
return project, zone, cluster_name, existing_pools
89+
90+
91+
def _apply_pool_update(project, zone, cluster_name, node_pools):
92+
"""Run a Pulumi update with the given node pool list.
93+
94+
Returns:
95+
True if the update succeeded, False if it encountered an error.
96+
"""
97+
config = InfraConfig(
98+
project=project,
99+
zone=zone,
100+
cluster_name=cluster_name,
101+
node_pools=node_pools,
102+
)
103+
program = create_program(config)
104+
stack = get_stack(program, config)
105+
106+
console.print("\n[bold]Updating infrastructure...[/bold]\n")
107+
try:
108+
result = stack.up(on_output=print)
109+
console.print()
110+
success(f"Pulumi update complete. {result.summary.resource_changes}")
111+
return True
112+
except auto.errors.CommandError as e:
113+
console.print()
114+
warning(f"Pulumi update encountered an issue: {e}")
115+
return False
116+
117+
118+
@pool.command("add")
119+
@_common_options
120+
@click.option(
121+
"--accelerator",
122+
required=True,
123+
help="Accelerator spec: t4, l4, a100, a100-80gb, h100, "
124+
"v5litepod, v5p, v6e, v3 (with optional count/topology)",
125+
)
126+
@click.option("--yes", "-y", is_flag=True, help="Skip confirmation prompt")
127+
def pool_add(project, zone, cluster_name, accelerator, yes):
128+
"""Add an accelerator node pool to the cluster."""
129+
banner("keras-remote Pool Add")
130+
131+
# Parse the accelerator spec first to fail fast on bad input.
132+
try:
133+
accel_config = accelerators.parse_accelerator(accelerator)
134+
except ValueError as e:
135+
raise click.BadParameter(str(e), param_hint="--accelerator") from e
136+
137+
if accel_config is None:
138+
raise click.BadParameter(
139+
"Cannot add a CPU node pool. Use 'keras-remote up' instead.",
140+
param_hint="--accelerator",
141+
)
142+
143+
new_pool_name = generate_pool_name(accel_config)
144+
new_pool = NodePoolConfig(new_pool_name, accel_config)
145+
146+
project, zone, cluster_name, existing_pools = _load_pools(
147+
project, zone, cluster_name
148+
)
149+
all_pools = existing_pools + [new_pool]
150+
151+
console.print(f"\nAdding pool [bold]{new_pool_name}[/bold] ({accelerator})")
152+
console.print(f"Total pools after add: {len(all_pools)}\n")
153+
154+
if not yes:
155+
click.confirm("Proceed?", abort=True)
156+
157+
update_succeeded = _apply_pool_update(project, zone, cluster_name, all_pools)
158+
159+
console.print()
160+
if update_succeeded:
161+
banner("Pool Added")
162+
else:
163+
banner("Pool Update Failed")
164+
console.print()
165+
console.print(
166+
"You may re-run the command to retry, or use"
167+
" [bold]keras-remote pool list[/bold] to check current state."
168+
)
169+
console.print()
170+
171+
172+
@pool.command("remove")
173+
@_common_options
174+
@click.argument("pool_name")
175+
@click.option("--yes", "-y", is_flag=True, help="Skip confirmation prompt")
176+
def pool_remove(project, zone, cluster_name, pool_name, yes):
177+
"""Remove an accelerator node pool from the cluster."""
178+
banner("keras-remote Pool Remove")
179+
180+
project, zone, cluster_name, existing_pools = _load_pools(
181+
project, zone, cluster_name
182+
)
183+
184+
remaining = [p for p in existing_pools if p.name != pool_name]
185+
if len(remaining) == len(existing_pools):
186+
existing_names = [p.name for p in existing_pools]
187+
raise click.ClickException(
188+
f"Node pool '{pool_name}' not found. "
189+
f"Existing pools: {', '.join(existing_names) or '(none)'}"
190+
)
191+
192+
console.print(f"\nRemoving pool [bold]{pool_name}[/bold]")
193+
console.print(f"Remaining pools after remove: {len(remaining)}\n")
194+
195+
if not yes:
196+
click.confirm("Proceed?", abort=True)
197+
198+
update_succeeded = _apply_pool_update(project, zone, cluster_name, remaining)
199+
200+
console.print()
201+
if update_succeeded:
202+
banner("Pool Removed")
203+
else:
204+
banner("Pool Update Failed")
205+
console.print()
206+
console.print(
207+
"You may re-run the command to retry, or use"
208+
" [bold]keras-remote pool list[/bold] to check current state."
209+
)
210+
console.print()
211+
212+
213+
@pool.command("list")
214+
@_common_options
215+
def pool_list(project, zone, cluster_name):
216+
"""List accelerator node pools on the cluster."""
217+
banner("keras-remote Node Pools")
218+
219+
check_all()
220+
project, zone, cluster_name = _resolve_common(project, zone, cluster_name)
221+
222+
base_config = InfraConfig(
223+
project=project, zone=zone, cluster_name=cluster_name
224+
)
225+
226+
try:
227+
program = create_program(base_config)
228+
stack = get_stack(program, base_config)
229+
except auto.errors.CommandError as e:
230+
warning(f"No Pulumi stack found for project '{project}': {e}")
231+
console.print("Run 'keras-remote up' to provision infrastructure.")
232+
return
233+
234+
console.print("\nRefreshing state...\n")
235+
try:
236+
stack.refresh(on_output=print)
237+
except auto.errors.CommandError as e:
238+
warning(f"Failed to refresh stack state: {e}")
239+
240+
outputs = stack.outputs()
241+
if not outputs:
242+
warning("No infrastructure found. Run 'keras-remote up' first.")
243+
return
244+
245+
infrastructure_state(outputs)

0 commit comments

Comments
 (0)