Skip to content

Commit 56b0886

Browse files
aws_ssm: Refactor _init_clients Method (#2223)
SUMMARY Refer: https://issues.redhat.com/browse/ACA-2092 This PR Refactors the _init_clients method ISSUE TYPE Bugfix Pull Request Docs Pull Request Feature Pull Request New Module Pull Request COMPONENT NAME ADDITIONAL INFORMATION Reviewed-by: Mark Chappell Reviewed-by: Alina Buzachis
1 parent cfdcc05 commit 56b0886

File tree

3 files changed

+173
-59
lines changed

3 files changed

+173
-59
lines changed
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,2 @@
1+
minor_changes:
2+
- aws_ssm - Refactor _init_clients Method for Improved Clarity and Efficiency (https://github.com/ansible-collections/community.aws/pull/2223).

plugins/connection/aws_ssm.py

+92-59
Original file line numberDiff line numberDiff line change
@@ -5,6 +5,7 @@
55

66
# Based on the ssh connection plugin by Michael DeHaan
77

8+
89
DOCUMENTATION = r"""
910
name: aws_ssm
1011
author:
@@ -284,7 +285,6 @@
284285
name: nginx
285286
state: present
286287
"""
287-
288288
import os
289289
import getpass
290290
import json
@@ -295,6 +295,7 @@
295295
import string
296296
import subprocess
297297
import time
298+
from typing import Optional
298299

299300
try:
300301
import boto3
@@ -347,7 +348,10 @@ def wrapped(self, *args, **kwargs):
347348
if isinstance(e, AnsibleConnectionFailure):
348349
msg = f"ssm_retry: attempt: {attempt}, cmd ({cmd_summary}), pausing for {pause} seconds"
349350
else:
350-
msg = f"ssm_retry: attempt: {attempt}, caught exception({e}) from cmd ({cmd_summary}), pausing for {pause} seconds"
351+
msg = (
352+
f"ssm_retry: attempt: {attempt}, caught exception({e})"
353+
f"from cmd ({cmd_summary}),pausing for {pause} seconds"
354+
)
351355

352356
self._vv(msg)
353357

@@ -390,6 +394,90 @@ class Connection(ConnectionBase):
390394
_timeout = False
391395
MARK_LENGTH = 26
392396

397+
def __init__(self, *args, **kwargs):
398+
super().__init__(*args, **kwargs)
399+
400+
if not HAS_BOTO3:
401+
raise AnsibleError(missing_required_lib("boto3"))
402+
403+
self.host = self._play_context.remote_addr
404+
405+
if getattr(self._shell, "SHELL_FAMILY", "") == "powershell":
406+
self.delegate = None
407+
self.has_native_async = True
408+
self.always_pipeline_modules = True
409+
self.module_implementation_preferences = (".ps1", ".exe", "")
410+
self.protocol = None
411+
self.shell_id = None
412+
self._shell_type = "powershell"
413+
self.is_windows = True
414+
415+
def __del__(self):
416+
self.close()
417+
418+
def _connect(self):
419+
"""connect to the host via ssm"""
420+
421+
self._play_context.remote_user = getpass.getuser()
422+
423+
if not self._session_id:
424+
self.start_session()
425+
return self
426+
427+
def _init_clients(self) -> None:
428+
"""
429+
Initializes required AWS clients (SSM and S3).
430+
Delegates client initialization to specialized methods.
431+
"""
432+
433+
self._vvvv("INITIALIZE BOTO3 CLIENTS")
434+
profile_name = self.get_option("profile") or ""
435+
region_name = self.get_option("region")
436+
437+
# Initialize SSM client
438+
self._initialize_ssm_client(region_name, profile_name)
439+
440+
# Initialize S3 client
441+
self._initialize_s3_client(profile_name)
442+
443+
def _initialize_ssm_client(self, region_name: Optional[str], profile_name: str) -> None:
444+
"""
445+
Initializes the SSM client used to manage sessions.
446+
Args:
447+
region_name (Optional[str]): AWS region for the SSM client.
448+
profile_name (str): AWS profile name for authentication.
449+
450+
Returns:
451+
None
452+
"""
453+
454+
self._vvvv("SETUP BOTO3 CLIENTS: SSM")
455+
self._client = self._get_boto_client(
456+
"ssm",
457+
region_name=region_name,
458+
profile_name=profile_name,
459+
)
460+
461+
def _initialize_s3_client(self, profile_name: str) -> None:
462+
"""
463+
Initializes the S3 client used for accessing S3 buckets.
464+
465+
Args:
466+
profile_name (str): AWS profile name for authentication.
467+
468+
Returns:
469+
None
470+
"""
471+
472+
s3_endpoint_url, s3_region_name = self._get_bucket_endpoint()
473+
self._vvvv(f"SETUP BOTO3 CLIENTS: S3 {s3_endpoint_url}")
474+
self._s3_client = self._get_boto_client(
475+
"s3",
476+
region_name=s3_region_name,
477+
endpoint_url=s3_endpoint_url,
478+
profile_name=profile_name,
479+
)
480+
393481
def _display(self, f, message):
394482
if self.host:
395483
host_args = {"host": self.host}
@@ -447,62 +535,6 @@ def _get_bucket_endpoint(self):
447535

448536
return s3_bucket_client.meta.endpoint_url, s3_bucket_client.meta.region_name
449537

450-
def _init_clients(self):
451-
self._vvvv("INITIALIZE BOTO3 CLIENTS")
452-
profile_name = self.get_option("profile") or ""
453-
region_name = self.get_option("region")
454-
455-
# The SSM Boto client, currently used to initiate and manage the session
456-
# Note: does not handle the actual SSM session traffic
457-
self._vvvv("SETUP BOTO3 CLIENTS: SSM")
458-
ssm_client = self._get_boto_client(
459-
"ssm",
460-
region_name=region_name,
461-
profile_name=profile_name,
462-
)
463-
self._client = ssm_client
464-
465-
s3_endpoint_url, s3_region_name = self._get_bucket_endpoint()
466-
self._vvvv(f"SETUP BOTO3 CLIENTS: S3 {s3_endpoint_url}")
467-
s3_bucket_client = self._get_boto_client(
468-
"s3",
469-
region_name=s3_region_name,
470-
endpoint_url=s3_endpoint_url,
471-
profile_name=profile_name,
472-
)
473-
474-
self._s3_client = s3_bucket_client
475-
476-
def __init__(self, *args, **kwargs):
477-
super().__init__(*args, **kwargs)
478-
479-
if not HAS_BOTO3:
480-
raise AnsibleError(missing_required_lib("boto3"))
481-
482-
self.host = self._play_context.remote_addr
483-
484-
if getattr(self._shell, "SHELL_FAMILY", "") == "powershell":
485-
self.delegate = None
486-
self.has_native_async = True
487-
self.always_pipeline_modules = True
488-
self.module_implementation_preferences = (".ps1", ".exe", "")
489-
self.protocol = None
490-
self.shell_id = None
491-
self._shell_type = "powershell"
492-
self.is_windows = True
493-
494-
def __del__(self):
495-
self.close()
496-
497-
def _connect(self):
498-
"""connect to the host via ssm"""
499-
500-
self._play_context.remote_user = getpass.getuser()
501-
502-
if not self._session_id:
503-
self.start_session()
504-
return self
505-
506538
def reset(self):
507539
"""start a fresh ssm session"""
508540
self._vvvv("reset called on ssm connection")
@@ -853,7 +885,8 @@ def _generate_commands(self, bucket_name, s3_path, in_path, out_path):
853885
put_commands = [
854886
(
855887
"Invoke-WebRequest -Method PUT "
856-
f"-Headers @{{{put_command_headers}}} " # @{'key' = 'value'; 'key2' = 'value2'}
888+
# @{'key' = 'value'; 'key2' = 'value2'}
889+
f"-Headers @{{{put_command_headers}}} "
857890
f"-InFile '{in_path}' "
858891
f"-Uri '{put_url}' "
859892
f"-UseBasicParsing"

tests/unit/plugins/connection/test_aws_ssm.py

+79
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,85 @@
1616

1717

1818
class TestConnectionBaseClass:
19+
def test_init_clients(self):
20+
pc = PlayContext()
21+
new_stdin = StringIO()
22+
conn = connection_loader.get("community.aws.aws_ssm", pc, new_stdin)
23+
24+
# Mock get_option to return expected region and profile
25+
def mock_get_option(key):
26+
options = {
27+
"profile": "test-profile",
28+
"region": "us-east-1",
29+
}
30+
return options.get(key, None)
31+
32+
conn.get_option = MagicMock(side_effect=mock_get_option)
33+
34+
# Mock the _initialize_ssm_client and _initialize_s3_client methods
35+
conn._initialize_ssm_client = MagicMock()
36+
conn._initialize_s3_client = MagicMock()
37+
38+
conn._init_clients()
39+
40+
conn._initialize_ssm_client.assert_called_once_with("us-east-1", "test-profile")
41+
conn._initialize_s3_client.assert_called_once_with("test-profile")
42+
43+
@patch("boto3.client")
44+
def test_initialize_ssm_client(self, mock_boto3_client):
45+
"""
46+
Test for the _initialize_ssm_client method to ensure the SSM client is initialized correctly.
47+
"""
48+
pc = PlayContext()
49+
new_stdin = StringIO()
50+
conn = connection_loader.get("community.aws.aws_ssm", pc, new_stdin)
51+
52+
test_region_name = "us-west-2"
53+
test_profile_name = "test-profile"
54+
55+
# Mock the _get_boto_client method to return a mock client
56+
conn._get_boto_client = MagicMock(return_value=mock_boto3_client)
57+
58+
conn._initialize_ssm_client(test_region_name, test_profile_name)
59+
60+
conn._get_boto_client.assert_called_once_with(
61+
"ssm",
62+
region_name=test_region_name,
63+
profile_name=test_profile_name,
64+
)
65+
66+
assert conn._client is mock_boto3_client
67+
68+
@patch("boto3.client")
69+
def test_initialize_s3_client(self, mock_boto3_client):
70+
"""
71+
Test for the _initialize_s3_client method to ensure the S3 client is initialized correctly.
72+
"""
73+
74+
pc = PlayContext()
75+
new_stdin = StringIO()
76+
conn = connection_loader.get("community.aws.aws_ssm", pc, new_stdin)
77+
78+
test_profile_name = "test-profile"
79+
80+
# Mock the _get_bucket_endpoint method to return dummy values
81+
conn._get_bucket_endpoint = MagicMock(return_value=("http://example.com", "us-west-2"))
82+
83+
conn._get_boto_client = MagicMock(return_value=mock_boto3_client)
84+
85+
conn._initialize_s3_client(test_profile_name)
86+
87+
conn._get_bucket_endpoint.assert_called_once()
88+
89+
conn._get_boto_client.assert_called_once_with(
90+
"s3",
91+
region_name="us-west-2",
92+
endpoint_url="http://example.com",
93+
profile_name=test_profile_name,
94+
)
95+
96+
assert conn._s3_client is mock_boto3_client
97+
1998
@patch("os.path.exists")
2099
@patch("subprocess.Popen")
21100
@patch("select.poll")

0 commit comments

Comments
 (0)