@@ -40,8 +40,7 @@ class GpuSpec:
4040 """Registry entry for a GPU type."""
4141
4242 gke_label : str
43- machine_type : str
44- counts : tuple [int , ...]
43+ counts : dict [int , str ] # count -> machine_type
4544
4645
4746@dataclass (frozen = True )
@@ -63,14 +62,53 @@ class TpuSpec:
6362
6463
6564GPUS : dict [str , GpuSpec ] = {
66- "l4" : GpuSpec ("nvidia-l4" , "g2-standard-4" , (1 , 2 , 4 , 8 )),
67- "t4" : GpuSpec ("nvidia-tesla-t4" , "n1-standard-4" , (1 , 2 , 4 )),
68- "v100" : GpuSpec ("nvidia-tesla-v100" , "n1-standard-8" , (1 , 2 , 4 , 8 )),
69- "a100" : GpuSpec ("nvidia-tesla-a100" , "a2-highgpu-1g" , (1 , 2 , 4 , 8 , 16 )),
70- "a100-80gb" : GpuSpec ("nvidia-a100-80gb" , "a2-ultragpu-1g" , (1 , 2 , 4 , 8 , 16 )),
71- "h100" : GpuSpec ("nvidia-h100-80gb" , "a3-highgpu-1g" , (1 , 2 , 4 , 8 )),
72- "p4" : GpuSpec ("nvidia-tesla-p4" , "n1-standard-4" , (1 , 2 , 4 )),
73- "p100" : GpuSpec ("nvidia-tesla-p100" , "n1-standard-4" , (1 , 2 , 4 )),
65+ "l4" : GpuSpec ("nvidia-l4" , {
66+ 1 : "g2-standard-4" ,
67+ 2 : "g2-standard-24" ,
68+ 4 : "g2-standard-48" ,
69+ 8 : "g2-standard-96" ,
70+ }),
71+ "t4" : GpuSpec ("nvidia-tesla-t4" , {
72+ 1 : "n1-standard-4" ,
73+ 2 : "n1-standard-8" ,
74+ 4 : "n1-standard-16" ,
75+ }),
76+ "v100" : GpuSpec ("nvidia-tesla-v100" , {
77+ 1 : "n1-standard-8" ,
78+ 2 : "n1-standard-16" ,
79+ 4 : "n1-standard-32" ,
80+ 8 : "n1-standard-64" ,
81+ }),
82+ "a100" : GpuSpec ("nvidia-tesla-a100" , {
83+ 1 : "a2-highgpu-1g" ,
84+ 2 : "a2-highgpu-2g" ,
85+ 4 : "a2-highgpu-4g" ,
86+ 8 : "a2-highgpu-8g" ,
87+ 16 : "a2-megagpu-16g" ,
88+ }),
89+ "a100-80gb" : GpuSpec ("nvidia-a100-80gb" , {
90+ 1 : "a2-ultragpu-1g" ,
91+ 2 : "a2-ultragpu-2g" ,
92+ 4 : "a2-ultragpu-4g" ,
93+ 8 : "a2-ultragpu-8g" ,
94+ 16 : "a2-ultragpu-16g" ,
95+ }),
96+ "h100" : GpuSpec ("nvidia-h100-80gb" , {
97+ 1 : "a3-highgpu-1g" ,
98+ 2 : "a3-highgpu-2g" ,
99+ 4 : "a3-highgpu-4g" ,
100+ 8 : "a3-highgpu-8g" ,
101+ }),
102+ "p4" : GpuSpec ("nvidia-tesla-p4" , {
103+ 1 : "n1-standard-4" ,
104+ 2 : "n1-standard-8" ,
105+ 4 : "n1-standard-16" ,
106+ }),
107+ "p100" : GpuSpec ("nvidia-tesla-p100" , {
108+ 1 : "n1-standard-4" ,
109+ 2 : "n1-standard-8" ,
110+ 4 : "n1-standard-16" ,
111+ }),
74112}
75113
76114_GPU_ALIASES : dict [str , str ] = {
@@ -83,19 +121,6 @@ class TpuSpec:
83121# Machine-type suffix "-Nt" → N chips per VM (e.g. ct5p-hightpu-4t → 4 chips).
84122# v5p uses 3-D topologies (AxBxC); v2, v3, v5litepod, v6e use 2-D (AxB).
85123TPUS : dict [str , TpuSpec ] = {
86- "v2" : TpuSpec (
87- "tpu-v2-podslice" ,
88- 4 ,
89- {
90- 4 : TpuTopologySpec ("2x2" , "ct2-hightpu-4t" , 1 ),
91- 16 : TpuTopologySpec ("4x4" , "ct2-hightpu-4t" , 4 ),
92- 32 : TpuTopologySpec ("4x8" , "ct2-hightpu-4t" , 8 ),
93- 64 : TpuTopologySpec ("8x8" , "ct2-hightpu-4t" , 16 ),
94- 128 : TpuTopologySpec ("8x16" , "ct2-hightpu-4t" , 32 ),
95- 256 : TpuTopologySpec ("16x16" , "ct2-hightpu-4t" , 64 ),
96- 512 : TpuTopologySpec ("16x32" , "ct2-hightpu-4t" , 128 ),
97- },
98- ),
99124 "v3" : TpuSpec (
100125 "tpu-v3-podslice" ,
101126 4 ,
@@ -148,6 +173,7 @@ class TpuSpec:
148173 {
149174 8 : TpuTopologySpec ("2x2x2" , "ct5p-hightpu-4t" , 2 ),
150175 16 : TpuTopologySpec ("2x2x4" , "ct5p-hightpu-4t" , 4 ),
176+ 32 : TpuTopologySpec ("2x4x4" , "ct5p-hightpu-4t" , 8 ),
151177 },
152178 ),
153179 "v6e" : TpuSpec (
@@ -165,13 +191,11 @@ class TpuSpec:
165191}
166192
167193
168- # ── Parser ────────────────────────────────────────────────────────
169-
170- _MULTI_GPU_RE = re .compile (r"^(.+?)(?:x|-)(\d+)$" ) # "a100x4", "l4-2"
171- _TPU_CHIPS_RE = re .compile (r"^([a-z0-9_]+)-(\d+)$" ) # "v3-8", "v5litepod-16"
194+ _MULTI_GPU_RE = re .compile (r"^([^x]+)(?:x)(\d+)$" ) # "a100x4"
195+ _TPU_CHIPS_RE = re .compile (r"^([a-z0-9_]+)-(\d+)$" ) # "v3-8"
172196_TPU_TOPO_RE = re .compile (
173197 r"^([a-z0-9_]+)-(\d+x\d+(?:x\d+)?)$"
174- ) # "v5litepod-2x2", "v5p-2x2x2"
198+ ) # "v5litepod-2x2"
175199
176200DEFAULT_GPU = "l4"
177201DEFAULT_TPU = "v5litepod"
@@ -186,7 +210,7 @@ class TpuSpec:
186210 "p100" ,
187211 "p4" ,
188212]
189- _PREFERRED_TPUS = ["v6e" , "v5p" , "v5litepod" , "v4" , "v3" , "v2" ]
213+ _PREFERRED_TPUS = ["v6e" , "v5p" , "v5litepod" , "v4" , "v3" ]
190214
191215
192216def _resolve_gpu_alias (name : str ) -> str :
@@ -203,13 +227,25 @@ def parse_accelerator(accel_str: str) -> Accelerator:
203227 Returns GpuConfig, TpuConfig, or None (for "cpu").
204228
205229 Accepted formats:
206- GPU: "l4", "gpu", "gpu-4", "a100x4", "l4-2", "a100-80gbx8"
207- TPU: "v3-8", "tpu", "tpu-8", "v5litepod-2x2", "v5litepod"
208- CPU: "cpu", "cpu-8"
230+ - Generic: "gpu", "tpu", "cpu" (resolves to defaults)
231+ - Dynamic Count: "gpu:4", "tpu:8", "cpu:8" (assigns most capable hardware matching the count)
232+ - Explicit GPU Name: "gpu:l4", "l4", "gpu:a100-80gb" (resolves to 1 instance of the specified GPU)
233+ - Multi-GPU Name: "gpu:a100x4", "a100x4", "gpu:l4-2" (resolves to N instances of the specified GPU)
234+ - Explicit TPU Name: "tpu:v5litepod", "v5litepod" (resolves to the default topology/chips for the TPU)
235+ - Explicit TPU Topology/Chips: "tpu:v3-8", "tpu:v5litepod-2x2", "v3-8" (resolves to the specified TPU slice)
236+
237+ Note: Prefixes ('gpu:' and 'tpu:') are recommended for complete disambiguation but are completely optional.
238+
239+ Dynamic Resolution:
240+ When using generic formats like "gpu:<N>" or "tpu:<N>", the parser
241+ dynamically assigns the most capable hardware type that supports the
242+ requested device count `N`. Hardware is selected based on an internal
243+ preference hierarchy (e.g., H100 > A100 > L4 for GPUs, and
244+ v6e > v5p > v5litepod for TPUs).
209245 """
210246 s = accel_str .strip ().lower ()
211247
212- if s == "cpu" or (s .startswith ("cpu- " ) and s [4 :].isdigit ()):
248+ if s == "cpu" or (s .startswith ("cpu: " ) and s [4 :].isdigit ()):
213249 return None
214250
215251 if s == "gpu" :
@@ -218,42 +254,56 @@ def parse_accelerator(accel_str: str) -> Accelerator:
218254 if s == "tpu" :
219255 return make_tpu (DEFAULT_TPU , TPUS [DEFAULT_TPU ].default_chips )
220256
221- if s .startswith ("gpu-" ) and s [4 :].isdigit ():
222- count = int (s [4 :])
223- search_order = _PREFERRED_GPUS
224- for gpu_name in search_order :
257+ # 1) Try parsing as GPU
258+ is_gpu_explicit = s .startswith ("gpu:" )
259+ gpu_str = s [4 :] if is_gpu_explicit else s
260+
261+ if gpu_str .isdigit ():
262+ count = int (gpu_str )
263+ for gpu_name in _PREFERRED_GPUS :
225264 if gpu_name in GPUS and count in GPUS [gpu_name ].counts :
226265 return make_gpu (gpu_name , count )
227- valid_counts = sorted (set (c for spec in GPUS .values () for c in spec .counts ))
228- raise ValueError (
229- f"No GPU supports count { count } . Supported counts across all GPUs: { valid_counts } "
230- )
231-
232- if s .startswith ("tpu-" ) and s [4 :].isdigit ():
233- chips = int (s [4 :])
234- search_order = _PREFERRED_TPUS
235- for tpu_name in search_order :
236- if tpu_name in TPUS and chips in TPUS [tpu_name ].topologies :
237- return make_tpu (tpu_name , chips )
238- valid_chips = sorted (
239- set (c for spec in TPUS .values () for c in spec .topologies )
240- )
241- raise ValueError (
242- f"No TPU supports { chips } chips. Supported chip counts across all TPUs: { valid_chips } "
243- )
266+ if is_gpu_explicit :
267+ valid_counts = sorted (set (c for spec in GPUS .values () for c in spec .counts ))
268+ raise ValueError (
269+ f"No GPU supports count { count } . Supported counts: { valid_counts } "
270+ )
244271
245- # Direct GPU name: "l4", "a100-80gb"
246- name = _resolve_gpu_alias (s )
272+ name = _resolve_gpu_alias (gpu_str )
247273 if name in GPUS :
248274 return make_gpu (name , 1 )
249275
250- # Direct TPU name (bare): "v5litepod" → default chips
251- name = _resolve_tpu_alias (s )
276+ m = _MULTI_GPU_RE .match (gpu_str )
277+ if m :
278+ name = _resolve_gpu_alias (m .group (1 ))
279+ if name in GPUS :
280+ return make_gpu (name , int (m .group (2 )))
281+
282+ if is_gpu_explicit :
283+ raise ValueError (f"Unknown GPU accelerator: '{ accel_str } '" )
284+
285+ # 2) Try parsing as TPU
286+ is_tpu_explicit = s .startswith ("tpu:" )
287+ tpu_str = s [4 :] if is_tpu_explicit else s
288+
289+ if tpu_str .isdigit ():
290+ chips = int (tpu_str )
291+ for tpu_name in _PREFERRED_TPUS :
292+ if tpu_name in TPUS and chips in TPUS [tpu_name ].topologies :
293+ return make_tpu (tpu_name , chips )
294+ if is_tpu_explicit :
295+ valid_chips = sorted (
296+ set (c for spec in TPUS .values () for c in spec .topologies )
297+ )
298+ raise ValueError (
299+ f"No TPU supports { chips } chips. Supported chip counts: { valid_chips } "
300+ )
301+
302+ name = _resolve_tpu_alias (tpu_str )
252303 if name in TPUS :
253304 return make_tpu (name , TPUS [name ].default_chips )
254305
255- # TPU with topology string: "v5litepod-2x2", "v5p-2x2x2"
256- m = _TPU_TOPO_RE .match (s )
306+ m = _TPU_TOPO_RE .match (tpu_str )
257307 if m :
258308 name = _resolve_tpu_alias (m .group (1 ))
259309 if name in TPUS :
@@ -267,25 +317,16 @@ def parse_accelerator(accel_str: str) -> Accelerator:
267317 f"Supported: { ', ' .join (valid )} ."
268318 )
269319
270- # TPU with chip count: "v3-8", "v5litepod-4"
271- m = _TPU_CHIPS_RE .match (s )
320+ m = _TPU_CHIPS_RE .match (tpu_str )
272321 if m :
273322 name = _resolve_tpu_alias (m .group (1 ))
274323 if name in TPUS :
275324 return make_tpu (name , int (m .group (2 )))
276325
277- # Multi-GPU: "a100x4", "l4x2"
278- m = _MULTI_GPU_RE .match (s )
279- if m :
280- name = _resolve_gpu_alias (m .group (1 ))
281- if name in GPUS :
282- return make_gpu (name , int (m .group (2 )))
283-
284326 raise ValueError (
285327 f"Unknown accelerator: '{ accel_str } '. "
286- f"GPUs: { ', ' .join (GPUS )} (use 'xN' for multi-GPU, e.g. 'a100x4'). "
287- f"TPUs: { ', ' .join (TPUS )} (use '-N' for chips, e.g. 'v3-8', "
288- f"or '-NxM' for topology, e.g. 'v5litepod-2x2')."
328+ f"GPUs: { ', ' .join (GPUS )} (use 'gpu:name' or 'gpu:namexN'). "
329+ f"TPUs: { ', ' .join (TPUS )} (use 'tpu:name' or 'tpu:name-N')."
289330 )
290331
291332
@@ -324,7 +365,7 @@ def make_gpu(name: str, count: int) -> GpuConfig:
324365 name = name ,
325366 count = count ,
326367 gke_label = spec .gke_label ,
327- machine_type = spec .machine_type ,
368+ machine_type = spec .counts [ count ] ,
328369 )
329370
330371
0 commit comments