Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
282 changes: 245 additions & 37 deletions nemo_skills/code_execution/local_sandbox/local_sandbox_server.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,11 @@
import threading
import time
import traceback
import uuid
from concurrent.futures import ThreadPoolExecutor
from contextlib import redirect_stderr, redirect_stdout
from dataclasses import asdict, dataclass, field
from typing import Any, Dict, List, Optional

import psutil
from flask import Flask, request
Expand All @@ -41,6 +45,158 @@
)


@dataclass
class Job:
job_id: str
request: Dict[str, Any]
status: str = "queued"
created_at: float = field(default_factory=time.time)
started_at: Optional[float] = None
finished_at: Optional[float] = None
result: Optional[Dict[str, Any]] = None


class JobManager:
def __init__(self, shell_manager):
self.jobs = {}
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

A potential issue with keeping this all in memory, is if the worker stops responding or crashes for some reason, all of your state disappears with the worker. Might not be something we need to fix right now, but if we start having issues of workers stopping/crashing, this will be a crux.

Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's not clear to me if we need something more robust yet, though.

self.futures = {}
self.lock = threading.Lock()
self.shell_manager = shell_manager
self.executor = ThreadPoolExecutor(max_workers=1)

def submit(self, request: Dict[str, Any]):
job_id = str(uuid.uuid4())
job = Job(job_id, request=request)
with self.lock:
self.jobs[job_id] = job
# schedule serially via single-worker executor
future = self.executor.submit(self._run_job, job_id)
self.futures[job_id] = future
return job_id

def get_job(self, job_id: str):
with self.lock:
job = self.jobs.get(job_id)
if job is None:
return None
# Evict finished jobs after they are returned once to save memory
if job.status not in ("queued", "running"):
return self.jobs.pop(job_id)
return job

def queued_ahead_count(self, job_id: str) -> int:
with self.lock:
return sum((j.job_id != job_id) and (j.status in ("queued", "running")) for j in self.jobs.values())

def get_related_jobs(self, job_id: str) -> List[str]:
with self.lock:
base = self.jobs.get(job_id)
if base is None:
return []
sid = base.request.get("session_id")
return [j.job_id for j in self.jobs.values() if j.request.get("session_id") == sid and j.job_id != job_id]

def _run_job(self, job_id: str) -> None:
with self.lock:
job = self.jobs.get(job_id)
if job is None:
return
if job.status == "canceled":
# was canceled before start; nothing to do
self.futures.pop(job_id, None)
return
job.started_at = time.time()
job.status = "running"
logging.info("Worker %s started running job %s", worker_id, job_id)

result = self.execute_job(job.request)

with self.lock:
# If the job was canceled while running, set the final status
if job.status == "canceled":
job.result = {
"process_status": "canceled",
"stdout": result.get("stdout", ""),
"stderr": "Job was canceled.",
}
else:
job.result = result
job.status = result["process_status"]
logging.info("Worker %s finished running job %s with status %s", worker_id, job_id, job.status)
job.finished_at = time.time()
self.futures.pop(job_id, None)

def cancel_job(self, job_id: str):
"""Attempt to cancel a queued or running job.

- If the job hasn't started, mark it canceled and prevent execution.
- If running and language is ipython, stop the associated shell to interrupt execution.
- For other languages, we mark as canceled; underlying process may still finish, but
result will be recorded as canceled.
"""
with self.lock:
job = self.jobs.get(job_id)
if job is None:
return False, f"Job {job_id} not found"
if job.status not in ("queued", "running"):
return False, f"Job {job_id} already finished with status {job.status}"

if job.status == "queued":
job.finished_at = time.time()
job.result = {"process_status": "canceled", "stdout": "", "stderr": "Job was canceled."}

job.status = "canceled"
logging.info("Worker %s canceled job %s", worker_id, job_id)
future = self.futures.get(job_id)
if future is not None:
future.cancel()
self.futures.pop(job_id, None)
language = job.request.get("language")
session_id = job.request.get("session_id")

if language == "ipython" and session_id:
ok, msg = self.shell_manager.stop_shell(session_id)
if not ok:
logging.warning(
"Tried to cancel job %s by stopping session %s but failed: %s", job_id, session_id, msg
)

return True, f"Job {job_id} canceled"

def execute_job(self, request: Dict[str, Any]):
try:
language = request["language"]
generated_code = request["generated_code"]
timeout = request.get("timeout", 10.0)
session_id = request.get("session_id", None)
std_input = request.get("std_input", "")
max_output_characters = request.get("max_output_characters", 1000)
traceback_verbosity = request.get("traceback_verbosity", "Plain")
result = {}

if language == "ipython":
if session_id is None:
result = {"process_status": "error", "stderr": "X-Session-ID header required for ipython sessions"}
return result
result = execute_ipython_session(generated_code, session_id, timeout, traceback_verbosity)
elif language == "lean4":
result = execute_lean4(generated_code, timeout)
elif language == "shell":
result = execute_shell(generated_code, timeout)
else:
result = execute_python(generated_code, std_input, timeout, language)

if len(result.get("stdout", "")) > max_output_characters:
result["stdout"] = result["stdout"][:max_output_characters] + "<output cut>"
if len(result.get("stderr", "")) > max_output_characters:
result["stderr"] = result["stderr"][:max_output_characters] + "<output cut>"

return result
except Exception as e:
logging.error(f"Error during job execution: {e}\n{traceback.format_exc()}")
return {"process_status": "error", "stderr": f"An unexpected error occurred: {e}"}


