Skip to content

Commit 829c816

Browse files
committed
Fix local executor tests
Also demo file transfers in local executor
1 parent fd82b5e commit 829c816

File tree

3 files changed

+105
-42
lines changed

3 files changed

+105
-42
lines changed

covalent/executor/executor_plugins/local.py

+9-4
Original file line numberDiff line numberDiff line change
@@ -162,24 +162,26 @@ def _send(
162162
resources: ResourceMap,
163163
task_group_metadata: dict,
164164
):
165+
os.makedirs(self.workdir, exist_ok=True)
165166
dispatch_id = task_group_metadata["dispatch_id"]
166167
task_ids = task_group_metadata["node_ids"]
167168
gid = task_group_metadata["task_group_id"]
168169
output_uris = []
169170
for node_id in task_ids:
170-
result_uri = os.path.join(self.cache_dir, f"result_{dispatch_id}-{node_id}.pkl")
171-
stdout_uri = os.path.join(self.cache_dir, f"stdout_{dispatch_id}-{node_id}.txt")
172-
stderr_uri = os.path.join(self.cache_dir, f"stderr_{dispatch_id}-{node_id}.txt")
171+
result_uri = os.path.join(self.workdir, f"result_{dispatch_id}-{node_id}.pkl")
172+
stdout_uri = os.path.join(self.workdir, f"stdout_{dispatch_id}-{node_id}.txt")
173+
stderr_uri = os.path.join(self.workdir, f"stderr_{dispatch_id}-{node_id}.txt")
173174
output_uris.append((result_uri, stdout_uri, stderr_uri))
174175

175176
server_url = format_server_url()
176177

177178
app_log.debug(f"Running task group {dispatch_id}:{task_ids}")
179+
app_log.debug(f"Generated artifacts will be saved at: {output_uris}")
178180
future = proc_pool.submit(
179181
run_task_group,
180182
list(map(lambda t: t.model_dump(), task_specs)),
181183
output_uris,
182-
self.cache_dir,
184+
self.workdir,
183185
task_group_metadata,
184186
server_url,
185187
)
@@ -190,6 +192,9 @@ def handle_cancelled(fut):
190192
if ex is not None:
191193
tb = "".join(traceback.TracebackException.from_exception(ex).format())
192194
app_log.debug(tb)
195+
for task_id in task_ids:
196+
url = f"{server_url}/api/v2/dispatches/{dispatch_id}/electrons/{task_id}/job"
197+
requests.put(url, json={"status": "FAILED"})
193198
if fut.cancelled():
194199
for task_id in task_ids:
195200
url = f"{server_url}/api/v2/dispatches/{dispatch_id}/electrons/{task_id}/job"

covalent/executor/utils/wrappers.py

+7-15
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@
2929

3030
import requests
3131

