Skip to content

connection/aws_ssm - add caching and caching_ttl options to improve velocity #2278

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 2 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
@@ -0,0 +1,3 @@
---
minor_changes:
- aws_ssm - New options ``ansible_aws_ssm_caching`` and ``ansible_aws_ssm_caching_ttl`` to improve velocity of the connection plugin (https://github.com/ansible-collections/community.aws/pull/2278).
32 changes: 26 additions & 6 deletions plugins/connection/aws_ssm.py
Original file line number Diff line number Diff line change
Expand Up @@ -148,6 +148,23 @@
version_added: 5.2.0
vars:
- name: ansible_aws_ssm_s3_addressing_style
caching:
description:
- The plugin will create a daemon starting a shell session to handle all command
sent to the managed host.
type: boolean
default: False
version_added: 10.0.0
vars:
- name: ansible_aws_ssm_caching
caching_ttl:
description:
- The time (in seconds) the daemon will wait before exit when there is no incoming request.
type: int
default: 30
version_added: 10.0.0
vars:
- name: ansible_aws_ssm_caching_ttl
"""

EXAMPLES = r"""
Expand Down Expand Up @@ -362,6 +379,7 @@

from ansible_collections.community.aws.plugins.plugin_utils.s3clientmanager import S3ClientManager
from ansible_collections.community.aws.plugins.plugin_utils.terminalmanager import TerminalManager
from ansible_collections.community.aws.plugins.plugin_utils.cache_client import exec_command_using_caching

display = Display()

Expand Down Expand Up @@ -502,8 +520,9 @@ def __del__(self) -> None:
def _connect(self) -> Any:
"""connect to the host via ssm"""
self._play_context.remote_user = getpass.getuser()

if not self._session_id:
self._init_clients()
caching = self.get_option("caching")
if not caching and not self._session_id:
self.start_session()
return self

Expand Down Expand Up @@ -572,8 +591,9 @@ def verbosity_display(self, level: int, message: str) -> None:
def reset(self) -> Any:
"""start a fresh ssm session"""
self.verbosity_display(4, "reset called on ssm connection")
self.close()
return self.start_session()
if not self.get_option("caching"):
self.close()
return self.start_session()

@property
def instance_id(self) -> str:
Expand Down Expand Up @@ -608,8 +628,6 @@ def start_session(self):

executable = self.get_executable()

self._init_clients()

self.verbosity_display(4, f"START SSM SESSION: {self.instance_id}")
start_session_args = dict(Target=self.instance_id, Parameters={})
document_name = self.get_option("ssm_document")
Expand Down Expand Up @@ -733,6 +751,8 @@ def exec_command(self, cmd: str, in_data: bool = None, sudoable: bool = True) ->
"""When running a command on the SSM host, uses generate_mark to get delimiting strings"""

super().exec_command(cmd, in_data=in_data, sudoable=sudoable)
if self.get_option("caching"):
return exec_command_using_caching(self, cmd)

self.verbosity_display(3, f"EXEC: {to_text(cmd)}")

Expand Down
143 changes: 143 additions & 0 deletions plugins/plugin_utils/cache_client.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,143 @@
# -*- coding: utf-8 -*-

# Copyright: Contributors to the Ansible project
# GNU General Public License v3.0+ (see COPYING or https://www.gnu.org/licenses/gpl-3.0.txt)

import json
import os
import pickle
import socket
import subprocess
import sys
import time
from typing import Any
from typing import Dict

from ansible.errors import AnsibleFileNotFound
from ansible.errors import AnsibleRuntimeError
from ansible.plugins.shell.powershell import _common_args


def _create_socket_path(instance_id: str, region_name: str) -> str:
return os.path.join(
os.environ["HOME"], ".ansible", "_".join(["connection_aws_ssm_caching", instance_id, region_name])
)


class SSMCachingSocket:
def __init__(self, conn_plugin: Any):
self.verbosity_display = conn_plugin.verbosity_display
self._verbosity_level = 1
self._region = conn_plugin.get_option("region") or "us-east-1"
self._socket_path = _create_socket_path(conn_plugin.instance_id, self._region)
self.verbosity_display(self._verbosity_level, f">>> SSM Caching Socket path = {self._socket_path}")
self.conn_plugin = conn_plugin
self._socket = None
self._bind()

def _bind(self):
running = False
self._socket = socket.socket(socket.AF_UNIX, socket.SOCK_STREAM)
for attempt in range(100, -1, -1):
try:
self._socket.connect(self._socket_path)
return True
except (ConnectionRefusedError, FileNotFoundError):
if not running:
running = self.start_server()
if attempt == 0:
raise
time.sleep(0.01)

def _mask_command(self, command: str) -> str:
if self.conn_plugin.get_option("access_key_id"):
command = command.replace(self.conn_plugin.get_option("access_key_id"), "*****")
if self.conn_plugin.get_option("secret_access_key"):
command = command.replace(self.conn_plugin.get_option("secret_access_key"), "*****")
if self.conn_plugin.get_option("session_token"):
command = command.replace(self.conn_plugin.get_option("session_token"), "*****")
return command

def start_server(self):
env = os.environ
parameters = [
"--fork",
"--socket-path",
self._socket_path,
"--region",
self._region,
"--executable",
self.conn_plugin.get_executable(),
]

pairing_options = {
"--instance-id": "instance_id",
"--ssm-timeout": "ssm_timeout",
"--reconnection-retries": "reconnection_retries",
"--access-key-id": "access_key_id",
"--secret-access-key": "secret_access_key",
"--session-token": "session_token",
"--profile": "profile",
"--ssm-document": "ssm_document",
"--is-windows": "is_windows",
"--ttl": "caching_ttl",
}
for opt, attr in pairing_options.items():
if hasattr(self.conn_plugin, attr):
if opt_value := getattr(self.conn_plugin, attr):
parameters.extend([opt, str(opt_value)])
elif opt_value := self.conn_plugin.get_option(attr):
parameters.extend([opt, str(opt_value)])

command = [sys.executable]
ansiblez_path = sys.path[0]
env.update({"PYTHONPATH": ansiblez_path})
parent_dir = os.path.dirname(__file__)
server_path = os.path.join(parent_dir, "cache_server.py")
if not os.path.exists(server_path):
raise AnsibleFileNotFound(f"The socket does not exist at expected path = '{server_path}'")
command += [server_path]
displayed_command = self._mask_command(" ".join(command + parameters))
self.verbosity_display(self._verbosity_level, f">>> SSM Caching socket command = '{displayed_command}'")
p = subprocess.Popen(
command + parameters,
env=env,
close_fds=True,
)
p.communicate()
self.verbosity_display(self._verbosity_level, f">>> SSM Caching socket process pid = '{p.pid}'")
return p.pid

def communicate(self, command, wait_sleep=0.01):
encoded_data = pickle.dumps(command)
self._socket.sendall(encoded_data)
self._socket.shutdown(socket.SHUT_WR)
raw_answer = b""
while True:
b = self._socket.recv((1024 * 1024))
if not b:
break
raw_answer += b
time.sleep(wait_sleep)
try:
result = json.loads(raw_answer.decode())
return result
except json.decoder.JSONDecodeError:
raise AnsibleRuntimeError(f"Cannot decode exec_command answer: {raw_answer}")

def __enter__(self) -> Any:
return self

def __exit__(self, type, value, traceback):
if self._socket:
self._socket.close()


def exec_command_using_caching(conn_plugin: Any, cmd: str) -> Dict:
with SSMCachingSocket(conn_plugin) as cache:
# Encode Windows command
if conn_plugin.is_windows:
if not cmd.startswith(" ".join(_common_args) + " -EncodedCommand"):
cmd = conn_plugin._shell._encode_script(cmd, preserve_rc=True)
result = cache.communicate(command=cmd)
return result.get("returncode"), result.get("stdout"), result.get("stderr")
Loading
Loading