From 840c1e2971a4821875125994a8433b9b7dee9fbc Mon Sep 17 00:00:00 2001 From: aldbr Date: Thu, 27 Jun 2024 09:21:51 +0200 Subject: [PATCH] feat(Resources): introduce fabric in SSHCE --- environment.yml | 3 + setup.cfg | 3 + .../Computing/SSHBatchComputingElement.py | 117 ++- .../Computing/SSHComputingElement.py | 699 +++++++----------- .../scripts/dirac_admin_debug_ce.py | 81 +- 5 files changed, 418 insertions(+), 485 deletions(-) diff --git a/environment.yml b/environment.yml index 3c25b0c3c2e..5bfe0cd922d 100644 --- a/environment.yml +++ b/environment.yml @@ -18,11 +18,14 @@ dependencies: - cwltool - db12 - opensearch-py + - fabric - fts3 - gitpython >=2.1.0 + - invoke - m2crypto >=0.38.0 - matplotlib - numpy + - paramiko - pexpect >=4.0.1 - pillow - prompt-toolkit >=3,<4 diff --git a/setup.cfg b/setup.cfg index 7becf2289f7..6c9e1c1b2bd 100644 --- a/setup.cfg +++ b/setup.cfg @@ -35,12 +35,15 @@ install_requires = diracx-core >=v0.0.1 diracx-cli >=v0.0.1 db12 + fabric fts3 gfal2-python importlib_metadata >=4.4 importlib_resources + invoke M2Crypto >=0.36 packaging + paramiko pexpect prompt-toolkit >=3 psutil diff --git a/src/DIRAC/Resources/Computing/SSHBatchComputingElement.py b/src/DIRAC/Resources/Computing/SSHBatchComputingElement.py index 38c70887d4e..52fb59982b5 100644 --- a/src/DIRAC/Resources/Computing/SSHBatchComputingElement.py +++ b/src/DIRAC/Resources/Computing/SSHBatchComputingElement.py @@ -1,4 +1,4 @@ -""" SSH (Virtual) Computing Element: For a given list of ip/cores pair it will send jobs +""" SSH (Virtual) Batch Computing Element: For a given list of ip/cores pair it will send jobs directly through ssh """ @@ -12,64 +12,77 @@ class SSHBatchComputingElement(SSHComputingElement): - ############################################################################# def __init__(self, ceUniqueID): """Standard constructor.""" super().__init__(ceUniqueID) - self.ceType = "SSHBatch" - self.sshHost = [] + self.connections = {} self.execution = "SSHBATCH" def _reset(self): """Process CE parameters and make necessary adjustments""" + # Get the Batch System instance result = self._getBatchSystem() if not result["OK"]: return result + + # Get the location of the remote directories self._getBatchSystemDirectoryLocations() - self.user = self.ceParameters["SSHUser"] + # Get the SSH parameters + self.timeout = self.ceParameters.get("Timeout", self.timeout) + self.user = self.ceParameters.get("SSHUser", self.user) + port = self.ceParameters.get("SSHPort", None) + password = self.ceParameters.get("SSHPassword", None) + key = self.ceParameters.get("SSHKey", None) + + # Get submission parameters + self.submitOptions = self.ceParameters.get("SubmitOptions", self.submitOptions) + self.preamble = self.ceParameters.get("Preamble", self.preamble) + self.account = self.ceParameters.get("Account", self.account) self.queue = self.ceParameters["Queue"] self.log.info("Using queue: ", self.queue) - self.submitOptions = self.ceParameters.get("SubmitOptions", "") - self.preamble = self.ceParameters.get("Preamble", "") - self.account = self.ceParameters.get("Account", "") - - # Prepare all the hosts - for hPar in self.ceParameters["SSHHost"].strip().split(","): - host = hPar.strip().split("/")[0] - result = self._prepareRemoteHost(host=host) - if result["OK"]: - self.log.info(f"Host {host} registered for usage") - self.sshHost.append(hPar.strip()) + # Get output and error templates + self.outputTemplate = self.ceParameters.get("OutputTemplate", self.outputTemplate) + self.errorTemplate = self.ceParameters.get("ErrorTemplate", self.errorTemplate) + + # Prepare the remote hosts + for host in self.ceParameters.get("SSHHost", "").strip().split(","): + hostDetails = host.strip().split("/") + if len(hostDetails) > 1: + hostname = hostDetails[0] + maxJobs = int(hostDetails[1]) else: - self.log.error("Failed to initialize host", host) + hostname = hostDetails[0] + maxJobs = int(self.ceParameters.get("MaxTotalJobs", 0)) + + connection = self._getConnection(hostname, self.user, port, password, key) + + result = self._prepareRemoteHost(connection) + if not result["OK"]: return result + self.connections[hostname] = {"connection": connection, "maxJobs": maxJobs} + self.log.info(f"Host {hostname} registered for usage") + return S_OK() ############################################################################# + def submitJob(self, executableFile, proxy, numberOfJobs=1): """Method to submit job""" - # Choose eligible hosts, rank them by the number of available slots rankHosts = {} maxSlots = 0 - for host in self.sshHost: - thost = host.split("/") - hostName = thost[0] - maxHostJobs = 1 - if len(thost) > 1: - maxHostJobs = int(thost[1]) - - result = self._getHostStatus(hostName) + for _, details in self.connections.items(): + result = self._getHostStatus(details["connection"]) if not result["OK"]: continue - slots = maxHostJobs - result["Value"]["Running"] + slots = details["maxJobs"] - result["Value"]["Running"] if slots > 0: rankHosts.setdefault(slots, []) - rankHosts[slots].append(hostName) + rankHosts[slots].append(details["connection"]) if slots > maxSlots: maxSlots = slots @@ -83,18 +96,28 @@ def submitJob(self, executableFile, proxy, numberOfJobs=1): restJobs = numberOfJobs submittedJobs = [] stampDict = {} + batchSystemName = self.batchSystem.__class__.__name__.lower() + for slots in range(maxSlots, 0, -1): if slots not in rankHosts: continue - for host in rankHosts[slots]: - result = self._submitJobToHost(executableFile, min(slots, restJobs), host) + for connection in rankHosts[slots]: + result = self._submitJobToHost(connection, executableFile, min(slots, restJobs)) if not result["OK"]: continue - nJobs = len(result["Value"]) + batchIDs, jobStamps = result["Value"] + + nJobs = len(batchIDs) if nJobs > 0: - submittedJobs.extend(result["Value"]) - stampDict.update(result.get("PilotStampDict", {})) + jobIDs = [ + f"{self.ceType.lower()}{batchSystemName}://{self.ceName}/{connection.host}/{_id}" + for _id in batchIDs + ] + submittedJobs.extend(jobIDs) + for iJob, jobID in enumerate(jobIDs): + stampDict[jobID] = jobStamps[iJob] + restJobs = restJobs - nJobs if restJobs <= 0: break @@ -105,6 +128,8 @@ def submitJob(self, executableFile, proxy, numberOfJobs=1): result["PilotStampDict"] = stampDict return result + ############################################################################# + def killJob(self, jobIDs): """Kill specified jobs""" jobIDList = list(jobIDs) @@ -120,7 +145,7 @@ def killJob(self, jobIDs): failed = [] for host, jobIDList in hostDict.items(): - result = self._killJobOnHost(jobIDList, host) + result = self._killJobOnHost(self.connections[host]["connection"], jobIDList) if not result["OK"]: failed.extend(jobIDList) message = result["Message"] @@ -133,6 +158,8 @@ def killJob(self, jobIDs): return result + ############################################################################# + def getCEStatus(self): """Method to return information on running and pending jobs.""" result = S_OK() @@ -140,9 +167,8 @@ def getCEStatus(self): result["RunningJobs"] = 0 result["WaitingJobs"] = 0 - for host in self.sshHost: - thost = host.split("/") - resultHost = self._getHostStatus(thost[0]) + for _, details in self.connections.items(): + resultHost = self._getHostStatus(details["connection"]) if resultHost["OK"]: result["RunningJobs"] += resultHost["Value"]["Running"] @@ -151,6 +177,8 @@ def getCEStatus(self): return result + ############################################################################# + def getJobStatus(self, jobIDList): """Get status of the jobs in the given list""" hostDict = {} @@ -162,7 +190,7 @@ def getJobStatus(self, jobIDList): resultDict = {} failed = [] for host, jobIDList in hostDict.items(): - result = self._getJobStatusOnHost(jobIDList, host) + result = self._getJobStatusOnHost(self.connections[host]["connection"], jobIDList) if not result["OK"]: failed.extend(jobIDList) continue @@ -173,3 +201,16 @@ def getJobStatus(self, jobIDList): resultDict[job] = PilotStatus.UNKNOWN return S_OK(resultDict) + + ############################################################################# + + def getJobOutput(self, jobID, localDir=None): + """Get the specified job standard output and error files. If the localDir is provided, + the output is returned as file in this directory. Otherwise, the output is returned + as strings. + """ + self.log.verbose("Getting output for jobID", jobID) + + # host can be retrieved from the path of the jobID + host = os.path.dirname(urlparse(jobID).path).lstrip("/") + return self._getJobOutputFilesOnHost(self.connections[host]["connection"], jobID, localDir) diff --git a/src/DIRAC/Resources/Computing/SSHComputingElement.py b/src/DIRAC/Resources/Computing/SSHComputingElement.py index 25668e62b57..aa6cf11127c 100644 --- a/src/DIRAC/Resources/Computing/SSHComputingElement.py +++ b/src/DIRAC/Resources/Computing/SSHComputingElement.py @@ -40,25 +40,20 @@ SSH password SSHPort: - Port number if not standard, e.g. for the gsissh access + Port number if not standard SSHKey: Location of the ssh private key for no-password connection -SSHOptions: - Any other SSH options to be used. Example:: - - SSHOptions = -o UserKnownHostsFile=/local/path/to/known_hosts - - Allows to have a local copy of the ``known_hosts`` file, independent of the HOME directory. - SSHTunnel: - String defining the use of intermediate SSH host. Example:: + Gateway/jump host used to reach the final destination. Example:: - ssh -i /private/key/location -l final_user final_host + gateway_host + gateway_host:port + +Timeout: + Timeout for the SSH commands. Default is 120 seconds. -SSHType: - SSH (default) or gsissh **Code Documentation** """ @@ -69,278 +64,99 @@ import stat import tempfile import uuid -from shlex import quote as shlex_quote from urllib.parse import urlparse -import pexpect +from fabric import Connection +from invoke.exceptions import CommandTimedOut +from paramiko.ssh_exception import SSHException import DIRAC -from DIRAC import S_ERROR, S_OK, gLogger +from DIRAC import S_ERROR, S_OK from DIRAC.Core.Utilities.List import breakListIntoChunks, uniqueElements from DIRAC.Resources.Computing.BatchSystems.executeBatch import executeBatchContent from DIRAC.Resources.Computing.ComputingElement import ComputingElement -class SSH: - """SSH class encapsulates passing commands and files through an SSH tunnel - to a remote host. It can use either ssh or gsissh access. The final host - where the commands will be executed and where the files will copied/retrieved - can be reached through an intermediate host if SSHTunnel parameters is defined. - - SSH constructor parameters are defined in a SSH accessible Computing Element - in the Configuration System: - - - SSHHost: SSH host name - - SSHUser: SSH user login - - SSHPassword: SSH password - - SSHPort: port number if not standard, e.g. for the gsissh access - - SSHKey: location of the ssh private key for no-password connection - - SSHOptions: any other SSH options to be used - - SSHTunnel: string defining the use of intermediate SSH host. Example: - 'ssh -i /private/key/location -l final_user final_host' - - SSHType: ssh ( default ) or gsissh - - The class public interface includes two methods: - - sshCall( timeout, command_sequence ) - scpCall( timeout, local_file, remote_file, upload = False/True ) - """ - - def __init__(self, host=None, parameters=None): - self.host = host - if parameters is None: - parameters = {} - if not host: - self.host = parameters.get("SSHHost", "") - - self.user = parameters.get("SSHUser", "") - self.password = parameters.get("SSHPassword", "") - self.port = parameters.get("SSHPort", "") - self.key = parameters.get("SSHKey", "") - self.options = parameters.get("SSHOptions", "") - self.sshTunnel = parameters.get("SSHTunnel", "") - self.sshType = parameters.get("SSHType", "ssh") - - if self.port: - self.options += f" -p {self.port}" - if self.key: - self.options += f" -i {self.key}" - self.options = self.options.strip() - - self.log = gLogger.getSubLogger("SSH") - - def __ssh_call(self, command, timeout): - if not timeout: - timeout = 999 - - ssh_newkey = "Are you sure you want to continue connecting" - try: - child = pexpect.spawn(command, timeout=timeout, encoding="utf-8") - i = child.expect([pexpect.TIMEOUT, ssh_newkey, pexpect.EOF, "assword: "]) - if i == 0: # Timeout - return S_OK((-1, child.before, "SSH login failed")) - - if i == 1: # SSH does not have the public key. Just accept it. - child.sendline("yes") - child.expect("assword: ") - i = child.expect([pexpect.TIMEOUT, "assword: "]) - if i == 0: # Timeout - return S_OK((-1, str(child.before) + str(child.after), "SSH login failed")) - if i == 1: - child.sendline(self.password) - child.expect(pexpect.EOF) - return S_OK((0, child.before, "")) - - if i == 2: - # Passwordless login, get the output - return S_OK((0, child.before, "")) - - if self.password: - child.sendline(self.password) - child.expect(pexpect.EOF) - return S_OK((0, child.before, "")) - - return S_ERROR(f"Unknown error: {child.before}") - except Exception as x: - return S_ERROR(f"Encountered exception: {str(x)}") - - def sshCall(self, timeout, cmdSeq): - """Execute remote command via a ssh remote call - - :param int timeout: timeout of the command - :param cmdSeq: list of command components - :type cmdSeq: python:list - """ - - command = cmdSeq - if isinstance(cmdSeq, list): - command = " ".join(cmdSeq) - - pattern = "__DIRAC__" - - if self.sshTunnel: - command = command.replace("'", '\\\\\\"') - command = command.replace("$", "\\\\\\$") - command = '/bin/sh -c \' {} -q {} -l {} {} "{} \\"echo {}; {}\\" " \' '.format( - self.sshType, - self.options, - self.user, - self.host, - self.sshTunnel, - pattern, - command, - ) - else: - # command = command.replace( '$', '\$' ) - command = '{} -q {} -l {} {} "echo {}; {}"'.format( - self.sshType, - self.options, - self.user, - self.host, - pattern, - command, - ) - self.log.debug(f"SSH command: {command}") - result = self.__ssh_call(command, timeout) - self.log.debug(f"SSH command result {str(result)}") - if not result["OK"]: - return result - - # Take the output only after the predefined pattern - ind = result["Value"][1].find("__DIRAC__") - if ind == -1: - return result - - status, output, error = result["Value"] - output = output[ind + 9 :] - if output.startswith("\r"): - output = output[1:] - if output.startswith("\n"): - output = output[1:] - - result["Value"] = (status, output, error) - return result - - def scpCall(self, timeout, localFile, remoteFile, postUploadCommand="", upload=True): - """Perform file copy through an SSH magic. - - :param int timeout: timeout of the command - :param str localFile: local file path, serves as source for uploading and destination for downloading. - Can take 'Memory' as value, in this case the downloaded contents is returned - as result['Value'] - :param str remoteFile: remote file full path - :param str postUploadCommand: command executed on the remote side after file upload - :param bool upload: upload if True, download otherwise - """ - # shlex_quote aims to prevent any security issue or problems with filepath containing spaces - # it returns a shell-escaped version of the filename - localFile = shlex_quote(localFile) - remoteFile = shlex_quote(remoteFile) - if upload: - if self.sshTunnel: - remoteFile = remoteFile.replace("$", r"\\\\\$") - postUploadCommand = postUploadCommand.replace("$", r"\\\\\$") - command = '/bin/sh -c \'cat {} | {} -q {} {}@{} "{} \\"cat > {}; {}\\""\' '.format( - localFile, - self.sshType, - self.options, - self.user, - self.host, - self.sshTunnel, - remoteFile, - postUploadCommand, - ) - else: - command = "/bin/sh -c \"cat {} | {} -q {} {}@{} 'cat > {}; {}'\" ".format( - localFile, - self.sshType, - self.options, - self.user, - self.host, - remoteFile, - postUploadCommand, - ) - else: - finalCat = f"| cat > {localFile}" - if localFile.lower() == "memory": - finalCat = "" - if self.sshTunnel: - remoteFile = remoteFile.replace("$", "\\\\\\$") - command = '/bin/sh -c \'{} -q {} -l {} {} "{} \\"cat {}\\"" {}\''.format( - self.sshType, - self.options, - self.user, - self.host, - self.sshTunnel, - remoteFile, - finalCat, - ) - else: - remoteFile = remoteFile.replace("$", r"\$") - command = "/bin/sh -c '{} -q {} -l {} {} \"cat {}\" {}'".format( - self.sshType, - self.options, - self.user, - self.host, - remoteFile, - finalCat, - ) - - self.log.debug(f"SSH copy command: {command}") - return self.__ssh_call(command, timeout) - - class SSHComputingElement(ComputingElement): ############################################################################# def __init__(self, ceUniqueID): """Standard constructor.""" super().__init__(ceUniqueID) - self.execution = "SSHCE" self.submittedJobs = 0 - self.outputTemplate = "" - self.errorTemplate = "" - - ############################################################################ - def setProxy(self, proxy): - """ - Set and prepare proxy to use - :param str proxy: proxy to use - :return: S_OK/S_ERROR - """ - ComputingElement.setProxy(self, proxy) - if self.ceParameters.get("SSHType", "ssh") == "gsissh": - result = self._prepareProxy() - if not result["OK"]: - gLogger.error("SSHComputingElement: failed to set up proxy", result["Message"]) - return result - return S_OK() + # SSH connection + self.connection = None + self.timeout = 120 + self.user = None + self.host = None + + # Submission parameters + self.queue = None + self.submitOptions = None + self.preamble = None + self.account = None + self.execution = "SSHCE" - ############################################################################# - def _addCEConfigDefaults(self): - """Method to make sure all necessary Configuration Parameters are defined""" - # First assure that any global parameters are loaded - ComputingElement._addCEConfigDefaults(self) - # Now batch system specific ones - if "SharedArea" not in self.ceParameters: - # . isn't a good location, move to $HOME - self.ceParameters["SharedArea"] = "$HOME" + # Directories + self.sharedArea = "$HOME" + self.batchOutput = "data" + self.batchError = "data" + self.infoArea = "data" + self.executableArea = "info" + self.workArea = "work" - if "BatchOutput" not in self.ceParameters: - self.ceParameters["BatchOutput"] = "data" + # Output and error templates + self.outputTemplate = "" + self.errorTemplate = "" - if "BatchError" not in self.ceParameters: - self.ceParameters["BatchError"] = "data" + ############################################################################# - if "ExecutableArea" not in self.ceParameters: - self.ceParameters["ExecutableArea"] = "data" + def _run(self, connection: Connection, command: str): + """Run the command on the remote host""" + try: + result = connection.run(command, warn=True, hide=True) + if result.failed: + return S_ERROR(f"[{connection.host}] Command returned an error: {result.stderr}") + return S_OK(result.stdout) + except CommandTimedOut as e: + return S_ERROR( + errno.ETIME, f"[{connection.host}] The command timed out. Consider increasing the timeout: {e}" + ) + except SSHException as e: + return S_ERROR(f"[{connection.host}] SSH error occurred: {str(e)}") - if "InfoArea" not in self.ceParameters: - self.ceParameters["InfoArea"] = "info" + def _put(self, connection: Connection, local: str, remote: str, preserveMode: bool = True): + """Upload a file to the remote host""" + try: + connection.put(local, remote=remote, preserve_mode=preserveMode) + return S_OK() + except OSError as e: + return S_ERROR(f"[{connection.host}] Failed uploading file: {str(e)}") + except SSHException as e: + return S_ERROR(f"[{connection.host}] SSH error occurred: {str(e)}") + + def _get(self, connection: Connection, remote: str, local: str, preserveMode: bool = True): + """Download a file from the remote host""" + try: + if local == "Memory": + # Download to memory: use BytesIO buffer + from io import BytesIO + + buffer = BytesIO() + connection.get(remote, local=buffer) + content = buffer.getvalue().decode("utf-8", errors="replace") + return S_OK(content) + else: + # Download to file + connection.get(remote, local=local, preserve_mode=preserveMode) + return S_OK() + except OSError as e: + return S_ERROR(f"[{connection.host}] Failed downloading file: {str(e)}") + except SSHException as e: + return S_ERROR(f"[{connection.host}] SSH error occurred: {str(e)}") - if "WorkArea" not in self.ceParameters: - self.ceParameters["WorkArea"] = "work" + ############################################################################# def _getBatchSystem(self): """Load a Batch System instance from the CE Parameters""" @@ -354,90 +170,163 @@ def _getBatchSystem(self): def _getBatchSystemDirectoryLocations(self): """Get names of the locations to store outputs, errors, info and executables.""" - self.sharedArea = self.ceParameters["SharedArea"] - self.batchOutput = self.ceParameters["BatchOutput"] - if not self.batchOutput.startswith("/"): - self.batchOutput = os.path.join(self.sharedArea, self.batchOutput) - self.batchError = self.ceParameters["BatchError"] - if not self.batchError.startswith("/"): - self.batchError = os.path.join(self.sharedArea, self.batchError) - self.infoArea = self.ceParameters["InfoArea"] - if not self.infoArea.startswith("/"): - self.infoArea = os.path.join(self.sharedArea, self.infoArea) - self.executableArea = self.ceParameters["ExecutableArea"] - if not self.executableArea.startswith("/"): - self.executableArea = os.path.join(self.sharedArea, self.executableArea) - self.workArea = self.ceParameters["WorkArea"] - if not self.workArea.startswith("/"): - self.workArea = os.path.join(self.sharedArea, self.workArea) + self.sharedArea = self.ceParameters.get("SharedArea", self.sharedArea) + + def _get_dir(directory: str, defaultValue: str) -> str: + value = self.ceParameters.get(directory, defaultValue) + if value.startswith("/"): + return value + return os.path.join(self.sharedArea, value) + + self.batchOutput = _get_dir("BatchOutput", self.batchOutput) + self.batchError = _get_dir("BatchError", self.batchError) + self.infoArea = _get_dir("InfoArea", self.infoArea) + self.executableArea = _get_dir("ExecutableArea", self.executableArea) + self.workArea = _get_dir("WorkArea", self.workArea) + + def _parseTunnel(self, tunnel: str) -> tuple[str, int | None]: + """Parse a tunnel string to extract the gateway hostname and port. + + Supported formats: + - ``hostname`` + - ``hostname:port`` + + :param tunnel: The tunnel/gateway host string + :return: Tuple of (hostname, port) where port is None if not specified + """ + if not tunnel: + return None, None + + tunnel = tunnel.strip() + + if ":" in tunnel: + hostname, port_str = tunnel.rsplit(":", 1) + try: + port = int(port_str) + except ValueError: + self.log.warn(f"Invalid port number '{port_str}' in SSHTunnel value '{tunnel}', ignoring port") + return tunnel, None + return hostname, port + + return tunnel, None + + def _getConnection( + self, + host: str, + user: str, + port: int, + password: str, + key: str, + gateway_host: str | None = None, + gateway_port: int | None = None, + ) -> Connection: + """Get a Connection instance to the host. + + :param host: The final destination host + :param user: SSH username + :param port: SSH port + :param password: SSH password + :param key: SSH key file path + :param gateway_host: The gateway/jump host (None if no gateway) + """ + connectionParams = {} + if password: + connectionParams["password"] = password + if key: + connectionParams["key_filename"] = key + + gateway = None + if gateway_host: + gateway = Connection(gateway_host, user=user, port=gateway_port, connect_kwargs=connectionParams) + + return Connection( + host, + user=user, + port=port, + gateway=gateway, + connect_kwargs=connectionParams, + connect_timeout=self.timeout, + ) def _reset(self): """Process CE parameters and make necessary adjustments""" + # Get the Batch System instance result = self._getBatchSystem() if not result["OK"]: return result + + # Get the location of the remote directories self._getBatchSystemDirectoryLocations() - self.user = self.ceParameters["SSHUser"] + # Get the SSH parameters + self.host = self.ceParameters.get("SSHHost", self.host) + self.timeout = self.ceParameters.get("Timeout", self.timeout) + self.user = self.ceParameters.get("SSHUser", self.user) + port = self.ceParameters.get("SSHPort", None) + password = self.ceParameters.get("SSHPassword", None) + key = self.ceParameters.get("SSHKey", None) + tunnel = self.ceParameters.get("SSHTunnel", None) + + # When SSHTunnel is specified, it defines the gateway/jump host + # used to reach the final destination (SSHHost). + if tunnel: + gateway_host, gateway_port = self._parseTunnel(tunnel) + else: + gateway_host = None + gateway_port = None + + # Configure the SSH connection + self.connection = self._getConnection(self.host, self.user, port, password, key, gateway_host, gateway_port) + + # Get submission parameters + self.submitOptions = self.ceParameters.get("SubmitOptions", self.submitOptions) + self.preamble = self.ceParameters.get("Preamble", self.preamble) + self.account = self.ceParameters.get("Account", self.account) self.queue = self.ceParameters["Queue"] self.log.info("Using queue: ", self.queue) - self.submitOptions = self.ceParameters.get("SubmitOptions", "") - self.preamble = self.ceParameters.get("Preamble", "") + # Get output and error templates + self.outputTemplate = self.ceParameters.get("OutputTemplate", self.outputTemplate) + self.errorTemplate = self.ceParameters.get("ErrorTemplate", self.errorTemplate) - self.account = self.ceParameters.get("Account", "") - result = self._prepareRemoteHost() + # Prepare the remote host + result = self._prepareRemoteHost(self.connection) if not result["OK"]: return result return S_OK() - def _prepareRemoteHost(self, host=None): + def _prepareRemoteHost(self, connection: Connection): """Prepare remote directories and upload control script""" - - ssh = SSH(host=host, parameters=self.ceParameters) - # Make remote directories + self.log.verbose(f"Creating working directories on {self.host}") dirTuple = tuple( uniqueElements( [self.sharedArea, self.executableArea, self.infoArea, self.batchOutput, self.batchError, self.workArea] ) ) - nDirs = len(dirTuple) - cmd = "mkdir -p %s; " * nDirs % dirTuple - cmd = f"bash -c '{cmd}'" - self.log.verbose(f"Creating working directories on {self.ceParameters['SSHHost']}") - result = ssh.sshCall(30, cmd) + cmd = f"mkdir -p {' '.join(dirTuple)}" + result = self._run(connection, cmd) if not result["OK"]: - self.log.error("Failed creating working directories", f"({result['Message']})") + self.log.error("Failed creating working directories: ", result["Message"]) return result - status, output, _error = result["Value"] - if status == -1: - self.log.error("Timeout while creating directories") - return S_ERROR(errno.ETIME, "Timeout while creating directories") - if "cannot" in output: - self.log.error("Failed to create directories", f"({output})") - return S_ERROR(errno.EACCES, "Failed to create directories") # Upload the control script now + self.log.verbose("Generating control script") result = self._generateControlScript() if not result["OK"]: - self.log.warn("Failed generating control script") + self.log.error("Failed generating control script") return result localScript = result["Value"] - self.log.verbose(f"Uploading {self.batchSystem.__class__.__name__} script to {self.ceParameters['SSHHost']}") + os.chmod(localScript, 0o755) + + self.log.verbose(f"Uploading {self.batchSystem.__class__.__name__} script to {self.host}") remoteScript = f"{self.sharedArea}/execute_batch" - result = ssh.scpCall(30, localScript, remoteScript, postUploadCommand=f"chmod +x {remoteScript}") + + result = self._put(connection, localScript, remote=remoteScript) if not result["OK"]: - self.log.warn(f"Failed uploading control script: {result['Message']}") + self.log.error(f"Failed uploading control script: {result['Message']}") return result - status, output, _error = result["Value"] - if status != 0: - if status == -1: - self.log.warn("Timeout while uploading control script") - return S_ERROR("Timeout while uploading control script") - self.log.warn(f"Failed uploading control script: {output}") - return S_ERROR("Failed uploading control script") # Delete the generated control script locally try: @@ -470,10 +359,10 @@ def _generateControlScript(self): return S_OK(f"{controlScript}") - def __executeHostCommand(self, command, options, ssh=None, host=None): - if not ssh: - ssh = SSH(host=host, parameters=self.ceParameters) + ############################################################################# + def __executeHostCommand(self, connection: Connection, command: str, options: dict[str]): + """Execute a command on the remote host""" options["BatchSystem"] = self.batchSystem.__class__.__name__ options["Method"] = command options["SharedDir"] = self.sharedArea @@ -497,7 +386,7 @@ def __executeHostCommand(self, command, options, ssh=None, host=None): # Upload the options file to the remote host remoteOptionsFile = f"{self.sharedArea}/batch_options_{uuid.uuid4().hex}.json" - result = ssh.scpCall(30, localOptionsFile, remoteOptionsFile) + result = self._put(connection, localOptionsFile, remote=remoteOptionsFile) if not result["OK"]: return result @@ -510,17 +399,10 @@ def __executeHostCommand(self, command, options, ssh=None, host=None): self.log.verbose(f"CE submission command: {cmd}") - result = ssh.sshCall(120, cmd) + result = self._run(connection, cmd) if not result["OK"]: - self.log.error(f"{self.ceType} CE job submission failed", result["Message"]) return result - sshStatus = result["Value"][0] - if sshStatus != 0: - sshStdout = result["Value"][1] - sshStderr = result["Value"][2] - return S_ERROR(f"CE job submission command failed with status {sshStatus}: {sshStdout} {sshStderr}") - # The result should be written to a JSON file by execute_batch # Compute the expected result file path remoteResultFile = remoteOptionsFile.replace(".json", "_result.json") @@ -529,7 +411,7 @@ def __executeHostCommand(self, command, options, ssh=None, host=None): with tempfile.NamedTemporaryFile(mode="r", suffix=".json", delete=False) as f: localResultFile = f.name - result = ssh.scpCall(30, localResultFile, remoteResultFile, upload=False) + result = self._get(connection, remoteResultFile, local=localResultFile) if not result["OK"]: return result @@ -545,28 +427,44 @@ def __executeHostCommand(self, command, options, ssh=None, host=None): os.remove(localResultFile) # Clean up remote temporary files if remoteOptionsFile: - ssh.sshCall(30, f"rm -f {remoteOptionsFile}") + self._run(connection, f"rm -f {remoteOptionsFile}") if remoteResultFile: - ssh.sshCall(30, f"rm -f {remoteResultFile}") + self._run(connection, f"rm -f {remoteResultFile}") def submitJob(self, executableFile, proxy, numberOfJobs=1): - # self.log.verbose( "Executable file path: %s" % executableFile ) if not os.access(executableFile, 5): os.chmod(executableFile, stat.S_IRWXU | stat.S_IRGRP | stat.S_IXGRP | stat.S_IROTH | stat.S_IXOTH) - return self._submitJobToHost(executableFile, numberOfJobs) + result = self._submitJobToHost(self.connection, executableFile, numberOfJobs) + if not result["OK"]: + return result + + batchIDs, jobStamps = result["Value"] + batchSystemName = self.batchSystem.__class__.__name__.lower() + jobIDs = [f"{self.ceType.lower()}{batchSystemName}://{self.ceName}/{_id}" for _id in batchIDs] + + result = S_OK(jobIDs) + stampDict = {} + for iJob, jobID in enumerate(jobIDs): + stampDict[jobID] = jobStamps[iJob] + result["PilotStampDict"] = stampDict + self.submittedJobs += len(batchIDs) + + return result - def _submitJobToHost(self, executableFile, numberOfJobs, host=None): + def _submitJobToHost(self, connection: Connection, executableFile: str, numberOfJobs: int): """Submit prepared executable to the given host""" - ssh = SSH(host=host, parameters=self.ceParameters) # Copy the executable + self.log.verbose(f"Copying executable to {self.host}") submitFile = os.path.join(self.executableArea, os.path.basename(executableFile)) - result = ssh.scpCall(30, executableFile, submitFile, postUploadCommand=f"chmod +x {submitFile}") + os.chmod(executableFile, 0o755) + + result = self._put(connection, executableFile, submitFile) if not result["OK"]: return result jobStamps = [] - for _i in range(numberOfJobs): + for _ in range(numberOfJobs): jobStamps.append(uuid.uuid4().hex) numberOfProcessors = self.ceParameters.get("NumberOfProcessors", 1) @@ -589,10 +487,8 @@ def _submitJobToHost(self, executableFile, numberOfJobs, host=None): "NumberOfGPUs": self.numberOfGPUs, "Account": self.account, } - if host: - commandOptions["SSHNodeHost"] = host - resultCommand = self.__executeHostCommand("submitJob", commandOptions, ssh=ssh, host=host) + resultCommand = self.__executeHostCommand(connection, "submitJob", commandOptions) if not resultCommand["OK"]: return resultCommand @@ -601,42 +497,29 @@ def _submitJobToHost(self, executableFile, numberOfJobs, host=None): return S_ERROR("Invalid result from job submission") if result["Status"] != 0: return S_ERROR(f"Failed job submission: {result['Message']}") - else: - batchIDs = result["Jobs"] - if batchIDs: - batchSystemName = self.batchSystem.__class__.__name__.lower() - if host is None: - jobIDs = [f"{self.ceType.lower()}{batchSystemName}://{self.ceName}/{_id}" for _id in batchIDs] - else: - jobIDs = [ - f"{self.ceType.lower()}{batchSystemName}://{self.ceName}/{host}/{_id}" for _id in batchIDs - ] - else: - return S_ERROR("No jobs IDs returned") - result = S_OK(jobIDs) - stampDict = {} - for iJob, jobID in enumerate(jobIDs): - stampDict[jobID] = jobStamps[iJob] - result["PilotStampDict"] = stampDict - self.submittedJobs += len(batchIDs) + batchIDs = result["Jobs"] + if not batchIDs: + return S_ERROR("No jobs IDs returned") - return result + return S_OK((batchIDs, jobStamps)) + + ############################################################################# def killJob(self, jobIDList): """Kill a bunch of jobs""" if isinstance(jobIDList, str): jobIDList = [jobIDList] - return self._killJobOnHost(jobIDList) + return self._killJobOnHost(self.connection, jobIDList) - def _killJobOnHost(self, jobIDList, host=None): + def _killJobOnHost(self, connection: Connection, jobIDList: list[str]): """Kill the jobs for the given list of job IDs""" batchSystemJobList = [] for jobID in jobIDList: batchSystemJobList.append(os.path.basename(urlparse(jobID.split(":::")[0]).path)) commandOptions = {"JobIDList": batchSystemJobList, "User": self.user} - resultCommand = self.__executeHostCommand("killJob", commandOptions, host=host) + resultCommand = self.__executeHostCommand(connection, "killJob", commandOptions) if not resultCommand["OK"]: return resultCommand @@ -651,6 +534,8 @@ def _killJobOnHost(self, jobIDList, host=None): return S_OK(len(result["Successful"])) + ############################################################################# + def getCEStatus(self): """Method to return information on running and pending jobs.""" result = S_OK() @@ -658,7 +543,7 @@ def getCEStatus(self): result["RunningJobs"] = 0 result["WaitingJobs"] = 0 - resultHost = self._getHostStatus() + resultHost = self._getHostStatus(self.connection) if not resultHost["OK"]: return resultHost @@ -671,9 +556,9 @@ def getCEStatus(self): return result - def _getHostStatus(self, host=None): + def _getHostStatus(self, connection: Connection): """Get jobs running at a given host""" - resultCommand = self.__executeHostCommand("getCEStatus", {}, host=host) + resultCommand = self.__executeHostCommand(connection, "getCEStatus", {}) if not resultCommand["OK"]: return resultCommand @@ -685,11 +570,13 @@ def _getHostStatus(self, host=None): return S_OK(result) + ############################################################################# + def getJobStatus(self, jobIDList): """Get the status information for the given list of jobs""" - return self._getJobStatusOnHost(jobIDList) + return self._getJobStatusOnHost(self.connection, jobIDList) - def _getJobStatusOnHost(self, jobIDList, host=None): + def _getJobStatusOnHost(self, connection: Connection, jobIDList: list[str]): """Get the status information for the given list of jobs""" resultDict = {} batchSystemJobDict = {} @@ -698,7 +585,7 @@ def _getJobStatusOnHost(self, jobIDList, host=None): batchSystemJobDict[batchSystemJobID] = jobID for jobList in breakListIntoChunks(list(batchSystemJobDict), 100): - resultCommand = self.__executeHostCommand("getJobStatus", {"JobIDList": jobList}, host=host) + resultCommand = self.__executeHostCommand(connection, "getJobStatus", {"JobIDList": jobList}) if not resultCommand["OK"]: return resultCommand @@ -713,65 +600,23 @@ def _getJobStatusOnHost(self, jobIDList, host=None): return S_OK(resultDict) + ############################################################################# + def getJobOutput(self, jobID, localDir=None): """Get the specified job standard output and error files. If the localDir is provided, the output is returned as file in this directory. Otherwise, the output is returned as strings. """ self.log.verbose("Getting output for jobID", jobID) - result = self._getJobOutputFiles(jobID) - if not result["OK"]: - return result - - batchSystemJobID, host, outputFile, errorFile = result["Value"] - - if localDir: - localOutputFile = f"{localDir}/{batchSystemJobID}.out" - localErrorFile = f"{localDir}/{batchSystemJobID}.err" - else: - localOutputFile = "Memory" - localErrorFile = "Memory" + return self._getJobOutputFilesOnHost(self.connection, jobID, localDir) - # Take into account the SSHBatch possible SSHHost syntax - host = host.split("/")[0] - - ssh = SSH(host=host, parameters=self.ceParameters) - resultStdout = ssh.scpCall(30, localOutputFile, outputFile, upload=False) - if not resultStdout["OK"]: - return resultStdout - - resultStderr = ssh.scpCall(30, localErrorFile, errorFile, upload=False) - if not resultStderr["OK"]: - return resultStderr - - if localDir: - output = localOutputFile - error = localErrorFile - else: - output = resultStdout["Value"][1] - error = resultStderr["Value"][1] - - return S_OK((output, error)) - - def _getJobOutputFiles(self, jobID): + def _getJobOutputFilesOnHost(self, connection: Connection, jobID: str, localDir: str | None = None): """Get output file names for the specific CE""" batchSystemJobID = os.path.basename(urlparse(jobID.split(":::")[0]).path) - # host can be retrieved from the path of the jobID - # it might not be present, in this case host is an empty string and will be defined by the CE parameters later - host = os.path.dirname(urlparse(jobID).path).lstrip("/") - - if "OutputTemplate" in self.ceParameters: - self.outputTemplate = self.ceParameters["OutputTemplate"] - self.errorTemplate = self.ceParameters["ErrorTemplate"] if self.outputTemplate: - output = self.outputTemplate % batchSystemJobID - error = self.errorTemplate % batchSystemJobID - elif "OutputTemplate" in self.ceParameters: - self.outputTemplate = self.ceParameters["OutputTemplate"] - self.errorTemplate = self.ceParameters["ErrorTemplate"] - output = self.outputTemplate % batchSystemJobID - error = self.errorTemplate % batchSystemJobID + outputFile = self.outputTemplate % batchSystemJobID + errorFile = self.errorTemplate % batchSystemJobID elif hasattr(self.batchSystem, "getJobOutputFiles"): # numberOfNodes is treated as a string as it can contain values such as "2-4" # where 2 would represent the minimum number of nodes to allocate, and 4 the maximum @@ -782,7 +627,7 @@ def _getJobOutputFiles(self, jobID): "ErrorDir": self.batchError, "NumberOfNodes": numberOfNodes, } - resultCommand = self.__executeHostCommand("getJobOutputFiles", commandOptions, host=host) + resultCommand = self.__executeHostCommand(connection, "getJobOutputFiles", commandOptions) if not resultCommand["OK"]: return resultCommand @@ -796,10 +641,32 @@ def _getJobOutputFiles(self, jobID): self.outputTemplate = result["OutputTemplate"] self.errorTemplate = result["ErrorTemplate"] - output = result["Jobs"][batchSystemJobID]["Output"] - error = result["Jobs"][batchSystemJobID]["Error"] + outputFile = result["Jobs"][batchSystemJobID]["Output"] + errorFile = result["Jobs"][batchSystemJobID]["Error"] + else: + outputFile = f"{self.batchOutput}/{batchSystemJobID}.out" + errorFile = f"{self.batchError}/{batchSystemJobID}.err" + + if localDir: + localOutputFile = f"{localDir}/{batchSystemJobID}.out" + localErrorFile = f"{localDir}/{batchSystemJobID}.err" else: - output = f"{self.batchOutput}/{batchSystemJobID}.out" - error = f"{self.batchError}/{batchSystemJobID}.err" + localOutputFile = "Memory" + localErrorFile = "Memory" - return S_OK((batchSystemJobID, host, output, error)) + resultStdout = self._get(connection, outputFile, local=localOutputFile, preserveMode=False) + if not resultStdout["OK"]: + return resultStdout + + resultStderr = self._get(connection, errorFile, local=localErrorFile, preserveMode=False) + if not resultStderr["OK"]: + return resultStderr + + if localDir: + output = localOutputFile + error = localErrorFile + else: + output = resultStdout["Value"] + error = resultStderr["Value"] + + return S_OK((output, error)) diff --git a/src/DIRAC/WorkloadManagementSystem/scripts/dirac_admin_debug_ce.py b/src/DIRAC/WorkloadManagementSystem/scripts/dirac_admin_debug_ce.py index dcae26da313..da8d8e519ee 100644 --- a/src/DIRAC/WorkloadManagementSystem/scripts/dirac_admin_debug_ce.py +++ b/src/DIRAC/WorkloadManagementSystem/scripts/dirac_admin_debug_ce.py @@ -1,20 +1,19 @@ #!/usr/bin/env python """ -Test the interactions with a given set of Computing Elements (CE). For each CE: +Test the interactions with a given set of Computing Elements (CE). For each CE - Get the CE status if available - Submit a job to the CE - Get the job status - Get the job output/error/log if available -Conditions: +Conditions - The CEs must be configured in the DIRAC configuration -- The script should be executed with an admin proxy: used to fetch a pilot proxy and a token -- The script should be executed: - - - in a DIRAC client environment for normal CEs, such as AREX and HTCondor - - in a DIRAC host environment for SSH/Local CEs (credentials would not be available otherwise) +- For non-SSH CEs (e.g. AREX, HTCondor): requires an admin proxy (FullDelegation) + to fetch pilot proxy and token. Run from a DIRAC client environment. +- For SSH/SSHBatch CEs: no proxy required. These use SSH credentials from the CS + configuration. Run from a DIRAC server/host environment. Usage: dirac-admin-debug-ce [--site ] [--ce ] [--ce-type ] [--script ] @@ -23,6 +22,7 @@ $ dirac-admin-debug-ce dteam --site LCG.CERN.cern --ce-type HTCondorCE """ import concurrent.futures +import tempfile import time from pathlib import Path @@ -153,10 +153,11 @@ def buildQueues(vo, sites, ces, ceTypes): return result["Value"] -def interactWithCE(ce): +def interactWithCE(ce, executableFile): """Interact with a given Computing Element (CE). :param ce: The Computing Element instance. + :param str executableFile: The path to the executable script to submit. :return: A dictionary with the result of each check. """ checks = { @@ -178,7 +179,7 @@ def interactWithCE(ce): # Submit a job to the CE gLogger.info(f"[{ce.ceName}]Submitting a job") - res = ce.submitJob("workloadExec.sh", None) + res = ce.submitJob(executableFile, None) if not res["OK"]: gLogger.error(f"[{ce.ceName}]Cannot submit job to CE: {res['Message']}") checks["job_submit"]["Message"] = res["Message"] @@ -236,31 +237,19 @@ def main(): Script.registerSwitch("", "script=", "Path to custom executable script (default: workloadExec.sh)", setScript) Script.registerSwitch("", "timeout=", "Timeout in seconds for job status polling (default: 1800)", setTimeout) Script.registerArgument("VO: Virtual Organization") - Script.parseCommandLine() - from DIRAC.Core.Security.Properties import SecurityProperty - from DIRAC.Core.Security.ProxyInfo import getProxyInfo + # Determine the credential context before parseCommandLine contacts the CS: + # if no proxy file exists, assume we are on a server with a host certificate. + from DIRAC.Core.Security.Locations import getProxyLocation - # Check credentials - result = getProxyInfo() - if not result["OK"]: - gLogger.error("Do you have a valid proxy?") - gLogger.error(result["Message"]) - DIRAC.exit(1) - proxyProps = result["Value"] + hasProxy = getProxyLocation() is not None + if not hasProxy: + Script.localCfg.addDefaultEntry("/DIRAC/Security/UseServerCertificate", "true") - if SecurityProperty.FULL_DELEGATION not in proxyProps.get("groupProperties", []): - gLogger.error("You need an admin proxy (with FullDelegation property) to run this script") - DIRAC.exit(1) + Script.parseCommandLine() vo = Script.getPositionalArgs()[0] - # Get credentials for the given VO - pilotDN, pilotGroup = findGenericCreds(vo) - if not pilotDN or not pilotGroup: - gLogger.error("Cannot get pilot credentials") - DIRAC.exit(1) - # Get the queues queueDict = buildQueues( vo=vo, @@ -272,6 +261,36 @@ def main(): gLogger.error("Cannot get queues") DIRAC.exit(1) + # SSH-based CEs run from the server and use SSH credentials from the CS, + # so they do not need a proxy or pilot credentials. + hasNonSSHCEs = any(not q["CE"].ceType.startswith("SSH") for q in queueDict.values()) + + pilotDN, pilotGroup = None, None + if hasNonSSHCEs: + from DIRAC.Core.Security.Properties import SecurityProperty + from DIRAC.Core.Security.ProxyInfo import getProxyInfo + + # Non-SSH CEs require an admin proxy to fetch pilot credentials + if not hasProxy: + gLogger.error("Non-SSH CEs found but no proxy available. Do you have a valid proxy?") + DIRAC.exit(1) + + result = getProxyInfo() + if not result["OK"]: + gLogger.error("Failed to read proxy info", result["Message"]) + DIRAC.exit(1) + proxyProps = result["Value"] + + if SecurityProperty.FULL_DELEGATION not in proxyProps.get("groupProperties", []): + gLogger.error("You need an admin proxy (with FullDelegation property) to run this script") + DIRAC.exit(1) + + # Get credentials for the given VO + pilotDN, pilotGroup = findGenericCreds(vo) + if not pilotDN or not pilotGroup: + gLogger.error("Cannot get pilot credentials") + DIRAC.exit(1) + if scriptPath: gLogger.info(f"Using custom script: {scriptPath}") executable = Path(scriptPath) @@ -280,7 +299,7 @@ def main(): DIRAC.exit(1) else: gLogger.info("Creating default workloadExec.sh") - executable = Path("workloadExec.sh") + executable = Path(tempfile.gettempdir()) / "workloadExec.sh" with open(executable, "w") as f: f.write("#!/bin/bash\n") f.write("echo 'Hello from DIRAC!'\n") @@ -293,7 +312,7 @@ def main(): def process_queue(queueName): ce = queueDict[queueName]["CE"] - if ce.ceType != "SSH": + if not ce.ceType.startswith("SSH"): gLogger.info(f"Getting creds for CE: {ce.ceName} ({ce.ceType})") proxy, token = getCredentials(pilotDN, pilotGroup, ce) if not proxy or not token: @@ -308,7 +327,7 @@ def process_queue(queueName): if ce.ceType == "HTCondorCE": ce.workingDirectory = str(Path.cwd()) gLogger.info(f"Interacting with CE: {ce.ceName} ({ce.ceType})") - return queueName, interactWithCE(ce) + return queueName, interactWithCE(ce, str(executable)) with concurrent.futures.ThreadPoolExecutor() as executor: results = executor.map(process_queue, list(queueDict.keys()))