Skip to content
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
35 changes: 35 additions & 0 deletions .github/workflows/tests.yaml
Original file line number Diff line number Diff line change
@@ -0,0 +1,35 @@
name: Tests

on:
push:
branches: [main]
pull_request:
branches: [main]

jobs:
test:
runs-on: ubuntu-latest
steps:
- uses: actions/checkout@v4

- name: Set up Python 3.13
uses: actions/setup-python@v5
with:
python-version: "3.13"

- name: Install dependencies
run: pip install -e ".[test,cli]"

- name: Run unit and integration tests
run: >
pytest keras_remote/ tests/integration/
-v --tb=short
--cov=keras_remote
--cov-report=xml
--ignore=keras_remote/src

- name: Upload coverage
uses: codecov/codecov-action@v4
with:
files: coverage.xml
fail_ci_if_error: false
43 changes: 43 additions & 0 deletions keras_remote/conftest.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,43 @@
"""Shared fixtures for colocated unit tests."""

import pytest


@pytest.fixture
def sample_function():
"""A simple function suitable for serialization tests."""

def add(a, b):
return a + b

return add


@pytest.fixture
def gcp_env(monkeypatch):
"""Set standard GCP env vars for unit tests."""
monkeypatch.setenv("KERAS_REMOTE_PROJECT", "test-project")
monkeypatch.setenv("KERAS_REMOTE_ZONE", "us-central1-a")
monkeypatch.setenv("KERAS_REMOTE_GKE_CLUSTER", "test-cluster")


@pytest.fixture
def mock_storage_client(mocker):
"""Mock google.cloud.storage.Client."""
mock_client = mocker.MagicMock()
mocker.patch("google.cloud.storage.Client", return_value=mock_client)
return mock_client


@pytest.fixture
def mock_kube_config(mocker):
"""Mock kubernetes config loading."""
mocker.patch("keras_remote.backend.gke_client._load_kube_config")


@pytest.fixture
def mock_batch_v1(mocker):
"""Mock kubernetes BatchV1Api."""
mock_api = mocker.MagicMock()
mocker.patch("kubernetes.client.BatchV1Api", return_value=mock_api)
return mock_api
191 changes: 191 additions & 0 deletions keras_remote/core/test_accelerators.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,191 @@
"""Tests for keras_remote.core.accelerators — parser, registry, categories."""

import pytest

from keras_remote.core.accelerators import (
_GPU_ALIASES,
GPUS,
TPUS,
GpuConfig,
TpuConfig,
get_category,
parse_accelerator,
)


class TestParseGpuDirect:
def test_l4(self):
result = parse_accelerator("l4")
assert isinstance(result, GpuConfig)
assert result.name == "l4"
assert result.count == 1
assert result.gke_label == "nvidia-l4"
assert result.machine_type == "g2-standard-4"

@pytest.mark.parametrize("gpu_name", list(GPUS.keys()))
def test_all_gpu_types_parse_with_count_1(self, gpu_name):
result = parse_accelerator(gpu_name)
assert isinstance(result, GpuConfig)
assert result.count == 1
assert result.name == gpu_name


class TestParseGpuMultiCount:
def test_a100x4(self):
result = parse_accelerator("a100x4")
assert isinstance(result, GpuConfig)
assert result.name == "a100"
assert result.count == 4

def test_a100_80gbx4(self):
result = parse_accelerator("a100-80gbx4")
assert isinstance(result, GpuConfig)
assert result.name == "a100-80gb"
assert result.count == 4


class TestParseGpuAlias:
def test_nvidia_tesla_t4(self):
result = parse_accelerator("nvidia-tesla-t4")
assert isinstance(result, GpuConfig)
assert result.name == "t4"
assert result.count == 1

def test_nvidia_tesla_v100x4(self):
result = parse_accelerator("nvidia-tesla-v100x4")
assert isinstance(result, GpuConfig)
assert result.name == "v100"
assert result.count == 4


class TestParseGpuErrors:
def test_l4x8_invalid_count(self):
with pytest.raises(ValueError, match="not supported"):
parse_accelerator("l4x8")


class TestParseTpuBare:
def test_v5litepod(self):
result = parse_accelerator("v5litepod")
assert isinstance(result, TpuConfig)
assert result.name == "v5litepod"
assert result.chips == 4
assert result.topology == "2x2"

@pytest.mark.parametrize("tpu_name", list(TPUS.keys()))
def test_all_tpu_types_parse_with_default_chips(self, tpu_name):
result = parse_accelerator(tpu_name)
assert isinstance(result, TpuConfig)
assert result.name == tpu_name
assert result.chips == TPUS[tpu_name].default_chips


class TestParseTpuChipCount:
def test_v3_8(self):
result = parse_accelerator("v3-8")
assert isinstance(result, TpuConfig)
assert result.name == "v3"
assert result.chips == 8
assert result.topology == "2x2"

def test_v3_32(self):
result = parse_accelerator("v3-32")
assert isinstance(result, TpuConfig)
assert result.name == "v3"
assert result.chips == 32
assert result.topology == "4x4"

def test_v5litepod_1(self):
result = parse_accelerator("v5litepod-1")
assert isinstance(result, TpuConfig)
assert result.chips == 1
assert result.topology == "1x1"


class TestParseTpuTopology:
def test_v5litepod_2x2(self):
result = parse_accelerator("v5litepod-2x2")
assert isinstance(result, TpuConfig)
assert result.name == "v5litepod"
assert result.chips == 4
assert result.topology == "2x2"

