Skip to content

Commit efea2fd

Browse files
Adds test infrastructure and pure logic tests (#30)
- Adds pytest, pytest-cov, pytest-mock, pytest-timeout as test dependencies - Adds [test] extra and updates [dev] extra to include it - Adds pytest configuration and ruff ignores for test files - Adds GitHub Actions workflow for running tests on PRs
1 parent 516107f commit efea2fd

File tree

10 files changed

+426
-1
lines changed

10 files changed

+426
-1
lines changed

.github/workflows/tests.yaml

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
name: Tests
2+
3+
on:
4+
push:
5+
branches: [main]
6+
pull_request:
7+
branches: [main]
8+
9+
jobs:
10+
test:
11+
runs-on: ubuntu-latest
12+
steps:
13+
- uses: actions/checkout@v4
14+
15+
- name: Set up Python 3.13
16+
uses: actions/setup-python@v5
17+
with:
18+
python-version: "3.13"
19+
20+
- name: Install dependencies
21+
run: pip install -e ".[test,cli]"
22+
23+
- name: Run unit and integration tests
24+
run: >
25+
pytest keras_remote/ tests/integration/
26+
-v --tb=short
27+
--cov=keras_remote
28+
--cov-report=xml
29+
--ignore=keras_remote/src
30+
31+
- name: Upload coverage
32+
uses: codecov/codecov-action@v4
33+
with:
34+
files: coverage.xml
35+
fail_ci_if_error: false

keras_remote/conftest.py

Lines changed: 43 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,43 @@
1+
"""Shared fixtures for colocated unit tests."""
2+
3+
import pytest
4+
5+
6+
@pytest.fixture
7+
def sample_function():
8+
"""A simple function suitable for serialization tests."""
9+
10+
def add(a, b):
11+
return a + b
12+
13+
return add
14+
15+
16+
@pytest.fixture
17+
def gcp_env(monkeypatch):
18+
"""Set standard GCP env vars for unit tests."""
19+
monkeypatch.setenv("KERAS_REMOTE_PROJECT", "test-project")
20+
monkeypatch.setenv("KERAS_REMOTE_ZONE", "us-central1-a")
21+
monkeypatch.setenv("KERAS_REMOTE_GKE_CLUSTER", "test-cluster")
22+
23+
24+
@pytest.fixture
25+
def mock_storage_client(mocker):
26+
"""Mock google.cloud.storage.Client."""
27+
mock_client = mocker.MagicMock()
28+
mocker.patch("google.cloud.storage.Client", return_value=mock_client)
29+
return mock_client
30+
31+
32+
@pytest.fixture
33+
def mock_kube_config(mocker):
34+
"""Mock kubernetes config loading."""
35+
mocker.patch("keras_remote.backend.gke_client._load_kube_config")
36+
37+
38+
@pytest.fixture
39+
def mock_batch_v1(mocker):
40+
"""Mock kubernetes BatchV1Api."""
41+
mock_api = mocker.MagicMock()
42+
mocker.patch("kubernetes.client.BatchV1Api", return_value=mock_api)
43+
return mock_api
Lines changed: 191 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,191 @@
1+
"""Tests for keras_remote.core.accelerators — parser, registry, categories."""
2+
3+
import pytest
4+
5+
from keras_remote.core.accelerators import (
6+
_GPU_ALIASES,
7+
GPUS,
8+
TPUS,
9+
GpuConfig,
10+
TpuConfig,
11+
get_category,
12+
parse_accelerator,
13+
)
14+
15+
16+
class TestParseGpuDirect:
17+
def test_l4(self):
18+
result = parse_accelerator("l4")
19+
assert isinstance(result, GpuConfig)
20+
assert result.name == "l4"
21+
assert result.count == 1
22+
assert result.gke_label == "nvidia-l4"
23+
assert result.machine_type == "g2-standard-4"
24+
25+
@pytest.mark.parametrize("gpu_name", list(GPUS.keys()))
26+
def test_all_gpu_types_parse_with_count_1(self, gpu_name):
27+
result = parse_accelerator(gpu_name)
28+
assert isinstance(result, GpuConfig)
29+
assert result.count == 1
30+
assert result.name == gpu_name
31+
32+
33+
class TestParseGpuMultiCount:
34+
def test_a100x4(self):
35+
result = parse_accelerator("a100x4")
36+
assert isinstance(result, GpuConfig)
37+
assert result.name == "a100"
38+
assert result.count == 4
39+
40+
def test_a100_80gbx4(self):
41+
result = parse_accelerator("a100-80gbx4")
42+
assert isinstance(result, GpuConfig)
43+
assert result.name == "a100-80gb"
44+
assert result.count == 4
45+
46+
47+
class TestParseGpuAlias:
48+
def test_nvidia_tesla_t4(self):
49+
result = parse_accelerator("nvidia-tesla-t4")
50+
assert isinstance(result, GpuConfig)
51+
assert result.name == "t4"
52+
assert result.count == 1
53+
54+
def test_nvidia_tesla_v100x4(self):
55+
result = parse_accelerator("nvidia-tesla-v100x4")
56+
assert isinstance(result, GpuConfig)
57+
assert result.name == "v100"
58+
assert result.count == 4
59+
60+
61+
class TestParseGpuErrors:
62+
def test_l4x8_invalid_count(self):
63+
with pytest.raises(ValueError, match="not supported"):
64+
parse_accelerator("l4x8")
65+
66+
67+
class TestParseTpuBare:
68+
def test_v5litepod(self):
69+
result = parse_accelerator("v5litepod")
70+
assert isinstance(result, TpuConfig)
71+
assert result.name == "v5litepod"
72+
assert result.chips == 4
73+
assert result.topology == "2x2"
74+
75+
@pytest.mark.parametrize("tpu_name", list(TPUS.keys()))
76+
def test_all_tpu_types_parse_with_default_chips(self, tpu_name):
77+
result = parse_accelerator(tpu_name)
78+
assert isinstance(result, TpuConfig)
79+
assert result.name == tpu_name
80+
assert result.chips == TPUS[tpu_name].default_chips
81+
82+
83+
class TestParseTpuChipCount:
84+
def test_v3_8(self):
85+
result = parse_accelerator("v3-8")
86+
assert isinstance(result, TpuConfig)
87+
assert result.name == "v3"
88+
assert result.chips == 8
89+
assert result.topology == "2x2"
90+
91+
def test_v3_32(self):
92+
result = parse_accelerator("v3-32")
93+
assert isinstance(result, TpuConfig)
94+
assert result.name == "v3"
95+
assert result.chips == 32
96+
assert result.topology == "4x4"
97+
98+
def test_v5litepod_1(self):
99+
result = parse_accelerator("v5litepod-1")
100+
assert isinstance(result, TpuConfig)
101+
assert result.chips == 1
102+
assert result.topology == "1x1"
103+
104+
105+
class TestParseTpuTopology:
106+
def test_v5litepod_2x2(self):
107+
result = parse_accelerator("v5litepod-2x2")
108+
assert isinstance(result, TpuConfig)
109+
assert result.name == "v5litepod"
110+
assert result.chips == 4
111+
assert result.topology == "2x2"
112+
113+
def test_v5litepod_1x1(self):
114+
result = parse_accelerator("v5litepod-1x1")
115+
assert isinstance(result, TpuConfig)
116+
assert result.chips == 1
117+
assert result.topology == "1x1"
118+
119+
120+
class TestParseTpuErrors:
121+
def test_v3_16_invalid_chips(self):
122+
with pytest.raises(ValueError, match="not supported"):
123+
parse_accelerator("v3-16")
124+
125+
def test_v5litepod_3x3_invalid_topology(self):
126+
with pytest.raises(ValueError, match="Unknown accelerator"):
127+
parse_accelerator("v5litepod-3x3")
128+
129+
130+
class TestParseTpuConfigFields:
131+
def test_v3_8_full_config(self):
132+
result = parse_accelerator("v3-8")
133+
assert result.gke_accelerator == "tpu-v3-podslice"
134+
assert result.machine_type == "ct3p-hightpu-4t"
135+
assert result.num_nodes == 2
136+
137+
138+
class TestParseCpu:
139+
def test_cpu(self):
140+
assert parse_accelerator("cpu") is None
141+
142+
143+
class TestParseNormalizationAndErrors:
144+
def test_whitespace_and_case(self):
145+
result = parse_accelerator(" A100X4 ")
146+
assert isinstance(result, GpuConfig)
147+
assert result.name == "a100"
148+
assert result.count == 4
149+
150+
def test_empty_string(self):
151+
with pytest.raises(ValueError, match="Unknown accelerator"):
152+
parse_accelerator("")
153+
154+
def test_unknown_accelerator(self):
155+
with pytest.raises(ValueError, match="Unknown accelerator"):
156+
parse_accelerator("unknown")
157+
158+
159+
class TestGetCategory:
160+
def test_cpu(self):
161+
assert get_category("cpu") == "cpu"
162+
163+
def test_gpu(self):
164+
assert get_category("l4") == "gpu"
165+
166+
def test_tpu(self):
167+
assert get_category("v5litepod") == "tpu"
168+
169+
170+
class TestRegistryIntegrity:
171+
def test_all_gpus_have_nonempty_counts(self):
172+
for name, spec in GPUS.items():
173+
assert len(spec.counts) > 0, f"GPU '{name}' has empty counts"
174+
175+
def test_all_tpus_have_nonempty_topologies(self):
176+
for name, spec in TPUS.items():
177+
assert len(spec.topologies) > 0, f"TPU '{name}' has empty topologies"
178+
179+
def test_all_tpu_default_chips_valid(self):
180+
for name, spec in TPUS.items():
181+
assert spec.default_chips in spec.topologies, (
182+
f"TPU '{name}' default_chips={spec.default_chips} "
183+
f"not in topologies {list(spec.topologies.keys())}"
184+
)
185+
186+
def test_all_gpus_have_gke_label_alias(self):
187+
for name, spec in GPUS.items():
188+
assert spec.gke_label in _GPU_ALIASES, (
189+
f"GPU '{name}' gke_label '{spec.gke_label}' not in aliases"
190+
)
191+
assert _GPU_ALIASES[spec.gke_label] == name

keras_remote/test_constants.py

Lines changed: 96 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,96 @@
1+
"""Tests for keras_remote.constants — zone/region/location helpers."""
2+
3+
import pytest
4+
5+
from keras_remote.constants import (
6+
DEFAULT_REGION,
7+
DEFAULT_ZONE,
8+
get_default_zone,
9+
zone_to_ar_location,
10+
zone_to_region,
11+
)
12+
13+
14+
class TestZoneToRegion:
15+
@pytest.mark.parametrize(
16+
"zone, expected_region",
17+
[
18+
("us-central1-a", "us-central1"),
19+
("us-central1-b", "us-central1"),
20+
("us-east1-b", "us-east1"),
21+
("us-east4-c", "us-east4"),
22+
("us-west1-a", "us-west1"),
23+
("us-west4-b", "us-west4"),
24+
("europe-west1-b", "europe-west1"),
25+
("europe-west4-b", "europe-west4"),
26+
("asia-east1-c", "asia-east1"),
27+
("asia-southeast1-a", "asia-southeast1"),
28+
("me-west1-a", "me-west1"),
29+
("southamerica-east1-b", "southamerica-east1"),
30+
],
31+
)
32+
def test_zone_to_region(self, zone, expected_region):
33+
assert zone_to_region(zone) == expected_region
34+
35+
@pytest.mark.parametrize(
36+
"zone",
37+
[
38+
"",
39+
None,
40+
"invalid",
41+
],
42+
)
43+
def test_fallback_returns_default(self, zone):
44+
assert zone_to_region(zone) == DEFAULT_REGION
45+
46+
47+
class TestZoneToArLocation:
48+
@pytest.mark.parametrize(
49+
"zone, expected_location",
50+
[
51+
("us-central1-a", "us"),
52+
("us-east1-b", "us"),
53+
("us-west1-a", "us"),
54+
("europe-west1-b", "europe"),
55+
("europe-west4-b", "europe"),
56+
("asia-east1-c", "asia"),
57+
("asia-southeast1-a", "asia"),
58+
("me-west1-a", "me"),
59+
("southamerica-east1-b", "southamerica"),
60+
],
61+
)
62+
def test_zone_to_ar_location(self, zone, expected_location):
63+
assert zone_to_ar_location(zone) == expected_location
64+
65+
66+
class TestGetDefaultZone:
67+
@pytest.mark.parametrize(
68+
"env_value, expected_zone",
69+
[
70+
("us-west1-b", "us-west1-b"),
71+
("europe-west4-a", "europe-west4-a"),
72+
("asia-east1-c", "asia-east1-c"),
73+
],
74+
)
75+
def test_returns_env_var_when_set(
76+
self, monkeypatch, env_value, expected_zone
77+
):
78+
# Temporarily set the env var so get_default_zone() picks it up.
79+
monkeypatch.setenv("KERAS_REMOTE_ZONE", env_value)
80+
assert get_default_zone() == expected_zone
81+
82+
def test_returns_default_when_unset(self, monkeypatch):
83+
# Remove the env var if set
84+
# raising=False avoids KeyError when it's already absent.
85+
monkeypatch.delenv("KERAS_REMOTE_ZONE", raising=False)
86+
assert get_default_zone() == DEFAULT_ZONE
87+
88+
@pytest.mark.parametrize(
89+
"constant, expected_value",
90+
[
91+
(DEFAULT_ZONE, "us-central1-a"),
92+
(DEFAULT_REGION, "us-central1"),
93+
],
94+
)
95+
def test_default_constants(self, constant, expected_value):
96+
assert constant == expected_value

0 commit comments

Comments
 (0)