Skip to content

Commit 8bddab6

Browse files
committed
ssh properties as launcher method
1 parent 9fd5971 commit 8bddab6

File tree

3 files changed

+48
-5
lines changed

3 files changed

+48
-5
lines changed

runhouse/resources/hardware/cluster.py

+11-1
Original file line numberDiff line numberDiff line change
@@ -160,6 +160,7 @@ def __init__(
160160

161161
self._ips = ips
162162
self._http_client = None
163+
self._ssh_properties = ssh_properties
163164
self.den_auth = den_auth or False
164165
self.cert_config = TLSCertConfig(cert_path=ssl_certfile, key_path=ssl_keyfile)
165166

@@ -169,7 +170,6 @@ def __init__(
169170
self.server_port = server_port
170171
self.client_port = client_port
171172
self.ssh_port = ssh_port or self.DEFAULT_SSH_PORT
172-
self.ssh_properties = ssh_properties or {}
173173
self.server_host = server_host
174174
self.domain = domain
175175
self.compute_properties = {}
@@ -200,6 +200,16 @@ def head_ip(self):
200200
"""Head IP"""
201201
return self.ips[0] if self.ips else None
202202

203+
@property
204+
def ssh_properties(self):
205+
return self._ssh_properties or {}
206+
207+
@ssh_properties.setter
208+
def ssh_properties(self, value):
209+
if not isinstance(value, dict):
210+
raise ValueError(f"SSH properties must be a dict, not {type(value)}.")
211+
self._ssh_properties = value
212+
203213
@property
204214
def client(self):
205215
def check_connect_server():

runhouse/resources/hardware/launcher_utils.py

+22-2
Original file line numberDiff line numberDiff line change
@@ -93,7 +93,7 @@ def log_processor(cls, cluster_name: str):
9393
return LogProcessor(cluster_name)
9494

9595
@classmethod
96-
def up(cls, cluster, verbose: bool = True):
96+
def up(cls, cluster, verbose: bool = True, force: bool = False):
9797
"""Abstract method for launching a cluster."""
9898
raise NotImplementedError
9999

@@ -107,6 +107,10 @@ def keep_warm(cls, cluster, mins: int):
107107
"""Abstract method for keeping a cluster warm."""
108108
raise NotImplementedError
109109

110+
@classmethod
111+
def ssh_properties(cls, cluster):
112+
raise NotImplementedError
113+
110114
@staticmethod
111115
def supported_providers():
112116
"""Return the base list of Sky supported providers."""
@@ -360,6 +364,10 @@ def load_creds(cls):
360364

361365
return secret
362366

367+
@classmethod
368+
def ssh_properties(cls, cluster):
369+
return cluster._ssh_properties
370+
363371

364372
class LocalLauncher(Launcher):
365373
"""Launcher APIs for operations handled locally via Sky."""
@@ -375,7 +383,7 @@ def _validate_provider(cls, cluster):
375383
)
376384

377385
@classmethod
378-
def up(cls, cluster, verbose: bool = True):
386+
def up(cls, cluster, verbose: bool = True, force: bool = False):
379387
"""Launch the cluster locally."""
380388
import sky
381389

@@ -462,6 +470,18 @@ def keep_warm(cls, cluster, mins: int):
462470
set_cluster_autostop_cmd = _cluster_set_autostop_command(mins)
463471
cluster.run_bash_over_ssh([set_cluster_autostop_cmd], node=cluster.head_ip)
464472

473+
@classmethod
474+
def ssh_properties(cls, cluster):
475+
ssh_properties = cluster._ssh_properties
476+
# Note: Sky requires the ssh private key path to be in this specific path
477+
ssh_properties["ssh_private_key"] = "~/.ssh/ssh-sky-key"
478+
return ssh_properties
479+
480+
@classmethod
481+
def load_creds(cls):
482+
# Note: We rely on Sky for handling creds with local launching
483+
return None
484+
465485
@staticmethod
466486
def _set_docker_env_vars(image, task):
467487
"""Helper method to set Docker login environment variables."""

runhouse/resources/hardware/on_demand_cluster.py

+15-2
Original file line numberDiff line numberDiff line change
@@ -146,6 +146,19 @@ def ips(self):
146146
def internal_ips(self):
147147
return self.compute_properties.get("internal_ips", [])
148148

149+
@property
150+
def ssh_properties(self):
151+
if self.launcher == LauncherType.LOCAL:
152+
return LocalLauncher.ssh_properties(self)
153+
if self.launcher == LauncherType.DEN:
154+
return DenLauncher.ssh_properties(self)
155+
156+
@ssh_properties.setter
157+
def ssh_properties(self, value):
158+
if not isinstance(value, dict):
159+
raise ValueError(f"SSH properties must be a dict, not {type(value)}.")
160+
self._ssh_properties = value
161+
149162
@property
150163
def client(self):
151164
try:
@@ -177,10 +190,10 @@ def autostop_mins(self, mins):
177190
else:
178191
self.call_client_method("set_settings", {"autostop_mins": mins})
179192

180-
if self.launcher == "local":
193+
if self.launcher == LauncherType.LOCAL:
181194
LocalLauncher.keep_warm(self, mins)
182195

183-
elif self.launcher == "den":
196+
elif self.launcher == LauncherType.DEN:
184197
DenLauncher.keep_warm(self, mins)
185198

186199
@property

0 commit comments

Comments
 (0)