Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions keras_remote/backend/gke_client_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,13 +44,13 @@ def test_gpu_a100x4(self):
self.assertEqual(result["resource_requests"], {"nvidia.com/gpu": "4"})

def test_tpu_v3_8(self):
result = _parse_accelerator("v3-8")
result = _parse_accelerator("v3-4")
self.assertIn(
"cloud.google.com/gke-tpu-accelerator", result["node_selector"]
)
self.assertIn("cloud.google.com/gke-tpu-topology", result["node_selector"])
self.assertEqual(result["resource_limits"], {"google.com/tpu": "8"})
self.assertEqual(result["resource_requests"], {"google.com/tpu": "8"})
self.assertEqual(result["resource_limits"], {"google.com/tpu": "4"})
self.assertEqual(result["resource_requests"], {"google.com/tpu": "4"})
self.assertEqual(result["jax_platform"], "tpu")
self.assertLen(result["tolerations"], 1)
self.assertEqual(result["tolerations"][0]["key"], "google.com/tpu")
Expand Down
15 changes: 11 additions & 4 deletions keras_remote/cli/infra/program.py
Original file line number Diff line number Diff line change
Expand Up @@ -160,6 +160,16 @@ def _create_gpu_node_pool(cluster, gpu: GpuConfig, zone, project_id):
def _create_tpu_node_pool(cluster, tpu: TpuConfig, zone, project_id):
"""Create a TPU GKE node pool."""
pool_name = f"tpu-{tpu.name}-pool"
# Single-host TPU slices (1 node) must not specify placement_policy;
# multi-host slices require COMPACT placement with an explicit topology.
placement = (
gcp.container.NodePoolPlacementPolicyArgs(
type="COMPACT",
tpu_topology=tpu.topology,
)
if tpu.num_nodes > 1
else None
)
gcp.container.NodePool(
pool_name,
name=pool_name,
Expand All @@ -172,8 +182,5 @@ def _create_tpu_node_pool(cluster, tpu: TpuConfig, zone, project_id):
oauth_scopes=_BASE_OAUTH_SCOPES,
labels={RESOURCE_NAME_PREFIX: "true"},
),
placement_policy=gcp.container.NodePoolPlacementPolicyArgs(
type="COMPACT",
tpu_topology=tpu.topology,
),
placement_policy=placement,
)
90 changes: 90 additions & 0 deletions keras_remote/cli/infra/program_test.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,90 @@
"""Tests for keras_remote.cli.infra.program — TPU node pool creation."""

from unittest import mock

from absl.testing import absltest, parameterized

from keras_remote.core.accelerators import TpuConfig

# Patch the pulumi_gcp module before importing program, so the module-level
# import inside program.py picks up the mock.
with mock.patch.dict("sys.modules", {"pulumi_gcp": mock.MagicMock()}):
from keras_remote.cli.infra import program


class TestCreateTpuNodePool(parameterized.TestCase):
"""Verify _create_tpu_node_pool sets placement_policy correctly."""

@parameterized.named_parameters(
dict(
testcase_name="v5p_multi_host",
tpu=TpuConfig("v5p", 8, "2x2x2", "tpu-v5p-slice", "ct5p-hightpu-4t", 2),
expect_placement=True,
),
dict(
testcase_name="v6e_multi_host",
tpu=TpuConfig("v6e", 8, "2x4", "tpu-v6e-slice", "ct6e-standard-4t", 2),
expect_placement=True,
),
dict(
testcase_name="v3_single_host",
tpu=TpuConfig("v3", 4, "2x2", "tpu-v3-podslice", "ct3-hightpu-4t", 1),
expect_placement=False,
),
dict(
testcase_name="v5litepod_single_host",
tpu=TpuConfig(
"v5litepod", 4, "2x2", "tpu-v5-lite-podslice", "ct5lp-hightpu-4t", 1
),
expect_placement=False,
),
)
@mock.patch.object(program, "gcp")
def test_placement_policy(self, gcp_mock, tpu, expect_placement):
cluster = mock.MagicMock()
cluster.name = "test-cluster"

program._create_tpu_node_pool(cluster, tpu, "us-central2-b", "my-project")

call_kwargs = gcp_mock.container.NodePool.call_args
placement = call_kwargs.kwargs.get(
"placement_policy", call_kwargs[1].get("placement_policy")
)

if expect_placement:
self.assertIsNotNone(placement)
gcp_mock.container.NodePoolPlacementPolicyArgs.assert_called_once_with(
type="COMPACT",
tpu_topology=tpu.topology,
)
else:
self.assertIsNone(placement)

@mock.patch.object(program, "gcp")
def test_node_count_matches_config(self, gcp_mock):
tpu = TpuConfig("v5p", 16, "2x2x4", "tpu-v5p-slice", "ct5p-hightpu-4t", 4)
cluster = mock.MagicMock()
cluster.name = "test-cluster"

program._create_tpu_node_pool(cluster, tpu, "us-central2-b", "my-project")

