Skip to content

Commit 58a5b58

Browse files
Merge branch 'keras-team:main' into simplify-hw-names
2 parents 437c1bf + 8042207 commit 58a5b58

File tree

21 files changed

+2191
-89
lines changed

21 files changed

+2191
-89
lines changed

.github/workflows/e2e-tests.yaml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,8 @@ jobs:
4444

4545
- name: Set up gcloud
4646
uses: google-github-actions/setup-gcloud@v2
47+
with:
48+
install_components: "gke-gcloud-auth-plugin"
4749

4850
- name: Get GKE credentials
4951
uses: google-github-actions/get-gke-credentials@v2
@@ -61,4 +63,4 @@ jobs:
6163
KERAS_REMOTE_PROJECT: ${{ secrets.GCP_PROJECT }}
6264
KERAS_REMOTE_ZONE: ${{ secrets.GKE_ZONE }}
6365
KERAS_REMOTE_CLUSTER: ${{ secrets.GKE_CLUSTER }}
64-
run: python -m unittest discover -s tests/e2e -p "*_test.py" -v
66+
run: python -m pytest tests/e2e/ -v -n auto

AGENTS.md

Lines changed: 51 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ Keras Remote lets users execute Keras/JAX workloads on cloud TPUs and GPUs via a
1010
keras_remote/
1111
├── core/ # @run decorator, accelerator registry & parser
1212
├── backend/ # Job execution backends (GKE, Pathways)
13+
├── data/ # Data class for declaring data dependencies
1314
├── infra/ # Docker container building & caching
1415
├── runner/ # Remote worker entrypoint (runs inside container)
1516
├── utils/ # Serialization (packager) and Cloud Storage helpers
@@ -26,7 +27,7 @@ keras_remote/
2627
@keras_remote.run() called
2728
→ JobContext.from_params() # Resolve config from args/env vars
2829
→ ensure_credentials() # Verify/auto-configure gcloud, ADC, kubeconfig
29-
→ _prepare_artifacts() # Serialize function (cloudpickle), zip working dir
30+
→ _prepare_artifacts() # Upload Data, serialize function, zip working dir
3031
→ _build_container() # Build or retrieve cached Docker image
3132
→ _upload_artifacts() # Upload payload.pkl, context.zip to GCS
3233
→ backend.submit_job() # Create K8s Job (GKE) or LeaderWorkerSet (Pathways)
@@ -46,9 +47,10 @@ keras_remote/
4647
| `backend/gke_client.py` | K8s Job creation, status polling, pod log retrieval |
4748
| `backend/pathways_client.py` | LeaderWorkerSet creation for multi-host TPUs |
4849
| `infra/container_builder.py` | Content-hashed Docker image building via Cloud Build |
49-
| `utils/packager.py` | `save_payload()` (cloudpickle), `zip_working_dir()` |
50-
| `utils/storage.py` | GCS upload/download/cleanup for job artifacts |
51-
| `runner/remote_runner.py` | Runs inside container: deserialize, execute, upload result |
50+
| `data/data.py` | `Data` class, content hashing, data ref serialization |
51+
| `utils/packager.py` | `save_payload()` (cloudpickle), `zip_working_dir()`, Data ref extraction |
52+
| `utils/storage.py` | GCS upload/download/cleanup for job artifacts and Data cache |
53+
| `runner/remote_runner.py` | Runs inside container: resolve Data refs/volumes, execute, upload result |
5254
| `cli/commands/pool.py` | Node pool add/remove/list commands |
5355
| `cli/infra/post_deploy.py` | kubectl, LWS CRD, GPU driver setup after stack.up() |
5456
| `cli/constants.py` | CLI defaults, paths, API list |
@@ -59,8 +61,53 @@ keras_remote/
5961
- **`JobContext`** (`backend/execution.py`): Mutable dataclass carrying all job state through the pipeline — inputs, generated IDs, artifact paths, image URI.
6062
- **`BaseK8sBackend`** (`backend/execution.py`): Base class with `submit_job`, `wait_for_job`, `cleanup_job`. Subclassed by `GKEBackend` and `PathwaysBackend`.
6163
- **`GpuConfig` / `TpuConfig`** (`core/accelerators.py`): Frozen dataclasses for accelerator metadata. Single source of truth used by runtime, container builder, and CLI.
64+
- **`Data`** (`data/data.py`): Wraps a local path or GCS URI. Passed as a function argument or via the `volumes` decorator parameter. Resolved to a plain filesystem path on the remote pod. Content-hashed for upload caching.
6265
- **`InfraConfig` / `NodePoolConfig`** (`cli/config.py`): CLI provisioning configuration. `InfraConfig` holds project, zone, cluster name, and a list of `NodePoolConfig` entries. `NodePoolConfig` pairs a unique pool name (e.g., `gpu-l4-a3f2`) with a `GpuConfig` or `TpuConfig`.
6366

