Skip to content

Commit ef4b74d

Browse files
Fixes TPU detection logic (#77)
1 parent 70bd83e commit ef4b74d

File tree

1 file changed

+8
-0
lines changed

1 file changed

+8
-0
lines changed

keras_remote/backend/gke_client.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -461,6 +461,14 @@ def _check_node_pool_exists_cached(selector_items) -> bool:
461461
"cloud.google.com/gke-tpu-topology", ""
462462
)
463463

464+
# Infer accelerator count from machine type using registry
465+
# This is robust because it uses the same source of truth as the Pod spec generation
466+
for tpu_spec in accelerators.TPUS.values():
467+
for chips, topo_spec in tpu_spec.topologies.items():
468+
if topo_spec.machine_type == machine_type:
469+
pool_labels["cloud.google.com/gke-accelerator-count"] = str(chips)
470+
break
471+
464472
if all(pool_labels.get(k) == str(v) for k, v in selector.items()):
465473
return True
466474
return False

0 commit comments

Comments
 (0)