Skip to content
56 changes: 45 additions & 11 deletions sirepo/job_driver/sbatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,9 @@
from sirepo import job_driver
from sirepo import util
import asyncssh
import asyncio
import uuid
from urllib.parse import urlparse
import datetime
import errno
import sirepo.const
Expand Down Expand Up @@ -69,6 +72,13 @@ async def kill(self):
pkdlog("websocket closed {}", self)
except Exception as e:
pkdlog("{} error={} stack={}", self, e, pkdexc())
if self.conn is not None:
try:
self.conn.close()
await self.conn.wait_closed()
self.conn = None
except Exception as e:
pkdlog("{} error={} stack={}", self, e, pkdexc())

@classmethod
def get_instance(cls, op):
Expand Down Expand Up @@ -225,6 +235,18 @@ def _write_to_log(stdout, stderr, filename):
agent_start_dir = self._srdb_root
if pkconfig.in_dev_mode():
pkdlog("agent_log={}/{}", agent_start_dir, log_file)

# If the supervisor URI points to localhost, while the slurm host is not localhost
# the supervisor URI is likely inaccessible and we should forward it over SSH
supervisor_uri = urlparse(self.cfg.supervisor_uri)
loopback_addresses = ['localhost', '127.0.0.1']
is_supervisor_uri_proper = not (supervisor_uri.hostname in loopback_addresses and
self.cfg.host not in loopback_addresses)
if not is_supervisor_uri_proper:
dest_domain_socket0 = f"/tmp/sirepo-{uuid.uuid4().hex}.sock" # Domain socket name just needs to be unique, so we can use a UUID
old_supervisor_uri = self.cfg.supervisor_uri # Save the old supervisor URI to restore it later
self.cfg.supervisor_uri = f"unix:/{dest_domain_socket0};" # Semicolon is needed to separate the socket path from the resource path

script = f"""#!/bin/bash
set -euo pipefail
{_agent_start_dev()}
Expand All @@ -235,17 +257,29 @@ def _write_to_log(stdout, stderr, filename):
disown
"""
try:
async with asyncssh.connect(self.cfg.host, **_creds()) as c:
async with c.create_process("/bin/bash --noprofile --norc -l") as p:
await _get_agent_log(c, before_start=True)
o, e = await p.communicate(input=script)
if o or e:
_write_to_log(o, e, "start")
self.driver_details.pkupdate(
host=self.cfg.host,
username=self._creds.username,
)
await _get_agent_log(c, before_start=False)
self.conn = await asyncssh.connect(self.cfg.host, **_creds())
if not is_supervisor_uri_proper:
supervisor_port = supervisor_uri.port
if not supervisor_port:
supervisor_port = 443 if supervisor_uri.scheme == 'https' else 80 # Default to the port for HTTPS or HTTP
listener0 = await self.conn.forward_remote_path_to_port(dest_domain_socket0, 'localhost', supervisor_port)
self.conn_listener0 = asyncio.create_task(listener0.wait_closed())
self.cfg.supervisor_uri = old_supervisor_uri # Restore the original supervisor URI

async with self.conn.create_process("/bin/bash --noprofile --norc -l") as p:
await _get_agent_log(self.conn, before_start=True)
o, e = await p.communicate(input=script)
if o or e:
_write_to_log(o, e, "start")
self.driver_details.pkupdate(
host=self.cfg.host,
username=self._creds.username,
)
await _get_agent_log(self.conn, before_start=False)
if is_supervisor_uri_proper:
self.conn.close()
await self.conn.wait_closed()
self.conn = None
except Exception as e:
pkdlog("error={} stack={}", e, pkdexc())
self._srdb_root = None
Expand Down
16 changes: 15 additions & 1 deletion sirepo/pkcli/job_agent.py
Original file line number Diff line number Diff line change
Expand Up @@ -175,6 +175,12 @@ def _remove_own_pid_file(info):
finally:
_remove_own_pid_file(p)

class UnixResolver(tornado.netutil.Resolver):
def initialize(self, socket_path):
self.socket_path = socket_path

async def resolve(self, host, port, *args, **kwargs):
return [(socket.AF_UNIX, self.socket_path)]

class _Dispatcher(PKDict):
def __init__(self):
Expand Down Expand Up @@ -245,15 +251,23 @@ def _parse_text():

async def loop(self):
async def _connect_and_loop():
tgt_url = _cfg.supervisor_uri
resolver = None
if _cfg.supervisor_uri.startswith("unix:/"): # Handle unix domain sockets
socket_path, _, resource_path = _cfg.supervisor_uri.replace("unix:/", "", 1).partition(";")
tgt_url = f"ws://localhost:0/{resource_path.lstrip('/')}"
resolver = UnixResolver(socket_path)

self._websocket = await tornado.websocket.websocket_connect(
tornado.httpclient.HTTPRequest(
connect_timeout=_CONNECT_SECS,
url=_cfg.supervisor_uri,
url=tgt_url,
validate_cert=job.cfg().verify_tls,
),
max_message_size=job.cfg().max_message_bytes,
ping_interval=job.cfg().ping_interval_secs,
ping_timeout=job.cfg().ping_timeout_secs,
resolver=resolver,
)
s = self.format_op(None, job.OP_ALIVE)
rv = False
Expand Down