@@ -437,22 +437,38 @@ def _check_node_pool_exists_cached(selector_items) -> bool:
437437 pool_labels = config_dict .get ("labels" , {}).copy ()
438438
439439 # Map GKE injected node labels for accelerators mapping
440- accelerators = config_dict .get ("accelerators" , [])
441- if accelerators :
442- accel_type = accelerators [0 ].get ("acceleratorType" , "" )
440+ accel_config_list = config_dict .get ("accelerators" , [])
441+ if accel_config_list :
442+ accel_type = accel_config_list [0 ].get ("acceleratorType" , "" )
443443 if accel_type .startswith ("tpu-" ):
444444 pool_labels ["cloud.google.com/gke-tpu-accelerator" ] = accel_type
445445 else :
446446 pool_labels ["cloud.google.com/gke-accelerator" ] = accel_type
447447
448448 # TPU mapping fallback
449449 machine_type = config_dict .get ("machineType" , "" )
450+
451+ # Check resource labels for TPU type (common in v5e/v5litepod)
452+ resource_labels = config_dict .get ("resourceLabels" , {})
453+ if "goog-gke-accelerator-type" in resource_labels :
454+ pool_labels ["cloud.google.com/gke-tpu-accelerator" ] = resource_labels [
455+ "goog-gke-accelerator-type"
456+ ]
457+
450458 if machine_type .startswith ("ct" ):
451459 # We roughly map TPU topology presence for preflight
452460 pool_labels ["cloud.google.com/gke-tpu-topology" ] = selector .get (
453461 "cloud.google.com/gke-tpu-topology" , ""
454462 )
455463
464+ # Infer accelerator count from machine type using registry
465+ # This is robust because it uses the same source of truth as the Pod spec generation
466+ for tpu_spec in accelerators .TPUS .values ():
467+ for chips , topo_spec in tpu_spec .topologies .items ():
468+ if topo_spec .machine_type == machine_type :
469+ pool_labels ["cloud.google.com/gke-accelerator-count" ] = str (chips )
470+ break
471+
456472 if all (pool_labels .get (k ) == str (v ) for k , v in selector .items ()):
457473 return True
458474 return False
0 commit comments