Skip to content

Commit 88f008b

Browse files
Displays accelerator node pool information as part of status output (#46)
* Displays accelerator pool info with status command * removes unused method
1 parent 96e1b02 commit 88f008b

File tree

5 files changed

+261
-14
lines changed

5 files changed

+261
-14
lines changed

keras_remote/cli/commands/status.py

Lines changed: 7 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -2,13 +2,17 @@
22

33
import click
44
from pulumi.automation import CommandError
5-
from rich.table import Table
65

76
from keras_remote.cli.config import InfraConfig
87
from keras_remote.cli.constants import DEFAULT_ZONE
98
from keras_remote.cli.infra.program import create_program
109
from keras_remote.cli.infra.stack_manager import get_stack
11-
from keras_remote.cli.output import banner, console, warning
10+
from keras_remote.cli.output import (
11+
banner,
12+
console,
13+
infrastructure_state,
14+
warning,
15+
)
1216
from keras_remote.cli.prerequisites_check import check_all
1317
from keras_remote.cli.prompts import resolve_project
1418

@@ -57,13 +61,4 @@ def status(project, zone):
5761
warning("No infrastructure found. Run 'keras-remote up' first.")
5862
return
5963

60-
table = Table(title="Infrastructure State")
61-
table.add_column("Resource", style="bold")
62-
table.add_column("Value", style="green")
63-
64-
for key, output in outputs.items():
65-
table.add_row(key, str(output.value))
66-
67-
console.print()
68-
console.print(table)
69-
console.print()
64+
infrastructure_state(outputs)

keras_remote/cli/infra/program.py

Lines changed: 29 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -131,6 +131,35 @@ def pulumi_program():
131131
f"{ar_location}-docker.pkg.dev/{project_id}/keras-remote",
132132
)
133133

134+
# 7. Accelerator node pool exports
135+
if isinstance(accelerator, GpuConfig):
136+
pulumi.export(
137+
"accelerator",
138+
{
139+
"type": "GPU",
140+
"name": accelerator.name,
141+
"count": accelerator.count,
142+
"machine_type": accelerator.machine_type,
143+
"node_pool": "gpu-pool",
144+
"node_count": 1,
145+
},
146+
)
147+
elif isinstance(accelerator, TpuConfig):
148+
pulumi.export(
149+
"accelerator",
150+
{
151+
"type": "TPU",
152+
"name": accelerator.name,
153+
"chips": accelerator.chips,
154+
"topology": accelerator.topology,
155+
"machine_type": accelerator.machine_type,
156+
"node_pool": f"tpu-{accelerator.name}-pool",
157+
"node_count": accelerator.num_nodes,
158+
},
159+
)
160+
else:
161+
pulumi.export("accelerator", None)
162+
134163
return pulumi_program
135164

136165

keras_remote/cli/infra/program_test.py

Lines changed: 62 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,10 @@
1-
"""Tests for keras_remote.cli.infra.program — TPU node pool creation."""
1+
"""Tests for keras_remote.cli.infra.program — node pool and exports."""
22

33
from unittest import mock
44

55
from absl.testing import absltest, parameterized
66

7-
from keras_remote.core.accelerators import TpuConfig
7+
from keras_remote.core.accelerators import GpuConfig, TpuConfig
88

99
# Patch the pulumi_gcp module before importing program, so the module-level
1010
# import inside program.py picks up the mock.
@@ -86,5 +86,65 @@ def test_pool_name_includes_tpu_name(self, gcp_mock):
8686
self.assertEqual(positional_args[0], "tpu-v5p-pool")
8787

8888

