Skip to content

Commit c99e8ef

Browse files
committed
ssh: only allow string command
List of strings was previously accepted, but was giving the false feeling that we could avoid the command being parsed by the remote shell. ssh doesn't support passing separated parameters. It reconstructs a string with the arguments it gets from the command line separated by a single space. This string is then passed to the ssh servers that passes it to the shell. With this change, the string passed to the ssh() function is the exact string received by the shell on the remote host. Signed-off-by: Gaëtan Lehmann <gaetan.lehmann@vates.tech>
1 parent 71f5651 commit c99e8ef

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

54 files changed

+509
-503
lines changed

conftest.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -681,7 +681,7 @@ def uefi_vm(imported_vm):
681681
yield vm
682682

683683
@pytest.fixture(scope='session')
684-
def additional_repos(request, hosts):
684+
def additional_repos(request, hosts: list[Host]):
685685
if request.param is None:
686686
yield []
687687
return
@@ -707,7 +707,7 @@ def additional_repos(request, hosts):
707707

708708
for host in hosts:
709709
for host_ in host.pool.hosts:
710-
host_.ssh(['rm', '-f', repo_file])
710+
host_.ssh('rm -f {repo_file}')
711711

712712
@pytest.fixture(scope='session')
713713
def second_network(pytestconfig, host):

jobs.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -544,7 +544,7 @@ def build_pytest_cmd(job_data, hosts=None, host_version=None, pytest_args=[]):
544544
if hosts is not None:
545545
try:
546546
host = hosts.split(',')[0]
547-
cmd = ["lsb_release", "-sr"]
547+
cmd = "lsb_release -sr"
548548
host_version = ssh(host, cmd)
549549
except Exception as e:
550550
print(e, file=sys.stderr)

lib/commands.py

Lines changed: 29 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,18 @@
1+
from __future__ import annotations
2+
13
import base64
24
import logging
35
import os
46
import platform
57
import subprocess
68

79
import lib.config as config
8-
from lib.common import HostAddress
910
from lib.netutil import wrap_ip
1011

11-
from typing import List, Literal, Union, overload
12+
from typing import TYPE_CHECKING, List, Literal, Union, overload
13+
14+
if TYPE_CHECKING:
15+
from lib.common import HostAddress
1216

1317
class BaseCommandFailed(Exception):
1418
__slots__ = 'returncode', 'stdout', 'cmd'
@@ -63,8 +67,17 @@ def _ellide_log_lines(log):
6367
reduced_message.append("(...)")
6468
return "\n{}".format("\n".join(reduced_message))
6569

