Skip to content
168 changes: 129 additions & 39 deletions keras_remote/core/accelerators.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,12 +63,14 @@ class TpuSpec:


GPUS: dict[str, GpuSpec] = {
"l4": GpuSpec("nvidia-l4", "g2-standard-4", (1, 2, 4)),
"l4": GpuSpec("nvidia-l4", "g2-standard-4", (1, 2, 4, 8)),
"t4": GpuSpec("nvidia-tesla-t4", "n1-standard-4", (1, 2, 4)),
"v100": GpuSpec("nvidia-tesla-v100", "n1-standard-8", (1, 2, 4, 8)),
"a100": GpuSpec("nvidia-tesla-a100", "a2-highgpu-1g", (1, 2, 4, 8)),
"a100-80gb": GpuSpec("nvidia-a100-80gb", "a2-ultragpu-1g", (1, 2, 4, 8)),
"a100": GpuSpec("nvidia-tesla-a100", "a2-highgpu-1g", (1, 2, 4, 8, 16)),
"a100-80gb": GpuSpec("nvidia-a100-80gb", "a2-ultragpu-1g", (1, 2, 4, 8, 16)),
"h100": GpuSpec("nvidia-h100-80gb", "a3-highgpu-1g", (1, 2, 4, 8)),
"p4": GpuSpec("nvidia-tesla-p4", "n1-standard-4", (1, 2, 4)),
"p100": GpuSpec("nvidia-tesla-p100", "n1-standard-4", (1, 2, 4)),
}

_GPU_ALIASES: dict[str, str] = {
Expand All @@ -88,6 +90,10 @@ class TpuSpec:
4: TpuTopologySpec("2x2", "ct2-hightpu-4t", 1),
16: TpuTopologySpec("4x4", "ct2-hightpu-4t", 4),
32: TpuTopologySpec("4x8", "ct2-hightpu-4t", 8),
64: TpuTopologySpec("8x8", "ct2-hightpu-4t", 16),
128: TpuTopologySpec("8x16", "ct2-hightpu-4t", 32),
256: TpuTopologySpec("16x16", "ct2-hightpu-4t", 64),
512: TpuTopologySpec("16x32", "ct2-hightpu-4t", 128),
},
),
"v3": TpuSpec(
Expand All @@ -97,6 +103,29 @@ class TpuSpec:
4: TpuTopologySpec("2x2", "ct3-hightpu-4t", 1),
16: TpuTopologySpec("4x4", "ct3p-hightpu-4t", 4),
32: TpuTopologySpec("4x8", "ct3p-hightpu-4t", 8),
64: TpuTopologySpec("8x8", "ct3p-hightpu-4t", 16),
128: TpuTopologySpec("8x16", "ct3p-hightpu-4t", 32),
256: TpuTopologySpec("16x16", "ct3p-hightpu-4t", 64),
512: TpuTopologySpec("16x32", "ct3p-hightpu-4t", 128),
1024: TpuTopologySpec("32x32", "ct3p-hightpu-4t", 256),
2048: TpuTopologySpec("32x64", "ct3p-hightpu-4t", 512),
},
),
"v4": TpuSpec(
"tpu-v4-podslice",
4,
{
4: TpuTopologySpec("2x2x1", "ct4p-hightpu-4t", 1),
8: TpuTopologySpec("2x2x2", "ct4p-hightpu-4t", 2),
16: TpuTopologySpec("2x2x4", "ct4p-hightpu-4t", 4),
32: TpuTopologySpec("2x4x4", "ct4p-hightpu-4t", 8),
64: TpuTopologySpec("4x4x4", "ct4p-hightpu-4t", 16),
128: TpuTopologySpec("4x4x8", "ct4p-hightpu-4t", 32),
256: TpuTopologySpec("4x8x8", "ct4p-hightpu-4t", 64),
512: TpuTopologySpec("8x8x8", "ct4p-hightpu-4t", 128),
1024: TpuTopologySpec("8x8x16", "ct4p-hightpu-4t", 256),
2048: TpuTopologySpec("8x16x16", "ct4p-hightpu-4t", 512),
4096: TpuTopologySpec("16x16x16", "ct4p-hightpu-4t", 1024),
},
),
"v5litepod": TpuSpec(
Expand All @@ -106,6 +135,11 @@ class TpuSpec:
1: TpuTopologySpec("1x1", "ct5lp-hightpu-1t", 1),
4: TpuTopologySpec("2x2", "ct5lp-hightpu-4t", 1),
8: TpuTopologySpec("2x4", "ct5lp-hightpu-8t", 1),
16: TpuTopologySpec("4x4", "ct5lp-hightpu-4t", 4),
32: TpuTopologySpec("4x8", "ct5lp-hightpu-4t", 8),
64: TpuTopologySpec("8x8", "ct5lp-hightpu-4t", 16),
128: TpuTopologySpec("8x16", "ct5lp-hightpu-4t", 32),
256: TpuTopologySpec("16x16", "ct5lp-hightpu-4t", 64),
},
),
"v5p": TpuSpec(
Expand All @@ -126,70 +160,126 @@ class TpuSpec:
),
}

