Skip to content

Commit 3a69ed8

Browse files
code reformat
1 parent 01ac7b7 commit 3a69ed8

File tree

5 files changed

+368
-359
lines changed

5 files changed

+368
-359
lines changed

examples/pathways_example.py

Lines changed: 37 additions & 34 deletions
Original file line numberDiff line numberDiff line change
@@ -3,46 +3,49 @@
33
os.environ["KERAS_BACKEND"] = "jax"
44

55
import keras
6-
from keras import layers
76
import numpy as np
7+
from keras import layers
8+
89
import keras_remote
910

11+
1012
# A simple model that will be executed remotely
1113
@keras_remote.run(
12-
accelerator="v5litepod-1",
14+
accelerator="v5litepod-1",
1315
)
1416
def train_simple_model():
15-
print("Running Pathways job on JAX Backend!")
16-
17-
# Create a simple dataset
18-
x = np.random.rand(1000, 10)
19-
y = np.random.randint(0, 2, size=(1000, 1))
20-
21-
# A simple sequential model
22-
model = keras.Sequential([
23-
keras.Input(shape=(10,)),
24-
layers.Dense(32, activation='relu'),
25-
layers.Dense(16, activation='relu'),
26-
layers.Dense(1, activation='sigmoid')
27-
])
28-
29-
model.compile(
30-
optimizer='adam',
31-
loss='binary_crossentropy',
32-
metrics=['accuracy']
33-
)
34-
35-
print("Model Architecture:")
36-
model.summary()
37-
38-
# Train the model
39-
print("\nStarting Training...")
40-
history = model.fit(x, y, epochs=5, batch_size=32, validation_split=0.2)
41-
42-
print("\nTraining completed successfully on Pathways!")
43-
return history.history
17+
print("Running Pathways job on JAX Backend!")
18+
19+
# Create a simple dataset
20+
x = np.random.rand(1000, 10)
21+
y = np.random.randint(0, 2, size=(1000, 1))
22+
23+
# A simple sequential model
24+
model = keras.Sequential(
25+
[
26+
keras.Input(shape=(10,)),
27+
layers.Dense(32, activation="relu"),
28+
layers.Dense(16, activation="relu"),
29+
layers.Dense(1, activation="sigmoid"),
30+
]
31+
)
32+
33+
model.compile(
34+
optimizer="adam", loss="binary_crossentropy", metrics=["accuracy"]
35+
)
36+
37+
print("Model Architecture:")
38+
model.summary()
39+
40+
# Train the model
41+
print("\nStarting Training...")
42+
history = model.fit(x, y, epochs=5, batch_size=32, validation_split=0.2)
43+
44+
print("\nTraining completed successfully on Pathways!")
45+
return history.history
46+
4447

4548
if __name__ == "__main__":
46-
print("Submitting Pathways training job...")
47-
result_history = train_simple_model()
48-
print("Final validation accuracy:", result_history['val_accuracy'][-1])
49+
print("Submitting Pathways training job...")
50+
result_history = train_simple_model()
51+
print("Final validation accuracy:", result_history["val_accuracy"][-1])

keras_remote/backend/execution.py

Lines changed: 45 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,7 @@
1414
import cloudpickle
1515
from absl import logging
1616

17-
from keras_remote.backend import gke_client
17+
from keras_remote.backend import gke_client, pathways_client
1818
from keras_remote.constants import get_default_zone, zone_to_region
1919
from keras_remote.infra import container_builder
2020
from keras_remote.utils import packager, storage
@@ -106,59 +106,61 @@ def cleanup_job(self, job: Any, ctx: JobContext) -> None:
106106

107107

108108
class BaseK8sBackend:
109-
"""Base class for Kubernetes-based backends."""
109+
"""Base class for Kubernetes-based backends."""
110+
111+
def __init__(self, cluster: Optional[str] = None, namespace: str = "default"):
112+
self.cluster = cluster
113+
self.namespace = namespace
110114

111-
def __init__(self, cluster: Optional[str] = None, namespace: str = "default"):
112-
self.cluster = cluster
113-
self.namespace = namespace
114115

