Skip to content

Commit 0f778c8

Browse files
committed
resume and async resume
1 parent ab53fad commit 0f778c8

File tree

1 file changed

+52
-31
lines changed

1 file changed

+52
-31
lines changed

metaflow/metaflow_runner.py

+52-31
Original file line numberDiff line numberDiff line change
@@ -109,61 +109,82 @@ def __init__(
109109
if profile:
110110
self.env_vars["METAFLOW_PROFILE"] = profile
111111
self.spm = SubprocessManager()
112-
self.api = MetaflowAPI.from_cli(self.flow_file, start)(**kwargs)
112+
self.top_level_kwargs = kwargs
113+
self.api = MetaflowAPI.from_cli(self.flow_file, start)
113114

114115
def __enter__(self):
115116
return self
116117

117118
async def __aenter__(self):
118119
return self
119120

121+
def __get_executing_run(self, tfp_pathspec, command_obj):
122+
try:
123+
pathspec = read_from_file_when_ready(tfp_pathspec.name, timeout=5)
124+
run_object = Run(pathspec, _namespace_check=False)
125+
return ExecutingRun(self, command_obj, run_object)
126+
except TimeoutError as e:
127+
stdout_log = open(command_obj.log_files["stdout"]).read()
128+
stderr_log = open(command_obj.log_files["stderr"]).read()
129+
command = " ".join(command_obj.command)
130+
error_message = "Error executing: '%s':\n" % command
131+
if stdout_log.strip():
132+
error_message += "\nStdout:\n%s\n" % stdout_log
133+
if stderr_log.strip():
134+
error_message += "\nStderr:\n%s\n" % stderr_log
135+
raise RuntimeError(error_message) from e
136+
120137
def run(self, **kwargs):
121138
with tempfile.TemporaryDirectory() as temp_dir:
122139
tfp_pathspec = tempfile.NamedTemporaryFile(dir=temp_dir, delete=False)
123-
command = self.api.run(pathspec_file=tfp_pathspec.name, **kwargs)
140+
command = self.api(**self.top_level_kwargs).run(
141+
pathspec_file=tfp_pathspec.name, **kwargs
142+
)
124143

125144
pid = self.spm.run_command([sys.executable, *command], env=self.env_vars)
126145
command_obj = self.spm.get(pid)
127146

128-
try:
129-
pathspec = read_from_file_when_ready(tfp_pathspec.name, timeout=5)
130-
run_object = Run(pathspec, _namespace_check=False)
131-
return ExecutingRun(self, command_obj, run_object)
132-
except TimeoutError as e:
133-
stdout_log = open(command_obj.log_files["stdout"]).read()
134-
stderr_log = open(command_obj.log_files["stderr"]).read()
135-
command = " ".join(command_obj.command)
136-
error_message = "Error executing: '%s':\n" % command
137-
if stdout_log.strip():
138-
error_message += "\nStdout:\n%s\n" % stdout_log
139-
if stderr_log.strip():
140-
error_message += "\nStderr:\n%s\n" % stderr_log
141-
raise RuntimeError(error_message) from e
147+
return self.__get_executing_run(tfp_pathspec, command_obj)
148+
149+
def resume(self, **kwargs):
150+
with tempfile.TemporaryDirectory() as temp_dir:
151+
tfp_pathspec = tempfile.NamedTemporaryFile(dir=temp_dir, delete=False)
152+
command = self.api(**self.top_level_kwargs).resume(
153+
pathspec_file=tfp_pathspec.name, **kwargs
154+
)
155+
156+
pid = self.spm.run_command([sys.executable, *command], env=self.env_vars)
157+
command_obj = self.spm.get(pid)
158+
159+
return self.__get_executing_run(tfp_pathspec, command_obj)
142160

143161
async def async_run(self, **kwargs):
144162
with tempfile.TemporaryDirectory() as temp_dir:
145163
tfp_pathspec = tempfile.NamedTemporaryFile(dir=temp_dir, delete=False)
146-
command = self.api.run(pathspec_file=tfp_pathspec.name, **kwargs)
164+
command = self.api(**self.top_level_kwargs).run(
165+
pathspec_file=tfp_pathspec.name, **kwargs
166+
)
167+
168+
pid = await self.spm.async_run_command(
169+
[sys.executable, *command], env=self.env_vars
170+
)
171+
command_obj = self.spm.get(pid)
172+
173+
return self.__get_executing_run(tfp_pathspec, command_obj)
174+
175+
async def async_resume(self, **kwargs):
176+
with tempfile.TemporaryDirectory() as temp_dir:
177+
tfp_pathspec = tempfile.NamedTemporaryFile(dir=temp_dir, delete=False)
178+
command = self.api(**self.top_level_kwargs).resume(
179+
pathspec_file=tfp_pathspec.name, **kwargs
180+
)
147181

148182
pid = await self.spm.async_run_command(
149183
[sys.executable, *command], env=self.env_vars
150184
)
151185
command_obj = self.spm.get(pid)
152186

153-
try:
154-
pathspec = read_from_file_when_ready(tfp_pathspec.name, timeout=5)
155-
run_object = Run(pathspec, _namespace_check=False)
156-
return ExecutingRun(self, command_obj, run_object)
157-
except TimeoutError as e:
158-
stdout_log = open(command_obj.log_files["stdout"]).read()
159-
stderr_log = open(command_obj.log_files["stderr"]).read()
160-
command = " ".join(command_obj.command)
161-
error_message = "Error executing: '%s':\n" % command
162-
if stdout_log.strip():
163-
error_message += "\nStdout:\n%s\n" % stdout_log
164-
if stderr_log.strip():
165-
error_message += "\nStderr:\n%s\n" % stderr_log
166-
raise RuntimeError(error_message) from e
187+
return self.__get_executing_run(tfp_pathspec, command_obj)
167188

168189
def __exit__(self, exc_type, exc_value, traceback):
169190
self.spm.cleanup()

0 commit comments

Comments
 (0)