_TPU_ALIASES: dict[str, str] = {
"v5e": "v5litepod",
}


# ── Parser ────────────────────────────────────────────────────────

_MULTI_GPU_RE = re.compile(r"^(.+?)x(\d+)$") # "a100x4"
_TPU_CHIPS_RE = re.compile(r"^(v\d+\w*)-(\d+)$") # "v3-8"
_MULTI_GPU_RE = re.compile(r"^(.+?)(?:x|-)(\d+)$") # "a100x4", "l4-2"
_TPU_CHIPS_RE = re.compile(r"^([a-z0-9_]+)-(\d+)$") # "v3-8", "v5litepod-16"
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

_MULTI_GPU_RE = r"^(.+?)(?:x|-)(\d+)$" matches any name-number pattern, which is same shape as _TPU_CHIPS_RE = r"^([a-z0-9_]+)-(\d+)$".

eg., "l4-2" matches both regexes. The reason it works is that Multi-GPU is moved to the end of the function and TPU falls through (since "l4" is not in TPUS).

If anyone reorders the checks in the future, TPU parsing will intercept GPU (with dash) strings or vice versa.

Do you think we should let it be for now, or should we use this opportunity to implement what was discussed offline to utilize tpu:.. and gpu:... prefixes for accelerator name, which also helps popularise the TPU branding.

Copy link
Collaborator Author

@divyashreepathihalli divyashreepathihalli Mar 9, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Great idea. I've updated the core parsing logic to fully support and encourage explicit gpu: and tpu: prefixes. However, I also retained fallback parsing so that legacy unprefixed strings (like v5litepod or l4) continue to work as before!

_TPU_TOPO_RE = re.compile(
r"^(v\d+\w*)-(\d+x\d+(?:x\d+)?)$"
r"^([a-z0-9_]+)-(\d+x\d+(?:x\d+)?)$"
) # "v5litepod-2x2", "v5p-2x2x2"

DEFAULT_GPU = "l4"
DEFAULT_TPU = "v5litepod"

_PREFERRED_GPUS = [
"h100",
"a100-80gb",
"a100",
"l4",
"v100",
"t4",
"p100",
"p4",
]
_PREFERRED_TPUS = ["v6e", "v5p", "v5litepod", "v4", "v3", "v2"]


def _resolve_gpu_alias(name: str) -> str:
return _GPU_ALIASES.get(name, name)


def _resolve_tpu_alias(name: str) -> str:
return _TPU_ALIASES.get(name, name)


