@@ -40,8 +40,7 @@ class GpuSpec:
4040 """Registry entry for a GPU type."""
4141
4242 gke_label : str
43- machine_type : str
44- counts : tuple [int , ...]
43+ counts : dict [int , str ] # count -> machine_type
4544
4645
4746@dataclass (frozen = True )
@@ -63,12 +62,77 @@ class TpuSpec:
6362
6463
6564GPUS : dict [str , GpuSpec ] = {
66- "l4" : GpuSpec ("nvidia-l4" , "g2-standard-4" , (1 , 2 , 4 )),
67- "t4" : GpuSpec ("nvidia-tesla-t4" , "n1-standard-4" , (1 , 2 , 4 )),
68- "v100" : GpuSpec ("nvidia-tesla-v100" , "n1-standard-8" , (1 , 2 , 4 , 8 )),
69- "a100" : GpuSpec ("nvidia-tesla-a100" , "a2-highgpu-1g" , (1 , 2 , 4 , 8 )),
70- "a100-80gb" : GpuSpec ("nvidia-a100-80gb" , "a2-ultragpu-1g" , (1 , 2 , 4 , 8 )),
71- "h100" : GpuSpec ("nvidia-h100-80gb" , "a3-highgpu-1g" , (1 , 2 , 4 , 8 )),
65+ "l4" : GpuSpec (
66+ "nvidia-l4" ,
67+ {
68+ 1 : "g2-standard-4" ,
69+ 2 : "g2-standard-24" ,
70+ 4 : "g2-standard-48" ,
71+ 8 : "g2-standard-96" ,
72+ },
73+ ),
74+ "t4" : GpuSpec (
75+ "nvidia-tesla-t4" ,
76+ {
77+ 1 : "n1-standard-4" ,
78+ 2 : "n1-standard-8" ,
79+ 4 : "n1-standard-16" ,
80+ },
81+ ),
82+ "v100" : GpuSpec (
83+ "nvidia-tesla-v100" ,
84+ {
85+ 1 : "n1-standard-8" ,
86+ 2 : "n1-standard-16" ,
87+ 4 : "n1-standard-32" ,
88+ 8 : "n1-standard-64" ,
89+ },
90+ ),
91+ "a100" : GpuSpec (
92+ "nvidia-tesla-a100" ,
93+ {
94+ 1 : "a2-highgpu-1g" ,
95+ 2 : "a2-highgpu-2g" ,
96+ 4 : "a2-highgpu-4g" ,
97+ 8 : "a2-highgpu-8g" ,
98+ 16 : "a2-megagpu-16g" ,
99+ },
100+ ),
101+ "a100-80gb" : GpuSpec (
102+ "nvidia-a100-80gb" ,
103+ {
104+ 1 : "a2-ultragpu-1g" ,
105+ 2 : "a2-ultragpu-2g" ,
106+ 4 : "a2-ultragpu-4g" ,
107+ 8 : "a2-ultragpu-8g" ,
108+ 16 : "a2-ultragpu-16g" ,
109+ },
110+ ),
111+ "h100" : GpuSpec (
112+ "nvidia-h100-80gb" ,
113+ {
114+ 1 : "a3-highgpu-1g" ,
115+ 2 : "a3-highgpu-2g" ,
116+ 4 : "a3-highgpu-4g" ,
117+ 8 : "a3-highgpu-8g" ,
118+ },
119+ ),
120+ "p4" : GpuSpec (
121+ "nvidia-tesla-p4" ,
122+ {
123+ 1 : "n1-standard-4" ,
124+ 2 : "n1-standard-8" ,
125+ 4 : "n1-standard-16" ,
126+ },
127+ ),
128+ "p100" : GpuSpec (
129+ "nvidia-tesla-p100" ,
130+ {
131+ 1 : "n1-standard-4" ,
132+ 2 : "n1-standard-8" ,
133+ 4 : "n1-standard-16" ,
134+ },
135+ ),
72136}
73137
74138_GPU_ALIASES : dict [str , str ] = {
@@ -81,22 +145,36 @@ class TpuSpec:
81145# Machine-type suffix "-Nt" → N chips per VM (e.g. ct5p-hightpu-4t → 4 chips).
82146# v5p uses 3-D topologies (AxBxC); v2, v3, v5litepod, v6e use 2-D (AxB).
83147TPUS : dict [str , TpuSpec ] = {
84- "v2" : TpuSpec (
85- "tpu-v2-podslice" ,
86- 4 ,
87- {
88- 4 : TpuTopologySpec ("2x2" , "ct2-hightpu-4t" , 1 ),
89- 16 : TpuTopologySpec ("4x4" , "ct2-hightpu-4t" , 4 ),
90- 32 : TpuTopologySpec ("4x8" , "ct2-hightpu-4t" , 8 ),
91- },
92- ),
93148 "v3" : TpuSpec (
94149 "tpu-v3-podslice" ,
95150 4 ,
96151 {
97152 4 : TpuTopologySpec ("2x2" , "ct3-hightpu-4t" , 1 ),
98153 16 : TpuTopologySpec ("4x4" , "ct3p-hightpu-4t" , 4 ),
99154 32 : TpuTopologySpec ("4x8" , "ct3p-hightpu-4t" , 8 ),
155+ 64 : TpuTopologySpec ("8x8" , "ct3p-hightpu-4t" , 16 ),
156+ 128 : TpuTopologySpec ("8x16" , "ct3p-hightpu-4t" , 32 ),
157+ 256 : TpuTopologySpec ("16x16" , "ct3p-hightpu-4t" , 64 ),
158+ 512 : TpuTopologySpec ("16x32" , "ct3p-hightpu-4t" , 128 ),
159+ 1024 : TpuTopologySpec ("32x32" , "ct3p-hightpu-4t" , 256 ),
160+ 2048 : TpuTopologySpec ("32x64" , "ct3p-hightpu-4t" , 512 ),
161+ },
162+ ),
163+ "v4" : TpuSpec (
164+ "tpu-v4-podslice" ,
165+ 4 ,
166+ {
167+ 4 : TpuTopologySpec ("2x2x1" , "ct4p-hightpu-4t" , 1 ),
168+ 8 : TpuTopologySpec ("2x2x2" , "ct4p-hightpu-4t" , 2 ),
169+ 16 : TpuTopologySpec ("2x2x4" , "ct4p-hightpu-4t" , 4 ),
170+ 32 : TpuTopologySpec ("2x4x4" , "ct4p-hightpu-4t" , 8 ),
171+ 64 : TpuTopologySpec ("4x4x4" , "ct4p-hightpu-4t" , 16 ),
172+ 128 : TpuTopologySpec ("4x4x8" , "ct4p-hightpu-4t" , 32 ),
173+ 256 : TpuTopologySpec ("4x8x8" , "ct4p-hightpu-4t" , 64 ),
174+ 512 : TpuTopologySpec ("8x8x8" , "ct4p-hightpu-4t" , 128 ),
175+ 1024 : TpuTopologySpec ("8x8x16" , "ct4p-hightpu-4t" , 256 ),
176+ 2048 : TpuTopologySpec ("8x16x16" , "ct4p-hightpu-4t" , 512 ),
177+ 4096 : TpuTopologySpec ("16x16x16" , "ct4p-hightpu-4t" , 1024 ),
100178 },
101179 ),
102180 "v5litepod" : TpuSpec (
@@ -106,6 +184,11 @@ class TpuSpec:
106184 1 : TpuTopologySpec ("1x1" , "ct5lp-hightpu-1t" , 1 ),
107185 4 : TpuTopologySpec ("2x2" , "ct5lp-hightpu-4t" , 1 ),
108186 8 : TpuTopologySpec ("2x4" , "ct5lp-hightpu-8t" , 1 ),
187+ 16 : TpuTopologySpec ("4x4" , "ct5lp-hightpu-4t" , 4 ),
188+ 32 : TpuTopologySpec ("4x8" , "ct5lp-hightpu-4t" , 8 ),
189+ 64 : TpuTopologySpec ("8x8" , "ct5lp-hightpu-4t" , 16 ),
190+ 128 : TpuTopologySpec ("8x16" , "ct5lp-hightpu-4t" , 32 ),
191+ 256 : TpuTopologySpec ("16x16" , "ct5lp-hightpu-4t" , 64 ),
109192 },
110193 ),
111194 "v5p" : TpuSpec (
@@ -114,6 +197,7 @@ class TpuSpec:
114197 {
115198 8 : TpuTopologySpec ("2x2x2" , "ct5p-hightpu-4t" , 2 ),
116199 16 : TpuTopologySpec ("2x2x4" , "ct5p-hightpu-4t" , 4 ),
200+ 32 : TpuTopologySpec ("2x4x4" , "ct5p-hightpu-4t" , 8 ),
117201 },
118202 ),
119203 "v6e" : TpuSpec (
@@ -126,14 +210,39 @@ class TpuSpec:
126210 ),
127211}
128212
213+ _TPU_ALIASES : dict [str , str ] = {
214+ "v5e" : "v5litepod" ,
215+ }
129216
130- # ── Parser ────────────────────────────────────────────────────────
131217
132- _MULTI_GPU_RE = re .compile (r"^(.+?)x (\d+)$" ) # "a100x4"
133- _TPU_CHIPS_RE = re .compile (r"^(v\d+\w* )-(\d+)$" ) # "v3-8"
218+ _MULTI_GPU_RE = re .compile (r"^([^x]+)(?:x) (\d+)$" ) # "a100x4"
219+ _TPU_CHIPS_RE = re .compile (r"^([a-z0-9_]+ )-(\d+)$" ) # "v3-8"
134220_TPU_TOPO_RE = re .compile (
135- r"^(v\d+\w*)-(\d+x\d+(?:x\d+)?)$"
136- ) # "v5litepod-2x2", "v5p-2x2x2"
221+ r"^([a-z0-9_]+)-(\d+x\d+(?:x\d+)?)$"
222+ ) # "v5litepod-2x2"
223+
224+ DEFAULT_GPU = "l4"
225+ DEFAULT_TPU = "v5litepod"
226+
227+ _PREFERRED_GPUS = [
228+ "h100" ,
229+ "a100-80gb" ,
230+ "a100" ,
231+ "l4" ,
232+ "v100" ,
233+ "t4" ,
234+ "p100" ,
235+ "p4" ,
236+ ]
237+ _PREFERRED_TPUS = ["v6e" , "v5p" , "v5litepod" , "v4" , "v3" ]
238+
239+
240+ def _resolve_gpu_alias (name : str ) -> str :
241+ return _GPU_ALIASES .get (name , name )
242+
243+
244+ def _resolve_tpu_alias (name : str ) -> str :
245+ return _TPU_ALIASES .get (name , name )
137246
138247
139248def parse_accelerator (accel_str : str ) -> Accelerator :
@@ -142,60 +251,108 @@ def parse_accelerator(accel_str: str) -> Accelerator:
142251 Returns GpuConfig, TpuConfig, or None (for "cpu").
143252
144253 Accepted formats:
145- GPU: "l4", "nvidia-l4", "a100x4", "a100-80gbx8"
146- TPU: "v3-8" (chip count), "v5litepod-2x2" (topology), "v5litepod" (default)
147- CPU: "cpu"
254+ - Generic: "gpu", "tpu", "cpu" (resolves to defaults)
255+ - Dynamic Count: "gpu:4", "tpu:8", "cpu:8" (assigns most capable hardware matching the count)
256+ - Explicit GPU Name: "gpu:l4", "l4", "gpu:a100-80gb" (resolves to 1 instance of the specified GPU)
257+ - Multi-GPU Name: "gpu:a100x4", "a100x4", "gpu:l4-2" (resolves to N instances of the specified GPU)
258+ - Explicit TPU Name: "tpu:v5litepod", "v5litepod" (resolves to the default topology/chips for the TPU)
259+ - Explicit TPU Topology/Chips: "tpu:v3-8", "tpu:v5litepod-2x2", "v3-8" (resolves to the specified TPU slice)
260+
261+ Note: Prefixes ('gpu:' and 'tpu:') are recommended for complete disambiguation but are completely optional.
262+
263+ Dynamic Resolution:
264+ When using generic formats like "gpu:<N>" or "tpu:<N>", the parser
265+ dynamically assigns the most capable hardware type that supports the
266+ requested device count `N`. Hardware is selected based on an internal
267+ preference hierarchy (e.g., H100 > A100 > L4 for GPUs, and
268+ v6e > v5p > v5litepod for TPUs).
148269 """
149270 s = accel_str .strip ().lower ()
150271
151- if s == "cpu" :
272+ if s == "cpu" or ( s . startswith ( "cpu:" ) and s [ 4 :]. isdigit ()) :
152273 return None
153274
154- # Direct GPU name: "l4", "a100-80gb"
155- if s in GPUS :
156- return make_gpu (s , 1 )
157-
158- # GPU alias: "nvidia-l4"
159- if s in _GPU_ALIASES :
160- return make_gpu (_GPU_ALIASES [s ], 1 )
161-
162- # Multi-GPU: "a100x4", "l4x2"
163- m = _MULTI_GPU_RE .match (s )
275+ if s == "gpu" :
276+ return make_gpu (DEFAULT_GPU , 1 )
277+
278+ if s == "tpu" :
279+ return make_tpu (DEFAULT_TPU , TPUS [DEFAULT_TPU ].default_chips )
280+
281+ # 1) Try parsing as GPU
282+ is_gpu_explicit = s .startswith ("gpu:" )
283+ gpu_str = s [4 :] if is_gpu_explicit else s
284+
285+ if gpu_str .isdigit ():
286+ count = int (gpu_str )
287+ for gpu_name in _PREFERRED_GPUS :
288+ if gpu_name in GPUS and count in GPUS [gpu_name ].counts :
289+ return make_gpu (gpu_name , count )
290+ if is_gpu_explicit :
291+ valid_counts = sorted (
292+ set (c for spec in GPUS .values () for c in spec .counts )
293+ )
294+ raise ValueError (
295+ f"No GPU supports count { count } . Supported counts: { valid_counts } "
296+ )
297+
298+ name = _resolve_gpu_alias (gpu_str )
299+ if name in GPUS :
300+ return make_gpu (name , 1 )
301+
302+ m = _MULTI_GPU_RE .match (gpu_str )
164303 if m :
165- name = m .group (1 )
304+ name = _resolve_gpu_alias ( m .group (1 ) )
166305 if name in GPUS :
167306 return make_gpu (name , int (m .group (2 )))
168- if name in _GPU_ALIASES :
169- return make_gpu (_GPU_ALIASES [name ], int (m .group (2 )))
170-
171- # Direct TPU name (bare): "v5litepod" → default chips
172- if s in TPUS :
173- return make_tpu (s , TPUS [s ].default_chips )
174-
175- # TPU with topology string: "v5litepod-2x2", "v5p-2x2x2"
176- m = _TPU_TOPO_RE .match (s )
177- if m and m .group (1 ) in TPUS :
178- name = m .group (1 )
179- topo_str = m .group (2 )
180- for chips , topo_spec in TPUS [name ].topologies .items ():
181- if topo_spec .topology == topo_str :
182- return make_tpu (name , chips )
183- valid = [ts .topology for ts in TPUS [name ].topologies .values ()]
184- raise ValueError (
185- f"Topology '{ topo_str } ' not supported for '{ name } '. "
186- f"Supported: { ', ' .join (valid )} ."
187- )
188307
189- # TPU with chip count: "v3-8", "v5litepod-4"
190- m = _TPU_CHIPS_RE .match (s )
191- if m and m .group (1 ) in TPUS :
192- return make_tpu (m .group (1 ), int (m .group (2 )))
308+ if is_gpu_explicit :
309+ raise ValueError (f"Unknown GPU accelerator: '{ accel_str } '" )
310+
311+ # 2) Try parsing as TPU
312+ is_tpu_explicit = s .startswith ("tpu:" )
313+ tpu_str = s [4 :] if is_tpu_explicit else s
314+
315+ if tpu_str .isdigit ():
316+ chips = int (tpu_str )
317+ for tpu_name in _PREFERRED_TPUS :
318+ if tpu_name in TPUS and chips in TPUS [tpu_name ].topologies :
319+ return make_tpu (tpu_name , chips )
320+ if is_tpu_explicit :
321+ valid_chips = sorted (
322+ set (c for spec in TPUS .values () for c in spec .topologies )
323+ )
324+ raise ValueError (
325+ f"No TPU supports { chips } chips. Supported chip counts: { valid_chips } "
326+ )
327+
328+ name = _resolve_tpu_alias (tpu_str )
329+ if name in TPUS :
330+ return make_tpu (name , TPUS [name ].default_chips )
331+
332+ m = _TPU_TOPO_RE .match (tpu_str )
333+ if m :
334+ name = _resolve_tpu_alias (m .group (1 ))
335+ if name in TPUS :
336+ topo_str = m .group (2 )
337+ for chips , topo_spec in TPUS [name ].topologies .items ():
338+ if topo_spec .topology == topo_str :
339+ return make_tpu (name , chips )
340+ valid = [ts .topology for ts in TPUS [name ].topologies .values ()]
341+ raise ValueError (
342+ f"Topology '{ topo_str } ' not supported for '{ name } '. "
343+ f"Supported: { ', ' .join (valid )} ."
344+ )
345+
346+ m = _TPU_CHIPS_RE .match (tpu_str )
347+ if m :
348+ name = _resolve_tpu_alias (m .group (1 ))
349+ if name in TPUS :
350+ return make_tpu (name , int (m .group (2 )))
193351
194352 raise ValueError (
195353 f"Unknown accelerator: '{ accel_str } '. "
196- f"GPUs: { ', ' .join (GPUS )} (use 'xN' for multi-GPU, e.g. 'a100x4'). "
197- f"TPUs: { ', ' .join (TPUS )} (use '-N' for chips, e.g. 'v3-8', "
198- f"or '-NxM' for topology, e.g. 'v5litepod-2x2')."
354+ f"GPUs: { ', ' .join (GPUS )} (use 'gpu:name' or 'gpu:namexN'). "
355+ f"TPUs: { ', ' .join (TPUS )} (use 'tpu:name' or 'tpu:name-N')."
199356 )
200357
201358
@@ -234,7 +391,7 @@ def make_gpu(name: str, count: int) -> GpuConfig:
234391 name = name ,
235392 count = count ,
236393 gke_label = spec .gke_label ,
237- machine_type = spec .machine_type ,
394+ machine_type = spec .counts [ count ] ,
238395 )
239396
240397
0 commit comments