Skip to content
2 changes: 1 addition & 1 deletion keras_remote/cli/infra/program.py
Original file line number Diff line number Diff line change
Expand Up @@ -232,7 +232,7 @@ def _create_gpu_node_pool(cluster, gpu: GpuConfig, zone, project_id, pool_name):
guest_accelerators=[
gcp.container.NodePoolNodeConfigGuestAcceleratorArgs(
type=gpu.gke_label,
count=1,
count=gpu.count,
),
],
labels={RESOURCE_NAME_PREFIX: "true"},
Expand Down
287 changes: 222 additions & 65 deletions keras_remote/core/accelerators.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,8 +40,7 @@ class GpuSpec:
"""Registry entry for a GPU type."""

gke_label: str
machine_type: str
counts: tuple[int, ...]
counts: dict[int, str] # count -> machine_type


@dataclass(frozen=True)
Expand All @@ -63,12 +62,77 @@ class TpuSpec:


GPUS: dict[str, GpuSpec] = {
"l4": GpuSpec("nvidia-l4", "g2-standard-4", (1, 2, 4)),
"t4": GpuSpec("nvidia-tesla-t4", "n1-standard-4", (1, 2, 4)),
"v100": GpuSpec("nvidia-tesla-v100", "n1-standard-8", (1, 2, 4, 8)),
"a100": GpuSpec("nvidia-tesla-a100", "a2-highgpu-1g", (1, 2, 4, 8)),
"a100-80gb": GpuSpec("nvidia-a100-80gb", "a2-ultragpu-1g", (1, 2, 4, 8)),
"h100": GpuSpec("nvidia-h100-80gb", "a3-highgpu-1g", (1, 2, 4, 8)),
"l4": GpuSpec(
"nvidia-l4",
{
1: "g2-standard-4",
2: "g2-standard-24",
4: "g2-standard-48",
8: "g2-standard-96",
},
),
"t4": GpuSpec(
"nvidia-tesla-t4",
{
1: "n1-standard-4",
2: "n1-standard-8",
4: "n1-standard-16",
},
),
"v100": GpuSpec(
"nvidia-tesla-v100",
{
1: "n1-standard-8",
2: "n1-standard-16",
4: "n1-standard-32",
8: "n1-standard-64",
},
),
"a100": GpuSpec(
"nvidia-tesla-a100",
{
1: "a2-highgpu-1g",
2: "a2-highgpu-2g",
4: "a2-highgpu-4g",
8: "a2-highgpu-8g",
16: "a2-megagpu-16g",
},
),
"a100-80gb": GpuSpec(
"nvidia-a100-80gb",
{
1: "a2-ultragpu-1g",
2: "a2-ultragpu-2g",
4: "a2-ultragpu-4g",
8: "a2-ultragpu-8g",
16: "a2-ultragpu-16g",
},
),
"h100": GpuSpec(
"nvidia-h100-80gb",
{
1: "a3-highgpu-1g",
2: "a3-highgpu-2g",
4: "a3-highgpu-4g",
8: "a3-highgpu-8g",
},
),
"p4": GpuSpec(
"nvidia-tesla-p4",
{
1: "n1-standard-4",
2: "n1-standard-8",
4: "n1-standard-16",
},
),
"p100": GpuSpec(
"nvidia-tesla-p100",
{
1: "n1-standard-4",
2: "n1-standard-8",
4: "n1-standard-16",
},
),
}

