Skip to content

Commit dac47b8

Browse files
address review comments
1 parent d7fbf7e commit dac47b8

File tree

3 files changed

+182
-105
lines changed

3 files changed

+182
-105
lines changed

keras_remote/cli/infra/program.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -232,7 +232,7 @@ def _create_gpu_node_pool(cluster, gpu: GpuConfig, zone, project_id, pool_name):
232232
guest_accelerators=[
233233
gcp.container.NodePoolNodeConfigGuestAcceleratorArgs(
234234
type=gpu.gke_label,
235-
count=1,
235+
count=gpu.count,
236236
),
237237
],
238238
labels={RESOURCE_NAME_PREFIX: "true"},

keras_remote/core/accelerators.py

Lines changed: 114 additions & 73 deletions
Original file line numberDiff line numberDiff line change
@@ -40,8 +40,7 @@ class GpuSpec:
4040
"""Registry entry for a GPU type."""
4141

4242
gke_label: str
43-
machine_type: str
44-
counts: tuple[int, ...]
43+
counts: dict[int, str] # count -> machine_type
4544

4645

4746
@dataclass(frozen=True)
@@ -63,14 +62,53 @@ class TpuSpec:
6362

6463

6564
GPUS: dict[str, GpuSpec] = {
66-
"l4": GpuSpec("nvidia-l4", "g2-standard-4", (1, 2, 4, 8)),
67-
"t4": GpuSpec("nvidia-tesla-t4", "n1-standard-4", (1, 2, 4)),
68-
"v100": GpuSpec("nvidia-tesla-v100", "n1-standard-8", (1, 2, 4, 8)),
69-
"a100": GpuSpec("nvidia-tesla-a100", "a2-highgpu-1g", (1, 2, 4, 8, 16)),
70-
"a100-80gb": GpuSpec("nvidia-a100-80gb", "a2-ultragpu-1g", (1, 2, 4, 8, 16)),
71-
"h100": GpuSpec("nvidia-h100-80gb", "a3-highgpu-1g", (1, 2, 4, 8)),
72-
"p4": GpuSpec("nvidia-tesla-p4", "n1-standard-4", (1, 2, 4)),
73-
"p100": GpuSpec("nvidia-tesla-p100", "n1-standard-4", (1, 2, 4)),
65+
"l4": GpuSpec("nvidia-l4", {
66+
1: "g2-standard-4",
67+
2: "g2-standard-24",
68+
4: "g2-standard-48",
69+
8: "g2-standard-96",
70+
}),
71+
"t4": GpuSpec("nvidia-tesla-t4", {
72+
1: "n1-standard-4",
73+
2: "n1-standard-8",
74+
4: "n1-standard-16",
75+
}),
76+
"v100": GpuSpec("nvidia-tesla-v100", {
77+
1: "n1-standard-8",
78+
2: "n1-standard-16",
79+
4: "n1-standard-32",
80+
8: "n1-standard-64",
81+
}),
82+
"a100": GpuSpec("nvidia-tesla-a100", {
83+
1: "a2-highgpu-1g",
84+
2: "a2-highgpu-2g",
85+
4: "a2-highgpu-4g",
86+
8: "a2-highgpu-8g",
87+
16: "a2-megagpu-16g",
88+
}),
89+
"a100-80gb": GpuSpec("nvidia-a100-80gb", {
90+
1: "a2-ultragpu-1g",
91+
2: "a2-ultragpu-2g",
92+
4: "a2-ultragpu-4g",
93+
8: "a2-ultragpu-8g",
94+
16: "a2-ultragpu-16g",
95+
}),
96+
"h100": GpuSpec("nvidia-h100-80gb", {
97+
1: "a3-highgpu-1g",
98+
2: "a3-highgpu-2g",
99+
4: "a3-highgpu-4g",
100+
8: "a3-highgpu-8g",
101+
}),
102+
"p4": GpuSpec("nvidia-tesla-p4", {
103+
1: "n1-standard-4",
104+
2: "n1-standard-8",
105+
4: "n1-standard-16",
106+
}),
107+
"p100": GpuSpec("nvidia-tesla-p100", {
108+
1: "n1-standard-4",
109+
2: "n1-standard-8",
110+
4: "n1-standard-16",
111+
}),
74112
}
75113