67+
## Data API
68+
69+
The `Data` class (`keras_remote.Data`) declares data dependencies for remote functions. It accepts local file/directory paths or GCS URIs (`gs://...`).
70+
71+
### Two usage patterns
72+
73+
**Function arguments**`Data` objects passed as args/kwargs are uploaded to GCS, serialized as data ref dicts in the payload, and resolved to local paths on the pod:
74+
75+
```python
76+
@keras_remote.run(accelerator="v3-8")
77+
def train(data_dir, config_path):
78+
... # data_dir and config_path are plain strings
79+
80+
train(Data("./dataset/"), Data("./config.json"))
81+
```
82+
83+
**Volumes**`Data` objects in the `volumes=` decorator parameter are downloaded to fixed mount paths before execution:
84+
85+
```python
86+
@keras_remote.run(accelerator="v3-8", volumes={"/data": Data("./dataset/")})
87+
def train():
88+
files = os.listdir("/data") # available at mount path
89+
```
90+
91+
Both patterns can be combined. `Data` objects can also be nested inside lists, dicts, and other containers — they are recursively discovered and resolved.
92+
93+
### Content-addressed caching
94+
95+
Local `Data` objects are content-hashed (SHA-256 over sorted file contents). Uploads go to `gs://{bucket}/{namespace}/data-cache/{hash}/`. A `.cache_marker` sentinel enables O(1) cache-hit checks. Identical data is uploaded only once.
96+
97+
### Pipeline integration
98+
99+
During `_prepare_artifacts()`:
100+
101+
1. Upload `Data` from `volumes` and function args via `storage.upload_data()` (content-addressed)
102+
2. Replace `Data` objects in args/kwargs with serializable `__data_ref__` dicts
103+
3. Local `Data` paths inside the caller directory are auto-excluded from `context.zip`
104+
105+
On the remote pod (`remote_runner.py`):
106+
107+
1. `resolve_volumes()` — download volume data to mount paths
108+
2. `resolve_data_refs()` — recursively resolve `__data_ref__` dicts in args/kwargs to local paths
109+
3. Single-file `Data` resolves to the file path; directory `Data` resolves to the directory path
110+
64111
## Conventions
65112

66113
### Code Style

examples/example_data_api.py

