Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
37 changes: 37 additions & 0 deletions .github/workflows/tests.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
name: Tests

on:
push:
branches: [main]
pull_request:
branches: [main]

jobs:
test:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4

- name: Set up Python 3.13
uses: actions/setup-python@v5
with:
python-version: "3.13"

- name: Install dependencies
run: pip install -e ".[test,cli]"

- name: Run unit tests
run: >
coverage run
-m unittest discover
-s keras_remote -p "test_*.py"
-v

- name: Generate coverage report
run: coverage xml

- name: Upload coverage
uses: codecov/codecov-action@v4
with:
files: coverage.xml
fail_ci_if_error: false
199 changes: 199 additions & 0 deletions keras_remote/core/test_accelerators.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,199 @@
"""Tests for keras_remote.core.accelerators — parser, registry, categories."""

from absl.testing import absltest, parameterized

from keras_remote.core.accelerators import (
_GPU_ALIASES,
GPUS,
TPUS,
GpuConfig,
TpuConfig,
get_category,
parse_accelerator,
)


class TestParseGpuDirect(parameterized.TestCase):
def test_l4(self):
result = parse_accelerator("l4")
self.assertIsInstance(result, GpuConfig)
self.assertEqual(result.name, "l4")
self.assertEqual(result.count, 1)
self.assertEqual(result.gke_label, "nvidia-l4")
self.assertEqual(result.machine_type, "g2-standard-4")

@parameterized.parameters(*list(GPUS.keys()))
def test_all_gpu_types_parse_with_count_1(self, gpu_name):
result = parse_accelerator(gpu_name)
self.assertIsInstance(result, GpuConfig)
self.assertEqual(result.count, 1)
self.assertEqual(result.name, gpu_name)


class TestParseGpuMultiCount(absltest.TestCase):
def test_a100x4(self):
result = parse_accelerator("a100x4")
self.assertIsInstance(result, GpuConfig)
self.assertEqual(result.name, "a100")
self.assertEqual(result.count, 4)

def test_a100_80gbx4(self):
result = parse_accelerator("a100-80gbx4")
self.assertIsInstance(result, GpuConfig)
self.assertEqual(result.name, "a100-80gb")
self.assertEqual(result.count, 4)


class TestParseGpuAlias(absltest.TestCase):
def test_nvidia_tesla_t4(self):
result = parse_accelerator("nvidia-tesla-t4")
self.assertIsInstance(result, GpuConfig)
self.assertEqual(result.name, "t4")
self.assertEqual(result.count, 1)

def test_nvidia_tesla_v100x4(self):
result = parse_accelerator("nvidia-tesla-v100x4")
self.assertIsInstance(result, GpuConfig)
self.assertEqual(result.name, "v100")
self.assertEqual(result.count, 4)


class TestParseGpuErrors(absltest.TestCase):
def test_l4x8_invalid_count(self):
with self.assertRaisesRegex(ValueError, "not supported"):
parse_accelerator("l4x8")


class TestParseTpuBare(parameterized.TestCase):
def test_v5litepod(self):
result = parse_accelerator("v5litepod")
self.assertIsInstance(result, TpuConfig)
self.assertEqual(result.name, "v5litepod")
self.assertEqual(result.chips, 4)
self.assertEqual(result.topology, "2x2")

@parameterized.parameters(*list(TPUS.keys()))
def test_all_tpu_types_parse_with_default_chips(self, tpu_name):
result = parse_accelerator(tpu_name)
self.assertIsInstance(result, TpuConfig)
self.assertEqual(result.name, tpu_name)
self.assertEqual(result.chips, TPUS[tpu_name].default_chips)


class TestParseTpuChipCount(absltest.TestCase):
def test_v3_8(self):
result = parse_accelerator("v3-8")
self.assertIsInstance(result, TpuConfig)
self.assertEqual(result.name, "v3")
self.assertEqual(result.chips, 8)
self.assertEqual(result.topology, "2x2")

def test_v3_32(self):
result = parse_accelerator("v3-32")
self.assertIsInstance(result, TpuConfig)
self.assertEqual(result.name, "v3")
self.assertEqual(result.chips, 32)
self.assertEqual(result.topology, "4x4")

def test_v5litepod_1(self):
result = parse_accelerator("v5litepod-1")
self.assertIsInstance(result, TpuConfig)
self.assertEqual(result.chips, 1)
self.assertEqual(result.topology, "1x1")


