Skip to content

Commit 8f3495a

Browse files
Dynamic hardware detection (#71)
* simplify accelerator names * code reformat * address Gemini review * address Gemini comments * address review comments * remove *** * code reformat * address review comments * code reformat
1 parent 5d2e304 commit 8f3495a

File tree

3 files changed

+344
-89
lines changed

3 files changed

+344
-89
lines changed

keras_remote/cli/infra/program.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -232,7 +232,7 @@ def _create_gpu_node_pool(cluster, gpu: GpuConfig, zone, project_id, pool_name):
232232
guest_accelerators=[
233233
gcp.container.NodePoolNodeConfigGuestAcceleratorArgs(
234234
type=gpu.gke_label,
235-
count=1,
235+
count=gpu.count,
236236
),
237237
],
238238
labels={RESOURCE_NAME_PREFIX: "true"},

keras_remote/core/accelerators.py

Lines changed: 222 additions & 65 deletions
Original file line numberDiff line numberDiff line change
@@ -40,8 +40,7 @@ class GpuSpec:
4040
"""Registry entry for a GPU type."""
4141

4242
gke_label: str
43-
machine_type: str
44-
counts: tuple[int, ...]
43+
counts: dict[int, str] # count -> machine_type
4544

4645

4746
@dataclass(frozen=True)
@@ -63,12 +62,77 @@ class TpuSpec:
6362

6463

6564
GPUS: dict[str, GpuSpec] = {
66-
"l4": GpuSpec("nvidia-l4", "g2-standard-4", (1, 2, 4)),
67-
"t4": GpuSpec("nvidia-tesla-t4", "n1-standard-4", (1, 2, 4)),
68-
"v100": GpuSpec("nvidia-tesla-v100", "n1-standard-8", (1, 2, 4, 8)),
69-
"a100": GpuSpec("nvidia-tesla-a100", "a2-highgpu-1g", (1, 2, 4, 8)),
70-
"a100-80gb": GpuSpec("nvidia-a100-80gb", "a2-ultragpu-1g", (1, 2, 4, 8)),
71-
"h100": GpuSpec("nvidia-h100-80gb", "a3-highgpu-1g", (1, 2, 4, 8)),
65+
"l4": GpuSpec(
66+
"nvidia-l4",
67+
{
68+
1: "g2-standard-4",
69+
2: "g2-standard-24",
70+
4: "g2-standard-48",
71+
8: "g2-standard-96",
72+
},
73+
),
74+
"t4": GpuSpec(
75+
"nvidia-tesla-t4",
76+
{
77+
1: "n1-standard-4",
78+
2: "n1-standard-8",
79+
4: "n1-standard-16",
80+
},
81+
),
82+
"v100": GpuSpec(
83+
"nvidia-tesla-v100",
84+
{
85+
1: "n1-standard-8",
86+
2: "n1-standard-16",
87+
4: "n1-standard-32",
88+
8: "n1-standard-64",
89+
},
90+
),
91+
"a100": GpuSpec(
92+
"nvidia-tesla-a100",
93+
{
94+
1: "a2-highgpu-1g",
95+
2: "a2-highgpu-2g",
96+
4: "a2-highgpu-4g",
97+
8: "a2-highgpu-8g",
98+
16: "a2-megagpu-16g",
99+
},
100+
),
101+
"a100-80gb": GpuSpec(
102+
"nvidia-a100-80gb",
103+
{
104+
1: "a2-ultragpu-1g",
105+
2: "a2-ultragpu-2g",
106+
4: "a2-ultragpu-4g",
107+
8: "a2-ultragpu-8g",
108+
16: "a2-ultragpu-16g",
109+
},
110+
),
111+
"h100": GpuSpec(
112+
"nvidia-h100-80gb",
113+
{
114+
1: "a3-highgpu-1g",
115+
2: "a3-highgpu-2g",
116+
4: "a3-highgpu-4g",
117+
8: "a3-highgpu-8g",
118+
},
119+
),
120+
"p4": GpuSpec(
121+
"nvidia-tesla-p4",
122+
{
123+
1: "n1-standard-4",
124+
2: "n1-standard-8",
125+
4: "n1-standard-16",
126+
},
127+
),
128+
"p100": GpuSpec(
129+
"nvidia-tesla-p100",
130+
{
131+
1: "n1-standard-4",
132+
2: "n1-standard-8",
133+
4: "n1-standard-16",
134+
},
135+
),
72136
}
73137

74138
_GPU_ALIASES: dict[str, str] = {
@@ -81,22 +145,36 @@ class TpuSpec:
81145
# Machine-type suffix "-Nt" → N chips per VM (e.g. ct5p-hightpu-4t → 4 chips).
82146
# v5p uses 3-D topologies (AxBxC); v2, v3, v5litepod, v6e use 2-D (AxB).
83147
TPUS: dict[str, TpuSpec] = {
84-
"v2": TpuSpec(
85-
"tpu-v2-podslice",
86-
4,
87-
{
88-
4: TpuTopologySpec("2x2", "ct2-hightpu-4t", 1),
89-
16: TpuTopologySpec("4x4", "ct2-hightpu-4t", 4),
90-
32: TpuTopologySpec("4x8", "ct2-hightpu-4t", 8),
91-
},
92-
),
93148
"v3": TpuSpec(
94149
"tpu-v3-podslice",
95150
4,
96151
{
97152
4: TpuTopologySpec("2x2", "ct3-hightpu-4t", 1),
98153
16: TpuTopologySpec("4x4", "ct3p-hightpu-4t", 4),
99154
32: TpuTopologySpec("4x8", "ct3p-hightpu-4t", 8),
155+
64: TpuTopologySpec("8x8", "ct3p-hightpu-4t", 16),
156+
128: TpuTopologySpec("8x16", "ct3p-hightpu-4t", 32),
157+
256: TpuTopologySpec("16x16", "ct3p-hightpu-4t", 64),
158+
512: TpuTopologySpec("16x32", "ct3p-hightpu-4t", 128),
159+
1024: TpuTopologySpec("32x32", "ct3p-hightpu-4t", 256),
160+
2048: TpuTopologySpec("32x64", "ct3p-hightpu-4t", 512),
161+
},
162+
),
163+
"v4": TpuSpec(
164+
"tpu-v4-podslice",
165+
4,
166+
{
167+
4: TpuTopologySpec("2x2x1", "ct4p-hightpu-4t", 1),
168+
8: TpuTopologySpec("2x2x2", "ct4p-hightpu-4t", 2),
169+
16: TpuTopologySpec("2x2x4", "ct4p-hightpu-4t", 4),
170+
32: TpuTopologySpec("2x4x4", "ct4p-hightpu-4t", 8),
171+
64: TpuTopologySpec("4x4x4", "ct4p-hightpu-4t", 16),
172+
128: TpuTopologySpec("4x4x8", "ct4p-hightpu-4t", 32),
173+
256: TpuTopologySpec("4x8x8", "ct4p-hightpu-4t", 64),
174+
512: TpuTopologySpec("8x8x8", "ct4p-hightpu-4t", 128),
175+
1024: TpuTopologySpec("8x8x16", "ct4p-hightpu-4t", 256),
176+
2048: TpuTopologySpec("8x16x16", "ct4p-hightpu-4t", 512),
177+
4096: TpuTopologySpec("16x16x16", "ct4p-hightpu-4t", 1024),
100178
},
101179
),
102180
"v5litepod": TpuSpec(
@@ -106,6 +184,11 @@ class TpuSpec:
106184
1: TpuTopologySpec("1x1", "ct5lp-hightpu-1t", 1),
107185
4: TpuTopologySpec("2x2", "ct5lp-hightpu-4t", 1),
108186
8: TpuTopologySpec("2x4", "ct5lp-hightpu-8t", 1),
187+
16: TpuTopologySpec("4x4", "ct5lp-hightpu-4t", 4),
188+
32: TpuTopologySpec("4x8", "ct5lp-hightpu-4t", 8),
189+
64: TpuTopologySpec("8x8", "ct5lp-hightpu-4t", 16),
190+
128: TpuTopologySpec("8x16", "ct5lp-hightpu-4t", 32),
191+
256: TpuTopologySpec("16x16", "ct5lp-hightpu-4t", 64),
109192
},
110193
),
111194
"v5p": TpuSpec(
@@ -114,6 +197,7 @@ class TpuSpec:
114197
{
115198
8: TpuTopologySpec("2x2x2", "ct5p-hightpu-4t", 2),
116199
16: TpuTopologySpec("2x2x4", "ct5p-hightpu-4t", 4),
200+
32: TpuTopologySpec("2x4x4", "ct5p-hightpu-4t", 8),
117201
},
118202
),
119203
"v6e": TpuSpec(
@@ -126,14 +210,39 @@ class TpuSpec:
126210
),
127211
}
128212

213+
_TPU_ALIASES: dict[str, str] = {
214+
"v5e": "v5litepod",
215+
}
129216

130-
# ── Parser ────────────────────────────────────────────────────────
131217

132-
_MULTI_GPU_RE = re.compile(r"^(.+?)x(\d+)$") # "a100x4"
133-
_TPU_CHIPS_RE = re.compile(r"^(v\d+\w*)-(\d+)$") # "v3-8"
218+
_MULTI_GPU_RE = re.compile(r"^([^x]+)(?:x)(\d+)$") # "a100x4"
219+
_TPU_CHIPS_RE = re.compile(r"^([a-z0-9_]+)-(\d+)$") # "v3-8"
134220
_TPU_TOPO_RE = re.compile(
135-
r"^(v\d+\w*)-(\d+x\d+(?:x\d+)?)$"
136-
) # "v5litepod-2x2", "v5p-2x2x2"
221+
r"^([a-z0-9_]+)-(\d+x\d+(?:x\d+)?)$"
222+
) # "v5litepod-2x2"
223+
224+
DEFAULT_GPU = "l4"
225+
DEFAULT_TPU = "v5litepod"
226+
227+
_PREFERRED_GPUS = [
228+
"h100",
229+
"a100-80gb",
230+
"a100",
231+
"l4",
232+
"v100",
233+
"t4",
234+
"p100",
235+
"p4",
236+
]
237+
_PREFERRED_TPUS = ["v6e", "v5p", "v5litepod", "v4", "v3"]
238+
239+
240+
def _resolve_gpu_alias(name: str) -> str:
241+
return _GPU_ALIASES.get(name, name)
242+
243+
244+
def _resolve_tpu_alias(name: str) -> str:
245+
return _TPU_ALIASES.get(name, name)
137246

