Skip to content

Commit c24041c

Browse files
Fixes pathways integration
1 parent 04fcc53 commit c24041c

File tree

4 files changed

+98
-9
lines changed

4 files changed

+98
-9
lines changed

examples/pathways_example.py

Lines changed: 46 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -10,10 +10,55 @@
1010

1111

1212
# A simple model that will be executed remotely on pathways
13-
@keras_remote.run(accelerator="v5litepod-1", backend="pathways")
13+
@keras_remote.run(
14+
accelerator="v6e-16", backend="pathways", cluster="keras-team-dogfood"
15+
)
1416
def train_simple_model():
17+
import jax
18+
from jax import lax
19+
1520
print("Running Pathways job on JAX Backend!")
1621

22+
# Verify distributed JAX setup (Pathways auto-initialization)
23+
process_count = jax.process_count()
24+
process_index = jax.process_index()
25+
device_count = jax.device_count()
26+
local_device_count = jax.local_device_count()
27+
28+
print("JAX Distributed Environment:")
29+
print(f" Process Count: {process_count}")
30+
print(f" Process Index: {process_index}")
31+
print(f" Total Devices: {device_count}")
32+
print(f" Local Devices: {local_device_count}")
33+
34+
# Fail if not actually running on multiple hosts
35+
if process_count <= 1:
36+
raise RuntimeError(
37+
f"Pathways verification failed: Expected > 1 processes, but found {process_count}. "
38+
"This indicates the job is NOT running in a multi-host Pathways environment."
39+
)
40+
41+
# Verify collective communication (cross-host psum)
42+
try:
43+
# Use jax.pmap to sum values across all devices in the cluster
44+
x = np.ones(local_device_count)
45+
distributed_sum = jax.pmap(lambda val: lax.psum(val, "i"), axis_name="i")(x)
46+
total_sum = distributed_sum[0]
47+
48+
if total_sum != device_count:
49+
raise RuntimeError(
50+
f"Collective verification failed: Expected psum {device_count}, got {total_sum}"
51+
)
52+
print(
53+
f"Successfully verified collective communication across all {total_sum} devices!"
54+
)
55+
except Exception as e:
56+
print(f"Warning: Collective verification failed: {e}")
57+
if isinstance(e, RuntimeError) and "Collective verification failed" in str(
58+
e
59+
):
60+
raise
61+
1762
# Create a simple dataset
1863
x = np.random.rand(1000, 10)
1964
y = np.random.randint(0, 2, size=(1000, 1))

keras_remote/backend/gke_client.py

Lines changed: 12 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -238,13 +238,16 @@ def _parse_accelerator(accelerator):
238238
}
239239

240240
if isinstance(parsed, TpuConfig):
241+
# For TPU Podslices (multi-node), resource requests must be per-node.
242+
# num_nodes is 1 for single-host TPUs (v3-8, v4-8, v5litepod-1/4/8).
243+
chips_per_node = parsed.chips // parsed.num_nodes
241244
return {
242245
"node_selector": {
243246
"cloud.google.com/gke-tpu-accelerator": parsed.gke_accelerator,
244247
"cloud.google.com/gke-tpu-topology": parsed.topology,
245248
},
246-
"resource_limits": {"google.com/tpu": str(parsed.chips)},
247-
"resource_requests": {"google.com/tpu": str(parsed.chips)},
249+
"resource_limits": {"google.com/tpu": str(chips_per_node)},
250+
"resource_requests": {"google.com/tpu": str(chips_per_node)},
248251
"tolerations": [
249252
{"key": "google.com/tpu", "operator": "Exists", "effect": "NoSchedule"}
250253
],
@@ -330,8 +333,10 @@ def _create_job_spec(
330333
],
331334
env=env_vars,
332335
resources=client.V1ResourceRequirements(
333-
limits=accel_config["resource_limits"],
334-
requests=accel_config["resource_requests"],
336+
limits={k: str(v) for k, v in accel_config["resource_limits"].items()},
337+
requests={
338+
k: str(v) for k, v in accel_config["resource_requests"].items()
339+
},
335340
),
336341
)
337342

