Skip to content

Latest commit

 

History

History
393 lines (275 loc) · 9.63 KB

File metadata and controls

393 lines (275 loc) · 9.63 KB

Keras Remote

License Python

Run Keras and JAX workloads on cloud TPUs and GPUs with a simple decorator. No infrastructure management required.

import keras_remote

@keras_remote.run(accelerator="v3-8")
def train_model():
    import keras
    model = keras.Sequential([...])
    model.fit(x_train, y_train)
    return model.history.history["loss"][-1]

# Executes on TPU v3-8, returns the result
final_loss = train_model()

Table of Contents

Features

  • Simple decorator API — Add @keras_remote.run() to any function to execute it remotely
  • Automatic infrastructure — No manual VM provisioning or teardown required
  • Result serialization — Functions return actual values, not just logs
  • Container caching — Subsequent runs start in 2-4 minutes after initial build
  • Built-in monitoring — View job status and logs in Google Cloud Console
  • Automatic cleanup — Resources are released when jobs complete

Installation

Library Only

Install the core package to use the @keras_remote.run() decorator in your code:

git clone https://github.com/keras-team/keras-remote.git
cd keras-remote
pip install -e .

This is sufficient if your infrastructure (GKE cluster, Artifact Registry, etc.) is already provisioned.

Library + CLI

Install with the cli extra to also get the keras-remote command for managing infrastructure:

git clone https://github.com/keras-team/keras-remote.git
cd keras-remote
pip install -e ".[cli]"

This adds the keras-remote up, keras-remote down, keras-remote status, and keras-remote config commands for provisioning and tearing down cloud resources.

Requirements

  • Python 3.11+
  • Google Cloud SDK (gcloud)
    • Run gcloud auth login and gcloud auth application-default login
  • Pulumi CLI (required for [cli] install only)
  • A Google Cloud project with billing enabled

Quick Start

1. Configure Google Cloud

Run the CLI setup command:

keras-remote up

This will interactively:

  • Prompt for your GCP project ID
  • Let you choose an accelerator type (CPU, GPU, or TPU)
  • Enable required APIs (Cloud Build, Artifact Registry, Cloud Storage, GKE)
  • Create the Artifact Registry repository
  • Provision a GKE cluster with optional accelerator node pools
  • Configure Docker authentication and kubectl access

You can also run non-interactively:

keras-remote up --project=my-project --accelerator=t4 --yes

To view current infrastructure state:

keras-remote status

To view configuration:

keras-remote config

2. Set Environment Variables

Add to your shell profile (~/.bashrc, ~/.zshrc, etc.):

export KERAS_REMOTE_PROJECT="your-project-id"
export KERAS_REMOTE_ZONE="us-central1-a"  # Optional

3. Run Your First Job

import keras_remote

@keras_remote.run(accelerator="v3-8")
def hello_tpu():
    import jax
    return f"Running on {jax.devices()}"

result = hello_tpu()
print(result)

Usage Examples

Basic Computation

import keras_remote

@keras_remote.run(accelerator="v3-8")
def compute(x, y):
    return x + y

result = compute(5, 7)
print(f"Result: {result}")  # Output: Result: 12

Keras Model Training

import keras_remote

@keras_remote.run(accelerator="v3-8")
def train_model():
    import keras
    import numpy as np

    model = keras.Sequential([
        keras.layers.Dense(64, activation="relu", input_shape=(10,)),
        keras.layers.Dense(1)
    ])
    model.compile(optimizer="adam", loss="mse")

    x_train = np.random.randn(1000, 10)
    y_train = np.random.randn(1000, 1)

    history = model.fit(x_train, y_train, epochs=5, verbose=0)
    return history.history["loss"][-1]

final_loss = train_model()
print(f"Final loss: {final_loss}")

Custom Dependencies

Create a requirements.txt in your project directory:

tensorflow-datasets
pillow
scikit-learn

Keras Remote automatically detects and installs dependencies on the remote worker.

Prebuilt Container Images

Skip container build time by using prebuilt images:

@keras_remote.run(
    accelerator="v3-8",
    container_image="us-docker.pkg.dev/my-project/keras-remote/prebuilt:v1.0"
)
def train():
    ...

See examples/Dockerfile.prebuilt for a template.

Configuration

Environment Variables

Variable Required Default Description
KERAS_REMOTE_PROJECT Yes Google Cloud project ID
KERAS_REMOTE_ZONE No us-central1-a Default compute zone
KERAS_REMOTE_CLUSTER No GKE cluster name

Decorator Parameters

@keras_remote.run(
    accelerator="v3-8",        # Required: TPU/GPU type
    container_image=None,      # Custom container URI
    zone=None,                 # Override default zone
    project=None,              # Override default project
    cluster=None,              # GKE cluster name
    namespace="default"        # Kubernetes namespace
)

Supported Accelerators

TPUs

Type Configurations
TPU v2 v2-8, v2-32
TPU v3 v3-8, v3-32
TPU v5 Litepod v5litepod-1, v5litepod-4, v5litepod-8
TPU v5p v5p-8, v5p-16
TPU v6e v6e-8, v6e-16

GPUs

Type Aliases
NVIDIA T4 t4, nvidia-tesla-t4
NVIDIA L4 l4, nvidia-l4
NVIDIA V100 v100, nvidia-tesla-v100
NVIDIA A100 a100, nvidia-tesla-a100
NVIDIA H100 h100, nvidia-h100-80gb

For multi-GPU configurations on GKE, append the count: a100x4, l4x2, etc.

Monitoring

Google Cloud Console

Command Line

# List GKE jobs
kubectl get jobs -n default

Troubleshooting

Common Issues

"Project must be specified" error

export KERAS_REMOTE_PROJECT="your-project-id"

"404 Requested entity was not found" error

Enable required APIs and create the Artifact Registry repository:

gcloud services enable cloudbuild.googleapis.com \
    artifactregistry.googleapis.com storage.googleapis.com \
    container.googleapis.com --project=$KERAS_REMOTE_PROJECT

gcloud artifacts repositories create keras-remote \
    --repository-format=docker \
    --location=us \
    --project=$KERAS_REMOTE_PROJECT

Permission denied errors

Grant required IAM roles:

gcloud projects add-iam-policy-binding $KERAS_REMOTE_PROJECT \
    --member="user:your-email@example.com" \
    --role="roles/storage.admin"

Container build failures

Check Cloud Build logs:

gcloud builds list --project=$KERAS_REMOTE_PROJECT --limit=5

Debug Logging

import logging
logging.basicConfig(level=logging.INFO)

Verify Setup

# Check authentication
gcloud auth list

# Check project
echo $KERAS_REMOTE_PROJECT

# Check APIs
gcloud services list --enabled --project=$KERAS_REMOTE_PROJECT \
    | grep -E "(cloudbuild|artifactregistry|storage|container)"

# Check Artifact Registry
gcloud artifacts repositories describe keras-remote \
    --location=us --project=$KERAS_REMOTE_PROJECT

Resource Cleanup

Remove all Keras Remote resources to avoid charges:

keras-remote down

This removes:

  • GKE cluster and accelerator node pools (via Pulumi)
  • Artifact Registry repository and container images
  • Cloud Storage buckets (jobs and builds) Use --yes to skip the confirmation prompt.

Contributing

Contributions are welcome. Please read our contributing guidelines before submitting pull requests.

Development Setup

  1. Install the package with dev dependencies:
pip install -e ".[dev]"
  1. Install pre-commit hooks:
pre-commit install

This enables automatic linting and formatting checks (via Ruff) on every commit.

To run the checks manually against all files:

pre-commit run --all-files

Submitting Changes

  1. Fork the repository
  2. Create a feature branch (git checkout -b feature/amazing-feature)
  3. Commit your changes (git commit -m 'Add amazing feature')
  4. Push to the branch (git push origin feature/amazing-feature)
  5. Open a Pull Request

License

This project is licensed under the Apache License 2.0. See LICENSE for details.


Maintained by the Keras team at Google.