Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
63 changes: 49 additions & 14 deletions sky/provision/slurm/instance.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
"""Slurm instance provisioning."""

import os
import tempfile
import textwrap
import time
Expand All @@ -22,38 +23,44 @@
# TODO(kevin): This assumes $HOME is in a shared filesystem.
# We should probably make it configurable, and add a check
# during sky check.
SHARED_ROOT_SKY_DIRECTORY = '~/.sky_clusters'
SKY_CLUSTERS_DIRECTORY_NAME = '.sky_clusters'
PROVISION_SCRIPTS_DIRECTORY_NAME = '.sky_provision'
PROVISION_SCRIPTS_DIRECTORY = f'~/{PROVISION_SCRIPTS_DIRECTORY_NAME}'

POLL_INTERVAL_SECONDS = 2
# Default KillWait is 30 seconds, so we add some buffer time here.
_JOB_TERMINATION_TIMEOUT_SECONDS = 60
_SKY_DIR_CREATION_TIMEOUT_SECONDS = 30


def _sky_cluster_home_dir(cluster_name_on_cloud: str) -> str:
def _sky_cluster_home_dir(workdir: Optional[str],
cluster_name_on_cloud: str) -> str:
"""Returns the SkyPilot cluster's home directory path on the Slurm cluster.

This path is assumed to be on a shared NFS mount accessible by all nodes.
To support clusters with non-NFS home directories, we would need to let
users specify an NFS-backed "working directory" or use a different
coordination mechanism.
"""
return f'{SHARED_ROOT_SKY_DIRECTORY}/{cluster_name_on_cloud}'
wd = workdir if workdir is not None else '~'
return os.path.join(wd, SKY_CLUSTERS_DIRECTORY_NAME, cluster_name_on_cloud)


def _sbatch_provision_script_path(filename: str) -> str:
def _sbatch_provision_script_path(workdir: Optional[str],
cluster_name_on_cloud: str) -> str:
"""Returns the path to the sbatch provision script on the login node."""
# Put sbatch script in $HOME instead of /tmp as there can be
# multiple login nodes, and different SSH connections
# can land on different login nodes.
return f'{PROVISION_SCRIPTS_DIRECTORY}/{filename}'
wd = workdir if workdir is not None else '~'
return os.path.join(wd, PROVISION_SCRIPTS_DIRECTORY_NAME,
f'{cluster_name_on_cloud}.sh')


def _skypilot_runtime_dir(cluster_name_on_cloud: str) -> str:
def _skypilot_runtime_dir(tmpdir: Optional[str],
cluster_name_on_cloud: str) -> str:
"""Returns the SkyPilot runtime directory path on the Slurm cluster."""
return f'/tmp/{cluster_name_on_cloud}'
tmp = tmpdir if tmpdir is not None else '/tmp'
return os.path.join(tmp, cluster_name_on_cloud)


@timeline.event
Expand Down Expand Up @@ -128,6 +135,16 @@ def _create_virtual_instance(
keys=('provision_timeout',),
default_value=None)

workdir = skypilot_config.get_effective_region_config(cloud='slurm',
region=region,
keys=('workdir',),
default_value=None)
tmpdir = skypilot_config.get_effective_region_config(cloud='slurm',
region=region,
keys=('tmpdir',),
default_value=None)
logger.info(f'workdir: {workdir}, tmpdir: {tmpdir}')
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

This logger.info statement is useful for debugging during development but might be too verbose for production logs. Consider changing it to logger.debug to keep the default log level cleaner, or remove it once the feature is stable.

Suggested change
logger.info(f'workdir: {workdir}, tmpdir: {tmpdir}')
logger.debug(f'workdir: {workdir}, tmpdir: {tmpdir}')

Copy link
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Removing.


