Skip to content

Commit c8b9ab2

Browse files
committed
resume and async resume
1 parent 4d6460a commit c8b9ab2

File tree

1 file changed

+101
-42
lines changed

1 file changed

+101
-42
lines changed

metaflow/runner/metaflow_runner.py

+101-42
Original file line numberDiff line numberDiff line change
@@ -208,6 +208,7 @@ def __init__(
208208
if profile:
209209
self.env_vars["METAFLOW_PROFILE"] = profile
210210
self.spm = SubprocessManager()
211+
self.top_level_kwargs = kwargs
211212
self.api = MetaflowAPI.from_cli(self.flow_file, start)
212213

213214
def __enter__(self) -> "Runner":
@@ -216,19 +217,35 @@ def __enter__(self) -> "Runner":
216217
async def __aenter__(self) -> "Runner":
217218
return self
218219

219-
def __exit__(self, exc_type, exc_value, traceback):
220-
self.spm.cleanup()
221-
222-
async def __aexit__(self, exc_type, exc_value, traceback):
223-
self.spm.cleanup()
224-
225-
def run(self, **kwargs) -> ExecutingRun:
220+
def __get_executing_run(self, tfp_pathspec, command_obj):
221+
try:
222+
pathspec = read_from_file_when_ready(tfp_pathspec.name, timeout=5)
223+
run_object = Run(pathspec, _namespace_check=False)
224+
return ExecutingRun(self, command_obj, run_object)
225+
except TimeoutError as e:
226+
stdout_log = open(command_obj.log_files["stdout"]).read()
227+
stderr_log = open(command_obj.log_files["stderr"]).read()
228+
command = " ".join(command_obj.command)
229+
error_message = "Error executing: '%s':\n" % command
230+
if stdout_log.strip():
231+
error_message += "\nStdout:\n%s\n" % stdout_log
232+
if stderr_log.strip():
233+
error_message += "\nStderr:\n%s\n" % stderr_log
234+
raise RuntimeError(error_message) from e
235+
236+
def run(self, show_output: bool = False, **kwargs) -> ExecutingRun:
226237
"""
227238
Synchronous execution of the run. This method will *block* until
228239
the run has completed execution.
229240
230241
Parameters
231242
----------
243+
show_output : bool, default False
244+
Suppress the 'stdout' and 'stderr' to the console by default.
245+
They can be accessed later by reading the files present in the
246+
ExecutingRun object (referenced as 'result' below) returned:
247+
- result.stdout
248+
- result.stderr
232249
**kwargs : Any
233250
Additional arguments that you would pass to `python ./myflow.py` after
234251
the `run` command.
@@ -240,25 +257,51 @@ def run(self, **kwargs) -> ExecutingRun:
240257
"""
241258
with tempfile.TemporaryDirectory() as temp_dir:
242259
tfp_pathspec = tempfile.NamedTemporaryFile(dir=temp_dir, delete=False)
243-
command = self.api.run(pathspec_file=tfp_pathspec.name, **kwargs)
260+
command = self.api(**self.top_level_kwargs).run(
261+
pathspec_file=tfp_pathspec.name, **kwargs
262+
)
244263

245-
pid = self.spm.run_command([sys.executable, *command], env=self.env_vars)
264+
pid = self.spm.run_command(
265+
[sys.executable, *command], env=self.env_vars, show_output=show_output
266+
)
246267
command_obj = self.spm.get(pid)
247268

248-
try:
249-
pathspec = read_from_file_when_ready(tfp_pathspec.name, timeout=5)
250-
run_object = Run(pathspec, _namespace_check=False)
251-
return ExecutingRun(self, command_obj, run_object)
252-
except TimeoutError as e:
253-
stdout_log = open(command_obj.log_files["stdout"]).read()
254-
stderr_log = open(command_obj.log_files["stderr"]).read()
255-
command = " ".join(command_obj.command)
256-
error_message = "Error executing: '%s':\n" % command
257-
if stdout_log.strip():
258-
error_message += "\nStdout:\n%s\n" % stdout_log
259-
if stderr_log.strip():
260-
error_message += "\nStderr:\n%s\n" % stderr_log
261-
raise RuntimeError(error_message) from e
269+
return self.__get_executing_run(tfp_pathspec, command_obj)
270+
271+
def resume(self, show_output: bool = False, **kwargs):
272+
"""
273+
Synchronous resume execution of the run.
274+
This method will *block* until the resumed run has completed execution.
275+
276+
Parameters
277+
----------
278+
show_output : bool, default False
279+
Suppress the 'stdout' and 'stderr' to the console by default.
280+
They can be accessed later by reading the files present in the
281+
ExecutingRun object (referenced as 'result' below) returned:
282+
- result.stdout
283+
- result.stderr
284+
**kwargs : Any
285+
Additional arguments that you would pass to `python ./myflow.py` after
286+
the `resume` command.
287+
288+
Returns
289+
-------
290+
ExecutingRun
291+
ExecutingRun object for this resumed run.
292+
"""
293+
with tempfile.TemporaryDirectory() as temp_dir:
294+
tfp_pathspec = tempfile.NamedTemporaryFile(dir=temp_dir, delete=False)
295+
command = self.api(**self.top_level_kwargs).resume(
296+
pathspec_file=tfp_pathspec.name, **kwargs
297+
)
298+
299+
pid = self.spm.run_command(
300+
[sys.executable, *command], env=self.env_vars, show_output=show_output
301+
)
302+
command_obj = self.spm.get(pid)
303+
304+
return self.__get_executing_run(tfp_pathspec, command_obj)
262305

263306
async def async_run(self, **kwargs) -> ExecutingRun:
264307
"""
@@ -278,32 +321,48 @@ async def async_run(self, **kwargs) -> ExecutingRun:
278321
"""
279322
with tempfile.TemporaryDirectory() as temp_dir:
280323
tfp_pathspec = tempfile.NamedTemporaryFile(dir=temp_dir, delete=False)
281-
command = self.api.run(pathspec_file=tfp_pathspec.name, **kwargs)
324+
command = self.api(**self.top_level_kwargs).run(
325+
pathspec_file=tfp_pathspec.name, **kwargs
326+
)
282327

283328
pid = await self.spm.async_run_command(
284329
[sys.executable, *command], env=self.env_vars
285330
)
286331
command_obj = self.spm.get(pid)
287332

288-
try:
289-
pathspec = read_from_file_when_ready(tfp_pathspec.name, timeout=5)
290-
run_object = Run(pathspec, _namespace_check=False)
291-
return ExecutingRun(self, command_obj, run_object)
292-
except TimeoutError as e:
293-
stdout_log = open(
294-
command_obj.log_files["stdout"], encoding="utf-8"
295-
).read()
296-
stderr_log = open(
297-
command_obj.log_files["stderr"], encoding="utf-8"
298-
).read()
299-
command = " ".join(command_obj.command)
333+
return self.__get_executing_run(tfp_pathspec, command_obj)
300334

301-
error_message = "Error executing: '%s':\n" % command
335+
async def async_resume(self, **kwargs):
336+
"""
337+
Asynchronous resume execution of the run.
338+
This method will return as soon as the resume has launched.
302339
303-
if stdout_log.strip():
304-
error_message += "\nStdout:\n%s\n" % stdout_log
340+
Parameters
341+
----------
342+
**kwargs : Any
343+
Additional arguments that you would pass to `python ./myflow.py` after
344+
the `resume` command.
305345
306-
if stderr_log.strip():
307-
error_message += "\nStderr:\n%s\n" % stderr_log
346+
Returns
347+
-------
348+
ExecutingRun
349+
ExecutingRun object for this resumed run.
350+
"""
351+
with tempfile.TemporaryDirectory() as temp_dir:
352+
tfp_pathspec = tempfile.NamedTemporaryFile(dir=temp_dir, delete=False)
353+
command = self.api(**self.top_level_kwargs).resume(
354+
pathspec_file=tfp_pathspec.name, **kwargs
355+
)
356+
357+
pid = await self.spm.async_run_command(
358+
[sys.executable, *command], env=self.env_vars
359+
)
360+
command_obj = self.spm.get(pid)
308361

309-
raise RuntimeError(error_message) from e
362+
return self.__get_executing_run(tfp_pathspec, command_obj)
363+
364+
def __exit__(self, exc_type, exc_value, traceback):
365+
self.spm.cleanup()
366+
367+
async def __aexit__(self, exc_type, exc_value, traceback):
368+
self.spm.cleanup()

0 commit comments

Comments
 (0)