Lines changed: 109 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,109 @@
1+
import json
2+
import os
3+
import tempfile
4+
5+
import keras_remote
6+
from keras_remote import Data
7+
8+
# Setup: create temporary dummy data
9+
tmp_dir = tempfile.mkdtemp(prefix="kr-data-example-")
10+
dataset_dir = os.path.join(tmp_dir, "dataset")
11+
os.makedirs(dataset_dir, exist_ok=True)
12+
13+
# A small CSV file used by several tests below.
14+
train_csv = os.path.join(dataset_dir, "train.csv")
15+
with open(train_csv, "w") as f:
16+
f.write("feature,label\n1,100\n2,200\n3,300\n")
17+
18+
# A JSON config file used by the single-file and mixed tests.
19+
config_json = os.path.join(tmp_dir, "config.json")
20+
with open(config_json, "w") as f:
21+
json.dump({"lr": 0.01, "epochs": 10}, f)
22+
23+
print(f"Created temp data in {tmp_dir}\n")
24+
25+
26+
# Data as function arg (local directory)
27+
@keras_remote.run(accelerator="cpu")
28+
def test_data_arg(data_dir):
29+
files = sorted(os.listdir(data_dir))
30+
with open(f"{data_dir}/train.csv") as f:
31+
content = f.read()
32+
return {"files": files, "content": content}
33+
34+
35+
result = test_data_arg(Data(dataset_dir))
36+
print(f"Test 1 (dir arg): {result}")
37+
assert result["files"] == ["train.csv"]
38+
assert "1,100" in result["content"]
39+
40+
41+
# Data as function arg (single file)
42+
@keras_remote.run(accelerator="cpu")
43+
def test_file_arg(config_path):
44+
with open(config_path) as f:
45+
return json.load(f)
46+
47+
48+
result = test_file_arg(Data(config_json))
49+
print(f"Test 2 (file arg): {result}")
50+
assert result["lr"] == 0.01
51+
52+
# Cache hit (re-run same data, check logs for "cache hit")
53+
result = test_file_arg(Data(config_json))
54+
print(f"Test 3 (cache hit): {result}")
55+
assert result["lr"] == 0.01
56+
57+
58+
# volumes (fixed-path mount)
59+
@keras_remote.run(
60+
accelerator="cpu",
61+
volumes={"/data": Data(dataset_dir)},
62+
)
63+
def test_volumes():
64+
files = sorted(os.listdir("/data"))
65+
with open("/data/train.csv") as f:
66+
content = f.read()
67+
return {"files": files, "content": content}
68+
69+
70+
result = test_volumes()
71+
print(f"Test 4 (volumes): {result}")
72+
assert result["files"] == ["train.csv"]
73+
74+
75+
# Mixed — volumes + Data arg + plain arg
76+
@keras_remote.run(
77+
accelerator="cpu",
78+
volumes={"/weights": Data(dataset_dir)},
79+
)
80+
def test_mixed(config_path, lr=0.001):
81+
with open(config_path) as f:
82+
cfg = json.load(f)
83+
has_weights = os.path.isdir("/weights")
84+
return {"config": cfg, "lr": lr, "has_weights": has_weights}
85+
86+
87+
result = test_mixed(Data(config_json), lr=0.01)
88+
print(f"Test 5 (mixed): {result}")
89+
assert result["config"]["lr"] == 0.01
90+
assert result["lr"] == 0.01
91+
assert result["has_weights"] is True
92+
93+
94+
# Data in nested structure
95+
@keras_remote.run(accelerator="cpu")
96+
def test_nested(datasets):
97+
return [sorted(os.listdir(d)) for d in datasets]
98+
99+
100+
result = test_nested(
101+
datasets=[
102+
Data(dataset_dir),
103+
Data(dataset_dir),
104+
]
105+
)
106+
print(f"Test 6 (nested): {result}")
107+
assert len(result) == 2
108+
109+
print("\nAll E2E tests passed!")

keras_remote/__init__.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,3 +6,4 @@
66
os.environ.setdefault("GRPC_ENABLE_FORK_SUPPORT", "0")
77

88
from keras_remote.core.core import run as run
9+
from keras_remote.data import Data as Data

keras_remote/backend/execution.py

Lines changed: 61 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,9 @@
1818
from keras_remote.backend import gke_client, pathways_client
1919
from keras_remote.constants import get_default_zone, zone_to_region
2020
from keras_remote.credentials import ensure_credentials
21+
from keras_remote.data import _make_data_ref
2122
from keras_remote.infra import container_builder
23+
from keras_remote.infra.infra import get_default_project
2224
from keras_remote.utils import packager, storage
2325

2426

@@ -46,6 +48,9 @@ class JobContext:
4648
region: str = field(init=False)
4749
display_name: str = field(init=False)
4850