if existing_jobs:
assert len(existing_jobs) == 1, (
f'Multiple jobs found with name {cluster_name_on_cloud}: '
Expand Down Expand Up @@ -167,8 +184,8 @@ def _create_virtual_instance(
except (TypeError, ValueError):
accelerator_count = 0

skypilot_runtime_dir = _skypilot_runtime_dir(cluster_name_on_cloud)
sky_home_dir = _sky_cluster_home_dir(cluster_name_on_cloud)
skypilot_runtime_dir = _skypilot_runtime_dir(tmpdir, cluster_name_on_cloud)
sky_home_dir = _sky_cluster_home_dir(workdir, cluster_name_on_cloud)
ready_signal = f'{sky_home_dir}/.sky_sbatch_ready'

# Build the sbatch script
Expand Down Expand Up @@ -233,8 +250,11 @@ def _create_virtual_instance(
ssh_proxy_command=ssh_proxy_command,
ssh_proxy_jump=ssh_proxy_jump,
)
provision_script_path = _sbatch_provision_script_path(
workdir, cluster_name_on_cloud)
provision_scripts_dir = os.path.dirname(provision_script_path)

cmd = f'mkdir -p {PROVISION_SCRIPTS_DIRECTORY}'
cmd = f'mkdir -p {provision_scripts_dir}'
rc, stdout, stderr = login_node_runner.run(cmd,
require_outputs=True,
stream_logs=False)
Expand All @@ -248,7 +268,7 @@ def _create_virtual_instance(
f.write(provision_script)
f.flush()
src_path = f.name
tgt_path = _sbatch_provision_script_path(f'{cluster_name_on_cloud}.sh')
tgt_path = provision_script_path
login_node_runner.rsync(src_path, tgt_path, up=True, stream_logs=False)

job_id = client.submit_job(partition, cluster_name_on_cloud, tgt_path)
Expand Down Expand Up @@ -599,13 +619,28 @@ def get_command_runners(
# of the cluster yaml.
ssh_proxy_jump = cluster_info.provider_config.get('ssh', {}).get(
'proxyjump', None)

slurm_cluster_name = cluster_info.provider_config.get('cluster')
workdir = skypilot_config.get_effective_region_config(
cloud='slurm',
region=slurm_cluster_name,
keys=('workdir',),
default_value=None)
tmpdir = skypilot_config.get_effective_region_config(
cloud='slurm',
region=slurm_cluster_name,
keys=('tmpdir',),
default_value=None)
sky_dir = _sky_cluster_home_dir(workdir, cluster_name_on_cloud)
skypilot_runtime_dir = _skypilot_runtime_dir(tmpdir, cluster_name_on_cloud)

runners = [
command_runner.SlurmCommandRunner(
(instance_info.external_ip or '', instance_info.ssh_port),
ssh_user,
ssh_private_key,
sky_dir=_sky_cluster_home_dir(cluster_name_on_cloud),
skypilot_runtime_dir=_skypilot_runtime_dir(cluster_name_on_cloud),
sky_dir=sky_dir,
skypilot_runtime_dir=skypilot_runtime_dir,
job_id=instance_info.tags['job_id'],
slurm_node=instance_info.tags['node'],
ssh_proxy_jump=ssh_proxy_jump,
Expand Down
2 changes: 2 additions & 0 deletions sky/utils/config_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -274,6 +274,8 @@ def get_cloud_config_value_from_dict(
region_key = None
if cloud in ('kubernetes', 'ssh'):
region_key = 'context_configs'
elif cloud == 'slurm':
region_key = 'cluster_configs'
elif cloud in _REGION_CONFIG_CLOUDS:
region_key = 'region_configs'

Expand Down
19 changes: 19 additions & 0 deletions sky/utils/schemas.py
Original file line number Diff line number Diff line change
Expand Up @@ -1420,6 +1420,25 @@ def _get_controller_schema():
'provision_timeout': {
'type': 'integer',
},
'cluster_configs': {
'type': 'object',
'required': [],
'properties': {},
# Properties are slurm cluster names.
'additionalProperties': {
'type': 'object',
'required': [],
'additionalProperties': False,
'properties': {
'workdir': {
'type': 'string',
},
'tmpdir': {
'type': 'string',
},
},
},
},
}
},
'oci': {
Expand Down
Loading