call_kwargs = gcp_mock.container.NodePool.call_args
node_count = call_kwargs.kwargs.get(
"node_count", call_kwargs[1].get("node_count")
)
self.assertEqual(node_count, 4)

@mock.patch.object(program, "gcp")
def test_pool_name_includes_tpu_name(self, gcp_mock):
tpu = TpuConfig("v5p", 8, "2x2x2", "tpu-v5p-slice", "ct5p-hightpu-4t", 2)
cluster = mock.MagicMock()
cluster.name = "test-cluster"

program._create_tpu_node_pool(cluster, tpu, "us-central2-b", "my-project")

positional_args = gcp_mock.container.NodePool.call_args[0]
self.assertEqual(positional_args[0], "tpu-v5p-pool")


if __name__ == "__main__":
absltest.main()
38 changes: 26 additions & 12 deletions keras_remote/core/accelerators.py
Original file line number Diff line number Diff line change
Expand Up @@ -74,21 +74,28 @@ class TpuSpec:
spec.gke_label: name for name, spec in GPUS.items()
}

# Topology reference — verify new entries against:
# https://docs.cloud.google.com/kubernetes-engine/docs/concepts/plan-tpus
# Formula: num_nodes = product(topology_dims) / chips_per_VM
# Machine-type suffix "-Nt" → N chips per VM (e.g. ct5p-hightpu-4t → 4 chips).
# v5p uses 3-D topologies (AxBxC); v2, v3, v5litepod, v6e use 2-D (AxB).
TPUS: dict[str, TpuSpec] = {
"v2": TpuSpec(
"tpu-v2-podslice",
8,
4,
{
8: TpuTopologySpec("2x2", "ct2-hightpu-4t", 2),
32: TpuTopologySpec("4x4", "ct2-hightpu-4t", 8),
4: TpuTopologySpec("2x2", "ct2-hightpu-4t", 1),
16: TpuTopologySpec("4x4", "ct2-hightpu-4t", 4),
32: TpuTopologySpec("4x8", "ct2-hightpu-4t", 8),
},
),
"v3": TpuSpec(
"tpu-v3-podslice",
8,
4,
{
8: TpuTopologySpec("2x2", "ct3p-hightpu-4t", 2),
32: TpuTopologySpec("4x4", "ct3p-hightpu-4t", 8),
4: TpuTopologySpec("2x2", "ct3-hightpu-4t", 1),
16: TpuTopologySpec("4x4", "ct3p-hightpu-4t", 4),
32: TpuTopologySpec("4x8", "ct3p-hightpu-4t", 8),
},
),
"v5litepod": TpuSpec(
Expand All @@ -104,16 +111,16 @@ class TpuSpec:
"tpu-v5p-slice",
8,
{
8: TpuTopologySpec("2x2", "ct5p-hightpu-4t", 2),
16: TpuTopologySpec("2x4", "ct5p-hightpu-4t", 4),
8: TpuTopologySpec("2x2x2", "ct5p-hightpu-4t", 2),
16: TpuTopologySpec("2x2x4", "ct5p-hightpu-4t", 4),
},
),
"v6e": TpuSpec(
"tpu-v6e-slice",
8,
{
8: TpuTopologySpec("2x2", "ct6e-standard-4t", 2),
16: TpuTopologySpec("2x4", "ct6e-standard-4t", 4),
8: TpuTopologySpec("2x4", "ct6e-standard-4t", 2),
16: TpuTopologySpec("4x4", "ct6e-standard-4t", 4),
},
),
}
Expand All @@ -123,7 +130,9 @@ class TpuSpec:

_MULTI_GPU_RE = re.compile(r"^(.+?)x(\d+)$") # "a100x4"
_TPU_CHIPS_RE = re.compile(r"^(v\d+\w*)-(\d+)$") # "v3-8"
_TPU_TOPO_RE = re.compile(r"^(v\d+\w*)-(\d+x\d+)$") # "v5litepod-2x2"
_TPU_TOPO_RE = re.compile(
r"^(v\d+\w*)-(\d+x\d+(?:x\d+)?)$"
) # "v5litepod-2x2", "v5p-2x2x2"


def parse_accelerator(accel_str: str) -> Accelerator:
Expand Down Expand Up @@ -162,14 +171,19 @@ def parse_accelerator(accel_str: str) -> Accelerator:
if s in TPUS:
return _make_tpu(s, TPUS[s].default_chips)

# TPU with topology string: "v5litepod-2x2"
# TPU with topology string: "v5litepod-2x2", "v5p-2x2x2"
m = _TPU_TOPO_RE.match(s)
if m and m.group(1) in TPUS:
name = m.group(1)
topo_str = m.group(2)
for chips, topo_spec in TPUS[name].topologies.items():
if topo_spec.topology == topo_str:
return _make_tpu(name, chips)
valid = [ts.topology for ts in TPUS[name].topologies.values()]
raise ValueError(
f"Topology '{topo_str}' not supported for '{name}'. "
f"Supported: {', '.join(valid)}."
)

