Skip to content

Commit 64e6058

Browse files
simplify accelerator names
1 parent 22252fd commit 64e6058

File tree

2 files changed

+177
-24
lines changed

2 files changed

+177
-24
lines changed

keras_remote/core/accelerators.py

Lines changed: 107 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -63,12 +63,14 @@ class TpuSpec:
6363

6464

6565
GPUS: dict[str, GpuSpec] = {
66-
"l4": GpuSpec("nvidia-l4", "g2-standard-4", (1, 2, 4)),
66+
"l4": GpuSpec("nvidia-l4", "g2-standard-4", (1, 2, 4, 8)),
6767
"t4": GpuSpec("nvidia-tesla-t4", "n1-standard-4", (1, 2, 4)),
6868
"v100": GpuSpec("nvidia-tesla-v100", "n1-standard-8", (1, 2, 4, 8)),
69-
"a100": GpuSpec("nvidia-tesla-a100", "a2-highgpu-1g", (1, 2, 4, 8)),
70-
"a100-80gb": GpuSpec("nvidia-a100-80gb", "a2-ultragpu-1g", (1, 2, 4, 8)),
69+
"a100": GpuSpec("nvidia-tesla-a100", "a2-highgpu-1g", (1, 2, 4, 8, 16)),
70+
"a100-80gb": GpuSpec("nvidia-a100-80gb", "a2-ultragpu-1g", (1, 2, 4, 8, 16)),
7171
"h100": GpuSpec("nvidia-h100-80gb", "a3-highgpu-1g", (1, 2, 4, 8)),
72+
"p4": GpuSpec("nvidia-tesla-p4", "n1-standard-4", (1, 2, 4)),
73+
"p100": GpuSpec("nvidia-tesla-p100", "n1-standard-4", (1, 2, 4)),
7274
}
7375

7476
_GPU_ALIASES: dict[str, str] = {
@@ -88,6 +90,10 @@ class TpuSpec:
8890
4: TpuTopologySpec("2x2", "ct2-hightpu-4t", 1),
8991
16: TpuTopologySpec("4x4", "ct2-hightpu-4t", 4),
9092
32: TpuTopologySpec("4x8", "ct2-hightpu-4t", 8),
93+
64: TpuTopologySpec("8x8", "ct2-hightpu-4t", 16),
94+
128: TpuTopologySpec("8x16", "ct2-hightpu-4t", 32),
95+
256: TpuTopologySpec("16x16", "ct2-hightpu-4t", 64),
96+
512: TpuTopologySpec("16x32", "ct2-hightpu-4t", 128),
9197
},
9298
),
9399
"v3": TpuSpec(
@@ -97,6 +103,29 @@ class TpuSpec:
97103
4: TpuTopologySpec("2x2", "ct3-hightpu-4t", 1),
98104
16: TpuTopologySpec("4x4", "ct3p-hightpu-4t", 4),
99105
32: TpuTopologySpec("4x8", "ct3p-hightpu-4t", 8),
106+
64: TpuTopologySpec("8x8", "ct3p-hightpu-4t", 16),
107+
128: TpuTopologySpec("8x16", "ct3p-hightpu-4t", 32),
108+
256: TpuTopologySpec("16x16", "ct3p-hightpu-4t", 64),
109+
512: TpuTopologySpec("16x32", "ct3p-hightpu-4t", 128),
110+
1024: TpuTopologySpec("32x32", "ct3p-hightpu-4t", 256),
111+
2048: TpuTopologySpec("32x64", "ct3p-hightpu-4t", 512),
112+
},
113+
),
114+
"v4": TpuSpec(
115+
"tpu-v4-podslice",
116+
4,
117+
{
118+
4: TpuTopologySpec("2x2x1", "ct4p-hightpu-4t", 1),
119+
8: TpuTopologySpec("2x2x2", "ct4p-hightpu-4t", 2),
120+
16: TpuTopologySpec("2x2x4", "ct4p-hightpu-4t", 4),
121+
32: TpuTopologySpec("2x4x4", "ct4p-hightpu-4t", 8),
122+
64: TpuTopologySpec("4x4x4", "ct4p-hightpu-4t", 16),
123+
128: TpuTopologySpec("4x4x8", "ct4p-hightpu-4t", 32),
124+
256: TpuTopologySpec("4x8x8", "ct4p-hightpu-4t", 64),
125+
512: TpuTopologySpec("8x8x8", "ct4p-hightpu-4t", 128),
126+
1024: TpuTopologySpec("8x8x16", "ct4p-hightpu-4t", 256),
127+
2048: TpuTopologySpec("8x16x16", "ct4p-hightpu-4t", 512),
128+
4096: TpuTopologySpec("16x16x16", "ct4p-hightpu-4t", 1024),
100129
},
101130
),
102131
"v5litepod": TpuSpec(
@@ -106,6 +135,11 @@ class TpuSpec:
106135
1: TpuTopologySpec("1x1", "ct5lp-hightpu-1t", 1),
107136
4: TpuTopologySpec("2x2", "ct5lp-hightpu-4t", 1),
108137
8: TpuTopologySpec("2x4", "ct5lp-hightpu-8t", 1),
138+
16: TpuTopologySpec("4x4", "ct5lp-hightpu-4t", 4),
139+
32: TpuTopologySpec("4x8", "ct5lp-hightpu-4t", 8),
140+
64: TpuTopologySpec("8x8", "ct5lp-hightpu-4t", 16),
141+
128: TpuTopologySpec("8x16", "ct5lp-hightpu-4t", 32),
142+
256: TpuTopologySpec("16x16", "ct5lp-hightpu-4t", 64),
109143
},
110144
),
111145
"v5p": TpuSpec(
@@ -126,31 +160,72 @@ class TpuSpec:
126160
),
127161
}
128162

