Skip to content

Commit 3f70738

Browse files
Migrate e2e tests from pytest to absl testing
1 parent 930fa15 commit 3f70738

File tree

2 files changed

+35
-26
lines changed

2 files changed

+35
-26
lines changed

.github/workflows/e2e-tests.yaml

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -48,4 +48,4 @@ jobs:
4848
KERAS_REMOTE_PROJECT: ${{ secrets.GCP_PROJECT }}
4949
KERAS_REMOTE_ZONE: ${{ secrets.GKE_ZONE }}
5050
KERAS_REMOTE_GKE_CLUSTER: ${{ secrets.GKE_CLUSTER }}
51-
run: pytest tests/e2e/ -v --tb=long --timeout=600
51+
run: python -m unittest discover -s tests/e2e -p "test_*.py" -v

tests/e2e/test_cpu_execution.py

Lines changed: 34 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -9,26 +9,31 @@
99
"""
1010

1111
import os
12+
from unittest import mock
1213

13-
import pytest
14+
from absl.testing import absltest
1415

1516
import keras_remote
17+
from tests.e2e.e2e_utils import get_gcp_project, skip_unless_e2e
1618

1719

18-
@pytest.mark.e2e
19-
@pytest.mark.timeout(600)
20-
class TestCpuExecution:
21-
def test_simple_function(self, gcp_project):
20+
@skip_unless_e2e()
21+
class TestCpuExecution(absltest.TestCase):
22+
def setUp(self):
23+
super().setUp()
24+
self.project = get_gcp_project()
25+
26+
def test_simple_function(self):
2227
"""Execute a simple add function remotely and verify the result."""
2328

2429
@keras_remote.run(accelerator="cpu")
2530
def add(a, b):
2631
return a + b
2732

2833
result = add(2, 3)
29-
assert result == 5
34+
self.assertEqual(result, 5)
3035

31-
def test_complex_return_type(self, gcp_project):
36+
def test_complex_return_type(self):
3237
"""Verify complex return types survive serialization roundtrip."""
3338

3439
@keras_remote.run(accelerator="cpu")
@@ -40,41 +45,45 @@ def complex_return():
4045
}
4146

4247
result = complex_return()
43-
assert result["key"] == [1, 2, 3]
44-
assert result["nested"]["a"] is True
45-
assert result["nested"]["b"] is None
46-
assert result["tuple"] == (4, 5)
48+
self.assertEqual(result["key"], [1, 2, 3])
49+
self.assertTrue(result["nested"]["a"])
50+
self.assertIsNone(result["nested"]["b"])
51+
self.assertEqual(result["tuple"], (4, 5))
4752

48-
def test_function_that_raises(self, gcp_project):
53+
def test_function_that_raises(self):
4954
"""Verify remote exceptions are re-raised locally."""
5055

5156
@keras_remote.run(accelerator="cpu")
5257
def bad_func():
5358
raise ValueError("intentional test error")
5459

55-
with pytest.raises(ValueError, match="intentional test error"):
60+
with self.assertRaisesRegex(ValueError, "intentional test error"):
5661
bad_func()
5762

58-
def test_env_var_propagation(self, gcp_project, monkeypatch):
63+
def test_env_var_propagation(self):
5964
"""Verify captured env vars are available in the remote environment."""
60-
monkeypatch.setenv("E2E_TEST_VAR", "hello_from_local")
65+
with mock.patch.dict(os.environ, {"E2E_TEST_VAR": "hello_from_local"}):
6166

62-
@keras_remote.run(
63-
accelerator="cpu",
64-
capture_env_vars=["E2E_TEST_VAR"],
65-
)
66-
def read_env():
67-
return os.environ.get("E2E_TEST_VAR")
67+
@keras_remote.run(
68+
accelerator="cpu",
69+
capture_env_vars=["E2E_TEST_VAR"],
70+
)
71+
def read_env():
72+
return os.environ.get("E2E_TEST_VAR")
6873

69-
result = read_env()
70-
assert result == "hello_from_local"
74+
result = read_env()
75+
self.assertEqual(result, "hello_from_local")
7176

72-
def test_function_with_args_and_kwargs(self, gcp_project):
77+
def test_function_with_args_and_kwargs(self):
7378
"""Verify positional and keyword arguments are passed correctly."""
7479

7580
@keras_remote.run(accelerator="cpu")
7681
def compute(x, y, scale=1.0, offset=0.0):
7782
return (x + y) * scale + offset
7883

7984
result = compute(3, 4, scale=2.0, offset=1.0)
80-
assert result == 15.0
85+
self.assertEqual(result, 15.0)
86+
87+
88+
if __name__ == "__main__":
89+
absltest.main()

0 commit comments

Comments
 (0)