89+
def _make_config(accelerator=None):
90+
"""Create a mock InfraConfig for testing."""
91+
config = mock.MagicMock()
92+
config.project = "test-project"
93+
config.zone = "us-central1-a"
94+
config.cluster_name = "test-cluster"
95+
config.accelerator = accelerator
96+
return config
97+
98+
99+
def _run_program_and_get_exports(config):
100+
"""Run the Pulumi program and return a dict of exported key -> value."""
101+
with (
102+
mock.patch.object(program, "pulumi") as pulumi_mock,
103+
mock.patch.object(program, "gcp"),
104+
):
105+
program_fn = program.create_program(config)
106+
program_fn()
107+
return {
108+
call.args[0]: call.args[1] for call in pulumi_mock.export.call_args_list
109+
}
110+
111+
112+
class TestAcceleratorExports(absltest.TestCase):
113+
"""Verify accelerator metadata is exported correctly."""
114+
115+
def test_gpu_exports(self):
116+
gpu = GpuConfig("l4", 1, "nvidia-l4", "g2-standard-4")
117+
exports = _run_program_and_get_exports(_make_config(gpu))
118+
119+
self.assertIn("accelerator", exports)
120+
accel = exports["accelerator"]
121+
self.assertEqual(accel["type"], "GPU")
122+
self.assertEqual(accel["name"], "l4")
123+
self.assertEqual(accel["count"], 1)
124+
self.assertEqual(accel["machine_type"], "g2-standard-4")
125+
self.assertEqual(accel["node_pool"], "gpu-pool")
126+
self.assertEqual(accel["node_count"], 1)
127+
128+
def test_tpu_exports(self):
129+
tpu = TpuConfig("v5p", 8, "2x2x2", "tpu-v5p-slice", "ct5p-hightpu-4t", 2)
130+
exports = _run_program_and_get_exports(_make_config(tpu))
131+
132+
self.assertIn("accelerator", exports)
133+
accel = exports["accelerator"]
134+
self.assertEqual(accel["type"], "TPU")
135+
self.assertEqual(accel["name"], "v5p")
136+
self.assertEqual(accel["chips"], 8)
137+
self.assertEqual(accel["topology"], "2x2x2")
138+
self.assertEqual(accel["machine_type"], "ct5p-hightpu-4t")
139+
self.assertEqual(accel["node_pool"], "tpu-v5p-pool")
140+
self.assertEqual(accel["node_count"], 2)
141+
142+
def test_cpu_only_exports_none(self):
143+
exports = _run_program_and_get_exports(_make_config(None))
144+
145+
self.assertIn("accelerator", exports)
146+
self.assertIsNone(exports["accelerator"])
147+
148+
89149
if __name__ == "__main__":
90150
absltest.main()

keras_remote/cli/output.py

Lines changed: 62 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,68 @@ def error(msg):
2929
console.print(f"[red]{msg}[/red]")
3030

3131

32+
_INFRA_LABELS = {
33+
"project": "Project",
34+
"zone": "Zone",
35+
"cluster_name": "Cluster Name",
36+
"cluster_endpoint": "Cluster Endpoint",
37+
"ar_registry": "Artifact Registry",
38+
}
39+
40+
_GPU_LABELS = {
41+
"name": "GPU Type",
42+
"count": "GPU Count",
43+
"machine_type": "Machine Type",
44+
"node_pool": "Node Pool",
45+
"node_count": "Node Count",
46+
}
47+
48+
_TPU_LABELS = {
49+
"name": "TPU Type",
50+
"chips": "TPU Chips",
51+
"topology": "Topology",
52+
"machine_type": "Machine Type",
53+
"node_pool": "Node Pool",
54+
"node_count": "Node Count",
55+
}
56+
57+
58+
def infrastructure_state(outputs):
59+
"""Display infrastructure state from Pulumi stack outputs.
60+
61+
Args:
62+
outputs: dict of key -> pulumi.automation.OutputValue from stack.outputs().
63+
"""
64+
table = Table(title="Infrastructure State")
65+
table.add_column("Resource", style="bold")
66+
table.add_column("Value", style="green")
67+
68+
for key, label in _INFRA_LABELS.items():
69+
if key in outputs:
70+
table.add_row(label, str(outputs[key].value))
71+
72+
if "accelerator" not in outputs:
73+
table.add_row(
74+
"Accelerator",
75+
"[dim]Unknown (run 'keras-remote up' to refresh)[/dim]",
76+
)
77+
elif outputs["accelerator"].value is None:
78+
table.add_row("Accelerator", "CPU only")
79+
else:
80+
accel = outputs["accelerator"].value
81+
accel_type = accel.get("type", "Unknown")
82+
table.add_row("", "")
83+
table.add_row("Accelerator", accel_type)
84+
labels = _GPU_LABELS if accel_type == "GPU" else _TPU_LABELS
85+
for key, label in labels.items():
86+
if key in accel:
87+
table.add_row(f" {label}", str(accel[key]))
88+
89+
console.print()
90+
console.print(table)
91+
console.print()
92+
93+
3294
def config_summary(config):
3395
"""Display a configuration summary table."""
3496
table = Table(title="Configuration Summary")