def test_v5litepod_1x1(self):
result = parse_accelerator("v5litepod-1x1")
assert isinstance(result, TpuConfig)
assert result.chips == 1
assert result.topology == "1x1"


class TestParseTpuErrors:
def test_v3_16_invalid_chips(self):
with pytest.raises(ValueError, match="not supported"):
parse_accelerator("v3-16")

def test_v5litepod_3x3_invalid_topology(self):
with pytest.raises(ValueError, match="Unknown accelerator"):
parse_accelerator("v5litepod-3x3")


class TestParseTpuConfigFields:
def test_v3_8_full_config(self):
result = parse_accelerator("v3-8")
assert result.gke_accelerator == "tpu-v3-podslice"
assert result.machine_type == "ct3p-hightpu-4t"
assert result.num_nodes == 2


class TestParseCpu:
def test_cpu(self):
assert parse_accelerator("cpu") is None


class TestParseNormalizationAndErrors:
def test_whitespace_and_case(self):
result = parse_accelerator(" A100X4 ")
assert isinstance(result, GpuConfig)
assert result.name == "a100"
assert result.count == 4

def test_empty_string(self):
with pytest.raises(ValueError, match="Unknown accelerator"):
parse_accelerator("")

def test_unknown_accelerator(self):
with pytest.raises(ValueError, match="Unknown accelerator"):
parse_accelerator("unknown")


class TestGetCategory:
def test_cpu(self):
assert get_category("cpu") == "cpu"

def test_gpu(self):
assert get_category("l4") == "gpu"

def test_tpu(self):
assert get_category("v5litepod") == "tpu"


class TestRegistryIntegrity:
def test_all_gpus_have_nonempty_counts(self):
for name, spec in GPUS.items():
assert len(spec.counts) > 0, f"GPU '{name}' has empty counts"

def test_all_tpus_have_nonempty_topologies(self):
for name, spec in TPUS.items():
assert len(spec.topologies) > 0, f"TPU '{name}' has empty topologies"

def test_all_tpu_default_chips_valid(self):
for name, spec in TPUS.items():
assert spec.default_chips in spec.topologies, (
f"TPU '{name}' default_chips={spec.default_chips} "
f"not in topologies {list(spec.topologies.keys())}"
)

def test_all_gpus_have_gke_label_alias(self):
for name, spec in GPUS.items():
assert spec.gke_label in _GPU_ALIASES, (
f"GPU '{name}' gke_label '{spec.gke_label}' not in aliases"
)
assert _GPU_ALIASES[spec.gke_label] == name
96 changes: 96 additions & 0 deletions keras_remote/test_constants.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,96 @@
"""Tests for keras_remote.constants — zone/region/location helpers."""

import pytest

from keras_remote.constants import (
DEFAULT_REGION,
DEFAULT_ZONE,
get_default_zone,
zone_to_ar_location,
zone_to_region,
)


class TestZoneToRegion:
@pytest.mark.parametrize(
"zone, expected_region",
[
("us-central1-a", "us-central1"),
("us-central1-b", "us-central1"),
("us-east1-b", "us-east1"),
("us-east4-c", "us-east4"),
("us-west1-a", "us-west1"),
("us-west4-b", "us-west4"),
("europe-west1-b", "europe-west1"),
("europe-west4-b", "europe-west4"),
("asia-east1-c", "asia-east1"),
("asia-southeast1-a", "asia-southeast1"),
("me-west1-a", "me-west1"),
("southamerica-east1-b", "southamerica-east1"),
],
)
def test_zone_to_region(self, zone, expected_region):
assert zone_to_region(zone) == expected_region

@pytest.mark.parametrize(
"zone",
[
"",
None,
"invalid",
],
)
def test_fallback_returns_default(self, zone):
assert zone_to_region(zone) == DEFAULT_REGION


class TestZoneToArLocation:
@pytest.mark.parametrize(
"zone, expected_location",
[
("us-central1-a", "us"),
("us-east1-b", "us"),
("us-west1-a", "us"),
("europe-west1-b", "europe"),
("europe-west4-b", "europe"),
("asia-east1-c", "asia"),
("asia-southeast1-a", "asia"),
("me-west1-a", "me"),
("southamerica-east1-b", "southamerica"),
],
)
def test_zone_to_ar_location(self, zone, expected_location):
assert zone_to_ar_location(zone) == expected_location


class TestGetDefaultZone:
@pytest.mark.parametrize(
"env_value, expected_zone",
[
("us-west1-b", "us-west1-b"),
("europe-west4-a", "europe-west4-a"),
("asia-east1-c", "asia-east1-c"),
],
)
def test_returns_env_var_when_set(
self, monkeypatch, env_value, expected_zone
):
# Temporarily set the env var so get_default_zone() picks it up.
monkeypatch.setenv("KERAS_REMOTE_ZONE", env_value)
assert get_default_zone() == expected_zone

def test_returns_default_when_unset(self, monkeypatch):
# Remove the env var if set
# raising=False avoids KeyError when it's already absent.
monkeypatch.delenv("KERAS_REMOTE_ZONE", raising=False)
assert get_default_zone() == DEFAULT_ZONE

@pytest.mark.parametrize(
"constant, expected_value",
[
(DEFAULT_ZONE, "us-central1-a"),
(DEFAULT_REGION, "us-central1"),
],
)
def test_default_constants(self, constant, expected_value):
assert constant == expected_value
Loading