diff --git a/README.md b/README.md index 00466e7..39e40a6 100644 --- a/README.md +++ b/README.md @@ -21,46 +21,58 @@ final_loss = train_model() ## Table of Contents +- [How It Works](#how-it-works) - [Features](#features) -- [Installation](#installation) -- [Quick Start](#quick-start) -- [Usage Examples](#usage-examples) -- [Handling Data](#handling-data) -- [Configuration](#configuration) -- [Supported Accelerators](#supported-accelerators) -- [Monitoring](#monitoring) +- [Getting Started](#getting-started) +- [Usage Guide](#usage-guide) +- [Reference](#reference) - [Troubleshooting](#troubleshooting) -- [Resource Cleanup](#resource-cleanup) - [Contributing](#contributing) - [License](#license) +## How It Works + +When you call a decorated function, Keras Remote handles the entire remote execution pipeline: + +1. **Packages** your function, local code, and data dependencies +2. **Builds a container** with your dependencies via Cloud Build (cached after first build — subsequent runs skip this step) +3. **Runs the job** on a GKE cluster with the requested accelerator (TPU or GPU) +4. **Returns the result** to your local machine — logs are streamed in real time, and the function's return value is delivered back as if it ran locally + +If the remote function raises an exception, it is re-raised locally with the original traceback, so debugging works the same as local development. + +You need a GKE cluster with accelerator node pools to run jobs. The `keras-remote` CLI handles this setup for you. + ## Features - **Simple decorator API** — Add `@keras_remote.run()` to any function to execute it remotely - **Automatic infrastructure** — No manual VM provisioning or teardown required - **Result serialization** — Functions return actual values, not just logs -- **Container caching** — Subsequent runs start in 2-4 minutes after initial build +- **Fast iteration** — Container images are cached by dependency hash; unchanged dependencies skip the build entirely (subsequent runs start in less than a minute) +- **Data API** — Declarative `Data` class with smart caching for local files and GCS data +- **Environment variable forwarding** — Propagate local env vars to the remote environment with wildcard patterns (`capture_env_vars=["KAGGLE_*"]`) - **Built-in monitoring** — View job status and logs in Google Cloud Console - **Automatic cleanup** — Resources are released when jobs complete - **Transparent errors** — Remote exceptions are re-raised locally with the original traceback -## Installation +## Getting Started + +### Prerequisites -### Library Only +- Python 3.11+ +- Google Cloud SDK (`gcloud`) — [install guide](https://cloud.google.com/sdk/docs/install) +- A Google Cloud project with billing enabled -Install the core package to use the `@keras_remote.run()` decorator in your code: +Authenticate with Google Cloud: ```bash -git clone https://github.com/keras-team/keras-remote.git -cd keras-remote -pip install -e . +gcloud auth login +gcloud auth application-default login ``` -This is sufficient if your infrastructure (GKE cluster, Artifact Registry, etc.) is already provisioned. +> **Note:** The Pulumi CLI (used for infrastructure provisioning) is bundled and managed automatically. It will be installed to `~/.keras-remote/pulumi` on first use if not already present. -### Library + CLI - -Install with the `cli` extra to also get the `keras-remote` command for managing infrastructure: +### Install ```bash git clone https://github.com/keras-team/keras-remote.git @@ -68,35 +80,24 @@ cd keras-remote pip install -e ".[cli]" ``` -This adds the `keras-remote up`, `keras-remote down`, `keras-remote status`, `keras-remote config`, and `keras-remote pool` commands for provisioning and managing cloud resources. - -### Requirements - -- Python 3.11+ -- Google Cloud SDK (`gcloud`) - - Run `gcloud auth login` and `gcloud auth application-default login` -- A Google Cloud project with billing enabled +This installs both the `@keras_remote.run()` decorator and the `keras-remote` CLI for managing infrastructure. -Note: The Pulumi CLI is bundled and managed automatically. It will be installed to `~/.keras-remote/pulumi` on first use if not already present. +> If your GKE cluster and Artifact Registry are already provisioned, you can install without the CLI: `pip install -e .` -## Quick Start +### Provision Infrastructure -### 1. Configure Google Cloud - -Run the CLI setup command: +Run the one-time setup to create the required cloud resources: ```bash keras-remote up ``` -This will interactively: +This interactively prompts for your GCP project and accelerator type, then: -- Prompt for your GCP project ID -- Let you choose an accelerator type (CPU, GPU, or TPU) -- Enable required APIs (Cloud Build, Artifact Registry, Cloud Storage, GKE) -- Create the Artifact Registry repository -- Provision a GKE cluster with optional accelerator node pools -- Configure Docker authentication and kubectl access +- Enables required APIs (Cloud Build, Artifact Registry, Cloud Storage, GKE) +- Creates an Artifact Registry repository for container images +- Provisions a GKE cluster with an accelerator node pool +- Configures Docker authentication and kubectl access You can also run non-interactively: @@ -104,41 +105,19 @@ You can also run non-interactively: keras-remote up --project=my-project --accelerator=t4 --yes ``` -To view current infrastructure state: +> **Cleanup reminder:** When you're done, run `keras-remote down` to tear down all resources and avoid ongoing charges. See [CLI Commands](#cli-commands). -```bash -keras-remote status -``` - -To view configuration: - -```bash -keras-remote config -``` +### Configure -To manage accelerator node pools after initial setup: - -```bash -# Add a node pool for a specific accelerator -keras-remote pool add --accelerator=v6e-8 - -# List current node pools -keras-remote pool list - -# Remove a node pool by name -keras-remote pool remove -``` - -### 2. Set Environment Variables - -Add to your shell profile (`~/.bashrc`, `~/.zshrc`, etc.): +Set your project ID so the library knows where to run jobs: ```bash export KERAS_REMOTE_PROJECT="your-project-id" -export KERAS_REMOTE_ZONE="us-central1-a" # Optional ``` -### 3. Run Your First Job +Add this to your shell profile (`~/.bashrc`, `~/.zshrc`, etc.) to persist it. See [Configuration](#configuration) for the full list of environment variables. + +### Run Your First Job ```python import keras_remote @@ -152,22 +131,11 @@ result = hello_tpu() print(result) ``` -## Usage Examples - -### Basic Computation - -```python -import keras_remote - -@keras_remote.run(accelerator="v6e-8") -def compute(x, y): - return x + y +> **First run timing:** The initial execution takes longer (~5 minutes) because it builds a container image with your dependencies. Subsequent runs with unchanged dependencies use the cached image and start in less than a minute. -result = compute(5, 7) -print(f"Result: {result}") # Output: Result: 12 -``` +## Usage Guide -### Keras Model Training +### Training a Keras Model ```python import keras_remote @@ -193,65 +161,24 @@ final_loss = train_model() print(f"Final loss: {final_loss}") ``` -### Custom Dependencies - -Create a `requirements.txt` in your project directory: - -```text -tensorflow-datasets -pillow -scikit-learn -``` - -Alternatively, dependencies declared in `pyproject.toml` are also supported: - -```toml -[project] -dependencies = [ - "tensorflow-datasets", - "pillow", - "scikit-learn", -] -``` - -Keras Remote automatically detects and installs dependencies on the remote worker. -If both files exist in the same directory, `requirements.txt` takes precedence. +### Working with Data -> **Note:** JAX packages (`jax`, `jaxlib`, `libtpu`, `libtpu-nightly`) are automatically filtered from your `requirements.txt` to prevent overriding the accelerator-specific JAX installation. To keep a JAX line, append `# kr:keep` to it. +Keras Remote provides a declarative Data API to seamlessly make your local and cloud data available to remote functions. -### Prebuilt Container Images +The Data API is read-only — it delivers data to your pods at the start of a job. For saving model outputs or checkpointing, write directly to GCS from within your function. -Skip container build time by using prebuilt images: +Under the hood, the Data API provides two key optimizations: -```python -@keras_remote.run( - accelerator="v6e-8", - container_image="us-docker.pkg.dev/my-project/keras-remote/prebuilt:v1.0" -) -def train(): - ... -``` - -Build your own prebuilt image using the project's Dockerfile template as a starting point. - -## Handling Data - -Keras Remote provides a declarative and performant Data API to seamlessly make your local and cloud data available to your remote functions. - -The Data API is designed to be read-only. It reliably delivers data to your pods at the start of a job. For saving model outputs or checkpointing, you should write directly to GCS from within your function. - -Under the hood, the Data API optimizes your workflows with two key features: - -- **Smart Caching:** Local data is content-hashed and uploaded to a cache bucket only once. Subsequent job runs that use byte-identical data will hit the cache and skip the upload entirely, drastically speeding up execution. +- **Smart Caching:** Local data is content-hashed and uploaded to a cache bucket only once. Subsequent job runs with byte-identical data skip the upload entirely. - **Automatic Zip Exclusion:** When you reference a data path inside your current working directory, Keras Remote automatically excludes that directory from the project's zipped payload to avoid uploading the same data twice. -There are three main ways to handle data depending on your workflow: +There are three approaches depending on your workflow: -### 1. Dynamic Data (The `Data` Class) +#### Dynamic Data (The `Data` Class) -The simplest and most Pythonic approach is to pass `Data` objects as regular function arguments. The `Data` class wraps a local file/directory path or a Google Cloud Storage (GCS) URI. +The simplest approach — pass `Data` objects as regular function arguments. The `Data` class wraps a local file/directory path or a Google Cloud Storage (GCS) URI. -On the remote pod, these objects are automatically resolved into plain string paths pointing to the downloaded files, meaning your function code never needs to know about GCS or cloud storage APIs. +On the remote pod, these objects are automatically resolved into plain string paths pointing to the downloaded files, so your function code never needs to know about GCS or cloud storage APIs. ```python import pandas as pd @@ -260,7 +187,7 @@ from keras_remote import Data @keras_remote.run(accelerator="v6e-8") def train(data_dir): - # data_dir is resolved to a dynamic local path on the remote machine + # data_dir is resolved to a local path on the remote machine df = pd.read_csv(f"{data_dir}/train.csv") # ... @@ -271,13 +198,13 @@ train(Data("./my_dataset/")) train(Data("./my_dataset/")) ``` -**Note on GCS Directories:** When referencing a GCS directory with the `Data` class, you must include a trailing slash (e.g., `Data("gs://my-bucket/dataset/")`). If you omit the trailing slash, the system will treat it as a single file object. +> **GCS Directories:** When referencing a GCS directory with the `Data` class, include a trailing slash (e.g., `Data("gs://my-bucket/dataset/")`). Without the trailing slash, the system treats it as a single file object. You can also pass multiple `Data` arguments, or nest them inside lists and dictionaries (e.g., `train(datasets=[Data("./d1"), Data("./d2")])`). -### 2. Static Data (The `volumes` Parameter) +#### Static Data (The `volumes` Parameter) -For established training scripts where data requirements are static, you can use the `volumes` parameter in the `@keras_remote.run` decorator. This mounts data at fixed, hardcoded absolute filesystem paths, allowing you to drop `keras_remote` into existing codebases without altering the function signature. +For established training scripts where data requirements are fixed, use the `volumes` parameter in the decorator. This mounts data at hardcoded absolute filesystem paths, allowing you to use Keras Remote with existing codebases without altering the function signature. ```python import pandas as pd @@ -299,12 +226,11 @@ def train(): # No data arguments needed! train() - ``` -### 3. Direct GCS Streaming (For Large Datasets) +#### Direct GCS Streaming (For Large Datasets) -If your dataset is very large (e.g., > 10GB), it is inefficient to download the entire dataset to the remote pod's local disk. Instead, skip the `Data` wrapper entirely and pass a GCS URI string directly. You can then use frameworks with native GCS streaming support (like `tf.data` or `grain`) to read the data on the fly. +If your dataset is very large (e.g., > 10GB), it is inefficient to download the entire dataset to the pod's local disk. Instead, skip the `Data` wrapper and pass a GCS URI string directly. Use frameworks with native GCS streaming support (like `tf.data` or `grain`) to read the data on the fly. ```python import grain.python as grain @@ -318,22 +244,135 @@ def train(data_uri): # Pass as a plain string, no Data() wrapper needed train("gs://my-bucket/arrayrecords/") +``` + +### Custom Dependencies + +Create a `requirements.txt` in your project directory: + +```text +tensorflow-datasets +pillow +scikit-learn +``` +Alternatively, dependencies declared in `pyproject.toml` are also supported: + +```toml +[project] +dependencies = [ + "tensorflow-datasets", + "pillow", + "scikit-learn", +] ``` -## Configuration +Keras Remote automatically detects and installs dependencies on the remote worker. +If both files exist in the same directory, `requirements.txt` takes precedence. + +> **Note:** JAX packages (`jax`, `jaxlib`, `libtpu`, `libtpu-nightly`) are automatically filtered from your dependencies to prevent overriding the accelerator-specific JAX installation. To keep a JAX line, append `# kr:keep` to it. -### Environment Variables +### Prebuilt Container Images -| Variable | Required | Default | Description | -| ---------------------------- | -------- | ---------------------- | ----------------------- | -| `KERAS_REMOTE_PROJECT` | Yes | — | Google Cloud project ID | -| `KERAS_REMOTE_ZONE` | No | `us-central1-a` | Default compute zone | -| `KERAS_REMOTE_CLUSTER` | No | `keras-remote-cluster` | GKE cluster name | -| `KERAS_REMOTE_GKE_NAMESPACE` | No | `default` | Kubernetes namespace | +Skip container build time by using prebuilt images: + +```python +@keras_remote.run( + accelerator="v6e-8", + container_image="us-docker.pkg.dev/my-project/keras-remote/prebuilt:v1.0" +) +def train(): + ... +``` + +Build your own prebuilt image using the project's Dockerfile template as a starting point. + +### Forwarding Environment Variables + +Use `capture_env_vars` to propagate local environment variables to the remote pod. This supports exact names and wildcard patterns: + +```python +import keras_remote + +@keras_remote.run( + accelerator="v5litepod-1", + capture_env_vars=["KAGGLE_*", "GOOGLE_CLOUD_*"] +) +def train_gemma(): + import keras_hub + gemma_lm = keras_hub.models.Gemma3CausalLM.from_preset("gemma3_1b") + # KAGGLE_USERNAME and KAGGLE_KEY are available for model downloads + # ... +``` + +This is useful for forwarding API keys, credentials, or configuration without hardcoding them. + +### Multi-Host TPU (Pathways) + +Multi-host TPU configurations (those requiring more than one node, such as `v2-16`, `v3-32`, or `v5p-16`) automatically use the [Pathways](https://cloud.google.com/tpu/docs/pathways-overview) backend. You can also set the backend explicitly: + +```python +@keras_remote.run(accelerator="v3-32", backend="pathways") +def distributed_train(): + ... +``` + +### Multiple Clusters + +You can run multiple independent clusters within the same GCP project — for example, one for GPU workloads and another for TPUs. Each cluster gets its own isolated set of cloud resources (GKE cluster, Artifact Registry, storage buckets) backed by a separate infrastructure stack, so they never interfere with each other. + +**Create clusters** by passing `--cluster` to `keras-remote up`: + +```bash +# Default cluster (named "keras-remote-cluster") +keras-remote up --project=my-project --accelerator=v6e-8 + +# A separate GPU cluster +keras-remote up --project=my-project --cluster=gpu-cluster --accelerator=a100 +``` + +**Target a cluster** in your code with the `cluster` parameter or the `KERAS_REMOTE_CLUSTER` environment variable: + +```python +# Run on the GPU cluster +@keras_remote.run(accelerator="a100", cluster="gpu-cluster") +def train_on_gpu(): + ... + +# Or set the env var to avoid repeating the cluster name +# export KERAS_REMOTE_CLUSTER="gpu-cluster" +@keras_remote.run(accelerator="a100") +def train_on_gpu(): + ... +``` + +All CLI commands accept `--cluster` as well, so you can manage each cluster independently: + +```bash +keras-remote status --cluster=gpu-cluster +keras-remote pool add --cluster=gpu-cluster --accelerator=h100 +keras-remote down --cluster=gpu-cluster +``` + +For more examples, see the [`examples/`](examples/) directory. + +## Reference + +### Configuration + +#### Environment Variables + +| Variable | Required | Default | Description | +| ---------------------------- | -------- | ---------------------- | ------------------------------------------------------------ | +| `KERAS_REMOTE_PROJECT` | Yes | — | Google Cloud project ID | +| `KERAS_REMOTE_ZONE` | No | `us-central1-a` | Default compute zone | +| `KERAS_REMOTE_CLUSTER` | No | `keras-remote-cluster` | GKE cluster name | +| `KERAS_REMOTE_GKE_NAMESPACE` | No | `default` | Kubernetes namespace | | `KERAS_REMOTE_LOG_LEVEL` | No | `INFO` | Log verbosity (`DEBUG`, `INFO`, `WARNING`, `ERROR`, `FATAL`) | -### Decorator Parameters +Keras Remote uses `absl-py` for logging. Set `KERAS_REMOTE_LOG_LEVEL=DEBUG` for verbose output when debugging issues. + +#### Decorator Parameters ```python @keras_remote.run( @@ -349,56 +388,102 @@ train("gs://my-bucket/arrayrecords/") ) ``` -## Supported Accelerators +### Supported Accelerators -Note: each accelerator and topology requires -[setting up its own NodePool](#quick-start) as a prerequisite. +Each accelerator and topology requires [setting up its own node pool](#keras-remote-pool) as a prerequisite. -### TPUs +#### TPUs -| Type | Configurations | -| -------------- | ------------------------------------------- | -| TPU v2 | `v2-4`, `v2-16`, `v2-32` | -| TPU v3 | `v3-4`, `v3-16`, `v3-32` | -| TPU v5 Litepod | `v5litepod-1`, `v5litepod-4`, `v5litepod-8` | -| TPU v5p | `v5p-8`, `v5p-16` | -| TPU v6e | `v6e-8`, `v6e-16` | +| Type | Configurations | +| -------------- | ----------------------------------------------------------------------------------------------------------------------------- | +| TPU v3 | `v3-4`, `v3-16`, `v3-32`, `v3-64`, `v3-128`, `v3-256`, `v3-512`, `v3-1024`, `v3-2048` | +| TPU v4 | `v4-4`, `v4-8`, `v4-16`, `v4-32`, `v4-64`, `v4-128`, `v4-256`, `v4-512`, `v4-1024`, `v4-2048`, `v4-4096` | +| TPU v5 Litepod | `v5litepod-1`, `v5litepod-4`, `v5litepod-8`, `v5litepod-16`, `v5litepod-32`, `v5litepod-64`, `v5litepod-128`, `v5litepod-256` | +| TPU v5p | `v5p-8`, `v5p-16`, `v5p-32` | +| TPU v6e | `v6e-8`, `v6e-16` | -### GPUs +#### GPUs | Type | Aliases | Multi-GPU Counts | | ---------------- | ------------------------------- | ---------------- | | NVIDIA T4 | `t4`, `nvidia-tesla-t4` | 1, 2, 4 | -| NVIDIA L4 | `l4`, `nvidia-l4` | 1, 2, 4 | +| NVIDIA L4 | `l4`, `nvidia-l4` | 1, 2, 4, 8 | | NVIDIA V100 | `v100`, `nvidia-tesla-v100` | 1, 2, 4, 8 | -| NVIDIA A100 | `a100`, `nvidia-tesla-a100` | 1, 2, 4, 8 | -| NVIDIA A100 80GB | `a100-80gb`, `nvidia-a100-80gb` | 1, 2, 4, 8 | +| NVIDIA A100 | `a100`, `nvidia-tesla-a100` | 1, 2, 4, 8, 16 | +| NVIDIA A100 80GB | `a100-80gb`, `nvidia-a100-80gb` | 1, 2, 4, 8, 16 | | NVIDIA H100 | `h100`, `nvidia-h100-80gb` | 1, 2, 4, 8 | +| NVIDIA P4 | `p4`, `nvidia-tesla-p4` | 1, 2, 4 | +| NVIDIA P100 | `p100`, `nvidia-tesla-p100` | 1, 2, 4 | For multi-GPU configurations on GKE, append the count: `a100x4`, `l4x2`, etc. -### CPU +#### CPU Use `accelerator="cpu"` to run on a CPU-only node (no accelerator attached). -### Multi-Host TPU (Pathways) +### CLI Commands -Multi-host TPU configurations (those requiring more than one node, such as `v2-16`, `v3-32`, or `v5p-16`) automatically use the [Pathways](https://cloud.google.com/tpu/docs/pathways-overview) backend. You can also set the backend explicitly: +The `keras-remote` CLI manages your cloud infrastructure. Install it with `pip install -e ".[cli]"`. -```python -@keras_remote.run(accelerator="v3-32", backend="pathways") -def distributed_train(): - ... +#### `keras-remote up` + +Provision all required cloud resources (one-time setup): + +```bash +keras-remote up +keras-remote up --project=my-project --accelerator=t4 --yes +``` + +#### `keras-remote down` + +Remove all Keras Remote resources to avoid ongoing charges: + +```bash +keras-remote down +keras-remote down --yes # Skip confirmation prompt +``` + +This removes the GKE cluster and node pools, Artifact Registry repository and container images, and Cloud Storage buckets. + +#### `keras-remote status` + +View current infrastructure state: + +```bash +keras-remote status +``` + +#### `keras-remote config` + +View current configuration: + +```bash +keras-remote config +``` + +#### `keras-remote pool` + +Manage accelerator node pools after initial setup: + +```bash +# Add a node pool for a specific accelerator +keras-remote pool add --accelerator=v6e-8 + +# List current node pools +keras-remote pool list + +# Remove a node pool by name +keras-remote pool remove ``` -## Monitoring +### Monitoring -### Google Cloud Console +#### Google Cloud Console - **Cloud Build:** [console.cloud.google.com/cloud-build/builds](https://console.cloud.google.com/cloud-build/builds) - **GKE Workloads:** [console.cloud.google.com/kubernetes/workload](https://console.cloud.google.com/kubernetes/workload) -### Command Line +#### Command Line ```bash # List GKE jobs @@ -449,18 +534,10 @@ Check Cloud Build logs: gcloud builds list --project=$KERAS_REMOTE_PROJECT --limit=5 ``` -### Debug Logging - -Keras Remote uses `absl-py` for logging. You can control the log verbosity by setting the `KERAS_REMOTE_LOG_LEVEL` environment variable: - -```bash -export KERAS_REMOTE_LOG_LEVEL="DEBUG" -``` - -Supported levels are `DEBUG`, `INFO`, `WARNING`, `ERROR`, and `FATAL`. The default is `INFO`. - ### Verify Setup +Run `keras-remote status` to check the health of your infrastructure. For manual verification: + ```bash # Check authentication gcloud auth list @@ -471,28 +548,8 @@ echo $KERAS_REMOTE_PROJECT # Check APIs gcloud services list --enabled --project=$KERAS_REMOTE_PROJECT \ | grep -E "(cloudbuild|artifactregistry|storage|container)" - -# Check Artifact Registry -gcloud artifacts repositories describe keras-remote \ - --location=us --project=$KERAS_REMOTE_PROJECT -``` - -## Resource Cleanup - -Remove all Keras Remote resources to avoid charges: - -```bash -keras-remote down ``` -This removes: - -- GKE cluster and accelerator node pools -- Artifact Registry repository and container images -- Cloud Storage buckets (jobs and builds) - -Use `--yes` to skip the confirmation prompt. - ## Contributing Contributions are welcome. Please read our [contributing guidelines](docs/contributing.md) before submitting pull requests.