Skip to content

DNM - connection/aw_sss - Use Port forwarding session for file transfer instead of S3 bucket #2265

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
---
minor_changes:
- aws_ssm - Add for file transfer using SSM port forwarding session with netcat for Linux/MacOS EC2 managed nodes (https://github.com/ansible-collections/community.aws/pull/2265).
337 changes: 89 additions & 248 deletions plugins/connection/aws_ssm.py

Large diffs are not rendered by default.

213 changes: 213 additions & 0 deletions plugins/plugin_utils/ssm/command.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,213 @@
# -*- coding: utf-8 -*-

# Copyright: Ansible Project
# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt)

import argparse
import asyncio
import json
import os
import pickle
import pty
import random
import re
import select
import signal
import string
import subprocess
import sys
import traceback
import uuid
from datetime import datetime
from functools import wraps
from typing import Any
from typing import Callable
from typing import Iterator
from typing import List
from typing import Tuple

try:
import boto3
except ImportError:
pass

from ansible.module_utils._text import to_bytes
from ansible.module_utils._text import to_text
from ansible.plugins.shell.powershell import _common_args

from .common import SSMDisplay
from .common import StdoutPoller


@staticmethod
def generate_mark() -> str:
"""Generates a random string of characters to delimit SSM CLI commands"""
mark = "".join([random.choice(string.ascii_letters) for i in range(26)])
return mark


def chunks(lst: List, n: int) -> Iterator[List[Any]]:
"""Yield successive n-sized chunks from lst."""
for i in range(0, len(lst), n):
yield lst[i:i + n] # fmt: skip


def filter_ansi(line: str, is_windows: bool) -> str:
"""Remove any ANSI terminal control codes.

:param line: The input line.
:param is_windows: Whether the output is coming from a Windows host.
:returns: The result line.
"""
line = to_text(line)

if is_windows:
osc_filter = re.compile(r"\x1b\][^\x07]*\x07")
line = osc_filter.sub("", line)
ansi_filter = re.compile(r"(\x9B|\x1B\[)[0-?]*[ -/]*[@-~]")
line = ansi_filter.sub("", line)

# Replace or strip sequence (at terminal width)
line = line.replace("\r\r\n", "\n")
if len(line) == 201:
line = line[:-1]

return line


def encode_script(shell: Any, cmd: str) -> str:
result = cmd
if getattr(shell, "SHELL_FAMILY", "") == "powershell" and not cmd.startswith(
" ".join(_common_args) + " -EncodedCommand"
):
result = shell._encode_script(cmd, preserve_rc=True)
return result


class CommandManager(SSMDisplay):
def __init__(
self, is_windows: bool, session: Any, stdout_r: Any, ssm_timeout: int, verbosity_display: Callable
) -> None:
super(CommandManager, self).__init__(verbosity_display=verbosity_display)
stdout = os.fdopen(stdout_r, "rb", 0)
poller = select.poll()
poller.register(stdout, select.POLLIN)
self._poller = StdoutPoller(session=session, stdout=stdout, poller=poller, timeout=ssm_timeout)
self.is_windows = is_windows

@property
def poller(self) -> Any:
return self._poller

@property
def has_timeout(self) -> bool:
return self._poller._has_timeout

def _wrap_command(self, cmd: str, mark_start: str, mark_end: str) -> str:
"""Wrap command so stdout and status can be extracted"""
if self.is_windows:
cmd = cmd + "; echo " + mark_start + "\necho " + mark_end + "\n"
else:
cmd = (
f"printf '%s\\n' '{mark_start}';\n"
f"echo | {cmd};\n"
f"printf '\\n%s\\n%s\\n' \"$?\" '{mark_end}';\n"
) # fmt: skip

self.verbosity_display(4, f"_wrap_command: \n'{to_text(cmd)}'")
return cmd

def _post_process(self, stdout: str, mark_begin: str) -> Tuple[str, str]:
"""extract command status and strip unwanted lines"""

if not self.is_windows:
# Get command return code
returncode = int(stdout.splitlines()[-2])

# Throw away final lines
for _x in range(0, 3):
stdout = stdout[:stdout.rfind('\n')] # fmt: skip

return (returncode, stdout)

# Windows is a little more complex
# Value of $LASTEXITCODE will be the line after the mark
trailer = stdout[stdout.rfind(mark_begin):] # fmt: skip
last_exit_code = trailer.splitlines()[1]
if last_exit_code.isdigit:
returncode = int(last_exit_code)
else:
returncode = -1
# output to keep will be before the mark
stdout = stdout[:stdout.rfind(mark_begin)] # fmt: skip

# If the return code contains #CLIXML (like a progress bar) remove it
clixml_filter = re.compile(r"#<\sCLIXML\s<Objs.*</Objs>")
stdout = clixml_filter.sub("", stdout)

# If it looks like JSON remove any newlines
if stdout.startswith("{"):
stdout = stdout.replace("\n", "")

return (returncode, stdout)

def _exec_communicate(self, mark_start: str, mark_begin: str, mark_end: str) -> Tuple[int, str, str]:
"""Interact with session.
Read stdout between the markers until 'mark_end' is reached.

:param cmd: The command being executed.
:param mark_start: The marker which starts the output.
:param mark_begin: The begin marker.
:param mark_end: The end marker.
:returns: A tuple with the return code, the stdout and the stderr content.
"""
# Read stdout between the markers
stdout = ""
win_line = ""
begin = False
returncode = None
for poll_result in self._poller.poll():
if not poll_result:
continue

line = filter_ansi(self._poller.readline(), self.is_windows)
self.verbosity_display(4, f"EXEC stdout line: \n{line}")

if not begin and self.is_windows:
win_line = win_line + line
line = win_line

if mark_start in line:
begin = True
if not line.startswith(mark_start):
stdout = ""
continue
if begin:
if mark_end in line:
self.verbosity_display(4, f"POST_PROCESS: \n{to_text(stdout)}")
returncode, stdout = self._post_process(stdout, mark_begin)
self.verbosity_display(4, f"POST_PROCESSED: \n{to_text(stdout)}")
break
stdout = stdout + line

# see https://github.com/pylint-dev/pylint/issues/8909)
return (returncode, stdout, self._poller.flush_stderr()) # pylint: disable=unreachable

def exec_command(self, cmd: str) -> Tuple[int, str, str]:
self.verbosity_display(3, f"EXEC: {to_text(cmd)}")

mark_begin = generate_mark()
if self.is_windows:
mark_start = mark_begin + " $LASTEXITCODE"
else:
mark_start = mark_begin
mark_end = generate_mark()

# Wrap command in markers accordingly for the shell used
cmd = self._wrap_command(cmd, mark_start, mark_end)

self._poller.flush_stderr()
for chunk in chunks(cmd, 1024):
self._poller.stdin_write(to_bytes(chunk, errors="surrogate_or_strict"))

return self._exec_communicate(mark_start, mark_begin, mark_end)
159 changes: 159 additions & 0 deletions plugins/plugin_utils/ssm/common.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,159 @@
# -*- coding: utf-8 -*-

# Copyright: Ansible Project
# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt)

import json
import os
import pty
import select
import subprocess
import time
from typing import Any
from typing import Callable
from typing import Dict
from typing import NoReturn
from typing import Optional
from typing import TypedDict
from typing import Union

from ansible.errors import AnsibleConnectionFailure


class CommandResult(TypedDict):
"""
A dictionary that contains the executed command results.
"""

returncode: int
stdout_combined: str
stderr_combined: str


class SSMDisplay:
def __init__(self, verbosity_display: Callable[[int, str], None]):
self.verbosity_display = verbosity_display


class StdoutPoller:
def __init__(self, session: Any, stdout: Any, poller: Any, timeout: int) -> None:
self._stdout = stdout
self._poller = poller
self._session = session
self._timeout = timeout
self._has_timeout = False

def readline(self):
return self._stdout.readline()

def has_data(self, timeout: int = 1000) -> bool:
return bool(self._poller.poll(timeout))

def read_stdout(self, length: int = 1024) -> str:
return self._stdout.read(length).decode("utf-8")

def stdin_write(self, value: Union[str | bytes]) -> None:
self._session.stdin.write(value)

def poll(self) -> NoReturn:
start = round(time.time())
yield self.has_data()
while self._session.poll() is None:
remaining = start + self._timeout - round(time.time())
if remaining < 0:
self._has_timeout = True
raise AnsibleConnectionFailure("StdoutPoller timeout...")
yield self.has_data()

def match_expr(self, expr: Union[str, callable]) -> str:
time_start = time.time()
content = ""
while (int(time.time()) - time_start) < self._timeout:
if self.poll():
content += self.read_stdout()
if callable(expr):
if expr(content):
return content
elif expr in content:
return content
raise TimeoutError(f"Unable to match expr '{expr}' from content")

def flush_stderr(self) -> str:
"""read and return stderr with minimal blocking"""

poll_stderr = select.poll()
poll_stderr.register(self._session.stderr, select.POLLIN)
stderr = ""
while self._session.poll() is None:
if not poll_stderr.poll(1):
break
line = self._session.stderr.readline()
stderr = stderr + line
return stderr


class SSMSessionManager(SSMDisplay):
def __init__(
self,
client: Any,
instance_id: str,
executable: str,
region: Optional[str],
profile: Optional[str],
ssm_timeout: int,
verbosity_display: Callable,
document_name: Optional[str] = None,
document_parameters: Optional[Dict] = None,
):
super(SSMSessionManager, self).__init__(verbosity_display=verbosity_display)

self._client = client
params = {"Target": instance_id}
if document_name:
params["DocumentName"] = document_name
if document_parameters:
params["Parameters"] = document_parameters

try:
response = self._client.start_session(**params)
self._session_id = response["SessionId"]
self.verbosity_display(4, f"Start session - SessionId: {self._session_id}")

cmd = [
executable,
json.dumps(response),
region,
"StartSession",
profile,
json.dumps({"Target": instance_id}),
self._client.meta.endpoint_url,
]

stdout_r, stdout_w = pty.openpty()
self._session = subprocess.Popen(
cmd,
stdin=subprocess.PIPE,
stdout=stdout_w,
stderr=subprocess.PIPE,
close_fds=True,
bufsize=0,
)

os.close(stdout_w)
stdout = os.fdopen(stdout_r, "rb", 0)
self._poller = StdoutPoller(
session=self._session,
stdout=stdout,
poller=select.poll().register(stdout, select.POLLIN),
timeout=ssm_timeout,
)
except Exception as e:
raise AnsibleConnectionFailure(f"failed to start session: {e}")

def __del__(self):
if self._session_id:
self._display.vvvv(f"Terminating AWS Session: {self._session_id}")
self._client.terminate_session(SessionId=self._session_id)
if self._session:
self._display.vvvv("Terminating subprocess.Popen session")
self._session.terminate()
Loading
Loading