keras_remote/cli/output_test.py

Lines changed: 101 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,101 @@
1+
"""Tests for keras_remote.cli.output — infrastructure_state display."""
2+
3+
from unittest import mock
4+
5+
from absl.testing import absltest
6+
from rich.console import Console
7+
8+
from keras_remote.cli import output
9+
10+
11+
def _make_output(value):
12+
"""Create a mock Pulumi OutputValue."""
13+
ov = mock.MagicMock()
14+
ov.value = value
15+
return ov
16+
17+
18+
def _render_text(outputs):
19+
"""Render infrastructure_state and capture output as text."""
20+
from io import StringIO
21+
22+
sio = StringIO()
23+
buf = Console(file=sio, force_terminal=False, width=120)
24+
with mock.patch.object(output, "console", buf):
25+
output.infrastructure_state(outputs)
26+
return sio.getvalue()
27+
28+
29+
class TestInfrastructureState(absltest.TestCase):
30+
"""Verify infrastructure_state renders correctly."""
31+
32+
def _base_outputs(self):
33+
return {
34+
"project": _make_output("my-project"),
35+
"zone": _make_output("us-central1-a"),
36+
"cluster_name": _make_output("keras-remote-cluster"),
37+
"cluster_endpoint": _make_output("34.123.45.67"),
38+
"ar_registry": _make_output("us-docker.pkg.dev/my-project/keras-remote"),
39+
}
40+
41+
def test_gpu_accelerator(self):
42+
outputs = self._base_outputs()
43+
outputs["accelerator"] = _make_output(
44+
{
45+
"type": "GPU",
46+
"name": "l4",
47+
"count": 1,
48+
"machine_type": "g2-standard-4",
49+
"node_pool": "gpu-pool",
50+
"node_count": 1,
51+
}
52+
)
53+
text = _render_text(outputs)
54+
55+
self.assertIn("my-project", text)
56+
self.assertIn("GPU", text)
57+
self.assertIn("l4", text)
58+
self.assertIn("g2-standard-4", text)
59+
self.assertIn("gpu-pool", text)
60+
61+
def test_tpu_accelerator(self):
62+
outputs = self._base_outputs()
63+
outputs["accelerator"] = _make_output(
64+
{
65+
"type": "TPU",
66+
"name": "v5p",
67+
"chips": 8,
68+
"topology": "2x2x2",
69+
"machine_type": "ct5p-hightpu-4t",
70+
"node_pool": "tpu-v5p-pool",
71+
"node_count": 2,
72+
}
73+
)
74+
text = _render_text(outputs)
75+
76+
self.assertIn("TPU", text)
77+
self.assertIn("v5p", text)
78+
self.assertIn("2x2x2", text)
79+
self.assertIn("ct5p-hightpu-4t", text)
80+
self.assertIn("tpu-v5p-pool", text)
81+
82+
def test_cpu_only(self):
83+
outputs = self._base_outputs()
84+
outputs["accelerator"] = _make_output(None)
85+
text = _render_text(outputs)
86+
87+
self.assertIn("CPU only", text)
88+
self.assertNotIn("GPU", text)
89+
self.assertNotIn("TPU", text)
90+
91+
def test_missing_accelerator_key_backward_compat(self):
92+
outputs = self._base_outputs()
93+
# No "accelerator" key — simulates old stack
94+
text = _render_text(outputs)
95+
96+
self.assertIn("Unknown", text)
97+
self.assertIn("keras-remote up", text)
98+
99+
100+
if __name__ == "__main__":
101+
absltest.main()

0 commit comments

Comments
 (0)