Skip to content

Commit 51ecd5f

Browse files
committed
fix(local-kill): fix local kill (#303)
1. Fix local sequential job termination — ensure pending jobs are handled correctly when killing jobs. 2. Clarify kill error message — when a kill command cannot be executed because the job is already finished or canceled. --------- Signed-off-by: Anna Warno <awarno@nvidia.com>
1 parent b6cb857 commit 51ecd5f

8 files changed

Lines changed: 153 additions & 173 deletions

File tree

packages/nemo-evaluator-launcher/src/nemo_evaluator_launcher/configs/execution/local.yaml

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,3 +16,4 @@
1616
type: local
1717
output_dir: ???
1818
extra_docker_args: ""
19+
mode: sequential

packages/nemo-evaluator-launcher/src/nemo_evaluator_launcher/executors/base.py

Lines changed: 23 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -95,3 +95,26 @@ def kill_job(job_id: str) -> None:
9595
NotImplementedError: If not implemented by a subclass.
9696
"""
9797
raise NotImplementedError("Subclasses must implement this method")
98+
99+
@staticmethod
100+
def get_kill_failure_message(
101+
job_id: str, container_or_id: str, status: Optional[ExecutionState] = None
102+
) -> str:
103+
"""Generate an informative error message when kill fails based on job status.
104+
105+
Args:
106+
job_id: The job ID that failed to kill.
107+
container_or_id: Container name, SLURM job ID, or other identifier.
108+
status: Optional execution state of the job.
109+
110+
Returns:
111+
str: An informative error message with job status context.
112+
"""
113+
if status == ExecutionState.SUCCESS:
114+
return f"Could not find or kill job {job_id} ({container_or_id}) - job already completed successfully"
115+
elif status == ExecutionState.FAILED:
116+
return f"Could not find or kill job {job_id} ({container_or_id}) - job already failed"
117+
elif status == ExecutionState.KILLED:
118+
return f"Could not find or kill job {job_id} ({container_or_id}) - job was already killed"
119+
# Generic error message
120+
return f"Could not find or kill job {job_id} ({container_or_id})"

packages/nemo-evaluator-launcher/src/nemo_evaluator_launcher/executors/lepton/executor.py

Lines changed: 17 additions & 71 deletions
Original file line numberDiff line numberDiff line change
@@ -622,76 +622,14 @@ def get_status(id: str) -> List[ExecutionStatus]:
622622
def kill_job(job_id: str) -> None:
623623
"""Kill Lepton evaluation jobs and clean up endpoints.
624624
625-
For invocation IDs, this will kill all jobs and clean up all
626-
dedicated endpoints created for the invocation.
627-
628625
Args:
629-
job_id: The job ID or invocation ID to kill.
626+
job_id: The job ID to kill.
630627
631628
Raises:
632629
ValueError: If job is not found or invalid.
633630
RuntimeError: If job cannot be killed.
634631
"""
635632
db = ExecutionDB()
636-
637-
# If it looks like an invocation_id, kill all jobs for that invocation
638-
if len(job_id) == 8 and "." not in job_id:
639-
jobs = db.get_jobs(job_id)
640-
if not jobs:
641-
raise ValueError(f"No jobs found for invocation {job_id}")
642-
643-
endpoint_names = (
644-
set()
645-
) # Use set to avoid duplicates (though each should be unique)
646-
lepton_job_names = []
647-
648-
# Collect all Lepton jobs and endpoint info
649-
for curr_job_data in jobs.values():
650-
if curr_job_data.executor != "lepton":
651-
continue
652-
653-
# Collect endpoint name for this job (each task may have its own)
654-
endpoint_name = curr_job_data.data.get("endpoint_name")
655-
if endpoint_name:
656-
endpoint_names.add(endpoint_name)
657-
658-
lepton_job_name = curr_job_data.data.get("lepton_job_name")
659-
if lepton_job_name:
660-
lepton_job_names.append(lepton_job_name)
661-
662-
# Mark job as killed in database
663-
curr_job_data.data["status"] = "killed"
664-
curr_job_data.data["killed_time"] = time.time()
665-
db.write_job(curr_job_data)
666-
667-
print(
668-
f"🛑 Killing {len(lepton_job_names)} Lepton jobs for invocation {job_id}"
669-
)
670-
671-
# Cancel all Lepton jobs
672-
for lepton_job_name in lepton_job_names:
673-
success = delete_lepton_job(lepton_job_name)
674-
if success:
675-
print(f"✅ Cancelled Lepton job: {lepton_job_name}")
676-
else:
677-
print(f"⚠️ Failed to cancel Lepton job: {lepton_job_name}")
678-
679-
# Clean up all dedicated endpoints
680-
if endpoint_names:
681-
print(f"🧹 Cleaning up {len(endpoint_names)} dedicated endpoints")
682-
for endpoint_name in endpoint_names:
683-
success = delete_lepton_endpoint(endpoint_name)
684-
if success:
685-
print(f"✅ Cleaned up endpoint: {endpoint_name}")
686-
else:
687-
print(f"⚠️ Failed to cleanup endpoint: {endpoint_name}")
688-
else:
689-
print("📌 No dedicated endpoints to clean up (using shared endpoint)")
690-
691-
print(f"🛑 Killed all resources for invocation {job_id}")
692-
return
693-
694-
# Otherwise, treat as individual job_id
695633
job_data = db.get_job(job_id)
696634
if job_data is None:
697635
raise ValueError(f"Job {job_id} not found")
@@ -703,17 +641,25 @@ def kill_job(job_id: str) -> None:
703641

704642
# Cancel the specific Lepton job
705643
lepton_job_name = job_data.data.get("lepton_job_name")
644+
706645
if lepton_job_name:
707-
success = delete_lepton_job(lepton_job_name)
708-
if success:
646+
cancel_success = delete_lepton_job(lepton_job_name)
647+
if cancel_success:
709648
print(f"✅ Cancelled Lepton job: {lepton_job_name}")
649+
# Mark job as killed in database
650+
job_data.data["status"] = "killed"
651+
job_data.data["killed_time"] = time.time()
652+
db.write_job(job_data)
710653
else:
711-
print(f"⚠️ Failed to cancel Lepton job: {lepton_job_name}")
712-
713-
# Mark job as killed in database
714-
job_data.data["status"] = "killed"
715-
job_data.data["killed_time"] = time.time()
716-
db.write_job(job_data)
654+
# Use common helper to get informative error message based on job status
655+
status_list = LeptonExecutor.get_status(job_id)
656+
current_status = status_list[0].state if status_list else None
657+
error_msg = LeptonExecutor.get_kill_failure_message(
658+
job_id, f"lepton_job: {lepton_job_name}", current_status
659+
)
660+
raise RuntimeError(error_msg)
661+
else:
662+
raise ValueError(f"No Lepton job name found for job {job_id}")
717663

718664
print(f"🛑 Killed Lepton job {job_id}")
719665

packages/nemo-evaluator-launcher/src/nemo_evaluator_launcher/executors/local/executor.py

Lines changed: 48 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -415,10 +415,10 @@ def get_status(id: str) -> List[ExecutionStatus]:
415415

416416
@staticmethod
417417
def kill_job(job_id: str) -> None:
418-
"""Kill a local job by stopping its Docker container and related processes.
418+
"""Kill a local job.
419419
420420
Args:
421-
job_id: The job ID to kill.
421+
job_id: The job ID (e.g., abc123.0) to kill.
422422
423423
Raises:
424424
ValueError: If job is not found or invalid.
@@ -463,14 +463,55 @@ def kill_job(job_id: str) -> None:
463463
if result.returncode == 0:
464464
killed_something = True
465465

466-
# Mark job as killed in database if we killed something
466+
# If we successfully killed something, mark as killed
467467
if killed_something:
468468
job_data.data["killed"] = True
469469
db.write_job(job_data)
470-
else:
471-
raise RuntimeError(
472-
f"Could not find or kill job {job_id} (container: {container_name})"
473-
)
470+
LocalExecutor._add_to_killed_jobs(job_data.invocation_id, job_id)
471+
return
472+
473+
# If nothing was killed, check if this is a pending job
474+
status_list = LocalExecutor.get_status(job_id)
475+
if status_list and status_list[0].state == ExecutionState.PENDING:
476+
# For pending jobs, mark as killed even though there's nothing to kill yet
477+
job_data.data["killed"] = True
478+
db.write_job(job_data)
479+
LocalExecutor._add_to_killed_jobs(job_data.invocation_id, job_id)
480+
return
481+
482+
# Use common helper to get informative error message based on job status
483+
current_status = status_list[0].state if status_list else None
484+
error_msg = LocalExecutor.get_kill_failure_message(
485+
job_id, f"container: {container_name}", current_status
486+
)
487+
raise RuntimeError(error_msg)
488+
489+
@staticmethod
490+
def _add_to_killed_jobs(invocation_id: str, job_id: str) -> None:
491+
"""Add a job ID to the killed jobs file for this invocation.
492+
493+
Args:
494+
invocation_id: The invocation ID.
495+
job_id: The job ID to mark as killed.
496+
"""
497+
db = ExecutionDB()
498+
jobs = db.get_jobs(invocation_id)
499+
if not jobs:
500+
return
501+
502+
# Get invocation output directory from any job's output_dir
503+
first_job_data = next(iter(jobs.values()))
504+
job_output_dir = pathlib.Path(first_job_data.data.get("output_dir", ""))
505+
if not job_output_dir.exists():
506+
return
507+
508+
# Invocation dir is parent of job output dir
509+
invocation_dir = job_output_dir.parent
510+
killed_jobs_file = invocation_dir / "killed_jobs.txt"
511+
512+
# Append job_id to file
513+
with open(killed_jobs_file, "a") as f:
514+
f.write(f"{job_id}\n")
474515

475516

476517
def _get_progress(artifacts_dir: pathlib.Path) -> Optional[float]:

packages/nemo-evaluator-launcher/src/nemo_evaluator_launcher/executors/local/run.template.sh

Lines changed: 18 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,11 @@
1717
# check if docker exists
1818
command -v docker >/dev/null 2>&1 || { echo 'docker not found'; exit 1; }
1919

20+
# Initialize: remove killed jobs file from previous runs
21+
script_dir="$(cd "$(dirname "${BASH_SOURCE[0]}")" && pwd)"
22+
killed_jobs_file="$script_dir/killed_jobs.txt"
23+
rm -f "$killed_jobs_file"
24+
2025
{% for task in evaluation_tasks %}
2126
# {{ task.job_id }} {{ task.name }}
2227

@@ -28,13 +33,17 @@ mkdir -m 777 -p "$task_dir"
2833
mkdir -m 777 -p "$artifacts_dir"
2934
mkdir -m 777 -p "$logs_dir"
3035

31-
# Create pre-start stage file
32-
echo "$(date -u +%Y-%m-%dT%H:%M:%SZ)" > "$logs_dir/stage.pre-start"
36+
# Check if this job was killed
37+
if [ -f "$killed_jobs_file" ] && grep -q "^{{ task.job_id }}$" "$killed_jobs_file"; then
38+
echo "$(date -u +%Y-%m-%dT%H:%M:%SZ) Job {{ task.job_id }} ({{ task.name }}) was killed, skipping execution" | tee -a "$logs_dir/stdout.log"
39+
else
40+
# Create pre-start stage file
41+
echo "$(date -u +%Y-%m-%dT%H:%M:%SZ)" > "$logs_dir/stage.pre-start"
3342

34-
# Docker run with eval factory command
35-
(
36-
echo "$(date -u +%Y-%m-%dT%H:%M:%SZ)" > "$logs_dir/stage.running"
37-
docker run --rm --shm-size=100g {{ extra_docker_args }} \
43+
# Docker run with eval factory command
44+
(
45+
echo "$(date -u +%Y-%m-%dT%H:%M:%SZ)" > "$logs_dir/stage.running"
46+
docker run --rm --shm-size=100g {{ extra_docker_args }} \
3847
--name {{ task.container_name }} \
3948
--volume "$artifacts_dir":/results \
4049
{% for env_var in task.env_vars -%}
@@ -85,4 +94,7 @@ echo "$(date -u +%Y-%m-%dT%H:%M:%SZ)" > "$logs_dir/stage.pre-start"
8594
)
8695

8796
{% endif %}
97+
fi
98+
99+
88100
{% endfor %}

packages/nemo-evaluator-launcher/src/nemo_evaluator_launcher/executors/slurm/executor.py

Lines changed: 40 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -400,7 +400,7 @@ def kill_job(job_id: str) -> None:
400400
"""Kill a SLURM job.
401401
402402
Args:
403-
job_id: The job ID to kill.
403+
job_id: The job ID (e.g., abc123.0) to kill.
404404
"""
405405
db = ExecutionDB()
406406
job_data = db.get_job(job_id)
@@ -413,26 +413,31 @@ def kill_job(job_id: str) -> None:
413413
f"Job {job_id} is not a slurm job (executor: {job_data.executor})"
414414
)
415415

416-
killed_something = False
417-
418-
result = _kill_slurm_job(
416+
# OPTIMIZATION: Query status AND kill in ONE SSH call
417+
slurm_status, result = _kill_slurm_job(
419418
slurm_job_ids=[job_data.data.get("slurm_job_id")],
420419
username=job_data.data.get("username"),
421420
hostname=job_data.data.get("hostname"),
422421
socket=job_data.data.get("socket"),
423422
)
424423

424+
# Mark job as killed in database if kill succeeded
425425
if result.returncode == 0:
426-
killed_something = True
427-
428-
# Mark job as killed in database if we killed something
429-
if killed_something:
430426
job_data.data["killed"] = True
431427
db.write_job(job_data)
432428
else:
433-
raise RuntimeError(
434-
f"Could not find or kill job {job_id} (slurm_job_id: {job_data.data.get('slurm_job_id')})"
429+
# Use the pre-fetched status for better error message
430+
current_status = None
431+
if slurm_status:
432+
current_status = SlurmExecutor._map_slurm_state_to_execution_state(
433+
slurm_status
434+
)
435+
error_msg = SlurmExecutor.get_kill_failure_message(
436+
job_id,
437+
f"slurm_job_id: {job_data.data.get('slurm_job_id')}",
438+
current_status,
435439
)
440+
raise RuntimeError(error_msg)
436441

437442

438443
def _create_slurm_sbatch_script(
@@ -883,34 +888,47 @@ def _query_slurm_jobs_status(
883888

884889
def _kill_slurm_job(
885890
slurm_job_ids: List[str], username: str, hostname: str, socket: str | None
886-
) -> None:
887-
"""Kill a SLURM job.
891+
) -> tuple[str | None, subprocess.CompletedProcess]:
892+
"""Kill a SLURM job, querying status first in one SSH call for efficiency.
888893
889894
Args:
890895
slurm_job_ids: List of SLURM job IDs to kill.
891896
username: SSH username.
892897
hostname: SSH hostname.
893898
socket: control socket location or None
899+
900+
Returns:
901+
Tuple of (status_string, completed_process) where status_string is the SLURM status or None
894902
"""
895903
if len(slurm_job_ids) == 0:
896-
return {}
897-
kill_command = "scancel {}".format(",".join(slurm_job_ids))
904+
return None, subprocess.CompletedProcess(args=[], returncode=0)
905+
906+
jobs_str = ",".join(slurm_job_ids)
907+
# Combine both commands in one SSH call: query THEN kill
908+
combined_command = (
909+
f"sacct -j {jobs_str} --format='JobID,State%32' --noheader -P 2>/dev/null; "
910+
f"scancel {jobs_str}"
911+
)
912+
898913
ssh_command = ["ssh"]
899914
if socket is not None:
900915
ssh_command.append(f"-S {socket}")
901916
ssh_command.append(f"{username}@{hostname}")
902-
ssh_command.append(kill_command)
917+
ssh_command.append(combined_command)
903918
ssh_command = " ".join(ssh_command)
919+
904920
completed_process = subprocess.run(
905921
args=shlex.split(ssh_command), capture_output=True
906922
)
907-
if completed_process.returncode != 0:
908-
raise RuntimeError(
909-
"failed to kill slurm job\n{}".format(
910-
completed_process.stderr.decode("utf-8")
911-
)
912-
)
913-
return completed_process
923+
924+
# Parse the sacct output (before scancel runs)
925+
sacct_output = completed_process.stdout.decode("utf-8")
926+
sacct_output_lines = sacct_output.strip().split("\n")
927+
slurm_status = None
928+
if sacct_output_lines and len(slurm_job_ids) == 1:
929+
slurm_status = _parse_slurm_job_status(slurm_job_ids[0], sacct_output_lines)
930+
931+
return slurm_status, completed_process
914932

915933

916934
def _parse_slurm_job_status(slurm_job_id: str, sacct_output_lines: List[str]) -> str:

0 commit comments

Comments
 (0)