Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
15 changes: 6 additions & 9 deletions examples/demo_train.py → examples/simple_demo.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,16 @@
import os
import socket

os.environ["KERAS_BACKEND"] = "jax"

import jax
import keras
import numpy as np

from keras_remote import core as keras_remote
import keras_remote


@keras_remote.run(accelerator="v2-8")
@keras_remote.run(accelerator="cpu")
def train_keras_jax_model():
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This function is missing a docstring. The repository style guide emphasizes the importance of documentation. Adding a docstring would improve the clarity and maintainability of this example.

Suggested change
def train_keras_jax_model():
def train_keras_jax_model():
"""Trains a simple Keras model remotely and returns the final loss."""
References
  1. The style guide states: "Don't neglect error messages, docstrings, and documentation." (link)

host = socket.gethostname()
print(f"Running on host: {host}")
print(f"Keras version: {keras.__version__}")
print(f"Keras backend: {keras.config.backend()}")
print(f"JAX version: {jax.__version__}")
Expand Down Expand Up @@ -54,13 +51,13 @@ def train_keras_jax_model():
)

print("Starting model.fit...")
model.fit(x_train, y_train, epochs=1, batch_size=32, verbose=2)
history = model.fit(x_train, y_train, epochs=1, batch_size=32, verbose=2)
print("Model.fit finished.")

return f"Keras JAX training complete on {host}"
return history.history["loss"][-1]


if __name__ == "__main__":
print("Starting Keras JAX demo...")
result = train_keras_jax_model()
print(f"Demo result: {result}")
loss = train_keras_jax_model()
print(f"Final training loss: {loss}")
Loading