Skip to content

Commit 70bd83e

Browse files
fix TPU node pool scale to zero (#75)
* fix TPU node pool scale to zero * code reformat * address gemini comment
1 parent c6f013d commit 70bd83e

File tree

1 file changed

+11
-3
lines changed

1 file changed

+11
-3
lines changed

keras_remote/backend/gke_client.py

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -437,16 +437,24 @@ def _check_node_pool_exists_cached(selector_items) -> bool:
437437
pool_labels = config_dict.get("labels", {}).copy()
438438

439439
# Map GKE injected node labels for accelerators mapping
440-
accelerators = config_dict.get("accelerators", [])
441-
if accelerators:
442-
accel_type = accelerators[0].get("acceleratorType", "")
440+
accel_config_list = config_dict.get("accelerators", [])
441+
if accel_config_list:
442+
accel_type = accel_config_list[0].get("acceleratorType", "")
443443
if accel_type.startswith("tpu-"):
444444
pool_labels["cloud.google.com/gke-tpu-accelerator"] = accel_type
445445
else:
446446
pool_labels["cloud.google.com/gke-accelerator"] = accel_type
447447

448448
# TPU mapping fallback
449449
machine_type = config_dict.get("machineType", "")
450+
451+
# Check resource labels for TPU type (common in v5e/v5litepod)
452+
resource_labels = config_dict.get("resourceLabels", {})
453+
if "goog-gke-accelerator-type" in resource_labels:
454+
pool_labels["cloud.google.com/gke-tpu-accelerator"] = resource_labels[
455+
"goog-gke-accelerator-type"
456+
]
457+
450458
if machine_type.startswith("ct"):
451459
# We roughly map TPU topology presence for preflight
452460
pool_labels["cloud.google.com/gke-tpu-topology"] = selector.get(

0 commit comments

Comments
 (0)