51+
# Data volumes {mount_path: Data}
52+
volumes: Optional[dict] = None
53+
4954
# Artifact paths (set during prepare phase)
5055
payload_path: Optional[str] = None
5156
context_path: Optional[str] = None
@@ -68,14 +73,13 @@ def from_params(
6873
zone: Optional[str],
6974
project: Optional[str],
7075
env_vars: dict,
76+
volumes: Optional[dict] = None,
7177
) -> "JobContext":
7278
"""Factory method with default resolution for zone/project."""
7379
if not zone:
7480
zone = get_default_zone()
7581
if not project:
76-
project = os.environ.get("KERAS_REMOTE_PROJECT") or os.environ.get(
77-
"GOOGLE_CLOUD_PROJECT"
78-
)
82+
project = get_default_project()
7983
if not project:
8084
raise ValueError(
8185
"project must be specified or set KERAS_REMOTE_PROJECT"
@@ -91,13 +95,14 @@ def from_params(
9195
container_image=container_image,
9296
zone=zone,
9397
project=project,
98+
volumes=volumes,
9499
)
95100

96101

97102
class BaseK8sBackend:
98103
"""Base class for Kubernetes-based backends."""
99104

100-
def __init__(self, cluster: Optional[str] = None, namespace: str = "default"):
105+
def __init__(self, cluster: str, namespace: str = "default"):
101106
self.cluster = cluster
102107
self.namespace = namespace
103108

@@ -203,6 +208,14 @@ def _find_requirements(start_dir: str) -> Optional[str]:
203208
return None
204209

205210

211+
def _maybe_exclude(data_path, caller_path, exclude_paths):
212+
"""Add data_path to exclude_paths if it's inside the caller directory."""
213+
data_abs = os.path.normpath(data_path)
214+
caller_abs = os.path.normpath(caller_path)
215+
if data_abs.startswith(caller_abs + os.sep) or data_abs == caller_abs:
216+
exclude_paths.add(data_abs)
217+
218+
206219
def _prepare_artifacts(
207220
ctx: JobContext, tmpdir: str, caller_frame_depth: int = 3
208221
) -> None:
@@ -212,21 +225,58 @@ def _prepare_artifacts(
212225
# Get caller directory
213226
frame = inspect.stack()[caller_frame_depth]
214227
module = inspect.getmodule(frame[0])
215-
if module:
228+
caller_path: str
229+
if module and module.__file__:
216230
caller_path = os.path.dirname(os.path.abspath(module.__file__))
217231
else:
218232
caller_path = os.getcwd()
219233

220-
# Serialize function + args
234+
# Process Data objects
235+
exclude_paths: set[str] = set()
236+
ref_map = {} # id(Data) -> ref dict (for arg replacement)
237+
volume_refs = [] # list of ref dicts (for volumes)
238+
239+
# Process volumes
240+
if ctx.volumes:
241+
for mount_path, data_obj in ctx.volumes.items():
242+
gcs_uri = storage.upload_data(ctx.bucket_name, data_obj, ctx.project)
243+
volume_refs.append(
244+
_make_data_ref(gcs_uri, data_obj.is_dir, mount_path=mount_path)
245+
)
246+
if not data_obj.is_gcs:
247+
_maybe_exclude(data_obj.path, caller_path, exclude_paths)
248+
249+
# Process Data in function args
250+
data_refs = packager.extract_data_refs(ctx.args, ctx.kwargs)
251+
for data_obj, _position in data_refs:
252+
gcs_uri = storage.upload_data(ctx.bucket_name, data_obj, ctx.project)
253+
ref_map[id(data_obj)] = _make_data_ref(gcs_uri, data_obj.is_dir)
254+
if not data_obj.is_gcs:
255+
_maybe_exclude(data_obj.path, caller_path, exclude_paths)
256+
257+
# Replace Data with refs in args/kwargs
258+
if ref_map:
259+
ctx.args, ctx.kwargs = packager.replace_data_with_refs(
260+
ctx.args, ctx.kwargs, ref_map
261+
)
262+
263+
# Serialize function + args (with volume refs)
221264
ctx.payload_path = os.path.join(tmpdir, "payload.pkl")
222265
packager.save_payload(
223-
ctx.func, ctx.args, ctx.kwargs, ctx.env_vars, ctx.payload_path
266+
ctx.func,
267+
ctx.args,
268+
ctx.kwargs,
269+
ctx.env_vars,
270+
ctx.payload_path,
271+
volumes=volume_refs or None,
224272
)
225273
logging.info("Payload serialized to %s", ctx.payload_path)
226274

227-
# Zip working directory
275+
# Zip working directory (excluding Data paths)
228276
ctx.context_path = os.path.join(tmpdir, "context.zip")
229-
packager.zip_working_dir(caller_path, ctx.context_path)
277+
packager.zip_working_dir(
278+
caller_path, ctx.context_path, exclude_paths=exclude_paths
279+
)
230280
logging.info("Context packaged to %s", ctx.context_path)
231281

232282
# Find requirements.txt
@@ -258,6 +308,8 @@ def _build_container(ctx: JobContext) -> None:
258308

259309
def _upload_artifacts(ctx: JobContext) -> None:
260310
"""Phase 3: Upload artifacts to Cloud Storage."""
311+
if ctx.payload_path is None or ctx.context_path is None:
312+
raise ValueError("payload_path and context_path must be set before upload")
261313
logging.info("Uploading artifacts to Cloud Storage (job: %s)...", ctx.job_id)
262314
storage.upload_artifacts(
263315
bucket_name=ctx.bucket_name,

0 commit comments

Comments
 (0)