163+
_TPU_ALIASES: dict[str, str] = {
164+
"v5e": "v5litepod",
165+
"ghostlite": "v5litepod",
166+
}
167+
129168

130169
# ── Parser ────────────────────────────────────────────────────────
131170

132-
_MULTI_GPU_RE = re.compile(r"^(.+?)x(\d+)$") # "a100x4"
133-
_TPU_CHIPS_RE = re.compile(r"^(v\d+\w*)-(\d+)$") # "v3-8"
171+
_MULTI_GPU_RE = re.compile(r"^(.+?)(?:x|-)(\d+)$") # "a100x4", "l4-2"
172+
_TPU_CHIPS_RE = re.compile(r"^([a-z0-9_]+)-(\d+)$") # "v3-8", "ghostlite-16"
134173
_TPU_TOPO_RE = re.compile(
135-
r"^(v\d+\w*)-(\d+x\d+(?:x\d+)?)$"
174+
r"^([a-z0-9_]+)-(\d+x\d+(?:x\d+)?)$"
136175
) # "v5litepod-2x2", "v5p-2x2x2"
137176

177+
DEFAULT_GPU = "l4"
178+
DEFAULT_TPU = "v5litepod"
179+
180+
_PREFERRED_GPUS = ["h100", "a100-80gb", "a100", "l4", "v100", "t4", "p100", "p4"]
181+
_PREFERRED_TPUS = ["v6e", "v5p", "v5litepod", "v4", "v3", "v2"]
182+
138183

139184
def parse_accelerator(accel_str: str) -> Accelerator:
140185
"""Parse an accelerator string into a fully resolved config.
141186
142187
Returns GpuConfig, TpuConfig, or None (for "cpu").
143188
144189
Accepted formats:
145-
GPU: "l4", "nvidia-l4", "a100x4", "a100-80gbx8"
146-
TPU: "v3-8" (chip count), "v5litepod-2x2" (topology), "v5litepod" (default)
147-
CPU: "cpu"
190+
GPU: "l4", "gpu", "gpu-4", "a100x4", "l4-2", "a100-80gbx8"
191+
TPU: "v3-8", "tpu", "tpu-8", "v5litepod-2x2", "v5litepod"
192+
CPU: "cpu", "cpu-8"
148193
"""
149194
s = accel_str.strip().lower()
150195

151-
if s == "cpu":
196+
if s == "cpu" or (s.startswith("cpu-") and s[4:].isdigit()):
152197
return None
153198