_GPU_ALIASES: dict[str, str] = {
Expand All @@ -81,22 +145,36 @@ class TpuSpec:
# Machine-type suffix "-Nt" → N chips per VM (e.g. ct5p-hightpu-4t → 4 chips).
# v5p uses 3-D topologies (AxBxC); v2, v3, v5litepod, v6e use 2-D (AxB).
TPUS: dict[str, TpuSpec] = {
"v2": TpuSpec(
"tpu-v2-podslice",
4,
{
4: TpuTopologySpec("2x2", "ct2-hightpu-4t", 1),
16: TpuTopologySpec("4x4", "ct2-hightpu-4t", 4),
32: TpuTopologySpec("4x8", "ct2-hightpu-4t", 8),
},
),
"v3": TpuSpec(
"tpu-v3-podslice",
4,
{
4: TpuTopologySpec("2x2", "ct3-hightpu-4t", 1),
16: TpuTopologySpec("4x4", "ct3p-hightpu-4t", 4),
32: TpuTopologySpec("4x8", "ct3p-hightpu-4t", 8),
64: TpuTopologySpec("8x8", "ct3p-hightpu-4t", 16),
128: TpuTopologySpec("8x16", "ct3p-hightpu-4t", 32),
256: TpuTopologySpec("16x16", "ct3p-hightpu-4t", 64),
512: TpuTopologySpec("16x32", "ct3p-hightpu-4t", 128),
1024: TpuTopologySpec("32x32", "ct3p-hightpu-4t", 256),
2048: TpuTopologySpec("32x64", "ct3p-hightpu-4t", 512),
},
),
"v4": TpuSpec(
"tpu-v4-podslice",
4,
{
4: TpuTopologySpec("2x2x1", "ct4p-hightpu-4t", 1),
8: TpuTopologySpec("2x2x2", "ct4p-hightpu-4t", 2),
16: TpuTopologySpec("2x2x4", "ct4p-hightpu-4t", 4),
32: TpuTopologySpec("2x4x4", "ct4p-hightpu-4t", 8),
64: TpuTopologySpec("4x4x4", "ct4p-hightpu-4t", 16),
128: TpuTopologySpec("4x4x8", "ct4p-hightpu-4t", 32),
256: TpuTopologySpec("4x8x8", "ct4p-hightpu-4t", 64),
512: TpuTopologySpec("8x8x8", "ct4p-hightpu-4t", 128),
1024: TpuTopologySpec("8x8x16", "ct4p-hightpu-4t", 256),
2048: TpuTopologySpec("8x16x16", "ct4p-hightpu-4t", 512),
4096: TpuTopologySpec("16x16x16", "ct4p-hightpu-4t", 1024),
},
),
"v5litepod": TpuSpec(
Expand All @@ -106,6 +184,11 @@ class TpuSpec:
1: TpuTopologySpec("1x1", "ct5lp-hightpu-1t", 1),
4: TpuTopologySpec("2x2", "ct5lp-hightpu-4t", 1),
8: TpuTopologySpec("2x4", "ct5lp-hightpu-8t", 1),
16: TpuTopologySpec("4x4", "ct5lp-hightpu-4t", 4),
32: TpuTopologySpec("4x8", "ct5lp-hightpu-4t", 8),
64: TpuTopologySpec("8x8", "ct5lp-hightpu-4t", 16),
128: TpuTopologySpec("8x16", "ct5lp-hightpu-4t", 32),
256: TpuTopologySpec("16x16", "ct5lp-hightpu-4t", 64),
},
),
"v5p": TpuSpec(
Expand All @@ -114,6 +197,7 @@ class TpuSpec:
{
8: TpuTopologySpec("2x2x2", "ct5p-hightpu-4t", 2),
16: TpuTopologySpec("2x2x4", "ct5p-hightpu-4t", 4),
32: TpuTopologySpec("2x4x4", "ct5p-hightpu-4t", 8),
},
),
"v6e": TpuSpec(
Expand All @@ -126,14 +210,39 @@ class TpuSpec:
),
}

_TPU_ALIASES: dict[str, str] = {
"v5e": "v5litepod",
}

# ── Parser ────────────────────────────────────────────────────────