76114
_GPU_ALIASES: dict[str, str] = {
@@ -83,19 +121,6 @@ class TpuSpec:
83121
# Machine-type suffix "-Nt" → N chips per VM (e.g. ct5p-hightpu-4t → 4 chips).
84122
# v5p uses 3-D topologies (AxBxC); v2, v3, v5litepod, v6e use 2-D (AxB).
85123
TPUS: dict[str, TpuSpec] = {
86-
"v2": TpuSpec(
87-
"tpu-v2-podslice",
88-
4,
89-
{
90-
4: TpuTopologySpec("2x2", "ct2-hightpu-4t", 1),
91-
16: TpuTopologySpec("4x4", "ct2-hightpu-4t", 4),
92-
32: TpuTopologySpec("4x8", "ct2-hightpu-4t", 8),
93-
64: TpuTopologySpec("8x8", "ct2-hightpu-4t", 16),
94-
128: TpuTopologySpec("8x16", "ct2-hightpu-4t", 32),
95-
256: TpuTopologySpec("16x16", "ct2-hightpu-4t", 64),
96-
512: TpuTopologySpec("16x32", "ct2-hightpu-4t", 128),
97-
},
98-
),
99124
"v3": TpuSpec(
100125
"tpu-v3-podslice",
101126
4,
@@ -148,6 +173,7 @@ class TpuSpec:
148173
{
149174
8: TpuTopologySpec("2x2x2", "ct5p-hightpu-4t", 2),
150175
16: TpuTopologySpec("2x2x4", "ct5p-hightpu-4t", 4),
176+
32: TpuTopologySpec("2x4x4", "ct5p-hightpu-4t", 8),
151177
},
152178
),
153179
"v6e": TpuSpec(
@@ -165,13 +191,11 @@ class TpuSpec:
165191
}
166192

167193

168-
# ── Parser ────────────────────────────────────────────────────────
169-
170-
_MULTI_GPU_RE = re.compile(r"^(.+?)(?:x|-)(\d+)$") # "a100x4", "l4-2"
171-
_TPU_CHIPS_RE = re.compile(r"^([a-z0-9_]+)-(\d+)$") # "v3-8", "v5litepod-16"
194+
_MULTI_GPU_RE = re.compile(r"^([^x]+)(?:x)(\d+)$") # "a100x4"
195+
_TPU_CHIPS_RE = re.compile(r"^([a-z0-9_]+)-(\d+)$") # "v3-8"
172196
_TPU_TOPO_RE = re.compile(
173197
r"^([a-z0-9_]+)-(\d+x\d+(?:x\d+)?)$"
174-
) # "v5litepod-2x2", "v5p-2x2x2"
198+
) # "v5litepod-2x2"
175199

176200
DEFAULT_GPU = "l4"
177201
DEFAULT_TPU = "v5litepod"
@@ -186,7 +210,7 @@ class TpuSpec:
186210
"p100",
187211
"p4",
188212
]
189-
_PREFERRED_TPUS = ["v6e", "v5p", "v5litepod", "v4", "v3", "v2"]
213+
_PREFERRED_TPUS = ["v6e", "v5p", "v5litepod", "v4", "v3"]
190214

191215

