Skip to content
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.

Commit 2c4ec6e

Browse files
authoredMay 21, 2025··
Consolidate bastion _kill_job logic (#1193)
1 parent 765d92b commit 2c4ec6e

File tree

2 files changed

+51
-25
lines changed

2 files changed

+51
-25
lines changed
 

‎axlearn/cloud/common/bastion.py

Lines changed: 19 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -61,6 +61,7 @@
6161
import tempfile
6262
import time
6363
from concurrent.futures import ThreadPoolExecutor, wait
64+
from contextlib import suppress
6465
from datetime import datetime, timezone
6566
from subprocess import CalledProcessError
6667
from typing import IO, Any, NamedTuple, Optional, Union
@@ -884,8 +885,20 @@ def _kill_job(self, job: Job):
884885
"""
885886
if job.command_proc is not None:
886887
self._wait_and_close_proc(job.command_proc, kill=True)
888+
job.command_proc = None
887889
if job.cleanup_proc is not None:
888890
self._wait_and_close_proc(job.cleanup_proc, kill=True)
891+
job.cleanup_proc = None
892+
893+
def _remove_local_job(self, job: Job):
894+
"""Removes a job that is being tracked by self._active_jobs locally."""
895+
try:
896+
self._kill_job(job)
897+
del self._active_jobs[job.spec.name]
898+
logging.info("Removed job %s.", job.spec.name)
899+
except Exception as e: # pylint: disable=broad-except
900+
logging.warning("Fail to remove a job %s with error: %s", job.spec.name, e)
901+
raise
889902

890903
def _sync_jobs(self):
891904
"""Makes the local bastion state consistent with the remote state.
@@ -940,9 +953,7 @@ def _sync_jobs(self):
940953
job = self._active_jobs[job_name]
941954
if job.state.status != JobStatus.COMPLETED:
942955
logging.warning("Detected orphaned job %s! Killing it...", job.spec.name)
943-
self._kill_job(job)
944-
logging.info("Removed job %s.", job_name)
945-
del self._active_jobs[job_name]
956+
self._remove_local_job(job)
946957
# Detected updated job: exists in both.
947958
else:
948959
curr_job = self._active_jobs[job_name]
@@ -1286,12 +1297,11 @@ def execute(self):
12861297
self._execute()
12871298
except Exception:
12881299
logging.error("Caught exception, will cleanup all child jobs.")
1289-
for job in self._active_jobs.values():
1290-
try:
1291-
self._kill_job(job)
1292-
except Exception as e: # pylint: disable=broad-except
1293-
logging.warning("Fail to kill a job with error: %s", e)
1294-
self._active_jobs = {}
1300+
for job in [*self._active_jobs.values()]:
1301+
# Gracefully attempt to kill each job, ignoring any exception without
1302+
# affecting cleanup of the next job.
1303+
with suppress(Exception): # pylint: disable=broad-except
1304+
self._remove_local_job(job)
12951305
self._uploader.cleanup()
12961306
raise # Re-raise.
12971307

‎axlearn/cloud/common/bastion_test.py

Lines changed: 32 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -1573,15 +1573,17 @@ def test_update_scheduler(
15731573
def test_exception(self):
15741574
patch_signal = mock.patch(f"{bastion.__name__}.send_signal")
15751575
with patch_signal, self._patch_bastion() as mock_bastion:
1576+
mock_command_proc = mock.Mock()
1577+
mock_cleanup_proc = mock.Mock()
15761578
mock_job = Job(
15771579
spec=mock.Mock(),
15781580
state=mock.Mock(),
1579-
command_proc=mock.Mock(),
1580-
cleanup_proc=mock.Mock(),
1581+
command_proc=mock_command_proc,
1582+
cleanup_proc=mock_cleanup_proc,
15811583
)
15821584

15831585
def mock_execute():
1584-
mock_bastion._active_jobs["test"] = mock_job
1586+
mock_bastion._active_jobs[mock_job.spec.name] = mock_job
15851587
raise ValueError("Mock error")
15861588

15871589
with mock.patch.multiple(
@@ -1591,14 +1593,19 @@ def mock_execute():
15911593
) as mock_methods:
15921594
with self.assertRaisesRegex(ValueError, "Mock error"):
15931595
mock_bastion.execute()
1594-
self.assertIn(
1595-
mock_job.command_proc, mock_methods["_wait_and_close_proc"].call_args_list[0][0]
1596-
)
1597-
self.assertIn(
1598-
mock_job.cleanup_proc, mock_methods["_wait_and_close_proc"].call_args_list[1][0]
1599-
)
1596+
mock_wait_and_close_proc = mock_methods["_wait_and_close_proc"]
1597+
self.assertIn(mock_command_proc, mock_wait_and_close_proc.call_args_list[0][0])
1598+
self.assertIn(mock_cleanup_proc, mock_wait_and_close_proc.call_args_list[1][0])
16001599

1601-
def test_execute_with_exception_and_job_failure(self):
1600+
@parameterized.product(
1601+
kill_job1_error=[None, Exception("Cannot kill job1")],
1602+
kill_job2_error=[None, Exception("Cannot kill job2")],
1603+
)
1604+
def test_execute_with_exception_and_job_failure(
1605+
self,
1606+
kill_job1_error: Optional[Exception],
1607+
kill_job2_error: Optional[Exception],
1608+
):
16021609
job_1 = Job(
16031610
spec=mock.Mock(),
16041611
state=mock.Mock(),
@@ -1612,22 +1619,31 @@ def test_execute_with_exception_and_job_failure(self):
16121619
cleanup_proc=mock.Mock(),
16131620
)
16141621
active_jobs = {
1615-
"job1": job_1,
1616-
"job2": job_2,
1622+
job_1.spec.name: job_1,
1623+
job_2.spec.name: job_2,
16171624
}
16181625

16191626
with self._patch_bastion() as mock_bastion:
16201627
mock_bastion._execute = mock.Mock(side_effect=Exception("Execution failed"))
1621-
mock_bastion._kill_job = mock.Mock(side_effect=[Exception("Cannot kill job"), None])
1622-
1628+
mock_bastion._kill_job = mock.Mock(side_effect=[kill_job1_error, kill_job2_error])
1629+
mock_bastion._remove_local_job = mock.Mock(wraps=mock_bastion._remove_local_job)
16231630
mock_bastion._active_jobs = active_jobs
16241631

16251632
with self.assertRaises(Exception):
16261633
mock_bastion.execute()
16271634

1628-
self.assertEqual(mock_bastion._kill_job.call_count, 2)
1635+
self.assertEqual(mock_bastion._remove_local_job.call_count, 2)
16291636
expected_calls = [mock.call(job_1), mock.call(job_2)]
1630-
self.assertEqual(mock_bastion._kill_job.call_args_list, expected_calls)
1637+
self.assertEqual(mock_bastion._remove_local_job.call_args_list, expected_calls)
1638+
# A job remains if and only if there is exception during the clean up process.
1639+
self.assertEqual(
1640+
job_1 in mock_bastion._active_jobs.values(),
1641+
kill_job1_error is not None,
1642+
)
1643+
self.assertEqual(
1644+
job_2 in mock_bastion._active_jobs.values(),
1645+
kill_job2_error is not None,
1646+
)
16311647

16321648
def test_sync_jobs_for_valid_pending_to_sudden_invalid_jobs(self):
16331649
"""Test behavior of state transition for pending invalid jobs."""

0 commit comments

Comments
 (0)
Please sign in to comment.