Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 24 additions & 15 deletions sky/adaptors/slurm.py
Original file line number Diff line number Diff line change
Expand Up @@ -50,17 +50,18 @@ class SlurmClient:

def __init__(
Copy link
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what do you think about just adding a boolean argument is_local_execution_mode: bool = False and pass that in from sky/provision/slurm/instance.py to make the intent more explicit?

self,
ssh_host: str,
ssh_port: int,
ssh_user: str,
ssh_key: Optional[str],
ssh_host: Optional[str] = None,
ssh_port: Optional[int] = None,
ssh_user: Optional[str] = None,
ssh_key: Optional[str] = None,
ssh_proxy_command: Optional[str] = None,
ssh_proxy_jump: Optional[str] = None,
):
"""Initialize SlurmClient.

Args:
ssh_host: Hostname of the Slurm controller.
ssh_host: Hostname of the Slurm controller. If None, uses local
execution mode (for when running on the Slurm cluster itself).
ssh_port: SSH port on the controller.
ssh_user: SSH username.
ssh_key: Path to SSH private key, or None for keyless SSH.
Expand All @@ -74,16 +75,24 @@ def __init__(
self.ssh_proxy_command = ssh_proxy_command
self.ssh_proxy_jump = ssh_proxy_jump

# Internal runner for executing Slurm CLI commands
# on the controller node.
self._runner = command_runner.SSHCommandRunner(
(ssh_host, ssh_port),
ssh_user,
ssh_key,
ssh_proxy_command=ssh_proxy_command,
ssh_proxy_jump=ssh_proxy_jump,
enable_interactive_auth=True,
)
self._runner: command_runner.CommandRunner

if ssh_host is None:
# Local execution mode - for running on the Slurm cluster itself
# (e.g., autodown from skylet).
self._runner = command_runner.LocalProcessCommandRunner()
else:
# Remote execution via SSH
assert ssh_port is not None
assert ssh_user is not None
self._runner = command_runner.SSHCommandRunner(
(ssh_host, ssh_port),
ssh_user,
ssh_key,
ssh_proxy_command=ssh_proxy_command,
ssh_proxy_jump=ssh_proxy_jump,
enable_interactive_auth=True,
)

def _run_slurm_cmd(self, cmd: str) -> Tuple[int, str, str]:
return self._runner.run(cmd,
Expand Down
54 changes: 28 additions & 26 deletions sky/provision/slurm/instance.py
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,7 @@ def _create_virtual_instance(
skypilot_runtime_dir = _skypilot_runtime_dir(cluster_name_on_cloud)
sky_home_dir = _sky_cluster_home_dir(cluster_name_on_cloud)
ready_signal = f'{sky_home_dir}/.sky_sbatch_ready'
slurm_marker_file = f'{sky_home_dir}/{slurm_utils.SLURM_MARKER_FILE}'

# Build the sbatch script
gpu_directive = ''
Expand Down Expand Up @@ -217,6 +218,8 @@ def _create_virtual_instance(
mkdir -p {sky_home_dir}
# Create sky runtime directory on each node.
srun --nodes={num_nodes} mkdir -p {skypilot_runtime_dir}
# Marker file to indicate we're in a Slurm cluster.
touch {slurm_marker_file}
# Suppress login messages.
touch {sky_home_dir}/.hushlogin
# Signal that the sbatch script has completed setup.
Expand Down Expand Up @@ -487,32 +490,31 @@ def terminate_instances(
'worker_only=True is not supported for Slurm, this is a no-op.')
return

ssh_config_dict = provider_config['ssh']
ssh_host = ssh_config_dict['hostname']
ssh_port = int(ssh_config_dict['port'])
ssh_user = ssh_config_dict['user']
ssh_private_key = ssh_config_dict['private_key']
# Check if we are running inside a Slurm job (Only happens with autodown,
# where the Skylet will invoke terminate_instances on the remote cluster),
# where we assume SSH between nodes have been set up on each node's
# ssh config.
# TODO(kevin): Validate this assumption. Another way would be to
# mount the private key to the remote cluster, like we do with
# other clouds' API keys.
if slurm_utils.is_inside_slurm_job():
logger.debug('Running inside a Slurm job, using machine\'s ssh config')
ssh_private_key = None
ssh_proxy_command = ssh_config_dict.get('proxycommand', None)
ssh_proxy_jump = ssh_config_dict.get('proxyjump', None)

client = slurm.SlurmClient(
ssh_host,
ssh_port,
ssh_user,
ssh_private_key,
ssh_proxy_command=ssh_proxy_command,
ssh_proxy_jump=ssh_proxy_jump,
)
# Check if we are running inside a Slurm cluster (only happens with
# autodown, where the Skylet invokes terminate_instances on the remote
# cluster). In this case, use local execution instead of SSH.
# This assumes that the compute node is able to run scancel.
# TODO(kevin): Validate this assumption.
if slurm_utils.is_inside_slurm_cluster():
logger.debug('Running inside a Slurm cluster, using local execution')
client = slurm.SlurmClient()
else:
ssh_config_dict = provider_config['ssh']
ssh_host = ssh_config_dict['hostname']
ssh_port = int(ssh_config_dict['port'])
ssh_user = ssh_config_dict['user']
ssh_private_key = ssh_config_dict['private_key']
ssh_proxy_command = ssh_config_dict.get('proxycommand', None)
ssh_proxy_jump = ssh_config_dict.get('proxyjump', None)

client = slurm.SlurmClient(
ssh_host,
ssh_port,
ssh_user,
ssh_private_key,
ssh_proxy_command=ssh_proxy_command,
ssh_proxy_jump=ssh_proxy_jump,
)
jobs_state = client.get_jobs_state_by_name(cluster_name_on_cloud)
if not jobs_state:
logger.debug(f'Job for cluster {cluster_name_on_cloud} not found, '
Expand Down
9 changes: 7 additions & 2 deletions sky/provision/slurm/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
logger = sky_logging.init_logger(__name__)

DEFAULT_SLURM_PATH = '~/.slurm/config'
SLURM_MARKER_FILE = '.sky_slurm_cluster'


def get_slurm_ssh_config() -> SSHConfig:
Expand Down Expand Up @@ -523,8 +524,12 @@ def slurm_node_info(
return node_list


def is_inside_slurm_job() -> bool:
return os.environ.get('SLURM_JOB_ID') is not None
def is_inside_slurm_cluster() -> bool:
# Check for the marker file in the current home directory. When run by
# the skylet on a compute node, the HOME environment variable is set to
# the cluster's sky home directory by the SlurmCommandRunner.
marker_file = os.path.join(os.path.expanduser('~'), SLURM_MARKER_FILE)
return os.path.exists(marker_file)


@annotations.lru_cache(scope='request')
Expand Down
18 changes: 8 additions & 10 deletions tests/unit_tests/test_sky/clouds/test_slurm.py
Original file line number Diff line number Diff line change
Expand Up @@ -117,15 +117,13 @@ class TestTerminateInstances:
('SUSPENDED', True, True),
('STAGING_OUT', True, True),
])
@patch('sky.provision.slurm.instance.slurm_utils.is_inside_slurm_job')
@patch('sky.provision.slurm.instance.slurm_utils.is_inside_slurm_cluster')
@patch('sky.provision.slurm.instance.slurm.SlurmClient')
def test_terminate_instances_handles_job_states(self,
mock_slurm_client_class,
mock_is_inside_job,
job_state, should_cancel,
should_signal):
def test_terminate_instances_handles_job_states(
self, mock_slurm_client_class, mock_is_inside_slurm_cluster,
job_state, should_cancel, should_signal):
"""Test terminate_instances handles different job states correctly."""
mock_is_inside_job.return_value = False
mock_is_inside_slurm_cluster.return_value = False

mock_client = mock.MagicMock()
mock_slurm_client_class.return_value = mock_client
Expand Down Expand Up @@ -163,12 +161,12 @@ def test_terminate_instances_handles_job_states(self,
else:
mock_client.cancel_jobs_by_name.assert_not_called()

@patch('sky.provision.slurm.instance.slurm_utils.is_inside_slurm_job')
@patch('sky.provision.slurm.instance.slurm_utils.is_inside_slurm_cluster')
@patch('sky.provision.slurm.instance.slurm.SlurmClient')
def test_terminate_instances_no_jobs_found(self, mock_slurm_client_class,
mock_is_inside_job):
mock_is_inside_slurm_cluster):
"""Test terminate_instances when no jobs are found."""
mock_is_inside_job.return_value = False
mock_is_inside_slurm_cluster.return_value = False

mock_client = mock.MagicMock()
mock_slurm_client_class.return_value = mock_client
Expand Down
Loading