@@ -63,12 +63,14 @@ class TpuSpec:
6363
6464
6565GPUS : dict [str , GpuSpec ] = {
66- "l4" : GpuSpec ("nvidia-l4" , "g2-standard-4" , (1 , 2 , 4 )),
66+ "l4" : GpuSpec ("nvidia-l4" , "g2-standard-4" , (1 , 2 , 4 , 8 )),
6767 "t4" : GpuSpec ("nvidia-tesla-t4" , "n1-standard-4" , (1 , 2 , 4 )),
6868 "v100" : GpuSpec ("nvidia-tesla-v100" , "n1-standard-8" , (1 , 2 , 4 , 8 )),
69- "a100" : GpuSpec ("nvidia-tesla-a100" , "a2-highgpu-1g" , (1 , 2 , 4 , 8 )),
70- "a100-80gb" : GpuSpec ("nvidia-a100-80gb" , "a2-ultragpu-1g" , (1 , 2 , 4 , 8 )),
69+ "a100" : GpuSpec ("nvidia-tesla-a100" , "a2-highgpu-1g" , (1 , 2 , 4 , 8 , 16 )),
70+ "a100-80gb" : GpuSpec ("nvidia-a100-80gb" , "a2-ultragpu-1g" , (1 , 2 , 4 , 8 , 16 )),
7171 "h100" : GpuSpec ("nvidia-h100-80gb" , "a3-highgpu-1g" , (1 , 2 , 4 , 8 )),
72+ "p4" : GpuSpec ("nvidia-tesla-p4" , "n1-standard-4" , (1 , 2 , 4 )),
73+ "p100" : GpuSpec ("nvidia-tesla-p100" , "n1-standard-4" , (1 , 2 , 4 )),
7274}
7375
7476_GPU_ALIASES : dict [str , str ] = {
@@ -88,6 +90,10 @@ class TpuSpec:
8890 4 : TpuTopologySpec ("2x2" , "ct2-hightpu-4t" , 1 ),
8991 16 : TpuTopologySpec ("4x4" , "ct2-hightpu-4t" , 4 ),
9092 32 : TpuTopologySpec ("4x8" , "ct2-hightpu-4t" , 8 ),
93+ 64 : TpuTopologySpec ("8x8" , "ct2-hightpu-4t" , 16 ),
94+ 128 : TpuTopologySpec ("8x16" , "ct2-hightpu-4t" , 32 ),
95+ 256 : TpuTopologySpec ("16x16" , "ct2-hightpu-4t" , 64 ),
96+ 512 : TpuTopologySpec ("16x32" , "ct2-hightpu-4t" , 128 ),
9197 },
9298 ),
9399 "v3" : TpuSpec (
@@ -97,6 +103,29 @@ class TpuSpec:
97103 4 : TpuTopologySpec ("2x2" , "ct3-hightpu-4t" , 1 ),
98104 16 : TpuTopologySpec ("4x4" , "ct3p-hightpu-4t" , 4 ),
99105 32 : TpuTopologySpec ("4x8" , "ct3p-hightpu-4t" , 8 ),
106+ 64 : TpuTopologySpec ("8x8" , "ct3p-hightpu-4t" , 16 ),
107+ 128 : TpuTopologySpec ("8x16" , "ct3p-hightpu-4t" , 32 ),
108+ 256 : TpuTopologySpec ("16x16" , "ct3p-hightpu-4t" , 64 ),
109+ 512 : TpuTopologySpec ("16x32" , "ct3p-hightpu-4t" , 128 ),
110+ 1024 : TpuTopologySpec ("32x32" , "ct3p-hightpu-4t" , 256 ),
111+ 2048 : TpuTopologySpec ("32x64" , "ct3p-hightpu-4t" , 512 ),
112+ },
113+ ),
114+ "v4" : TpuSpec (
115+ "tpu-v4-podslice" ,
116+ 4 ,
117+ {
118+ 4 : TpuTopologySpec ("2x2x1" , "ct4p-hightpu-4t" , 1 ),
119+ 8 : TpuTopologySpec ("2x2x2" , "ct4p-hightpu-4t" , 2 ),
120+ 16 : TpuTopologySpec ("2x2x4" , "ct4p-hightpu-4t" , 4 ),
121+ 32 : TpuTopologySpec ("2x4x4" , "ct4p-hightpu-4t" , 8 ),
122+ 64 : TpuTopologySpec ("4x4x4" , "ct4p-hightpu-4t" , 16 ),
123+ 128 : TpuTopologySpec ("4x4x8" , "ct4p-hightpu-4t" , 32 ),
124+ 256 : TpuTopologySpec ("4x8x8" , "ct4p-hightpu-4t" , 64 ),
125+ 512 : TpuTopologySpec ("8x8x8" , "ct4p-hightpu-4t" , 128 ),
126+ 1024 : TpuTopologySpec ("8x8x16" , "ct4p-hightpu-4t" , 256 ),
127+ 2048 : TpuTopologySpec ("8x16x16" , "ct4p-hightpu-4t" , 512 ),
128+ 4096 : TpuTopologySpec ("16x16x16" , "ct4p-hightpu-4t" , 1024 ),
100129 },
101130 ),
102131 "v5litepod" : TpuSpec (
@@ -106,6 +135,11 @@ class TpuSpec:
106135 1 : TpuTopologySpec ("1x1" , "ct5lp-hightpu-1t" , 1 ),
107136 4 : TpuTopologySpec ("2x2" , "ct5lp-hightpu-4t" , 1 ),
108137 8 : TpuTopologySpec ("2x4" , "ct5lp-hightpu-8t" , 1 ),
138+ 16 : TpuTopologySpec ("4x4" , "ct5lp-hightpu-4t" , 4 ),
139+ 32 : TpuTopologySpec ("4x8" , "ct5lp-hightpu-4t" , 8 ),
140+ 64 : TpuTopologySpec ("8x8" , "ct5lp-hightpu-4t" , 16 ),
141+ 128 : TpuTopologySpec ("8x16" , "ct5lp-hightpu-4t" , 32 ),
142+ 256 : TpuTopologySpec ("16x16" , "ct5lp-hightpu-4t" , 64 ),
109143 },
110144 ),
111145 "v5p" : TpuSpec (
@@ -126,31 +160,72 @@ class TpuSpec:
126160 ),
127161}
128162
163+ _TPU_ALIASES : dict [str , str ] = {
164+ "v5e" : "v5litepod" ,
165+ "ghostlite" : "v5litepod" ,
166+ }
167+
129168
130169# ── Parser ────────────────────────────────────────────────────────
131170
132- _MULTI_GPU_RE = re .compile (r"^(.+?)x (\d+)$" ) # "a100x4"
133- _TPU_CHIPS_RE = re .compile (r"^(v\d+\w* )-(\d+)$" ) # "v3-8"
171+ _MULTI_GPU_RE = re .compile (r"^(.+?)(?:x|-) (\d+)$" ) # "a100x4", "l4-2 "
172+ _TPU_CHIPS_RE = re .compile (r"^([a-z0-9_]+ )-(\d+)$" ) # "v3-8", "ghostlite-16 "
134173_TPU_TOPO_RE = re .compile (
135- r"^(v\d+\w* )-(\d+x\d+(?:x\d+)?)$"
174+ r"^([a-z0-9_]+ )-(\d+x\d+(?:x\d+)?)$"
136175) # "v5litepod-2x2", "v5p-2x2x2"
137176
177+ DEFAULT_GPU = "l4"
178+ DEFAULT_TPU = "v5litepod"
179+
180+ _PREFERRED_GPUS = ["h100" , "a100-80gb" , "a100" , "l4" , "v100" , "t4" , "p100" , "p4" ]
181+ _PREFERRED_TPUS = ["v6e" , "v5p" , "v5litepod" , "v4" , "v3" , "v2" ]
182+
138183
139184def parse_accelerator (accel_str : str ) -> Accelerator :
140185 """Parse an accelerator string into a fully resolved config.
141186
142187 Returns GpuConfig, TpuConfig, or None (for "cpu").
143188
144189 Accepted formats:
145- GPU: "l4", "nvidia-l4 ", "a100x4", "a100-80gbx8"
146- TPU: "v3-8" (chip count) , "v5litepod-2x2" (topology) , "v5litepod" (default)
147- CPU: "cpu"
190+ GPU: "l4", "gpu", "gpu-4 ", "a100x4", "l4-2 ", "a100-80gbx8"
191+ TPU: "v3-8", "tpu", "tpu-8" , "v5litepod-2x2", "v5litepod"
192+ CPU: "cpu", "cpu-8"
148193 """
149194 s = accel_str .strip ().lower ()
150195
151- if s == "cpu" :
196+ if s == "cpu" or ( s . startswith ( "cpu-" ) and s [ 4 :]. isdigit ()) :
152197 return None
153198
199+ if s == "gpu" :
200+ return make_gpu (DEFAULT_GPU , 1 )
201+
202+ if s == "tpu" :
203+ return make_tpu (DEFAULT_TPU , TPUS [DEFAULT_TPU ].default_chips )
204+
205+ if s .startswith ("gpu-" ) and s [4 :].isdigit ():
206+ count = int (s [4 :])
207+ if count in GPUS [DEFAULT_GPU ].counts :
208+ return make_gpu (DEFAULT_GPU , count )
209+ for gpu_name in _PREFERRED_GPUS :
210+ if gpu_name in GPUS and count in GPUS [gpu_name ].counts :
211+ return make_gpu (gpu_name , count )
212+ valid_counts = sorted (set (c for spec in GPUS .values () for c in spec .counts ))
213+ raise ValueError (
214+ f"No GPU supports count { count } . Supported counts across all GPUs: { valid_counts } "
215+ )
216+
217+ if s .startswith ("tpu-" ) and s [4 :].isdigit ():
218+ chips = int (s [4 :])
219+ if chips in TPUS [DEFAULT_TPU ].topologies :
220+ return make_tpu (DEFAULT_TPU , chips )
221+ for tpu_name in _PREFERRED_TPUS :
222+ if tpu_name in TPUS and chips in TPUS [tpu_name ].topologies :
223+ return make_tpu (tpu_name , chips )
224+ valid_chips = sorted (set (c for spec in TPUS .values () for c in spec .topologies ))
225+ raise ValueError (
226+ f"No TPU supports { chips } chips. Supported chip counts across all TPUs: { valid_chips } "
227+ )
228+
154229 # Direct GPU name: "l4", "a100-80gb"
155230 if s in GPUS :
156231 return make_gpu (s , 1 )
@@ -171,25 +246,35 @@ def parse_accelerator(accel_str: str) -> Accelerator:
171246 # Direct TPU name (bare): "v5litepod" → default chips
172247 if s in TPUS :
173248 return make_tpu (s , TPUS [s ].default_chips )
249+ if s in _TPU_ALIASES :
250+ name = _TPU_ALIASES [s ]
251+ return make_tpu (name , TPUS [name ].default_chips )
174252
175253 # TPU with topology string: "v5litepod-2x2", "v5p-2x2x2"
176254 m = _TPU_TOPO_RE .match (s )
177- if m and m . group ( 1 ) in TPUS :
255+ if m :
178256 name = m .group (1 )
179- topo_str = m .group (2 )
180- for chips , topo_spec in TPUS [name ].topologies .items ():
181- if topo_spec .topology == topo_str :
182- return make_tpu (name , chips )
183- valid = [ts .topology for ts in TPUS [name ].topologies .values ()]
184- raise ValueError (
185- f"Topology '{ topo_str } ' not supported for '{ name } '. "
186- f"Supported: { ', ' .join (valid )} ."
187- )
257+ if name in _TPU_ALIASES :
258+ name = _TPU_ALIASES [name ]
259+ if name in TPUS :
260+ topo_str = m .group (2 )
261+ for chips , topo_spec in TPUS [name ].topologies .items ():
262+ if topo_spec .topology == topo_str :
263+ return make_tpu (name , chips )
264+ valid = [ts .topology for ts in TPUS [name ].topologies .values ()]
265+ raise ValueError (
266+ f"Topology '{ topo_str } ' not supported for '{ name } '. "
267+ f"Supported: { ', ' .join (valid )} ."
268+ )
188269
189270 # TPU with chip count: "v3-8", "v5litepod-4"
190271 m = _TPU_CHIPS_RE .match (s )
191- if m and m .group (1 ) in TPUS :
192- return make_tpu (m .group (1 ), int (m .group (2 )))
272+ if m :
273+ name = m .group (1 )
274+ if name in _TPU_ALIASES :
275+ name = _TPU_ALIASES [name ]
276+ if name in TPUS :
277+ return make_tpu (name , int (m .group (2 )))
193278
194279 raise ValueError (
195280 f"Unknown accelerator: '{ accel_str } '. "
0 commit comments