115116
class GKEBackend(BaseK8sBackend):
116-
"""Backend adapter for standard GKE Jobs."""
117-
def submit_job(self, ctx: JobContext) -> Any:
118-
"""Submit job to GKE cluster."""
119-
return gke_client.submit_k8s_job(
120-
display_name=ctx.display_name,
121-
container_uri=ctx.image_uri,
122-
accelerator=ctx.accelerator,
123-
project=ctx.project,
124-
job_id=ctx.job_id,
125-
bucket_name=ctx.bucket_name,
126-
namespace=self.namespace,
127-
)
117+
"""Backend adapter for standard GKE Jobs."""
128118

129-
def wait_for_job(self, job: Any, ctx: JobContext) -> None:
130-
"""Wait for GKE job completion."""
131-
gke_client.wait_for_job(job, namespace=self.namespace)
119+
def submit_job(self, ctx: JobContext) -> Any:
120+
"""Submit job to GKE cluster."""
121+
return gke_client.submit_k8s_job(
122+
display_name=ctx.display_name,
123+
container_uri=ctx.image_uri,
124+
accelerator=ctx.accelerator,
125+
project=ctx.project,
126+
job_id=ctx.job_id,
127+
bucket_name=ctx.bucket_name,
128+
namespace=self.namespace,
129+
)
130+
131+
def wait_for_job(self, job: Any, ctx: JobContext) -> None:
132+
"""Wait for GKE job completion."""
133+
gke_client.wait_for_job(job, namespace=self.namespace)
132134

133-
def cleanup_job(self, job: Any, ctx: JobContext) -> None:
134-
"""Clean up K8s job resources."""
135-
job_name = job.metadata.name
136-
gke_client.cleanup_job(job_name, namespace=self.namespace)
135+
def cleanup_job(self, job: Any, ctx: JobContext) -> None:
136+
"""Clean up K8s job resources."""
137+
job_name = job.metadata.name
138+
gke_client.cleanup_job(job_name, namespace=self.namespace)
137139

138140

139141
class PathwaysBackend(BaseK8sBackend):
140-
"""Backend adapter for ML Pathways using LeaderWorkerSet."""
141-
def submit_job(self, ctx: JobContext) -> Any:
142-
"""Submit LWS job to GKE cluster."""
143-
return pathways_client.submit_pathways_job(
144-
display_name=ctx.display_name,
145-
container_uri=ctx.image_uri,
146-
accelerator=ctx.accelerator,
147-
project=ctx.project,
148-
job_id=ctx.job_id,
149-
bucket_name=ctx.bucket_name,
150-
namespace=self.namespace,
151-
)
142+
"""Backend adapter for ML Pathways using LeaderWorkerSet."""
152143

153-
def wait_for_job(self, job: Any, ctx: JobContext) -> None:
154-
"""Wait for Pathways LWS completion."""
155-
pathways_client.wait_for_job(job, ctx.job_id, namespace=self.namespace)
144+
def submit_job(self, ctx: JobContext) -> Any:
145+
"""Submit LWS job to GKE cluster."""
146+
return pathways_client.submit_pathways_job(
147+
display_name=ctx.display_name,
148+
container_uri=ctx.image_uri,
149+
accelerator=ctx.accelerator,
150+
project=ctx.project,
151+
job_id=ctx.job_id,
152+
bucket_name=ctx.bucket_name,
153+
namespace=self.namespace,
154+
)
156155

157-
def cleanup_job(self, job: Any, ctx: JobContext) -> None:
158-
"""Clean up LWS resources."""
159-
job_name = f"keras-pathways-{ctx.job_id}"
160-
pathways_client.cleanup_job(job_name, namespace=self.namespace)
156+
def wait_for_job(self, job: Any, ctx: JobContext) -> None:
157+
"""Wait for Pathways LWS completion."""
158+
pathways_client.wait_for_job(job, ctx.job_id, namespace=self.namespace)
161159

160+
def cleanup_job(self, job: Any, ctx: JobContext) -> None:
161+
"""Clean up LWS resources."""
162+
job_name = f"keras-pathways-{ctx.job_id}"
163+
pathways_client.cleanup_job(job_name, namespace=self.namespace)
162164

163165

164166
def _find_requirements(start_dir: str) -> Optional[str]:

0 commit comments

Comments
 (0)