Skip to content

Commit c6b4dad

Browse files
Add Pathways Integration (#31)
* initial pathways integration * add pathways example * fix core * code reformat * auto detect lws version * code reformat * address gemini comment * address Gemini comment * address comments * code reformat * map internal APIs to standard Multi-Host Eenv specs * address jeff's comment
1 parent ddf770f commit c6b4dad

File tree

8 files changed

+501
-20
lines changed

8 files changed

+501
-20
lines changed

examples/pathways_example.py

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,49 @@
1+
import os
2+
3+
os.environ["KERAS_BACKEND"] = "jax"
4+
5+
import keras
6+
import numpy as np
7+
from keras import layers
8+
9+
import keras_remote
10+
11+
12+
# A simple model that will be executed remotely on pathways
13+
@keras_remote.run(accelerator="v5litepod-1", backend="pathways")
14+
def train_simple_model():
15+
print("Running Pathways job on JAX Backend!")
16+
17+
# Create a simple dataset
18+
x = np.random.rand(1000, 10)
19+
y = np.random.randint(0, 2, size=(1000, 1))
20+
21+
# A simple sequential model
22+
model = keras.Sequential(
23+
[
24+
keras.Input(shape=(10,)),
25+
layers.Dense(32, activation="relu"),
26+
layers.Dense(16, activation="relu"),
27+
layers.Dense(1, activation="sigmoid"),
28+
]
29+
)
30+
31+
model.compile(
32+
optimizer="adam", loss="binary_crossentropy", metrics=["accuracy"]
33+
)
34+
35+
print("Model Architecture:")
36+
model.summary()
37+
38+
# Train the model
39+
print("\nStarting Training...")
40+
history = model.fit(x, y, epochs=5, batch_size=32, validation_split=0.2)
41+
42+
print("\nTraining completed successfully on Pathways!")
43+
return history.history
44+
45+
46+
if __name__ == "__main__":
47+
print("Submitting Pathways training job...")
48+
result_history = train_simple_model()
49+
print("Final validation accuracy:", result_history["val_accuracy"][-1])

keras_remote/backend/execution.py

Lines changed: 32 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
import cloudpickle
1515
from absl import logging
1616

17-
from keras_remote.backend import gke_client
17+
from keras_remote.backend import gke_client, pathways_client
1818
from keras_remote.constants import get_default_zone, zone_to_region
1919
from keras_remote.infra import container_builder
2020
from keras_remote.utils import packager, storage
@@ -105,13 +105,17 @@ def cleanup_job(self, job: Any, ctx: JobContext) -> None:
105105
...
106106

107107

108-
class GKEBackend:
109-
"""Backend adapter for GKE."""
108+
class BaseK8sBackend:
109+
"""Base class for Kubernetes-based backends."""
110110

111111
def __init__(self, cluster: Optional[str] = None, namespace: str = "default"):
112112
self.cluster = cluster
113113
self.namespace = namespace
114114

115+
116+
class GKEBackend(BaseK8sBackend):
117+
"""Backend adapter for standard GKE Jobs."""
118+
115119
def submit_job(self, ctx: JobContext) -> Any:
116120
"""Submit job to GKE cluster."""
117121
return gke_client.submit_k8s_job(
@@ -134,6 +138,31 @@ def cleanup_job(self, job: Any, ctx: JobContext) -> None:
134138
gke_client.cleanup_job(job_name, namespace=self.namespace)
135139

136140

141+
class PathwaysBackend(BaseK8sBackend):
142+
"""Backend adapter for ML Pathways using LeaderWorkerSet."""
143+
144+
def submit_job(self, ctx: JobContext) -> Any:
145+
"""Submit LWS job to GKE cluster."""
146+
return pathways_client.submit_pathways_job(
147+
display_name=ctx.display_name,
148+
container_uri=ctx.image_uri,
149+
accelerator=ctx.accelerator,
150+
project=ctx.project,
151+
job_id=ctx.job_id,
152+
bucket_name=ctx.bucket_name,
153+
namespace=self.namespace,
154+
)
155+
156+
def wait_for_job(self, job: Any, ctx: JobContext) -> None:
157+
"""Wait for Pathways LWS completion."""
158+
pathways_client.wait_for_job(ctx.job_id, namespace=self.namespace)
159+
160+
def cleanup_job(self, job: Any, ctx: JobContext) -> None:
161+
"""Clean up LWS resources."""
162+
job_name = pathways_client._get_job_name(ctx.job_id)
163+
pathways_client.cleanup_job(job_name, namespace=self.namespace)
164+
165+
137166
def _find_requirements(start_dir: str) -> Optional[str]:
138167
"""Search up directory tree for requirements.txt."""
139168
search_dir = start_dir

keras_remote/backend/gke_client.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -300,7 +300,7 @@ def _create_job_spec(
300300

301301
pod_template = client.V1PodTemplateSpec(
302302
metadata=client.V1ObjectMeta(
303-
labels={"app": "keras-remote", "job-id": job_id}
303+
labels={"app": "keras-remote", "job-id": job_id, "job-name": job_name}
304304
),
305305
spec=client.V1PodSpec(**pod_spec_kwargs),
306306
)

0 commit comments

Comments
 (0)