_MULTI_GPU_RE = re.compile(r"^(.+?)x(\d+)$") # "a100x4"
_TPU_CHIPS_RE = re.compile(r"^(v\d+\w*)-(\d+)$") # "v3-8"
_MULTI_GPU_RE = re.compile(r"^([^x]+)(?:x)(\d+)$") # "a100x4"
_TPU_CHIPS_RE = re.compile(r"^([a-z0-9_]+)-(\d+)$") # "v3-8"
_TPU_TOPO_RE = re.compile(
r"^(v\d+\w*)-(\d+x\d+(?:x\d+)?)$"
) # "v5litepod-2x2", "v5p-2x2x2"
r"^([a-z0-9_]+)-(\d+x\d+(?:x\d+)?)$"
) # "v5litepod-2x2"

DEFAULT_GPU = "l4"
DEFAULT_TPU = "v5litepod"

_PREFERRED_GPUS = [
"h100",
"a100-80gb",
"a100",
"l4",
"v100",
"t4",
"p100",
"p4",
]
_PREFERRED_TPUS = ["v6e", "v5p", "v5litepod", "v4", "v3"]


def _resolve_gpu_alias(name: str) -> str:
return _GPU_ALIASES.get(name, name)


def _resolve_tpu_alias(name: str) -> str:
return _TPU_ALIASES.get(name, name)


def parse_accelerator(accel_str: str) -> Accelerator:
Expand All @@ -142,60 +251,108 @@ def parse_accelerator(accel_str: str) -> Accelerator:
Returns GpuConfig, TpuConfig, or None (for "cpu").

Accepted formats:
GPU: "l4", "nvidia-l4", "a100x4", "a100-80gbx8"
TPU: "v3-8" (chip count), "v5litepod-2x2" (topology), "v5litepod" (default)
CPU: "cpu"
- Generic: "gpu", "tpu", "cpu" (resolves to defaults)
- Dynamic Count: "gpu:4", "tpu:8", "cpu:8" (assigns most capable hardware matching the count)
- Explicit GPU Name: "gpu:l4", "l4", "gpu:a100-80gb" (resolves to 1 instance of the specified GPU)
- Multi-GPU Name: "gpu:a100x4", "a100x4", "gpu:l4-2" (resolves to N instances of the specified GPU)
- Explicit TPU Name: "tpu:v5litepod", "v5litepod" (resolves to the default topology/chips for the TPU)
- Explicit TPU Topology/Chips: "tpu:v3-8", "tpu:v5litepod-2x2", "v3-8" (resolves to the specified TPU slice)

Note: Prefixes ('gpu:' and 'tpu:') are recommended for complete disambiguation but are completely optional.

