11"""Tests for keras_remote.backend.execution — JobContext and execute_remote."""
22
3- import re
3+ import os
4+ import pathlib
5+ import tempfile
6+ from unittest import mock
47from unittest .mock import MagicMock
58
6- import pytest
9+ from absl . testing import absltest
710
811from keras_remote .backend .execution import (
912 JobContext ,
1215)
1316
1417
15- class TestJobContext :
18+ def _make_temp_path (test_case ):
19+ """Create a temp directory that is cleaned up after the test."""
20+ td = tempfile .TemporaryDirectory ()
21+ test_case .addCleanup (td .cleanup )
22+ return pathlib .Path (td .name )
23+
24+
25+ class TestJobContext (absltest .TestCase ):
1626 def _make_func (self ):
1727 def my_train ():
1828 return 42
@@ -30,10 +40,10 @@ def test_post_init_derived_fields(self):
3040 zone = "europe-west4-b" ,
3141 project = "my-proj" ,
3242 )
33- assert ctx .bucket_name == "my-proj-keras-remote-jobs"
34- assert ctx .region == "europe-west4"
35- assert ctx .display_name .startswith ("keras-remote-my_train-" )
36- assert re . fullmatch ( r" job-[0-9a-f]{8}" , ctx . job_id )
43+ self . assertEqual ( ctx .bucket_name , "my-proj-keras-remote-jobs" )
44+ self . assertEqual ( ctx .region , "europe-west4" )
45+ self . assertTrue ( ctx .display_name .startswith ("keras-remote-my_train-" ) )
46+ self . assertRegex ( ctx . job_id , r"^ job-[0-9a-f]{8}$" )
3747
3848 def test_from_params_explicit (self ):
3949 ctx = JobContext .from_params (
@@ -46,35 +56,38 @@ def test_from_params_explicit(self):
4656 project = "explicit-proj" ,
4757 env_vars = {"X" : "Y" },
4858 )
49- assert ctx .zone == "us-west1-a"
50- assert ctx .project == "explicit-proj"
51- assert ctx .accelerator == "l4"
52- assert ctx .container_image == "my-image:latest"
53- assert ctx .args == (1 , 2 )
54- assert ctx .kwargs == {"k" : "v" }
55- assert ctx .env_vars == {"X" : "Y" }
56-
57- def test_from_params_resolves_zone_from_env (self , monkeypatch ):
58- monkeypatch .setenv ("KERAS_REMOTE_ZONE" , "asia-east1-c" )
59- monkeypatch .setenv ("KERAS_REMOTE_PROJECT" , "env-proj" )
60-
61- ctx = JobContext .from_params (
62- func = self ._make_func (),
63- args = (),
64- kwargs = {},
65- accelerator = "cpu" ,
66- container_image = None ,
67- zone = None ,
68- project = None ,
69- env_vars = {},
70- )
71- assert ctx .zone == "asia-east1-c"
72- assert ctx .project == "env-proj"
73-
74- def test_from_params_no_project_raises (self , monkeypatch ):
75- monkeypatch .delenv ("KERAS_REMOTE_PROJECT" , raising = False )
76-
77- with pytest .raises (ValueError , match = "project must be specified" ):
59+ self .assertEqual (ctx .zone , "us-west1-a" )
60+ self .assertEqual (ctx .project , "explicit-proj" )
61+ self .assertEqual (ctx .accelerator , "l4" )
62+ self .assertEqual (ctx .container_image , "my-image:latest" )
63+ self .assertEqual (ctx .args , (1 , 2 ))
64+ self .assertEqual (ctx .kwargs , {"k" : "v" })
65+ self .assertEqual (ctx .env_vars , {"X" : "Y" })
66+
67+ def test_from_params_resolves_zone_from_env (self ):
68+ with mock .patch .dict (
69+ os .environ ,
70+ {"KERAS_REMOTE_ZONE" : "asia-east1-c" , "KERAS_REMOTE_PROJECT" : "env-proj" },
71+ ):
72+ ctx = JobContext .from_params (
73+ func = self ._make_func (),
74+ args = (),
75+ kwargs = {},
76+ accelerator = "cpu" ,
77+ container_image = None ,
78+ zone = None ,
79+ project = None ,
80+ env_vars = {},
81+ )
82+ self .assertEqual (ctx .zone , "asia-east1-c" )
83+ self .assertEqual (ctx .project , "env-proj" )
84+
85+ def test_from_params_no_project_raises (self ):
86+ env = {k : v for k , v in os .environ .items () if k != "KERAS_REMOTE_PROJECT" }
87+ with (
88+ mock .patch .dict (os .environ , env , clear = True ),
89+ self .assertRaisesRegex (ValueError , "project must be specified" ),
90+ ):
7891 JobContext .from_params (
7992 func = self ._make_func (),
8093 args = (),
@@ -87,29 +100,36 @@ def test_from_params_no_project_raises(self, monkeypatch):
87100 )
88101
89102
90- class TestFindRequirements :
91- def test_finds_in_start_dir (self , tmp_path ):
103+ class TestFindRequirements ( absltest . TestCase ) :
104+ def test_finds_in_start_dir (self ):
92105 """Returns the path when requirements.txt exists in the start directory."""
106+ tmp_path = _make_temp_path (self )
93107 (tmp_path / "requirements.txt" ).write_text ("numpy\n " )
94- assert _find_requirements (str (tmp_path )) == str (
95- tmp_path / "requirements.txt"
108+ self .assertEqual (
109+ _find_requirements (str (tmp_path )),
110+ str (tmp_path / "requirements.txt" ),
96111 )
97112
98- def test_finds_in_parent_dir (self , tmp_path ):
113+ def test_finds_in_parent_dir (self ):
99114 """Walks up the directory tree to find requirements.txt in a parent."""
115+ tmp_path = _make_temp_path (self )
100116 (tmp_path / "requirements.txt" ).write_text ("numpy\n " )
101117 child = tmp_path / "subdir"
102118 child .mkdir ()
103- assert _find_requirements (str (child )) == str (tmp_path / "requirements.txt" )
119+ self .assertEqual (
120+ _find_requirements (str (child )),
121+ str (tmp_path / "requirements.txt" ),
122+ )
104123
105- def test_returns_none_when_not_found (self , tmp_path ):
124+ def test_returns_none_when_not_found (self ):
106125 """Returns None when no requirements.txt exists in any ancestor."""
126+ tmp_path = _make_temp_path (self )
107127 empty = tmp_path / "empty"
108128 empty .mkdir ()
109- assert _find_requirements (str (empty )) is None
129+ self . assertIsNone ( _find_requirements (str (empty )))
110130
111131
112- class TestExecuteRemote :
132+ class TestExecuteRemote ( absltest . TestCase ) :
113133 def _make_func (self ):
114134 def my_train ():
115135 return 42
@@ -128,38 +148,44 @@ def _make_ctx(self, container_image=None):
128148 project = "proj" ,
129149 )
130150
131- def test_success_flow (self , mocker ):
132- mocker .patch ("keras_remote.backend.execution._build_container" )
133- mocker .patch ("keras_remote.backend.execution._upload_artifacts" )
134- mocker .patch (
135- "keras_remote.backend.execution._download_result" ,
136- return_value = {"success" : True , "result" : 42 },
137- )
138- mocker .patch (
139- "keras_remote.backend.execution._cleanup_and_return" ,
140- return_value = 42 ,
141- )
142-
143- ctx = self ._make_ctx ()
144- backend = MagicMock ()
145-
146- result = execute_remote (ctx , backend )
147-
148- backend .submit_job .assert_called_once_with (ctx )
149- backend .wait_for_job .assert_called_once ()
150- backend .cleanup_job .assert_called_once ()
151- assert result == 42
152-
153- def test_cleanup_on_wait_failure (self , mocker ):
154- mocker .patch ("keras_remote.backend.execution._build_container" )
155- mocker .patch ("keras_remote.backend.execution._upload_artifacts" )
156-
157- ctx = self ._make_ctx ()
158- backend = MagicMock ()
159- backend .wait_for_job .side_effect = RuntimeError ("job failed" )
160-
161- with pytest .raises (RuntimeError , match = "job failed" ):
162- execute_remote (ctx , backend )
163-
164- # cleanup_job is called in finally block even when wait fails
165- backend .cleanup_job .assert_called_once ()
151+ def test_success_flow (self ):
152+ with (
153+ mock .patch ("keras_remote.backend.execution._build_container" ),
154+ mock .patch ("keras_remote.backend.execution._upload_artifacts" ),
155+ mock .patch (
156+ "keras_remote.backend.execution._download_result" ,
157+ return_value = {"success" : True , "result" : 42 },
158+ ),
159+ mock .patch (
160+ "keras_remote.backend.execution._cleanup_and_return" ,
161+ return_value = 42 ,
162+ ),
163+ ):
164+ ctx = self ._make_ctx ()
165+ backend = MagicMock ()
166+
167+ result = execute_remote (ctx , backend )
168+
169+ backend .submit_job .assert_called_once_with (ctx )
170+ backend .wait_for_job .assert_called_once ()
171+ backend .cleanup_job .assert_called_once ()
172+ self .assertEqual (result , 42 )
173+
174+ def test_cleanup_on_wait_failure (self ):
175+ with (
176+ mock .patch ("keras_remote.backend.execution._build_container" ),
177+ mock .patch ("keras_remote.backend.execution._upload_artifacts" ),
178+ ):
179+ ctx = self ._make_ctx ()
180+ backend = MagicMock ()
181+ backend .wait_for_job .side_effect = RuntimeError ("job failed" )
182+
183+ with self .assertRaisesRegex (RuntimeError , "job failed" ):
184+ execute_remote (ctx , backend )
185+
186+ # cleanup_job is called in finally block even when wait fails
187+ backend .cleanup_job .assert_called_once ()
188+
189+
190+ if __name__ == "__main__" :
191+ absltest .main ()
0 commit comments