Skip to content
49 changes: 49 additions & 0 deletions examples/pathways_example.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
import os

os.environ["KERAS_BACKEND"] = "jax"

import keras
import numpy as np
from keras import layers

import keras_remote


# A simple model that will be executed remotely
@keras_remote.run(accelerator="v5litepod-1", backend="pathways")
def train_simple_model():
print("Running Pathways job on JAX Backend!")

# Create a simple dataset
x = np.random.rand(1000, 10)
y = np.random.randint(0, 2, size=(1000, 1))

# A simple sequential model
model = keras.Sequential(
[
keras.Input(shape=(10,)),
layers.Dense(32, activation="relu"),
layers.Dense(16, activation="relu"),
layers.Dense(1, activation="sigmoid"),
]
)

model.compile(
optimizer="adam", loss="binary_crossentropy", metrics=["accuracy"]
)

print("Model Architecture:")
model.summary()

# Train the model
print("\nStarting Training...")
history = model.fit(x, y, epochs=5, batch_size=32, validation_split=0.2)

print("\nTraining completed successfully on Pathways!")
return history.history
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Just to confirm, there are no user code changes to run on Pathways within their remote function? All it needs is backend="pathways"? That's pretty cool

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

yeah if it is None. It auto detects if the user requested for a multi node TPU and picks the pathways backend



if __name__ == "__main__":
print("Submitting Pathways training job...")
result_history = train_simple_model()
print("Final validation accuracy:", result_history["val_accuracy"][-1])
35 changes: 32 additions & 3 deletions keras_remote/backend/execution.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
import cloudpickle
from absl import logging

from keras_remote.backend import gke_client
from keras_remote.backend import gke_client, pathways_client
from keras_remote.constants import get_default_zone, zone_to_region
from keras_remote.infra import container_builder
from keras_remote.utils import packager, storage
Expand Down Expand Up @@ -105,13 +105,17 @@ def cleanup_job(self, job: Any, ctx: JobContext) -> None:
...


class GKEBackend:
"""Backend adapter for GKE."""
class BaseK8sBackend:
"""Base class for Kubernetes-based backends."""

def __init__(self, cluster: Optional[str] = None, namespace: str = "default"):
self.cluster = cluster
self.namespace = namespace


class GKEBackend(BaseK8sBackend):
"""Backend adapter for standard GKE Jobs."""

def submit_job(self, ctx: JobContext) -> Any:
"""Submit job to GKE cluster."""
return gke_client.submit_k8s_job(
Expand All @@ -134,6 +138,31 @@ def cleanup_job(self, job: Any, ctx: JobContext) -> None:
gke_client.cleanup_job(job_name, namespace=self.namespace)


class PathwaysBackend(BaseK8sBackend):
"""Backend adapter for ML Pathways using LeaderWorkerSet."""

def submit_job(self, ctx: JobContext) -> Any:
"""Submit LWS job to GKE cluster."""
return pathways_client.submit_pathways_job(
display_name=ctx.display_name,
container_uri=ctx.image_uri,
accelerator=ctx.accelerator,
project=ctx.project,
job_id=ctx.job_id,
bucket_name=ctx.bucket_name,
namespace=self.namespace,
)

def wait_for_job(self, job: Any, ctx: JobContext) -> None:
"""Wait for Pathways LWS completion."""
pathways_client.wait_for_job(ctx.job_id, namespace=self.namespace)

def cleanup_job(self, job: Any, ctx: JobContext) -> None:
"""Clean up LWS resources."""
job_name = pathways_client._get_job_name(ctx.job_id)
pathways_client.cleanup_job(job_name, namespace=self.namespace)


def _find_requirements(start_dir: str) -> Optional[str]:
"""Search up directory tree for requirements.txt."""
search_dir = start_dir
Expand Down
2 changes: 1 addition & 1 deletion keras_remote/backend/gke_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -300,7 +300,7 @@ def _create_job_spec(

pod_template = client.V1PodTemplateSpec(
metadata=client.V1ObjectMeta(
labels={"app": "keras-remote", "job-id": job_id}
labels={"app": "keras-remote", "job-id": job_id, "job-name": job_name}
),
spec=client.V1PodSpec(**pod_spec_kwargs),
)
Expand Down
Loading