|
| 1 | +import json |
| 2 | +import os |
| 3 | +import tempfile |
| 4 | + |
| 5 | +import keras_remote |
| 6 | +from keras_remote import Data |
| 7 | + |
| 8 | +# Setup: create temporary dummy data |
| 9 | +tmp_dir = tempfile.mkdtemp(prefix="kr-data-example-") |
| 10 | +dataset_dir = os.path.join(tmp_dir, "dataset") |
| 11 | +os.makedirs(dataset_dir, exist_ok=True) |
| 12 | + |
| 13 | +# A small CSV file used by several tests below. |
| 14 | +train_csv = os.path.join(dataset_dir, "train.csv") |
| 15 | +with open(train_csv, "w") as f: |
| 16 | + f.write("feature,label\n1,100\n2,200\n3,300\n") |
| 17 | + |
| 18 | +# A JSON config file used by the single-file and mixed tests. |
| 19 | +config_json = os.path.join(tmp_dir, "config.json") |
| 20 | +with open(config_json, "w") as f: |
| 21 | + json.dump({"lr": 0.01, "epochs": 10}, f) |
| 22 | + |
| 23 | +print(f"Created temp data in {tmp_dir}\n") |
| 24 | + |
| 25 | + |
| 26 | +# Data as function arg (local directory) |
| 27 | +@keras_remote.run(accelerator="cpu") |
| 28 | +def test_data_arg(data_dir): |
| 29 | + files = sorted(os.listdir(data_dir)) |
| 30 | + with open(f"{data_dir}/train.csv") as f: |
| 31 | + content = f.read() |
| 32 | + return {"files": files, "content": content} |
| 33 | + |
| 34 | + |
| 35 | +result = test_data_arg(Data(dataset_dir)) |
| 36 | +print(f"Test 1 (dir arg): {result}") |
| 37 | +assert result["files"] == ["train.csv"] |
| 38 | +assert "1,100" in result["content"] |
| 39 | + |
| 40 | + |
| 41 | +# Data as function arg (single file) |
| 42 | +@keras_remote.run(accelerator="cpu") |
| 43 | +def test_file_arg(config_path): |
| 44 | + with open(config_path) as f: |
| 45 | + return json.load(f) |
| 46 | + |
| 47 | + |
| 48 | +result = test_file_arg(Data(config_json)) |
| 49 | +print(f"Test 2 (file arg): {result}") |
| 50 | +assert result["lr"] == 0.01 |
| 51 | + |
| 52 | +# Cache hit (re-run same data, check logs for "cache hit") |
| 53 | +result = test_file_arg(Data(config_json)) |
| 54 | +print(f"Test 3 (cache hit): {result}") |
| 55 | +assert result["lr"] == 0.01 |
| 56 | + |
| 57 | + |
| 58 | +# volumes (fixed-path mount) |
| 59 | +@keras_remote.run( |
| 60 | + accelerator="cpu", |
| 61 | + volumes={"/data": Data(dataset_dir)}, |
| 62 | +) |
| 63 | +def test_volumes(): |
| 64 | + files = sorted(os.listdir("/data")) |
| 65 | + with open("/data/train.csv") as f: |
| 66 | + content = f.read() |
| 67 | + return {"files": files, "content": content} |
| 68 | + |
| 69 | + |
| 70 | +result = test_volumes() |
| 71 | +print(f"Test 4 (volumes): {result}") |
| 72 | +assert result["files"] == ["train.csv"] |
| 73 | + |
| 74 | + |
| 75 | +# Mixed — volumes + Data arg + plain arg |
| 76 | +@keras_remote.run( |
| 77 | + accelerator="cpu", |
| 78 | + volumes={"/weights": Data(dataset_dir)}, |
| 79 | +) |
| 80 | +def test_mixed(config_path, lr=0.001): |
| 81 | + with open(config_path) as f: |
| 82 | + cfg = json.load(f) |
| 83 | + has_weights = os.path.isdir("/weights") |
| 84 | + return {"config": cfg, "lr": lr, "has_weights": has_weights} |
| 85 | + |
| 86 | + |
| 87 | +result = test_mixed(Data(config_json), lr=0.01) |
| 88 | +print(f"Test 5 (mixed): {result}") |
| 89 | +assert result["config"]["lr"] == 0.01 |
| 90 | +assert result["lr"] == 0.01 |
| 91 | +assert result["has_weights"] is True |
| 92 | + |
| 93 | + |
| 94 | +# Data in nested structure |
| 95 | +@keras_remote.run(accelerator="cpu") |
| 96 | +def test_nested(datasets): |
| 97 | + return [sorted(os.listdir(d)) for d in datasets] |
| 98 | + |
| 99 | + |
| 100 | +result = test_nested( |
| 101 | + datasets=[ |
| 102 | + Data(dataset_dir), |
| 103 | + Data(dataset_dir), |
| 104 | + ] |
| 105 | +) |
| 106 | +print(f"Test 6 (nested): {result}") |
| 107 | +assert len(result) == 2 |
| 108 | + |
| 109 | +print("\nAll E2E tests passed!") |
0 commit comments