Skip to content

Commit e3bc550

Browse files
authored
Fix bug with mila code --cluster=<DRAC> (#115)
* Hotfix for bug in `mila code --cluster=<DRAC>` Signed-off-by: Fabrice Normandin <[email protected]> * Fix running sync vscode extensions in background Signed-off-by: Fabrice Normandin <[email protected]> * Use ssh key from ssh config for ssh-copy-id Signed-off-by: Fabrice Normandin <[email protected]> * Fix dumb unit test Signed-off-by: Fabrice Normandin <[email protected]> * Fix broken test for `mila init` Signed-off-by: Fabrice Normandin <[email protected]> * Fix broken test for `make_process` Signed-off-by: Fabrice Normandin <[email protected]> * Always run `ssh-copy-id` (with the right key) Signed-off-by: Fabrice Normandin <[email protected]> * Change where the compute node setup occurs in init Signed-off-by: Fabrice Normandin <[email protected]> * Fix error in test for syncing vscode extensions Signed-off-by: Fabrice Normandin <[email protected]> * Add temporary "fix" for failing test Signed-off-by: Fabrice Normandin <[email protected]> * Remove flaky check from test_ensure_allocation Signed-off-by: Fabrice Normandin <[email protected]> * Always cd to $SCRATCH before salloc/sbatch/srun Signed-off-by: Fabrice Normandin <[email protected]> * Adjust unit tests following `cd $SCRATCH` change Signed-off-by: Fabrice Normandin <[email protected]> * Fix test and make it slightly more agnostic to imp Signed-off-by: Fabrice Normandin <[email protected]> * Fix failing check for the workdir in test_code Signed-off-by: Fabrice Normandin <[email protected]> --------- Signed-off-by: Fabrice Normandin <[email protected]>
1 parent 5adb85c commit e3bc550

File tree

11 files changed

+128
-88
lines changed

11 files changed

+128
-88
lines changed

milatools/cli/commands.py

Lines changed: 10 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -584,7 +584,8 @@ def code(
584584
)
585585
elif no_internet_on_compute_nodes(cluster):
586586
# Sync the VsCode extensions from the local machine over to the target cluster.
587-
run_in_the_background = False # if "pytest" not in sys.modules else True
587+
# TODO: Make this happen in the background (without overwriting the output).
588+
run_in_the_background = False
588589
print(
589590
console.log(
590591
f"[cyan]Installing VSCode extensions that are on the local machine on "
@@ -680,11 +681,17 @@ def code(
680681
if persist:
681682
print("This allocation is persistent and is still active.")
682683
print("To reconnect to this node:")
683-
print(T.bold(f" mila code {path} --node {node_name}"))
684+
print(
685+
T.bold(
686+
f" mila code {path} "
687+
+ (f"--cluster={cluster} " if cluster != "mila" else "")
688+
+ f"--node {node_name}"
689+
)
690+
)
684691
print("To kill this allocation:")
685692
assert data is not None
686693
assert "jobid" in data
687-
print(T.bold(f" ssh mila scancel {data['jobid']}"))
694+
print(T.bold(f" ssh {cluster} scancel {data['jobid']}"))
688695

689696

690697
def connect(identifier: str, port: int | None):

milatools/cli/init_command.py

Lines changed: 46 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@
1515
import questionary as qn
1616
from invoke.exceptions import UnexpectedExit
1717

18+
from milatools.utils.remote_v2 import SSH_CONFIG_FILE
19+
1820
from ..utils.vscode_utils import (
1921
get_expected_vscode_settings_json_path,
2022
vscode_installed,
@@ -238,20 +240,23 @@ def setup_passwordless_ssh_access(ssh_config: SSHConfig) -> bool:
238240

239241
here = Local()
240242
sshdir = Path.home() / ".ssh"
241-
ssh_private_key_path = Path.home() / ".ssh" / "id_rsa"
242243

243244
# Check if there is a public key file in ~/.ssh
244245
if not list(sshdir.glob("id*.pub")):
245246
if yn("You have no public keys. Generate one?"):
246247
# Run ssh-keygen with the given location and no passphrase.
248+
ssh_private_key_path = Path.home() / ".ssh" / "id_rsa"
247249
create_ssh_keypair(ssh_private_key_path, here)
248250
else:
249251
print("No public keys.")
250252
return False
251253

254+
# TODO: This uses the public key set in the SSH config file, which may (or may not)
255+
# be the random id*.pub file that was just checked for above.
252256
success = setup_passwordless_ssh_access_to_cluster("mila")
253257
if not success:
254258
return False
259+
setup_keys_on_login_node("mila")
255260

256261
drac_clusters_in_ssh_config: list[str] = []
257262
hosts_in_config = ssh_config.hosts()
@@ -277,6 +282,7 @@ def setup_passwordless_ssh_access(ssh_config: SSHConfig) -> bool:
277282
success = setup_passwordless_ssh_access_to_cluster(drac_cluster)
278283
if not success:
279284
return False
285+
setup_keys_on_login_node(drac_cluster)
280286
return True
281287

282288

@@ -293,28 +299,38 @@ def setup_passwordless_ssh_access_to_cluster(cluster: str) -> bool:
293299
print(f"Checking if passwordless SSH access is setup for the {cluster} cluster.")
294300
# TODO: Potentially use a custom key like `~/.ssh/id_milatools.pub` instead of
295301
# the default.
296-
ssh_private_key_path = Path.home() / ".ssh" / "id_rsa"
302+
303+
from paramiko.config import SSHConfig
304+
305+
config = SSHConfig.from_path(str(SSH_CONFIG_FILE))
306+
identity_file = config.lookup(cluster).get("identityfile", "~/.ssh/id_rsa")
307+
# Seems to be a list for some reason?
308+
if isinstance(identity_file, list):
309+
assert identity_file
310+
identity_file = identity_file[0]
311+
ssh_private_key_path = Path(identity_file).expanduser()
297312
ssh_public_key_path = ssh_private_key_path.with_suffix(".pub")
298313
assert ssh_public_key_path.exists()
299-
# TODO: This will fail for clusters with 2FA.
300-
if check_passwordless(cluster):
301-
logger.info(f"Passwordless SSH access to {cluster} is already setup correctly.")
302-
return True
303314

304-
if not yn(
305-
f"Your public key does not appear be registered on the {cluster} cluster. "
306-
"Register it?"
307-
):
308-
print("No passwordless login.")
309-
return False
310-
311-
print("Please enter your password when prompted.")
315+
# TODO: This will fail on Windows for clusters with 2FA.
316+
# if check_passwordless(cluster):
317+
# logger.info(f"Passwordless SSH access to {cluster} is already setup correctly.")
318+
# return True
319+
# if not yn(
320+
# f"Your public key does not appear be registered on the {cluster} cluster. "
321+
# "Register it?"
322+
# ):
323+
# print("No passwordless login.")
324+
# return False
325+
print("Please enter your password if prompted.")
312326
if sys.platform == "win32":
313327
# NOTE: This is to remove extra '^M' characters that would be added at the end
314328
# of the file on the remote!
315329
public_key_contents = ssh_public_key_path.read_text().replace("\r\n", "\n")
316330
command = (
317331
"ssh",
332+
"-i",
333+
str(ssh_private_key_path),
318334
"-o",
319335
"StrictHostKeyChecking=no",
320336
cluster,
@@ -328,7 +344,15 @@ def setup_passwordless_ssh_access_to_cluster(cluster: str) -> bool:
328344
f.seek(0)
329345
subprocess.run(command, check=True, text=False, stdin=f)
330346
else:
331-
here.run("ssh-copy-id", "-o", "StrictHostKeyChecking=no", cluster, check=True)
347+
here.run(
348+
"ssh-copy-id",
349+
"-i",
350+
str(ssh_private_key_path),
351+
"-o",
352+
"StrictHostKeyChecking=no",
353+
cluster,
354+
check=True,
355+
)
332356

333357
# double-check that this worked.
334358
if not check_passwordless(cluster):
@@ -337,14 +361,17 @@ def setup_passwordless_ssh_access_to_cluster(cluster: str) -> bool:
337361
return True
338362

339363

340-
def setup_keys_on_login_node():
364+
def setup_keys_on_login_node(cluster: str = "mila"):
341365
#####################################
342366
# Step 3: Set up keys on login node #
343367
#####################################
344368

345-
print("Checking connection to compute nodes")
346-
347-
remote = Remote("mila")
369+
print(
370+
f"Checking connection to compute nodes on the {cluster} cluster. "
371+
"This is required for `mila code` to work properly."
372+
)
373+
# todo: avoid re-creating the `Remote` here, since it goes through 2FA each time!
374+
remote = Remote(cluster)
348375
try:
349376
pubkeys = remote.get_lines("ls -t ~/.ssh/id*.pub")
350377
print("# OK")

milatools/cli/remote.py

Lines changed: 11 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,6 @@
2020
from typing_extensions import Self, TypedDict, deprecated
2121

2222
from .utils import (
23-
DRAC_CLUSTERS,
2423
SSHConnectionError,
2524
T,
2625
cluster_to_connect_kwargs,
@@ -344,7 +343,6 @@ def extract(
344343
# something like a `io.StringIO` here instead, and create an object that manages
345344
# reading from it, and pass that `io.StringIO` buffer to `self.run`.
346345
qio: TextIO = QueueIO()
347-
348346
promise = self.run(
349347
cmd,
350348
hide=hide,
@@ -467,7 +465,10 @@ def __init__(
467465
)
468466

469467
def srun_transform(self, cmd: str) -> str:
470-
return shlex.join(["srun", *self.alloc, "bash", "-c", cmd])
468+
cmd = shlex.join(["srun", *self.alloc, "bash", "-c", cmd])
469+
# We need to cd to $SCRATCH before we can run jobs with `srun` on some clusters.
470+
cmd = f"cd $SCRATCH && {cmd}"
471+
return cmd
471472

472473
def srun_transform_persist(self, cmd: str) -> str:
473474
tag = time.time_ns()
@@ -482,10 +483,8 @@ def srun_transform_persist(self, cmd: str) -> str:
482483
self.puttext(text=batch, dest=batch_file)
483484
self.simple_run(f"chmod +x {batch_file}")
484485
cmd = shlex.join(["sbatch", *self.alloc, str(batch_file)])
485-
486-
# NOTE: We need to cd to $SCRATCH before we run `sbatch` on DRAC clusters.
487-
if self.connection.host in DRAC_CLUSTERS:
488-
cmd = f"cd $SCRATCH && {cmd}"
486+
# We need to cd to $SCRATCH before we run `sbatch` on some SLURM clusters.
487+
cmd = f"cd $SCRATCH && {cmd}"
489488
return f"{cmd}; touch {output_file}; tail -n +1 -f {output_file}"
490489

491490
def with_transforms(
@@ -518,9 +517,10 @@ def ensure_allocation(
518517
- a dict with the compute node name (without the jobid)
519518
- a `fabric.runners.Remote` object connected to the *login* node.
520519
"""
520+
521521
if self._persist:
522522
login_node_runner, results = self.extract(
523-
"echo @@@ $(hostname) @@@ && sleep 1000d",
523+
"echo @@@ $SLURMD_NODENAME @@@ && sleep 1000d",
524524
patterns={
525525
"node_name": "@@@ ([^ ]+) @@@",
526526
"jobid": "Submitted batch job ([0-9]+)",
@@ -535,10 +535,9 @@ def ensure_allocation(
535535
else:
536536
remote = Remote(hostname=self.hostname, connection=self.connection)
537537
command = shlex.join(["salloc", *self.alloc])
538-
# NOTE: On some DRAC clusters, it's required to first cd to $SCRATCH or
539-
# /projects before submitting a job.
540-
if self.connection.host in DRAC_CLUSTERS:
541-
command = f"cd $SCRATCH && {command}"
538+
# We need to cd to $SCRATCH before we can run `salloc` on some clusters.
539+
command = f"cd $SCRATCH && {command}"
540+
542541
proc, results = remote.extract(
543542
command,
544543
patterns={"node_name": "salloc: Nodes ([^ ]+) are ready for job"},

milatools/cli/utils.py

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -297,7 +297,9 @@ def make_process(
297297
) -> multiprocessing.Process:
298298
# Tiny wrapper around the `multiprocessing.Process` init to detect if the args and
299299
# kwargs don't match the target signature using typing instead of at runtime.
300-
return multiprocessing.Process(target=target, daemon=True, args=args, kwargs=kwargs)
300+
return multiprocessing.Process(
301+
target=target, daemon=False, args=args, kwargs=kwargs
302+
)
301303

302304

303305
def currently_in_a_test() -> bool:

milatools/utils/vscode_utils.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -458,7 +458,7 @@ def find_code_server_executable(
458458
def parse_vscode_extensions_versions(
459459
list_extensions_output_lines: list[str],
460460
) -> dict[str, str]:
461-
extensions = list_extensions_output_lines
461+
extensions = [line for line in list_extensions_output_lines if "@" in line]
462462

463463
def _extension_name_and_version(extension: str) -> tuple[str, str]:
464464
# extensions should include name@version since we use --show-versions.

tests/cli/test_init_command.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -31,6 +31,7 @@
3131
create_ssh_keypair,
3232
get_windows_home_path_in_wsl,
3333
has_passphrase,
34+
setup_keys_on_login_node,
3435
setup_passwordless_ssh_access,
3536
setup_passwordless_ssh_access_to_cluster,
3637
setup_ssh_config,
@@ -1575,6 +1576,12 @@ def test_setup_passwordless_ssh_access(
15751576
mock_setup_passwordless_ssh_access_to_cluster,
15761577
)
15771578

1579+
monkeypatch.setattr(
1580+
milatools.cli.init_command,
1581+
setup_keys_on_login_node.__name__,
1582+
Mock(spec=setup_keys_on_login_node),
1583+
)
1584+
15781585
result = setup_passwordless_ssh_access(ssh_config)
15791586

15801587
if not public_key_exists:

0 commit comments

Comments
 (0)