Dynamic Resolution:
When using generic formats like "gpu:<N>" or "tpu:<N>", the parser
dynamically assigns the most capable hardware type that supports the
requested device count `N`. Hardware is selected based on an internal
preference hierarchy (e.g., H100 > A100 > L4 for GPUs, and
v6e > v5p > v5litepod for TPUs).
"""
s = accel_str.strip().lower()

if s == "cpu":
if s == "cpu" or (s.startswith("cpu:") and s[4:].isdigit()):
return None

# Direct GPU name: "l4", "a100-80gb"
if s in GPUS:
return make_gpu(s, 1)

# GPU alias: "nvidia-l4"
if s in _GPU_ALIASES:
return make_gpu(_GPU_ALIASES[s], 1)

# Multi-GPU: "a100x4", "l4x2"
m = _MULTI_GPU_RE.match(s)
if s == "gpu":
return make_gpu(DEFAULT_GPU, 1)

if s == "tpu":
return make_tpu(DEFAULT_TPU, TPUS[DEFAULT_TPU].default_chips)

# 1) Try parsing as GPU
is_gpu_explicit = s.startswith("gpu:")
gpu_str = s[4:] if is_gpu_explicit else s

if gpu_str.isdigit():
count = int(gpu_str)
for gpu_name in _PREFERRED_GPUS:
if gpu_name in GPUS and count in GPUS[gpu_name].counts:
return make_gpu(gpu_name, count)
if is_gpu_explicit:
valid_counts = sorted(
set(c for spec in GPUS.values() for c in spec.counts)
)
raise ValueError(
f"No GPU supports count {count}. Supported counts: {valid_counts}"
)

name = _resolve_gpu_alias(gpu_str)
if name in GPUS:
return make_gpu(name, 1)

m = _MULTI_GPU_RE.match(gpu_str)
if m:
name = m.group(1)
name = _resolve_gpu_alias(m.group(1))
if name in GPUS:
return make_gpu(name, int(m.group(2)))
if name in _GPU_ALIASES:
return make_gpu(_GPU_ALIASES[name], int(m.group(2)))

# Direct TPU name (bare): "v5litepod" → default chips
if s in TPUS:
return make_tpu(s, TPUS[s].default_chips)

# TPU with topology string: "v5litepod-2x2", "v5p-2x2x2"
m = _TPU_TOPO_RE.match(s)
if m and m.group(1) in TPUS:
name = m.group(1)
topo_str = m.group(2)
for chips, topo_spec in TPUS[name].topologies.items():
if topo_spec.topology == topo_str:
return make_tpu(name, chips)
valid = [ts.topology for ts in TPUS[name].topologies.values()]
raise ValueError(
f"Topology '{topo_str}' not supported for '{name}'. "
f"Supported: {', '.join(valid)}."
)

# TPU with chip count: "v3-8", "v5litepod-4"
m = _TPU_CHIPS_RE.match(s)
if m and m.group(1) in TPUS:
return make_tpu(m.group(1), int(m.group(2)))
if is_gpu_explicit:
raise ValueError(f"Unknown GPU accelerator: '{accel_str}'")

# 2) Try parsing as TPU
is_tpu_explicit = s.startswith("tpu:")
tpu_str = s[4:] if is_tpu_explicit else s

if tpu_str.isdigit():
chips = int(tpu_str)
for tpu_name in _PREFERRED_TPUS:
if tpu_name in TPUS and chips in TPUS[tpu_name].topologies:
return make_tpu(tpu_name, chips)
if is_tpu_explicit:
valid_chips = sorted(
set(c for spec in TPUS.values() for c in spec.topologies)
)
raise ValueError(
f"No TPU supports {chips} chips. Supported chip counts: {valid_chips}"
)

name = _resolve_tpu_alias(tpu_str)
if name in TPUS:
return make_tpu(name, TPUS[name].default_chips)

m = _TPU_TOPO_RE.match(tpu_str)
if m:
name = _resolve_tpu_alias(m.group(1))
if name in TPUS:
topo_str = m.group(2)
for chips, topo_spec in TPUS[name].topologies.items():
if topo_spec.topology == topo_str:
return make_tpu(name, chips)
valid = [ts.topology for ts in TPUS[name].topologies.values()]
raise ValueError(
f"Topology '{topo_str}' not supported for '{name}'. "
f"Supported: {', '.join(valid)}."
)

m = _TPU_CHIPS_RE.match(tpu_str)
if m:
name = _resolve_tpu_alias(m.group(1))
if name in TPUS:
return make_tpu(name, int(m.group(2)))

raise ValueError(
f"Unknown accelerator: '{accel_str}'. "
f"GPUs: {', '.join(GPUS)} (use 'xN' for multi-GPU, e.g. 'a100x4'). "
f"TPUs: {', '.join(TPUS)} (use '-N' for chips, e.g. 'v3-8', "
f"or '-NxM' for topology, e.g. 'v5litepod-2x2')."
f"GPUs: {', '.join(GPUS)} (use 'gpu:name' or 'gpu:namexN'). "
f"TPUs: {', '.join(TPUS)} (use 'tpu:name' or 'tpu:name-N')."
)


Expand Down Expand Up @@ -234,7 +391,7 @@ def make_gpu(name: str, count: int) -> GpuConfig:
name=name,
count=count,
gke_label=spec.gke_label,
machine_type=spec.machine_type,
machine_type=spec.counts[count],
)


Expand Down
Loading
Loading