-
Notifications
You must be signed in to change notification settings - Fork 13
/
Copy pathtest_init_command.py
1670 lines (1483 loc) · 55.9 KB
/
test_init_command.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
469
470
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
545
546
547
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
612
613
614
615
616
617
618
619
620
621
622
623
624
625
626
627
628
629
630
631
632
633
634
635
636
637
638
639
640
641
642
643
644
645
646
647
648
649
650
651
652
653
654
655
656
657
658
659
660
661
662
663
664
665
666
667
668
669
670
671
672
673
674
675
676
677
678
679
680
681
682
683
684
685
686
687
688
689
690
691
692
693
694
695
696
697
698
699
700
701
702
703
704
705
706
707
708
709
710
711
712
713
714
715
716
717
718
719
720
721
722
723
724
725
726
727
728
729
730
731
732
733
734
735
736
737
738
739
740
741
742
743
744
745
746
747
748
749
750
751
752
753
754
755
756
757
758
759
760
761
762
763
764
765
766
767
768
769
770
771
772
773
774
775
776
777
778
779
780
781
782
783
784
785
786
787
788
789
790
791
792
793
794
795
796
797
798
799
800
801
802
803
804
805
806
807
808
809
810
811
812
813
814
815
816
817
818
819
820
821
822
823
824
825
826
827
828
829
830
831
832
833
834
835
836
837
838
839
840
841
842
843
844
845
846
847
848
849
850
851
852
853
854
855
856
857
858
859
860
861
862
863
864
865
866
867
868
869
870
871
872
873
874
875
876
877
878
879
880
881
882
883
884
885
886
887
888
889
890
891
892
893
894
895
896
897
898
899
900
901
902
903
904
905
906
907
908
909
910
911
912
913
914
915
916
917
918
919
920
921
922
923
924
925
926
927
928
929
930
931
932
933
934
935
936
937
938
939
940
941
942
943
944
945
946
947
948
949
950
951
952
953
954
955
956
957
958
959
960
961
962
963
964
965
966
967
968
969
970
971
972
973
974
975
976
977
978
979
980
981
982
983
984
985
986
987
988
989
990
991
992
993
994
995
996
997
998
999
1000
from __future__ import annotations
import contextlib
import errno
import getpass
import json
import os
import shutil
import subprocess
import sys
import textwrap
from functools import partial
from logging import getLogger as get_logger
from pathlib import Path, PurePosixPath
from unittest.mock import Mock
import invoke
import paramiko
import pytest
import pytest_mock
import questionary
from prompt_toolkit.input import PipeInput, create_pipe_input
from pytest_regressions.file_regression import FileRegressionFixture
from milatools.cli import init_command
from milatools.cli.init_command import (
DRAC_CLUSTERS,
_get_drac_username,
_get_mila_username,
_setup_ssh_config_file,
create_ssh_keypair,
get_windows_home_path_in_wsl,
has_passphrase,
run_ssh_copy_id,
setup_keys_on_login_node,
setup_passwordless_ssh_access,
setup_ssh_config,
setup_vscode_settings,
setup_windows_ssh_config_from_wsl,
)
from milatools.cli.utils import (
SSHConfig,
running_inside_WSL,
)
from milatools.utils.local_v1 import check_passwordless
from milatools.utils.remote_v1 import RemoteV1
from milatools.utils.remote_v2 import (
SSH_CACHE_DIR,
SSH_CONFIG_FILE,
RemoteV2,
get_controlpath_for,
is_already_logged_in,
)
from .common import (
function_call_string,
in_github_CI,
in_self_hosted_github_CI,
on_windows,
passwordless_ssh_connection_to_localhost_is_setup,
xfails_on_windows,
)
logger = get_logger(__name__)
def raises_NoConsoleScreenBufferError_on_windows_ci_action():
if sys.platform == "win32":
import prompt_toolkit.output.win32
raises = prompt_toolkit.output.win32.NoConsoleScreenBufferError
else:
raises = ()
return xfails_on_windows(
raises=raises,
reason="TODO: Tests using input pipes don't work on GitHub CI.",
strict=False,
)
def permission_bits_check_doesnt_work_on_windows():
return pytest.mark.xfail(
sys.platform == "win32",
raises=AssertionError,
reason="TODO: The check for permission bits is failing on Windows in CI.",
)
# Set a module-level mark: Each test cannot take longer than 1 second to run.
pytestmark = pytest.mark.timeout(10)
@pytest.fixture
def input_pipe(monkeypatch: pytest.MonkeyPatch, request: pytest.FixtureRequest):
"""Fixture that creates an input pipe and makes questionary use it.
To use it, call `input_pipe.send_text("some text")`.
NOTE: Important: Send the \\r (with one backslash) character if the prompt is on a
newline.
For confirmation prompts, just send one letter, otherwise the '\r' is passed to the
next prompt, which sees it as just pressing enter, which uses the default value.
"""
request.node.add_marker(raises_NoConsoleScreenBufferError_on_windows_ci_action())
with create_pipe_input() as input_pipe:
monkeypatch.setattr(
"questionary.confirm",
partial(questionary.confirm, input=input_pipe),
)
monkeypatch.setattr(
"questionary.text", partial(questionary.text, input=input_pipe)
)
yield input_pipe
def test_questionary_uses_input_pipe(input_pipe: PipeInput):
"""Small test just to make sure that our way of passing the input pipe to
Questionary in tests makes sense.
TODO: Ideally we'd want to make sure that the input prompts work exactly the same
way in our tests as they will for the users, but that's not something I'm confident
I can guarantee.
"""
input_pipe.send_text("bob\r")
assert questionary.text("name?").unsafe_ask() == "bob"
input_pipe.send_text("y")
assert questionary.confirm("confirm?").unsafe_ask() is True
input_pipe.send_text("n")
assert questionary.confirm("confirm?").unsafe_ask() is False
def _join_blocks(*blocks: str, user: str = "bob") -> str:
return "\n".join(textwrap.dedent(block) for block in blocks).format(user=user)
def _yn(accept: bool):
return "y" if accept else "n"
def test_creates_ssh_config_file(tmp_path: Path, input_pipe: PipeInput):
ssh_config_path = tmp_path / "ssh_config"
for prompt in [
"y",
"bob\r", # mila username
"y", # drac?
"bob\r", # drac username
"y",
"y",
"y",
"y",
"y",
]:
input_pipe.send_text(prompt)
setup_ssh_config(tmp_path / "ssh_config")
assert ssh_config_path.exists()
@pytest.mark.parametrize(
"drac_username",
[None, "bob"],
ids=["no_drac", "drac"],
)
@pytest.mark.parametrize(
"confirm_changes",
[False, True],
ids=["reject_changes", "confirm_changes"],
)
@pytest.mark.parametrize(
"initial_contents",
[
"",
"""\
# A comment in the file.
""",
"""\
# a comment
Host foo
HostName foobar.com
""",
"""\
# a comment
Host foo
HostName foobar.com
# another comment
""",
"""\
# a comment
Host foo
HostName foobar.com
# another comment after lots of empty lines.
""",
],
ids=[
"empty",
"has_comment",
"has_different_indent",
"has_comment_and_entry",
"has_comment_and_entry_with_extra_space",
],
)
def test_setup_ssh(
initial_contents: str,
confirm_changes: bool,
drac_username: str | None,
tmp_path: Path,
file_regression: FileRegressionFixture,
input_pipe: PipeInput,
):
"""Checks what entries are added to the ssh config file when running the
corresponding portion of `mila init`."""
ssh_config_path = tmp_path / ".ssh" / "config"
ssh_config_path.parent.mkdir(parents=True, exist_ok=False)
if initial_contents:
initial_contents = textwrap.dedent(initial_contents)
if initial_contents is not None:
with open(ssh_config_path, "w") as f:
f.write(initial_contents)
user_inputs = [
"bob\r", # username on Mila cluster
*( # DRAC account? + enter username
["n"] if drac_username is None else ["y", drac_username + "\r"]
),
_yn(confirm_changes),
]
for prompt in user_inputs:
input_pipe.send_text(prompt)
should_exit = not confirm_changes
with pytest.raises(SystemExit) if should_exit else contextlib.nullcontext():
setup_ssh_config(ssh_config_path=ssh_config_path)
assert ssh_config_path.exists()
with open(ssh_config_path) as f:
resulting_contents = f.read()
expected_text = "\n".join(
[
"Running the `mila init` command with "
+ (
"\n".join(
[
"this initial content:",
"",
"```",
initial_contents,
"```",
]
)
if initial_contents
else "no initial ssh config file"
),
"",
f"and these user inputs: {tuple(user_inputs)}",
"leads the following ssh config file:",
"",
"```",
resulting_contents,
"```",
"",
]
)
file_regression.check(expected_text, extension=".md")
def test_fixes_overly_general_entry(
tmp_path: Path,
input_pipe: PipeInput,
file_regression: FileRegressionFixture,
):
"""Test the case where the user has a *.server.mila.quebec entry."""
ssh_config_path = tmp_path / ".ssh" / "config"
ssh_config_path.parent.mkdir(parents=True, exist_ok=False)
initial_contents = textwrap.dedent(
"""\
Host *.server.mila.quebec
User bob
"""
)
with open(ssh_config_path, "w") as f:
f.write(initial_contents)
# Enter username, accept fixing that entry, then confirm.
for user_input in [
"bob\r", # mila username
"n", # DRAC account?
"y",
"y",
]:
input_pipe.send_text(user_input)
setup_ssh_config(ssh_config_path=ssh_config_path)
with open(ssh_config_path) as f:
resulting_contents = f.read()
file_regression.check(resulting_contents)
assert (
"Host *.server.mila.quebec !*login.server.mila.quebec"
in resulting_contents.splitlines()
)
def test_ssh_config_host(tmp_path: Path):
ssh_config_path = tmp_path / "config"
with open(ssh_config_path, "w") as f:
f.write(
textwrap.dedent(
"""\
Host mila
HostName login.server.mila.quebec
User normandf
PreferredAuthentications publickey,keyboard-interactive
Port 2222
ServerAliveInterval 120
ServerAliveCountMax 5
BatchMode yes
"""
)
)
assert SSHConfig(str(ssh_config_path)).host("mila") == {
"hostname": "login.server.mila.quebec",
"user": "normandf",
"preferredauthentications": "publickey,keyboard-interactive",
"port": "2222",
"serveraliveinterval": "120",
"serveralivecountmax": "5",
"batchmode": "yes",
}
@pytest.mark.parametrize(
"already_has_drac", [True, False], ids=["has_drac_entries", "no_drac_entries"]
)
@pytest.mark.parametrize(
"already_has_mila", [True, False], ids=["has_mila_entry", "no_mila_entry"]
)
@pytest.mark.parametrize(
"already_has_mila_cpu",
[True, False],
ids=["has_mila_cpu_entry", "no_mila_cpu_entry"],
)
@pytest.mark.parametrize(
"already_has_mila_compute",
[True, False],
ids=["has_mila_compute_entry", "no_mila_compute_entry"],
)
def test_with_existing_entries(
already_has_mila: bool,
already_has_mila_cpu: bool,
already_has_mila_compute: bool,
already_has_drac: bool,
file_regression: FileRegressionFixture,
tmp_path: Path,
input_pipe: PipeInput,
):
user = "bob"
existing_mila = textwrap.dedent(
f"""\
Host mila
HostName login.server.mila.quebec
User {user}
"""
)
existing_mila_cpu = textwrap.dedent(
"""\
Host mila-cpu
HostName login.server.mila.quebec
"""
)
existing_mila_compute = textwrap.dedent(
"""\
Host *.server.mila.quebec !*login.server.mila.quebec
HostName foooobar.com
"""
)
existing_drac = textwrap.dedent(
f"""
# Compute Canada
Host beluga cedar graham narval niagara
Hostname %h.alliancecan.ca
User {user}
Host mist
Hostname mist.scinet.utoronto.ca
User {user}
Host !beluga bc????? bg????? bl?????
ProxyJump beluga
User {user}
Host !cedar cdr? cdr?? cdr??? cdr????
ProxyJump cedar
User {user}
Host !graham gra??? gra????
ProxyJump graham
User {user}
Host !narval nc????? ng?????
ProxyJump narval
User {user}
Host !niagara nia????
ProxyJump niagara
User {user}
"""
)
initial_blocks = []
initial_blocks += [existing_mila] if already_has_mila else []
initial_blocks += [existing_mila_cpu] if already_has_mila_cpu else []
initial_blocks += [existing_mila_compute] if already_has_mila_compute else []
initial_blocks += [existing_drac] if already_has_drac else []
initial_contents = _join_blocks(*initial_blocks)
# TODO: Need to insert the entries in the right place, in the right order!
ssh_config_path = tmp_path / ".ssh" / "config"
ssh_config_path.parent.mkdir(parents=True, exist_ok=False)
with open(ssh_config_path, "w") as f:
f.write(initial_contents)
# Accept all the prompts.
username_input = (
["bob\r"]
if not already_has_mila or (already_has_mila and "User" not in existing_mila)
else []
)
controlmaster_block = "\n".join(
[
" ControlMaster auto",
" ControlPath ~/.cache/ssh/%r@%h:%p",
" ControlPersist 600",
]
)
if not all(
[
already_has_mila and controlmaster_block in existing_mila,
already_has_mila_cpu,
already_has_mila_compute and controlmaster_block in existing_mila_compute,
already_has_drac,
]
):
# There's a confirmation prompt only if we're adding some entry.
confirm_inputs = ["y"]
else:
confirm_inputs = []
drac_username_inputs = []
if not already_has_drac:
drac_username_inputs = ["y", f"{user}\r"]
prompt_inputs = username_input + drac_username_inputs + confirm_inputs
for prompt_input in prompt_inputs:
input_pipe.send_text(prompt_input)
setup_ssh_config(ssh_config_path=ssh_config_path)
with open(ssh_config_path) as f:
resulting_contents = f.read()
expected_text = "\n".join(
[
"Running the `mila init` command with "
+ (
"\n".join(
[
"this initial content:",
"",
"```",
initial_contents,
"```",
]
)
if initial_contents
else "no initial ssh config file"
),
"",
f"and these user inputs: {prompt_inputs}",
"leads to the following ssh config file:",
"",
"```",
resulting_contents,
"```",
]
)
file_regression.check(
expected_text,
extension=".md",
)
@pytest.mark.parametrize(
("contents", "prompt_inputs", "expected"),
[
pytest.param(
"", # empty file.
["bob\r"], # enter "bob" then enter.
"bob", # get "bob" as username.
id="empty_file",
),
pytest.param(
textwrap.dedent(
"""\
Host mila
HostName login.server.mila.quebec
User bob
"""
),
[],
"bob",
id="existing_mila_entry",
),
pytest.param(
textwrap.dedent(
"""\
Host mila
HostName login.server.mila.quebec
"""
),
["bob\r"],
"bob",
id="entry_without_user",
),
pytest.param(
textwrap.dedent(
"""\
Host mila
HostName login.server.mila.quebec
User george
# duplicate entry
Host mila mila_alias
User Bob
"""
),
["bob\r"],
"bob",
id="two_matching_entries",
),
pytest.param(
textwrap.dedent(
"""\
Host fooo mila bar baz
HostName login.server.mila.quebec
User george
"""
),
[],
"george",
id="with_aliases",
),
pytest.param(
"",
[" \r", "bob\r"],
"bob",
id="empty_username",
),
],
)
def test_get_username(
contents: str,
prompt_inputs: list[str],
expected: str,
input_pipe: PipeInput,
tmp_path: Path,
):
# TODO: We should probably also have a test that checks that keyboard interrupts
# work.
# Seems like the text to send for that would be "\x03".
ssh_config_path = tmp_path / "config"
with open(ssh_config_path, "w") as f:
f.write(contents)
ssh_config = SSHConfig(ssh_config_path)
if not prompt_inputs:
input_pipe.close()
for prompt_input in prompt_inputs:
input_pipe.send_text(prompt_input)
assert _get_mila_username(ssh_config) == expected
@pytest.mark.parametrize(
("contents", "prompt_inputs", "expected"),
[
pytest.param(
"", # empty file.
["n"], # No I don't have a DRAC account.
None, # get None as a result
id="no_drac_account",
),
pytest.param(
"", # empty file.
["y", "bob\r"], # enter yes, then "bob" then enter.
"bob", # get "bob" as username.
id="empty_file",
),
pytest.param(
textwrap.dedent(
"""\
Host narval
HostName narval.computecanada.ca
User bob
"""
),
[],
"bob",
id="existing_drac_entry",
),
pytest.param(
textwrap.dedent(
"""\
Host beluga cedar graham narval niagara
HostName %h.computecanada.ca
ControlMaster auto
ControlPath ~/.cache/ssh/%r@%h:%p
ControlPersist 600
"""
),
["y", "bob\r"], # Yes I have a username on the drac clusters, and it's bob.
"bob",
id="entry_without_user",
),
pytest.param(
textwrap.dedent(
"""\
Host beluga cedar graham narval niagara
HostName login.server.mila.quebec
User george
# duplicate entry
Host beluga cedar graham narval niagara other_cluster
User Bob
"""
),
["y", "bob\r"],
"bob",
id="two_matching_entries",
),
pytest.param(
textwrap.dedent(
"""\
Host fooo beluga bar baz
HostName beluga.alliancecan.ca
User george
"""
),
[],
"george",
id="with_aliases",
),
pytest.param(
"",
# Yes (by pressing just enter), then an invalid username (space), then a
# real username.
["\r", " \r", "bob\r"],
"bob",
id="empty_username",
),
],
)
def test_get_drac_username(
contents: str,
prompt_inputs: list[str],
expected: str | None,
input_pipe: PipeInput,
tmp_path: Path,
):
ssh_config_path = tmp_path / "config"
with open(ssh_config_path, "w") as f:
f.write(contents)
ssh_config = SSHConfig(ssh_config_path)
if not prompt_inputs:
input_pipe.close()
for prompt_input in prompt_inputs:
input_pipe.send_text(prompt_input)
assert _get_drac_username(ssh_config) == expected
class TestSetupSshFile:
def test_create_file(self, tmp_path: Path, input_pipe: PipeInput):
config_path = tmp_path / "config"
input_pipe.send_text("y")
file = _setup_ssh_config_file(config_path)
assert file.exists()
assert file.stat().st_mode & 0o777 == 0o600
def test_refuse_creating_file(self, tmp_path: Path, input_pipe: PipeInput):
config_path = tmp_path / "config"
input_pipe.send_text("n")
with pytest.raises(SystemExit):
config_path = _setup_ssh_config_file(config_path)
assert not config_path.exists()
@permission_bits_check_doesnt_work_on_windows()
def test_fix_file_permissions(self, tmp_path: Path):
config_path = tmp_path / "config"
config_path.touch(mode=0o644)
assert config_path.stat().st_mode & 0o777 == 0o644
# todo: Do we want to have a prompt in this case here?
# idea: might be nice to also test that the right output is printed
file = _setup_ssh_config_file(config_path)
assert file.exists()
assert file.stat().st_mode & 0o777 == 0o600
def test_creates_dir(self, tmp_path: Path, input_pipe: PipeInput):
config_path = tmp_path / "fake_ssh" / "config"
input_pipe.send_text("y")
file = _setup_ssh_config_file(config_path)
assert file.parent.exists()
assert file.parent.stat().st_mode & 0o777 == 0o700
assert file.exists()
assert file.stat().st_mode & 0o777 == 0o600
@pytest.mark.parametrize(
"file_exists",
[
pytest.param(
True,
marks=permission_bits_check_doesnt_work_on_windows(),
),
False,
],
)
def test_fixes_dir_permission_issues(
self, file_exists: bool, tmp_path: Path, input_pipe: PipeInput
):
config_path = tmp_path / "fake_ssh" / "config"
config_path.parent.mkdir(mode=0o755)
if file_exists:
config_path.touch(mode=0o600)
else:
input_pipe.send_text("y")
file = _setup_ssh_config_file(config_path)
assert file.parent.exists()
assert file.parent.stat().st_mode & 0o777 == 0o700
assert file.exists()
assert file.stat().st_mode & 0o777 == 0o600
# takes a little longer in the CI runner (Windows in particular)
@pytest.mark.timeout(20)
@pytest.mark.parametrize(
("passphrase", "expected"),
[("", False), ("bobobo", True), ("\n", True), (" ", True)],
)
@pytest.mark.parametrize(
"filename",
[
"bob",
"dir with spaces/somefile",
"dir_with_'single_quotes'/somefile",
pytest.param(
'dir_with_"doublequotes"/somefile',
marks=pytest.mark.xfail(
sys.platform == "win32",
strict=True,
raises=OSError,
reason="Doesn't work on Windows.",
),
),
pytest.param(
"windows_style_dir\\bob",
marks=pytest.mark.skipif(
sys.platform != "win32", reason="only runs on Windows."
),
),
],
)
def test_create_ssh_keypair(
mocker: pytest_mock.MockerFixture,
tmp_path: Path,
filename: str,
passphrase: str,
expected: bool,
):
# Wrap the subprocess.run call (but also actually execute the commands).
subprocess_run = mocker.patch("subprocess.run", wraps=subprocess.run)
fake_ssh_folder = tmp_path / "fake_ssh"
fake_ssh_folder.mkdir(mode=0o700)
ssh_private_key_path = fake_ssh_folder / filename
ssh_private_key_path.parent.mkdir(mode=0o700, exist_ok=True, parents=True)
create_ssh_keypair(ssh_private_key_path=ssh_private_key_path, passphrase=passphrase)
subprocess_run.assert_called_once()
assert ssh_private_key_path.exists()
if not on_windows:
assert ssh_private_key_path.stat().st_mode & 0o777 == 0o600
ssh_public_key_path = ssh_private_key_path.with_suffix(".pub")
assert ssh_public_key_path.exists()
if not on_windows:
assert ssh_public_key_path.stat().st_mode & 0o777 == 0o644
assert has_passphrase(ssh_private_key_path) == expected
@pytest.fixture
def linux_ssh_config(
tmp_path: Path, input_pipe: PipeInput, monkeypatch: pytest.MonkeyPatch
) -> SSHConfig:
"""Creates the SSH config that is generated by `mila init` on a Linux machine."""
# Enter username, accept fixing that entry, then confirm.
ssh_config_path = tmp_path / "ssh_config"
for prompt in [
"y", # Create an ssh config file?
"bob\r", # What's your username on the Mila cluster?
"y", # Do you also have a DRAC account?
"bob\r", # username on DRAC
"y", # accept adding the entries in the ssh config
]:
input_pipe.send_text(prompt)
if sys.platform.startswith("win"):
pytest.skip(
"TODO: Issue when changing sys.platform to get the Linux config when "
"on Windows."
)
setup_ssh_config(ssh_config_path)
return SSHConfig(ssh_config_path)
@pytest.mark.parametrize("accept_changes", [True, False], ids=["accept", "reject"])
def test_setup_windows_ssh_config_from_wsl(
pretend_to_be_in_WSL, # here even if `windows_home` already uses it (more explicit)
windows_home: Path,
linux_ssh_config: SSHConfig,
input_pipe: PipeInput,
file_regression: FileRegressionFixture,
fake_linux_ssh_keypair: tuple[Path, Path], # add this fixture so the keys exist.
accept_changes: bool,
):
initial_contents = linux_ssh_config.cfg.config()
windows_ssh_config_path = windows_home / ".ssh" / "config"
user_inputs: list[str] = []
if not windows_ssh_config_path.exists():
# We accept creating the Windows SSH config file for now.
user_inputs.append("y")
user_inputs.append("y" if accept_changes else "n")
for prompt in user_inputs:
input_pipe.send_text(prompt)
setup_windows_ssh_config_from_wsl(linux_ssh_config=linux_ssh_config)
assert windows_ssh_config_path.exists()
assert windows_ssh_config_path.stat().st_mode & 0o777 == 0o600
assert windows_ssh_config_path.parent.stat().st_mode & 0o777 == 0o700
if not accept_changes:
assert windows_ssh_config_path.read_text() == ""
expected_text = "\n".join(
[
"When this SSH config is already present in the WSL environment with "
+ (
"\n".join(
[
"these initial contents:",
"```",
initial_contents,
"```",
"",
]
)
if initial_contents.strip()
else "no initial ssh config file"
),
"",
f"and these user inputs: {tuple(user_inputs)}",
"leads the following ssh config file on the Windows side:",
"",
"```",
windows_ssh_config_path.read_text(),
"```",
]
)
file_regression.check(expected_text, extension=".md")
@pytest.fixture
def windows_ssh_config(
linux_ssh_config: SSHConfig,
windows_home: Path,
input_pipe: PipeInput,
monkeypatch: pytest.MonkeyPatch,
) -> SSHConfig:
"""Returns the Windows ssh config as it would be when we create it from WSL."""
windows_ssh_config_path = windows_home / ".ssh" / "config"
monkeypatch.setattr(
init_command,
running_inside_WSL.__name__, # type: ignore
Mock(spec=running_inside_WSL, return_value=True),
)
monkeypatch.setattr(
init_command,
get_windows_home_path_in_wsl.__name__, # type: ignore
Mock(spec=get_windows_home_path_in_wsl, return_value=windows_home),
)
user_inputs: list[str] = []
if not windows_ssh_config_path.exists():
# We accept creating the Windows SSH config file for now.
user_inputs.append("y")
user_inputs.append("y") # accept changes.
for prompt in user_inputs:
input_pipe.send_text(prompt)
setup_windows_ssh_config_from_wsl(linux_ssh_config=linux_ssh_config)
assert windows_ssh_config_path.exists()
assert windows_ssh_config_path.stat().st_mode & 0o777 == 0o600
assert windows_ssh_config_path.parent.stat().st_mode & 0o777 == 0o700
return SSHConfig(windows_ssh_config_path)
@xfails_on_windows(
raises=AssertionError, reason="TODO: buggy test: getting assert None is not None."
)
@pytest.mark.parametrize(
"initial_settings", [None, {}, {"foo": "bar"}, {"remote.SSH.connectTimeout": 123}]
)
@pytest.mark.parametrize("accept_changes", [True, False], ids=["accept", "reject"])
def test_setup_vscode_settings(
tmp_path: Path,
monkeypatch: pytest.MonkeyPatch,
input_pipe: PipeInput,
initial_settings: dict | None,
file_regression: FileRegressionFixture,
accept_changes: bool,
):
vscode_settings_json_path = tmp_path / "settings.json"
if initial_settings is not None:
with open(vscode_settings_json_path, "w") as f:
json.dump(initial_settings, f, indent=4)
monkeypatch.setattr(
init_command,
init_command.vscode_installed.__name__,
Mock(spec=init_command.vscode_installed, return_value=True),
)
monkeypatch.setattr(
init_command,
init_command.get_expected_vscode_settings_json_path.__name__,
Mock(
spec=init_command.get_expected_vscode_settings_json_path,
return_value=vscode_settings_json_path,
),
)
user_inputs = ["y" if accept_changes else "n"]
for user_input in user_inputs:
input_pipe.send_text(user_input)
setup_vscode_settings()
resulting_contents: str | None = None
resulting_settings: dict | None = None
if not accept_changes and initial_settings is None:
# Shouldn't create the file if we don't accept the changes and there's no
# initial file.
assert not vscode_settings_json_path.exists()
if vscode_settings_json_path.exists():
resulting_contents = vscode_settings_json_path.read_text()
resulting_settings = json.loads(resulting_contents)
assert isinstance(resulting_settings, dict)
if not accept_changes:
if initial_settings is None:
assert not vscode_settings_json_path.exists()
return # skip creating the regression file in that case.
assert resulting_settings == initial_settings
return
assert resulting_contents is not None
assert resulting_settings is not None
expected_text = "\n".join(
[
f"Calling `{setup_vscode_settings.__name__}()` with "
+ (
"\n".join(
[
"this initial content:",
"",
"```json",