32+
from covalent._file_transfer import FileTransfer
3233
from covalent._workflow.depsbash import DepsBash
3334
from covalent._workflow.depscall import RESERVED_RETVAL_KEY__FILES, DepsCall
3435
from covalent._workflow.depspip import DepsPip
@@ -169,14 +170,10 @@ def run_task_group(
169170
170171
"""
171172

172-
prefix = "file://"
173-
prefix_len = len(prefix)
174-
175173
outputs = {}
176174
results = []
177175
dispatch_id = task_group_metadata["dispatch_id"]
178176
task_ids = task_group_metadata["node_ids"]
179-
gid = task_group_metadata["task_group_id"]
180177

181178
os.environ["COVALENT_DISPATCH_ID"] = dispatch_id
182179
os.environ["COVALENT_DISPATCHER_URL"] = server_url
@@ -288,9 +285,8 @@ def run_task_group(
288285
uri_resp = requests.post(upload_url, headers=headers)
289286
uri_resp.raise_for_status()
290287
remote_uri = uri_resp.json()["remote_uri"]
291-
292-
with open(result_uri, "rb") as f:
293-
requests.put(remote_uri, data=f)
288+
_, cp = FileTransfer(f"file://{result_uri}", remote_uri).cp()
289+
cp()
294290

295291
sys.stdout.flush()
296292
if stdout_uri:
@@ -299,10 +295,8 @@ def run_task_group(
299295
uri_resp = requests.post(upload_url, headers=headers)
300296
uri_resp.raise_for_status()
301297
remote_uri = uri_resp.json()["remote_uri"]
302-
303-
with open(stdout_uri, "rb") as f:
304-
305-
requests.put(remote_uri, data=f)
298+
_, cp = FileTransfer(f"file://{stdout_uri}", remote_uri).cp()
299+
cp()
306300

307301
sys.stderr.flush()
308302
if stderr_uri:
@@ -311,10 +305,8 @@ def run_task_group(
311305
uri_resp = requests.post(upload_url, headers=headers)
312306
uri_resp.raise_for_status()
313307
remote_uri = uri_resp.json()["remote_uri"]
314-
315-
with open(stderr_uri, "rb") as f:
316-
headers = {"Content-Length": os.path.getsize(stderr_uri)}
317-
requests.put(remote_uri, data=f)
308+
_, cp = FileTransfer(f"file://{stderr_uri}", remote_uri).cp()
309+
cp()
318310

319311
result_path = os.path.join(results_dir, f"result-{dispatch_id}:{task_id}.json")
320312

tests/covalent_tests/executor/executor_plugins/local_test.py

+89-23
Original file line numberDiff line numberDiff line change
@@ -281,47 +281,80 @@ def task(x, y):
281281
node_0_function_url = (
282282
f"{server_url}/api/v2/dispatches/{dispatch_id}/electrons/0/assets/function"
283283
)
284+
node_0_function_file_url = f"{server_url}/files/node_0_function"
285+
node_0_output_url = f"{server_url}/api/v2/dispatches/{dispatch_id}/electrons/0/assets/output"
286+
node_0_output_file_url = f"{server_url}/files/node_0_output"
287+
node_0_stdout_url = f"{server_url}/api/v2/dispatches/{dispatch_id}/electrons/0/assets/stdout"
288+
node_0_stdout_file_url = f"{server_url}/files/node_0_stdout"
289+
node_0_stderr_url = f"{server_url}/api/v2/dispatches/{dispatch_id}/electrons/0/assets/stderr"
290+
node_0_stderr_file_url = f"{server_url}/files/node_0_stderr"
284291

285292
hooks_file = tempfile.NamedTemporaryFile("wb")
286293
hooks_file.write(ser_hooks)
287294
hooks_file.flush()
288295
hooks_url = f"{server_url}/api/v2/dispatches/{dispatch_id}/electrons/0/assets/hooks"
296+
hooks_file_url = f"{server_url}/files/hooks"
289297

290298
node_1_file = tempfile.NamedTemporaryFile("wb")
291299
node_1_file.write(ser_x)
292300
node_1_file.flush()
293301
node_1_output_url = f"{server_url}/api/v2/dispatches/{dispatch_id}/electrons/1/assets/output"
302+
node_1_output_file_url = f"{server_url}/node_1_output"
294303

295304
node_2_file = tempfile.NamedTemporaryFile("wb")
296305
node_2_file.write(ser_y)
297306
node_2_file.flush()
298307
node_2_output_url = f"{server_url}/api/v2/dispatches/{dispatch_id}/electrons/2/assets/output"
308+
node_2_output_file_url = f"{server_url}/node_2_output"
299309

300310
task_spec = TaskSpec(
301311
electron_id=0,
302312
args=[1, 2],
303313
kwargs={},
304314
)
305315

316+
# GET/POST URLs
317+
url_map = {
318+
node_0_function_url: node_0_function_file_url,
319+
node_0_output_url: node_0_output_file_url,
320+
node_0_stdout_url: node_0_stdout_file_url,
321+
node_0_stderr_url: node_0_stderr_file_url,
322+
hooks_url: hooks_file_url,
323+
node_1_output_url: node_1_output_file_url,
324+
node_2_output_url: node_2_output_file_url,
325+
}
326+
327+
# GET/PUT files
306328
resources = {
307-
node_0_function_url: ser_task,
308-
node_1_output_url: ser_x,
309-
node_2_output_url: ser_y,
310-
hooks_url: ser_hooks,
329+
node_0_function_file_url: ser_task,
330+
node_1_output_file_url: ser_x,
331+
node_2_output_file_url: ser_y,
332+
hooks_file_url: ser_hooks,
311333
}
312334

313-
def mock_req_get(url, stream):
335+
def mock_req_get(url, **kwargs):
314336
mock_resp = MagicMock()
315337
mock_resp.status_code = 200
316-
mock_resp.content = resources[url]
338+
if url in url_map:
339+
mock_resp.json.return_value = {"remote_uri": url_map[url]}
340+
else:
341+
mock_resp.content = resources[url]
317342
return mock_resp
318343

319-
def mock_req_post(url, files):
320-
resources[url] = files["asset_file"].read()
344+
def mock_req_post(url, **kwargs):
345+
mock_resp = MagicMock()
346+
mock_resp.status_code = 200
347+
mock_resp.json.return_value = {"remote_uri": url_map[url]}
348+
return mock_resp
349+
350+
def mock_req_put(url, data=None, headers={}, json={}):
351+
if data is not None:
352+
resources[url] = data if isinstance(data, bytes) else data.read()
353+
return MagicMock()
321354

322355
mocker.patch("requests.get", mock_req_get)
323356
mocker.patch("requests.post", mock_req_post)
324-
mock_put = mocker.patch("requests.put")
357+
mocker.patch("requests.put", mock_req_put)
325358
task_group_metadata = {
326359
"dispatch_id": dispatch_id,
327360
"node_ids": [node_id],
@@ -342,8 +375,7 @@ def mock_req_post(url, files):
342375
server_url=server_url,
343376
)
344377

345-
with open(result_file.name, "rb") as f:
346-
output = TransportableObject.deserialize(f.read())
378+
output = TransportableObject.deserialize(resources[node_0_output_file_url])
347379
assert output.get_deserialized() == 3
348380

349381
with open(cb_tmpfile.name, "r") as f:
@@ -352,8 +384,6 @@ def mock_req_post(url, files):
352384
with open(ca_tmpfile.name, "r") as f:
353385
assert f.read() == "Bye\n"
354386

355-
mock_put.assert_called()
356-
357387

358388
def test_run_task_group_exception(mocker):
359389
"""Test the wrapper submitted to local"""
@@ -398,47 +428,80 @@ def task(x, y):
398428
node_0_function_url = (
399429
f"{server_url}/api/v2/dispatches/{dispatch_id}/electrons/0/assets/function"
400430
)
431+
node_0_function_file_url = f"{server_url}/files/node_0_function"
432+
node_0_output_url = f"{server_url}/api/v2/dispatches/{dispatch_id}/electrons/0/assets/output"
433+
node_0_output_file_url = f"{server_url}/files/node_0_output"
434+
node_0_stdout_url = f"{server_url}/api/v2/dispatches/{dispatch_id}/electrons/0/assets/stdout"
435+
node_0_stdout_file_url = f"{server_url}/files/node_0_stdout"
436+
node_0_stderr_url = f"{server_url}/api/v2/dispatches/{dispatch_id}/electrons/0/assets/stderr"
437+
node_0_stderr_file_url = f"{server_url}/files/node_0_stderr"
401438

402439
hooks_file = tempfile.NamedTemporaryFile("wb")
403440
hooks_file.write(ser_hooks)
404441
hooks_file.flush()
405442
hooks_url = f"{server_url}/api/v2/dispatches/{dispatch_id}/electrons/0/assets/hooks"
443+
hooks_file_url = f"{server_url}/files/hooks"
406444

407445
node_1_file = tempfile.NamedTemporaryFile("wb")
408446
node_1_file.write(ser_x)
409447
node_1_file.flush()
410448
node_1_output_url = f"{server_url}/api/v2/dispatches/{dispatch_id}/electrons/1/assets/output"
449+
node_1_output_file_url = f"{server_url}/node_1_output"
411450

412451
node_2_file = tempfile.NamedTemporaryFile("wb")
413452
node_2_file.write(ser_y)
414453
node_2_file.flush()
415454
node_2_output_url = f"{server_url}/api/v2/dispatches/{dispatch_id}/electrons/2/assets/output"
455+
node_2_output_file_url = f"{server_url}/node_2_output"
416456

417457
task_spec = TaskSpec(
418458
electron_id=0,
419459
args=[1],
420460
kwargs={"y": 2},
421461
)
422462

463+
# GET/POST URLs
464+
url_map = {
465+
node_0_function_url: node_0_function_file_url,
466+
node_0_output_url: node_0_output_file_url,
467+
node_0_stdout_url: node_0_stdout_file_url,
468+
node_0_stderr_url: node_0_stderr_file_url,
469+
hooks_url: hooks_file_url,
470+
node_1_output_url: node_1_output_file_url,
471+
node_2_output_url: node_2_output_file_url,
472+
}
473+
474+
# GET/PUT files
423475
resources = {
424-
node_0_function_url: ser_task,
425-
node_1_output_url: ser_x,
426-
node_2_output_url: ser_y,
427-
hooks_url: ser_hooks,
476+
node_0_function_file_url: ser_task,
477+
node_1_output_file_url: ser_x,
478+
node_2_output_file_url: ser_y,
479+
hooks_file_url: ser_hooks,
428480
}
429481

430-
def mock_req_get(url, stream):
482+
def mock_req_get(url, **kwargs):
431483
mock_resp = MagicMock()
432484
mock_resp.status_code = 200
433-
mock_resp.content = resources[url]
485+
if url in url_map:
486+
mock_resp.json.return_value = {"remote_uri": url_map[url]}
487+
else:
488+
mock_resp.content = resources[url]
434489
return mock_resp
435490

436-
def mock_req_post(url, files):
437-
resources[url] = files["asset_file"].read()
491+
def mock_req_post(url, **kwargs):
492+
mock_resp = MagicMock()
493+
mock_resp.status_code = 200
494+
mock_resp.json.return_value = {"remote_uri": url_map[url]}
495+
return mock_resp
496+
497+
def mock_req_put(url, data=None, headers={}, json={}):
498+
if data is not None:
499+
resources[url] = data if isinstance(data, bytes) else data.read()
500+
return MagicMock()
438501

439502
mocker.patch("requests.get", mock_req_get)
440503
mocker.patch("requests.post", mock_req_post)
441-
mocker.patch("requests.put")
504+
mocker.patch("requests.put", mock_req_put)
442505
task_group_metadata = {
443506
"dispatch_id": dispatch_id,
444507
"node_ids": [node_id],
@@ -459,6 +522,9 @@ def mock_req_post(url, files):
459522
server_url=server_url,
460523
)
461524

525+
stderr = resources[node_0_stderr_file_url].decode("utf-8")
526+
assert "AssertionError" in stderr
527+
462528
summary_file_path = f"{results_dir.name}/result-{dispatch_id}:{node_id}.json"
463529

464530
with open(summary_file_path, "r") as f:
@@ -571,7 +637,7 @@ def test_send_internal(
571637
run_task_group,
572638
list(map(lambda t: t.dict(), test_case["task_specs"])),
573639
test_case["expected_output_uris"],
574-
"mock_cache_dir",
640+
local_exec.workdir,
575641
test_case["task_group_metadata"],
576642
test_case["expected_server_url"],
577643
)

0 commit comments

Comments
 (0)