Skip to content

Commit eca05f0

Browse files
fix TPU node pool scale to zero
1 parent c6f013d commit eca05f0

File tree

1 file changed

+19
-3
lines changed

1 file changed

+19
-3
lines changed

keras_remote/backend/gke_client.py

Lines changed: 19 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -437,22 +437,38 @@ def _check_node_pool_exists_cached(selector_items) -> bool:
437437
pool_labels = config_dict.get("labels", {}).copy()
438438

439439
# Map GKE injected node labels for accelerators mapping
440-
accelerators = config_dict.get("accelerators", [])
441-
if accelerators:
442-
accel_type = accelerators[0].get("acceleratorType", "")
440+
accel_config_list = config_dict.get("accelerators", [])
441+
if accel_config_list:
442+
accel_type = accel_config_list[0].get("acceleratorType", "")
443443
if accel_type.startswith("tpu-"):
444444
pool_labels["cloud.google.com/gke-tpu-accelerator"] = accel_type
445445
else:
446446
pool_labels["cloud.google.com/gke-accelerator"] = accel_type
447447

448448
# TPU mapping fallback
449449
machine_type = config_dict.get("machineType", "")
450+
451+
# Check resource labels for TPU type (common in v5e/v5litepod)
452+
resource_labels = config_dict.get("resourceLabels", {})
453+
if "goog-gke-accelerator-type" in resource_labels:
454+
pool_labels["cloud.google.com/gke-tpu-accelerator"] = resource_labels[
455+
"goog-gke-accelerator-type"
456+
]
457+
450458
if machine_type.startswith("ct"):
451459
# We roughly map TPU topology presence for preflight
452460
pool_labels["cloud.google.com/gke-tpu-topology"] = selector.get(
453461
"cloud.google.com/gke-tpu-topology", ""
454462
)
455463

464+
# Infer accelerator count from machine type using registry
465+
# This is robust because it uses the same source of truth as the Pod spec generation
466+
for tpu_spec in accelerators.TPUS.values():
467+
for chips, topo_spec in tpu_spec.topologies.items():
468+
if topo_spec.machine_type == machine_type:
469+
pool_labels["cloud.google.com/gke-accelerator-count"] = str(chips)
470+
break
471+
456472
if all(pool_labels.get(k) == str(v) for k, v in selector.items()):
457473
return True
458474
return False

0 commit comments

Comments
 (0)