99"""
1010
1111import os
12+ from unittest import mock
1213
13- import pytest
14+ from absl . testing import absltest
1415
1516import keras_remote
17+ from tests .e2e .e2e_utils import get_gcp_project , skip_unless_e2e
1618
1719
18- @pytest .mark .e2e
19- @pytest .mark .timeout (600 )
20- class TestCpuExecution :
21- def test_simple_function (self , gcp_project ):
20+ @skip_unless_e2e ()
21+ class TestCpuExecution (absltest .TestCase ):
22+ def setUp (self ):
23+ super ().setUp ()
24+ self .project = get_gcp_project ()
25+
26+ def test_simple_function (self ):
2227 """Execute a simple add function remotely and verify the result."""
2328
2429 @keras_remote .run (accelerator = "cpu" )
2530 def add (a , b ):
2631 return a + b
2732
2833 result = add (2 , 3 )
29- assert result == 5
34+ self . assertEqual ( result , 5 )
3035
31- def test_complex_return_type (self , gcp_project ):
36+ def test_complex_return_type (self ):
3237 """Verify complex return types survive serialization roundtrip."""
3338
3439 @keras_remote .run (accelerator = "cpu" )
@@ -40,41 +45,45 @@ def complex_return():
4045 }
4146
4247 result = complex_return ()
43- assert result ["key" ] == [1 , 2 , 3 ]
44- assert result ["nested" ]["a" ] is True
45- assert result ["nested" ]["b" ] is None
46- assert result ["tuple" ] == (4 , 5 )
48+ self . assertEqual ( result ["key" ], [1 , 2 , 3 ])
49+ self . assertTrue ( result ["nested" ]["a" ])
50+ self . assertIsNone ( result ["nested" ]["b" ])
51+ self . assertEqual ( result ["tuple" ], (4 , 5 ) )
4752
48- def test_function_that_raises (self , gcp_project ):
53+ def test_function_that_raises (self ):
4954 """Verify remote exceptions are re-raised locally."""
5055
5156 @keras_remote .run (accelerator = "cpu" )
5257 def bad_func ():
5358 raise ValueError ("intentional test error" )
5459
55- with pytest . raises (ValueError , match = "intentional test error" ):
60+ with self . assertRaisesRegex (ValueError , "intentional test error" ):
5661 bad_func ()
5762
58- def test_env_var_propagation (self , gcp_project , monkeypatch ):
63+ def test_env_var_propagation (self ):
5964 """Verify captured env vars are available in the remote environment."""
60- monkeypatch . setenv ( "E2E_TEST_VAR" , "hello_from_local" )
65+ with mock . patch . dict ( os . environ , { "E2E_TEST_VAR" : "hello_from_local" }):
6166
62- @keras_remote .run (
63- accelerator = "cpu" ,
64- capture_env_vars = ["E2E_TEST_VAR" ],
65- )
66- def read_env ():
67- return os .environ .get ("E2E_TEST_VAR" )
67+ @keras_remote .run (
68+ accelerator = "cpu" ,
69+ capture_env_vars = ["E2E_TEST_VAR" ],
70+ )
71+ def read_env ():
72+ return os .environ .get ("E2E_TEST_VAR" )
6873
69- result = read_env ()
70- assert result == "hello_from_local"
74+ result = read_env ()
75+ self . assertEqual ( result , "hello_from_local" )
7176
72- def test_function_with_args_and_kwargs (self , gcp_project ):
77+ def test_function_with_args_and_kwargs (self ):
7378 """Verify positional and keyword arguments are passed correctly."""
7479
7580 @keras_remote .run (accelerator = "cpu" )
7681 def compute (x , y , scale = 1.0 , offset = 0.0 ):
7782 return (x + y ) * scale + offset
7883
7984 result = compute (3 , 4 , scale = 2.0 , offset = 1.0 )
80- assert result == 15.0
85+ self .assertEqual (result , 15.0 )
86+
87+
88+ if __name__ == "__main__" :
89+ absltest .main ()
0 commit comments