-
Notifications
You must be signed in to change notification settings - Fork 1
Expand file tree
/
Copy pathexample_data_api.py
More file actions
109 lines (82 loc) · 2.7 KB
/
example_data_api.py
File metadata and controls
109 lines (82 loc) · 2.7 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
import json
import os
import tempfile
import keras_remote
from keras_remote import Data
# Setup: create temporary dummy data
tmp_dir = tempfile.mkdtemp(prefix="kr-data-example-")
dataset_dir = os.path.join(tmp_dir, "dataset")
os.makedirs(dataset_dir, exist_ok=True)
# A small CSV file used by several tests below.
train_csv = os.path.join(dataset_dir, "train.csv")
with open(train_csv, "w") as f:
f.write("feature,label\n1,100\n2,200\n3,300\n")
# A JSON config file used by the single-file and mixed tests.
config_json = os.path.join(tmp_dir, "config.json")
with open(config_json, "w") as f:
json.dump({"lr": 0.01, "epochs": 10}, f)
print(f"Created temp data in {tmp_dir}\n")
# Data as function arg (local directory)
@keras_remote.run(accelerator="cpu")
def test_data_arg(data_dir):
files = sorted(os.listdir(data_dir))
with open(f"{data_dir}/train.csv") as f:
content = f.read()
return {"files": files, "content": content}
result = test_data_arg(Data(dataset_dir))
print(f"Test 1 (dir arg): {result}")
assert result["files"] == ["train.csv"]
assert "1,100" in result["content"]
# Data as function arg (single file)
@keras_remote.run(accelerator="cpu")
def test_file_arg(config_path):
with open(config_path) as f:
return json.load(f)
result = test_file_arg(Data(config_json))
print(f"Test 2 (file arg): {result}")
assert result["lr"] == 0.01
# Cache hit (re-run same data, check logs for "cache hit")
result = test_file_arg(Data(config_json))
print(f"Test 3 (cache hit): {result}")
assert result["lr"] == 0.01
# volumes (fixed-path mount)
@keras_remote.run(
accelerator="cpu",
volumes={"/data": Data(dataset_dir)},
)
def test_volumes():
files = sorted(os.listdir("/data"))
with open("/data/train.csv") as f:
content = f.read()
return {"files": files, "content": content}
result = test_volumes()
print(f"Test 4 (volumes): {result}")
assert result["files"] == ["train.csv"]
# Mixed — volumes + Data arg + plain arg
@keras_remote.run(
accelerator="cpu",
volumes={"/weights": Data(dataset_dir)},
)
def test_mixed(config_path, lr=0.001):
with open(config_path) as f:
cfg = json.load(f)
has_weights = os.path.isdir("/weights")
return {"config": cfg, "lr": lr, "has_weights": has_weights}
result = test_mixed(Data(config_json), lr=0.01)
print(f"Test 5 (mixed): {result}")
assert result["config"]["lr"] == 0.01
assert result["lr"] == 0.01
assert result["has_weights"] is True
# Data in nested structure
@keras_remote.run(accelerator="cpu")
def test_nested(datasets):
return [sorted(os.listdir(d)) for d in datasets]
result = test_nested(
datasets=[
Data(dataset_dir),
Data(dataset_dir),
]
)
print(f"Test 6 (nested): {result}")
assert len(result) == 2
print("\nAll E2E tests passed!")