-
Notifications
You must be signed in to change notification settings - Fork 1
Adds Data Handling section to the README #78
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
|
|
@@ -72,9 +72,10 @@ This adds the `keras-remote up`, `keras-remote down`, `keras-remote status`, and | |
| - Python 3.11+ | ||
| - Google Cloud SDK (`gcloud`) | ||
| - Run `gcloud auth login` and `gcloud auth application-default login` | ||
| - [Pulumi CLI](https://www.pulumi.com/docs/install/) (required for `[cli]` install only) | ||
| - A Google Cloud project with billing enabled | ||
|
|
||
| Note: The Pulumi CLI is bundled and managed automatically. It will be installed to `~/.keras-remote/pulumi` on first use if not already present. | ||
|
|
||
| ## Quick Start | ||
|
|
||
| ### 1. Configure Google Cloud | ||
|
|
@@ -203,15 +204,102 @@ def train(): | |
|
|
||
| See [examples/Dockerfile.prebuilt](examples/Dockerfile.prebuilt) for a template. | ||
|
|
||
| ## Handling Data | ||
|
|
||
| Keras Remote provides a declarative and performant Data API to seamlessly make your local and cloud data available to your remote functions. | ||
|
|
||
| The Data API is designed to be read-only. It reliably delivers data to your pods at the start of a job. For saving model outputs or checkpointing, you should write directly to GCS from within your function. | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Should we provide guidance on where in GCS to write? Or leave it completely up to the user?
Collaborator
Author
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I think it's okay to leave it to the user since we don't really have any framework specific guidance to provide at this stage. |
||
|
|
||
| Under the hood, the Data API optimizes your workflows with two key features: | ||
|
|
||
| - **Smart Caching:** Local data is content-hashed and uploaded to a cache bucket only once. Subsequent job runs that use byte-identical data will hit the cache and skip the upload entirely, drastically speeding up execution. | ||
| - **Automatic Zip Exclusion:** When you reference a data path inside your current working directory, Keras Remote automatically excludes that directory from the project's zipped payload to avoid uploading the same data twice. | ||
|
|
||
| There are three main ways to handle data depending on your workflow: | ||
|
|
||
| ### 1. Dynamic Data (The `Data` Class) | ||
|
|
||
| The simplest and most Pythonic approach is to pass `Data` objects as regular function arguments. The `Data` class wraps a local file/directory path or a Google Cloud Storage (GCS) URI. | ||
|
|
||
| On the remote pod, these objects are automatically resolved into plain string paths pointing to the downloaded files, meaning your function code never needs to know about GCS or cloud storage APIs. | ||
|
Member
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is great! |
||
|
|
||
| ```python | ||
| import pandas as pd | ||
| import keras_remote | ||
| from keras_remote import Data | ||
|
|
||
| @keras_remote.run(accelerator="v6e-8") | ||
| def train(data_dir): | ||
| # data_dir is resolved to a dynamic local path on the remote machine | ||
| df = pd.read_csv(f"{data_dir}/train.csv") | ||
| # ... | ||
|
|
||
| # Uploads the local directory to the remote pod automatically | ||
| train(Data("./my_dataset/")) | ||
|
|
||
| # Cache hit: subsequent runs with the same data skip the upload! | ||
| train(Data("./my_dataset/")) | ||
| ``` | ||
|
|
||
| **Note on GCS Directories:** When referencing a GCS directory with the `Data` class, you must include a trailing slash (e.g., `Data("gs://my-bucket/dataset/")`). If you omit the trailing slash, the system will treat it as a single file object. | ||
|
|
||
| You can also pass multiple `Data` arguments, or nest them inside lists and dictionaries (e.g., `train(datasets=[Data("./d1"), Data("./d2")])`). | ||
|
|
||
| ### 2. Static Data (The `volumes` Parameter) | ||
|
|
||
| For established training scripts where data requirements are static, you can use the `volumes` parameter in the `@keras_remote.run` decorator. This mounts data at fixed, hardcoded absolute filesystem paths, allowing you to drop `keras_remote` into existing codebases without altering the function signature. | ||
|
|
||
| ```python | ||
| import pandas as pd | ||
| import keras_remote | ||
| from keras_remote import Data | ||
|
|
||
| @keras_remote.run( | ||
| accelerator="v6e-8", | ||
| volumes={ | ||
| "/data": Data("./my_dataset/"), | ||
| "/weights": Data("gs://my-bucket/pretrained-weights/") | ||
| } | ||
| ) | ||
| def train(): | ||
| # Data is guaranteed to be available at these absolute paths | ||
| df = pd.read_csv("/data/train.csv") | ||
| model.load_weights("/weights/model.h5") | ||
| # ... | ||
|
|
||
| # No data arguments needed! | ||
| train() | ||
|
|
||
| ``` | ||
|
|
||
| ### 3. Direct GCS Streaming (For Large Datasets) | ||
|
|
||
| If your dataset is very large (e.g., > 10GB), it is inefficient to download the entire dataset to the remote pod's local disk. Instead, skip the `Data` wrapper entirely and pass a GCS URI string directly. You can then use frameworks with native GCS streaming support (like `tf.data` or `grain`) to read the data on the fly. | ||
|
|
||
| ```python | ||
| import grain.python as grain | ||
| import keras_remote | ||
|
|
||
| @keras_remote.run(accelerator="v6e-8") | ||
| def train(data_uri): | ||
| # Native GCS reading, no download overhead | ||
| data_source = grain.ArrayRecordDataSource(data_uri) | ||
| # ... | ||
|
|
||
| # Pass as a plain string, no Data() wrapper needed | ||
| train("gs://my-bucket/arrayrecords/") | ||
|
|
||
| ``` | ||
|
|
||
| ## Configuration | ||
|
|
||
| ### Environment Variables | ||
|
|
||
| | Variable | Required | Default | Description | | ||
| | ---------------------- | -------- | --------------- | ---------------------------------- | | ||
| | `KERAS_REMOTE_PROJECT` | Yes | — | Google Cloud project ID | | ||
| | `KERAS_REMOTE_ZONE` | No | `us-central1-a` | Default compute zone | | ||
| | `KERAS_REMOTE_CLUSTER` | No | — | GKE cluster name | | ||
| | Variable | Required | Default | Description | | ||
| | ---------------------- | -------- | --------------- | ----------------------- | | ||
| | `KERAS_REMOTE_PROJECT` | Yes | — | Google Cloud project ID | | ||
| | `KERAS_REMOTE_ZONE` | No | `us-central1-a` | Default compute zone | | ||
| | `KERAS_REMOTE_CLUSTER` | No | — | GKE cluster name | | ||
|
|
||
| ### Decorator Parameters | ||
|
|
||
|
|
@@ -345,10 +433,10 @@ keras-remote down | |
|
|
||
| This removes: | ||
|
|
||
| - GKE cluster and accelerator node pools (via Pulumi) | ||
| - GKE cluster and accelerator node pools | ||
| - Artifact Registry repository and container images | ||
| - Cloud Storage buckets (jobs and builds) | ||
| Use `--yes` to skip the confirmation prompt. | ||
| Use `--yes` to skip the confirmation prompt. | ||
|
|
||
| ## Contributing | ||
|
|
||
|
|
||
Uh oh!
There was an error while loading. Please reload this page.