# TPU with chip count: "v3-8", "v5litepod-4"
m = _TPU_CHIPS_RE.match(s)
Expand Down
34 changes: 24 additions & 10 deletions keras_remote/core/accelerators_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -81,19 +81,19 @@ def test_all_tpu_types_parse_with_default_chips(self, tpu_name):


class TestParseTpuChipCount(absltest.TestCase):
def test_v3_8(self):
result = parse_accelerator("v3-8")
def test_v3_4(self):
result = parse_accelerator("v3-4")
self.assertIsInstance(result, TpuConfig)
self.assertEqual(result.name, "v3")
self.assertEqual(result.chips, 8)
self.assertEqual(result.chips, 4)
self.assertEqual(result.topology, "2x2")

def test_v3_32(self):
result = parse_accelerator("v3-32")
self.assertIsInstance(result, TpuConfig)
self.assertEqual(result.name, "v3")
self.assertEqual(result.chips, 32)
self.assertEqual(result.topology, "4x4")
self.assertEqual(result.topology, "4x8")

def test_v5litepod_1(self):
result = parse_accelerator("v5litepod-1")
Expand All @@ -118,20 +118,34 @@ def test_v5litepod_1x1(self):


class TestParseTpuErrors(absltest.TestCase):
def test_v3_16_invalid_chips(self):
def test_v3_8_invalid_chips(self):
with self.assertRaisesRegex(ValueError, "not supported"):
parse_accelerator("v3-16")
parse_accelerator("v3-8")

def test_v5litepod_3x3_invalid_topology(self):
with self.assertRaisesRegex(ValueError, "Unknown accelerator"):
with self.assertRaisesRegex(ValueError, "not supported"):
parse_accelerator("v5litepod-3x3")


class TestParseTpuConfigFields(absltest.TestCase):
def test_v3_8_full_config(self):
result = parse_accelerator("v3-8")
def test_v3_4_full_config(self):
result = parse_accelerator("v3-4")
self.assertEqual(result.gke_accelerator, "tpu-v3-podslice")
self.assertEqual(result.machine_type, "ct3p-hightpu-4t")
self.assertEqual(result.machine_type, "ct3-hightpu-4t")
self.assertEqual(result.num_nodes, 1)

def test_v5p_default(self):
result = parse_accelerator("v5p")
self.assertIsInstance(result, TpuConfig)
self.assertEqual(result.chips, 8)
self.assertEqual(result.topology, "2x2x2")

def test_v5p_3d_topology(self):
result = parse_accelerator("v5p-2x2x2")
self.assertIsInstance(result, TpuConfig)
self.assertEqual(result.name, "v5p")
self.assertEqual(result.chips, 8)
self.assertEqual(result.topology, "2x2x2")
self.assertEqual(result.num_nodes, 2)


Expand Down
12 changes: 6 additions & 6 deletions keras_remote/infra/container_builder_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,7 +50,7 @@ def test_different_accelerator_different_hash(self):
req.write_text("numpy\n")

h1 = _hash_requirements(str(req), "l4", "python:3.12-slim")
h2 = _hash_requirements(str(req), "v3-8", "python:3.12-slim")
h2 = _hash_requirements(str(req), "v3-4", "python:3.12-slim")
self.assertNotEqual(h1, h2)

def test_different_base_image_different_hash(self):
Expand Down Expand Up @@ -98,7 +98,7 @@ class TestGenerateDockerfile(parameterized.TestCase):
),
dict(
testcase_name="tpu",
accelerator_type="v3-8",
accelerator_type="v3-4",
expected=["jax[tpu]", "libtpu_releases"],
not_expected=[],
),
Expand Down Expand Up @@ -206,7 +206,7 @@ def test_correct_resource_name(self):
return_value=mock_client,
):
_image_exists(
"us-docker.pkg.dev/my-proj/keras-remote/base:v3-8-abc123def456",
"us-docker.pkg.dev/my-proj/keras-remote/base:v3-4-abc123def456",
"my-proj",
)
call_args = mock_client.get_tag.call_args
Expand All @@ -215,7 +215,7 @@ def test_correct_resource_name(self):
request.name,
"projects/my-proj/locations/us"
"/repositories/keras-remote"
"/packages/base/tags/v3-8-abc123def456",
"/packages/base/tags/v3-4-abc123def456",
)


Expand Down Expand Up @@ -279,13 +279,13 @@ def _get_image_uri(self, accelerator_type, project, zone):
)

def test_image_uri_format_tpu_europe(self):
result = self._get_image_uri("v3-8", "my-proj", "europe-west4-b")
result = self._get_image_uri("v3-4", "my-proj", "europe-west4-b")

self.assertTrue(
result.startswith("europe-docker.pkg.dev/my-proj/keras-remote/base:")
)
tag = result.split(":")[-1]
self.assertRegex(tag, r"^v3-8-[0-9a-f]{12}$")
self.assertRegex(tag, r"^v3-4-[0-9a-f]{12}$")

def test_image_uri_format_gpu_us(self):
result = self._get_image_uri("a100-80gb", "proj", "us-central1-a")
Expand Down
Loading