66-
def _ssh(hostname_or_ip, cmd, check, simple_output, suppress_fingerprint_warnings,
67-
background, decode, options, multiplexing) -> Union[SSHResult, SSHCommandFailed, str, bytes, None]:
70+
def _ssh(
71+
hostname_or_ip: str,
72+
cmd: str,
73+
check: bool,
74+
simple_output: bool,
75+
suppress_fingerprint_warnings: bool,
76+
background: bool,
77+
decode: bool,
78+
options: list[str],
79+
multiplexing: bool,
80+
) -> SSHResult | SSHCommandFailed | str | bytes | None:
6881
opts = list(options)
6982
opts += ['-o', 'BatchMode yes']
7083
opts += ['-o', 'PubkeyAcceptedAlgorithms +ssh-rsa']
@@ -87,12 +100,7 @@ def _ssh(hostname_or_ip, cmd, check, simple_output, suppress_fingerprint_warning
87100
else:
88101
opts += ['-o', 'ControlMaster no']
89102

90-
if isinstance(cmd, str):
91-
command = [cmd]
92-
else:
93-
command = cmd
94-
95-
ssh_cmd = ['ssh', f'root@{hostname_or_ip}'] + opts + command
103+
ssh_cmd = ['ssh', f'root@{hostname_or_ip}'] + opts + [cmd]
96104

97105
# Fetch banner and remove it to avoid stdout/stderr pollution.
98106
banner_res = None
@@ -104,7 +112,7 @@ def _ssh(hostname_or_ip, cmd, check, simple_output, suppress_fingerprint_warning
104112
check=False
105113
)
106114

107-
logging.debug(f"[{hostname_or_ip}] {' '.join(command)}")
115+
logging.debug(f"[{hostname_or_ip}] {cmd}")
108116
process = subprocess.Popen(
109117
ssh_cmd,
110118
stdout=subprocess.PIPE,
@@ -127,20 +135,20 @@ def _ssh(hostname_or_ip, cmd, check, simple_output, suppress_fingerprint_warning
127135

128136
# Even if check is False, we still raise in case of return code 255, which means a SSH error.
129137
if res.returncode == 255:
130-
return SSHCommandFailed(255, "SSH Error: %s" % output_for_errors, command)
138+
return SSHCommandFailed(255, "SSH Error: %s" % output_for_errors, cmd)
131139

132140
output: Union[bytes, str] = res.stdout
133141
if banner_res:
134142
if banner_res.returncode == 255:
135-
return SSHCommandFailed(255, "SSH Error: %s" % banner_res.stdout.decode(errors='replace'), command)
143+
return SSHCommandFailed(255, "SSH Error: %s" % banner_res.stdout.decode(errors='replace'), cmd)
136144
output = output[len(banner_res.stdout):]
137145

138146
if decode:
139147
assert isinstance(output, bytes)
140148
output = output.decode()
141149

142150
if res.returncode and check:
143-
return SSHCommandFailed(res.returncode, output_for_errors, command)
151+
return SSHCommandFailed(res.returncode, output_for_errors, cmd)
144152

145153
if simple_output:
146154
return output.strip()
@@ -150,37 +158,37 @@ def _ssh(hostname_or_ip, cmd, check, simple_output, suppress_fingerprint_warning
150158
# This function is kept short for shorter pytest traces upon SSH failures, which are common,
151159
# as pytest prints the whole function definition that raised the SSHCommandFailed exception
152160
@overload
153-
def ssh(hostname_or_ip: HostAddress, cmd: Union[str, List[str]], *, check: bool = True,
161+
def ssh(hostname_or_ip: HostAddress, cmd: str, *, check: bool = True,
154162
simple_output: Literal[True] = True,
155163
suppress_fingerprint_warnings: bool = True, background: Literal[False] = False,
156164
decode: Literal[True] = True, options: List[str] = [], multiplexing=True) -> str:
157165
...
158166
@overload
159-
def ssh(hostname_or_ip: HostAddress, cmd: Union[str, List[str]], *, check: bool = True,
167+
def ssh(hostname_or_ip: HostAddress, cmd: str, *, check: bool = True,
160168
simple_output: Literal[True] = True,
161169
suppress_fingerprint_warnings: bool = True, background: Literal[False] = False,
162170
decode: Literal[False], options: List[str] = [], multiplexing=True) -> bytes:
163171
...
164172
@overload
165-
def ssh(hostname_or_ip: HostAddress, cmd: Union[str, List[str]], *, check: bool = True,
173+
def ssh(hostname_or_ip: HostAddress, cmd: str, *, check: bool = True,
166174
simple_output: Literal[False],
167175
suppress_fingerprint_warnings: bool = True, background: Literal[False] = False,
168176
decode: bool = True, options: List[str] = [], multiplexing=True) -> SSHResult:
169177
...
170178
@overload
171-
def ssh(hostname_or_ip: HostAddress, cmd: Union[str, List[str]], *, check: bool = True,
179+
def ssh(hostname_or_ip: HostAddress, cmd: str, *, check: bool = True,
172180
simple_output: Literal[False],
173181
suppress_fingerprint_warnings: bool = True, background: Literal[True],
174182
decode: bool = True, options: List[str] = [], multiplexing=True) -> None:
175183
...
176184
@overload
177-
def ssh(hostname_or_ip: HostAddress, cmd: Union[str, List[str]], *, check=True,
185+
def ssh(hostname_or_ip: HostAddress, cmd: str, *, check=True,
178186
simple_output: bool = True,
179187
suppress_fingerprint_warnings=True, background: bool = False,
180188
decode: bool = True, options: List[str] = [], multiplexing=True) \
181189
-> Union[str, bytes, SSHResult, None]:
182190
...
183-
def ssh(hostname_or_ip, cmd, *, check=True, simple_output=True,
191+
def ssh(hostname_or_ip: HostAddress, cmd: str, *, check=True, simple_output=True,
184192
suppress_fingerprint_warnings=True,
185193
background=False, decode=True, options=[], multiplexing=True):
186194
result_or_exc = _ssh(hostname_or_ip, cmd, check, simple_output, suppress_fingerprint_warnings,
@@ -190,7 +198,7 @@ def ssh(hostname_or_ip, cmd, *, check=True, simple_output=True,
190198
else:
191199
return result_or_exc
192200

193-
def ssh_with_result(hostname_or_ip, cmd, suppress_fingerprint_warnings=True,
201+
def ssh_with_result(hostname_or_ip: HostAddress, cmd: str, suppress_fingerprint_warnings=True,
194202
background=False, decode=True, options=[], multiplexing=True) -> SSHResult:
195203
result_or_exc = _ssh(hostname_or_ip, cmd, False, False, suppress_fingerprint_warnings,
196204
background, decode, options, multiplexing)

lib/common.py

Lines changed: 19 additions & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -35,7 +35,8 @@
3535
)
3636

3737
if TYPE_CHECKING:
38-
import lib.host
38+
from lib.host import Host
39+
3940

4041
KiB = 2**10
4142
MiB = KiB**2
@@ -217,7 +218,7 @@ def strip_suffix(string, suffix):
217218
return string[:-len(suffix)]
218219
return string
219220

220-
def setup_formatted_and_mounted_disk(host, sr_disk, fs_type, mountpoint):
221+
def setup_formatted_and_mounted_disk(host: Host, sr_disk, fs_type, mountpoint):
221222
if fs_type == 'ext4':
222223
option_force = '-F'
223224
elif fs_type == 'xfs':
@@ -226,29 +227,25 @@ def setup_formatted_and_mounted_disk(host, sr_disk, fs_type, mountpoint):
226227
raise Exception(f"Unsupported fs_type '{fs_type}' in this function")
227228
device = '/dev/' + sr_disk
228229
logging.info(f">> Format sr_disk {sr_disk} and mount it on host {host}")
229-
host.ssh(['mkfs.' + fs_type, option_force, device])
230-
host.ssh(['rm', '-rf', mountpoint]) # Remove any existing leftover to ensure rmdir will not fail in teardown
231-
host.ssh(['mkdir', '-p', mountpoint])
232-
host.ssh(['cp', '-f', '/etc/fstab', '/etc/fstab.orig'])
233-
ssh_client = host.ssh(['echo', '$SSH_CLIENT']).split()[0]
230+
host.ssh(f'mkfs.{fs_type} {option_force} {device}')
231+
host.ssh(f'rm -rf {mountpoint}') # Remove any existing leftover to ensure rmdir will not fail in teardown
232+
host.ssh(f'mkdir -p {mountpoint}')
233+
host.ssh('cp -f /etc/fstab /etc/fstab.orig')
234+
ssh_client = host.ssh('echo $SSH_CLIENT').split()[0]
234235
now = datetime.now().isoformat()
235-
host.ssh([
236-
'echo',
237-
f'"# added by {ssh_client} on {now}\n{device} {mountpoint} {fs_type} defaults 0 0"',
238-
'>>/etc/fstab',
239-
])
236+
host.ssh(f'echo "# added by {ssh_client} on {now}\n{device} {mountpoint} {fs_type} defaults 0 0" >>/etc/fstab')
240237
try:
241-
host.ssh(['mount', mountpoint])
238+
host.ssh(f'mount {mountpoint}')
242239
except Exception:
243240
# restore fstab then re-raise
244-
host.ssh(['cp', '-f', '/etc/fstab.orig', '/etc/fstab'])
241+
host.ssh('cp -f /etc/fstab.orig /etc/fstab')
245242
raise
246243

247-
def teardown_formatted_and_mounted_disk(host, mountpoint):
244+
def teardown_formatted_and_mounted_disk(host: Host, mountpoint):
248245
logging.info(f"<< Restore fstab and unmount {mountpoint} on host {host}")
249-
host.ssh(['cp', '-f', '/etc/fstab.orig', '/etc/fstab'])
250-
host.ssh(['umount', mountpoint])
251-
host.ssh(['rmdir', mountpoint])
246+
host.ssh('cp -f /etc/fstab.orig /etc/fstab')
247+
host.ssh(f'umount {mountpoint}')
248+
host.ssh(f'rmdir {mountpoint}')
252249

253250
def exec_nofail(func):
254251
""" Execute a function, log a warning if it fails, and return eiter [] or [e] where e is the exception. """
@@ -305,21 +302,21 @@ def randid(length=6):
305302
return ''.join(random.choices(characters, k=length))
306303

307304
@overload
308-
def _param_get(host: 'lib.host.Host', xe_prefix: str, uuid: str, param_name: str, key: Optional[str] = ...,
305+
def _param_get(host: Host, xe_prefix: str, uuid: str, param_name: str, key: Optional[str] = ...,
309306
accept_unknown_key: Literal[False] = ...) -> str:
310307
...
311308

312309
@overload
313-
def _param_get(host: 'lib.host.Host', xe_prefix: str, uuid: str, param_name: str, key: Optional[str] = ...,
310+
def _param_get(host: Host, xe_prefix: str, uuid: str, param_name: str, key: Optional[str] = ...,
314311
accept_unknown_key: Literal[True] = ...) -> Optional[str]:
315312
...
316313

317314
@overload
318-
def _param_get(host: 'lib.host.Host', xe_prefix: str, uuid: str, param_name: str, key: Optional[str] = ...,
315+
def _param_get(host: Host, xe_prefix: str, uuid: str, param_name: str, key: Optional[str] = ...,
319316
accept_unknown_key: bool = ...) -> Optional[str]:
320317
...
321318

322-
def _param_get(host: 'lib.host.Host', xe_prefix: str, uuid: str, param_name: str, key: Optional[str] = None,
319+
def _param_get(host: Host, xe_prefix: str, uuid: str, param_name: str, key: Optional[str] = None,
323320
accept_unknown_key: bool = False) -> Optional[str]:
324321
""" Common implementation for param_get. """
325322
import lib.commands as commands

lib/fistpoint.py

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -40,7 +40,7 @@ def enable_exit_on_fistpoint(host: Host):
4040

4141
@staticmethod
4242
def disable_exit_on_fistpoint(host: Host):
43-
host.ssh(["rm", FistPoint._get_path(LVHDRT_EXIT_FIST)])
43+
host.ssh(f'rm {FistPoint._get_path(LVHDRT_EXIT_FIST)}')
4444

4545
@staticmethod
4646
def _get_name(name: str) -> str:
@@ -60,7 +60,7 @@ def enable(self):
6060
def disable(self):
6161
logging.info(f"Disabling fistpoint {self.fistpointName}")
6262
try:
63-
self.host.ssh(["rm", self._get_path(self.fistpointName)])
63+
self.host.ssh(f'rm {self._get_path(self.fistpointName)}')
6464
except SSHCommandFailed as e:
6565
logging.info(f"Failed trying to disable fistpoint {self._get_path(self.fistpointName)} with error {e}")
6666
raise

0 commit comments

Comments
 (0)