Skip to content

Commit 68b7dde

Browse files
lebricesatyaog
andauthored
Fix sshkeys passphrase issue with mila init [MT-72] (#93)
* Add `has_passphrase` function and test Signed-off-by: Fabrice Normandin <[email protected]> * Add passphrase param to test_create_ssh_keypair Signed-off-by: Fabrice Normandin <[email protected]> * Remove check for number of lines in private key Signed-off-by: Fabrice Normandin <[email protected]> * Update poetry.lock file Signed-off-by: Fabrice Normandin <[email protected]> * Increase timeout value for test_create_ssh_keypair Signed-off-by: Fabrice Normandin <[email protected]> * Fix bug in create_ssh_keypair Signed-off-by: Fabrice <[email protected]> * Add `use_shjoin` arg to `display` Signed-off-by: Fabrice <[email protected]> * Simplify passing path of keyfile to ssh-keygen Signed-off-by: Fabrice <[email protected]> * Simpify sending of ssh key on Windows Signed-off-by: Fabrice <[email protected]> * Add a hardcore integration test (not in CI yet) Signed-off-by: Fabrice <[email protected]> * Pass passphrase as f"-N='{passphrase}'" Signed-off-by: Fabrice Normandin <[email protected]> * Simplify and add a docstring to create_ssh_keypair Signed-off-by: Fabrice Normandin <[email protected]> * Setup ssh keypair if needed during test Signed-off-by: Fabrice Normandin <[email protected]> * Tweak comment Signed-off-by: Fabrice Normandin <[email protected]> * Also catch socket.gaierror for Windows errors Signed-off-by: Fabrice <[email protected]> * Fix small typing error in utils.py Signed-off-by: Fabrice <[email protected]> * Create the parent dir of sshkey in test Signed-off-by: Fabrice Normandin <[email protected]> * Try to fix ssh-keygen errors on Windows (again) Signed-off-by: Fabrice Normandin <[email protected]> * Create the SSH dir during test Signed-off-by: Fabrice Normandin <[email protected]> * Update milatools/cli/init_command.py * Remove the xfails for weird paths for ssh keys Signed-off-by: Fabrice <[email protected]> * Fix failing test param on windows Signed-off-by: Fabrice Normandin <[email protected]> * Apply suggestions from code review Co-authored-by: satyaog <[email protected]> * Change has_passphrase Signed-off-by: Fabrice Normandin <[email protected]> * Fix pre-commit hook issues Signed-off-by: Fabrice Normandin <[email protected]> * Remove unused "test" dep group Signed-off-by: Fabrice Normandin <[email protected]> --------- Signed-off-by: Fabrice Normandin <[email protected]> Signed-off-by: Fabrice <[email protected]> Co-authored-by: satyaog <[email protected]>
1 parent f8bb9c1 commit 68b7dde

File tree

5 files changed

+201
-34
lines changed

5 files changed

+201
-34
lines changed

milatools/cli/init_command.py

Lines changed: 63 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@
44
import difflib
55
import functools
66
import json
7-
import shlex
87
import shutil
98
import subprocess
109
import sys
@@ -16,7 +15,7 @@
1615
import questionary as qn
1716
from invoke.exceptions import UnexpectedExit
1817

19-
from .local import Local, check_passwordless
18+
from .local import Local, check_passwordless, display
2019
from .remote import Remote
2120
from .utils import SSHConfig, T, running_inside_WSL, yn
2221
from .vscode_utils import (
@@ -292,11 +291,11 @@ def setup_passwordless_ssh_access_to_cluster(cluster: str) -> bool:
292291
here = Local()
293292
# Check that it is possible to connect without using a password.
294293
print(f"Checking if passwordless SSH access is setup for the {cluster} cluster.")
295-
# TODO: Potentially use the public key from the SSH config file instead of
296-
# the default. It's also possible that ssh-copy-id selects the key from the
297-
# config file, I'm not sure.
298-
# ssh_private_key_path = Path.home() / ".ssh" / "id_rsa"
299-
294+
# TODO: Potentially use a custom key like `~/.ssh/id_milatools.pub` instead of
295+
# the default.
296+
ssh_private_key_path = Path.home() / ".ssh" / "id_rsa"
297+
ssh_public_key_path = ssh_private_key_path.with_suffix(".pub")
298+
assert ssh_public_key_path.exists()
300299
if check_passwordless(cluster):
301300
logger.info(f"Passwordless SSH access to {cluster} is already setup correctly.")
302301
return True
@@ -310,19 +309,23 @@ def setup_passwordless_ssh_access_to_cluster(cluster: str) -> bool:
310309

311310
print("Please enter your password when prompted.")
312311
if sys.platform == "win32":
313-
# todo: the path to the key is hard-coded here.
312+
# NOTE: This is to remove extra '^M' characters that would be added at the end
313+
# of the file on the remote!
314+
public_key_contents = ssh_public_key_path.read_text().replace("\r\n", "\n")
314315
command = (
315-
"powershell.exe",
316-
"type",
317-
"$env:USERPROFILE\\.ssh\\id_rsa.pub",
318-
"|",
319316
"ssh",
320317
"-o",
321318
"StrictHostKeyChecking=no",
322319
cluster,
323-
'"cat >> ~/.ssh/authorized_keys"',
320+
"cat >> ~/.ssh/authorized_keys",
324321
)
325-
here.run(*command, check=True)
322+
display(command)
323+
import tempfile
324+
325+
with tempfile.NamedTemporaryFile("w", newline="\n") as f:
326+
print(public_key_contents, end="", file=f)
327+
f.seek(0)
328+
subprocess.run(command, check=True, text=False, stdin=f)
326329
else:
327330
here.run("ssh-copy-id", "-o", "StrictHostKeyChecking=no", cluster, check=True)
328331

@@ -410,11 +413,53 @@ def get_windows_home_path_in_wsl() -> Path:
410413
return Path(f"/mnt/c/Users/{windows_username}")
411414

412415

413-
def create_ssh_keypair(ssh_private_key_path: Path, local: Local) -> None:
414-
local.run(
415-
*shlex.split(
416-
f'ssh-keygen -f {shlex.quote(str(ssh_private_key_path))} -t rsa -N=""'
416+
def create_ssh_keypair(
417+
ssh_private_key_path: Path,
418+
local: Local | None = None,
419+
passphrase: str | None = "",
420+
) -> None:
421+
"""Creates a public/private key pair at the given path using ssh-keygen.
422+
423+
If passphrase is `None`, ssh-keygen will prompt the user for a passphrase.
424+
Otherwise, if passphrase is an empty string, no passphrase will be used (default).
425+
If a string is passed, it is passed to ssh-keygen and used as the passphrase.
426+
"""
427+
local = local or Local()
428+
command = [
429+
"ssh-keygen",
430+
"-f",
431+
str(ssh_private_key_path.expanduser()),
432+
"-t",
433+
"rsa",
434+
]
435+
if passphrase is not None:
436+
command.extend(["-N", passphrase])
437+
display(command)
438+
subprocess.run(command, check=True)
439+
440+
441+
def has_passphrase(ssh_private_key_path: Path) -> bool:
442+
"""Returns whether the SSH private key has a passphrase or not."""
443+
assert ssh_private_key_path.exists()
444+
result = subprocess.run(
445+
args=(
446+
"ssh-keygen",
447+
"-y",
448+
"-P=''",
449+
"-f",
450+
str(ssh_private_key_path),
417451
),
452+
capture_output=True,
453+
text=True,
454+
)
455+
logger.debug(f"Result of ssh-keygen: {result}")
456+
if result.returncode == 0:
457+
return False
458+
elif "incorrect passphrase supplied to decrypt private key" in result.stderr:
459+
return True
460+
raise NotImplementedError(
461+
f"TODO: Unable to tell if the key at {ssh_private_key_path} has a passphrase "
462+
f"or not! (result={result})"
418463
)
419464

420465

milatools/cli/local.py

Lines changed: 6 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -66,8 +66,12 @@ def check_passwordless(self, host: str):
6666
return check_passwordless(host)
6767

6868

69-
def display(split_command: list[str] | tuple[str, ...]) -> None:
70-
print(T.bold_green("(local) $ ", shjoin(split_command)))
69+
def display(split_command: list[str] | tuple[str, ...] | str) -> None:
70+
if isinstance(split_command, str):
71+
command = split_command
72+
else:
73+
command = shjoin(split_command)
74+
print(T.bold_green("(local) $ ", command))
7175

7276

7377
def check_passwordless(host: str) -> bool:

milatools/cli/remote.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -119,8 +119,9 @@ def __init__(
119119
# assert isinstance(connection.transport, paramiko.Transport)
120120
transport: paramiko.Transport = connection.transport # type: ignore
121121
transport.set_keepalive(keepalive)
122-
except paramiko.SSHException as err:
122+
except (paramiko.SSHException, socket.gaierror) as err:
123123
raise SSHConnectionError(node_hostname=self.hostname, error=err)
124+
124125
self.connection = connection
125126
self.transforms = transforms
126127

milatools/cli/utils.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -131,7 +131,7 @@ def __str__(self):
131131

132132

133133
class SSHConnectionError(paramiko.SSHException):
134-
def __init__(self, node_hostname: str, error: paramiko.SSHException):
134+
def __init__(self, node_hostname: str, error: Exception):
135135
super().__init__()
136136
self.node_hostname = node_hostname
137137
self.error = error
@@ -158,7 +158,7 @@ def yn(prompt: str, default: bool = True) -> bool:
158158
return qn.confirm(prompt, default=default).unsafe_ask()
159159

160160

161-
def askpath(prompt, remote):
161+
def askpath(prompt: str, remote: Remote) -> str:
162162
while True:
163163
pth = qn.text(prompt).unsafe_ask()
164164
try:

tests/cli/test_init_command.py

Lines changed: 128 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
from __future__ import annotations
22

33
import contextlib
4+
import getpass
45
import json
56
import os
67
import shutil
@@ -12,6 +13,7 @@
1213
from pathlib import Path
1314
from unittest.mock import Mock
1415

16+
import fabric
1517
import pytest
1618
import pytest_mock
1719
import questionary
@@ -26,14 +28,15 @@
2628
_setup_ssh_config_file,
2729
create_ssh_keypair,
2830
get_windows_home_path_in_wsl,
31+
has_passphrase,
2932
setup_passwordless_ssh_access,
3033
setup_passwordless_ssh_access_to_cluster,
3134
setup_ssh_config,
3235
setup_vscode_settings,
3336
setup_windows_ssh_config_from_wsl,
3437
)
3538
from milatools.cli.local import Local, check_passwordless
36-
from milatools.cli.utils import SSHConfig, running_inside_WSL
39+
from milatools.cli.utils import SSHConfig, T, running_inside_WSL
3740

3841
from .common import (
3942
in_github_CI,
@@ -727,20 +730,52 @@ def test_fixes_dir_permission_issues(
727730

728731

729732
# takes a little longer in the CI runner (Windows in particular)
730-
@pytest.mark.timeout(10)
731-
def test_create_ssh_keypair(monkeypatch: pytest.MonkeyPatch, tmp_path: Path):
732-
here = Local()
733-
mock_run = Mock(
734-
wraps=subprocess.run,
735-
)
736-
monkeypatch.setattr(subprocess, "run", mock_run)
733+
@pytest.mark.timeout(20)
734+
@pytest.mark.parametrize(
735+
("passphrase", "expected"),
736+
[("", False), ("bobobo", True), ("\n", True), (" ", True)],
737+
)
738+
@pytest.mark.parametrize(
739+
"filename",
740+
[
741+
"bob",
742+
"dir with spaces/somefile",
743+
"dir_with_'single_quotes'/somefile",
744+
pytest.param(
745+
'dir_with_"doublequotes"/somefile',
746+
marks=pytest.mark.xfail(
747+
sys.platform == "win32",
748+
strict=True,
749+
raises=OSError,
750+
reason="Doesn't work on Windows.",
751+
),
752+
),
753+
pytest.param(
754+
"windows_style_dir\\bob",
755+
marks=pytest.mark.skipif(
756+
sys.platform != "win32", reason="only runs on Windows."
757+
),
758+
),
759+
],
760+
)
761+
def test_create_ssh_keypair(
762+
mocker: pytest_mock.MockerFixture,
763+
tmp_path: Path,
764+
filename: str,
765+
passphrase: str,
766+
expected: bool,
767+
):
768+
# Wrap the subprocess.run call (but also actually execute the commands).
769+
subprocess_run = mocker.patch("subprocess.run", wraps=subprocess.run)
770+
737771
fake_ssh_folder = tmp_path / "fake_ssh"
738772
fake_ssh_folder.mkdir(mode=0o700)
739-
ssh_private_key_path = fake_ssh_folder / "bob"
773+
ssh_private_key_path = fake_ssh_folder / filename
774+
ssh_private_key_path.parent.mkdir(mode=0o700, exist_ok=True, parents=True)
740775

741-
create_ssh_keypair(ssh_private_key_path=ssh_private_key_path, local=here)
776+
create_ssh_keypair(ssh_private_key_path=ssh_private_key_path, passphrase=passphrase)
742777

743-
mock_run.assert_called_once()
778+
subprocess_run.assert_called_once()
744779
assert ssh_private_key_path.exists()
745780
if not on_windows:
746781
assert ssh_private_key_path.stat().st_mode & 0o777 == 0o600
@@ -749,6 +784,8 @@ def test_create_ssh_keypair(monkeypatch: pytest.MonkeyPatch, tmp_path: Path):
749784
if not on_windows:
750785
assert ssh_public_key_path.stat().st_mode & 0o777 == 0o644
751786

787+
assert has_passphrase(ssh_private_key_path) == expected
788+
752789

753790
@pytest.fixture
754791
def linux_ssh_config(
@@ -1075,6 +1112,13 @@ def test_setup_passwordless_ssh_access_to_cluster(
10751112
backup_authorized_keys_file = backup_ssh_dir / "authorized_keys"
10761113
assert backup_authorized_keys_file.exists()
10771114

1115+
ssh_private_key_path = ssh_dir / "id_rsa"
1116+
ssh_public_key_path = ssh_private_key_path.with_suffix(".pub")
1117+
if not ssh_public_key_path.exists():
1118+
create_ssh_keypair(ssh_private_key_path=ssh_private_key_path)
1119+
assert ssh_public_key_path.exists()
1120+
assert not has_passphrase(ssh_private_key_path)
1121+
10781122
if not passwordless_to_cluster_is_already_setup:
10791123
if authorized_keys_file.exists():
10801124
logger.warning(
@@ -1166,6 +1210,7 @@ def test_setup_passwordless_ssh_access(
11661210
f"Temporarily deleting the ssh dir (backed up at {backup_ssh_dir})"
11671211
)
11681212
shutil.rmtree(ssh_dir)
1213+
ssh_dir.mkdir(mode=0o700, exist_ok=False)
11691214

11701215
if not public_key_exists:
11711216
# There should be no ssh keys in the ssh dir before calling the function.
@@ -1252,3 +1297,75 @@ def test_setup_passwordless_ssh_access(
12521297
for drac_cluster in drac_clusters_in_ssh_config:
12531298
mock_setup_passwordless_ssh_access_to_cluster.assert_any_call(drac_cluster)
12541299
assert result is True
1300+
1301+
1302+
@pytest.fixture()
1303+
def cluster(request: pytest.FixtureRequest) -> str:
1304+
cluster_name: str | None = getattr(
1305+
request, "param", os.environ.get("SLURM_CLUSTER", None)
1306+
)
1307+
if not cluster_name:
1308+
pytest.skip(reason="Need a real slurm cluster to be specified")
1309+
return cluster_name
1310+
1311+
1312+
@pytest.fixture()
1313+
def authorized_keys_backup(cluster: str):
1314+
"""Fixture used to backup the authorized_keys file on the remote and restore it
1315+
after tests."""
1316+
connect_kwargs = {}
1317+
backup_authorized_keys_path = "~/.ssh/authorized_keys.backup"
1318+
if not check_passwordless(cluster):
1319+
if in_github_CI:
1320+
pytest.skip(
1321+
f"Can't run this test because passwordless SSH access to {cluster} is "
1322+
"not setup."
1323+
)
1324+
password = getpass.getpass(
1325+
T.red("\nEnter your password for SSH-ing to the cluster\n")
1326+
)
1327+
connect_kwargs = {"password": password}
1328+
1329+
remote = fabric.Connection(cluster, connect_kwargs=connect_kwargs)
1330+
remote.run(
1331+
f"cp ~/.ssh/authorized_keys {backup_authorized_keys_path}",
1332+
echo=True,
1333+
echo_format=T.bold_cyan(f"({cluster})" + " $ {command}"),
1334+
in_stream=False,
1335+
)
1336+
try:
1337+
yield backup_authorized_keys_path
1338+
finally:
1339+
remote.run(
1340+
"cp ~/.ssh/authorized_keys.backup ~/.ssh/authorized_keys",
1341+
echo=True,
1342+
echo_format=T.bold_cyan(f"({cluster})" + " $ {command}"),
1343+
in_stream=False,
1344+
)
1345+
1346+
1347+
@pytest.mark.timeout(None)
1348+
@pytest.mark.skipif(
1349+
in_github_CI, reason="Can't run this in the github CI since it asks for a password."
1350+
)
1351+
@pytest.mark.skipif(
1352+
"SLURM_CLUSTER" not in os.environ, reason="Only runs with a real cluster."
1353+
)
1354+
def test_setup_passwordless_ssh_access_to_real_cluster(
1355+
cluster: str,
1356+
authorized_keys_backup: str,
1357+
):
1358+
if check_passwordless(cluster):
1359+
logger.warning(
1360+
f"Temporarily removing the ~/.ssh/authorized_keys file on {cluster} "
1361+
f"(backed up at {cluster}:{authorized_keys_backup})"
1362+
)
1363+
fabric.Connection(cluster).run(
1364+
"rm ~/.ssh/authorized_keys",
1365+
echo=True,
1366+
echo_format=T.bold_cyan(f"({cluster})" + " $ {command}"),
1367+
in_stream=False,
1368+
)
1369+
assert not check_passwordless(cluster)
1370+
setup_passwordless_ssh_access_to_cluster(cluster)
1371+
assert check_passwordless(cluster)

0 commit comments

Comments
 (0)