From 885d198751c6e28c7aaeaa8a53321b99a3ea5850 Mon Sep 17 00:00:00 2001 From: Fabrice Normandin Date: Tue, 21 Jan 2025 17:08:56 -0500 Subject: [PATCH] [wip] Refactor of `mila init` command Signed-off-by: Fabrice Normandin --- milatools/cli/commands.py | 2 +- milatools/cli/init_command.py | 152 ++++++++++++++++----------------- tests/cli/test_init_command.py | 20 +++-- 3 files changed, 85 insertions(+), 89 deletions(-) diff --git a/milatools/cli/commands.py b/milatools/cli/commands.py index a218e1ad..7087c6fc 100644 --- a/milatools/cli/commands.py +++ b/milatools/cli/commands.py @@ -530,7 +530,7 @@ def init(): if running_inside_WSL(): setup_windows_ssh_config_from_wsl(linux_ssh_config=ssh_config) - setup_keys_on_login_node() + setup_keys_on_login_node(cluster="mila") setup_vscode_settings() print_welcome_message() diff --git a/milatools/cli/init_command.py b/milatools/cli/init_command.py index 914af098..6014b44f 100644 --- a/milatools/cli/init_command.py +++ b/milatools/cli/init_command.py @@ -4,6 +4,7 @@ import difflib import functools import json +import shlex import shutil import subprocess import sys @@ -14,10 +15,11 @@ import questionary as qn from invoke.exceptions import UnexpectedExit +from paramiko.config import SSHConfig as SSHConfigReader -from milatools.utils.remote_v2 import SSH_CONFIG_FILE +from milatools.utils.local_v2 import LocalV2 -from ..utils.local_v1 import LocalV1, check_passwordless, display +from ..utils.local_v1 import check_passwordless, display from ..utils.remote_v1 import RemoteV1 from ..utils.vscode_utils import ( get_expected_vscode_settings_json_path, @@ -239,8 +241,25 @@ def setup_windows_ssh_config_from_wsl(linux_ssh_config: SSHConfig): _copy_if_needed(linux_key_file, windows_key_file) -def setup_passwordless_ssh_access(ssh_config: SSHConfig) -> bool: - """Sets up passwordless ssh access to the Mila and optionally also to DRAC. +def get_identityfile_from_ssh_config( + ssh_config: SSHConfig, hostname: str +) -> Path | None: + ssh_config_reader = SSHConfigReader.from_path(ssh_config.path) + private_key_path = ssh_config_reader.lookup(hostname).get("identityfile") + if private_key_path is None: + return None + # Seems to be a list for some reason? + if isinstance(private_key_path, list): + assert private_key_path + private_key_path = private_key_path[0] + return Path(private_key_path) + + +def setup_passwordless_ssh_access( + ssh_config: SSHConfig, + clusters: list[str] | tuple[str, ...] = ("mila", *DRAC_CLUSTERS), +) -> bool: + """Sets up passwordless ssh access to the given clusters. Sets up ssh connection to the DRAC clusters if they are present in the SSH config file. @@ -248,57 +267,50 @@ def setup_passwordless_ssh_access(ssh_config: SSHConfig) -> bool: Returns whether the operation completed successfully or not. """ print("Checking passwordless authentication") + clusters = list(clusters) + if not clusters: + print("No clusters to setup.") + return True - here = LocalV1() - sshdir = Path.home() / ".ssh" - - # Check if there is a public key file in ~/.ssh - if not list(sshdir.glob("id*.pub")): - if yn("You have no public keys. Generate one?"): - # Run ssh-keygen with the given location and no passphrase. - ssh_private_key_path = Path.home() / ".ssh" / "id_rsa" - create_ssh_keypair(ssh_private_key_path, here) - else: - print("No public keys.") - return False - - # TODO: This uses the public key set in the SSH config file, which may (or may not) - # be the random id*.pub file that was just checked for above. - success = setup_passwordless_ssh_access_to_cluster("mila") - - if not success: - return False - setup_keys_on_login_node("mila") + printed_drac_warning = False - drac_clusters_in_ssh_config: list[str] = [] - hosts_in_config = ssh_config.hosts() - for cluster in DRAC_CLUSTERS: - if any(cluster in hostname for hostname in hosts_in_config): - drac_clusters_in_ssh_config.append(cluster) + for cluster in clusters: + private_key_path = get_identityfile_from_ssh_config(ssh_config, cluster) + if private_key_path is None: + # todo: if the cluster doesn't have an `IdentityFile` set in the config, + # should we set a `IdentityFile` based on the cluster name? Or use the + # default key? + # For now, we just create the default ~/.ssh/id_rsa key if needed. + private_key_path = Path.home() / ".ssh" / "id_rsa" - if not drac_clusters_in_ssh_config: - logger.debug( - f"There are no DRAC clusters in the SSH config at {ssh_config.path}." - ) - return True + if not private_key_path.exists(): + # Run ssh-keygen with the given location and no passphrase. + print( + f"You don't have an SSH key for the {cluster!r} cluster. " + f"Generating one at {private_key_path}." + ) + create_ssh_keypair(private_key_path) + + if cluster in DRAC_CLUSTERS and not printed_drac_warning: + print( + "Setting up passwordless ssh access to the DRAC clusters with ssh-copy-id.\n" + "\n" + "Please note that you can also setup passwordless SSH access to all the DRAC " + "clusters by visiting https://ccdb.alliancecan.ca/ssh_authorized_keys and " + "copying in the content of your public key in the box.\n" + "See https://docs.alliancecan.ca/wiki/SSH_Keys#Using_CCDB for more info." + ) + printed_drac_warning = True + success = run_ssh_copy_id(cluster, private_key_path) - print( - "Setting up passwordless ssh access to the DRAC clusters with ssh-copy-id.\n" - "\n" - "Please note that you can also setup passwordless SSH access to all the DRAC " - "clusters by visiting https://ccdb.alliancecan.ca/ssh_authorized_keys and " - "copying in the content of your public key in the box.\n" - "See https://docs.alliancecan.ca/wiki/SSH_Keys#Using_CCDB for more info." - ) - for drac_cluster in drac_clusters_in_ssh_config: - success = setup_passwordless_ssh_access_to_cluster(drac_cluster) + setup_keys_on_login_node(cluster) if not success: return False - setup_keys_on_login_node(drac_cluster) + return True -def setup_passwordless_ssh_access_to_cluster(cluster: str) -> bool: +def run_ssh_copy_id(cluster: str, identity_file: Path) -> bool: """Sets up passwordless SSH access to the given hostname. On Mac/Linux, uses `ssh-copy-id`. Performs the steps of ssh-copy-id manually on @@ -306,34 +318,13 @@ def setup_passwordless_ssh_access_to_cluster(cluster: str) -> bool: Returns whether the operation completed successfully or not. """ - here = LocalV1() # Check that it is possible to connect without using a password. print(f"Checking if passwordless SSH access is setup for the {cluster} cluster.") - # TODO: Potentially use a custom key like `~/.ssh/id_milatools.pub` instead of - # the default. - - from paramiko.config import SSHConfig - config = SSHConfig.from_path(str(SSH_CONFIG_FILE)) - identity_file = config.lookup(cluster).get("identityfile", "~/.ssh/id_rsa") - # Seems to be a list for some reason? - if isinstance(identity_file, list): - assert identity_file - identity_file = identity_file[0] ssh_private_key_path = Path(identity_file).expanduser() ssh_public_key_path = ssh_private_key_path.with_suffix(".pub") assert ssh_public_key_path.exists() - # TODO: This will fail on Windows for clusters with 2FA. - # if check_passwordless(cluster): - # logger.info(f"Passwordless SSH access to {cluster} is already setup correctly.") - # return True - # if not yn( - # f"Your public key does not appear be registered on the {cluster} cluster. " - # "Register it?" - # ): - # print("No passwordless login.") - # return False print("Please enter your password if prompted.") if sys.platform == "win32": # NOTE: This is to remove extra '^M' characters that would be added at the end @@ -356,14 +347,15 @@ def setup_passwordless_ssh_access_to_cluster(cluster: str) -> bool: f.seek(0) subprocess.run(command, check=True, text=False, stdin=f) else: - here.run( - "ssh-copy-id", - "-i", - str(ssh_private_key_path), - "-o", - "StrictHostKeyChecking=no", - cluster, - check=True, + LocalV2.run( + ( + "ssh-copy-id", + "-i", + str(ssh_private_key_path), + "-o", + "StrictHostKeyChecking=no", + cluster, + ), ) # double-check that this worked. @@ -373,6 +365,10 @@ def setup_passwordless_ssh_access_to_cluster(cluster: str) -> bool: return True +def run_in_bash(cmd: str) -> str: + return shlex.join(["bash", "-c", cmd]) + + def setup_keys_on_login_node(cluster: str = "mila"): ##################################### # Step 3: Set up keys on login node # @@ -396,8 +392,8 @@ def setup_keys_on_login_node(cluster: str = "mila"): else: exit("Cannot proceed because there is no public key") - common = remote.with_bash().get_output( - "comm -12 <(sort ~/.ssh/authorized_keys) <(sort ~/.ssh/*.pub)" + common = remote.get_output( + run_in_bash("comm -12 <(sort ~/.ssh/authorized_keys) <(sort ~/.ssh/*.pub)") ) if common: print("# OK") @@ -465,7 +461,6 @@ def get_windows_home_path_in_wsl() -> Path: def create_ssh_keypair( ssh_private_key_path: Path, - local: LocalV1 | None = None, passphrase: str | None = "", ) -> None: """Creates a public/private key pair at the given path using ssh-keygen. @@ -474,7 +469,6 @@ def create_ssh_keypair( Otherwise, if passphrase is an empty string, no passphrase will be used (default). If a string is passed, it is passed to ssh-keygen and used as the passphrase. """ - local = local or LocalV1() command = [ "ssh-keygen", "-f", diff --git a/tests/cli/test_init_command.py b/tests/cli/test_init_command.py index d456ca40..ceef0938 100644 --- a/tests/cli/test_init_command.py +++ b/tests/cli/test_init_command.py @@ -31,9 +31,9 @@ create_ssh_keypair, get_windows_home_path_in_wsl, has_passphrase, + run_ssh_copy_id, setup_keys_on_login_node, setup_passwordless_ssh_access, - setup_passwordless_ssh_access_to_cluster, setup_ssh_config, setup_vscode_settings, setup_windows_ssh_config_from_wsl, @@ -42,7 +42,7 @@ SSHConfig, running_inside_WSL, ) -from milatools.utils.local_v1 import LocalV1, check_passwordless +from milatools.utils.local_v1 import check_passwordless from milatools.utils.remote_v1 import RemoteV1 from milatools.utils.remote_v2 import ( SSH_CACHE_DIR, @@ -1473,7 +1473,7 @@ def _mock_subprocess_run(command: tuple[str], *args, **kwargs): return subprocess_run(command, *args, **kwargs) mock_subprocess_run = mocker.patch("subprocess.run", wraps=_mock_subprocess_run) - success = setup_passwordless_ssh_access_to_cluster(cluster) + success = run_ssh_copy_id(cluster) if passwordless_ssh_was_previously_setup: # We already had access to the cluster. assert success is True @@ -1501,7 +1501,7 @@ def _mock_subprocess_run(command: tuple[str], *args, **kwargs): ] regression_text = "\n".join( [ - f"Calling {function_call_string(setup_passwordless_ssh_access_to_cluster, cluster)}", + f"Calling {function_call_string(run_ssh_copy_id, cluster)}", ] + [ f"with passwordless SSH access to {cluster} already setup" @@ -1582,9 +1582,7 @@ def test_setup_passwordless_ssh_access( else: # There should be an ssh key in the .ssh dir. # Won't ask to generate a key. - create_ssh_keypair( - ssh_private_key_path=ssh_dir / "id_rsa_milatools", local=LocalV1() - ) + create_ssh_keypair(ssh_private_key_path=ssh_dir / "id_rsa_milatools") if drac_clusters_in_ssh_config: # We should get a prompt asking if we want to register the public key # on the DRAC clusters or not. @@ -1609,14 +1607,14 @@ def test_setup_passwordless_ssh_access( # It's okay because we have a good test for it above. Therefore we just test how it # gets called here. mock_setup_passwordless_ssh_access_to_cluster = Mock( - spec=setup_passwordless_ssh_access_to_cluster, + spec=run_ssh_copy_id, side_effect=[accept_mila, *(accept_drac for _ in drac_clusters_in_ssh_config)], ) import milatools.cli.init_command monkeypatch.setattr( milatools.cli.init_command, - setup_passwordless_ssh_access_to_cluster.__name__, + run_ssh_copy_id.__name__, mock_setup_passwordless_ssh_access_to_cluster, ) @@ -1666,3 +1664,7 @@ def test_setup_passwordless_ssh_access( for drac_cluster in drac_clusters_in_ssh_config: mock_setup_passwordless_ssh_access_to_cluster.assert_any_call(drac_cluster) assert result is True + + +def test_inaccessible_cluster_is_skipped_in_mila_init(): + ...