def parse_accelerator(accel_str: str) -> Accelerator:
"""Parse an accelerator string into a fully resolved config.

Returns GpuConfig, TpuConfig, or None (for "cpu").

Accepted formats:
GPU: "l4", "nvidia-l4", "a100x4", "a100-80gbx8"
TPU: "v3-8" (chip count), "v5litepod-2x2" (topology), "v5litepod" (default)
CPU: "cpu"
GPU: "l4", "gpu", "gpu-4", "a100x4", "l4-2", "a100-80gbx8"
TPU: "v3-8", "tpu", "tpu-8", "v5litepod-2x2", "v5litepod"
CPU: "cpu", "cpu-8"
"""
s = accel_str.strip().lower()

if s == "cpu":
if s == "cpu" or (s.startswith("cpu-") and s[4:].isdigit()):
return None

# Direct GPU name: "l4", "a100-80gb"
if s in GPUS:
return make_gpu(s, 1)
if s == "gpu":
return make_gpu(DEFAULT_GPU, 1)

# GPU alias: "nvidia-l4"
if s in _GPU_ALIASES:
return make_gpu(_GPU_ALIASES[s], 1)
if s == "tpu":
return make_tpu(DEFAULT_TPU, TPUS[DEFAULT_TPU].default_chips)

# Multi-GPU: "a100x4", "l4x2"
m = _MULTI_GPU_RE.match(s)
if m:
name = m.group(1)
if name in GPUS:
return make_gpu(name, int(m.group(2)))
if name in _GPU_ALIASES:
return make_gpu(_GPU_ALIASES[name], int(m.group(2)))
if s.startswith("gpu-") and s[4:].isdigit():
count = int(s[4:])
search_order = _PREFERRED_GPUS
for gpu_name in search_order:
if gpu_name in GPUS and count in GPUS[gpu_name].counts:
return make_gpu(gpu_name, count)
valid_counts = sorted(set(c for spec in GPUS.values() for c in spec.counts))
raise ValueError(
f"No GPU supports count {count}. Supported counts across all GPUs: {valid_counts}"
)

if s.startswith("tpu-") and s[4:].isdigit():
chips = int(s[4:])
search_order = _PREFERRED_TPUS
for tpu_name in search_order:
if tpu_name in TPUS and chips in TPUS[tpu_name].topologies:
return make_tpu(tpu_name, chips)
valid_chips = sorted(
set(c for spec in TPUS.values() for c in spec.topologies)
)
raise ValueError(
f"No TPU supports {chips} chips. Supported chip counts across all TPUs: {valid_chips}"
)

# Direct GPU name: "l4", "a100-80gb"
name = _resolve_gpu_alias(s)
if name in GPUS:
return make_gpu(name, 1)

# Direct TPU name (bare): "v5litepod" → default chips
if s in TPUS:
return make_tpu(s, TPUS[s].default_chips)
name = _resolve_tpu_alias(s)
if name in TPUS:
return make_tpu(name, TPUS[name].default_chips)

# TPU with topology string: "v5litepod-2x2", "v5p-2x2x2"
m = _TPU_TOPO_RE.match(s)
if m and m.group(1) in TPUS:
name = m.group(1)
topo_str = m.group(2)
for chips, topo_spec in TPUS[name].topologies.items():
if topo_spec.topology == topo_str:
return make_tpu(name, chips)
valid = [ts.topology for ts in TPUS[name].topologies.values()]
raise ValueError(
f"Topology '{topo_str}' not supported for '{name}'. "
f"Supported: {', '.join(valid)}."
)
if m:
name = _resolve_tpu_alias(m.group(1))
if name in TPUS:
topo_str = m.group(2)
for chips, topo_spec in TPUS[name].topologies.items():
if topo_spec.topology == topo_str:
return make_tpu(name, chips)
valid = [ts.topology for ts in TPUS[name].topologies.values()]
raise ValueError(
f"Topology '{topo_str}' not supported for '{name}'. "
f"Supported: {', '.join(valid)}."
)

# TPU with chip count: "v3-8", "v5litepod-4"
m = _TPU_CHIPS_RE.match(s)
if m and m.group(1) in TPUS:
return make_tpu(m.group(1), int(m.group(2)))
if m:
name = _resolve_tpu_alias(m.group(1))
if name in TPUS:
return make_tpu(name, int(m.group(2)))

