Skip to content

Commit 7765e7a

Browse files
update wraning to error
1 parent d0f5bc8 commit 7765e7a

File tree

3 files changed

+46
-14
lines changed

3 files changed

+46
-14
lines changed

keras_remote/backend/gke_client.py

Lines changed: 37 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,8 @@
22

33
import time
44
from contextlib import suppress
5+
import subprocess
6+
import json
57

68
from absl import logging
79
from kubernetes import client, config
@@ -394,6 +396,25 @@ def _print_pod_logs(core_v1, job_name, namespace):
394396
logging.info("Pod %s logs:\n%s", pod.metadata.name, logs)
395397

396398

399+
def _validate_node_pool_exists(selector: dict) -> bool:
400+
"""Use gcloud to verify that a GKE NodePool matches the pod node selector."""
401+
try:
402+
# Requires gcloud CLI and valid credentials.
403+
out = subprocess.check_output([
404+
"gcloud", "container", "node-pools", "list",
405+
"--format", "json"], text=True, stderr=subprocess.DEVNULL)
406+
pools = json.loads(out)
407+
for pool in pools:
408+
config = pool.get("config", {})
409+
labels = config.get("labels", {})
410+
# Check if all keys/values in the selector exist in this pool's labels
411+
if all(labels.get(k) == str(v) for k, v in selector.items()):
412+
return True
413+
return False
414+
except Exception:
415+
# If gcloud is missing or unauthenticated, degrade gracefully and assume pool exists
416+
return True
417+
397418
def _check_pod_scheduling(core_v1, job_name, namespace):
398419
"""Check for pod scheduling issues and raise helpful errors."""
399420
with suppress(ApiException):
@@ -406,22 +427,30 @@ def _check_pod_scheduling(core_v1, job_name, namespace):
406427
if condition.type == "PodScheduled" and condition.status == "False":
407428
msg = condition.message or ""
408429
if "Insufficient nvidia.com/gpu" in msg:
430+
selector = pod.spec.node_selector or {}
431+
if not _validate_node_pool_exists(selector):
432+
selector_str = ", ".join([f"{k}: {v}" for k, v in selector.items()]) if selector else "None"
433+
raise RuntimeError(f"No GKE node pool exists with selector '{selector_str}'. "
434+
"Please use 'keras-remote pool add' to configure this accelerator.")
409435
logging.info(
410436
f"Pod {pod.metadata.name} is Pending: Insufficient nvidia.com/gpu. "
411-
"Waiting for GKE Cluster Autoscaler to provision a new node... (scale-to-zero)"
437+
"Waiting for GKE Cluster Autoscaler to provision a new node... (scale-to-zero)\n"
438+
" Note: If this hangs indefinitely, ensure your GCP project has adequate quota."
412439
)
413440
elif (
414441
"didn't match Pod's node affinity/selector" in msg
415442
or "node selector" in msg.lower()
416443
):
417-
selector = pod.spec.node_selector
418-
selector_str = (
419-
", ".join([f"{k}: {v}" for k, v in selector.items()])
420-
if selector
421-
else "None"
422-
)
444+
selector = pod.spec.node_selector or {}
445+
selector_str = ", ".join([f"{k}: {v}" for k, v in selector.items()]) if selector else "None"
446+
447+
if not _validate_node_pool_exists(selector):
448+
raise RuntimeError(f"No GKE node pool exists with selector '{selector_str}'. "
449+
"Please use 'keras-remote pool add' to configure this accelerator.")
450+
423451
logging.info(
424452
f"Pod {pod.metadata.name} is Pending: No currently running nodes "
425453
f"match accelerator selector '{selector_str}'. "
426-
"Waiting for GKE Cluster Autoscaler to provision a new node... (scale-to-zero)"
454+
"Waiting for GKE Cluster Autoscaler to provision a new node... (scale-to-zero)\n"
455+
" Note: If this hangs indefinitely, ensure your GCP project has adequate quota."
427456
)

keras_remote/cli/infra/program.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -249,9 +249,9 @@ def _create_tpu_node_pool(cluster, tpu: TpuConfig, zone, project_id):
249249
cluster=cluster.name,
250250
location=zone,
251251
project=project_id,
252-
initial_node_count=0,
252+
initial_node_count=tpu.num_nodes if tpu.num_nodes > 1 else 0,
253253
autoscaling=gcp.container.NodePoolAutoscalingArgs(
254-
min_node_count=0,
254+
min_node_count=tpu.num_nodes if tpu.num_nodes > 1 else 0,
255255
max_node_count=tpu.num_nodes,
256256
),
257257
management=gcp.container.NodePoolManagementArgs(

keras_remote/cli/infra/program_test.py

Lines changed: 7 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -68,13 +68,14 @@ def test_node_count_matches_config(self, gcp_mock):
6868

6969
program._create_tpu_node_pool(cluster, tpu, "us-central2-b", "my-project")
7070

71-
# Due to scale-to-zero, initial_node_count is 0 and max is stored in autoscaling
71+
# Due to multi-host TPU workaround, initial_node_count is equal to num_nodes
7272
call_kwargs = gcp_mock.container.NodePool.call_args.kwargs
73-
self.assertEqual(call_kwargs.get("initial_node_count"), 0)
73+
self.assertEqual(call_kwargs.get("initial_node_count"), 4)
7474
autoscaling_kwargs = (
7575
gcp_mock.container.NodePoolAutoscalingArgs.call_args.kwargs
7676
)
7777
self.assertEqual(autoscaling_kwargs.get("max_node_count"), 4)
78+
self.assertEqual(autoscaling_kwargs.get("min_node_count"), 4)
7879

7980
@mock.patch.object(program, "gcp")
8081
def test_pool_name_includes_tpu_name(self, gcp_mock):
@@ -300,13 +301,15 @@ def test_node_pool_scale_to_zero(
300301
cluster, accelerator, "us-central2-b", "my-project"
301302
)
302303

304+
is_multi_host = getattr(accelerator, "num_nodes", 1) > 1
305+
303306
call_kwargs = gcp_mock.container.NodePool.call_args.kwargs
304-
self.assertEqual(call_kwargs.get("initial_node_count"), 0)
307+
self.assertEqual(call_kwargs.get("initial_node_count"), expected_max_count if is_multi_host else 0)
305308

306309
autoscaling_kwargs = (
307310
gcp_mock.container.NodePoolAutoscalingArgs.call_args.kwargs
308311
)
309-
self.assertEqual(autoscaling_kwargs.get("min_node_count"), 0)
312+
self.assertEqual(autoscaling_kwargs.get("min_node_count"), expected_max_count if is_multi_host else 0)
310313
self.assertEqual(
311314
autoscaling_kwargs.get("max_node_count"), expected_max_count
312315
)

0 commit comments

Comments
 (0)