Skip to content

Commit c478d56

Browse files
authored
Update simple demo (#55)
- Rename to simple_demo.py - Make CPU by default to ensure it runs without having a NodePool set up - Fix import
1 parent a751539 commit c478d56

File tree

1 file changed

+6
-9
lines changed

1 file changed

+6
-9
lines changed
Lines changed: 6 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -1,19 +1,16 @@
11
import os
2-
import socket
32

43
os.environ["KERAS_BACKEND"] = "jax"
54

65
import jax
76
import keras
87
import numpy as np
98

10-
from keras_remote import core as keras_remote
9+
import keras_remote
1110

1211

13-
@keras_remote.run(accelerator="v2-8")
12+
@keras_remote.run(accelerator="cpu")
1413
def train_keras_jax_model():
15-
host = socket.gethostname()
16-
print(f"Running on host: {host}")
1714
print(f"Keras version: {keras.__version__}")
1815
print(f"Keras backend: {keras.config.backend()}")
1916
print(f"JAX version: {jax.__version__}")
@@ -54,13 +51,13 @@ def train_keras_jax_model():
5451
)
5552

5653
print("Starting model.fit...")
57-
model.fit(x_train, y_train, epochs=1, batch_size=32, verbose=2)
54+
history = model.fit(x_train, y_train, epochs=1, batch_size=32, verbose=2)
5855
print("Model.fit finished.")
5956

60-
return f"Keras JAX training complete on {host}"
57+
return history.history["loss"][-1]
6158

6259

6360
if __name__ == "__main__":
6461
print("Starting Keras JAX demo...")
65-
result = train_keras_jax_model()
66-
print(f"Demo result: {result}")
62+
loss = train_keras_jax_model()
63+
print(f"Final training loss: {loss}")

0 commit comments

Comments
 (0)