Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 20 additions & 11 deletions keras_remote/core/accelerators.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,21 +74,28 @@ class TpuSpec:
spec.gke_label: name for name, spec in GPUS.items()
}

# Topology reference — verify new entries against:
# https://docs.cloud.google.com/kubernetes-engine/docs/concepts/plan-tpus
# Formula: num_nodes = product(topology_dims) / chips_per_VM
# Machine-type suffix "-Nt" → N chips per VM (e.g. ct5p-hightpu-4t → 4 chips).
# v5p uses 3-D topologies (AxBxC); v2, v3, v5litepod, v6e use 2-D (AxB).
TPUS: dict[str, TpuSpec] = {
"v2": TpuSpec(
"tpu-v2-podslice",
8,
4,
{
8: TpuTopologySpec("2x2", "ct2-hightpu-4t", 2),
32: TpuTopologySpec("4x4", "ct2-hightpu-4t", 8),
4: TpuTopologySpec("2x2", "ct2-hightpu-4t", 1),
16: TpuTopologySpec("4x4", "ct2-hightpu-4t", 4),
32: TpuTopologySpec("4x8", "ct2-hightpu-4t", 8),
},
),
"v3": TpuSpec(
"tpu-v3-podslice",
8,
4,
{
8: TpuTopologySpec("2x2", "ct3p-hightpu-4t", 2),
32: TpuTopologySpec("4x4", "ct3p-hightpu-4t", 8),
4: TpuTopologySpec("2x2", "ct3-hightpu-4t", 1),
16: TpuTopologySpec("4x4", "ct3p-hightpu-4t", 4),
32: TpuTopologySpec("4x8", "ct3p-hightpu-4t", 8),
},
),
"v5litepod": TpuSpec(
Expand All @@ -104,16 +111,16 @@ class TpuSpec:
"tpu-v5p-slice",
8,
{
8: TpuTopologySpec("2x2", "ct5p-hightpu-4t", 2),
16: TpuTopologySpec("2x4", "ct5p-hightpu-4t", 4),
8: TpuTopologySpec("2x2x2", "ct5p-hightpu-4t", 2),
16: TpuTopologySpec("2x2x4", "ct5p-hightpu-4t", 4),
},
),
"v6e": TpuSpec(
"tpu-v6e-slice",
8,
{
8: TpuTopologySpec("2x2", "ct6e-standard-4t", 2),
16: TpuTopologySpec("2x4", "ct6e-standard-4t", 4),
8: TpuTopologySpec("2x4", "ct6e-standard-4t", 2),
16: TpuTopologySpec("4x4", "ct6e-standard-4t", 4),
},
),
}
Expand All @@ -123,7 +130,9 @@ class TpuSpec:

_MULTI_GPU_RE = re.compile(r"^(.+?)x(\d+)$") # "a100x4"
_TPU_CHIPS_RE = re.compile(r"^(v\d+\w*)-(\d+)$") # "v3-8"
_TPU_TOPO_RE = re.compile(r"^(v\d+\w*)-(\d+x\d+)$") # "v5litepod-2x2"
_TPU_TOPO_RE = re.compile(
r"^(v\d+\w*)-(\d+x\d+(?:x\d+)?)$"
) # "v5litepod-2x2", "v5p-2x2x2"


def parse_accelerator(accel_str: str) -> Accelerator:
Expand Down
32 changes: 23 additions & 9 deletions keras_remote/core/accelerators_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,19 +81,19 @@ def test_all_tpu_types_parse_with_default_chips(self, tpu_name):


class TestParseTpuChipCount(absltest.TestCase):
def test_v3_8(self):
result = parse_accelerator("v3-8")
def test_v3_4(self):
result = parse_accelerator("v3-4")
self.assertIsInstance(result, TpuConfig)
self.assertEqual(result.name, "v3")
self.assertEqual(result.chips, 8)
self.assertEqual(result.chips, 4)
self.assertEqual(result.topology, "2x2")

def test_v3_32(self):
result = parse_accelerator("v3-32")
self.assertIsInstance(result, TpuConfig)
self.assertEqual(result.name, "v3")
self.assertEqual(result.chips, 32)
self.assertEqual(result.topology, "4x4")
self.assertEqual(result.topology, "4x8")

def test_v5litepod_1(self):
result = parse_accelerator("v5litepod-1")
Expand All @@ -118,20 +118,34 @@ def test_v5litepod_1x1(self):


class TestParseTpuErrors(absltest.TestCase):
def test_v3_16_invalid_chips(self):
def test_v3_8_invalid_chips(self):
with self.assertRaisesRegex(ValueError, "not supported"):
parse_accelerator("v3-16")
parse_accelerator("v3-8")

def test_v5litepod_3x3_invalid_topology(self):
with self.assertRaisesRegex(ValueError, "Unknown accelerator"):
parse_accelerator("v5litepod-3x3")


class TestParseTpuConfigFields(absltest.TestCase):
def test_v3_8_full_config(self):
result = parse_accelerator("v3-8")
def test_v3_4_full_config(self):
result = parse_accelerator("v3-4")
self.assertEqual(result.gke_accelerator, "tpu-v3-podslice")
self.assertEqual(result.machine_type, "ct3p-hightpu-4t")
self.assertEqual(result.machine_type, "ct3-hightpu-4t")
self.assertEqual(result.num_nodes, 1)

def test_v5p_default(self):
result = parse_accelerator("v5p")
self.assertIsInstance(result, TpuConfig)
self.assertEqual(result.chips, 8)
self.assertEqual(result.topology, "2x2x2")

def test_v5p_3d_topology(self):
result = parse_accelerator("v5p-2x2x2")
self.assertIsInstance(result, TpuConfig)
self.assertEqual(result.name, "v5p")
self.assertEqual(result.chips, 8)
self.assertEqual(result.topology, "2x2x2")
self.assertEqual(result.num_nodes, 2)


Expand Down
Loading