class TestParseTpuTopology(absltest.TestCase):
def test_v5litepod_2x2(self):
result = parse_accelerator("v5litepod-2x2")
self.assertIsInstance(result, TpuConfig)
self.assertEqual(result.name, "v5litepod")
self.assertEqual(result.chips, 4)
self.assertEqual(result.topology, "2x2")

def test_v5litepod_1x1(self):
result = parse_accelerator("v5litepod-1x1")
self.assertIsInstance(result, TpuConfig)
self.assertEqual(result.chips, 1)
self.assertEqual(result.topology, "1x1")


class TestParseTpuErrors(absltest.TestCase):
def test_v3_16_invalid_chips(self):
with self.assertRaisesRegex(ValueError, "not supported"):
parse_accelerator("v3-16")

def test_v5litepod_3x3_invalid_topology(self):
with self.assertRaisesRegex(ValueError, "Unknown accelerator"):
parse_accelerator("v5litepod-3x3")


class TestParseTpuConfigFields(absltest.TestCase):
def test_v3_8_full_config(self):
result = parse_accelerator("v3-8")
self.assertEqual(result.gke_accelerator, "tpu-v3-podslice")
self.assertEqual(result.machine_type, "ct3p-hightpu-4t")
self.assertEqual(result.num_nodes, 2)


class TestParseCpu(absltest.TestCase):
def test_cpu(self):
self.assertIsNone(parse_accelerator("cpu"))


class TestParseNormalizationAndErrors(absltest.TestCase):
def test_whitespace_and_case(self):
result = parse_accelerator(" A100X4 ")
self.assertIsInstance(result, GpuConfig)
self.assertEqual(result.name, "a100")
self.assertEqual(result.count, 4)

def test_empty_string(self):
with self.assertRaisesRegex(ValueError, "Unknown accelerator"):
parse_accelerator("")

def test_unknown_accelerator(self):
with self.assertRaisesRegex(ValueError, "Unknown accelerator"):
parse_accelerator("unknown")


class TestGetCategory(absltest.TestCase):
def test_cpu(self):
self.assertEqual(get_category("cpu"), "cpu")

def test_gpu(self):
self.assertEqual(get_category("l4"), "gpu")

def test_tpu(self):
self.assertEqual(get_category("v5litepod"), "tpu")


class TestRegistryIntegrity(absltest.TestCase):
def test_all_gpus_have_nonempty_counts(self):
for name, spec in GPUS.items():
self.assertNotEmpty(spec.counts, f"GPU '{name}' has empty counts")

def test_all_tpus_have_nonempty_topologies(self):
for name, spec in TPUS.items():
self.assertNotEmpty(spec.topologies, f"TPU '{name}' has empty topologies")

def test_all_tpu_default_chips_valid(self):
for name, spec in TPUS.items():
self.assertIn(
spec.default_chips,
spec.topologies,
f"TPU '{name}' default_chips={spec.default_chips} "
f"not in topologies {list(spec.topologies.keys())}",
)

def test_all_gpus_have_gke_label_alias(self):
for name, spec in GPUS.items():
self.assertIn(
spec.gke_label,
_GPU_ALIASES,
f"GPU '{name}' gke_label '{spec.gke_label}' not in aliases",
)
self.assertEqual(_GPU_ALIASES[spec.gke_label], name)


if __name__ == "__main__":
absltest.main()
84 changes: 84 additions & 0 deletions keras_remote/test_constants.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,84 @@
"""Tests for keras_remote.constants — zone/region/location helpers."""

import os
from unittest import mock

from absl.testing import absltest, parameterized

from keras_remote.constants import (
DEFAULT_REGION,
DEFAULT_ZONE,
get_default_zone,
zone_to_ar_location,
zone_to_region,
)


class TestZoneToRegion(parameterized.TestCase):
@parameterized.parameters(
("us-central1-a", "us-central1"),
("us-central1-b", "us-central1"),
("us-east1-b", "us-east1"),
("us-east4-c", "us-east4"),
("us-west1-a", "us-west1"),
("us-west4-b", "us-west4"),
("europe-west1-b", "europe-west1"),
("europe-west4-b", "europe-west4"),
("asia-east1-c", "asia-east1"),
("asia-southeast1-a", "asia-southeast1"),
("me-west1-a", "me-west1"),
("southamerica-east1-b", "southamerica-east1"),
)
def test_zone_to_region(self, zone, expected_region):
self.assertEqual(zone_to_region(zone), expected_region)

