Skip to content
129 changes: 107 additions & 22 deletions keras_remote/core/accelerators.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,12 +63,14 @@ class TpuSpec:


GPUS: dict[str, GpuSpec] = {
"l4": GpuSpec("nvidia-l4", "g2-standard-4", (1, 2, 4)),
"l4": GpuSpec("nvidia-l4", "g2-standard-4", (1, 2, 4, 8)),
"t4": GpuSpec("nvidia-tesla-t4", "n1-standard-4", (1, 2, 4)),
"v100": GpuSpec("nvidia-tesla-v100", "n1-standard-8", (1, 2, 4, 8)),
"a100": GpuSpec("nvidia-tesla-a100", "a2-highgpu-1g", (1, 2, 4, 8)),
"a100-80gb": GpuSpec("nvidia-a100-80gb", "a2-ultragpu-1g", (1, 2, 4, 8)),
"a100": GpuSpec("nvidia-tesla-a100", "a2-highgpu-1g", (1, 2, 4, 8, 16)),
"a100-80gb": GpuSpec("nvidia-a100-80gb", "a2-ultragpu-1g", (1, 2, 4, 8, 16)),
"h100": GpuSpec("nvidia-h100-80gb", "a3-highgpu-1g", (1, 2, 4, 8)),
"p4": GpuSpec("nvidia-tesla-p4", "n1-standard-4", (1, 2, 4)),
"p100": GpuSpec("nvidia-tesla-p100", "n1-standard-4", (1, 2, 4)),
}

_GPU_ALIASES: dict[str, str] = {
Expand All @@ -88,6 +90,10 @@ class TpuSpec:
4: TpuTopologySpec("2x2", "ct2-hightpu-4t", 1),
16: TpuTopologySpec("4x4", "ct2-hightpu-4t", 4),
32: TpuTopologySpec("4x8", "ct2-hightpu-4t", 8),
64: TpuTopologySpec("8x8", "ct2-hightpu-4t", 16),
128: TpuTopologySpec("8x16", "ct2-hightpu-4t", 32),
256: TpuTopologySpec("16x16", "ct2-hightpu-4t", 64),
512: TpuTopologySpec("16x32", "ct2-hightpu-4t", 128),
},
),
"v3": TpuSpec(
Expand All @@ -97,6 +103,29 @@ class TpuSpec:
4: TpuTopologySpec("2x2", "ct3-hightpu-4t", 1),
16: TpuTopologySpec("4x4", "ct3p-hightpu-4t", 4),
32: TpuTopologySpec("4x8", "ct3p-hightpu-4t", 8),
64: TpuTopologySpec("8x8", "ct3p-hightpu-4t", 16),
128: TpuTopologySpec("8x16", "ct3p-hightpu-4t", 32),
256: TpuTopologySpec("16x16", "ct3p-hightpu-4t", 64),
512: TpuTopologySpec("16x32", "ct3p-hightpu-4t", 128),
1024: TpuTopologySpec("32x32", "ct3p-hightpu-4t", 256),
2048: TpuTopologySpec("32x64", "ct3p-hightpu-4t", 512),
},
),
"v4": TpuSpec(
"tpu-v4-podslice",
4,
{
4: TpuTopologySpec("2x2x1", "ct4p-hightpu-4t", 1),
8: TpuTopologySpec("2x2x2", "ct4p-hightpu-4t", 2),
16: TpuTopologySpec("2x2x4", "ct4p-hightpu-4t", 4),
32: TpuTopologySpec("2x4x4", "ct4p-hightpu-4t", 8),
64: TpuTopologySpec("4x4x4", "ct4p-hightpu-4t", 16),
128: TpuTopologySpec("4x4x8", "ct4p-hightpu-4t", 32),
256: TpuTopologySpec("4x8x8", "ct4p-hightpu-4t", 64),
512: TpuTopologySpec("8x8x8", "ct4p-hightpu-4t", 128),
1024: TpuTopologySpec("8x8x16", "ct4p-hightpu-4t", 256),
2048: TpuTopologySpec("8x16x16", "ct4p-hightpu-4t", 512),
4096: TpuTopologySpec("16x16x16", "ct4p-hightpu-4t", 1024),
},
),
"v5litepod": TpuSpec(
Expand All @@ -106,6 +135,11 @@ class TpuSpec:
1: TpuTopologySpec("1x1", "ct5lp-hightpu-1t", 1),
4: TpuTopologySpec("2x2", "ct5lp-hightpu-4t", 1),
8: TpuTopologySpec("2x4", "ct5lp-hightpu-8t", 1),
16: TpuTopologySpec("4x4", "ct5lp-hightpu-4t", 4),
32: TpuTopologySpec("4x8", "ct5lp-hightpu-4t", 8),
64: TpuTopologySpec("8x8", "ct5lp-hightpu-4t", 16),
128: TpuTopologySpec("8x16", "ct5lp-hightpu-4t", 32),
256: TpuTopologySpec("16x16", "ct5lp-hightpu-4t", 64),
},
),
"v5p": TpuSpec(
Expand All @@ -126,31 +160,72 @@ class TpuSpec:
),
}