192216
def _resolve_gpu_alias(name: str) -> str:
@@ -203,13 +227,25 @@ def parse_accelerator(accel_str: str) -> Accelerator:
203227
Returns GpuConfig, TpuConfig, or None (for "cpu").
204228
205229
Accepted formats:
206-
GPU: "l4", "gpu", "gpu-4", "a100x4", "l4-2", "a100-80gbx8"
207-
TPU: "v3-8", "tpu", "tpu-8", "v5litepod-2x2", "v5litepod"
208-
CPU: "cpu", "cpu-8"
230+
- Generic: "gpu", "tpu", "cpu" (resolves to defaults)
231+
- Dynamic Count: "gpu:4", "tpu:8", "cpu:8" (assigns most capable hardware matching the count)
232+
- Explicit GPU Name: "gpu:l4", "l4", "gpu:a100-80gb" (resolves to 1 instance of the specified GPU)
233+
- Multi-GPU Name: "gpu:a100x4", "a100x4", "gpu:l4-2" (resolves to N instances of the specified GPU)
234+
- Explicit TPU Name: "tpu:v5litepod", "v5litepod" (resolves to the default topology/chips for the TPU)
235+
- Explicit TPU Topology/Chips: "tpu:v3-8", "tpu:v5litepod-2x2", "v3-8" (resolves to the specified TPU slice)
236+
237+
Note: Prefixes ('gpu:' and 'tpu:') are recommended for complete disambiguation but are completely optional.
238+
239+
Dynamic Resolution:
240+
When using generic formats like "gpu:<N>" or "tpu:<N>", the parser
241+
dynamically assigns the most capable hardware type that supports the
242+
requested device count `N`. Hardware is selected based on an internal
243+
preference hierarchy (e.g., H100 > A100 > L4 for GPUs, and
244+
v6e > v5p > v5litepod for TPUs).
209245
"""
210246
s = accel_str.strip().lower()
211247

212-
if s == "cpu" or (s.startswith("cpu-") and s[4:].isdigit()):
248+
if s == "cpu" or (s.startswith("cpu:") and s[4:].isdigit()):
213249
return None
214250

215251
if s == "gpu":
@@ -218,42 +254,56 @@ def parse_accelerator(accel_str: str) -> Accelerator:
218254
if s == "tpu":
219255
return make_tpu(DEFAULT_TPU, TPUS[DEFAULT_TPU].default_chips)
220256

221-
if s.startswith("gpu-") and s[4:].isdigit():
222-
count = int(s[4:])
223-
search_order = _PREFERRED_GPUS
224-
for gpu_name in search_order:
257+
# 1) Try parsing as GPU
258+
is_gpu_explicit = s.startswith("gpu:")
259+
gpu_str = s[4:] if is_gpu_explicit else s
260+
261+
if gpu_str.isdigit():
262+
count = int(gpu_str)
263+
for gpu_name in _PREFERRED_GPUS:
225264
if gpu_name in GPUS and count in GPUS[gpu_name].counts:
226265
return make_gpu(gpu_name, count)
227-
valid_counts = sorted(set(c for spec in GPUS.values() for c in spec.counts))
228-
raise ValueError(
229-
f"No GPU supports count {count}. Supported counts across all GPUs: {valid_counts}"
230-
)
231-
232-
if s.startswith("tpu-") and s[4:].isdigit():
233-
chips = int(s[4:])
234-
search_order = _PREFERRED_TPUS
235-
for tpu_name in search_order:
236-
if tpu_name in TPUS and chips in TPUS[tpu_name].topologies:
237-
return make_tpu(tpu_name, chips)
238-
valid_chips = sorted(
239-
set(c for spec in TPUS.values() for c in spec.topologies)
240-
)
241-
raise ValueError(
242-
f"No TPU supports {chips} chips. Supported chip counts across all TPUs: {valid_chips}"
243-
)
266+
if is_gpu_explicit:
267+
valid_counts = sorted(set(c for spec in GPUS.values() for c in spec.counts))
268+
raise ValueError(
269+
f"No GPU supports count {count}. Supported counts: {valid_counts}"
270+
)
244271

245-
# Direct GPU name: "l4", "a100-80gb"
246-
name = _resolve_gpu_alias(s)
272+
name = _resolve_gpu_alias(gpu_str)
247273
if name in GPUS:
248274
return make_gpu(name, 1)
249275

250-
# Direct TPU name (bare): "v5litepod" → default chips
251-
name = _resolve_tpu_alias(s)
276+
m = _MULTI_GPU_RE.match(gpu_str)
277+
if m:
278+
name = _resolve_gpu_alias(m.group(1))
279+
if name in GPUS:
280+
return make_gpu(name, int(m.group(2)))
281+
282+
if is_gpu_explicit:
283+
raise ValueError(f"Unknown GPU accelerator: '{accel_str}'")
284+
285+
# 2) Try parsing as TPU
286+
is_tpu_explicit = s.startswith("tpu:")
287+
tpu_str = s[4:] if is_tpu_explicit else s
288+
289+
if tpu_str.isdigit():
290+
chips = int(tpu_str)
291+
for tpu_name in _PREFERRED_TPUS:
292+
if tpu_name in TPUS and chips in TPUS[tpu_name].topologies:
293+
return make_tpu(tpu_name, chips)
294+
if is_tpu_explicit:
295+
valid_chips = sorted(
296+
set(c for spec in TPUS.values() for c in spec.topologies)
297+
)
298+
raise ValueError(
299+
f"No TPU supports {chips} chips. Supported chip counts: {valid_chips}"
300+
)
301+
302+
name = _resolve_tpu_alias(tpu_str)
252303
if name in TPUS:
253304
return make_tpu(name, TPUS[name].default_chips)
254305

255-
# TPU with topology string: "v5litepod-2x2", "v5p-2x2x2"
256-
m = _TPU_TOPO_RE.match(s)
306+
m = _TPU_TOPO_RE.match(tpu_str)
257307
if m:
258308
name = _resolve_tpu_alias(m.group(1))
259309
if name in TPUS:
@@ -267,25 +317,16 @@ def parse_accelerator(accel_str: str) -> Accelerator:
267317
f"Supported: {', '.join(valid)}."
268318
)
269319

270-
# TPU with chip count: "v3-8", "v5litepod-4"
271-
m = _TPU_CHIPS_RE.match(s)
320+
m = _TPU_CHIPS_RE.match(tpu_str)
272321
if m:
273322
name = _resolve_tpu_alias(m.group(1))
274323
if name in TPUS:
275324
return make_tpu(name, int(m.group(2)))
276325

277-
# Multi-GPU: "a100x4", "l4x2"
278-
m = _MULTI_GPU_RE.match(s)
279-
if m:
280-
name = _resolve_gpu_alias(m.group(1))
281-
if name in GPUS:
282-
return make_gpu(name, int(m.group(2)))
283-
284326
raise ValueError(
285327
f"Unknown accelerator: '{accel_str}'. "
286-
f"GPUs: {', '.join(GPUS)} (use 'xN' for multi-GPU, e.g. 'a100x4'). "
287-
f"TPUs: {', '.join(TPUS)} (use '-N' for chips, e.g. 'v3-8', "
288-
f"or '-NxM' for topology, e.g. 'v5litepod-2x2')."
328+
f"GPUs: {', '.join(GPUS)} (use 'gpu:name' or 'gpu:namexN'). "
329+
f"TPUs: {', '.join(TPUS)} (use 'tpu:name' or 'tpu:name-N')."
289330
)
290331

291332

@@ -324,7 +365,7 @@ def make_gpu(name: str, count: int) -> GpuConfig:
324365
name=name,
325366
count=count,
326367
gke_label=spec.gke_label,
327-
machine_type=spec.machine_type,
368+
machine_type=spec.counts[count],
328369
)
329370

330371

0 commit comments

Comments
 (0)