# Multi-GPU: "a100x4", "l4x2"
m = _MULTI_GPU_RE.match(s)
if m:
name = _resolve_gpu_alias(m.group(1))
if name in GPUS:
return make_gpu(name, int(m.group(2)))

raise ValueError(
f"Unknown accelerator: '{accel_str}'. "
Expand Down
67 changes: 65 additions & 2 deletions keras_remote/core/accelerators_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,12 @@ def test_a100_80gbx4(self):
self.assertEqual(result.name, "a100-80gb")
self.assertEqual(result.count, 4)

def test_l4_dash_2(self):
result = parse_accelerator("l4-2")
self.assertIsInstance(result, GpuConfig)
self.assertEqual(result.name, "l4")
self.assertEqual(result.count, 2)


class TestParseGpuAlias(absltest.TestCase):
def test_nvidia_tesla_t4(self):
Expand All @@ -60,9 +66,9 @@ def test_nvidia_tesla_v100x4(self):


class TestParseGpuErrors(absltest.TestCase):
def test_l4x8_invalid_count(self):
def test_l4x16_invalid_count(self):
with self.assertRaisesRegex(ValueError, "not supported"):
parse_accelerator("l4x8")
parse_accelerator("l4x16")


class TestParseTpuBare(parameterized.TestCase):
Expand Down Expand Up @@ -154,6 +160,63 @@ class TestParseCpu(absltest.TestCase):
def test_cpu(self):
self.assertIsNone(parse_accelerator("cpu"))

def test_cpu_with_count(self):
self.assertIsNone(parse_accelerator("cpu-8"))


class TestParseGenericAliases(absltest.TestCase):
def test_gpu_bare(self):
result = parse_accelerator("gpu")
self.assertIsInstance(result, GpuConfig)
self.assertEqual(result.name, "l4")
self.assertEqual(result.count, 1)

def test_tpu_bare(self):
result = parse_accelerator("tpu")
self.assertIsInstance(result, TpuConfig)
self.assertEqual(result.name, "v5litepod")
self.assertEqual(result.chips, 4)

def test_gpu_with_count(self):
result = parse_accelerator("gpu-4")
self.assertIsInstance(result, GpuConfig)
self.assertEqual(result.name, "h100")
self.assertEqual(result.count, 4)

def test_tpu_with_count(self):
result = parse_accelerator("tpu-8")
self.assertIsInstance(result, TpuConfig)
self.assertEqual(result.name, "v6e")
self.assertEqual(result.chips, 8)

def test_gpu_with_dynamic_count(self):
# h100 supports up to 8. 16 should fall back to a100-80gb.
result = parse_accelerator("gpu-16")
self.assertIsInstance(result, GpuConfig)
self.assertEqual(result.name, "a100-80gb")
self.assertEqual(result.count, 16)

def test_tpu_with_dynamic_count(self):
# v5litepod supports up to 256. 4096 should fall back to v4.
result = parse_accelerator("tpu-4096")
self.assertIsInstance(result, TpuConfig)
self.assertEqual(result.name, "v4")
self.assertEqual(result.chips, 4096)

def test_v5e_alias(self):
result = parse_accelerator("v5e-8")
self.assertIsInstance(result, TpuConfig)
self.assertEqual(result.name, "v5litepod")
self.assertEqual(result.chips, 8)

def test_gpu_unsupported_count(self):
with self.assertRaisesRegex(ValueError, "No GPU supports count 32"):
parse_accelerator("gpu-32")

def test_tpu_unsupported_count(self):
with self.assertRaisesRegex(ValueError, "No TPU supports 8192 chips"):
parse_accelerator("tpu-8192")


class TestParseNormalizationAndErrors(absltest.TestCase):
def test_whitespace_and_case(self):
Expand Down
Loading