199+
if s == "gpu":
200+
return make_gpu(DEFAULT_GPU, 1)
201+
202+
if s == "tpu":
203+
return make_tpu(DEFAULT_TPU, TPUS[DEFAULT_TPU].default_chips)
204+
205+
if s.startswith("gpu-") and s[4:].isdigit():
206+
count = int(s[4:])
207+
if count in GPUS[DEFAULT_GPU].counts:
208+
return make_gpu(DEFAULT_GPU, count)
209+
for gpu_name in _PREFERRED_GPUS:
210+
if gpu_name in GPUS and count in GPUS[gpu_name].counts:
211+
return make_gpu(gpu_name, count)
212+
valid_counts = sorted(set(c for spec in GPUS.values() for c in spec.counts))
213+
raise ValueError(
214+
f"No GPU supports count {count}. Supported counts across all GPUs: {valid_counts}"
215+
)
216+
217+
if s.startswith("tpu-") and s[4:].isdigit():
218+
chips = int(s[4:])
219+
if chips in TPUS[DEFAULT_TPU].topologies:
220+
return make_tpu(DEFAULT_TPU, chips)
221+
for tpu_name in _PREFERRED_TPUS:
222+
if tpu_name in TPUS and chips in TPUS[tpu_name].topologies:
223+
return make_tpu(tpu_name, chips)
224+
valid_chips = sorted(set(c for spec in TPUS.values() for c in spec.topologies))
225+
raise ValueError(
226+
f"No TPU supports {chips} chips. Supported chip counts across all TPUs: {valid_chips}"
227+
)
228+
154229
# Direct GPU name: "l4", "a100-80gb"
155230
if s in GPUS:
156231
return make_gpu(s, 1)
@@ -171,25 +246,35 @@ def parse_accelerator(accel_str: str) -> Accelerator:
171246
# Direct TPU name (bare): "v5litepod" → default chips
172247
if s in TPUS:
173248
return make_tpu(s, TPUS[s].default_chips)
249+
if s in _TPU_ALIASES:
250+
name = _TPU_ALIASES[s]
251+
return make_tpu(name, TPUS[name].default_chips)
174252

175253
# TPU with topology string: "v5litepod-2x2", "v5p-2x2x2"
176254
m = _TPU_TOPO_RE.match(s)
177-
if m and m.group(1) in TPUS:
255+
if m:
178256
name = m.group(1)
179-
topo_str = m.group(2)
180-
for chips, topo_spec in TPUS[name].topologies.items():
181-
if topo_spec.topology == topo_str:
182-
return make_tpu(name, chips)
183-
valid = [ts.topology for ts in TPUS[name].topologies.values()]
184-
raise ValueError(
185-
f"Topology '{topo_str}' not supported for '{name}'. "
186-
f"Supported: {', '.join(valid)}."
187-
)
257+
if name in _TPU_ALIASES:
258+
name = _TPU_ALIASES[name]
259+
if name in TPUS:
260+
topo_str = m.group(2)
261+
for chips, topo_spec in TPUS[name].topologies.items():
262+
if topo_spec.topology == topo_str:
263+
return make_tpu(name, chips)
264+
valid = [ts.topology for ts in TPUS[name].topologies.values()]
265+
raise ValueError(
266+
f"Topology '{topo_str}' not supported for '{name}'. "
267+
f"Supported: {', '.join(valid)}."
268+
)
188269

189270
# TPU with chip count: "v3-8", "v5litepod-4"
190271
m = _TPU_CHIPS_RE.match(s)
191-
if m and m.group(1) in TPUS:
192-
return make_tpu(m.group(1), int(m.group(2)))
272+
if m:
273+
name = m.group(1)
274+
if name in _TPU_ALIASES:
275+
name = _TPU_ALIASES[name]
276+
if name in TPUS:
277+
return make_tpu(name, int(m.group(2)))
193278

