Run Keras and JAX workloads on cloud TPUs and GPUs with a simple decorator. No infrastructure management required.
import keras_remote
@keras_remote.run(accelerator="v3-8")
def train_model():
import keras
model = keras.Sequential([...])
model.fit(x_train, y_train)
return model.history.history["loss"][-1]
# Executes on TPU v3-8, returns the result
final_loss = train_model()- Features
- Installation
- Quick Start
- Usage Examples
- Configuration
- Supported Accelerators
- Monitoring
- Troubleshooting
- Contributing
- License
- Simple decorator API — Add
@keras_remote.run()to any function to execute it remotely - Automatic infrastructure — No manual VM provisioning or teardown required
- Result serialization — Functions return actual values, not just logs
- Container caching — Subsequent runs start in 2-4 minutes after initial build
- Built-in monitoring — View job status and logs in Google Cloud Console
- Automatic cleanup — Resources are released when jobs complete
Install the core package to use the @keras_remote.run() decorator in your code:
git clone https://github.com/keras-team/keras-remote.git
cd keras-remote
pip install -e .This is sufficient if your infrastructure (GKE cluster, Artifact Registry, etc.) is already provisioned.
Install with the cli extra to also get the keras-remote command for managing infrastructure:
git clone https://github.com/keras-team/keras-remote.git
cd keras-remote
pip install -e ".[cli]"This adds the keras-remote up, keras-remote down, keras-remote status, and keras-remote config commands for provisioning and tearing down cloud resources.
- Python 3.11+
- Google Cloud SDK (
gcloud)- Run
gcloud auth loginandgcloud auth application-default login
- Run
- Pulumi CLI (required for
[cli]install only) - A Google Cloud project with billing enabled
Run the CLI setup command:
keras-remote upThis will interactively:
- Prompt for your GCP project ID
- Let you choose an accelerator type (CPU, GPU, or TPU)
- Enable required APIs (Cloud Build, Artifact Registry, Cloud Storage, GKE)
- Create the Artifact Registry repository
- Provision a GKE cluster with optional accelerator node pools
- Configure Docker authentication and kubectl access
You can also run non-interactively:
keras-remote up --project=my-project --accelerator=t4 --yesTo view current infrastructure state:
keras-remote statusTo view configuration:
keras-remote configAdd to your shell profile (~/.bashrc, ~/.zshrc, etc.):
export KERAS_REMOTE_PROJECT="your-project-id"
export KERAS_REMOTE_ZONE="us-central1-a" # Optionalimport keras_remote
@keras_remote.run(accelerator="v3-8")
def hello_tpu():
import jax
return f"Running on {jax.devices()}"
result = hello_tpu()
print(result)import keras_remote
@keras_remote.run(accelerator="v3-8")
def compute(x, y):
return x + y
result = compute(5, 7)
print(f"Result: {result}") # Output: Result: 12import keras_remote
@keras_remote.run(accelerator="v3-8")
def train_model():
import keras
import numpy as np
model = keras.Sequential([
keras.layers.Dense(64, activation="relu", input_shape=(10,)),
keras.layers.Dense(1)
])
model.compile(optimizer="adam", loss="mse")
x_train = np.random.randn(1000, 10)
y_train = np.random.randn(1000, 1)
history = model.fit(x_train, y_train, epochs=5, verbose=0)
return history.history["loss"][-1]
final_loss = train_model()
print(f"Final loss: {final_loss}")Create a requirements.txt in your project directory:
tensorflow-datasets
pillow
scikit-learn
Keras Remote automatically detects and installs dependencies on the remote worker.
Skip container build time by using prebuilt images:
@keras_remote.run(
accelerator="v3-8",
container_image="us-docker.pkg.dev/my-project/keras-remote/prebuilt:v1.0"
)
def train():
...See examples/Dockerfile.prebuilt for a template.
| Variable | Required | Default | Description |
|---|---|---|---|
KERAS_REMOTE_PROJECT |
Yes | — | Google Cloud project ID |
KERAS_REMOTE_ZONE |
No | us-central1-a |
Default compute zone |
KERAS_REMOTE_CLUSTER |
No | — | GKE cluster name |
@keras_remote.run(
accelerator="v3-8", # Required: TPU/GPU type
container_image=None, # Custom container URI
zone=None, # Override default zone
project=None, # Override default project
cluster=None, # GKE cluster name
namespace="default" # Kubernetes namespace
)| Type | Configurations |
|---|---|
| TPU v2 | v2-8, v2-32 |
| TPU v3 | v3-8, v3-32 |
| TPU v5 Litepod | v5litepod-1, v5litepod-4, v5litepod-8 |
| TPU v5p | v5p-8, v5p-16 |
| TPU v6e | v6e-8, v6e-16 |
| Type | Aliases |
|---|---|
| NVIDIA T4 | t4, nvidia-tesla-t4 |
| NVIDIA L4 | l4, nvidia-l4 |
| NVIDIA V100 | v100, nvidia-tesla-v100 |
| NVIDIA A100 | a100, nvidia-tesla-a100 |
| NVIDIA H100 | h100, nvidia-h100-80gb |
For multi-GPU configurations on GKE, append the count: a100x4, l4x2, etc.
- Cloud Build: console.cloud.google.com/cloud-build/builds
- GKE Workloads: console.cloud.google.com/kubernetes/workload
# List GKE jobs
kubectl get jobs -n defaultexport KERAS_REMOTE_PROJECT="your-project-id"Enable required APIs and create the Artifact Registry repository:
gcloud services enable cloudbuild.googleapis.com \
artifactregistry.googleapis.com storage.googleapis.com \
container.googleapis.com --project=$KERAS_REMOTE_PROJECT
gcloud artifacts repositories create keras-remote \
--repository-format=docker \
--location=us \
--project=$KERAS_REMOTE_PROJECTGrant required IAM roles:
gcloud projects add-iam-policy-binding $KERAS_REMOTE_PROJECT \
--member="user:your-email@example.com" \
--role="roles/storage.admin"Check Cloud Build logs:
gcloud builds list --project=$KERAS_REMOTE_PROJECT --limit=5import logging
logging.basicConfig(level=logging.INFO)# Check authentication
gcloud auth list
# Check project
echo $KERAS_REMOTE_PROJECT
# Check APIs
gcloud services list --enabled --project=$KERAS_REMOTE_PROJECT \
| grep -E "(cloudbuild|artifactregistry|storage|container)"
# Check Artifact Registry
gcloud artifacts repositories describe keras-remote \
--location=us --project=$KERAS_REMOTE_PROJECTRemove all Keras Remote resources to avoid charges:
keras-remote downThis removes:
- GKE cluster and accelerator node pools (via Pulumi)
- Artifact Registry repository and container images
- Cloud Storage buckets (jobs and builds)
Use
--yesto skip the confirmation prompt.
Contributions are welcome. Please read our contributing guidelines before submitting pull requests.
- Install the package with dev dependencies:
pip install -e ".[dev]"- Install pre-commit hooks:
pre-commit installThis enables automatic linting and formatting checks (via Ruff) on every commit.
To run the checks manually against all files:
pre-commit run --all-files- Fork the repository
- Create a feature branch (
git checkout -b feature/amazing-feature) - Commit your changes (
git commit -m 'Add amazing feature') - Push to the branch (
git push origin feature/amazing-feature) - Open a Pull Request
This project is licensed under the Apache License 2.0. See LICENSE for details.
Maintained by the Keras team at Google.