diff --git a/src/cloudwatch/modules/client/ec2getclient.py b/src/cloudwatch/modules/client/ec2getclient.py index f1530ef..ad010f9 100644 --- a/src/cloudwatch/modules/client/ec2getclient.py +++ b/src/cloudwatch/modules/client/ec2getclient.py @@ -25,8 +25,14 @@ class EC2GetClient(object): _TOTAL_RETRIES = 1 def __init__(self, config_helper, connection_timeout=_DEFAULT_CONNECTION_TIMEOUT, response_timeout=_DEFAULT_RESPONSE_TIMEOUT): - self.request_builder = EC2RequestBuilder(config_helper.credentials, config_helper.region) + host_override = config_helper.ec2_endpoint + host_override = host_override.replace("https://", "") + host_override = host_override.replace("http://", "") + if host_override.endswith('/'): + host_override = host_override[:-1] + self.request_builder = EC2RequestBuilder(config_helper.credentials, config_helper.region, host_override) self._validate_and_set_endpoint(config_helper.ec2_endpoint) + self.ca_bundle_path = config_helper.ca_bundle_path self.timeout = (connection_timeout, response_timeout) def _validate_and_set_endpoint(self, endpoint): @@ -65,6 +71,8 @@ def _run_request(self, request): session = Session() session.mount("http://", HTTPAdapter(max_retries=self._TOTAL_RETRIES)) session.mount("https://", HTTPAdapter(max_retries=self._TOTAL_RETRIES)) + if self.ca_bundle_path: + session.verify = self.ca_bundle_path result = session.get(self.endpoint + "?" + request, headers=self._get_custom_headers(), timeout=self.timeout) result.raise_for_status() return result diff --git a/src/cloudwatch/modules/client/ec2requestbuilder.py b/src/cloudwatch/modules/client/ec2requestbuilder.py index 073cd1f..690ece1 100644 --- a/src/cloudwatch/modules/client/ec2requestbuilder.py +++ b/src/cloudwatch/modules/client/ec2requestbuilder.py @@ -13,8 +13,9 @@ class EC2RequestBuilder(BaseRequestBuilder): _ACTION = "DescribeTags" _API_VERSION = "2016-11-15" - def __init__(self, credentials, region): + def __init__(self, credentials, region, host_override=''): super(self.__class__, self).__init__(credentials, region, self._SERVICE, self._ACTION, self._API_VERSION) + self.host_override = host_override def create_signed_request(self, request_map): """ Creates a ready to send request with metrics from the metric list passed as parameter """ @@ -35,7 +36,9 @@ def _create_canonical_querystring(self, request_map): def _get_host(self): """ Returns the endpoint's hostname derived from the region """ - if self.region == "localhost": + if self.host_override: + return self.host_override + elif self.region == "localhost": return "localhost" elif self.region.startswith("cn-"): return "ec2." + self.region + ".amazonaws.com.cn" diff --git a/src/cloudwatch/modules/client/putclient.py b/src/cloudwatch/modules/client/putclient.py index 4fea5c1..d2a5442 100644 --- a/src/cloudwatch/modules/client/putclient.py +++ b/src/cloudwatch/modules/client/putclient.py @@ -29,13 +29,19 @@ class PutClient(object): _LOG_FILE_MAX_SIZE = 10*1024*1024 def __init__(self, config_helper, connection_timeout=_DEFAULT_CONNECTION_TIMEOUT, response_timeout=_DEFAULT_RESPONSE_TIMEOUT): - self.request_builder = RequestBuilder(config_helper.credentials, config_helper.region, config_helper.enable_high_resolution_metrics) + host_override = config_helper.endpoint + host_override = host_override.replace("https://", "") + host_override = host_override.replace("http://", "") + if host_override.endswith('/'): + host_override = host_override[:-1] + self.request_builder = RequestBuilder(config_helper.credentials, config_helper.region, config_helper.enable_high_resolution_metrics, host_override) self._validate_and_set_endpoint(config_helper.endpoint) self.timeout = (connection_timeout, response_timeout) self.proxy_server_name = config_helper.proxy_server_name self.proxy_server_port = config_helper.proxy_server_port self.debug = config_helper.debug self.config = config_helper + self.ca_bundle_path = config_helper.ca_bundle_path self._prepare_session() def _prepare_session(self): @@ -52,6 +58,8 @@ def _prepare_session(self): self._LOGGER.info("No proxy server is in use") self.session.mount("http://", HTTPAdapter(max_retries=self._TOTAL_RETRIES)) self.session.mount("https://", HTTPAdapter(max_retries=self._TOTAL_RETRIES)) + if self.ca_bundle_path: + self.session.verify = self.ca_bundle_path def _validate_and_set_endpoint(self, endpoint): pattern = re.compile("http[s]?://*/") diff --git a/src/cloudwatch/modules/client/requestbuilder.py b/src/cloudwatch/modules/client/requestbuilder.py index 85e3896..e26aaca 100644 --- a/src/cloudwatch/modules/client/requestbuilder.py +++ b/src/cloudwatch/modules/client/requestbuilder.py @@ -13,9 +13,10 @@ class RequestBuilder(BaseRequestBuilder): _ACTION = "PutMetricData" _API_VERSION = "2010-08-01" - def __init__(self, credentials, region, enable_high_resolution_metrics): + def __init__(self, credentials, region, enable_high_resolution_metrics, host_override=''): super(self.__class__, self).__init__(credentials, region, self._SERVICE, self._ACTION, self._API_VERSION, enable_high_resolution_metrics) self.namespace = "" + self.host_override = host_override def create_signed_request(self, namespace, metric_list): """ Creates a ready to send request with metrics from the metric list passed as parameter """ @@ -47,7 +48,9 @@ def _get_namespace_request_map(self): def _get_host(self): """ Returns the endpoint's hostname derived from the region """ - if self.region == "localhost": + if self.host_override: + return self.host_override + elif self.region == "localhost": return "localhost" elif self.region.startswith("cn-"): return "monitoring." + self.region + ".amazonaws.com.cn" diff --git a/src/cloudwatch/modules/configuration/confighelper.py b/src/cloudwatch/modules/configuration/confighelper.py index f013085..59740e7 100644 --- a/src/cloudwatch/modules/configuration/confighelper.py +++ b/src/cloudwatch/modules/configuration/confighelper.py @@ -49,6 +49,7 @@ def __init__(self, config_path=_DEFAULT_CONFIG_PATH, metadata_server=_METADATA_S self.constant_dimension_value = '' self.enable_high_resolution_metrics = False self.flush_interval_in_seconds = '' + self.ca_bundle_path = '' self._load_configuration() self.whitelist = Whitelist(WhitelistConfigReader(self.WHITELIST_CONFIG_PATH, self.pass_through).get_regex_list(), self.BLOCKED_METRIC_PATH) @@ -83,6 +84,7 @@ def _load_configuration(self): self._load_flush_interval_in_seconds() self._set_endpoint() self._set_ec2_endpoint() + self.ca_bundle_path = self.config_reader.ca_bundle_path self._load_autoscaling_group() self.debug = self.config_reader.debug self.pass_through = self.config_reader.pass_through @@ -140,7 +142,9 @@ def _load_hostname(self): def _set_ec2_endpoint(self): """ Creates endpoint from region information """ - if self.region is "localhost": + if self.config_reader.ec2_endpoint_override: + self.ec2_endpoint = self.config_reader.ec2_endpoint_override + elif self.region is "localhost": self.ec2_endpoint = "http://" + self.region + "/" elif self.region.startswith("cn-"): self.ec2_endpoint = "https://ec2." + self.region + ".amazonaws.com.cn/" @@ -180,7 +184,9 @@ def _load_flush_interval_in_seconds(self): def _set_endpoint(self): """ Creates endpoint from region information """ - if self.region is "localhost": + if self.config_reader.monitoring_endpoint_override: + self.endpoint = self.config_reader.monitoring_endpoint_override + elif self.region is "localhost": self.endpoint = "http://" + self.region + "/" elif self.region.startswith("cn-"): self.endpoint = "https://monitoring." + self.region + ".amazonaws.com.cn/" diff --git a/src/cloudwatch/modules/configuration/configreader.py b/src/cloudwatch/modules/configuration/configreader.py index 1da5a65..22d345b 100644 --- a/src/cloudwatch/modules/configuration/configreader.py +++ b/src/cloudwatch/modules/configuration/configreader.py @@ -39,6 +39,9 @@ class ConfigReader(object): PROXY_SERVER_PORT_KEY = "proxy_server_port" ENABLE_HIGH_DEFINITION_METRICS = "enable_high_resolution_metrics" FLUSH_INTERVAL_IN_SECONDS = "flush_interval_in_seconds" + MONITORING_ENDPOINT_OVERRIDE_KEY = "monitoring_endpoint_override" + EC2_ENDPOINT_OVERRIDE_KEY = "ec2_endpoint_override" + CA_BUNDLE_PATH_KEY = "ca_bundle_path" def __init__(self, config_path): self.config_path = config_path @@ -54,6 +57,9 @@ def __init__(self, config_path): self.proxy_server_port = '' self.enable_high_resolution_metrics = self._ENABLE_HIGH_DEFINITION_METRICS_DEFAULT_VALUE self.flush_interval_in_seconds = '' + self.monitoring_endpoint_override = '' + self.ec2_endpoint_override = '' + self.ca_bundle_path = '' try: self.reader_utils = ReaderUtils(config_path) self._parse_config_file() @@ -78,3 +84,6 @@ def _parse_config_file(self): self.push_asg = self.reader_utils.try_get_boolean(self.PUSH_ASG_KEY, self._PUSH_ASG_DEFAULT_VALUE) self.push_constant = self.reader_utils.try_get_boolean(self.PUSH_CONSTANT_KEY, self._PUSH_CONSTANT_DEFAULT_VALUE) self.constant_dimension_value = self.reader_utils.get_string(self.CONSTANT_DIMENSION_KEY) + self.monitoring_endpoint_override = self.reader_utils.get_string(self.MONITORING_ENDPOINT_OVERRIDE_KEY) + self.ec2_endpoint_override = self.reader_utils.get_string(self.EC2_ENDPOINT_OVERRIDE_KEY) + self.ca_bundle_path = self.reader_utils.get_string(self.CA_BUNDLE_PATH_KEY) diff --git a/test/config_files/valid_config_with_ca_bundle_path b/test/config_files/valid_config_with_ca_bundle_path new file mode 100644 index 0000000..3bb56fa --- /dev/null +++ b/test/config_files/valid_config_with_ca_bundle_path @@ -0,0 +1 @@ +ca_bundle_path = /path/to/bundle.pem \ No newline at end of file diff --git a/test/config_files/valid_config_with_endpoint_overrides b/test/config_files/valid_config_with_endpoint_overrides new file mode 100644 index 0000000..b7f57c7 --- /dev/null +++ b/test/config_files/valid_config_with_endpoint_overrides @@ -0,0 +1,2 @@ +ec2_endpoint_override = https://valid.url +monitoring_endpoint_override = https://valid.url \ No newline at end of file diff --git a/test/test_confighelper.py b/test/test_confighelper.py index fbc7dee..4e955dd 100644 --- a/test/test_confighelper.py +++ b/test/test_confighelper.py @@ -20,6 +20,8 @@ class ConfigHelperTest(unittest.TestCase): VALID_CONFIG_WITH_PASS_THROUGH_DISABLED = CONFIG_DIR + "valid_config_with_pass_through_disabled" VALID_CONFIG_WITH_PROXY_SERVER_NAME = CONFIG_DIR + "valid_config_with_proxy_server_name" VALID_CONFIG_WITH_PROXY_SERVER_PORT = CONFIG_DIR + "valid_config_with_proxy_server_port" + VALID_CONFIG_WITH_ENDPOINT_OVERRIDES = CONFIG_DIR + "valid_config_with_endpoint_overrides" + VALID_CONFIG_WITH_CA_BUNDLE_PATH = CONFIG_DIR + "valid_config_with_ca_bundle_path" VALID_CONFIG_WITHOUT_CREDS = CONFIG_DIR + "valid_config_without_creds" VALID_CREDENTIALS_FILE = CONFIG_DIR + "valid_credentials_file" MISSING_CONFIG = CONFIG_DIR + "no_config" @@ -35,6 +37,9 @@ class ConfigHelperTest(unittest.TestCase): VALID_PROXY_SERVER_PORT = "server_port" VALID_ENABLE_HIGH_DEFINITION_METRICS = "enable_high_resolution_metrics" VALID_FLUSH_INTERVAL_IN_SECONDS = "flush_interval_in_seconds" + VALID_EC2_ENDPOINT_OVERRIDE = "https://valid.url" + VALID_MONITORING_ENDPOINT_OVERRIDE = "https://valid.url" + VALID_CA_BUNDLE_PATH = "/path/to/bundle.pem" FAKE_SERVER = None @@ -192,6 +197,15 @@ def _load_and_assert_iam_role_credentials(self, expected_access, expected_secret self.config_helper = ConfigHelper(config_path=ConfigHelperTest.VALID_CONFIG_WITHOUT_CREDS,metadata_server=self.server.get_url()) assert_credentials(self.config_helper._credentials, expected_access, expected_secret, expected_token) + def test_endpoint_overrides(self): + self.config_helper = ConfigHelper(config_path=self.VALID_CONFIG_WITH_ENDPOINT_OVERRIDES) + self.assertEquals(self.VALID_EC2_ENDPOINT_OVERRIDE, self.config_helper.ec2_endpoint) + self.assertEquals(self.VALID_MONITORING_ENDPOINT_OVERRIDE, self.config_helper.endpoint) + + def test_ca_bundle_path(self): + self.config_helper = ConfigHelper(config_path=self.VALID_CONFIG_WITH_CA_BUNDLE_PATH) + self.assertEquals(self.VALID_CA_BUNDLE_PATH, self.config_helper.ca_bundle_path) + def _update_and_assert_iam_role_credentials(self, json, expected_access, expected_secret, expected_token): self.server.set_expected_response(json, 200) creds = self.config_helper.credentials diff --git a/test/test_configreader.py b/test/test_configreader.py index 63e3cb4..26abe90 100644 --- a/test/test_configreader.py +++ b/test/test_configreader.py @@ -17,6 +17,8 @@ class ConfigReaderTest(unittest.TestCase): VALID_CONFIG_WITH_PROXY_SERVER_PORT = CONFIG_DIR + "valid_config_with_proxy_server_port" VALID_CONFIG_WITH_PASS_THROUGH_ENABLED = CONFIG_DIR + "valid_config_with_pass_through_enabled" VALID_CONFIG_WITH_PASS_THROUGH_DISABLED = CONFIG_DIR + "valid_config_with_pass_through_disabled" + VALID_CONFIG_WITH_ENDPOINT_OVERRIDES = CONFIG_DIR + "valid_config_with_endpoint_overrides" + VALID_CONFIG_WITH_CA_BUNDLE_PATH = CONFIG_DIR + "valid_config_with_ca_bundle_path" INVALID_CONFIG_WITH_UNKNOWN_PARAMETER = CONFIG_DIR + "invalid_config_with_unknown_parameters" INVALID_CONFIG_WITH_SYNTAX_ERROR = CONFIG_DIR + "invalid_config_with_syntax_error" INVALID_CONFIG_WITH_SINGLE_KEY_MISSING = CONFIG_DIR + "invalid_config_full_with_single_key_missing" @@ -28,7 +30,10 @@ class ConfigReaderTest(unittest.TestCase): VALID_PUSH_ASG_AND_CONSTANT = CONFIG_DIR + "valid_config_push_constant_and_asg" VALID_PROXY_SERVER_NAME = "server_name" VALID_PROXY_SERVER_PORT = "server_port" - + VALID_EC2_ENDPOINT_OVERRIDE = "https://valid.url" + VALID_MONITORING_ENDPOINT_OVERRIDE = "https://valid.url" + VALID_CA_BUNDLE_PATH = "/path/to/bundle.pem" + def setUp(self): self.config_reader = None self.logger = MagicMock() @@ -63,7 +68,16 @@ def test_get_full_configuration(self): self.assertFalse(self.config_reader.debug) self.assertEquals(self.VALID_PROXY_SERVER_NAME, self.config_reader.proxy_server_name) self.assertEquals(self.VALID_PROXY_SERVER_PORT, self.config_reader.proxy_server_port) - + + def test_endpoint_overrides(self): + self.config_reader = ConfigReader(self.VALID_CONFIG_WITH_ENDPOINT_OVERRIDES) + self.assertEquals(self.VALID_EC2_ENDPOINT_OVERRIDE, self.config_reader.ec2_endpoint_override) + self.assertEquals(self.VALID_MONITORING_ENDPOINT_OVERRIDE, self.config_reader.monitoring_endpoint_override) + + def test_ca_bundle(self): + self.config_reader = ConfigReader(self.VALID_CONFIG_WITH_CA_BUNDLE_PATH) + self.assertEquals(self.VALID_CA_BUNDLE_PATH, self.config_reader.ca_bundle_path) + def test_valid_config_with_debug_enabled(self): self.config_reader = ConfigReader(self.VALID_CONFIG_WITH_DEBUG_ENABLED) self.assertTrue(self.config_reader.debug)