194279
raise ValueError(
195280
f"Unknown accelerator: '{accel_str}'. "

keras_remote/core/accelerators_test.py

Lines changed: 70 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,12 @@ def test_a100_80gbx4(self):
4444
self.assertEqual(result.name, "a100-80gb")
4545
self.assertEqual(result.count, 4)
4646

47+
def test_l4_dash_2(self):
48+
result = parse_accelerator("l4-2")
49+
self.assertIsInstance(result, GpuConfig)
50+
self.assertEqual(result.name, "l4")
51+
self.assertEqual(result.count, 2)
52+
4753

4854
class TestParseGpuAlias(absltest.TestCase):
4955
def test_nvidia_tesla_t4(self):
@@ -60,9 +66,9 @@ def test_nvidia_tesla_v100x4(self):
6066

6167

6268
class TestParseGpuErrors(absltest.TestCase):
63-
def test_l4x8_invalid_count(self):
69+
def test_l4x16_invalid_count(self):
6470
with self.assertRaisesRegex(ValueError, "not supported"):
65-
parse_accelerator("l4x8")
71+
parse_accelerator("l4x16")
6672

6773

6874
class TestParseTpuBare(parameterized.TestCase):
@@ -154,6 +160,68 @@ class TestParseCpu(absltest.TestCase):
154160
def test_cpu(self):
155161
self.assertIsNone(parse_accelerator("cpu"))
156162

163+
def test_cpu_with_count(self):
164+
self.assertIsNone(parse_accelerator("cpu-8"))
165+
166+
class TestParseGenericAliases(absltest.TestCase):
167+
def test_gpu_bare(self):
168+
result = parse_accelerator("gpu")
169+
self.assertIsInstance(result, GpuConfig)
170+
self.assertEqual(result.name, "l4")
171+
self.assertEqual(result.count, 1)
172+
173+
def test_tpu_bare(self):
174+
result = parse_accelerator("tpu")
175+
self.assertIsInstance(result, TpuConfig)
176+
self.assertEqual(result.name, "v5litepod")
177+
self.assertEqual(result.chips, 4)
178+
179+
def test_gpu_with_count(self):
180+
result = parse_accelerator("gpu-4")
181+
self.assertIsInstance(result, GpuConfig)
182+
self.assertEqual(result.name, "l4")
183+
self.assertEqual(result.count, 4)
184+
185+
def test_tpu_with_count(self):
186+
result = parse_accelerator("tpu-8")
187+
self.assertIsInstance(result, TpuConfig)
188+
self.assertEqual(result.name, "v5litepod")
189+
self.assertEqual(result.chips, 8)
190+
191+
def test_gpu_with_dynamic_count(self):
192+
# l4 supports up to 8 now. 16 should fall back to a100.
193+
result = parse_accelerator("gpu-16")
194+
self.assertIsInstance(result, GpuConfig)
195+
self.assertIn(result.name, ["a100", "a100-80gb"])
196+
self.assertEqual(result.count, 16)
197+
198+
def test_tpu_with_dynamic_count(self):
199+
# v5litepod supports up to 256. 4096 should fall back to v4.
200+
result = parse_accelerator("tpu-4096")
201+
self.assertIsInstance(result, TpuConfig)
202+
self.assertEqual(result.name, "v4")
203+
self.assertEqual(result.chips, 4096)
204+
205+
def test_v5e_alias(self):
206+
result = parse_accelerator("v5e-8")
207+
self.assertIsInstance(result, TpuConfig)
208+
self.assertEqual(result.name, "v5litepod")
209+
self.assertEqual(result.chips, 8)
210+
211+
def test_ghostlite_alias(self):
212+
result = parse_accelerator("ghostlite-16")
213+
self.assertIsInstance(result, TpuConfig)
214+
self.assertEqual(result.name, "v5litepod")
215+
self.assertEqual(result.chips, 16)
216+
217+
def test_gpu_unsupported_count(self):
218+
with self.assertRaisesRegex(ValueError, "No GPU supports count 32"):
219+
parse_accelerator("gpu-32")
220+
221+
def test_tpu_unsupported_count(self):
222+
with self.assertRaisesRegex(ValueError, "No TPU supports 8192 chips"):
223+
parse_accelerator("tpu-8192")
224+
157225

158226
class TestParseNormalizationAndErrors(absltest.TestCase):
159227
def test_whitespace_and_case(self):

0 commit comments

Comments
 (0)