Skip to content
Merged
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
104 changes: 96 additions & 8 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -72,9 +72,10 @@ This adds the `keras-remote up`, `keras-remote down`, `keras-remote status`, and
- Python 3.11+
- Google Cloud SDK (`gcloud`)
- Run `gcloud auth login` and `gcloud auth application-default login`
- [Pulumi CLI](https://www.pulumi.com/docs/install/) (required for `[cli]` install only)
- A Google Cloud project with billing enabled

Note: The Pulumi CLI is bundled and managed automatically. It will be installed to `~/.keras-remote/pulumi` on first use if not already present.

## Quick Start

### 1. Configure Google Cloud
Expand Down Expand Up @@ -203,15 +204,102 @@ def train():

See [examples/Dockerfile.prebuilt](examples/Dockerfile.prebuilt) for a template.

## Handling Data

Keras Remote provides a declarative and performant Data API to seamlessly make your local and cloud data available to your remote functions.

The Data API is designed to be read-only. It reliably delivers data to your pods at the start of a job. For saving model outputs or checkpointing, you should write directly to GCS from within your function.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should we provide guidance on where in GCS to write? Or leave it completely up to the user?

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it's okay to leave it to the user since we don't really have any framework specific guidance to provide at this stage.


Under the hood, the Data API optimizes your workflows with two key features:

- **Smart Caching:** Local data is content-hashed and uploaded to a cache bucket only once. Subsequent job runs that use byte-identical data will hit the cache and skip the upload entirely, drastically speeding up execution.
- **Automatic Zip Exclusion:** When you reference a data path inside your current working directory, Keras Remote automatically excludes that directory from the project's zipped payload to avoid uploading the same data twice.

There are three main ways to handle data depending on your workflow:

### 1. Dynamic Data (The `Data` Class)

The simplest and most Pythonic approach is to pass `Data` objects as regular function arguments. The `Data` class wraps a local file/directory path or a Google Cloud Storage (GCS) URI.

On the remote pod, these objects are automatically resolved into plain string paths pointing to the downloaded files, meaning your function code never needs to know about GCS or cloud storage APIs.
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is great!


```python
import pandas as pd
import keras_remote
from keras_remote import Data

@keras_remote.run(accelerator="v6e-8")
def train(data_dir):
# data_dir is resolved to a dynamic local path on the remote machine
df = pd.read_csv(f"{data_dir}/train.csv")
# ...

# Uploads the local directory to the remote pod automatically
train(Data("./my_dataset/"))

# Cache hit: subsequent runs with the same data skip the upload!
train(Data("./my_dataset/"))
```

**Note on GCS Directories:** When referencing a GCS directory with the `Data` class, you must include a trailing slash (e.g., `Data("gs://my-bucket/dataset/")`). If you omit the trailing slash, the system will treat it as a single file object.

You can also pass multiple `Data` arguments, or nest them inside lists and dictionaries (e.g., `train(datasets=[Data("./d1"), Data("./d2")])`).

### 2. Static Data (The `volumes` Parameter)

For established training scripts where data requirements are static, you can use the `volumes` parameter in the `@keras_remote.run` decorator. This mounts data at fixed, hardcoded absolute filesystem paths, allowing you to drop `keras_remote` into existing codebases without altering the function signature.

```python
import pandas as pd
import keras_remote
from keras_remote import Data

@keras_remote.run(
accelerator="v6e-8",
volumes={
"/data": Data("./my_dataset/"),
"/weights": Data("gs://my-bucket/pretrained-weights/")
}
)
def train():
# Data is guaranteed to be available at these absolute paths
df = pd.read_csv("/data/train.csv")
model.load_weights("/weights/model.h5")
# ...

# No data arguments needed!
train()

```

### 3. Direct GCS Streaming (For Large Datasets)

If your dataset is very large (e.g., > 10GB), it is inefficient to download the entire dataset to the remote pod's local disk. Instead, skip the `Data` wrapper entirely and pass a GCS URI string directly. You can then use frameworks with native GCS streaming support (like `tf.data` or `grain`) to read the data on the fly.

```python
import grain.python as grain
import keras_remote

@keras_remote.run(accelerator="v6e-8")
def train(data_uri):
# Native GCS reading, no download overhead
data_source = grain.ArrayRecordDataSource(data_uri)
# ...

# Pass as a plain string, no Data() wrapper needed
train("gs://my-bucket/arrayrecords/")

```

## Configuration

### Environment Variables

| Variable | Required | Default | Description |
| ---------------------- | -------- | --------------- | ---------------------------------- |
| `KERAS_REMOTE_PROJECT` | Yes | — | Google Cloud project ID |
| `KERAS_REMOTE_ZONE` | No | `us-central1-a` | Default compute zone |
| `KERAS_REMOTE_CLUSTER` | No | — | GKE cluster name |
| Variable | Required | Default | Description |
| ---------------------- | -------- | --------------- | ----------------------- |
| `KERAS_REMOTE_PROJECT` | Yes | — | Google Cloud project ID |
| `KERAS_REMOTE_ZONE` | No | `us-central1-a` | Default compute zone |
| `KERAS_REMOTE_CLUSTER` | No | — | GKE cluster name |

### Decorator Parameters

Expand Down Expand Up @@ -345,10 +433,10 @@ keras-remote down

This removes:

- GKE cluster and accelerator node pools (via Pulumi)
- GKE cluster and accelerator node pools
- Artifact Registry repository and container images
- Cloud Storage buckets (jobs and builds)
Use `--yes` to skip the confirmation prompt.
Use `--yes` to skip the confirmation prompt.

## Contributing

Expand Down
Loading