_TPU_ALIASES: dict[str, str] = {
"v5e": "v5litepod",
"ghostlite": "v5litepod",
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think *fish names are okay to use externally

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

removed

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fyi: The fish names are exposed externally at https://cloud.google.com/skus/sku-groups/vertex-prediction.

}


# ── Parser ────────────────────────────────────────────────────────

_MULTI_GPU_RE = re.compile(r"^(.+?)x(\d+)$") # "a100x4"
_TPU_CHIPS_RE = re.compile(r"^(v\d+\w*)-(\d+)$") # "v3-8"
_MULTI_GPU_RE = re.compile(r"^(.+?)(?:x|-)(\d+)$") # "a100x4", "l4-2"
_TPU_CHIPS_RE = re.compile(r"^([a-z0-9_]+)-(\d+)$") # "v3-8", "ghostlite-16"
_TPU_TOPO_RE = re.compile(
r"^(v\d+\w*)-(\d+x\d+(?:x\d+)?)$"
r"^([a-z0-9_]+)-(\d+x\d+(?:x\d+)?)$"
) # "v5litepod-2x2", "v5p-2x2x2"

DEFAULT_GPU = "l4"
DEFAULT_TPU = "v5litepod"

_PREFERRED_GPUS = ["h100", "a100-80gb", "a100", "l4", "v100", "t4", "p100", "p4"]
_PREFERRED_TPUS = ["v6e", "v5p", "v5litepod", "v4", "v3", "v2"]


def parse_accelerator(accel_str: str) -> Accelerator:
"""Parse an accelerator string into a fully resolved config.

Returns GpuConfig, TpuConfig, or None (for "cpu").

Accepted formats:
GPU: "l4", "nvidia-l4", "a100x4", "a100-80gbx8"
TPU: "v3-8" (chip count), "v5litepod-2x2" (topology), "v5litepod" (default)
CPU: "cpu"
GPU: "l4", "gpu", "gpu-4", "a100x4", "l4-2", "a100-80gbx8"
TPU: "v3-8", "tpu", "tpu-8", "v5litepod-2x2", "v5litepod"
CPU: "cpu", "cpu-8"
"""
s = accel_str.strip().lower()

if s == "cpu":
if s == "cpu" or (s.startswith("cpu-") and s[4:].isdigit()):
return None

if s == "gpu":
return make_gpu(DEFAULT_GPU, 1)

if s == "tpu":
return make_tpu(DEFAULT_TPU, TPUS[DEFAULT_TPU].default_chips)

if s.startswith("gpu-") and s[4:].isdigit():
count = int(s[4:])
if count in GPUS[DEFAULT_GPU].counts:
return make_gpu(DEFAULT_GPU, count)
for gpu_name in _PREFERRED_GPUS:
if gpu_name in GPUS and count in GPUS[gpu_name].counts:
return make_gpu(gpu_name, count)
valid_counts = sorted(set(c for spec in GPUS.values() for c in spec.counts))
raise ValueError(
f"No GPU supports count {count}. Supported counts across all GPUs: {valid_counts}"
)

if s.startswith("tpu-") and s[4:].isdigit():
chips = int(s[4:])
if chips in TPUS[DEFAULT_TPU].topologies:
return make_tpu(DEFAULT_TPU, chips)
for tpu_name in _PREFERRED_TPUS:
if tpu_name in TPUS and chips in TPUS[tpu_name].topologies:
return make_tpu(tpu_name, chips)
valid_chips = sorted(set(c for spec in TPUS.values() for c in spec.topologies))
raise ValueError(
f"No TPU supports {chips} chips. Supported chip counts across all TPUs: {valid_chips}"
)

# Direct GPU name: "l4", "a100-80gb"
if s in GPUS:
return make_gpu(s, 1)
Expand All @@ -171,25 +246,35 @@ def parse_accelerator(accel_str: str) -> Accelerator:
# Direct TPU name (bare): "v5litepod" → default chips
if s in TPUS:
return make_tpu(s, TPUS[s].default_chips)
if s in _TPU_ALIASES:
name = _TPU_ALIASES[s]
return make_tpu(name, TPUS[name].default_chips)

# TPU with topology string: "v5litepod-2x2", "v5p-2x2x2"
m = _TPU_TOPO_RE.match(s)
if m and m.group(1) in TPUS:
if m:
name = m.group(1)
topo_str = m.group(2)
for chips, topo_spec in TPUS[name].topologies.items():
if topo_spec.topology == topo_str:
return make_tpu(name, chips)
valid = [ts.topology for ts in TPUS[name].topologies.values()]
raise ValueError(
f"Topology '{topo_str}' not supported for '{name}'. "
f"Supported: {', '.join(valid)}."
)
if name in _TPU_ALIASES:
name = _TPU_ALIASES[name]
if name in TPUS:
topo_str = m.group(2)
for chips, topo_spec in TPUS[name].topologies.items():
if topo_spec.topology == topo_str:
return make_tpu(name, chips)
valid = [ts.topology for ts in TPUS[name].topologies.values()]
raise ValueError(
f"Topology '{topo_str}' not supported for '{name}'. "
f"Supported: {', '.join(valid)}."
)

