Run Keras and JAX workloads on cloud TPUs and GPUs with a simple decorator. No infrastructure management required.
import keras_remote
@keras_remote.run(accelerator="v6e-8")
def train_model():
import keras
model = keras.Sequential([...])
model.fit(x_train, y_train)
return model.history.history["loss"][-1]
# Executes on TPU v6e-8, returns the result
final_loss = train_model()When you call a decorated function, Keras Remote handles the entire remote execution pipeline:
- Packages your function, local code, and data dependencies
- Builds a container with your dependencies via Cloud Build (cached after first build — subsequent runs skip this step)
- Runs the job on a GKE cluster with the requested accelerator (TPU or GPU)
- Returns the result to your local machine — logs are streamed in real time, and the function's return value is delivered back as if it ran locally
If the remote function raises an exception, it is re-raised locally with the original traceback, so debugging works the same as local development.
You need a GKE cluster with accelerator node pools to run jobs. The keras-remote CLI handles this setup for you.
- Simple decorator API — Add
@keras_remote.run()to any function to execute it remotely - Automatic infrastructure — No manual VM provisioning or teardown required
- Result serialization — Functions return actual values, not just logs
- Fast iteration — Container images are cached by dependency hash; unchanged dependencies skip the build entirely (subsequent runs start in less than a minute)
- Data API — Declarative
Dataclass with smart caching for local files and GCS data - Environment variable forwarding — Propagate local env vars to the remote environment with wildcard patterns (
capture_env_vars=["KAGGLE_*"]) - Built-in monitoring — View job status and logs in Google Cloud Console
- Automatic cleanup — Resources are released when jobs complete
- Transparent errors — Remote exceptions are re-raised locally with the original traceback
- Python 3.11+
- Google Cloud SDK (
gcloud) — install guide - A Google Cloud project with billing enabled
Authenticate with Google Cloud:
gcloud auth login
gcloud auth application-default loginNote: The Pulumi CLI (used for infrastructure provisioning) is bundled and managed automatically. It will be installed to
~/.keras-remote/pulumion first use if not already present.
git clone https://github.com/keras-team/keras-remote.git
cd keras-remote
pip install -e ".[cli]"This installs both the @keras_remote.run() decorator and the keras-remote CLI for managing infrastructure.
If your GKE cluster and Artifact Registry are already provisioned, you can install without the CLI:
pip install -e .
Run the one-time setup to create the required cloud resources:
keras-remote upThis interactively prompts for your GCP project and accelerator type, then:
- Enables required APIs (Cloud Build, Artifact Registry, Cloud Storage, GKE)
- Creates an Artifact Registry repository for container images
- Provisions a GKE cluster with an accelerator node pool
- Configures Docker authentication and kubectl access
You can also run non-interactively:
keras-remote up --project=my-project --accelerator=t4 --yesCleanup reminder: When you're done, run
keras-remote downto tear down all resources and avoid ongoing charges. See CLI Commands.
Set your project ID so the library knows where to run jobs:
export KERAS_REMOTE_PROJECT="your-project-id"Add this to your shell profile (~/.bashrc, ~/.zshrc, etc.) to persist it. See Configuration for the full list of environment variables.
import keras_remote
@keras_remote.run(accelerator="v6e-8")
def hello_tpu():
import jax
return f"Running on {jax.devices()}"
result = hello_tpu()
print(result)First run timing: The initial execution takes longer (~5 minutes) because it builds a container image with your dependencies. Subsequent runs with unchanged dependencies use the cached image and start in less than a minute.
import keras_remote
@keras_remote.run(accelerator="v6e-8")
def train_model():
import keras
import numpy as np
model = keras.Sequential([
keras.layers.Dense(64, activation="relu", input_shape=(10,)),
keras.layers.Dense(1)
])
model.compile(optimizer="adam", loss="mse")
x_train = np.random.randn(1000, 10)
y_train = np.random.randn(1000, 1)
history = model.fit(x_train, y_train, epochs=5, verbose=0)
return history.history["loss"][-1]
final_loss = train_model()
print(f"Final loss: {final_loss}")Keras Remote provides a declarative Data API to seamlessly make your local and cloud data available to remote functions.
The Data API is read-only — it delivers data to your pods at the start of a job. For saving model outputs or checkpointing, write directly to GCS from within your function.
Under the hood, the Data API provides two key optimizations:
- Smart Caching: Local data is content-hashed and uploaded to a cache bucket only once. Subsequent job runs with byte-identical data skip the upload entirely.
- Automatic Zip Exclusion: When you reference a data path inside your current working directory, Keras Remote automatically excludes that directory from the project's zipped payload to avoid uploading the same data twice.
There are three approaches depending on your workflow:
The simplest approach — pass Data objects as regular function arguments. The Data class wraps a local file/directory path or a Google Cloud Storage (GCS) URI.
On the remote pod, these objects are automatically resolved into plain string paths pointing to the downloaded files, so your function code never needs to know about GCS or cloud storage APIs.
import pandas as pd
import keras_remote
from keras_remote import Data
@keras_remote.run(accelerator="v6e-8")
def train(data_dir):
# data_dir is resolved to a local path on the remote machine
df = pd.read_csv(f"{data_dir}/train.csv")
# ...
# Uploads the local directory to the remote pod automatically
train(Data("./my_dataset/"))
# Cache hit: subsequent runs with the same data skip the upload!
train(Data("./my_dataset/"))GCS Directories: When referencing a GCS directory with the
Dataclass, include a trailing slash (e.g.,Data("gs://my-bucket/dataset/")). Without the trailing slash, the system treats it as a single file object.
You can also pass multiple Data arguments, or nest them inside lists and dictionaries (e.g., train(datasets=[Data("./d1"), Data("./d2")])).
For established training scripts where data requirements are fixed, use the volumes parameter in the decorator. This mounts data at hardcoded absolute filesystem paths, allowing you to use Keras Remote with existing codebases without altering the function signature.
import pandas as pd
import keras_remote
from keras_remote import Data
@keras_remote.run(
accelerator="v6e-8",
volumes={
"/data": Data("./my_dataset/"),
"/weights": Data("gs://my-bucket/pretrained-weights/")
}
)
def train():
# Data is guaranteed to be available at these absolute paths
df = pd.read_csv("/data/train.csv")
model.load_weights("/weights/model.h5")
# ...
# No data arguments needed!
train()If your dataset is very large (e.g., > 10GB), it is inefficient to download the entire dataset to the pod's local disk. Instead, skip the Data wrapper and pass a GCS URI string directly. Use frameworks with native GCS streaming support (like tf.data or grain) to read the data on the fly.
import grain.python as grain
import keras_remote
@keras_remote.run(accelerator="v6e-8")
def train(data_uri):
# Native GCS reading, no download overhead
data_source = grain.ArrayRecordDataSource(data_uri)
# ...
# Pass as a plain string, no Data() wrapper needed
train("gs://my-bucket/arrayrecords/")Create a requirements.txt in your project directory:
tensorflow-datasets
pillow
scikit-learn
Alternatively, dependencies declared in pyproject.toml are also supported:
[project]
dependencies = [
"tensorflow-datasets",
"pillow",
"scikit-learn",
]Keras Remote automatically detects and installs dependencies on the remote worker.
If both files exist in the same directory, requirements.txt takes precedence.
Note: JAX packages (
jax,jaxlib,libtpu,libtpu-nightly) are automatically filtered from your dependencies to prevent overriding the accelerator-specific JAX installation. To keep a JAX line, append# kr:keepto it.
Skip container build time by using prebuilt images:
@keras_remote.run(
accelerator="v6e-8",
container_image="us-docker.pkg.dev/my-project/keras-remote/prebuilt:v1.0"
)
def train():
...Build your own prebuilt image using the project's Dockerfile template as a starting point.
Use capture_env_vars to propagate local environment variables to the remote pod. This supports exact names and wildcard patterns:
import keras_remote
@keras_remote.run(
accelerator="v5litepod-1",
capture_env_vars=["KAGGLE_*", "GOOGLE_CLOUD_*"]
)
def train_gemma():
import keras_hub
gemma_lm = keras_hub.models.Gemma3CausalLM.from_preset("gemma3_1b")
# KAGGLE_USERNAME and KAGGLE_KEY are available for model downloads
# ...This is useful for forwarding API keys, credentials, or configuration without hardcoding them.
Multi-host TPU configurations (those requiring more than one node, such as v2-16, v3-32, or v5p-16) automatically use the Pathways backend. You can also set the backend explicitly:
@keras_remote.run(accelerator="v3-32", backend="pathways")
def distributed_train():
...You can run multiple independent clusters within the same GCP project — for example, one for GPU workloads and another for TPUs. Each cluster gets its own isolated set of cloud resources (GKE cluster, Artifact Registry, storage buckets) backed by a separate infrastructure stack, so they never interfere with each other.
Create clusters by passing --cluster to keras-remote up:
# Default cluster (named "keras-remote-cluster")
keras-remote up --project=my-project --accelerator=v6e-8
# A separate GPU cluster
keras-remote up --project=my-project --cluster=gpu-cluster --accelerator=a100Target a cluster in your code with the cluster parameter or the KERAS_REMOTE_CLUSTER environment variable:
# Run on the GPU cluster
@keras_remote.run(accelerator="a100", cluster="gpu-cluster")
def train_on_gpu():
...
# Or set the env var to avoid repeating the cluster name
# export KERAS_REMOTE_CLUSTER="gpu-cluster"
@keras_remote.run(accelerator="a100")
def train_on_gpu():
...All CLI commands accept --cluster as well, so you can manage each cluster independently:
keras-remote status --cluster=gpu-cluster
keras-remote pool add --cluster=gpu-cluster --accelerator=h100
keras-remote down --cluster=gpu-clusterFor more examples, see the examples/ directory.
| Variable | Required | Default | Description |
|---|---|---|---|
KERAS_REMOTE_PROJECT |
Yes | — | Google Cloud project ID |
KERAS_REMOTE_ZONE |
No | us-central1-a |
Default compute zone |
KERAS_REMOTE_CLUSTER |
No | keras-remote-cluster |
GKE cluster name |
KERAS_REMOTE_GKE_NAMESPACE |
No | default |
Kubernetes namespace |
KERAS_REMOTE_LOG_LEVEL |
No | INFO |
Log verbosity (DEBUG, INFO, WARNING, ERROR, FATAL) |
Keras Remote uses absl-py for logging. Set KERAS_REMOTE_LOG_LEVEL=DEBUG for verbose output when debugging issues.
@keras_remote.run(
accelerator="v6e-8", # TPU/GPU type (default: "v6e-8")
container_image=None, # Custom container URI
zone=None, # Override default zone
project=None, # Override default project
capture_env_vars=None, # Env var names/patterns to forward (supports * wildcard)
cluster=None, # GKE cluster name
backend=None, # "gke", "pathways", or None (auto-detect)
namespace="default", # Kubernetes namespace
volumes=None, # Dict mapping absolute paths to Data objects
)Each accelerator and topology requires setting up its own node pool as a prerequisite.
| Type | Configurations |
|---|---|
| TPU v3 | v3-4, v3-16, v3-32, v3-64, v3-128, v3-256, v3-512, v3-1024, v3-2048 |
| TPU v4 | v4-4, v4-8, v4-16, v4-32, v4-64, v4-128, v4-256, v4-512, v4-1024, v4-2048, v4-4096 |
| TPU v5 Litepod | v5litepod-1, v5litepod-4, v5litepod-8, v5litepod-16, v5litepod-32, v5litepod-64, v5litepod-128, v5litepod-256 |
| TPU v5p | v5p-8, v5p-16, v5p-32 |
| TPU v6e | v6e-8, v6e-16 |
| Type | Aliases | Multi-GPU Counts |
|---|---|---|
| NVIDIA T4 | t4, nvidia-tesla-t4 |
1, 2, 4 |
| NVIDIA L4 | l4, nvidia-l4 |
1, 2, 4, 8 |
| NVIDIA V100 | v100, nvidia-tesla-v100 |
1, 2, 4, 8 |
| NVIDIA A100 | a100, nvidia-tesla-a100 |
1, 2, 4, 8, 16 |
| NVIDIA A100 80GB | a100-80gb, nvidia-a100-80gb |
1, 2, 4, 8, 16 |
| NVIDIA H100 | h100, nvidia-h100-80gb |
1, 2, 4, 8 |
| NVIDIA P4 | p4, nvidia-tesla-p4 |
1, 2, 4 |
| NVIDIA P100 | p100, nvidia-tesla-p100 |
1, 2, 4 |
For multi-GPU configurations on GKE, append the count: a100x4, l4x2, etc.
Use accelerator="cpu" to run on a CPU-only node (no accelerator attached).
The keras-remote CLI manages your cloud infrastructure. Install it with pip install -e ".[cli]".
Provision all required cloud resources (one-time setup):
keras-remote up
keras-remote up --project=my-project --accelerator=t4 --yesRemove all Keras Remote resources to avoid ongoing charges:
keras-remote down
keras-remote down --yes # Skip confirmation promptThis removes the GKE cluster and node pools, Artifact Registry repository and container images, and Cloud Storage buckets.
View current infrastructure state:
keras-remote statusView current configuration:
keras-remote configManage accelerator node pools after initial setup:
# Add a node pool for a specific accelerator
keras-remote pool add --accelerator=v6e-8
# List current node pools
keras-remote pool list
# Remove a node pool by name
keras-remote pool remove <pool-name>- Cloud Build: console.cloud.google.com/cloud-build/builds
- GKE Workloads: console.cloud.google.com/kubernetes/workload
# List GKE jobs
kubectl get jobs -n defaultexport KERAS_REMOTE_PROJECT="your-project-id"Enable required APIs and create the Artifact Registry repository:
gcloud services enable compute.googleapis.com \
cloudbuild.googleapis.com artifactregistry.googleapis.com \
storage.googleapis.com container.googleapis.com \
--project=$KERAS_REMOTE_PROJECT
gcloud artifacts repositories create keras-remote \
--repository-format=docker \
--location=us \
--project=$KERAS_REMOTE_PROJECTGrant required IAM roles:
gcloud projects add-iam-policy-binding $KERAS_REMOTE_PROJECT \
--member="user:your-email@example.com" \
--role="roles/storage.admin"Check Cloud Build logs:
gcloud builds list --project=$KERAS_REMOTE_PROJECT --limit=5Run keras-remote status to check the health of your infrastructure. For manual verification:
# Check authentication
gcloud auth list
# Check project
echo $KERAS_REMOTE_PROJECT
# Check APIs
gcloud services list --enabled --project=$KERAS_REMOTE_PROJECT \
| grep -E "(cloudbuild|artifactregistry|storage|container)"Contributions are welcome. Please read our contributing guidelines before submitting pull requests.
All contributions must follow our Code of Conduct.
This project is licensed under the Apache License 2.0. See LICENSE for details.
Maintained by the Keras team at Google.