138247

139248
def parse_accelerator(accel_str: str) -> Accelerator:
@@ -142,60 +251,108 @@ def parse_accelerator(accel_str: str) -> Accelerator:
142251
Returns GpuConfig, TpuConfig, or None (for "cpu").
143252
144253
Accepted formats:
145-
GPU: "l4", "nvidia-l4", "a100x4", "a100-80gbx8"
146-
TPU: "v3-8" (chip count), "v5litepod-2x2" (topology), "v5litepod" (default)
147-
CPU: "cpu"
254+
- Generic: "gpu", "tpu", "cpu" (resolves to defaults)
255+
- Dynamic Count: "gpu:4", "tpu:8", "cpu:8" (assigns most capable hardware matching the count)
256+
- Explicit GPU Name: "gpu:l4", "l4", "gpu:a100-80gb" (resolves to 1 instance of the specified GPU)
257+
- Multi-GPU Name: "gpu:a100x4", "a100x4", "gpu:l4-2" (resolves to N instances of the specified GPU)
258+
- Explicit TPU Name: "tpu:v5litepod", "v5litepod" (resolves to the default topology/chips for the TPU)
259+
- Explicit TPU Topology/Chips: "tpu:v3-8", "tpu:v5litepod-2x2", "v3-8" (resolves to the specified TPU slice)
260+
261+
Note: Prefixes ('gpu:' and 'tpu:') are recommended for complete disambiguation but are completely optional.
262+
263+
Dynamic Resolution:
264+
When using generic formats like "gpu:<N>" or "tpu:<N>", the parser
265+
dynamically assigns the most capable hardware type that supports the
266+
requested device count `N`. Hardware is selected based on an internal
267+
preference hierarchy (e.g., H100 > A100 > L4 for GPUs, and
268+
v6e > v5p > v5litepod for TPUs).
148269
"""
149270
s = accel_str.strip().lower()
150271

151-
if s == "cpu":
272+
if s == "cpu" or (s.startswith("cpu:") and s[4:].isdigit()):
152273
return None
153274

154-
# Direct GPU name: "l4", "a100-80gb"
155-
if s in GPUS:
156-
return make_gpu(s, 1)
157-
158-
# GPU alias: "nvidia-l4"
159-
if s in _GPU_ALIASES:
160-
return make_gpu(_GPU_ALIASES[s], 1)
161-
162-
# Multi-GPU: "a100x4", "l4x2"
163-
m = _MULTI_GPU_RE.match(s)
275+
if s == "gpu":
276+
return make_gpu(DEFAULT_GPU, 1)
277+
278+
if s == "tpu":
279+
return make_tpu(DEFAULT_TPU, TPUS[DEFAULT_TPU].default_chips)
280+
281+
# 1) Try parsing as GPU
282+
is_gpu_explicit = s.startswith("gpu:")
283+
gpu_str = s[4:] if is_gpu_explicit else s
284+
285+
if gpu_str.isdigit():
286+
count = int(gpu_str)
287+
for gpu_name in _PREFERRED_GPUS:
288+
if gpu_name in GPUS and count in GPUS[gpu_name].counts:
289+
return make_gpu(gpu_name, count)
290+
if is_gpu_explicit:
291+
valid_counts = sorted(
292+
set(c for spec in GPUS.values() for c in spec.counts)
293+
)
294+
raise ValueError(
295+
f"No GPU supports count {count}. Supported counts: {valid_counts}"
296+
)
297+
298+
name = _resolve_gpu_alias(gpu_str)
299+
if name in GPUS:
300+
return make_gpu(name, 1)
301+
302+
m = _MULTI_GPU_RE.match(gpu_str)
164303
if m:
165-
name = m.group(1)
304+
name = _resolve_gpu_alias(m.group(1))
166305
if name in GPUS:
167306
return make_gpu(name, int(m.group(2)))
168-
if name in _GPU_ALIASES:
169-
return make_gpu(_GPU_ALIASES[name], int(m.group(2)))
170-
171-
# Direct TPU name (bare): "v5litepod" → default chips
172-
if s in TPUS:
173-
return make_tpu(s, TPUS[s].default_chips)
174-
175-
# TPU with topology string: "v5litepod-2x2", "v5p-2x2x2"
176-
m = _TPU_TOPO_RE.match(s)
177-
if m and m.group(1) in TPUS:
178-
name = m.group(1)
179-
topo_str = m.group(2)
180-
for chips, topo_spec in TPUS[name].topologies.items():
181-
if topo_spec.topology == topo_str:
182-
return make_tpu(name, chips)
183-
valid = [ts.topology for ts in TPUS[name].topologies.values()]
184-
raise ValueError(
185-
f"Topology '{topo_str}' not supported for '{name}'. "
186-
f"Supported: {', '.join(valid)}."
187-
)
188307

189-
# TPU with chip count: "v3-8", "v5litepod-4"
190-
m = _TPU_CHIPS_RE.match(s)
191-
if m and m.group(1) in TPUS:
192-
return make_tpu(m.group(1), int(m.group(2)))
308+
if is_gpu_explicit:
309+
raise ValueError(f"Unknown GPU accelerator: '{accel_str}'")
310+
311+
# 2) Try parsing as TPU
312+
is_tpu_explicit = s.startswith("tpu:")
313+
tpu_str = s[4:] if is_tpu_explicit else s
314+
315+
if tpu_str.isdigit():
316+
chips = int(tpu_str)
317+
for tpu_name in _PREFERRED_TPUS:
318+
if tpu_name in TPUS and chips in TPUS[tpu_name].topologies:
319+
return make_tpu(tpu_name, chips)
320+
if is_tpu_explicit:
321+
valid_chips = sorted(
322+
set(c for spec in TPUS.values() for c in spec.topologies)
323+
)
324+
raise ValueError(
325+
f"No TPU supports {chips} chips. Supported chip counts: {valid_chips}"
326+
)
327+
328+
name = _resolve_tpu_alias(tpu_str)
329+
if name in TPUS:
330+
return make_tpu(name, TPUS[name].default_chips)
331+
332+
m = _TPU_TOPO_RE.match(tpu_str)
333+
if m:
334+
name = _resolve_tpu_alias(m.group(1))
335+
if name in TPUS:
336+
topo_str = m.group(2)
337+
for chips, topo_spec in TPUS[name].topologies.items():
338+
if topo_spec.topology == topo_str:
339+
return make_tpu(name, chips)
340+
valid = [ts.topology for ts in TPUS[name].topologies.values()]
341+
raise ValueError(
342+
f"Topology '{topo_str}' not supported for '{name}'. "
343+
f"Supported: {', '.join(valid)}."
344+
)
345+
346+
m = _TPU_CHIPS_RE.match(tpu_str)
347+
if m:
348+
name = _resolve_tpu_alias(m.group(1))
349+
if name in TPUS:
350+
return make_tpu(name, int(m.group(2)))
193351

194352
raise ValueError(
195353
f"Unknown accelerator: '{accel_str}'. "
196-
f"GPUs: {', '.join(GPUS)} (use 'xN' for multi-GPU, e.g. 'a100x4'). "
197-
f"TPUs: {', '.join(TPUS)} (use '-N' for chips, e.g. 'v3-8', "
198-
f"or '-NxM' for topology, e.g. 'v5litepod-2x2')."
354+
f"GPUs: {', '.join(GPUS)} (use 'gpu:name' or 'gpu:namexN'). "
355+
f"TPUs: {', '.join(TPUS)} (use 'tpu:name' or 'tpu:name-N')."
199356
)
200357

201358

@@ -234,7 +391,7 @@ def make_gpu(name: str, count: int) -> GpuConfig:
234391
name=name,
235392
count=count,
236393
gke_label=spec.gke_label,
237-
machine_type=spec.machine_type,
394+
machine_type=spec.counts[count],
238395
)
239396

240397

0 commit comments

Comments
 (0)