diff --git a/README.md b/README.md index 78431a0..62a3738 100644 --- a/README.md +++ b/README.md @@ -20,6 +20,7 @@ sudo ./setup.py ### Plugin specific configuration The default location of the configuration file used by collectd-cloudwatch plugin is: `/opt/collectd-plugins/cloudwatch/config/plugin.conf`. The parameters in this file are optional when plugin is executed on EC2 instance. This file allows modification of the following parameters: * __credentials_path__ - Used to point to AWS account configuration file + * __arn_role__ - Used when you want to explicitly assume a role using AWS STS (Security Token Service). * __region__ - Manual override for [region](http://docs.aws.amazon.com/general/latest/gr/rande.html#cw_region) used to publish metrics * __host__ - Manual override for EC2 Instance ID and Host information propagated by collectd * __proxy_server_name__ - Manual override for proxy server name, used by plugin to connect aws cloudwatch at *.amazonaws.com. @@ -36,6 +37,7 @@ The default location of the configuration file used by collectd-cloudwatch plugi #### Example configuration file ``` credentials_path = "/home/user/.aws/credentials" +arn_role = "arn:aws:test:eu-west-1:1234567890:role/to_assume" region = "us-west-1" host = "Server1" proxy_server_name = "http://myproxyserver.com" @@ -82,6 +84,12 @@ aws_access_key = valid_access_key aws_secret_key = valid_secret_key ``` +### Assuming role using AWS STS(Security toke service) +By setting __arn_role__ in configuration you can explicitly assume a role. It's helpful when you want to send metrics from an account to another account. +Suppose you want send a metric from an EC2 instance in `account_A` to cloudwatch in `account_B`. Then, You should define a role in the `account_A` with +proper policies accessing cloudwatch of the `account A` and then create another role in the `account B` with a policy which grants assuming role already +defined in the `account_A`. Finally, attached the later role one to your EC2 IAM role. For more information visit [aws sts assume role](http://docs.aws.amazon.com/STS/latest/APIReference/API_AssumeRole.html) documentation. + ### Whitelist configuration The CloudWatch collectd plugin allows users to select metrics to be published. This is done by adding metric names or regular expressions written in [python regex syntax](https://docs.python.org/2/library/re.html#regular-expression-syntax) to the whitelist config file. The default location of this configuration is: `/opt/collectd-plugins/cloudwatch/config/whitelist.conf`. diff --git a/src/cloudwatch/config/plugin.conf b/src/cloudwatch/config/plugin.conf index ea2d972..cf07da8 100644 --- a/src/cloudwatch/config/plugin.conf +++ b/src/cloudwatch/config/plugin.conf @@ -1,6 +1,10 @@ # The path to the AWS credentials file. This value has to be provided if plugin is used outside of EC2 instances #credentials_path = "/home/user/.aws/credentials" +# The arn of role used by sts to assuming and gets temporary credentials. +# SEE http://docs.aws.amazon.com/STS/latest/APIReference/API_AssumeRole.html#API_AssumeRole_Examples for more information. +#arn_role = + # The target region which will be used to publish metric data # For list of valid regions visit: http://docs.aws.amazon.com/general/latest/gr/rande.html#cw_region #region = "us-west-1" diff --git a/src/cloudwatch/modules/awscredentials.py b/src/cloudwatch/modules/awscredentials.py index ec0b1db..8917d65 100644 --- a/src/cloudwatch/modules/awscredentials.py +++ b/src/cloudwatch/modules/awscredentials.py @@ -1,3 +1,7 @@ +from datetime import datetime + +AWS_CREDENTIALS_TIMEFORMAT = '%Y-%m-%dT%H:%M:%SZ' + class AWSCredentials(object): """ The AWSCredentials object encapsulates the credentials used for signing put requests. @@ -7,9 +11,24 @@ class AWSCredentials(object): secret_key -- the AWS secret key (default None) token -- the temporary security token obtained through a call to AWS Security Token Service when using IAM Role (default None) + expire_at -- The date string in ISO 8601 standard format(YYYYMMDDThhmmssZ) + on which the current credentials expire (default None, means never) """ - - def __init__(self, access_key=None, secret_key=None, token=None): + + def __init__(self, access_key=None, secret_key=None, token=None, expire_at=None): + self.access_key = access_key self.secret_key = secret_key self.token = token + + if expire_at: + self.expire_at = datetime.strptime(expire_at, AWS_CREDENTIALS_TIMEFORMAT) + else: + self.expire_at = None + + def is_expired(self): + """ True if credentials has been expired """ + now = datetime.utcnow() + return self.expire_at and self.expire_at < now + + diff --git a/src/cloudwatch/modules/client/assumerolereqbuilder.py b/src/cloudwatch/modules/client/assumerolereqbuilder.py new file mode 100644 index 0000000..a94cc37 --- /dev/null +++ b/src/cloudwatch/modules/client/assumerolereqbuilder.py @@ -0,0 +1,42 @@ +from baserequestbuilder import BaseRequestBuilder + +class AssumeRoleReqBuilder(BaseRequestBuilder): + """ + The request builder is responsible for building the AssumeRole requests using HTTP GET. + + Keyword arguments: + credentials -- The AWSCredentials object containing access and secret keys + region -- The region to which the data should be published + arn_role -- The arn_role to assume + """ + _SERVICE = "sts" + _ACTION = "AssumeRole" + _API_VERSION = "2011-06-15" + + def __init__(self, credentials, region): + super(self.__class__, self).__init__(credentials, region, self._SERVICE, self._ACTION, self._API_VERSION) + + def create_signed_request(self, request_map): + """ Creates a ready to send request with metrics from the metric list passed as parameter """ + self._init_timestamps() + canonical_querystring = self._create_canonical_querystring(request_map) + signature = self.signer.create_request_signature(canonical_querystring, self._get_credential_scope(), + self.aws_timestamp, self.datestamp, self._get_canonical_headers(), + self._get_signed_headers(), self.payload) + canonical_querystring += '&X-Amz-Signature=' + signature + return canonical_querystring + + def _create_canonical_querystring(self, request_map): + """ + Creates a canonical querystring as defined in the official AWS API documentation: + http://docs.aws.amazon.com/general/latest/gr/sigv4-create-canonical-request.html + """ + return self.querystring_builder.build_querystring_from_map(request_map, self._get_request_map()) + + def _get_host(self): + """ Returns the endpoint's hostname derived from the region """ + if self.region == "localhost": + return "localhost" + elif self.region.startswith("cn-"): + return "sts." + self.region + ".amazonaws.com.cn" + return "sts." + self.region + ".amazonaws.com" diff --git a/src/cloudwatch/modules/client/stsassumeroleclient.py b/src/cloudwatch/modules/client/stsassumeroleclient.py new file mode 100644 index 0000000..5cd4b10 --- /dev/null +++ b/src/cloudwatch/modules/client/stsassumeroleclient.py @@ -0,0 +1,124 @@ +import re +import os + +from ..plugininfo import PLUGIN_NAME, PLUGIN_VERSION +from assumerolereqbuilder import AssumeRoleReqBuilder +from ..logger.logger import get_logger +from requests.adapters import HTTPAdapter +from requests.sessions import Session +from requests import RequestException +from tempfile import gettempdir +import xml.etree.ElementTree as ET +from ..awscredentials import AWSCredentials + +class StsAssumRoleClient(object): + """ + This is a simple HTTPClient wrapper which supports assumeRole operation on sts endpoints. + + Keyword arguments: + region -- the region used for request signing. + endpoint -- the endpoint used for publishing metric data + credentials -- the AWSCredentials object containing access_key, secret_key or + IAM Role token used for request signing + connection_timeout -- the amount of time in seconds to wait for extablishing server connection + response_timeout -- the amount of time in seconds to wait for the server response + """ + + _LOGGER = get_logger(__name__) + _DEFAULT_CONNECTION_TIMEOUT = 1 + _DEFAULT_RESPONSE_TIMEOUT = 3 + _TOTAL_RETRIES = 1 + _LOG_FILE_MAX_SIZE = 10*1024*1024 + + + def __init__(self, credentials, endpoint='', region='', proxy_server_name='', proxy_server_port='', debug=False, connection_timeout=_DEFAULT_CONNECTION_TIMEOUT, response_timeout=_DEFAULT_RESPONSE_TIMEOUT): + self.assumerole_req_builder = AssumeRoleReqBuilder(credentials, region) + self._validate_and_set_endpoint(endpoint) + self.timeout = (connection_timeout, response_timeout) + self.proxy_server_name = proxy_server_name + self.proxy_server_port = proxy_server_port + self.debug = debug + self._prepare_session() + + def get_credentials(self, arn_role, role_session_name, duration_seconds): + """ + Requests a temporary keys by assumming give arn_role. Returns None in case of error. + """ + request_map = {} + request_map["RoleSessionName"] = role_session_name + request_map["RoleArn"] = arn_role + request_map["DurationSeconds"] = duration_seconds + request = self.assumerole_req_builder.create_signed_request(request_map) + + try: + xml_content = self._run_request(request).content + xmldoc = ET.fromstring(xml_content) + ns={'sts': 'https://sts.amazonaws.com/doc/2011-06-15/'} + cred_xml = xmldoc.find('sts:AssumeRoleResult/sts:Credentials',ns) + cred = {} + cred["session_token"] = cred_xml.find('sts:SessionToken', ns).text.strip() + cred["secret_access_key"] = cred_xml.find('sts:SecretAccessKey', ns).text.strip() + cred["access_key_id"] = cred_xml.find('sts:AccessKeyId', ns).text.strip() + cred["expiration"] = cred_xml.find('sts:Expiration', ns).text.strip() + + if not cred["session_token"] or not cred['secret_access_key'] or not cred["access_key_id"] or not cred["expiration"]: + raise ValueError("Incomplete credentials retrieved.") + except RequestException as e: + self._LOGGER.warning("Could not assume '" + arn_role + "' using the following endpoint: '" + self.endpoint +"'. [Exception: " + str(e) + "]") + self._LOGGER.warning("Request details: '" + request + "'") + raise ValueError(e) + except Exception as e: + raise ValueError(e) + + return AWSCredentials(cred['access_key_id'], cred['secret_access_key'], cred["session_token"], cred["expiration"]) + + def _prepare_session(self): + self.session = Session() + if self.proxy_server_name is not None: + proxy_server = self.proxy_server_name + self._LOGGER.info("Using proxy server: " + proxy_server) + if self.proxy_server_port is not None: + proxy_server = proxy_server + ":" + self.proxy_server_port + self._LOGGER.info("Using proxy server port: " + self.proxy_server_port) + proxies = {'https': proxy_server} + self.session.proxies.update(proxies) + else: + self._LOGGER.info("No proxy server is in use") + self.session.mount("http://", HTTPAdapter(max_retries=self._TOTAL_RETRIES)) + self.session.mount("https://", HTTPAdapter(max_retries=self._TOTAL_RETRIES)) + + def _validate_and_set_endpoint(self, endpoint): + pattern = re.compile("http[s]?://*/") + if pattern.match(endpoint) or "localhost" in endpoint: + self.endpoint = endpoint + else: + msg = "Provided endpoint '" + endpoint + "' is not a valid URL." + self._LOGGER.error(msg) + raise StsAssumRoleClient.InvalidEndpointException(msg) + + def _get_custom_headers(self): + """ Returns dictionary of HTTP headers to be attached to each request """ + return {"User-Agent": self._get_user_agent_header()} + + def _get_user_agent_header(self): + """ Returns the plugin name and version used as User-Agent information """ + return PLUGIN_NAME + "/" + str(PLUGIN_VERSION) + + def _run_request(self, request): + """ + Executes HTTP GET request with timeout using the endpoint defined upon client creation. + """ + if self.debug: + file_path = gettempdir() + "/collectd_plugin_request_trace_log" + if os.path.isfile(file_path) and os.path.getsize(file_path) > self._LOG_FILE_MAX_SIZE: + os.remove(file_path) + with open(file_path, "a") as logfile: + logfile.write("curl -i -v -connect-timeout 1 -m 3 -w %{http_code}:%{http_connect}:%{content_type}:%{time_namelookup}:%{time_redirect}:%{time_pretransfer}:%{time_connect}:%{time_starttransfer}:%{time_total}:%{speed_download} -A \"collectd/1.0\" \'" + self.endpoint + "?" + request + "\'") + logfile.write("\n\n") + + result = self.session.get(self.endpoint + "?" + request, headers=self._get_custom_headers(), timeout=self.timeout) + result.raise_for_status() + return result + + class InvalidEndpointException(Exception): + pass diff --git a/src/cloudwatch/modules/configuration/confighelper.py b/src/cloudwatch/modules/configuration/confighelper.py index f013085..501f3f1 100644 --- a/src/cloudwatch/modules/configuration/confighelper.py +++ b/src/cloudwatch/modules/configuration/confighelper.py @@ -5,6 +5,8 @@ from credentialsreader import CredentialsReader from whitelist import Whitelist, WhitelistConfigReader from ..client.ec2getclient import EC2GetClient +from ..client.stsassumeroleclient import StsAssumRoleClient +from ..plugininfo import PLUGIN_NAME, PLUGIN_VERSION import traceback class ConfigHelper(object): @@ -35,9 +37,11 @@ def __init__(self, config_path=_DEFAULT_CONFIG_PATH, metadata_server=_METADATA_S self._config_path = config_path self._metadata_server = metadata_server self._use_iam_role_credentials = False + self._arn_role = '' self.region = '' self.endpoint = '' self.ec2_endpoint = '' + self.sts_endpoint = '' self.host = '' self.asg_name = 'NONE' self.proxy_server_name = '' @@ -58,11 +62,19 @@ def credentials(self): Returns credentials. If IAM role is used, credentials will be updated. Otherwise old credentials are returned. """ - if self._use_iam_role_credentials: + if self._use_iam_role_credentials and self._credentials.is_expired(): try: self._credentials = self._get_credentials_from_iam_role() except: self._LOGGER.warning("Could not retrieve credentials using IAM Role. Using old credentials instead.") + elif self._arn_role and self._credentials and self._credentials.is_expired(): + try: + # First use iam role to query sts temporary credentials + self._credentials = self._get_credentials_from_iam_role() + self._credentials = self._get_credentials_by_sts_assuming_role() + except: + self._LOGGER.warning("Could not retrieve credentials assuming IAM Role. Using old credentials instead.") + return self._credentials @credentials.setter @@ -77,13 +89,16 @@ def _load_configuration(self): self._load_credentials() self._load_region() self._load_hostname() + self._load_arn_role() self._load_proxy_server_name() self._load_proxy_server_port() self.enable_high_resolution_metrics = self.config_reader.enable_high_resolution_metrics self._load_flush_interval_in_seconds() self._set_endpoint() self._set_ec2_endpoint() + self._set_sts_endpoint() self._load_autoscaling_group() + self._overwrite_credentials_by_assuming_role() self.debug = self.config_reader.debug self.pass_through = self.config_reader.pass_through self.push_asg = self.config_reader.push_asg @@ -107,6 +122,25 @@ def _load_credentials(self): self._use_iam_role_credentials = True self.credentials = self._get_credentials_from_iam_role() + def _overwrite_credentials_by_assuming_role(self): + """ + Tries to load and overwrite credentials with new credentials got by assuming role if + any arn_role was provided. + """ + if self._arn_role and self.credentials: + try: + self.credentials = self._get_credentials_by_sts_assuming_role() + self._use_iam_role_credentials = False + except Exception as e: + self._LOGGER.error("Failed to set credentials by assuming role. Continue by iam role credentials. Cause: " + str(e)) + + def _get_credentials_by_sts_assuming_role(self): + #Don't use credentials getter method, otherwise it runs in an infinit loop + stsAssumRoleClient = StsAssumRoleClient(self._credentials, self.sts_endpoint, self.region, self.proxy_server_name, self.proxy_server_port, self.debug) + duration_seconds = 3600 + role_session_name = PLUGIN_NAME + "_v" + str(PLUGIN_VERSION) + return stsAssumRoleClient.get_credentials(self._arn_role, role_session_name, duration_seconds) + def _get_credentials_from_iam_role(self): """ Queries IAM Role metadata for latest credentials """ return self.metadata_reader.get_iam_role_credentials(self.metadata_reader.get_iam_role_name()) @@ -137,7 +171,14 @@ def _load_hostname(self): except Exception as e: ConfigHelper._LOGGER.warning("Cannot retrieve Instance ID from the local metadata server. Cause: " + str(e) + " Using host information provided by Collectd.") - + + def _load_arn_role(self): + """ + Loads arn_role from plugin configuration file. + """ + if self.config_reader.arn_role: + self._arn_role = self.config_reader.arn_role + def _set_ec2_endpoint(self): """ Creates endpoint from region information """ if self.region is "localhost": @@ -146,6 +187,15 @@ def _set_ec2_endpoint(self): self.ec2_endpoint = "https://ec2." + self.region + ".amazonaws.com.cn/" else: self.ec2_endpoint = "https://ec2." + self.region + ".amazonaws.com/" + + def _set_sts_endpoint(self): + """ Creates endpoint from region information """ + if self.region is "localhost": + self.sts_endpoint = "http://" + self.region + "/" + elif self.region.startswith("cn-"): + self.sts_endpoint = "https://sts." + self.region + ".amazonaws.com.cn/" + else: + self.sts_endpoint = "https://sts." + self.region + ".amazonaws.com/" def _load_proxy_server_name(self): """ diff --git a/src/cloudwatch/modules/configuration/configreader.py b/src/cloudwatch/modules/configuration/configreader.py index 1da5a65..275c147 100644 --- a/src/cloudwatch/modules/configuration/configreader.py +++ b/src/cloudwatch/modules/configuration/configreader.py @@ -27,6 +27,7 @@ class ConfigReader(object): _PASS_THROUGH_DEFAULT_VALUE = False _PUSH_ASG_DEFAULT_VALUE = False _PUSH_CONSTANT_DEFAULT_VALUE = False + ARN_ROLE_CONFIG_KEY = "arn_role" REGION_CONFIG_KEY = "region" HOST_CONFIG_KEY = "host" CREDENTIALS_PATH_KEY = "credentials_path" @@ -43,6 +44,7 @@ class ConfigReader(object): def __init__(self, config_path): self.config_path = config_path self.credentials_path = "" + self.arn_role = '' self.region = '' self.host = '' self.pass_through = self._PASS_THROUGH_DEFAULT_VALUE @@ -67,6 +69,7 @@ def _parse_config_file(self): in format ['key=value', 'key2=value2'] """ self.credentials_path = self.reader_utils.get_string(self.CREDENTIALS_PATH_KEY) + self.arn_role = self.reader_utils.get_string(self.ARN_ROLE_CONFIG_KEY) self.host = self.reader_utils.get_string(self.HOST_CONFIG_KEY) self.region = self.reader_utils.get_string(self.REGION_CONFIG_KEY) self.proxy_server_name = self.reader_utils.get_string(self.PROXY_SERVER_NAME_KEY) diff --git a/src/cloudwatch/modules/configuration/metadatareader.py b/src/cloudwatch/modules/configuration/metadatareader.py index 5a43329..fb77796 100644 --- a/src/cloudwatch/modules/configuration/metadatareader.py +++ b/src/cloudwatch/modules/configuration/metadatareader.py @@ -43,8 +43,8 @@ def get_iam_role_credentials(self, role_name): """ Get the IAMRoleCredentials object with values from IAM metadata """ try: iam_data = loads(self._get_metadata(self._IAM_ROLE_CREDENTIAL_REQUEST + role_name)) - if iam_data['AccessKeyId'] and iam_data['SecretAccessKey'] and iam_data['Token']: - return AWSCredentials(iam_data['AccessKeyId'], iam_data['SecretAccessKey'], iam_data['Token']) + if iam_data['AccessKeyId'] and iam_data['SecretAccessKey'] and iam_data['Token'] and iam_data['Expiration']: + return AWSCredentials(iam_data['AccessKeyId'], iam_data['SecretAccessKey'], iam_data['Token'], iam_data['Expiration']) else: raise ValueError("Incomplete credentials retrieved.") except Exception as e: diff --git a/src/setup.py b/src/setup.py index ad62a71..b5654cf 100755 --- a/src/setup.py +++ b/src/setup.py @@ -274,6 +274,7 @@ def make_dirs(directory): class PluginConfig(object): CREDENTIALS_PATH_KEY = "credentials_path" + ARN_ROLE_KEY = "arn_role" REGION_KEY = "region" HOST_KEY = "host" PROXY_SERVER_NAME = "proxy_server_name" @@ -288,10 +289,11 @@ class PluginConfig(object): ENABLE_HIGH_DEFINITION_METRICS = "enable_high_resolution_metrics" FLUSH_INTERVAL_IN_SECONDS = "flush_interval_in_seconds" - def __init__(self, credentials_path=None, access_key=None, secret_key=None, region=None, host=None, proxy_server_name=None, proxy_server_port=None, push_asg=None, push_constant=None, constant_dimension_value=None, enable_high_resolution_metrics=False, flush_interval_in_seconds=None): + def __init__(self, credentials_path=None, access_key=None, secret_key=None, arn_role=None, region=None, host=None, proxy_server_name=None, proxy_server_port=None, push_asg=None, push_constant=None, constant_dimension_value=None, enable_high_resolution_metrics=False, flush_interval_in_seconds=None): self.credentials_path = credentials_path self.access_key = access_key self.secret_key = secret_key + self.arn_role = arn_role self.region = region self.host = host self.use_recommended_collectd_config = False @@ -310,6 +312,9 @@ def __init__(self, credentials_path=None, access_key=None, secret_key=None, regi class InteractiveConfigurator(object): DEFAULT_PROMPT = "Enter choice [" + Color.green("{default}") + "]: " + CREDENTIAL_TYPE_IAM_ROLE = "IAM_ROLE" + CREDENTIAL_TYPE_IAM_USER = "IAM_USER" + CREDENTIAL_TYPE_STS_ASSUME_ROLE = "STS_ASSUME_ROLE" def __init__(self, plugin_config, metadata_reader, collectd_info): self.config = plugin_config @@ -412,24 +417,36 @@ def _configure_flush_interval_in_seconds(self): def _get_flush_interval_in_seconds(self): return Prompt("\nEnter the customized flush interval ([1, 60] s):", default="60", allowed_values=[str(x) for x in range(1, 61)]).run() - + + def _get_arn_role(self): + self.config.arn_role = None + return Prompt("\nEnter arn of role used by AWS STS to get temporary credentials (e.g. arn:aws:sts::xxxxxxxxxx:assumed-role/exmaple):", default=None).run() + def _configure_credentials(self): - if self._is_iam_user_required(): + cred_type = self._get_credentials_type() + if cred_type == InteractiveConfigurator.CREDENTIAL_TYPE_IAM_USER: self.config.credentials_path = self._get_credentials_path() self.config.credentials_file_exist = path.exists(self.config.credentials_path) if not self.config.credentials_file_exist: self.config.access_key = Prompt(message="Enter access key: ").run() self.config.secret_key = Prompt(message="Enter secret key: ").run() + elif cred_type == InteractiveConfigurator.CREDENTIAL_TYPE_STS_ASSUME_ROLE: + self.config.arn_role = self._get_arn_role() - def _is_iam_user_required(self): + def _get_credentials_type(self): try: iam_role = self.metadata_reader.get_iam_role_name() - answer = Prompt("\nChoose authentication method:", ["IAM Role [" + iam_role + "]", "IAM User"], default="1").run() - return answer == "2" + answer = Prompt("\nChoose authentication method:", ["IAM Role [" + iam_role + "]", "IAM User", "Assume Role (AWS STS)"], default="1").run() + if answer == "2": + return InteractiveConfigurator.CREDENTIAL_TYPE_IAM_USER + elif answer == "3": + return InteractiveConfigurator.CREDENTIAL_TYPE_STS_ASSUME_ROLE + else: + return InteractiveConfigurator.CREDENTIAL_TYPE_IAM_ROLE except MetadataRequestException: print(Color.yellow("\nIAM Role could not be automatically detected.")) - return True - + return InteractiveConfigurator.CREDENTIAL_TYPE_IAM_USER + def _get_credentials_path(self): recommended_path = path.expanduser('~') + '/.aws/credentials' creds_path = "" @@ -495,6 +512,10 @@ class PluginConfigWriter(object): TEMPLATE = """# The path to the AWS credentials file. This value has to be provided if plugin is used outside of EC2 instances $credentials_path$ +# The arn of role used by sts to assuming and gets temporary credentials. +# SEE http://docs.aws.amazon.com/STS/latest/APIReference/API_AssumeRole.html#API_AssumeRole_Examples for more information. +$arn_role$ + # The target region which will be used to publish metric data # For list of valid regions visit: http://docs.aws.amazon.com/general/latest/gr/rande.html#cw_region $region$ @@ -574,6 +595,7 @@ def _prepare_config(self): config = self._replace_with_value(config, self.plugin_config.PUSH_ASG_KEY, self.plugin_config.push_asg) config = self._replace_with_value(config, self.plugin_config.PUSH_CONSTANT_KEY, self.plugin_config.push_constant) config = self._replace_with_value(config, self.plugin_config.CONSTANT_DIMENSION_VALUE_KEY, self.plugin_config.constant_dimension_value) + config = self._replace_with_value(config, self.plugin_config.ARN_ROLE_KEY, self.plugin_config.arn_role) return config def _replace_with_value(self, string, key, value): diff --git a/test/test_awscredentials.py b/test/test_awscredentials.py index f898c25..7ce612c 100644 --- a/test/test_awscredentials.py +++ b/test/test_awscredentials.py @@ -1,8 +1,9 @@ import unittest -from cloudwatch.modules.awscredentials import AWSCredentials +from cloudwatch.modules.awscredentials import AWSCredentials, AWS_CREDENTIALS_TIMEFORMAT +from datetime import datetime, timedelta class AWSCredentialsTest(unittest.TestCase): - + def test_aws_credentials_with_default_constructor(self): creds = AWSCredentials() assert_credentials_data(creds) @@ -11,10 +12,33 @@ def test_aws_credentials_with_custom_values(self): new_access_key = "accessKey" new_secret_key = "secretKey" new_token = "token" - credentials = AWSCredentials(new_access_key, new_secret_key, new_token) - assert_credentials_data(credentials, new_access_key, new_secret_key, new_token) + new_expire_at_str = '2012-12-03T20:48:03Z' + new_expire_at = datetime.strptime(new_expire_at_str, AWS_CREDENTIALS_TIMEFORMAT) + + credentials = AWSCredentials(new_access_key, new_secret_key, new_token, new_expire_at_str) + assert_credentials_data(credentials, new_access_key, new_secret_key, new_token, new_expire_at) -def assert_credentials_data(credentials, access_key=None, secret_key=None, token=None): + def test_aws_credentials_exception_invalid_expire_time(self): + with self.assertRaises(ValueError): + AWSCredentials(expire_at='2012-12-03T20:48:_invalid_03Z') + + def test_aws_credentials_is_expired(self): + already_expired = datetime.utcnow() - timedelta(hours=1) + cred = AWSCredentials(expire_at=already_expired.strftime(AWS_CREDENTIALS_TIMEFORMAT)) + self.assertTrue(cred.is_expired()) + + def test_aws_credentials_is_not_expired(self): + already_expired = datetime.utcnow() + timedelta(hours=1) + cred = AWSCredentials(expire_at=already_expired.strftime(AWS_CREDENTIALS_TIMEFORMAT)) + self.assertFalse(cred.is_expired()) + + def test_aws_credentials_is_not_expired_on_NONE(self): + cred = AWSCredentials() + self.assertFalse(cred.is_expired()) + +def assert_credentials_data(credentials, access_key=None, secret_key=None, token=None, expire_at=None): assert access_key == credentials.access_key assert secret_key == credentials.secret_key assert token == credentials.token + assert expire_at == credentials.expire_at + diff --git a/test/test_confighelper.py b/test/test_confighelper.py index da87b26..cce0a2a 100644 --- a/test/test_confighelper.py +++ b/test/test_confighelper.py @@ -1,12 +1,13 @@ import unittest from mock import Mock +from datetime import datetime import cloudwatch.modules.collectd as collectd from cloudwatch.modules.configuration.confighelper import ConfigHelper from cloudwatch.modules.configuration.metadatareader import MetadataReader from helpers.fake_http_server import FakeServer - +from cloudwatch.modules.awscredentials import AWS_CREDENTIALS_TIMEFORMAT class ConfigHelperTest(unittest.TestCase): CONFIG_DIR = "./test/config_files/" @@ -37,7 +38,8 @@ class ConfigHelperTest(unittest.TestCase): VALID_FLUSH_INTERVAL_IN_SECONDS = "flush_interval_in_seconds" FAKE_SERVER = None - + FAKE_STS_SERVER = None + @classmethod def setUpClass(cls): cls.FAKE_SERVER = FakeServer() @@ -166,18 +168,23 @@ def test_exception_is_handled_when_instance_id_cannot_be_retrieved(self): self.assertTrue(collectd.warning.called) def test_configuration_with_iam_role_credentials(self): - self._load_and_assert_iam_role_credentials("ACCESS_KEY", "SECRET_KEY", "TOKEN") + self._load_and_assert_iam_role_credentials("ACCESS_KEY", "SECRET_KEY", "TOKEN", "2030-12-03T20:48:03Z") + + def test_iam_role_creds_are_refreshed_on_expiration(self): + self._load_and_assert_iam_role_credentials("ACCESS_KEY", "SECRET_KEY", "TOKEN", "2001-12-03T20:48:03Z") + creds_json = '{"AccessKeyId" : "NEW_ACCESS_KEY", "SecretAccessKey" : "NEW_SECRET_KEY", "Token" : "NEW_TOKEN", "Expiration" : "2030-12-03T20:48:03Z" }' + self._update_and_assert_iam_role_credentials(creds_json, "NEW_ACCESS_KEY", "NEW_SECRET_KEY", "NEW_TOKEN", "2030-12-03T20:48:03Z") + + def test_iam_role_creds_not_refreshed_if_not_expired(self): + self._load_and_assert_iam_role_credentials("ACCESS_KEY", "SECRET_KEY", "TOKEN", "2030-12-03T20:48:03Z") + creds_json = '{"AccessKeyId" : "NEW_ACCESS_KEY", "SecretAccessKey" : "NEW_SECRET_KEY", "Token" : "NEW_TOKEN", "Expiration" : "2050-12-03T20:48:03Z" }' + self._update_and_assert_iam_role_credentials(creds_json, "ACCESS_KEY", "SECRET_KEY", "TOKEN", "2030-12-03T20:48:03Z") - def test_iam_role_creds_are_refreshed(self): - self._load_and_assert_iam_role_credentials("ACCESS_KEY", "SECRET_KEY", "TOKEN") - creds_json = '{"AccessKeyId" : "NEW_ACCESS_KEY", "SecretAccessKey" : "NEW_SECRET_KEY", "Token" : "NEW_TOKEN" }' - self._update_and_assert_iam_role_credentials(creds_json, "NEW_ACCESS_KEY", "NEW_SECRET_KEY", "NEW_TOKEN") - def test_old_iam_role_creds_are_served_on_error(self): - self._load_and_assert_iam_role_credentials("ACCESS_KEY", "SECRET_KEY", "TOKEN") + self._load_and_assert_iam_role_credentials("ACCESS_KEY", "SECRET_KEY", "TOKEN", "2030-12-03T20:48:03Z") creds_json = '{"AccessKeyId" : "NEW_ACCESS_KEY", "SecretAccessKey" : "NEW_SECRET_KEY",}' - self._update_and_assert_iam_role_credentials(creds_json, "ACCESS_KEY", "SECRET_KEY", "TOKEN") - + self._update_and_assert_iam_role_credentials(creds_json, "ACCESS_KEY", "SECRET_KEY", "TOKEN", "2030-12-03T20:48:03Z") + def test_whitelist_is_properly_configured_based_on_plugin_config_file(self): ConfigHelper.WHITELIST_CONFIG_PATH = self.PASS_THROUGH_WHITELIST_CONFIG self.config_helper = ConfigHelper(config_path=self.VALID_CONFIG_WITH_PASS_THROUGH_DISABLED) @@ -185,18 +192,41 @@ def test_whitelist_is_properly_configured_based_on_plugin_config_file(self): self.config_helper = ConfigHelper(config_path=self.VALID_CONFIG_WITH_PASS_THROUGH_ENABLED) self.assertTrue(self.config_helper.whitelist.is_whitelisted("random-metric-name")) - def _load_and_assert_iam_role_credentials(self, expected_access, expected_secret, expected_token): - creds_json = '{"AccessKeyId" : "' + expected_access +'", "SecretAccessKey" : "' + expected_secret + '", "Token" : "' + expected_token + '" }' + def _load_and_assert_iam_role_credentials(self, expected_access, expected_secret, expected_token, expected_expire_at): + creds_json = '{"AccessKeyId" : "' + expected_access +'", "SecretAccessKey" : "' + expected_secret + '", "Token" : "' + expected_token + '", "Expiration" : "' + expected_expire_at + '" }' self.server.set_expected_response(creds_json, 200) ConfigHelper._DEFAULT_CREDENTIALS_PATH = "" self.config_helper = ConfigHelper(config_path=ConfigHelperTest.VALID_CONFIG_WITHOUT_CREDS,metadata_server=self.server.get_url()) - assert_credentials(self.config_helper._credentials, expected_access, expected_secret, expected_token) + assert_credentials(self.config_helper._credentials, expected_access, expected_secret, expected_token, expected_expire_at) - def _update_and_assert_iam_role_credentials(self, json, expected_access, expected_secret, expected_token): + def _update_and_assert_iam_role_credentials(self, json, expected_access, expected_secret, expected_token, expected_expire_at): self.server.set_expected_response(json, 200) creds = self.config_helper.credentials - assert_credentials(creds, expected_access, expected_secret, expected_token) - + assert_credentials(creds, expected_access, expected_secret, expected_token, expected_expire_at) + + def test_overwrite_credentials_by_assuming_role_on_load(self): + self._load_and_assert_iam_role_credentials_with_arn_role("ACCESS_KEY", "SECRET_KEY", "TOKEN", "2030-12-03T20:48:03Z") + + def _load_and_assert_iam_role_credentials_with_arn_role(self, expected_access, expected_secret, expected_token, expected_expire_at): + self._load_and_assert_iam_role_credentials("E_" + expected_access, "E" + expected_secret, "E" + expected_token, "2100-12-03T20:48:03Z") + self.config_helper._arn_role = "arn:aws:test:eu-west-1:1111111111111:role/assume" + sts_cred_xml = ''' + + + + ''' + expected_token + ''' + ''' + expected_secret + ''' + ''' + expected_expire_at + ''' + ''' + expected_access + ''' + + + + ''' + self.config_helper.sts_endpoint = self.server.get_url() + self.server.set_expected_response(sts_cred_xml, 200) + self.config_helper._overwrite_credentials_by_assuming_role() + assert_credentials(self.config_helper._credentials, expected_access, expected_secret, expected_token, expected_expire_at) + @classmethod def tearDownClass(cls): cls.FAKE_SERVER.stop_server() @@ -204,8 +234,11 @@ def tearDownClass(cls): def assert_credentials(credentials, expected_access=ConfigHelperTest.VALID_ACCESS_KEY_STRING, - expected_secret=ConfigHelperTest.VALID_SECRET_KEY_STRING, expected_token=None): + expected_secret=ConfigHelperTest.VALID_SECRET_KEY_STRING, expected_token=None, expected_expire_at=None): assert credentials.access_key == expected_access assert credentials.secret_key == expected_secret assert credentials.token == expected_token - + if expected_expire_at: + assert credentials.expire_at == datetime.strptime(expected_expire_at, AWS_CREDENTIALS_TIMEFORMAT) + else: + assert credentials.expire_at == None diff --git a/test/test_metadatareader.py b/test/test_metadatareader.py index ed32e26..8af3d0c 100644 --- a/test/test_metadatareader.py +++ b/test/test_metadatareader.py @@ -2,6 +2,7 @@ import requests from helpers.fake_http_server import FakeServer from cloudwatch.modules.configuration.metadatareader import MetadataReader, MetadataRequestException +from cloudwatch.modules.awscredentials import AWS_CREDENTIALS_TIMEFORMAT class MetadataReaderTest(unittest.TestCase): FAKE_SERVER = None @@ -63,6 +64,7 @@ def test_can_create_iam_role_credentials_from_json(self): self.assertEquals("ACCESS_KEY", creds.access_key) self.assertEquals("SECRET_KEY", creds.secret_key) self.assertEquals("TOKEN", creds.token) + self.assertEquals('2015-08-27T09:22:57Z', creds.expire_at.strftime(AWS_CREDENTIALS_TIMEFORMAT)) def test_get_iam_role_credentials_raises_exception_on_invalid_json_format(self): json = '{"Code" - "Success", "LastUpdated" : "2015-08-27T09:22:57Z", "Type" : "AWS-HMAC", \ @@ -81,4 +83,4 @@ def test_get_iam_role_credentials_raises_exception_on_missing_values(self): @classmethod def tearDownClass(cls): cls.FAKE_SERVER.stop_server() - cls.FAKE_SERVER = None \ No newline at end of file + cls.FAKE_SERVER = None diff --git a/test/test_stsassumeroleclient.py b/test/test_stsassumeroleclient.py new file mode 100644 index 0000000..4823f0a --- /dev/null +++ b/test/test_stsassumeroleclient.py @@ -0,0 +1,194 @@ +import unittest +import requests +import time + +from requests.utils import quote +from mock import Mock, MagicMock +from cloudwatch.modules.client.stsassumeroleclient import StsAssumRoleClient +from cloudwatch.modules.plugininfo import PLUGIN_NAME, PLUGIN_VERSION +from helpers.fake_http_server import FakeServer +from cloudwatch.modules.awscredentials import AWSCredentials + +class StsAssumRoleClientTest(unittest.TestCase): + + FAKE_SERVER = None + USER_AGENT = PLUGIN_NAME + "/" + str(PLUGIN_VERSION) + XML_RESPONSE = ''' + + + + token_test + secret_key_test + 2011-07-15T23:28:33Z + access_key_test + + + + ''' + + @classmethod + def setUpClass(cls): + cls.FAKE_SERVER = FakeServer() + cls.FAKE_SERVER.start_server() + cls.FAKE_SERVER.serve_forever() + + def setUp(self): + self.server = StsAssumRoleClientTest.FAKE_SERVER + self.server.set_expected_response(StsAssumRoleClientTest.XML_RESPONSE, 200) + self.client = StsAssumRoleClient(AWSCredentials("access", "secret"), "http://localhost:57575/", "localhost") + self.logger = MagicMock() + self.logger.warning = Mock() + self.client.__class__._LOGGER = self.logger + + def server_restart(self): + self.server.stop_server() + self.server.start_server() + self.server.set_expected_response(StsAssumRoleClientTest.XML_RESPONSE, 200) + self.server.serve_forever() + + def server_get_received_request(self): + return open(FakeServer.REQUEST_FILE).read()[2:] # trim '/?' from the request + + @classmethod + def tearDownClass(cls): + cls.FAKE_SERVER.stop_server() + cls.FAKE_SERVER = None + + def test_constructor(self): + connection_timeout = 10 + response_timeout = 20 + client = StsAssumRoleClient(AWSCredentials("access", "secret"), "http://localhost:57575/", "localhost", connection_timeout=connection_timeout, response_timeout=response_timeout) + self.assertEquals("http://localhost:57575/", client.endpoint) + self.assertEquals((connection_timeout,response_timeout), client.timeout) + + def test_initialize_sts_assume_role_with_valid_endpoint(self): + self.client = StsAssumRoleClient(AWSCredentials("access", "secret"), "https://sts.eu-west-1.amazonaws.com", "localhost") + + def test_put_initialize_sts_assume_role_with_invalid_endpoint(self): + with self.assertRaises(StsAssumRoleClient.InvalidEndpointException): + self.client = StsAssumRoleClient(AWSCredentials("access", "secret"), "invalid_endpoint", "localhost") + self.assertTrue(self.logger.error.called) + + def test_get_user_agent_header(self): + header = self.client._get_user_agent_header() + self.assertTrue(StsAssumRoleClientTest.USER_AGENT in header) + + def test_get_custom_headers(self): + headers = self.client._get_custom_headers() + self.assertTrue(headers['User-Agent']) + self.assertTrue(StsAssumRoleClientTest.USER_AGENT in headers['User-Agent']) + + def test_get_request(self): + request = "Testing_Request" + self.server.set_expected_response("OK", 200) + result = self.client._run_request("Testing_Request") + self.assertEquals("OK", result.text) + self.assertEquals(200, result.status_code) + self.assertTrue(request in self.server_get_received_request()) + + def test_client_raise_exception_on_credentials_error(self): + self.server.set_expected_response("Client Error: Forbidden", 403) + self.assert_no_retry_on_error_request("arn_role_test", "arn_session_name_test", 3600) + + def test_client_raise_exception_on_service_unavailable_error(self): + self.server.set_expected_response("Service Unavailable", 503) + self.assert_no_retry_on_error_request("arn_role_test", "arn_session_name_test", 3600) + + def test_client_raise_exception_on_request_throttling(self): + self.server.set_expected_response("Request Throttled", 400) + self.assert_no_retry_on_error_request("arn_role_test", "arn_session_name_test", 3600) + + def test_server_received_user_agent_information(self): + self.client.get_credentials("arn_role_test", "arn_session_name_test", 3600) + received_request = self.server_get_received_request() + self.assertTrue(StsAssumRoleClientTest.USER_AGENT in received_request) + + def test_get_crendetials_with_iam_role_creds(self): + self.client = StsAssumRoleClient(AWSCredentials("access", "secret", "IAM_ROLE_TOKEN"), "http://localhost:57575/", "localhost") + self.client.get_credentials("arn_role_test", "arn_session_name_test", 3600) + received_request = self.server_get_received_request() + self.assertTrue("X-Amz-Security-Token=IAM_ROLE_TOKEN" in received_request) + + def test_get_crendetials_with_retry(self): + self.server.set_timeout_delay(StsAssumRoleClient._DEFAULT_RESPONSE_TIMEOUT * StsAssumRoleClient._TOTAL_RETRIES) + self.client = StsAssumRoleClient(AWSCredentials("access", "secret"), "http://localhost:57575/", "localhost") + self.client.get_credentials("arn_role_test", "arn_session_name_test", 3600) + received_request = self.server_get_received_request() + self.assertTrue("RoleArn" in received_request) + + def test_get_crendetials(self): + arn_role = "sample_arn_role" + duration_seconds = "3600" + role_session_name = "test_session_name" + self.client.get_credentials(arn_role, role_session_name, duration_seconds) + received_request = self.server_get_received_request() + self.assertTrue("RoleSessionName=" + role_session_name in received_request) + self.assertTrue("RoleArn=" + arn_role in received_request) + self.assertTrue("DurationSeconds=" + duration_seconds in received_request) + self.assertTrue("Action=AssumeRole" in received_request) + self.assertTrue("Version" in received_request) + self.assertTrue("X-Amz-Algorithm" in received_request) + self.assertTrue("X-Amz-Credential" in received_request) + self.assertTrue("X-Amz-Date" in received_request) + self.assertTrue("X-Amz-SignedHeaders" in received_request) + self.assertTrue("X-Amz-Signature" in received_request) + + def test_get_crendetials_valid_response(self): + session_token = 'AQoDYXdzEPT//////////wEXAMPLEtc764bNrC9SAPBSM22wDOk4x4HIZ8j4FZTwdQWLWsKWHGBuFqwAeMicRXmxfpSPfIeoIYRqTflfKD8YUuwthAx7mSEI/qkPpKPi/kMcGd+xo0rKwT38xVqr7ZD0u0iPPkUL64lIZbqBAz+scqKmlzm8FDrypNC9Yjc8fPOLn9FX9KSYvKTr4rvx3iSIlTJabIQwj2ICCR/oLxBA==' + expected_credential = AWSCredentials('AKIAIOSFODNN7EXAMPLE', 'wJalrXUtnFEMI/K7MDENG/bPxRfiCYzEXAMPLEKEY', session_token) + arn_role = "sample_arn_role" + duration_seconds = "3600" + role_session_name = "test_session_name" + resp_content = ''' + + + + + ''' + expected_credential.token + ''' + + + ''' + expected_credential.secret_key + ''' + + 2011-07-15T23:28:39Z + + ''' + expected_credential.access_key + ''' + + + + + ''' + arn_role + ''' + + ARO123EXAMPLE123:Bob + + 6 + + + c6104cbe-af31-11e0-8154-cbc7ccf896c7 + + + ''' + self.server.set_expected_response(resp_content, 200) + cred = self.client.get_credentials(arn_role, role_session_name, duration_seconds) + self.assertTrue(expected_credential.access_key == cred.access_key) + self.assertTrue(expected_credential.secret_key == cred.secret_key) + self.assertTrue(expected_credential.token == cred.token) + + + def test_get_crendetials_with_timeout(self): + self.server.set_timeout_delay(StsAssumRoleClient._DEFAULT_RESPONSE_TIMEOUT * (StsAssumRoleClient._TOTAL_RETRIES + 1)) + self.assertRaises(ValueError, self.client.get_credentials, "arn_role_test", "arn_session_name_test", 3600) + self.assertTrue(self.logger.warning.called) + + def test_get_request_timeout(self): + self.server.set_timeout_delay(StsAssumRoleClient._DEFAULT_RESPONSE_TIMEOUT * (StsAssumRoleClient._TOTAL_RETRIES + 1)) + with self.assertRaises(requests.ConnectionError): + self.client._run_request("request") + self.server_restart() + + def assert_no_retry_on_error_request(self, arn_role, role_session_name, duration_seconds): + start = time.time() + self.assertRaises(ValueError, self.client.get_credentials, arn_role, role_session_name, duration_seconds) + end = time.time() + delta = end - start + self.assertTrue(delta < self.client._DEFAULT_RESPONSE_TIMEOUT) + self.assertTrue(self.logger.warning.called)