Skip to content
159 changes: 125 additions & 34 deletions keras_remote/core/accelerators.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,12 +63,14 @@ class TpuSpec:


GPUS: dict[str, GpuSpec] = {
"l4": GpuSpec("nvidia-l4", "g2-standard-4", (1, 2, 4)),
"l4": GpuSpec("nvidia-l4", "g2-standard-4", (1, 2, 4, 8)),
"t4": GpuSpec("nvidia-tesla-t4", "n1-standard-4", (1, 2, 4)),
"v100": GpuSpec("nvidia-tesla-v100", "n1-standard-8", (1, 2, 4, 8)),
"a100": GpuSpec("nvidia-tesla-a100", "a2-highgpu-1g", (1, 2, 4, 8)),
"a100-80gb": GpuSpec("nvidia-a100-80gb", "a2-ultragpu-1g", (1, 2, 4, 8)),
"a100": GpuSpec("nvidia-tesla-a100", "a2-highgpu-1g", (1, 2, 4, 8, 16)),
"a100-80gb": GpuSpec("nvidia-a100-80gb", "a2-ultragpu-1g", (1, 2, 4, 8, 16)),
"h100": GpuSpec("nvidia-h100-80gb", "a3-highgpu-1g", (1, 2, 4, 8)),
"p4": GpuSpec("nvidia-tesla-p4", "n1-standard-4", (1, 2, 4)),
"p100": GpuSpec("nvidia-tesla-p100", "n1-standard-4", (1, 2, 4)),
}

_GPU_ALIASES: dict[str, str] = {
Expand All @@ -88,6 +90,10 @@ class TpuSpec:
4: TpuTopologySpec("2x2", "ct2-hightpu-4t", 1),
16: TpuTopologySpec("4x4", "ct2-hightpu-4t", 4),
32: TpuTopologySpec("4x8", "ct2-hightpu-4t", 8),
64: TpuTopologySpec("8x8", "ct2-hightpu-4t", 16),
128: TpuTopologySpec("8x16", "ct2-hightpu-4t", 32),
256: TpuTopologySpec("16x16", "ct2-hightpu-4t", 64),
512: TpuTopologySpec("16x32", "ct2-hightpu-4t", 128),
},
),
"v3": TpuSpec(
Expand All @@ -97,6 +103,29 @@ class TpuSpec:
4: TpuTopologySpec("2x2", "ct3-hightpu-4t", 1),
16: TpuTopologySpec("4x4", "ct3p-hightpu-4t", 4),
32: TpuTopologySpec("4x8", "ct3p-hightpu-4t", 8),
64: TpuTopologySpec("8x8", "ct3p-hightpu-4t", 16),
128: TpuTopologySpec("8x16", "ct3p-hightpu-4t", 32),
256: TpuTopologySpec("16x16", "ct3p-hightpu-4t", 64),
512: TpuTopologySpec("16x32", "ct3p-hightpu-4t", 128),
1024: TpuTopologySpec("32x32", "ct3p-hightpu-4t", 256),
2048: TpuTopologySpec("32x64", "ct3p-hightpu-4t", 512),
},
),
"v4": TpuSpec(
"tpu-v4-podslice",
4,
{
4: TpuTopologySpec("2x2x1", "ct4p-hightpu-4t", 1),
8: TpuTopologySpec("2x2x2", "ct4p-hightpu-4t", 2),
16: TpuTopologySpec("2x2x4", "ct4p-hightpu-4t", 4),
32: TpuTopologySpec("2x4x4", "ct4p-hightpu-4t", 8),
64: TpuTopologySpec("4x4x4", "ct4p-hightpu-4t", 16),
128: TpuTopologySpec("4x4x8", "ct4p-hightpu-4t", 32),
256: TpuTopologySpec("4x8x8", "ct4p-hightpu-4t", 64),
512: TpuTopologySpec("8x8x8", "ct4p-hightpu-4t", 128),
1024: TpuTopologySpec("8x8x16", "ct4p-hightpu-4t", 256),
2048: TpuTopologySpec("8x16x16", "ct4p-hightpu-4t", 512),
4096: TpuTopologySpec("16x16x16", "ct4p-hightpu-4t", 1024),
},
),
"v5litepod": TpuSpec(
Expand All @@ -106,6 +135,11 @@ class TpuSpec:
1: TpuTopologySpec("1x1", "ct5lp-hightpu-1t", 1),
4: TpuTopologySpec("2x2", "ct5lp-hightpu-4t", 1),
8: TpuTopologySpec("2x4", "ct5lp-hightpu-8t", 1),
16: TpuTopologySpec("4x4", "ct5lp-hightpu-4t", 4),
32: TpuTopologySpec("4x8", "ct5lp-hightpu-4t", 8),
64: TpuTopologySpec("8x8", "ct5lp-hightpu-4t", 16),
128: TpuTopologySpec("8x16", "ct5lp-hightpu-4t", 32),
256: TpuTopologySpec("16x16", "ct5lp-hightpu-4t", 64),
},
),
"v5p": TpuSpec(
Expand All @@ -126,70 +160,127 @@ class TpuSpec:
),
}

