@@ -109,61 +109,82 @@ def __init__(
109
109
if profile :
110
110
self .env_vars ["METAFLOW_PROFILE" ] = profile
111
111
self .spm = SubprocessManager ()
112
- self .api = MetaflowAPI .from_cli (self .flow_file , start )(** kwargs )
112
+ self .top_level_kwargs = kwargs
113
+ self .api = MetaflowAPI .from_cli (self .flow_file , start )
113
114
114
115
def __enter__ (self ):
115
116
return self
116
117
117
118
async def __aenter__ (self ):
118
119
return self
119
120
121
+ def __get_executing_run (self , tfp_pathspec , command_obj ):
122
+ try :
123
+ pathspec = read_from_file_when_ready (tfp_pathspec .name , timeout = 5 )
124
+ run_object = Run (pathspec , _namespace_check = False )
125
+ return ExecutingRun (self , command_obj , run_object )
126
+ except TimeoutError as e :
127
+ stdout_log = open (command_obj .log_files ["stdout" ]).read ()
128
+ stderr_log = open (command_obj .log_files ["stderr" ]).read ()
129
+ command = " " .join (command_obj .command )
130
+ error_message = "Error executing: '%s':\n " % command
131
+ if stdout_log .strip ():
132
+ error_message += "\n Stdout:\n %s\n " % stdout_log
133
+ if stderr_log .strip ():
134
+ error_message += "\n Stderr:\n %s\n " % stderr_log
135
+ raise RuntimeError (error_message ) from e
136
+
120
137
def run (self , ** kwargs ):
121
138
with tempfile .TemporaryDirectory () as temp_dir :
122
139
tfp_pathspec = tempfile .NamedTemporaryFile (dir = temp_dir , delete = False )
123
- command = self .api .run (pathspec_file = tfp_pathspec .name , ** kwargs )
140
+ command = self .api (** self .top_level_kwargs ).run (
141
+ pathspec_file = tfp_pathspec .name , ** kwargs
142
+ )
124
143
125
144
pid = self .spm .run_command ([sys .executable , * command ], env = self .env_vars )
126
145
command_obj = self .spm .get (pid )
127
146
128
- try :
129
- pathspec = read_from_file_when_ready (tfp_pathspec .name , timeout = 5 )
130
- run_object = Run (pathspec , _namespace_check = False )
131
- return ExecutingRun (self , command_obj , run_object )
132
- except TimeoutError as e :
133
- stdout_log = open (command_obj .log_files ["stdout" ]).read ()
134
- stderr_log = open (command_obj .log_files ["stderr" ]).read ()
135
- command = " " .join (command_obj .command )
136
- error_message = "Error executing: '%s':\n " % command
137
- if stdout_log .strip ():
138
- error_message += "\n Stdout:\n %s\n " % stdout_log
139
- if stderr_log .strip ():
140
- error_message += "\n Stderr:\n %s\n " % stderr_log
141
- raise RuntimeError (error_message ) from e
147
+ return self .__get_executing_run (tfp_pathspec , command_obj )
148
+
149
+ def resume (self , ** kwargs ):
150
+ with tempfile .TemporaryDirectory () as temp_dir :
151
+ tfp_pathspec = tempfile .NamedTemporaryFile (dir = temp_dir , delete = False )
152
+ command = self .api (** self .top_level_kwargs ).resume (
153
+ pathspec_file = tfp_pathspec .name , ** kwargs
154
+ )
155
+
156
+ pid = self .spm .run_command ([sys .executable , * command ], env = self .env_vars )
157
+ command_obj = self .spm .get (pid )
158
+
159
+ return self .__get_executing_run (tfp_pathspec , command_obj )
142
160
143
161
async def async_run (self , ** kwargs ):
144
162
with tempfile .TemporaryDirectory () as temp_dir :
145
163
tfp_pathspec = tempfile .NamedTemporaryFile (dir = temp_dir , delete = False )
146
- command = self .api .run (pathspec_file = tfp_pathspec .name , ** kwargs )
164
+ command = self .api (** self .top_level_kwargs ).run (
165
+ pathspec_file = tfp_pathspec .name , ** kwargs
166
+ )
167
+
168
+ pid = await self .spm .async_run_command (
169
+ [sys .executable , * command ], env = self .env_vars
170
+ )
171
+ command_obj = self .spm .get (pid )
172
+
173
+ return self .__get_executing_run (tfp_pathspec , command_obj )
174
+
175
+ async def async_resume (self , ** kwargs ):
176
+ with tempfile .TemporaryDirectory () as temp_dir :
177
+ tfp_pathspec = tempfile .NamedTemporaryFile (dir = temp_dir , delete = False )
178
+ command = self .api (** self .top_level_kwargs ).resume (
179
+ pathspec_file = tfp_pathspec .name , ** kwargs
180
+ )
147
181
148
182
pid = await self .spm .async_run_command (
149
183
[sys .executable , * command ], env = self .env_vars
150
184
)
151
185
command_obj = self .spm .get (pid )
152
186
153
- try :
154
- pathspec = read_from_file_when_ready (tfp_pathspec .name , timeout = 5 )
155
- run_object = Run (pathspec , _namespace_check = False )
156
- return ExecutingRun (self , command_obj , run_object )
157
- except TimeoutError as e :
158
- stdout_log = open (command_obj .log_files ["stdout" ]).read ()
159
- stderr_log = open (command_obj .log_files ["stderr" ]).read ()
160
- command = " " .join (command_obj .command )
161
- error_message = "Error executing: '%s':\n " % command
162
- if stdout_log .strip ():
163
- error_message += "\n Stdout:\n %s\n " % stdout_log
164
- if stderr_log .strip ():
165
- error_message += "\n Stderr:\n %s\n " % stderr_log
166
- raise RuntimeError (error_message ) from e
187
+ return self .__get_executing_run (tfp_pathspec , command_obj )
167
188
168
189
def __exit__ (self , exc_type , exc_value , traceback ):
169
190
self .spm .cleanup ()
0 commit comments