Skip to content

Commit 6913365

Browse files
adds Data Handling section
1 parent 23a4a6c commit 6913365

File tree

1 file changed

+88
-0
lines changed

1 file changed

+88
-0
lines changed

README.md

Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -202,6 +202,94 @@ def train():
202202

203203
See [examples/Dockerfile.prebuilt](examples/Dockerfile.prebuilt) for a template.
204204

205+
## Handling Data
206+
207+
Keras Remote provides a declarative and performant Data API to seamlessly make your local and cloud data available to your remote functions.
208+
209+
The Data API is designed to be read-only. It reliably delivers data to your pods at the start of a job. For saving model outputs or checkpointing, you should write directly to GCS from within your function.
210+
211+
Under the hood, the Data API optimizes your workflows with two key features:
212+
213+
* **Smart Caching:** Local data is content-hashed and uploaded to a cache bucket only once. Subsequent job runs that use byte-identical data will hit the cache and skip the upload entirely, drastically speeding up execution.
214+
* **Automatic Zip Exclusion:** When you reference a data path inside your current working directory, Keras Remote automatically excludes that directory from the project's zipped payload to avoid uploading the same data twice.
215+
216+
There are three main ways to handle data depending on your workflow:
217+
218+
### 1. Dynamic Data (The `Data` Class)
219+
220+
The simplest and most Pythonic approach is to pass `Data` objects as regular function arguments. The `Data` class wraps a local file/directory path or a Google Cloud Storage (GCS) URI.
221+
222+
On the remote pod, these objects are automatically resolved into plain string paths pointing to the downloaded files, meaning your function code never needs to know about GCS or cloud storage APIs.
223+
224+
```python
225+
import pandas as pd
226+
import keras_remote
227+
from keras_remote import Data
228+
229+
@keras_remote.run(accelerator="v6e-8")
230+
def train(data_dir):
231+
# data_dir is resolved to a dynamic local path on the remote machine
232+
df = pd.read_csv(f"{data_dir}/train.csv")
233+
# ...
234+
235+
# Uploads the local directory to the remote pod automatically
236+
train(Data("./my_dataset/"))
237+
238+
# Cache hit: subsequent runs with the same data skip the upload!
239+
train(Data("./my_dataset/"))
240+
241+
```
242+
243+
**Note on GCS Directories:** When referencing a GCS directory with the `Data` class, you must include a trailing slash (e.g., `Data("gs://my-bucket/dataset/")`). If you omit the trailing slash, the system will treat it as a single file object.
244+
245+
You can also pass multiple `Data` arguments, or nest them inside lists and dictionaries (e.g., `train(datasets=[Data("./d1"), Data("./d2")])`).
246+
247+
### 2. Static Data (The `volumes` Parameter)
248+
249+
For established training scripts where data requirements are static, you can use the `volumes` parameter in the `@keras_remote.run` decorator. This mounts data at fixed, hardcoded absolute filesystem paths, allowing you to drop `keras_remote` into existing codebases without altering the function signature.
250+
251+
```python
252+
import pandas as pd
253+
import keras_remote
254+
from keras_remote import Data
255+
256+
@keras_remote.run(
257+
accelerator="v6e-8",
258+
volumes={
259+
"/data": Data("./my_dataset/"),
260+
"/weights": Data("gs://my-bucket/pretrained-weights/")
261+
}
262+
)
263+
def train():
264+
# Data is guaranteed to be available at these absolute paths
265+
df = pd.read_csv("/data/train.csv")
266+
model.load_weights("/weights/model.h5")
267+
# ...
268+
269+
# No data arguments needed!
270+
train()
271+
272+
```
273+
274+
### 3. Direct GCS Streaming (For Large Datasets)
275+
276+
If your dataset is very large (e.g., > 10GB), it is inefficient to download the entire dataset to the remote pod's local disk. Instead, skip the `Data` wrapper entirely and pass a GCS URI string directly. You can then use frameworks with native GCS streaming support (like `tf.data` or `grain`) to read the data on the fly.
277+
278+
```python
279+
import grain.python as grain
280+
import keras_remote
281+
282+
@keras_remote.run(accelerator="v6e-8")
283+
def train(data_uri):
284+
# Native GCS reading, no download overhead
285+
data_source = grain.ArrayRecordDataSource(data_uri)
286+
# ...
287+
288+
# Pass as a plain string, no Data() wrapper needed
289+
train("gs://my-bucket/arrayrecords/")
290+
291+
```
292+
205293
## Configuration
206294

207295
### Environment Variables

0 commit comments

Comments
 (0)