|
1 | 1 | import os |
2 | | -import socket |
3 | 2 |
|
4 | 3 | os.environ["KERAS_BACKEND"] = "jax" |
5 | 4 |
|
6 | 5 | import jax |
7 | 6 | import keras |
8 | 7 | import numpy as np |
9 | 8 |
|
10 | | -from keras_remote import core as keras_remote |
| 9 | +import keras_remote |
11 | 10 |
|
12 | 11 |
|
13 | | -@keras_remote.run(accelerator="v2-8") |
| 12 | +@keras_remote.run(accelerator="cpu") |
14 | 13 | def train_keras_jax_model(): |
15 | | - host = socket.gethostname() |
16 | | - print(f"Running on host: {host}") |
17 | 14 | print(f"Keras version: {keras.__version__}") |
18 | 15 | print(f"Keras backend: {keras.config.backend()}") |
19 | 16 | print(f"JAX version: {jax.__version__}") |
@@ -54,13 +51,13 @@ def train_keras_jax_model(): |
54 | 51 | ) |
55 | 52 |
|
56 | 53 | print("Starting model.fit...") |
57 | | - model.fit(x_train, y_train, epochs=1, batch_size=32, verbose=2) |
| 54 | + history = model.fit(x_train, y_train, epochs=1, batch_size=32, verbose=2) |
58 | 55 | print("Model.fit finished.") |
59 | 56 |
|
60 | | - return f"Keras JAX training complete on {host}" |
| 57 | + return history.history["loss"][-1] |
61 | 58 |
|
62 | 59 |
|
63 | 60 | if __name__ == "__main__": |
64 | 61 | print("Starting Keras JAX demo...") |
65 | | - result = train_keras_jax_model() |
66 | | - print(f"Demo result: {result}") |
| 62 | + loss = train_keras_jax_model() |
| 63 | + print(f"Final training loss: {loss}") |
0 commit comments