diff --git a/.github/BOTMETA.yml b/.github/BOTMETA.yml index 175982c6b6c..833f2ef3bbd 100644 --- a/.github/BOTMETA.yml +++ b/.github/BOTMETA.yml @@ -119,6 +119,8 @@ files: $connections/saltstack.py: labels: saltstack maintainers: mscherer + $connections/wsl.py: + maintainers: rgl $connections/zone.py: maintainers: $team_ansible_core $doc_fragments/: diff --git a/plugins/connection/wsl.py b/plugins/connection/wsl.py new file mode 100644 index 00000000000..f299ede3c56 --- /dev/null +++ b/plugins/connection/wsl.py @@ -0,0 +1,791 @@ +# -*- coding: utf-8 -*- +# Derived from ansible/plugins/connection/proxmox_pct_remote.py (c) 2024 Nils Stein (@mietzen) +# Derived from ansible/plugins/connection/paramiko_ssh.py (c) 2012, Michael DeHaan +# Copyright (c) 2025 Rui Lopes (@rgl) +# Copyright (c) 2025 Ansible Project +# GNU General Public License v3.0+ (see LICENSES/GPL-3.0-or-later.txt or https://www.gnu.org/licenses/gpl-3.0.txt) +# SPDX-License-Identifier: GPL-3.0-or-later + +from __future__ import annotations + +DOCUMENTATION = r""" +author: Rui Lopes (@rgl) +name: wsl +short_description: Run tasks in WSL distribution using wsl.exe CLI via SSH +requirements: + - paramiko +description: + - Run commands or put/fetch files to an existing WSL distribution using wsl.exe CLI via SSH. + - Uses the Python SSH implementation (Paramiko) to connect to the WSL host. +version_added: "10.5.0" +options: + remote_addr: + description: + - Address of the remote target. + default: inventory_hostname + type: string + vars: + - name: inventory_hostname + - name: ansible_host + - name: ansible_ssh_host + - name: ansible_paramiko_host + port: + description: Remote port to connect to. + type: int + default: 22 + ini: + - section: defaults + key: remote_port + - section: paramiko_connection + key: remote_port + env: + - name: ANSIBLE_REMOTE_PORT + - name: ANSIBLE_REMOTE_PARAMIKO_PORT + vars: + - name: ansible_port + - name: ansible_ssh_port + - name: ansible_paramiko_port + keyword: + - name: port + remote_user: + description: + - User to login/authenticate as. + - Can be set from the CLI via the C(--user) or C(-u) options. + type: string + vars: + - name: ansible_user + - name: ansible_ssh_user + - name: ansible_paramiko_user + env: + - name: ANSIBLE_REMOTE_USER + - name: ANSIBLE_PARAMIKO_REMOTE_USER + ini: + - section: defaults + key: remote_user + - section: paramiko_connection + key: remote_user + keyword: + - name: remote_user + password: + description: + - Secret used to either login the SSH server or as a passphrase for SSH keys that require it. + - Can be set from the CLI via the C(--ask-pass) option. + type: string + vars: + - name: ansible_password + - name: ansible_ssh_pass + - name: ansible_ssh_password + - name: ansible_paramiko_pass + - name: ansible_paramiko_password + use_rsa_sha2_algorithms: + description: + - Whether or not to enable RSA SHA2 algorithms for pubkeys and hostkeys. + - On paramiko versions older than 2.9, this only affects hostkeys. + - For behavior matching paramiko<2.9 set this to V(false). + vars: + - name: ansible_paramiko_use_rsa_sha2_algorithms + ini: + - {key: use_rsa_sha2_algorithms, section: paramiko_connection} + env: + - {name: ANSIBLE_PARAMIKO_USE_RSA_SHA2_ALGORITHMS} + default: true + type: boolean + host_key_auto_add: + description: "Automatically add host keys to C(~/.ssh/known_hosts)." + env: + - name: ANSIBLE_PARAMIKO_HOST_KEY_AUTO_ADD + ini: + - key: host_key_auto_add + section: paramiko_connection + type: boolean + look_for_keys: + default: True + description: "Set to V(false) to disable searching for private key files in C(~/.ssh/)." + env: + - name: ANSIBLE_PARAMIKO_LOOK_FOR_KEYS + ini: + - {key: look_for_keys, section: paramiko_connection} + type: boolean + proxy_command: + default: "" + description: + - Proxy information for running the connection via a jumphost. + - This option is supported by paramiko version 1.9.0 or newer. + type: string + env: + - name: ANSIBLE_PARAMIKO_PROXY_COMMAND + ini: + - {key: proxy_command, section: paramiko_connection} + vars: + - name: ansible_paramiko_proxy_command + record_host_keys: + default: True + description: "Save the host keys to a file." + env: + - name: ANSIBLE_PARAMIKO_RECORD_HOST_KEYS + ini: + - section: paramiko_connection + key: record_host_keys + type: boolean + host_key_checking: + description: "Set this to V(false) if you want to avoid host key checking by the underlying tools Ansible uses to connect to the host." + type: boolean + default: true + env: + - name: ANSIBLE_HOST_KEY_CHECKING + - name: ANSIBLE_SSH_HOST_KEY_CHECKING + - name: ANSIBLE_PARAMIKO_HOST_KEY_CHECKING + ini: + - section: defaults + key: host_key_checking + - section: paramiko_connection + key: host_key_checking + vars: + - name: ansible_host_key_checking + - name: ansible_ssh_host_key_checking + - name: ansible_paramiko_host_key_checking + use_persistent_connections: + description: "Toggles the use of persistence for connections." + type: boolean + default: False + env: + - name: ANSIBLE_USE_PERSISTENT_CONNECTIONS + ini: + - section: defaults + key: use_persistent_connections + banner_timeout: + type: float + default: 30 + description: + - Configures, in seconds, the amount of time to wait for the SSH + banner to be presented. + - This option is supported by paramiko version 1.15.0 or newer. + ini: + - section: paramiko_connection + key: banner_timeout + env: + - name: ANSIBLE_PARAMIKO_BANNER_TIMEOUT + timeout: + type: int + default: 10 + description: + - Number of seconds until the plugin gives up on failing to establish a TCP connection. + - This option is supported by paramiko version 2.2.0 or newer. + ini: + - section: defaults + key: timeout + - section: ssh_connection + key: timeout + - section: paramiko_connection + key: timeout + env: + - name: ANSIBLE_TIMEOUT + - name: ANSIBLE_SSH_TIMEOUT + - name: ANSIBLE_PARAMIKO_TIMEOUT + vars: + - name: ansible_ssh_timeout + - name: ansible_paramiko_timeout + cli: + - name: timeout + lock_file_timeout: + type: int + default: 60 + description: Number of seconds until the plugin gives up on trying to write a lock file when writing SSH known host keys. + vars: + - name: ansible_lock_file_timeout + env: + - name: ANSIBLE_LOCK_FILE_TIMEOUT + private_key_file: + description: + - Path to private key file to use for authentication. + type: string + ini: + - section: defaults + key: private_key_file + - section: paramiko_connection + key: private_key_file + env: + - name: ANSIBLE_PRIVATE_KEY_FILE + - name: ANSIBLE_PARAMIKO_PRIVATE_KEY_FILE + vars: + - name: ansible_private_key_file + - name: ansible_ssh_private_key_file + - name: ansible_paramiko_private_key_file + cli: + - name: private_key_file + option: "--private-key" + user_known_hosts_file: + description: + - Path to the user known hosts file. + - Used to verify the ssh hosts keys. + type: string + default: ~/.ssh/known_hosts + ini: + - section: paramiko_connection + key: user_known_hosts_file + vars: + - name: ansible_paramiko_user_known_hosts_file + wsl_distribution: + description: + - WSL distribution name + type: string + required: true + vars: + - name: wsl_distribution + wsl_user: + description: + - WSL distribution user + type: string + vars: + - name: wsl_user + become_user: + description: + - WSL distribution user + type: string + default: root + vars: + - name: become_user + - name: ansible_become_user + become: + description: + - whether to use the user defined by ansible_become_user. + type: bool + default: false + vars: + - name: become + - name: ansible_become +""" + +EXAMPLES = r""" +# ------------------------ +# Inventory: inventory.yml +# ------------------------ +--- +all: + children: + wsl: + hosts: + example-wsl-ubuntu: + ansible_host: 10.0.0.10 + wsl_distribution: ubuntu + wsl_user: ubuntu + vars: + ansible_connection: community.general.wsl + ansible_user: vagrant +# ---------------------- +# Playbook: playbook.yml +# ---------------------- +--- +- name: WSL Example + hosts: wsl + gather_facts: true + become: true + tasks: + - name: Ping + ansible.builtin.ping: + - name: Id (with become false) + become: false + changed_when: false + args: + executable: /bin/bash + ansible.builtin.shell: | + exec 2>&1 + set -x + echo "$0" + pwd + id + - name: Id (with become true) + changed_when: false + args: + executable: /bin/bash + ansible.builtin.shell: | + exec 2>&1 + set -x + echo "$0" + pwd + id + - name: Reboot + ansible.builtin.reboot: + boot_time_command: systemctl show -p ActiveEnterTimestamp init.scope +""" + +import io +import os +import pathlib +import shlex +import socket +import tempfile +import typing as t + +from ansible.errors import ( + AnsibleAuthenticationFailure, + AnsibleConnectionFailure, + AnsibleError, +) +from ansible_collections.community.general.plugins.module_utils._filelock import FileLock, LockTimeout +from ansible.module_utils.common.text.converters import to_bytes, to_native, to_text +from ansible.module_utils.compat.paramiko import PARAMIKO_IMPORT_ERR, paramiko +from ansible.module_utils.compat.version import LooseVersion +from ansible.playbook.play_context import PlayContext +from ansible.plugins.connection import ConnectionBase +from ansible.utils.display import Display +from ansible.utils.path import makedirs_safe +from binascii import hexlify +from subprocess import list2cmdline +from paramiko.client import SSHClient +from paramiko.pkey import PKey + + +display = Display() + + +def authenticity_msg(hostname: str, ktype: str, fingerprint: str) -> str: + msg = f""" + paramiko: The authenticity of host '{hostname}' can't be established. + The {ktype} key fingerprint is {fingerprint}. + Are you sure you want to continue connecting (yes/no)? + """ + return msg + + +# TODO why is this here? why are we even trying to verify whether paramiko exists? +# if this is really required, we cannot just put SSHClient in a variable and +# use that as a type hint... so what should we do? +# 1. drop ansible.module_utils.compat.paramiko (which the ansible devel branch now marks this as deprecated) +# see https://github.com/ansible/ansible/blob/v2.18.3/lib/ansible/module_utils/compat/paramiko.py +# see deprecate('The paramiko compat import is deprecated', version='2.21') +# at https://github.com/ansible/ansible/blob/devel/lib/ansible/module_utils/compat/paramiko.py +# 2. drop type hints. +MissingHostKeyPolicy: type = object +if paramiko: + MissingHostKeyPolicy = paramiko.MissingHostKeyPolicy + + +class MyAddPolicy(MissingHostKeyPolicy): + """ + Based on AutoAddPolicy in paramiko so we can determine when keys are added + + and also prompt for input. + + Policy for automatically adding the hostname and new host key to the + local L{HostKeys} object, and saving it. This is used by L{SSHClient}. + """ + + def __init__(self, connection: Connection) -> None: + self.connection = connection + self._options = connection._options + + def missing_host_key(self, client: SSHClient, hostname: str, key: PKey) -> None: + + if all((self.connection.get_option('host_key_checking'), not self.connection.get_option('host_key_auto_add'))): + + fingerprint = hexlify(key.get_fingerprint()) + ktype = key.get_name() + + if self.connection.get_option('use_persistent_connections') or self.connection.force_persistence: + # don't print the prompt string since the user cannot respond + # to the question anyway + raise AnsibleError(authenticity_msg(hostname, ktype, fingerprint)[1:92]) + + inp = to_text( + display.prompt_until(authenticity_msg(hostname, ktype, fingerprint), private=False), + errors='surrogate_or_strict' + ) + + if inp.lower() not in ['yes', 'y', '']: + raise AnsibleError('host connection rejected by user') + + key._added_by_ansible_this_time = True + + # existing implementation below: + client._host_keys.add(hostname, key.get_name(), key) + + # host keys are actually saved in close() function below + # in order to control ordering. + + +class Connection(ConnectionBase): + """ SSH based connections (paramiko) to WSL """ + + transport = 'community.general.wsl' + _log_channel: str | None = None + + def __init__(self, play_context: PlayContext, new_stdin: io.TextIOWrapper | None = None, *args: t.Any, **kwargs: t.Any): + super(Connection, self).__init__(play_context, new_stdin, *args, **kwargs) + + def _set_log_channel(self, name: str) -> None: + """ Mimic paramiko.SSHClient.set_log_channel """ + self._log_channel = name + + def _parse_proxy_command(self, port: int = 22) -> dict[str, t.Any]: + proxy_command = self.get_option('proxy_command') or None + + sock_kwarg = {} + if proxy_command: + replacers: t.Dict[str, str] = { + '%h': self.get_option('remote_addr'), + '%p': str(port), + '%r': self.get_option('remote_user') + } + for find, replace in replacers.items(): + proxy_command = proxy_command.replace(find, replace) + try: + sock_kwarg = {'sock': paramiko.ProxyCommand(proxy_command)} + display.vvv(f'CONFIGURE PROXY COMMAND FOR CONNECTION: {proxy_command}', host=self.get_option('remote_addr')) + except AttributeError: + display.warning('Paramiko ProxyCommand support unavailable. ' + 'Please upgrade to Paramiko 1.9.0 or newer. ' + 'Not using configured ProxyCommand') + + return sock_kwarg + + def _connect(self) -> Connection: + """ activates the connection object """ + + if paramiko is None: + raise AnsibleError(f'paramiko is not installed: {to_native(PARAMIKO_IMPORT_ERR)}') + + port = self.get_option('port') + display.vvv(f'ESTABLISH PARAMIKO SSH CONNECTION FOR USER: {self.get_option("remote_user")} on PORT {to_text(port)} TO {self.get_option("remote_addr")}', + host=self.get_option('remote_addr')) + + ssh = paramiko.SSHClient() + + # Set pubkey and hostkey algorithms to disable, the only manipulation allowed currently + # is keeping or omitting rsa-sha2 algorithms + # default_keys: t.Tuple[str] = () + paramiko_preferred_pubkeys = getattr(paramiko.Transport, '_preferred_pubkeys', ()) + paramiko_preferred_hostkeys = getattr(paramiko.Transport, '_preferred_keys', ()) + use_rsa_sha2_algorithms = self.get_option('use_rsa_sha2_algorithms') + disabled_algorithms: t.Dict[str, t.Iterable[str]] = {} + if not use_rsa_sha2_algorithms: + if paramiko_preferred_pubkeys: + disabled_algorithms['pubkeys'] = tuple(a for a in paramiko_preferred_pubkeys if 'rsa-sha2' in a) + if paramiko_preferred_hostkeys: + disabled_algorithms['keys'] = tuple(a for a in paramiko_preferred_hostkeys if 'rsa-sha2' in a) + + # override paramiko's default logger name + if self._log_channel is not None: + ssh.set_log_channel(self._log_channel) + + self.keyfile = os.path.expanduser(self.get_option('user_known_hosts_file')) + + if self.get_option('host_key_checking'): + for ssh_known_hosts in ('/etc/ssh/ssh_known_hosts', '/etc/openssh/ssh_known_hosts', self.keyfile): + try: + ssh.load_system_host_keys(ssh_known_hosts) + break + except IOError: + pass # file was not found, but not required to function + except paramiko.hostkeys.InvalidHostKey as e: + raise AnsibleConnectionFailure(f'Invalid host key: {to_text(e.line)}') + try: + ssh.load_system_host_keys() + except paramiko.hostkeys.InvalidHostKey as e: + raise AnsibleConnectionFailure(f'Invalid host key: {to_text(e.line)}') + + ssh_connect_kwargs = self._parse_proxy_command(port) + ssh.set_missing_host_key_policy(MyAddPolicy(self)) + conn_password = self.get_option('password') + allow_agent = True + + if conn_password is not None: + allow_agent = False + + try: + key_filename = None + if self.get_option('private_key_file'): + key_filename = os.path.expanduser(self.get_option('private_key_file')) + + # paramiko 2.2 introduced auth_timeout parameter + if LooseVersion(paramiko.__version__) >= LooseVersion('2.2.0'): + ssh_connect_kwargs['auth_timeout'] = self.get_option('timeout') + + # paramiko 1.15 introduced banner timeout parameter + if LooseVersion(paramiko.__version__) >= LooseVersion('1.15.0'): + ssh_connect_kwargs['banner_timeout'] = self.get_option('banner_timeout') + + ssh.connect( + self.get_option('remote_addr').lower(), + username=self.get_option('remote_user'), + allow_agent=allow_agent, + look_for_keys=self.get_option('look_for_keys'), + key_filename=key_filename, + password=conn_password, + timeout=self.get_option('timeout'), + port=port, + disabled_algorithms=disabled_algorithms, + **ssh_connect_kwargs, + ) + except paramiko.ssh_exception.BadHostKeyException as e: + raise AnsibleConnectionFailure(f'host key mismatch for {to_text(e.hostname)}') + except paramiko.ssh_exception.AuthenticationException as e: + msg = f'Failed to authenticate: {e}' + raise AnsibleAuthenticationFailure(msg) + except Exception as e: + msg = to_text(e) + if u'PID check failed' in msg: + raise AnsibleError('paramiko version issue, please upgrade paramiko on the machine running ansible') + elif u'Private key file is encrypted' in msg: + msg = f'ssh {self.get_option("remote_user")}@{self.get_options("remote_addr")}:{port} : ' + \ + f'{msg}\nTo connect as a different user, use -u .' + raise AnsibleConnectionFailure(msg) + else: + raise AnsibleConnectionFailure(msg) + self.ssh = ssh + self._connected = True + return self + + def _any_keys_added(self) -> bool: + for hostname, keys in self.ssh._host_keys.items(): + for keytype, key in keys.items(): + added_this_time = getattr(key, '_added_by_ansible_this_time', False) + if added_this_time: + return True + return False + + def _save_ssh_host_keys(self, filename: str) -> None: + """ + not using the paramiko save_ssh_host_keys function as we want to add new SSH keys at the bottom so folks + don't complain about it :) + """ + + if not self._any_keys_added(): + return + + path = os.path.expanduser('~/.ssh') + makedirs_safe(path) + + with open(filename, 'w') as f: + for hostname, keys in self.ssh._host_keys.items(): + for keytype, key in keys.items(): + # was f.write + added_this_time = getattr(key, '_added_by_ansible_this_time', False) + if not added_this_time: + f.write(f'{hostname} {keytype} {key.get_base64()}\n') + + for hostname, keys in self.ssh._host_keys.items(): + for keytype, key in keys.items(): + added_this_time = getattr(key, '_added_by_ansible_this_time', False) + if added_this_time: + f.write(f'{hostname} {keytype} {key.get_base64()}\n') + + def _build_wsl_command(self, cmd: str) -> str: + wsl_distribution = self.get_option('wsl_distribution') + become = self.get_option('become') + become_user = self.get_option('become_user') + if become and become_user: + wsl_user = become_user + else: + wsl_user = self.get_option('wsl_user') + args = ['wsl.exe', '--distribution', wsl_distribution] + if wsl_user: + args.extend(['--user', wsl_user]) + args.extend(['--']) + args.extend(shlex.split(cmd)) + if os.getenv('_ANSIBLE_TEST_WSL_CONNECTION_PLUGIN_Waeri5tepheeSha2fae8'): + return shlex.join(args) + return list2cmdline(args) # see https://github.com/python/cpython/blob/3.11/Lib/subprocess.py#L576 + + def exec_command(self, cmd: str, in_data: bytes | None = None, sudoable: bool = True) -> tuple[int, bytes, bytes]: + """ run a command on inside a WSL distribution """ + + cmd = self._build_wsl_command(cmd) + + super(Connection, self).exec_command(cmd, in_data=in_data, sudoable=sudoable) + + bufsize = 4096 + + try: + self.ssh.get_transport().set_keepalive(5) + chan = self.ssh.get_transport().open_session() + except Exception as e: + text_e = to_text(e) + msg = 'Failed to open session' + if text_e: + msg += f': {text_e}' + raise AnsibleConnectionFailure(to_native(msg)) + + display.vvv(f'EXEC {cmd}', host=self.get_option('remote_addr')) + + cmd = to_bytes(cmd, errors='surrogate_or_strict') + + no_prompt_out = b'' + no_prompt_err = b'' + become_output = b'' + + try: + chan.exec_command(cmd) + if self.become and self.become.expect_prompt(): + password_prompt = False + become_success = False + while not (become_success or password_prompt): + display.debug('Waiting for Privilege Escalation input') + + chunk = chan.recv(bufsize) + display.debug(f'chunk is: {to_text(chunk)}') + if not chunk: + if b'unknown user' in become_output: + n_become_user = to_native(self.become.get_option('become_user')) + raise AnsibleError(f'user {n_become_user} does not exist') + else: + break + # raise AnsibleError('ssh connection closed waiting for password prompt') + become_output += chunk + + # need to check every line because we might get lectured + # and we might get the middle of a line in a chunk + for line in become_output.splitlines(True): + if self.become.check_success(line): + become_success = True + break + elif self.become.check_password_prompt(line): + password_prompt = True + break + + if password_prompt: + if self.become: + become_pass = self.become.get_option('become_pass') + chan.sendall(to_bytes(become_pass + '\n', errors='surrogate_or_strict')) + else: + raise AnsibleError('A password is required but none was supplied') + else: + no_prompt_out += become_output + no_prompt_err += become_output + + if in_data: + for i in range(0, len(in_data), bufsize): + chan.send(in_data[i:i + bufsize]) + chan.shutdown_write() + elif in_data == b'': + chan.shutdown_write() + + except socket.timeout: + raise AnsibleError('ssh timed out waiting for privilege escalation.\n' + to_text(become_output)) + + stdout = b''.join(chan.makefile('rb', bufsize)) + stderr = b''.join(chan.makefile_stderr('rb', bufsize)) + returncode = chan.recv_exit_status() + + # NB the full english error message is: + # 'wsl.exe' is not recognized as an internal or external command, + # operable program or batch file. + if "'wsl.exe' is not recognized" in stderr.decode('utf-8'): + raise AnsibleError( + f'wsl.exe not found in path of host: {to_text(self.get_option("remote_addr"))}') + + return (returncode, no_prompt_out + stdout, no_prompt_out + stderr) + + def put_file(self, in_path: str, out_path: str) -> None: + """ transfer a file from local to remote """ + + display.vvv(f'PUT {in_path} TO {out_path}', host=self.get_option('remote_addr')) + try: + with open(in_path, 'rb') as f: + data = f.read() + returncode, stdout, stderr = self.exec_command( + ' '.join([ + self._shell.executable, '-c', + self._shell.quote(f'cat > {out_path}')]), + in_data=data, + sudoable=False) + if returncode != 0: + if 'cat: not found' in stderr.decode('utf-8'): + raise AnsibleError( + f'cat not found in path of WSL distribution: {to_text(self.get_option("wsl_distribution"))}') + raise AnsibleError( + f'{to_text(stdout)}\n{to_text(stderr)}') + except Exception as e: + raise AnsibleError( + f'error occurred while putting file from {in_path} to {out_path}!\n{to_text(e)}') + + def fetch_file(self, in_path: str, out_path: str) -> None: + """ save a remote file to the specified path """ + + display.vvv(f'FETCH {in_path} TO {out_path}', host=self.get_option('remote_addr')) + try: + returncode, stdout, stderr = self.exec_command( + ' '.join([ + self._shell.executable, '-c', + self._shell.quote(f'cat {in_path}')]), + sudoable=False) + if returncode != 0: + if 'cat: not found' in stderr.decode('utf-8'): + raise AnsibleError( + f'cat not found in path of WSL distribution: {to_text(self.get_option("wsl_distribution"))}') + raise AnsibleError( + f'{to_text(stdout)}\n{to_text(stderr)}') + with open(out_path, 'wb') as f: + f.write(stdout) + except Exception as e: + raise AnsibleError( + f'error occurred while fetching file from {in_path} to {out_path}!\n{to_text(e)}') + + def reset(self) -> None: + """ reset the connection """ + + if not self._connected: + return + self.close() + self._connect() + + def close(self) -> None: + """ terminate the connection """ + + if self.get_option('host_key_checking') and self.get_option('record_host_keys') and self._any_keys_added(): + # add any new SSH host keys -- warning -- this could be slow + # (This doesn't acquire the connection lock because it needs + # to exclude only other known_hosts writers, not connections + # that are starting up.) + lockfile = os.path.basename(self.keyfile) + dirname = os.path.dirname(self.keyfile) + makedirs_safe(dirname) + tmp_keyfile_name = None + try: + with FileLock().lock_file(lockfile, dirname, self.get_option('lock_file_timeout')): + # just in case any were added recently + + self.ssh.load_system_host_keys() + self.ssh._host_keys.update(self.ssh._system_host_keys) + + # gather information about the current key file, so + # we can ensure the new file has the correct mode/owner + + key_dir = os.path.dirname(self.keyfile) + if os.path.exists(self.keyfile): + key_stat = os.stat(self.keyfile) + mode = key_stat.st_mode & 0o777 + uid = key_stat.st_uid + gid = key_stat.st_gid + else: + mode = 0o644 + uid = os.getuid() + gid = os.getgid() + + # Save the new keys to a temporary file and move it into place + # rather than rewriting the file. We set delete=False because + # the file will be moved into place rather than cleaned up. + + with tempfile.NamedTemporaryFile(dir=key_dir, delete=False) as tmp_keyfile: + tmp_keyfile_name = tmp_keyfile.name + os.chmod(tmp_keyfile_name, mode) + os.chown(tmp_keyfile_name, uid, gid) + self._save_ssh_host_keys(tmp_keyfile_name) + + os.rename(tmp_keyfile_name, self.keyfile) + except LockTimeout: + raise AnsibleError( + f'writing lock file for {self.keyfile} ran in to the timeout of {self.get_option("lock_file_timeout")}s') + except paramiko.hostkeys.InvalidHostKey as e: + raise AnsibleConnectionFailure(f'Invalid host key: {e.line}') + except Exception as e: + # unable to save keys, including scenario when key was invalid + # and caught earlier + raise AnsibleError( + f'error occurred while writing SSH host keys!\n{to_text(e)}') + finally: + if tmp_keyfile_name is not None: + pathlib.Path(tmp_keyfile_name).unlink(missing_ok=True) + + self.ssh.close() + self._connected = False diff --git a/tests/integration/targets/connection_wsl/aliases b/tests/integration/targets/connection_wsl/aliases new file mode 100644 index 00000000000..d2fefd10c74 --- /dev/null +++ b/tests/integration/targets/connection_wsl/aliases @@ -0,0 +1,12 @@ +# Copyright (c) 2025 Nils Stein (@mietzen) +# Copyright (c) 2025 Ansible Project +# GNU General Public License v3.0+ (see LICENSES/GPL-3.0-or-later.txt or https://www.gnu.org/licenses/gpl-3.0.txt) +# SPDX-License-Identifier: GPL-3.0-or-later + +azp/posix/3 +destructive +needs/root +needs/target/connection +skip/docker +skip/alpine +skip/macos diff --git a/tests/integration/targets/connection_wsl/dependencies.yml b/tests/integration/targets/connection_wsl/dependencies.yml new file mode 100644 index 00000000000..a21674cf08c --- /dev/null +++ b/tests/integration/targets/connection_wsl/dependencies.yml @@ -0,0 +1,18 @@ +--- +# Copyright (c) 2025 Nils Stein (@mietzen) +# Copyright (c) 2025 Ansible Project +# GNU General Public License v3.0+ (see LICENSES/GPL-3.0-or-later.txt or https://www.gnu.org/licenses/gpl-3.0.txt) +# SPDX-License-Identifier: GPL-3.0-or-later + +- hosts: localhost + gather_facts: true + serial: 1 + tasks: + - name: Copy wsl.exe mock + copy: + src: files/wsl.exe + dest: /usr/sbin/wsl.exe + mode: '0755' + - name: Install paramiko + pip: + name: "paramiko>=3.0.0" diff --git a/tests/integration/targets/connection_wsl/files/wsl.exe b/tests/integration/targets/connection_wsl/files/wsl.exe new file mode 100755 index 00000000000..0c6aafaf0fb --- /dev/null +++ b/tests/integration/targets/connection_wsl/files/wsl.exe @@ -0,0 +1,72 @@ +#!/usr/bin/env bash +# Derived from ../../connection_proxmox_pct_remote/files/pct Copyright (c) 2025 Nils Stein (@mietzen) +# Copyright (c) 2025 Rui Lopes (@rgl) +# Copyright (c) 2025 Ansible Project +# GNU General Public License v3.0+ (see LICENSES/GPL-3.0-or-later.txt or https://www.gnu.org/licenses/gpl-3.0.txt) +# SPDX-License-Identifier: GPL-3.0-or-later + +# Shell script to mock wsl.exe behavior + +set -euo pipefail + +function quote_args { + local quoted_args=() + for arg in "$@"; do + if [[ -z "$arg" || "$arg" =~ [^a-zA-Z0-9@%+=:,./-] ]]; then + local escaped_arg=${arg//\'/\'\\\'\'} + quoted_args+=("'$escaped_arg'") + else + quoted_args+=("$arg") + fi + done + echo -n "${quoted_args[@]}" +} + +declare -a mock_args=() +declare -a cmd_args=() +wsl_distribution="" +wsl_user="" + +while [[ $# -gt 0 ]]; do + case $1 in + --distribution|-d) + wsl_distribution="$2" + mock_args+=("$1" "$2") + shift 2 + ;; + --user|-u) + wsl_user="$2" + mock_args+=("$1" "$2") + shift 2 + ;; + --) + mock_args+=("$1") + shift + while [[ $# -gt 0 ]]; do + mock_args+=("$1") + cmd_args+=("$1") + shift + done + ;; + *) + >&2 echo "unexpected args: $@" + exit 1 + ;; + esac +done + +mock_cmd="wsl.exe $(quote_args "${mock_args[@]}")" +cmd="$(quote_args "${cmd_args[@]}")" + +>&2 echo "[INFO] MOCKING: $mock_cmd" +>&2 echo "[INFO] CMD: $cmd" + +tmp_dir="/tmp/ansible-remote/wsl/integration_test/wsl_distribution_${wsl_distribution}" + +mkdir -p "$tmp_dir" + +pushd "$tmp_dir" >/dev/null + +eval "$cmd" + +popd >/dev/null diff --git a/tests/integration/targets/connection_wsl/plugin-specific-tests.yml b/tests/integration/targets/connection_wsl/plugin-specific-tests.yml new file mode 100644 index 00000000000..41fe06cdb95 --- /dev/null +++ b/tests/integration/targets/connection_wsl/plugin-specific-tests.yml @@ -0,0 +1,32 @@ +--- +# Copyright (c) Ansible Project +# GNU General Public License v3.0+ (see LICENSES/GPL-3.0-or-later.txt or https://www.gnu.org/licenses/gpl-3.0.txt) +# SPDX-License-Identifier: GPL-3.0-or-later + +- hosts: "{{ target_hosts }}" + gather_facts: false + serial: 1 + tasks: + - name: create file without content + copy: + content: "" + dest: "{{ remote_tmp }}/test_empty.txt" + force: no + mode: '0644' + + - name: assert file without content exists + stat: + path: "{{ remote_tmp }}/test_empty.txt" + register: empty_file_stat + + - name: verify file without content exists + assert: + that: + - empty_file_stat.stat.exists + fail_msg: "The file {{ remote_tmp }}/test_empty.txt does not exist." + + - name: verify file without content is empty + assert: + that: + - empty_file_stat.stat.size == 0 + fail_msg: "The file {{ remote_tmp }}/test_empty.txt is not empty." diff --git a/tests/integration/targets/connection_wsl/runme.sh b/tests/integration/targets/connection_wsl/runme.sh new file mode 100755 index 00000000000..95759fbd84d --- /dev/null +++ b/tests/integration/targets/connection_wsl/runme.sh @@ -0,0 +1,25 @@ +#!/usr/bin/env bash +# Copyright (c) 2025 Nils Stein (@mietzen) +# Copyright (c) 2025 Ansible Project +# GNU General Public License v3.0+ (see LICENSES/GPL-3.0-or-later.txt or https://www.gnu.org/licenses/gpl-3.0.txt) +# SPDX-License-Identifier: GPL-3.0-or-later + +set -eux + +# signal the wsl connection plugin that its running under the integration testing mode. +# NB while running integration tests, the mock wsl.exe implementation is actually +# running on unix, instead of on running windows, so the wsl.exe command line +# construction must use unix rules instead of windows rules. +export _ANSIBLE_TEST_WSL_CONNECTION_PLUGIN_Waeri5tepheeSha2fae8=1 + +ANSIBLE_ROLES_PATH=../ \ + ansible-playbook dependencies.yml -v "$@" + +./test.sh "$@" + +ansible-playbook plugin-specific-tests.yml -i "./test_connection.inventory" \ + -e target_hosts=wsl \ + -e action_prefix= \ + -e local_tmp=/tmp/ansible-local \ + -e remote_tmp=/tmp/ansible-remote \ + "$@" diff --git a/tests/integration/targets/connection_wsl/test.sh b/tests/integration/targets/connection_wsl/test.sh new file mode 120000 index 00000000000..70aa5dbdba4 --- /dev/null +++ b/tests/integration/targets/connection_wsl/test.sh @@ -0,0 +1 @@ +../connection_posix/test.sh \ No newline at end of file diff --git a/tests/integration/targets/connection_wsl/test_connection.inventory b/tests/integration/targets/connection_wsl/test_connection.inventory new file mode 100644 index 00000000000..3fcfec13d37 --- /dev/null +++ b/tests/integration/targets/connection_wsl/test_connection.inventory @@ -0,0 +1,14 @@ +# Copyright (c) 2025 Nils Stein (@mietzen) +# Copyright (c) 2025 Ansible Project +# GNU General Public License v3.0+ (see LICENSES/GPL-3.0-or-later.txt or https://www.gnu.org/licenses/gpl-3.0.txt) +# SPDX-License-Identifier: GPL-3.0-or-later + +[wsl] +wsl-pipelining ansible_ssh_pipelining=true +wsl-no-pipelining ansible_ssh_pipelining=false +[wsl:vars] +ansible_host=localhost +ansible_user=root +ansible_python_interpreter="{{ ansible_playbook_python }}" +ansible_connection=community.general.wsl +wsl_distribution=test diff --git a/tests/sanity/ignore-2.15.txt b/tests/sanity/ignore-2.15.txt index f042b888e81..6115954d00d 100644 --- a/tests/sanity/ignore-2.15.txt +++ b/tests/sanity/ignore-2.15.txt @@ -1,4 +1,5 @@ .azure-pipelines/scripts/publish-codecov.py replace-urlopen +plugins/connection/wsl.py yamllint:unparsable-with-libyaml plugins/inventory/gitlab_runners.py yamllint:unparsable-with-libyaml plugins/inventory/iocage.py yamllint:unparsable-with-libyaml plugins/inventory/linode.py yamllint:unparsable-with-libyaml diff --git a/tests/sanity/ignore-2.16.txt b/tests/sanity/ignore-2.16.txt index 6f3a7f038e7..8ac70d76d7f 100644 --- a/tests/sanity/ignore-2.16.txt +++ b/tests/sanity/ignore-2.16.txt @@ -1,3 +1,4 @@ +plugins/connection/wsl.py yamllint:unparsable-with-libyaml plugins/inventory/gitlab_runners.py yamllint:unparsable-with-libyaml plugins/inventory/iocage.py yamllint:unparsable-with-libyaml plugins/inventory/linode.py yamllint:unparsable-with-libyaml diff --git a/tests/unit/plugins/connection/test_wsl.py b/tests/unit/plugins/connection/test_wsl.py new file mode 100644 index 00000000000..5646ae33b16 --- /dev/null +++ b/tests/unit/plugins/connection/test_wsl.py @@ -0,0 +1,589 @@ +# -*- coding: utf-8 -*- +# Derived from test_proxmox_pct_remote.py (c) 2024 Nils Stein (@mietzen) +# Copyright (c) 2025 Rui Lopes (@rgl) +# Copyright (c) 2025 Ansible Project +# GNU General Public License v3.0+ (see LICENSES/GPL-3.0-or-later.txt or https://www.gnu.org/licenses/gpl-3.0.txt) +# SPDX-License-Identifier: GPL-3.0-or-later + +from __future__ import (annotations, absolute_import, division, print_function) +__metaclass__ = type + +import os +import pytest + +from ansible_collections.community.general.plugins.connection.wsl import authenticity_msg, MyAddPolicy +from ansible_collections.community.general.plugins.module_utils._filelock import FileLock, LockTimeout +from ansible.errors import AnsibleError, AnsibleAuthenticationFailure, AnsibleConnectionFailure +from ansible.module_utils.common.text.converters import to_bytes +from ansible.module_utils.compat.paramiko import paramiko +from ansible.playbook.play_context import PlayContext +from ansible.plugins.loader import connection_loader +from io import StringIO +from pathlib import Path +from unittest.mock import patch, MagicMock, mock_open + + +@pytest.fixture +def connection(): + play_context = PlayContext() + in_stream = StringIO() + conn = connection_loader.get('community.general.wsl', play_context, in_stream) + conn.set_option('remote_addr', '192.168.1.100') + conn.set_option('remote_user', 'root') + conn.set_option('password', 'password') + conn.set_option('wsl_distribution', 'test') + return conn + + +def test_connection_options(connection): + """ Test that connection options are properly set """ + assert connection.get_option('remote_addr') == '192.168.1.100' + assert connection.get_option('remote_user') == 'root' + assert connection.get_option('password') == 'password' + assert connection.get_option('wsl_distribution') == 'test' + + +def test_authenticity_msg(): + """ Test authenticity message formatting """ + msg = authenticity_msg('test.host', 'ssh-rsa', 'AA:BB:CC:DD') + assert 'test.host' in msg + assert 'ssh-rsa' in msg + assert 'AA:BB:CC:DD' in msg + + +def test_missing_host_key(connection): + """ Test MyAddPolicy missing_host_key method """ + + client = MagicMock() + key = MagicMock() + key.get_fingerprint.return_value = b'fingerprint' + key.get_name.return_value = 'ssh-rsa' + + policy = MyAddPolicy(connection) + + connection.set_option('host_key_auto_add', True) + policy.missing_host_key(client, 'test.host', key) + assert hasattr(key, '_added_by_ansible_this_time') + + connection.set_option('host_key_auto_add', False) + connection.set_option('host_key_checking', False) + policy.missing_host_key(client, 'test.host', key) + + connection.set_option('host_key_checking', True) + connection.set_option('host_key_auto_add', False) + connection.set_option('use_persistent_connections', False) + + with patch('ansible.utils.display.Display.prompt_until', return_value='yes'): + policy.missing_host_key(client, 'test.host', key) + + with patch('ansible.utils.display.Display.prompt_until', return_value='no'): + with pytest.raises(AnsibleError, match='host connection rejected by user'): + policy.missing_host_key(client, 'test.host', key) + + +def test_set_log_channel(connection): + """ Test setting log channel """ + connection._set_log_channel('test_channel') + assert connection._log_channel == 'test_channel' + + +def test_parse_proxy_command(connection): + """ Test proxy command parsing """ + connection.set_option('proxy_command', 'ssh -W %h:%p proxy.example.com') + connection.set_option('remote_addr', 'target.example.com') + connection.set_option('remote_user', 'testuser') + + result = connection._parse_proxy_command(port=2222) + assert 'sock' in result + assert isinstance(result['sock'], paramiko.ProxyCommand) + + +@patch('paramiko.SSHClient') +def test_connect_with_rsa_sha2_disabled(mock_ssh, connection): + """ Test connection with RSA SHA2 algorithms disabled """ + connection.set_option('use_rsa_sha2_algorithms', False) + mock_client = MagicMock() + mock_ssh.return_value = mock_client + + connection._connect() + + call_kwargs = mock_client.connect.call_args[1] + assert 'disabled_algorithms' in call_kwargs + assert 'pubkeys' in call_kwargs['disabled_algorithms'] + + +@patch('paramiko.SSHClient') +def test_connect_with_bad_host_key(mock_ssh, connection): + """ Test connection with bad host key """ + mock_client = MagicMock() + mock_ssh.return_value = mock_client + mock_client.connect.side_effect = paramiko.ssh_exception.BadHostKeyException( + 'hostname', MagicMock(), MagicMock()) + + with pytest.raises(AnsibleConnectionFailure, match='host key mismatch'): + connection._connect() + + +@patch('paramiko.SSHClient') +def test_connect_with_invalid_host_key(mock_ssh, connection): + """ Test connection with bad host key """ + connection.set_option('host_key_checking', True) + mock_client = MagicMock() + mock_ssh.return_value = mock_client + mock_client.load_system_host_keys.side_effect = paramiko.hostkeys.InvalidHostKey( + "Bad Line!", Exception('Something crashed!')) + + with pytest.raises(AnsibleConnectionFailure, match="Invalid host key: Bad Line!"): + connection._connect() + + +@patch('paramiko.SSHClient') +def test_connect_success(mock_ssh, connection): + """ Test successful SSH connection establishment """ + mock_client = MagicMock() + mock_ssh.return_value = mock_client + + connection._connect() + + assert mock_client.connect.called + assert connection._connected + + +@patch('paramiko.SSHClient') +def test_connect_authentication_failure(mock_ssh, connection): + """ Test SSH connection with authentication failure """ + mock_client = MagicMock() + mock_ssh.return_value = mock_client + mock_client.connect.side_effect = paramiko.ssh_exception.AuthenticationException('Auth failed') + + with pytest.raises(AnsibleAuthenticationFailure): + connection._connect() + + +def test_any_keys_added(connection): + """ Test checking for added host keys """ + connection.ssh = MagicMock() + connection.ssh._host_keys = { + 'host1': { + 'ssh-rsa': MagicMock(_added_by_ansible_this_time=True), + 'ssh-ed25519': MagicMock(_added_by_ansible_this_time=False) + } + } + + assert connection._any_keys_added() is True + + connection.ssh._host_keys = { + 'host1': { + 'ssh-rsa': MagicMock(_added_by_ansible_this_time=False) + } + } + assert connection._any_keys_added() is False + + +@patch('os.path.exists') +@patch('os.stat') +@patch('tempfile.NamedTemporaryFile') +def test_save_ssh_host_keys(mock_tempfile, mock_stat, mock_exists, connection): + """ Test saving SSH host keys """ + mock_exists.return_value = True + mock_stat.return_value = MagicMock(st_mode=0o644, st_uid=1000, st_gid=1000) + mock_tempfile.return_value.__enter__.return_value.name = '/tmp/test_keys' + + connection.ssh = MagicMock() + connection.ssh._host_keys = { + 'host1': { + 'ssh-rsa': MagicMock( + get_base64=lambda: 'KEY1', + _added_by_ansible_this_time=True + ) + } + } + + mock_open_obj = mock_open() + with patch('builtins.open', mock_open_obj): + connection._save_ssh_host_keys('/tmp/test_keys') + + mock_open_obj().write.assert_called_with('host1 ssh-rsa KEY1\n') + + +def test_build_wsl_command(connection): + """ Test wsl command building with different users """ + cmd = connection._build_wsl_command('/bin/sh -c "ls -la"') + assert cmd == 'wsl.exe --distribution test -- /bin/sh -c "ls -la"' + + connection.set_option('wsl_user', 'test-user') + cmd = connection._build_wsl_command('/bin/sh -c "ls -la"') + assert cmd == 'wsl.exe --distribution test --user test-user -- /bin/sh -c "ls -la"' + + connection.set_option('become', True) + connection.set_option('become_user', 'test-become-user') + cmd = connection._build_wsl_command('/bin/sh -c "ls -la"') + assert cmd == 'wsl.exe --distribution test --user test-become-user -- /bin/sh -c "ls -la"' + + +@patch('paramiko.SSHClient') +def test_exec_command_success(mock_ssh, connection): + """ Test successful command execution """ + mock_client = MagicMock() + mock_ssh.return_value = mock_client + mock_channel = MagicMock() + mock_transport = MagicMock() + + mock_client.get_transport.return_value = mock_transport + mock_transport.open_session.return_value = mock_channel + mock_channel.recv_exit_status.return_value = 0 + mock_channel.makefile.return_value = [to_bytes('stdout')] + mock_channel.makefile_stderr.return_value = [to_bytes("")] + + connection._connected = True + connection.ssh = mock_client + + returncode, stdout, stderr = connection.exec_command('ls -la') + + mock_transport.open_session.assert_called_once() + mock_transport.set_keepalive.assert_called_once_with(5) + + +@patch('paramiko.SSHClient') +def test_exec_command_wsl_not_found(mock_ssh, connection): + """ Test command execution when wsl.exe is not found """ + mock_client = MagicMock() + mock_ssh.return_value = mock_client + mock_channel = MagicMock() + mock_transport = MagicMock() + + mock_client.get_transport.return_value = mock_transport + mock_transport.open_session.return_value = mock_channel + mock_channel.recv_exit_status.return_value = 1 + mock_channel.makefile.return_value = [to_bytes("")] + mock_channel.makefile_stderr.return_value = [to_bytes("'wsl.exe' is not recognized")] + + connection._connected = True + connection.ssh = mock_client + + with pytest.raises(AnsibleError, match='wsl.exe not found in path of host'): + connection.exec_command('ls -la') + + +@patch('paramiko.SSHClient') +def test_exec_command_session_open_failure(mock_ssh, connection): + """ Test exec_command when session opening fails """ + mock_client = MagicMock() + mock_transport = MagicMock() + mock_transport.open_session.side_effect = Exception('Failed to open session') + mock_client.get_transport.return_value = mock_transport + + connection._connected = True + connection.ssh = mock_client + + with pytest.raises(AnsibleConnectionFailure, match='Failed to open session'): + connection.exec_command('test command') + + +@patch('paramiko.SSHClient') +def test_exec_command_with_privilege_escalation(mock_ssh, connection): + """ Test exec_command with privilege escalation """ + mock_client = MagicMock() + mock_channel = MagicMock() + mock_transport = MagicMock() + + mock_client.get_transport.return_value = mock_transport + mock_transport.open_session.return_value = mock_channel + connection._connected = True + connection.ssh = mock_client + + connection.become = MagicMock() + connection.become.expect_prompt.return_value = True + connection.become.check_success.return_value = False + connection.become.check_password_prompt.return_value = True + connection.become.get_option.return_value = 'sudo_password' + + mock_channel.recv.return_value = b'[sudo] password:' + mock_channel.recv_exit_status.return_value = 0 + mock_channel.makefile.return_value = [b""] + mock_channel.makefile_stderr.return_value = [b""] + + returncode, stdout, stderr = connection.exec_command('sudo test command') + + mock_channel.sendall.assert_called_once_with(b'sudo_password\n') + + +def test_put_file(connection): + """ Test putting a file to the remote system """ + connection.exec_command = MagicMock() + connection.exec_command.return_value = (0, b"", b"") + + with patch('builtins.open', create=True) as mock_open: + mock_open.return_value.__enter__.return_value.read.return_value = b'test content' + connection.put_file('/local/path', '/remote/path') + + connection.exec_command.assert_called_once_with("/bin/sh -c 'cat > /remote/path'", in_data=b'test content', sudoable=False) + + +@patch('paramiko.SSHClient') +def test_put_file_general_error(mock_ssh, connection): + """ Test put_file with general error """ + mock_client = MagicMock() + mock_ssh.return_value = mock_client + mock_channel = MagicMock() + mock_transport = MagicMock() + + mock_client.get_transport.return_value = mock_transport + mock_transport.open_session.return_value = mock_channel + mock_channel.recv_exit_status.return_value = 1 + mock_channel.makefile.return_value = [to_bytes("")] + mock_channel.makefile_stderr.return_value = [to_bytes('Some error')] + + connection._connected = True + connection.ssh = mock_client + + with pytest.raises(AnsibleError, match='error occurred while putting file from /remote/path to /local/path'): + connection.put_file('/remote/path', '/local/path') + + +@patch('paramiko.SSHClient') +def test_put_file_cat_not_found(mock_ssh, connection): + """ Test command execution when cat is not found """ + mock_client = MagicMock() + mock_ssh.return_value = mock_client + mock_channel = MagicMock() + mock_transport = MagicMock() + + mock_client.get_transport.return_value = mock_transport + mock_transport.open_session.return_value = mock_channel + mock_channel.recv_exit_status.return_value = 1 + mock_channel.makefile.return_value = [to_bytes("")] + mock_channel.makefile_stderr.return_value = [to_bytes('cat: not found')] + + connection._connected = True + connection.ssh = mock_client + + with pytest.raises(AnsibleError, match='cat not found in path of WSL distribution'): + connection.fetch_file('/remote/path', '/local/path') + + +def test_fetch_file(connection): + """ Test fetching a file from the remote system """ + connection.exec_command = MagicMock() + connection.exec_command.return_value = (0, b'test content', b"") + + with patch('builtins.open', create=True) as mock_open: + connection.fetch_file('/remote/path', '/local/path') + + connection.exec_command.assert_called_once_with("/bin/sh -c 'cat /remote/path'", sudoable=False) + mock_open.assert_called_with('/local/path', 'wb') + + +@patch('paramiko.SSHClient') +def test_fetch_file_general_error(mock_ssh, connection): + """ Test fetch_file with general error """ + mock_client = MagicMock() + mock_ssh.return_value = mock_client + mock_channel = MagicMock() + mock_transport = MagicMock() + + mock_client.get_transport.return_value = mock_transport + mock_transport.open_session.return_value = mock_channel + mock_channel.recv_exit_status.return_value = 1 + mock_channel.makefile.return_value = [to_bytes("")] + mock_channel.makefile_stderr.return_value = [to_bytes('Some error')] + + connection._connected = True + connection.ssh = mock_client + + with pytest.raises(AnsibleError, match='error occurred while fetching file from /remote/path to /local/path'): + connection.fetch_file('/remote/path', '/local/path') + + +@patch('paramiko.SSHClient') +def test_fetch_file_cat_not_found(mock_ssh, connection): + """ Test command execution when cat is not found """ + mock_client = MagicMock() + mock_ssh.return_value = mock_client + mock_channel = MagicMock() + mock_transport = MagicMock() + + mock_client.get_transport.return_value = mock_transport + mock_transport.open_session.return_value = mock_channel + mock_channel.recv_exit_status.return_value = 1 + mock_channel.makefile.return_value = [to_bytes("")] + mock_channel.makefile_stderr.return_value = [to_bytes('cat: not found')] + + connection._connected = True + connection.ssh = mock_client + + with pytest.raises(AnsibleError, match='cat not found in path of WSL distribution'): + connection.fetch_file('/remote/path', '/local/path') + + +def test_close(connection): + """ Test connection close """ + mock_ssh = MagicMock() + connection.ssh = mock_ssh + connection._connected = True + + connection.close() + + assert mock_ssh.close.called, 'ssh.close was not called' + assert not connection._connected, 'self._connected is still True' + + +def test_close_with_lock_file(connection): + """ Test close method with lock file creation """ + connection._any_keys_added = MagicMock(return_value=True) + connection._connected = True + connection.keyfile = '/tmp/wsl-known_hosts-test' + connection.set_option('host_key_checking', True) + connection.set_option('lock_file_timeout', 5) + connection.set_option('record_host_keys', True) + connection.ssh = MagicMock() + + lock_file_path = os.path.join(os.path.dirname(connection.keyfile), + f'ansible-{os.path.basename(connection.keyfile)}.lock') + + try: + connection.close() + assert os.path.exists(lock_file_path), 'Lock file was not created' + + lock_stat = os.stat(lock_file_path) + assert lock_stat.st_mode & 0o777 == 0o600, 'Incorrect lock file permissions' + finally: + Path(lock_file_path).unlink(missing_ok=True) + + +@patch('pathlib.Path.unlink') +@patch('os.path.exists') +def test_close_lock_file_time_out_error_handling(mock_exists, mock_unlink, connection): + """ Test close method with lock file timeout error """ + connection._any_keys_added = MagicMock(return_value=True) + connection._connected = True + connection._save_ssh_host_keys = MagicMock() + connection.keyfile = '/tmp/wsl-known_hosts-test' + connection.set_option('host_key_checking', True) + connection.set_option('lock_file_timeout', 5) + connection.set_option('record_host_keys', True) + connection.ssh = MagicMock() + + mock_exists.return_value = False + matcher = f'writing lock file for {connection.keyfile} ran in to the timeout of {connection.get_option("lock_file_timeout")}s' + with pytest.raises(AnsibleError, match=matcher): + with patch('os.getuid', return_value=1000), \ + patch('os.getgid', return_value=1000), \ + patch('os.chmod'), patch('os.chown'), \ + patch('os.rename'), \ + patch.object(FileLock, 'lock_file', side_effect=LockTimeout()): + connection.close() + + +@patch('ansible_collections.community.general.plugins.module_utils._filelock.FileLock.lock_file') +@patch('tempfile.NamedTemporaryFile') +@patch('os.chmod') +@patch('os.chown') +@patch('os.rename') +@patch('os.path.exists') +def test_tempfile_creation_and_move(mock_exists, mock_rename, mock_chown, mock_chmod, mock_tempfile, mock_lock_file, connection): + """ Test tempfile creation and move during close """ + connection._any_keys_added = MagicMock(return_value=True) + connection._connected = True + connection._save_ssh_host_keys = MagicMock() + connection.keyfile = '/tmp/wsl-known_hosts-test' + connection.set_option('host_key_checking', True) + connection.set_option('lock_file_timeout', 5) + connection.set_option('record_host_keys', True) + connection.ssh = MagicMock() + + mock_exists.return_value = False + + mock_lock_file_instance = MagicMock() + mock_lock_file.return_value = mock_lock_file_instance + mock_lock_file_instance.__enter__.return_value = None + + mock_tempfile_instance = MagicMock() + mock_tempfile_instance.name = '/tmp/mock_tempfile' + mock_tempfile.return_value.__enter__.return_value = mock_tempfile_instance + + mode = 0o644 + uid = 1000 + gid = 1000 + key_dir = os.path.dirname(connection.keyfile) + + with patch('os.getuid', return_value=uid), patch('os.getgid', return_value=gid): + connection.close() + + connection._save_ssh_host_keys.assert_called_once_with('/tmp/mock_tempfile') + mock_chmod.assert_called_once_with('/tmp/mock_tempfile', mode) + mock_chown.assert_called_once_with('/tmp/mock_tempfile', uid, gid) + mock_rename.assert_called_once_with('/tmp/mock_tempfile', connection.keyfile) + mock_tempfile.assert_called_once_with(dir=key_dir, delete=False) + + +@patch('pathlib.Path.unlink') +@patch('tempfile.NamedTemporaryFile') +@patch('ansible_collections.community.general.plugins.module_utils._filelock.FileLock.lock_file') +@patch('os.path.exists') +def test_close_tempfile_error_handling(mock_exists, mock_lock_file, mock_tempfile, mock_unlink, connection): + """ Test tempfile creation error """ + connection._any_keys_added = MagicMock(return_value=True) + connection._connected = True + connection._save_ssh_host_keys = MagicMock() + connection.keyfile = '/tmp/wsl-known_hosts-test' + connection.set_option('host_key_checking', True) + connection.set_option('lock_file_timeout', 5) + connection.set_option('record_host_keys', True) + connection.ssh = MagicMock() + + mock_exists.return_value = False + + mock_lock_file_instance = MagicMock() + mock_lock_file.return_value = mock_lock_file_instance + mock_lock_file_instance.__enter__.return_value = None + + mock_tempfile_instance = MagicMock() + mock_tempfile_instance.name = '/tmp/mock_tempfile' + mock_tempfile.return_value.__enter__.return_value = mock_tempfile_instance + + with pytest.raises(AnsibleError, match='error occurred while writing SSH host keys!'): + with patch.object(os, 'chmod', side_effect=Exception()): + connection.close() + mock_unlink.assert_called_with(missing_ok=True) + + +@patch('ansible_collections.community.general.plugins.module_utils._filelock.FileLock.lock_file') +@patch('os.path.exists') +def test_close_with_invalid_host_key(mock_exists, mock_lock_file, connection): + """ Test load_system_host_keys on close with InvalidHostKey error """ + connection._any_keys_added = MagicMock(return_value=True) + connection._connected = True + connection._save_ssh_host_keys = MagicMock() + connection.keyfile = '/tmp/wsl-known_hosts-test' + connection.set_option('host_key_checking', True) + connection.set_option('lock_file_timeout', 5) + connection.set_option('record_host_keys', True) + connection.ssh = MagicMock() + connection.ssh.load_system_host_keys.side_effect = paramiko.hostkeys.InvalidHostKey( + "Bad Line!", Exception('Something crashed!')) + + mock_exists.return_value = False + + mock_lock_file_instance = MagicMock() + mock_lock_file.return_value = mock_lock_file_instance + mock_lock_file_instance.__enter__.return_value = None + + with pytest.raises(AnsibleConnectionFailure, match="Invalid host key: Bad Line!"): + connection.close() + + +def test_reset(connection): + """ Test connection reset """ + connection._connected = True + connection.close = MagicMock() + connection._connect = MagicMock() + + connection.reset() + + connection.close.assert_called_once() + connection._connect.assert_called_once() + + connection._connected = False + connection.reset() + assert connection.close.call_count == 1