Run Keras and JAX workloads on cloud TPUs and GPUs with a simple decorator. No infrastructure management required.
import keras_remote
@keras_remote.run(accelerator="v3-8")
def train_model():
import keras
model = keras.Sequential([...])
model.fit(x_train, y_train)
return model.history.history["loss"][-1]
# Executes on TPU v3-8, returns the result
final_loss = train_model()- Features
- Installation
- Quick Start
- Usage Examples
- Configuration
- Supported Accelerators
- Monitoring
- Troubleshooting
- Contributing
- License
- Simple decorator API — Add
@keras_remote.run()to any function to execute it remotely - Automatic infrastructure — No manual VM provisioning or teardown required
- Result serialization — Functions return actual values, not just logs
- Container caching — Subsequent runs start in 2-4 minutes after initial build
- Built-in monitoring — View job status and logs in Google Cloud Console
- Automatic cleanup — Resources are released when jobs complete
Install the core package to use the @keras_remote.run() decorator in your code:
git clone https://github.com/keras-team/keras-remote.git
cd keras-remote
pip install -e .This is sufficient if your infrastructure (GKE cluster, Artifact Registry, etc.) is already provisioned.
Install with the cli extra to also get the keras-remote command for managing infrastructure:
git clone https://github.com/keras-team/keras-remote.git
cd keras-remote
pip install -e ".[cli]"This adds the keras-remote up, keras-remote down, keras-remote status, and keras-remote config commands for provisioning and tearing down cloud resources.
- Python 3.11+
- Google Cloud SDK (
gcloud)- Run
gcloud auth loginandgcloud auth application-default login
- Run
- A Google Cloud project with billing enabled
Note: The Pulumi CLI is bundled and managed automatically. It will be installed to ~/.keras-remote/pulumi on first use if not already present.
Run the CLI setup command:
keras-remote upThis will interactively:
- Prompt for your GCP project ID
- Let you choose an accelerator type (CPU, GPU, or TPU)
- Enable required APIs (Cloud Build, Artifact Registry, Cloud Storage, GKE)
- Create the Artifact Registry repository
- Provision a GKE cluster with optional accelerator node pools
- Configure Docker authentication and kubectl access
You can also run non-interactively:
keras-remote up --project=my-project --accelerator=t4 --yesTo view current infrastructure state:
keras-remote statusTo view configuration:
keras-remote configAdd to your shell profile (~/.bashrc, ~/.zshrc, etc.):
export KERAS_REMOTE_PROJECT="your-project-id"
export KERAS_REMOTE_ZONE="us-central1-a" # Optionalimport keras_remote
@keras_remote.run(accelerator="v3-8")
def hello_tpu():
import jax
return f"Running on {jax.devices()}"
result = hello_tpu()
print(result)import keras_remote
@keras_remote.run(accelerator="v3-8")
def compute(x, y):
return x + y
result = compute(5, 7)
print(f"Result: {result}") # Output: Result: 12import keras_remote
@keras_remote.run(accelerator="v3-8")
def train_model():
import keras
import numpy as np
model = keras.Sequential([
keras.layers.Dense(64, activation="relu", input_shape=(10,)),
keras.layers.Dense(1)
])
model.compile(optimizer="adam", loss="mse")
x_train = np.random.randn(1000, 10)
y_train = np.random.randn(1000, 1)
history = model.fit(x_train, y_train, epochs=5, verbose=0)
return history.history["loss"][-1]
final_loss = train_model()
print(f"Final loss: {final_loss}")Create a requirements.txt in your project directory:
tensorflow-datasets
pillow
scikit-learn
Keras Remote automatically detects and installs dependencies on the remote worker.
Skip container build time by using prebuilt images:
@keras_remote.run(
accelerator="v3-8",
container_image="us-docker.pkg.dev/my-project/keras-remote/prebuilt:v1.0"
)
def train():
...See examples/Dockerfile.prebuilt for a template.
Keras Remote provides a declarative and performant Data API to seamlessly make your local and cloud data available to your remote functions.
The Data API is designed to be read-only. It reliably delivers data to your pods at the start of a job. For saving model outputs or checkpointing, you should write directly to GCS from within your function.
Under the hood, the Data API optimizes your workflows with two key features:
- Smart Caching: Local data is content-hashed and uploaded to a cache bucket only once. Subsequent job runs that use byte-identical data will hit the cache and skip the upload entirely, drastically speeding up execution.
- Automatic Zip Exclusion: When you reference a data path inside your current working directory, Keras Remote automatically excludes that directory from the project's zipped payload to avoid uploading the same data twice.
There are three main ways to handle data depending on your workflow:
The simplest and most Pythonic approach is to pass Data objects as regular function arguments. The Data class wraps a local file/directory path or a Google Cloud Storage (GCS) URI.
On the remote pod, these objects are automatically resolved into plain string paths pointing to the downloaded files, meaning your function code never needs to know about GCS or cloud storage APIs.
import pandas as pd
import keras_remote
from keras_remote import Data
@keras_remote.run(accelerator="v6e-8")
def train(data_dir):
# data_dir is resolved to a dynamic local path on the remote machine
df = pd.read_csv(f"{data_dir}/train.csv")
# ...
# Uploads the local directory to the remote pod automatically
train(Data("./my_dataset/"))
# Cache hit: subsequent runs with the same data skip the upload!
train(Data("./my_dataset/"))Note on GCS Directories: When referencing a GCS directory with the Data class, you must include a trailing slash (e.g., Data("gs://my-bucket/dataset/")). If you omit the trailing slash, the system will treat it as a single file object.
You can also pass multiple Data arguments, or nest them inside lists and dictionaries (e.g., train(datasets=[Data("./d1"), Data("./d2")])).
For established training scripts where data requirements are static, you can use the volumes parameter in the @keras_remote.run decorator. This mounts data at fixed, hardcoded absolute filesystem paths, allowing you to drop keras_remote into existing codebases without altering the function signature.
import pandas as pd
import keras_remote
from keras_remote import Data
@keras_remote.run(
accelerator="v6e-8",
volumes={
"/data": Data("./my_dataset/"),
"/weights": Data("gs://my-bucket/pretrained-weights/")
}
)
def train():
# Data is guaranteed to be available at these absolute paths
df = pd.read_csv("/data/train.csv")
model.load_weights("/weights/model.h5")
# ...
# No data arguments needed!
train()If your dataset is very large (e.g., > 10GB), it is inefficient to download the entire dataset to the remote pod's local disk. Instead, skip the Data wrapper entirely and pass a GCS URI string directly. You can then use frameworks with native GCS streaming support (like tf.data or grain) to read the data on the fly.
import grain.python as grain
import keras_remote
@keras_remote.run(accelerator="v6e-8")
def train(data_uri):
# Native GCS reading, no download overhead
data_source = grain.ArrayRecordDataSource(data_uri)
# ...
# Pass as a plain string, no Data() wrapper needed
train("gs://my-bucket/arrayrecords/")| Variable | Required | Default | Description |
|---|---|---|---|
KERAS_REMOTE_PROJECT |
Yes | — | Google Cloud project ID |
KERAS_REMOTE_ZONE |
No | us-central1-a |
Default compute zone |
KERAS_REMOTE_CLUSTER |
No | — | GKE cluster name |
@keras_remote.run(
accelerator="v3-8", # Required: TPU/GPU type
container_image=None, # Custom container URI
zone=None, # Override default zone
project=None, # Override default project
cluster=None, # GKE cluster name
namespace="default" # Kubernetes namespace
)Note: each accelerator and topology requires setting up its own NodePool as a prerequisite.
| Type | Configurations |
|---|---|
| TPU v2 | v2-8, v2-32 |
| TPU v3 | v3-8, v3-32 |
| TPU v5 Litepod | v5litepod-1, v5litepod-4, v5litepod-8 |
| TPU v5p | v5p-8, v5p-16 |
| TPU v6e | v6e-8, v6e-16 |
| Type | Aliases |
|---|---|
| NVIDIA T4 | t4, nvidia-tesla-t4 |
| NVIDIA L4 | l4, nvidia-l4 |
| NVIDIA V100 | v100, nvidia-tesla-v100 |
| NVIDIA A100 | a100, nvidia-tesla-a100 |
| NVIDIA H100 | h100, nvidia-h100-80gb |
For multi-GPU configurations on GKE, append the count: a100x4, l4x2, etc.
- Cloud Build: console.cloud.google.com/cloud-build/builds
- GKE Workloads: console.cloud.google.com/kubernetes/workload
# List GKE jobs
kubectl get jobs -n defaultexport KERAS_REMOTE_PROJECT="your-project-id"Enable required APIs and create the Artifact Registry repository:
gcloud services enable cloudbuild.googleapis.com \
artifactregistry.googleapis.com storage.googleapis.com \
container.googleapis.com --project=$KERAS_REMOTE_PROJECT
gcloud artifacts repositories create keras-remote \
--repository-format=docker \
--location=us \
--project=$KERAS_REMOTE_PROJECTGrant required IAM roles:
gcloud projects add-iam-policy-binding $KERAS_REMOTE_PROJECT \
--member="user:your-email@example.com" \
--role="roles/storage.admin"Check Cloud Build logs:
gcloud builds list --project=$KERAS_REMOTE_PROJECT --limit=5import logging
logging.basicConfig(level=logging.INFO)# Check authentication
gcloud auth list
# Check project
echo $KERAS_REMOTE_PROJECT
# Check APIs
gcloud services list --enabled --project=$KERAS_REMOTE_PROJECT \
| grep -E "(cloudbuild|artifactregistry|storage|container)"
# Check Artifact Registry
gcloud artifacts repositories describe keras-remote \
--location=us --project=$KERAS_REMOTE_PROJECTRemove all Keras Remote resources to avoid charges:
keras-remote downThis removes:
- GKE cluster and accelerator node pools
- Artifact Registry repository and container images
- Cloud Storage buckets (jobs and builds)
Use
--yesto skip the confirmation prompt.
Contributions are welcome. Please read our contributing guidelines before submitting pull requests.
All contributions must follow our Code of Conduct.
This project is licensed under the Apache License 2.0. See LICENSE for details.
Maintained by the Keras team at Google.