@parameterized.parameters(
("",),
(None,),
("invalid",),
)
def test_fallback_returns_default(self, zone):
self.assertEqual(zone_to_region(zone), DEFAULT_REGION)


class TestZoneToArLocation(parameterized.TestCase):
@parameterized.parameters(
("us-central1-a", "us"),
("us-east1-b", "us"),
("us-west1-a", "us"),
("europe-west1-b", "europe"),
("europe-west4-b", "europe"),
("asia-east1-c", "asia"),
("asia-southeast1-a", "asia"),
("me-west1-a", "me"),
("southamerica-east1-b", "southamerica"),
)
def test_zone_to_ar_location(self, zone, expected_location):
self.assertEqual(zone_to_ar_location(zone), expected_location)


class TestGetDefaultZone(parameterized.TestCase):
@parameterized.parameters(
("us-west1-b", "us-west1-b"),
("europe-west4-a", "europe-west4-a"),
("asia-east1-c", "asia-east1-c"),
)
def test_returns_env_var_when_set(self, env_value, expected_zone):
with mock.patch.dict(os.environ, {"KERAS_REMOTE_ZONE": env_value}):
self.assertEqual(get_default_zone(), expected_zone)

def test_returns_default_when_unset(self):
env = {k: v for k, v in os.environ.items() if k != "KERAS_REMOTE_ZONE"}
with mock.patch.dict(os.environ, env, clear=True):
self.assertEqual(get_default_zone(), DEFAULT_ZONE)

@parameterized.parameters(
(DEFAULT_ZONE, "us-central1-a"),
(DEFAULT_REGION, "us-central1"),
)
def test_default_constants(self, constant, expected_value):
self.assertEqual(constant, expected_value)


if __name__ == "__main__":
absltest.main()
22 changes: 22 additions & 0 deletions keras_remote/test_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,22 @@
"""Shared test utilities for colocated unit tests."""

from unittest import mock


def create_mock_storage_client():
"""Create a mock google.cloud.storage.Client."""
mock_client = mock.MagicMock()
patcher = mock.patch("google.cloud.storage.Client", return_value=mock_client)
return mock_client, patcher


def create_mock_kube_config():
"""Create a mock kubernetes config loading patcher."""
return mock.patch("keras_remote.backend.gke_client._load_kube_config")


def create_mock_batch_v1():
"""Create a mock kubernetes BatchV1Api."""
mock_api = mock.MagicMock()
patcher = mock.patch("kubernetes.client.BatchV1Api", return_value=mock_api)
return mock_api, patcher
9 changes: 8 additions & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,13 @@ cli = [
"pulumi-gcp>=9.0",
"pulumi-command>=1.0",
]
test = [
"coverage>=7.0",
]
dev = [
"pre-commit",
"ruff",
"keras-remote[test]",
]
demo = [
"jax",
Expand Down Expand Up @@ -67,4 +71,7 @@ select = ["B", "E", "F", "N", "PYI", "T20", "TID", "SIM", "W", "I", "NPY"]
ignore = ["E501"]

[tool.ruff.lint.per-file-ignores]
"examples/*" = ["T201", "NPY002"]
"examples/*" = ["T201", "NPY002"]
"**/test_*.py" = ["T201"]
"tests/**" = ["T201"]

Empty file added tests/__init__.py
Empty file.
Empty file added tests/e2e/__init__.py
Empty file.
17 changes: 17 additions & 0 deletions tests/e2e/e2e_utils.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,17 @@
"""E2E test utilities — helpers for tests requiring GCP infrastructure."""

import os
import unittest


def skip_unless_e2e(reason="E2E_TESTS not set"):
"""Skip decorator for e2e tests unless E2E_TESTS env var is set."""
return unittest.skipUnless(os.environ.get("E2E_TESTS"), reason)


def get_gcp_project():
"""Return GCP project from env, skip test if not set."""
project = os.environ.get("KERAS_REMOTE_PROJECT")
if not project:
raise unittest.SkipTest("KERAS_REMOTE_PROJECT not set")
return project