# Worker that runs inside the shell process and owns a TerminalInteractiveShell()
def shell_worker(conn):
shell = TerminalInteractiveShell()
Expand Down Expand Up @@ -133,18 +289,28 @@ def stop_shell(self, shell_id):
entry = self.shells.pop(shell_id, None)

if not entry:
return
return False, f"IPython session {shell_id} not found"
proc, conn = entry["proc"], entry["conn"]
try:
conn.send({"cmd": "shutdown"})
except Exception:
pass
try:
conn.close()
except Exception:
pass
proc.terminate()
proc.join(timeout=2.0)
try:
conn.send({"cmd": "shutdown"})
except Exception:
pass
try:
conn.close()
except Exception:
pass
try:
proc.terminate()
except Exception:
try:
os.kill(proc.pid, signal.SIGKILL)
except Exception:
pass
proc.join(timeout=2.0)
return True, f"IPython session {shell_id} deleted successfully"
except Exception as e:
return False, f"Error stopping IPython session {shell_id}: {e}"

def run_cell(self, shell_id, code, timeout=1.0, grace=0.5, traceback_verbosity="Plain"):
"""
Expand Down Expand Up @@ -325,8 +491,11 @@ def cleanup_expired_sessions():

for session_id in expired_sessions:
try:
shell_manager.stop_shell(session_id)
logging.info(f"Cleaned up expired session: {session_id}")
ok, msg = shell_manager.stop_shell(session_id)
if ok:
logging.info(msg)
else:
logging.error(msg)
except Exception as e:
logging.warning(f"Error cleaning up session {session_id}: {e}")

Expand Down Expand Up @@ -605,35 +774,28 @@ def execute_shell(command, timeout):
os.remove(tmp_path)


job_manager = JobManager(shell_manager)


@app.route("/jobs/<job_id>", methods=["GET"])
def get_job(job_id):
job = job_manager.get_job(job_id)
if job is None:
return {"error": f"Job {job_id} not found"}, 404
return asdict(job)


# Main Flask endpoint to handle execution requests
@app.route("/execute", methods=["POST"])
def execute():
generated_code = request.json["generated_code"]
timeout = request.json["timeout"]
language = request.json.get("language", "ipython")
std_input = request.json.get("std_input", "")
max_output_characters = request.json.get("max_output_characters", 1000)
traceback_verbosity = request.json.get("traceback_verbosity", "Plain")

session_id = request.headers.get("X-Session-ID")

if language == "ipython":
if session_id is None:
return {"error": "X-Session-ID header required for ipython sessions"}, 400
result = execute_ipython_session(generated_code, session_id, timeout, traceback_verbosity)
elif language == "lean4":
result = execute_lean4(generated_code, timeout)
elif language == "shell":
result = execute_shell(generated_code, timeout)
else:
result = execute_python(generated_code, std_input, timeout, language)

if len(result["stdout"]) > max_output_characters:
result["stdout"] = result["stdout"][:max_output_characters] + "<output cut>"
if len(result["stderr"]) > max_output_characters:
result["stderr"] = result["stderr"][:max_output_characters] + "<output cut>"

return result
request_dict = request.json
request_dict["session_id"] = session_id
job_id = job_manager.submit(request_dict)
queued_ahead = job_manager.queued_ahead_count(job_id)
logging.info("Accepted job %s queued_ahead=%d pid=%d worker=%s", job_id, queued_ahead, os.getpid(), worker_id)
return {"job_id": job_id, "queued_ahead": queued_ahead, "pid": os.getpid(), "worker": worker_id}, 202


# Session management endpoints
Expand Down Expand Up @@ -679,7 +841,53 @@ def delete_session(session_id):

@app.route("/health", methods=["GET"])
def health():
return {"status": "healthy", "worker": os.environ.get("WORKER_NUM", "unknown")}
return {"status": "healthy", "worker": worker_id}


@app.route("/admin/reset_worker", methods=["POST"])
def reset_worker():
"""
Forcefully exit this worker process.

We return a 200 response first, then exit shortly after in a
background thread to ensure the HTTP response is delivered.
"""
try:
pid = os.getpid()
wid = os.environ.get("WORKER_NUM", "unknown")

def _delayed_exit():
try:
time.sleep(0.25)
finally:
os._exit(0)

threading.Thread(target=_delayed_exit, daemon=True).start()
logging.warning("Hard reset requested: exiting worker pid=%d worker=%s", pid, wid)
return {"message": "worker exiting for hard reset", "pid": pid, "worker": wid}
except Exception as e:
logging.error("Hard reset failed: %s", e)
return {"error": f"Hard reset failed: {e}"}, 500


@app.route("/jobs/<job_id>/cancel", methods=["POST"])
def cancel_job(job_id):
"""Cancel a queued or running job.

Best-effort: for ipython, we also stop the associated session to interrupt execution.
"""
req_session_id = request.headers.get("X-Session-ID")
for related_job_id in job_manager.get_related_jobs(job_id):
ok, msg = job_manager.cancel_job(related_job_id)

ok, msg = job_manager.cancel_job(job_id)
if ok:
logging.info("Cancel request for job %s (header_session=%s): %s", job_id, req_session_id, msg)
return {"message": msg}, 200
else:
logging.warning("Cancel request for job %s (header_session=%s) failed: %s", job_id, req_session_id, msg)
status = 404 if "not found" in msg else 409
return {"error": msg}, status


if __name__ == "__main__":
Expand Down
Loading
Loading