_TPU_ALIASES: dict[str, str] = {
"v5e": "v5litepod",
"ghostlite": "v5litepod",
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think *fish names are okay to use externally

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

removed

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fyi: The fish names are exposed externally at https://cloud.google.com/skus/sku-groups/vertex-prediction.

}


# ── Parser ────────────────────────────────────────────────────────

_MULTI_GPU_RE = re.compile(r"^(.+?)x(\d+)$") # "a100x4"
_TPU_CHIPS_RE = re.compile(r"^(v\d+\w*)-(\d+)$") # "v3-8"
_MULTI_GPU_RE = re.compile(r"^(.+?)(?:x|-)(\d+)$") # "a100x4", "l4-2"
_TPU_CHIPS_RE = re.compile(r"^([a-z0-9_]+)-(\d+)$") # "v3-8", "ghostlite-16"
_TPU_TOPO_RE = re.compile(
r"^(v\d+\w*)-(\d+x\d+(?:x\d+)?)$"
r"^([a-z0-9_]+)-(\d+x\d+(?:x\d+)?)$"
) # "v5litepod-2x2", "v5p-2x2x2"

DEFAULT_GPU = "l4"
DEFAULT_TPU = "v5litepod"

_PREFERRED_GPUS = [
"h100",
"a100-80gb",
"a100",
"l4",
"v100",
"t4",
"p100",
"p4",
]
_PREFERRED_TPUS = ["v6e", "v5p", "v5litepod", "v4", "v3", "v2"]


def _resolve_gpu_alias(name: str) -> str:
return _GPU_ALIASES.get(name, name)


def _resolve_tpu_alias(name: str) -> str:
return _TPU_ALIASES.get(name, name)


def parse_accelerator(accel_str: str) -> Accelerator:
"""Parse an accelerator string into a fully resolved config.

Returns GpuConfig, TpuConfig, or None (for "cpu").

Accepted formats:
GPU: "l4", "nvidia-l4", "a100x4", "a100-80gbx8"
TPU: "v3-8" (chip count), "v5litepod-2x2" (topology), "v5litepod" (default)
CPU: "cpu"
GPU: "l4", "gpu", "gpu-4", "a100x4", "l4-2", "a100-80gbx8"
TPU: "v3-8", "tpu", "tpu-8", "v5litepod-2x2", "v5litepod"
CPU: "cpu", "cpu-8"
"""
s = accel_str.strip().lower()

if s == "cpu":
if s == "cpu" or (s.startswith("cpu-") and s[4:].isdigit()):
return None

# Direct GPU name: "l4", "a100-80gb"
if s in GPUS:
return make_gpu(s, 1)
if s == "gpu":
return make_gpu(DEFAULT_GPU, 1)

# GPU alias: "nvidia-l4"
if s in _GPU_ALIASES:
return make_gpu(_GPU_ALIASES[s], 1)
if s == "tpu":
return make_tpu(DEFAULT_TPU, TPUS[DEFAULT_TPU].default_chips)

if s.startswith("gpu-") and s[4:].isdigit():
count = int(s[4:])
search_order = _PREFERRED_GPUS
for gpu_name in search_order:
if gpu_name in GPUS and count in GPUS[gpu_name].counts:
return make_gpu(gpu_name, count)
valid_counts = sorted(set(c for spec in GPUS.values() for c in spec.counts))
raise ValueError(
f"No GPU supports count {count}. Supported counts across all GPUs: {valid_counts}"
)

if s.startswith("tpu-") and s[4:].isdigit():
chips = int(s[4:])
search_order = _PREFERRED_TPUS
for tpu_name in search_order:
if tpu_name in TPUS and chips in TPUS[tpu_name].topologies:
return make_tpu(tpu_name, chips)
valid_chips = sorted(
set(c for spec in TPUS.values() for c in spec.topologies)
)
raise ValueError(
f"No TPU supports {chips} chips. Supported chip counts across all TPUs: {valid_chips}"
)

# Direct GPU name: "l4", "a100-80gb"
name = _resolve_gpu_alias(s)
if name in GPUS:
return make_gpu(name, 1)

# Multi-GPU: "a100x4", "l4x2"
m = _MULTI_GPU_RE.match(s)
if m:
name = m.group(1)
name = _resolve_gpu_alias(m.group(1))
if name in GPUS:
return make_gpu(name, int(m.group(2)))
if name in _GPU_ALIASES:
return make_gpu(_GPU_ALIASES[name], int(m.group(2)))

# Direct TPU name (bare): "v5litepod" → default chips
if s in TPUS:
return make_tpu(s, TPUS[s].default_chips)
name = _resolve_tpu_alias(s)
if name in TPUS:
return make_tpu(name, TPUS[name].default_chips)