@@ -466,7 +471,9 @@ def _check_node_pool_exists_cached(selector_items) -> bool:
466471
for tpu_spec in accelerators.TPUS.values():
467472
for chips, topo_spec in tpu_spec.topologies.items():
468473
if topo_spec.machine_type == machine_type:
469-
pool_labels["cloud.google.com/gke-accelerator-count"] = str(chips)
474+
pool_labels["cloud.google.com/gke-accelerator-count"] = str(
475+
chips // topo_spec.num_nodes
476+
)
470477
break
471478

472479
if all(pool_labels.get(k) == str(v) for k, v in selector.items()):

keras_remote/backend/gke_client_test.py

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -58,6 +58,15 @@ def test_tpu_v3_8(self):
5858
self.assertLen(result["tolerations"], 1)
5959
self.assertEqual(result["tolerations"][0]["key"], "google.com/tpu")
6060

61+
def test_tpu_v3_16_multi_node(self):
62+
# v3-16 has 4 nodes and 16 total chips -> 4 chips per node
63+
result = _parse_accelerator("v3-16")
64+
self.assertEqual(result["resource_limits"], {"google.com/tpu": "4"})
65+
self.assertEqual(result["resource_requests"], {"google.com/tpu": "4"})
66+
self.assertEqual(
67+
result["node_selector"]["cloud.google.com/gke-tpu-topology"], "4x4"
68+
)
69+
6170
def test_tpu_v5litepod_4(self):
6271
result = _parse_accelerator("v5litepod-4")
6372
self.assertEqual(
@@ -421,6 +430,29 @@ def test_tpu_match(self):
421430
)
422431
self.assertTrue(result)
423432

433+
def test_tpu_multi_node_match(self):
434+
"""Test that it correctly identifies a 4-chip-per-node pool for v6e-16."""
435+
self.mock_run.return_value = json.dumps(
436+
[
437+
{
438+
"config": {
439+
"machineType": "ct6e-standard-4t",
440+
"accelerators": [{"acceleratorType": "tpu-v6e-slice"}],
441+
"labels": {},
442+
}
443+
}
444+
]
445+
)
446+
447+
result = _check_node_pool_exists_cached(
448+
(
449+
("cloud.google.com/gke-tpu-accelerator", "tpu-v6e-slice"),
450+
("cloud.google.com/gke-tpu-topology", "4x4"),
451+
("cloud.google.com/gke-accelerator-count", "4"),
452+
)
453+
)
454+
self.assertTrue(result)
455+
424456
def test_no_match(self):
425457
self.mock_run.return_value = json.dumps(
426458
[

keras_remote/backend/pathways_client.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -137,6 +137,7 @@ def wait_for_job(job_id, namespace="default", timeout=3600, poll_interval=10):
137137
# The leader pod is suffixed with '-0' by LWS
138138
leader_pod_name = f"{job_name}-0"
139139

140+
logged_pending = set()
140141
with LogStreamer(core_v1, namespace) as streamer:
141142
while True:
142143
elapsed = time.time() - start_time
@@ -160,7 +161,7 @@ def wait_for_job(job_id, namespace="default", timeout=3600, poll_interval=10):
160161
raise RuntimeError(f"Pathways job {job_name} failed")
161162

162163
elif pod.status.phase == "Pending":
163-
_check_pod_scheduling(core_v1, job_name, namespace)
164+
_check_pod_scheduling(core_v1, job_name, namespace, logged_pending)
164165
logging.debug("Pod is Pending...")
165166

166167
elif pod.status.phase == "Running":
@@ -288,8 +289,12 @@ def _create_lws_spec(
288289
],
289290
"env": env_vars,
290291
"resources": {
291-
"limits": accel_config["resource_limits"],
292-
"requests": accel_config["resource_requests"],
292+
"limits": {
293+
k: str(v) for k, v in accel_config["resource_limits"].items()
294+
},
295+
"requests": {
296+
k: str(v) for k, v in accel_config["resource_requests"].items()
297+
},
293298
},
294299
}
295300
],

0 commit comments

Comments
 (0)