# TPU with chip count: "v3-8", "v5litepod-4"
m = _TPU_CHIPS_RE.match(s)
if m and m.group(1) in TPUS:
return make_tpu(m.group(1), int(m.group(2)))
if m:
name = m.group(1)
if name in _TPU_ALIASES:
name = _TPU_ALIASES[name]
if name in TPUS:
return make_tpu(name, int(m.group(2)))

raise ValueError(
f"Unknown accelerator: '{accel_str}'. "
Expand Down
72 changes: 70 additions & 2 deletions keras_remote/core/accelerators_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,6 +44,12 @@
self.assertEqual(result.name, "a100-80gb")
self.assertEqual(result.count, 4)

def test_l4_dash_2(self):
result = parse_accelerator("l4-2")
self.assertIsInstance(result, GpuConfig)
self.assertEqual(result.name, "l4")
self.assertEqual(result.count, 2)


class TestParseGpuAlias(absltest.TestCase):
def test_nvidia_tesla_t4(self):
Expand All @@ -60,9 +66,9 @@


class TestParseGpuErrors(absltest.TestCase):
def test_l4x8_invalid_count(self):
def test_l4x16_invalid_count(self):
with self.assertRaisesRegex(ValueError, "not supported"):
parse_accelerator("l4x8")
parse_accelerator("l4x16")


class TestParseTpuBare(parameterized.TestCase):
Expand Down Expand Up @@ -154,6 +160,68 @@
def test_cpu(self):
self.assertIsNone(parse_accelerator("cpu"))

def test_cpu_with_count(self):
self.assertIsNone(parse_accelerator("cpu-8"))

class TestParseGenericAliases(absltest.TestCase):
def test_gpu_bare(self):
result = parse_accelerator("gpu")
self.assertIsInstance(result, GpuConfig)
self.assertEqual(result.name, "l4")
self.assertEqual(result.count, 1)

def test_tpu_bare(self):
result = parse_accelerator("tpu")
self.assertIsInstance(result, TpuConfig)
self.assertEqual(result.name, "v5litepod")
self.assertEqual(result.chips, 4)

def test_gpu_with_count(self):
result = parse_accelerator("gpu-4")
self.assertIsInstance(result, GpuConfig)
self.assertEqual(result.name, "l4")
self.assertEqual(result.count, 4)

def test_tpu_with_count(self):
result = parse_accelerator("tpu-8")
self.assertIsInstance(result, TpuConfig)
self.assertEqual(result.name, "v5litepod")
self.assertEqual(result.chips, 8)

def test_gpu_with_dynamic_count(self):
# l4 supports up to 8 now. 16 should fall back to a100.
result = parse_accelerator("gpu-16")
self.assertIsInstance(result, GpuConfig)
self.assertIn(result.name, ["a100", "a100-80gb"])
self.assertEqual(result.count, 16)

def test_tpu_with_dynamic_count(self):
# v5litepod supports up to 256. 4096 should fall back to v4.
result = parse_accelerator("tpu-4096")
self.assertIsInstance(result, TpuConfig)
self.assertEqual(result.name, "v4")
self.assertEqual(result.chips, 4096)

Check failure on line 204 in keras_remote/core/accelerators_test.py

View workflow job for this annotation

GitHub Actions / ruff

ruff (W293)

keras_remote/core/accelerators_test.py:204:1: W293 Blank line contains whitespace help: Remove whitespace from blank line
def test_v5e_alias(self):
result = parse_accelerator("v5e-8")
self.assertIsInstance(result, TpuConfig)
self.assertEqual(result.name, "v5litepod")
self.assertEqual(result.chips, 8)

def test_ghostlite_alias(self):
result = parse_accelerator("ghostlite-16")
self.assertIsInstance(result, TpuConfig)
self.assertEqual(result.name, "v5litepod")
self.assertEqual(result.chips, 16)

def test_gpu_unsupported_count(self):
with self.assertRaisesRegex(ValueError, "No GPU supports count 32"):
parse_accelerator("gpu-32")

Check failure on line 220 in keras_remote/core/accelerators_test.py

View workflow job for this annotation

GitHub Actions / ruff

ruff (W293)

keras_remote/core/accelerators_test.py:220:1: W293 Blank line contains whitespace help: Remove whitespace from blank line
def test_tpu_unsupported_count(self):
with self.assertRaisesRegex(ValueError, "No TPU supports 8192 chips"):
parse_accelerator("tpu-8192")


class TestParseNormalizationAndErrors(absltest.TestCase):
def test_whitespace_and_case(self):
Expand Down
Loading