|
| 1 | +"""keras-remote pool commands — add, remove, and list accelerator node pools.""" |
| 2 | + |
| 3 | +import click |
| 4 | +import pulumi.automation as auto |
| 5 | + |
| 6 | +from keras_remote.cli.config import InfraConfig, NodePoolConfig |
| 7 | +from keras_remote.cli.constants import DEFAULT_CLUSTER_NAME, DEFAULT_ZONE |
| 8 | +from keras_remote.cli.infra.program import create_program |
| 9 | +from keras_remote.cli.infra.stack_manager import ( |
| 10 | + get_current_node_pools, |
| 11 | + get_stack, |
| 12 | +) |
| 13 | +from keras_remote.cli.output import ( |
| 14 | + banner, |
| 15 | + console, |
| 16 | + infrastructure_state, |
| 17 | + success, |
| 18 | + warning, |
| 19 | +) |
| 20 | +from keras_remote.cli.prerequisites_check import check_all |
| 21 | +from keras_remote.cli.prompts import resolve_project |
| 22 | +from keras_remote.core import accelerators |
| 23 | +from keras_remote.core.accelerators import generate_pool_name |
| 24 | + |
| 25 | + |
| 26 | +def _common_options(f): |
| 27 | + """Shared options for pool subcommands.""" |
| 28 | + f = click.option( |
| 29 | + "--project", |
| 30 | + envvar="KERAS_REMOTE_PROJECT", |
| 31 | + default=None, |
| 32 | + help="GCP project ID [env: KERAS_REMOTE_PROJECT]", |
| 33 | + )(f) |
| 34 | + f = click.option( |
| 35 | + "--zone", |
| 36 | + envvar="KERAS_REMOTE_ZONE", |
| 37 | + default=None, |
| 38 | + help=f"GCP zone [env: KERAS_REMOTE_ZONE, default: {DEFAULT_ZONE}]", |
| 39 | + )(f) |
| 40 | + f = click.option( |
| 41 | + "--cluster", |
| 42 | + "cluster_name", |
| 43 | + envvar="KERAS_REMOTE_CLUSTER", |
| 44 | + default=None, |
| 45 | + help="GKE cluster name [default: keras-remote-cluster]", |
| 46 | + )(f) |
| 47 | + return f |
| 48 | + |
| 49 | + |
| 50 | +def _resolve_common(project, zone, cluster_name): |
| 51 | + """Resolve common options to concrete values.""" |
| 52 | + return ( |
| 53 | + project or resolve_project(), |
| 54 | + zone or DEFAULT_ZONE, |
| 55 | + cluster_name or DEFAULT_CLUSTER_NAME, |
| 56 | + ) |
| 57 | + |
| 58 | + |
| 59 | +@click.group() |
| 60 | +def pool(): |
| 61 | + """Manage accelerator node pools.""" |
| 62 | + |
| 63 | + |
| 64 | +def _load_pools(project, zone, cluster_name): |
| 65 | + """Check prerequisites, refresh stack state, and return existing pools.""" |
| 66 | + check_all() |
| 67 | + project, zone, cluster_name = _resolve_common(project, zone, cluster_name) |
| 68 | + |
| 69 | + base_config = InfraConfig( |
| 70 | + project=project, zone=zone, cluster_name=cluster_name |
| 71 | + ) |
| 72 | + try: |
| 73 | + program = create_program(base_config) |
| 74 | + stack = get_stack(program, base_config) |
| 75 | + except auto.errors.CommandError as e: |
| 76 | + raise click.ClickException( |
| 77 | + f"No Pulumi stack found for project '{project}': {e}\n" |
| 78 | + "Run 'keras-remote up' to provision infrastructure first." |
| 79 | + ) from e |
| 80 | + |
| 81 | + console.print("\nRefreshing state...\n") |
| 82 | + try: |
| 83 | + stack.refresh(on_output=print) |
| 84 | + except auto.errors.CommandError as e: |
| 85 | + warning(f"Failed to refresh stack state: {e}") |
| 86 | + |
| 87 | + existing_pools = get_current_node_pools(stack) |
| 88 | + return project, zone, cluster_name, existing_pools |
| 89 | + |
| 90 | + |
| 91 | +def _apply_pool_update(project, zone, cluster_name, node_pools): |
| 92 | + """Run a Pulumi update with the given node pool list. |
| 93 | +
|
| 94 | + Returns: |
| 95 | + True if the update succeeded, False if it encountered an error. |
| 96 | + """ |
| 97 | + config = InfraConfig( |
| 98 | + project=project, |
| 99 | + zone=zone, |
| 100 | + cluster_name=cluster_name, |
| 101 | + node_pools=node_pools, |
| 102 | + ) |
| 103 | + program = create_program(config) |
| 104 | + stack = get_stack(program, config) |
| 105 | + |
| 106 | + console.print("\n[bold]Updating infrastructure...[/bold]\n") |
| 107 | + try: |
| 108 | + result = stack.up(on_output=print) |
| 109 | + console.print() |
| 110 | + success(f"Pulumi update complete. {result.summary.resource_changes}") |
| 111 | + return True |
| 112 | + except auto.errors.CommandError as e: |
| 113 | + console.print() |
| 114 | + warning(f"Pulumi update encountered an issue: {e}") |
| 115 | + return False |
| 116 | + |
| 117 | + |
| 118 | +@pool.command("add") |
| 119 | +@_common_options |
| 120 | +@click.option( |
| 121 | + "--accelerator", |
| 122 | + required=True, |
| 123 | + help="Accelerator spec: t4, l4, a100, a100-80gb, h100, " |
| 124 | + "v5litepod, v5p, v6e, v3 (with optional count/topology)", |
| 125 | +) |
| 126 | +@click.option("--yes", "-y", is_flag=True, help="Skip confirmation prompt") |
| 127 | +def pool_add(project, zone, cluster_name, accelerator, yes): |
| 128 | + """Add an accelerator node pool to the cluster.""" |
| 129 | + banner("keras-remote Pool Add") |
| 130 | + |
| 131 | + # Parse the accelerator spec first to fail fast on bad input. |
| 132 | + try: |
| 133 | + accel_config = accelerators.parse_accelerator(accelerator) |
| 134 | + except ValueError as e: |
| 135 | + raise click.BadParameter(str(e), param_hint="--accelerator") from e |
| 136 | + |
| 137 | + if accel_config is None: |
| 138 | + raise click.BadParameter( |
| 139 | + "Cannot add a CPU node pool. Use 'keras-remote up' instead.", |
| 140 | + param_hint="--accelerator", |
| 141 | + ) |
| 142 | + |
| 143 | + new_pool_name = generate_pool_name(accel_config) |
| 144 | + new_pool = NodePoolConfig(new_pool_name, accel_config) |
| 145 | + |
| 146 | + project, zone, cluster_name, existing_pools = _load_pools( |
| 147 | + project, zone, cluster_name |
| 148 | + ) |
| 149 | + all_pools = existing_pools + [new_pool] |
| 150 | + |
| 151 | + console.print(f"\nAdding pool [bold]{new_pool_name}[/bold] ({accelerator})") |
| 152 | + console.print(f"Total pools after add: {len(all_pools)}\n") |
| 153 | + |
| 154 | + if not yes: |
| 155 | + click.confirm("Proceed?", abort=True) |
| 156 | + |
| 157 | + update_succeeded = _apply_pool_update(project, zone, cluster_name, all_pools) |
| 158 | + |
| 159 | + console.print() |
| 160 | + if update_succeeded: |
| 161 | + banner("Pool Added") |
| 162 | + else: |
| 163 | + banner("Pool Update Failed") |
| 164 | + console.print() |
| 165 | + console.print( |
| 166 | + "You may re-run the command to retry, or use" |
| 167 | + " [bold]keras-remote pool list[/bold] to check current state." |
| 168 | + ) |
| 169 | + console.print() |
| 170 | + |
| 171 | + |
| 172 | +@pool.command("remove") |
| 173 | +@_common_options |
| 174 | +@click.argument("pool_name") |
| 175 | +@click.option("--yes", "-y", is_flag=True, help="Skip confirmation prompt") |
| 176 | +def pool_remove(project, zone, cluster_name, pool_name, yes): |
| 177 | + """Remove an accelerator node pool from the cluster.""" |
| 178 | + banner("keras-remote Pool Remove") |
| 179 | + |
| 180 | + project, zone, cluster_name, existing_pools = _load_pools( |
| 181 | + project, zone, cluster_name |
| 182 | + ) |
| 183 | + |
| 184 | + remaining = [p for p in existing_pools if p.name != pool_name] |
| 185 | + if len(remaining) == len(existing_pools): |
| 186 | + existing_names = [p.name for p in existing_pools] |
| 187 | + raise click.ClickException( |
| 188 | + f"Node pool '{pool_name}' not found. " |
| 189 | + f"Existing pools: {', '.join(existing_names) or '(none)'}" |
| 190 | + ) |
| 191 | + |
| 192 | + console.print(f"\nRemoving pool [bold]{pool_name}[/bold]") |
| 193 | + console.print(f"Remaining pools after remove: {len(remaining)}\n") |
| 194 | + |
| 195 | + if not yes: |
| 196 | + click.confirm("Proceed?", abort=True) |
| 197 | + |
| 198 | + update_succeeded = _apply_pool_update(project, zone, cluster_name, remaining) |
| 199 | + |
| 200 | + console.print() |
| 201 | + if update_succeeded: |
| 202 | + banner("Pool Removed") |
| 203 | + else: |
| 204 | + banner("Pool Update Failed") |
| 205 | + console.print() |
| 206 | + console.print( |
| 207 | + "You may re-run the command to retry, or use" |
| 208 | + " [bold]keras-remote pool list[/bold] to check current state." |
| 209 | + ) |
| 210 | + console.print() |
| 211 | + |
| 212 | + |
| 213 | +@pool.command("list") |
| 214 | +@_common_options |
| 215 | +def pool_list(project, zone, cluster_name): |
| 216 | + """List accelerator node pools on the cluster.""" |
| 217 | + banner("keras-remote Node Pools") |
| 218 | + |
| 219 | + check_all() |
| 220 | + project, zone, cluster_name = _resolve_common(project, zone, cluster_name) |
| 221 | + |
| 222 | + base_config = InfraConfig( |
| 223 | + project=project, zone=zone, cluster_name=cluster_name |
| 224 | + ) |
| 225 | + |
| 226 | + try: |
| 227 | + program = create_program(base_config) |
| 228 | + stack = get_stack(program, base_config) |
| 229 | + except auto.errors.CommandError as e: |
| 230 | + warning(f"No Pulumi stack found for project '{project}': {e}") |
| 231 | + console.print("Run 'keras-remote up' to provision infrastructure.") |
| 232 | + return |
| 233 | + |
| 234 | + console.print("\nRefreshing state...\n") |
| 235 | + try: |
| 236 | + stack.refresh(on_output=print) |
| 237 | + except auto.errors.CommandError as e: |
| 238 | + warning(f"Failed to refresh stack state: {e}") |
| 239 | + |
| 240 | + outputs = stack.outputs() |
| 241 | + if not outputs: |
| 242 | + warning("No infrastructure found. Run 'keras-remote up' first.") |
| 243 | + return |
| 244 | + |
| 245 | + infrastructure_state(outputs) |
0 commit comments