Skip to content

Commit a41579a

Browse files
Add Data API examples, e2e tests, and CI updates (#67)
1 parent a4569fe commit a41579a

File tree

5 files changed

+501
-2
lines changed

5 files changed

+501
-2
lines changed

.github/workflows/e2e-tests.yaml

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -44,6 +44,8 @@ jobs:
4444

4545
- name: Set up gcloud
4646
uses: google-github-actions/setup-gcloud@v2
47+
with:
48+
install_components: "gke-gcloud-auth-plugin"
4749

4850
- name: Get GKE credentials
4951
uses: google-github-actions/get-gke-credentials@v2
@@ -61,4 +63,4 @@ jobs:
6163
KERAS_REMOTE_PROJECT: ${{ secrets.GCP_PROJECT }}
6264
KERAS_REMOTE_ZONE: ${{ secrets.GKE_ZONE }}
6365
KERAS_REMOTE_CLUSTER: ${{ secrets.GKE_CLUSTER }}
64-
run: python -m unittest discover -s tests/e2e -p "*_test.py" -v
66+
run: python -m pytest tests/e2e/ -v -n auto

examples/example_data_api.py

Lines changed: 109 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,109 @@
1+
import json
2+
import os
3+
import tempfile
4+
5+
import keras_remote
6+
from keras_remote import Data
7+
8+
# Setup: create temporary dummy data
9+
tmp_dir = tempfile.mkdtemp(prefix="kr-data-example-")
10+
dataset_dir = os.path.join(tmp_dir, "dataset")
11+
os.makedirs(dataset_dir, exist_ok=True)
12+
13+
# A small CSV file used by several tests below.
14+
train_csv = os.path.join(dataset_dir, "train.csv")
15+
with open(train_csv, "w") as f:
16+
f.write("feature,label\n1,100\n2,200\n3,300\n")
17+
18+
# A JSON config file used by the single-file and mixed tests.
19+
config_json = os.path.join(tmp_dir, "config.json")
20+
with open(config_json, "w") as f:
21+
json.dump({"lr": 0.01, "epochs": 10}, f)
22+
23+
print(f"Created temp data in {tmp_dir}\n")
24+
25+
26+
# Data as function arg (local directory)
27+
@keras_remote.run(accelerator="cpu")
28+
def test_data_arg(data_dir):
29+
files = sorted(os.listdir(data_dir))
30+
with open(f"{data_dir}/train.csv") as f:
31+
content = f.read()
32+
return {"files": files, "content": content}
33+
34+
35+
result = test_data_arg(Data(dataset_dir))
36+
print(f"Test 1 (dir arg): {result}")
37+
assert result["files"] == ["train.csv"]
38+
assert "1,100" in result["content"]
39+
40+
41+
# Data as function arg (single file)
42+
@keras_remote.run(accelerator="cpu")
43+
def test_file_arg(config_path):
44+
with open(config_path) as f:
45+
return json.load(f)
46+
47+
48+
result = test_file_arg(Data(config_json))
49+
print(f"Test 2 (file arg): {result}")
50+
assert result["lr"] == 0.01
51+
52+
# Cache hit (re-run same data, check logs for "cache hit")
53+
result = test_file_arg(Data(config_json))
54+
print(f"Test 3 (cache hit): {result}")
55+
assert result["lr"] == 0.01
56+
57+
58+
# volumes (fixed-path mount)
59+
@keras_remote.run(
60+
accelerator="cpu",
61+
volumes={"/data": Data(dataset_dir)},
62+
)
63+
def test_volumes():
64+
files = sorted(os.listdir("/data"))
65+
with open("/data/train.csv") as f:
66+
content = f.read()
67+
return {"files": files, "content": content}
68+
69+
70+
result = test_volumes()
71+
print(f"Test 4 (volumes): {result}")
72+
assert result["files"] == ["train.csv"]
73+
74+
75+
# Mixed — volumes + Data arg + plain arg
76+
@keras_remote.run(
77+
accelerator="cpu",
78+
volumes={"/weights": Data(dataset_dir)},
79+
)
80+
def test_mixed(config_path, lr=0.001):
81+
with open(config_path) as f:
82+
cfg = json.load(f)
83+
has_weights = os.path.isdir("/weights")
84+
return {"config": cfg, "lr": lr, "has_weights": has_weights}
85+
86+
87+
result = test_mixed(Data(config_json), lr=0.01)
88+
print(f"Test 5 (mixed): {result}")
89+
assert result["config"]["lr"] == 0.01
90+
assert result["lr"] == 0.01
91+
assert result["has_weights"] is True
92+
93+
94+
# Data in nested structure
95+
@keras_remote.run(accelerator="cpu")
96+
def test_nested(datasets):
97+
return [sorted(os.listdir(d)) for d in datasets]
98+
99+
100+
result = test_nested(
101+
datasets=[
102+
Data(dataset_dir),
103+
Data(dataset_dir),
104+
]
105+
)
106+
print(f"Test 6 (nested): {result}")
107+
assert len(result) == 2
108+
109+
print("\nAll E2E tests passed!")

keras_remote/infra/container_builder.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -7,6 +7,7 @@
77
import tarfile
88
import tempfile
99
import time
10+
import uuid
1011

1112
from absl import logging
1213
from google.api_core import exceptions as google_exceptions
@@ -310,7 +311,7 @@ def _upload_build_source(tarball_path, bucket_name, project):
310311
bucket = client.bucket(bucket_name)
311312

312313
# Upload tarball
313-
blob_name = f"source-{int(time.time())}.tar.gz"
314+
blob_name = f"source-{int(time.time())}-{uuid.uuid4().hex[:8]}.tar.gz"
314315
blob = bucket.blob(blob_name)
315316
blob.upload_from_filename(tarball_path)
316317

pyproject.toml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -34,6 +34,7 @@ cli = [
3434
]
3535
test = [
3636
"coverage>=7.0",
37+
"pytest-xdist>=3.0",
3738
]
3839
dev = [
3940
"pre-commit",

0 commit comments

Comments
 (0)