# TPU with topology string: "v5litepod-2x2", "v5p-2x2x2"
m = _TPU_TOPO_RE.match(s)
if m and m.group(1) in TPUS:
name = m.group(1)
topo_str = m.group(2)
for chips, topo_spec in TPUS[name].topologies.items():
if topo_spec.topology == topo_str:
return make_tpu(name, chips)
valid = [ts.topology for ts in TPUS[name].topologies.values()]
raise ValueError(
f"Topology '{topo_str}' not supported for '{name}'. "
f"Supported: {', '.join(valid)}."
)
if m:
name = _resolve_tpu_alias(m.group(1))
if name in TPUS:
topo_str = m.group(2)
for chips, topo_spec in TPUS[name].topologies.items():
if topo_spec.topology == topo_str:
return make_tpu(name, chips)
valid = [ts.topology for ts in TPUS[name].topologies.values()]
raise ValueError(
f"Topology '{topo_str}' not supported for '{name}'. "
f"Supported: {', '.join(valid)}."
)

# TPU with chip count: "v3-8", "v5litepod-4"
m = _TPU_CHIPS_RE.match(s)
if m and m.group(1) in TPUS:
return make_tpu(m.group(1), int(m.group(2)))
if m:
name = _resolve_tpu_alias(m.group(1))
if name in TPUS:
return make_tpu(name, int(m.group(2)))

raise ValueError(
f"Unknown accelerator: '{accel_str}'. "
Expand Down
73 changes: 71 additions & 2 deletions keras_remote/core/accelerators_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,12 @@ def test_a100_80gbx4(self):
self.assertEqual(result.name, "a100-80gb")
self.assertEqual(result.count, 4)

def test_l4_dash_2(self):
result = parse_accelerator("l4-2")
self.assertIsInstance(result, GpuConfig)
self.assertEqual(result.name, "l4")
self.assertEqual(result.count, 2)


class TestParseGpuAlias(absltest.TestCase):
def test_nvidia_tesla_t4(self):
Expand All @@ -60,9 +66,9 @@ def test_nvidia_tesla_v100x4(self):


class TestParseGpuErrors(absltest.TestCase):
def test_l4x8_invalid_count(self):
def test_l4x16_invalid_count(self):
with self.assertRaisesRegex(ValueError, "not supported"):
parse_accelerator("l4x8")
parse_accelerator("l4x16")


class TestParseTpuBare(parameterized.TestCase):
Expand Down Expand Up @@ -154,6 +160,69 @@ class TestParseCpu(absltest.TestCase):
def test_cpu(self):
self.assertIsNone(parse_accelerator("cpu"))

def test_cpu_with_count(self):
self.assertIsNone(parse_accelerator("cpu-8"))


class TestParseGenericAliases(absltest.TestCase):
def test_gpu_bare(self):
result = parse_accelerator("gpu")
self.assertIsInstance(result, GpuConfig)
self.assertEqual(result.name, "l4")
self.assertEqual(result.count, 1)

def test_tpu_bare(self):
result = parse_accelerator("tpu")
self.assertIsInstance(result, TpuConfig)
self.assertEqual(result.name, "v5litepod")
self.assertEqual(result.chips, 4)

def test_gpu_with_count(self):
result = parse_accelerator("gpu-4")
self.assertIsInstance(result, GpuConfig)
self.assertEqual(result.name, "h100")
self.assertEqual(result.count, 4)

def test_tpu_with_count(self):
result = parse_accelerator("tpu-8")
self.assertIsInstance(result, TpuConfig)
self.assertEqual(result.name, "v6e")
self.assertEqual(result.chips, 8)

def test_gpu_with_dynamic_count(self):
# h100 supports up to 8. 16 should fall back to a100.
result = parse_accelerator("gpu-16")
self.assertIsInstance(result, GpuConfig)
self.assertIn(result.name, ["a100", "a100-80gb"])
self.assertEqual(result.count, 16)

def test_tpu_with_dynamic_count(self):
# v5litepod supports up to 256. 4096 should fall back to v4.
result = parse_accelerator("tpu-4096")
self.assertIsInstance(result, TpuConfig)
self.assertEqual(result.name, "v4")
self.assertEqual(result.chips, 4096)

def test_v5e_alias(self):
result = parse_accelerator("v5e-8")
self.assertIsInstance(result, TpuConfig)
self.assertEqual(result.name, "v5litepod")
self.assertEqual(result.chips, 8)

def test_ghostlite_alias(self):
result = parse_accelerator("ghostlite-16")
self.assertIsInstance(result, TpuConfig)
self.assertEqual(result.name, "v5litepod")
self.assertEqual(result.chips, 16)

def test_gpu_unsupported_count(self):
with self.assertRaisesRegex(ValueError, "No GPU supports count 32"):
parse_accelerator("gpu-32")

def test_tpu_unsupported_count(self):
with self.assertRaisesRegex(ValueError, "No TPU supports 8192 chips"):
parse_accelerator("tpu-8192")


class TestParseNormalizationAndErrors(absltest.TestCase):
def test_whitespace_and_case(self):
Expand Down
Loading