From ff20cf9a2afcdf20c49a9e79434763314a77b107 Mon Sep 17 00:00:00 2001 From: tvaron3 Date: Wed, 5 Feb 2025 19:03:52 -0800 Subject: [PATCH 01/86] change default read timeout --- sdk/cosmos/azure-cosmos/azure/cosmos/_synchronized_request.py | 2 ++ .../azure-cosmos/azure/cosmos/aio/_asynchronous_request.py | 2 ++ sdk/cosmos/azure-cosmos/azure/cosmos/documents.py | 1 + 3 files changed, 5 insertions(+) diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/_synchronized_request.py b/sdk/cosmos/azure-cosmos/azure/cosmos/_synchronized_request.py index 1c90bfa57150..68e37caf1d9d 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/_synchronized_request.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/_synchronized_request.py @@ -93,6 +93,8 @@ def _Request(global_endpoint_manager, request_params, connection_policy, pipelin if request_params.resource_type != http_constants.ResourceType.DatabaseAccount: global_endpoint_manager.refresh_endpoint_list(None, **kwargs) else: + # always override database account call timeouts + read_timeout = connection_policy.DBAReadTimeout connection_timeout = connection_policy.DBAConnectionTimeout if client_timeout is not None: kwargs['timeout'] = client_timeout - (time.time() - start_time) diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_asynchronous_request.py b/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_asynchronous_request.py index 377cb5d406b1..81430d8df42c 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_asynchronous_request.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_asynchronous_request.py @@ -62,6 +62,8 @@ async def _Request(global_endpoint_manager, request_params, connection_policy, p if request_params.resource_type != http_constants.ResourceType.DatabaseAccount: await global_endpoint_manager.refresh_endpoint_list(None, **kwargs) else: + # always override database account call timeouts + read_timeout = connection_policy.DBAReadTimeout connection_timeout = connection_policy.DBAConnectionTimeout if client_timeout is not None: kwargs['timeout'] = client_timeout - (time.time() - start_time) diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/documents.py b/sdk/cosmos/azure-cosmos/azure/cosmos/documents.py index fcf855f56921..8093b2f71fc0 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/documents.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/documents.py @@ -332,6 +332,7 @@ class ConnectionPolicy: # pylint: disable=too-many-instance-attributes __defaultRequestTimeout: int = 5 # seconds __defaultDBAConnectionTimeout: int = 3 # seconds __defaultReadTimeout: int = 65 # seconds + __defaultDBAReadTimeout: int = 3 # seconds __defaultMaxBackoff: int = 1 # seconds def __init__(self) -> None: From 40e43c40193fa7a7d8f87c2c570d9e82686c6185 Mon Sep 17 00:00:00 2001 From: tvaron3 Date: Wed, 5 Feb 2025 20:02:53 -0800 Subject: [PATCH 02/86] fix tests --- sdk/cosmos/azure-cosmos/azure/cosmos/documents.py | 1 + 1 file changed, 1 insertion(+) diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/documents.py b/sdk/cosmos/azure-cosmos/azure/cosmos/documents.py index 8093b2f71fc0..57d6e75be534 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/documents.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/documents.py @@ -339,6 +339,7 @@ def __init__(self) -> None: self.RequestTimeout: int = self.__defaultRequestTimeout self.DBAConnectionTimeout: int = self.__defaultDBAConnectionTimeout self.ReadTimeout: int = self.__defaultReadTimeout + self.DBAReadTimeout: int = self.__defaultDBAReadTimeout self.MaxBackoff: int = self.__defaultMaxBackoff self.ConnectionMode: int = ConnectionMode.Gateway self.SSLConfiguration: Optional[SSLConfiguration] = None From aefe30b42efa1d02dd282468e70276333da0b127 Mon Sep 17 00:00:00 2001 From: tvaron3 Date: Wed, 5 Feb 2025 21:24:38 -0800 Subject: [PATCH 03/86] Add read timeout tests for database account calls --- sdk/cosmos/azure-cosmos/test/test_crud.py | 9 ++++++++- sdk/cosmos/azure-cosmos/test/test_crud_async.py | 9 +++++++++ 2 files changed, 17 insertions(+), 1 deletion(-) diff --git a/sdk/cosmos/azure-cosmos/test/test_crud.py b/sdk/cosmos/azure-cosmos/test/test_crud.py index d41246f5d6cd..a96654cec6c0 100644 --- a/sdk/cosmos/azure-cosmos/test/test_crud.py +++ b/sdk/cosmos/azure-cosmos/test/test_crud.py @@ -1821,7 +1821,14 @@ def test_client_request_timeout(self): container = databaseForTest.get_container_client(self.configs.TEST_SINGLE_PARTITION_CONTAINER_ID) container.create_item(body={'id': str(uuid.uuid4()), 'name': 'sample'}) - + async def test_read_timeout_async(self): + connection_policy = documents.ConnectionPolicy() + # making timeout 0 ms to make sure it will throw + connection_policy.DBAReadTimeout = 0.000000000001 + with self.assertRaises(ServiceResponseError): + # this will make a get database account call + with cosmos_client.CosmosClient(self.host, self.masterKey, connection_policy=connection_policy): + print('initialization') def test_client_request_timeout_when_connection_retry_configuration_specified(self): connection_policy = documents.ConnectionPolicy() diff --git a/sdk/cosmos/azure-cosmos/test/test_crud_async.py b/sdk/cosmos/azure-cosmos/test/test_crud_async.py index 1c40afc3edfa..00517d23ef8e 100644 --- a/sdk/cosmos/azure-cosmos/test/test_crud_async.py +++ b/sdk/cosmos/azure-cosmos/test/test_crud_async.py @@ -1689,6 +1689,15 @@ async def test_client_request_timeout_async(self): await container.create_item(body={'id': str(uuid.uuid4()), 'name': 'sample'}) print('Async initialization') + async def test_read_timeout_async(self): + connection_policy = documents.ConnectionPolicy() + # making timeout 0 ms to make sure it will throw + connection_policy.DBAReadTimeout = 0.000000000001 + with self.assertRaises(ServiceResponseError): + # this will make a get database account call + async with CosmosClient(self.host, self.masterKey, connection_policy=connection_policy): + print('Async initialization') + async def test_client_request_timeout_when_connection_retry_configuration_specified_async(self): connection_policy = documents.ConnectionPolicy() # making timeout 0 ms to make sure it will throw From 9a234f87d902ea22925313864a91d7a3b864fbee Mon Sep 17 00:00:00 2001 From: tvaron3 Date: Wed, 5 Feb 2025 21:36:21 -0800 Subject: [PATCH 04/86] fix timeout retry policy --- .../azure-cosmos/azure/cosmos/_timeout_failover_retry_policy.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/_timeout_failover_retry_policy.py b/sdk/cosmos/azure-cosmos/azure/cosmos/_timeout_failover_retry_policy.py index 145dfd947ccf..036061a17b07 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/_timeout_failover_retry_policy.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/_timeout_failover_retry_policy.py @@ -30,7 +30,7 @@ def ShouldRetry(self, _exception): :rtype: bool """ if self.request: - if _OperationType.IsReadOnlyOperation(self.request.operation_type): + if not _OperationType.IsReadOnlyOperation(self.request.operation_type): return False if not self.connection_policy.EnableEndpointDiscovery: From 8859c9fc8972c7ba684c953b6b6cc526cdd7ca01 Mon Sep 17 00:00:00 2001 From: Kushagra Thapar Date: Wed, 5 Feb 2025 21:56:45 -0800 Subject: [PATCH 05/86] Fixed the timeout logic --- .../cosmos/_timeout_failover_retry_policy.py | 39 +++---------------- 1 file changed, 5 insertions(+), 34 deletions(-) diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/_timeout_failover_retry_policy.py b/sdk/cosmos/azure-cosmos/azure/cosmos/_timeout_failover_retry_policy.py index 036061a17b07..aa66cd4f76e7 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/_timeout_failover_retry_policy.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/_timeout_failover_retry_policy.py @@ -18,7 +18,6 @@ def __init__(self, connection_policy, global_endpoint_manager, *args): self.global_endpoint_manager = global_endpoint_manager self.retry_count = 0 - self.location_index = 0 self.connection_policy = connection_policy self.request = args[0] if args else None @@ -29,14 +28,13 @@ def ShouldRetry(self, _exception): :returns: a boolean stating whether the request should be retried :rtype: bool """ - if self.request: - if not _OperationType.IsReadOnlyOperation(self.request.operation_type): - return False + # we don't retry on write operations for timeouts or service unavailable + if self.request and (not _OperationType.IsReadOnlyOperation(self.request.operation_type)): + return False if not self.connection_policy.EnableEndpointDiscovery: return False - # Check if the next retry about to be done is safe if _exception.status_code == http_constants.StatusCodes.SERVICE_UNAVAILABLE and \ self.retry_count >= self._max_service_unavailable_retry_count: @@ -47,46 +45,19 @@ def ShouldRetry(self, _exception): return False if self.request: - # Update the last routed location to where this request was routed previously. - # So that we can check in location cache if we need to return the current or previous - # based on where the request was routed previously. - self.request.last_routed_location_endpoint_within_region = self.request.location_endpoint_to_route - - if _OperationType.IsReadOnlyOperation(self.request.operation_type): - # We just directly got to the next location in case of read requests - # We don't retry again on the same region for regional endpoint - location_endpoint = self.resolve_next_region_service_endpoint() - else: - self.global_endpoint_manager.swap_regional_endpoint_values(self.request) - location_endpoint = self.resolve_current_region_service_endpoint() - # This is the case where both current and previous point to the same writable endpoint - # In this case we don't want to retry again, rather failover to the next region - if self.request.last_routed_location_endpoint_within_region == location_endpoint: - location_endpoint = self.resolve_next_region_service_endpoint() - + location_endpoint = self.resolve_next_region_service_endpoint() self.request.route_to_location(location_endpoint) return True - - # This function prepares the request to go to the second endpoint in the same region - def resolve_current_region_service_endpoint(self): - # clear previous location-based routing directive - self.request.clear_route_to_location() - # resolve the next service endpoint in the same region - # since we maintain 2 endpoints per region for write operations - self.request.route_to_location_with_preferred_location_flag(self.location_index, True) - return self.global_endpoint_manager.resolve_service_endpoint(self.request) - # This function prepares the request to go to the next region def resolve_next_region_service_endpoint(self): - self.location_index += 1 # clear previous location-based routing directive self.request.clear_route_to_location() # clear the last routed endpoint within same region since we are going to a new region now self.request.last_routed_location_endpoint_within_region = None # set location-based routing directive based on retry count # ensuring usePreferredLocations is set to True for retry - self.request.route_to_location_with_preferred_location_flag(self.location_index, True) + self.request.route_to_location_with_preferred_location_flag(self.retry_count, True) # Resolve the endpoint for the request and pin the resolution to the resolved endpoint # This enables marking the endpoint unavailability on endpoint failover/unreachability return self.global_endpoint_manager.resolve_service_endpoint(self.request) From ac78da9632bf406e6ac38b4d9ea79cc03eb7d41d Mon Sep 17 00:00:00 2001 From: Kushagra Thapar Date: Wed, 5 Feb 2025 23:13:52 -0800 Subject: [PATCH 06/86] Fixed the timeout retry policy --- .../azure-cosmos/azure/cosmos/_retry_utility.py | 7 +++++-- .../azure/cosmos/_timeout_failover_retry_policy.py | 13 ++++--------- .../azure/cosmos/aio/_retry_utility_async.py | 5 +++-- .../azure-cosmos/azure/cosmos/http_constants.py | 1 - sdk/cosmos/azure-cosmos/test/test_globaldb_mock.py | 2 +- 5 files changed, 13 insertions(+), 15 deletions(-) diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/_retry_utility.py b/sdk/cosmos/azure-cosmos/azure/cosmos/_retry_utility.py index 99784facecc4..927ed7a41baa 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/_retry_utility.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/_retry_utility.py @@ -131,7 +131,6 @@ def Execute(client, global_endpoint_manager, function, *args, **kwargs): sub_status_code=SubStatusCodes.THROUGHPUT_OFFER_NOT_FOUND) return result except exceptions.CosmosHttpResponseError as e: - retry_policy = defaultRetry_policy if request and _has_database_account_header(request.headers): retry_policy = database_account_retry_policy # Re-assign retry policy based on error code @@ -173,8 +172,12 @@ def Execute(client, global_endpoint_manager, function, *args, **kwargs): retry_policy.container_rid = cached_container["_rid"] request.headers[retry_policy._intended_headers] = retry_policy.container_rid - elif e.status_code in [StatusCodes.REQUEST_TIMEOUT, StatusCodes.SERVICE_UNAVAILABLE]: + elif e.status_code == StatusCodes.REQUEST_TIMEOUT: retry_policy = timeout_failover_retry_policy + elif e.status_code >= StatusCodes.INTERNAL_SERVER_ERROR: + retry_policy = timeout_failover_retry_policy + else: + retry_policy = defaultRetry_policy # If none of the retry policies applies or there is no retry needed, set the # throttle related response headers and re-throw the exception back arg[0] diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/_timeout_failover_retry_policy.py b/sdk/cosmos/azure-cosmos/azure/cosmos/_timeout_failover_retry_policy.py index aa66cd4f76e7..60f0208e6351 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/_timeout_failover_retry_policy.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/_timeout_failover_retry_policy.py @@ -5,18 +5,17 @@ Cosmos database service. """ from azure.cosmos.documents import _OperationType -from . import http_constants class _TimeoutFailoverRetryPolicy(object): def __init__(self, connection_policy, global_endpoint_manager, *args): - self._max_retry_attempt_count = 120 - self._max_service_unavailable_retry_count = 1 - self.retry_after_in_milliseconds = 0 + self.retry_after_in_milliseconds = 500 self.args = args self.global_endpoint_manager = global_endpoint_manager + # If an account only has 1 region, then we still want to retry once on the same region + self._max_retry_attempt_count = len(self.global_endpoint_manager.location_cache.read_regional_endpoints) + 1 self.retry_count = 0 self.connection_policy = connection_policy self.request = args[0] if args else None @@ -28,17 +27,13 @@ def ShouldRetry(self, _exception): :returns: a boolean stating whether the request should be retried :rtype: bool """ - # we don't retry on write operations for timeouts or service unavailable + # we don't retry on write operations for timeouts or any internal server errors if self.request and (not _OperationType.IsReadOnlyOperation(self.request.operation_type)): return False if not self.connection_policy.EnableEndpointDiscovery: return False - # Check if the next retry about to be done is safe - if _exception.status_code == http_constants.StatusCodes.SERVICE_UNAVAILABLE and \ - self.retry_count >= self._max_service_unavailable_retry_count: - return False self.retry_count += 1 # Check if the next retry about to be done is safe if self.retry_count >= self._max_retry_attempt_count: diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_retry_utility_async.py b/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_retry_utility_async.py index 74df8ea9479f..c4be5a1afc2c 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_retry_utility_async.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_retry_utility_async.py @@ -130,7 +130,6 @@ async def ExecuteAsync(client, global_endpoint_manager, function, *args, **kwarg return result except exceptions.CosmosHttpResponseError as e: - retry_policy = None if request and _has_database_account_header(request.headers): retry_policy = database_account_retry_policy elif e.status_code == StatusCodes.FORBIDDEN and e.sub_status in \ @@ -171,7 +170,9 @@ async def ExecuteAsync(client, global_endpoint_manager, function, *args, **kwarg retry_policy.container_rid = cached_container["_rid"] request.headers[retry_policy._intended_headers] = retry_policy.container_rid - elif e.status_code in [StatusCodes.REQUEST_TIMEOUT, StatusCodes.SERVICE_UNAVAILABLE]: + elif e.status_code == StatusCodes.REQUEST_TIMEOUT: + retry_policy = timeout_failover_retry_policy + elif e.status_code >= StatusCodes.INTERNAL_SERVER_ERROR: retry_policy = timeout_failover_retry_policy else: retry_policy = defaultRetry_policy diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/http_constants.py b/sdk/cosmos/azure-cosmos/azure/cosmos/http_constants.py index 8a7b57b7c93f..31a95d2600d6 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/http_constants.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/http_constants.py @@ -400,7 +400,6 @@ class StatusCodes: RETRY_WITH = 449 INTERNAL_SERVER_ERROR = 500 - SERVICE_UNAVAILABLE = 503 # Operation pause and cancel. These are FAKE status codes for QOS logging purpose only. OPERATION_PAUSED = 1200 diff --git a/sdk/cosmos/azure-cosmos/test/test_globaldb_mock.py b/sdk/cosmos/azure-cosmos/test/test_globaldb_mock.py index 5548b51839b3..83d5e2603e9f 100644 --- a/sdk/cosmos/azure-cosmos/test/test_globaldb_mock.py +++ b/sdk/cosmos/azure-cosmos/test/test_globaldb_mock.py @@ -166,7 +166,7 @@ def MockExecuteFunction(self, function, *args, **kwargs): def MockGetDatabaseAccountStub(self, endpoint): raise exceptions.CosmosHttpResponseError( - status_code=StatusCodes.SERVICE_UNAVAILABLE, message="Service unavailable") + status_code=StatusCodes.INTERNAL_SERVER_ERROR, message="Internal Server Error") def test_global_db_endpoint_discovery_retry_policy(self): connection_policy = documents.ConnectionPolicy() From 09aac90f49d2391bebfe29cfab1658fcb892b161 Mon Sep 17 00:00:00 2001 From: tvaron3 Date: Thu, 6 Feb 2025 02:30:58 -0800 Subject: [PATCH 07/86] Mock tests for timeout and failover retry policy --- .../test_timeout_and_failover_retry_policy.py | 135 +++++++++++++++++ ...timeout_and_failover_retry_policy_async.py | 137 ++++++++++++++++++ 2 files changed, 272 insertions(+) create mode 100644 sdk/cosmos/azure-cosmos/test/test_timeout_and_failover_retry_policy.py create mode 100644 sdk/cosmos/azure-cosmos/test/test_timeout_and_failover_retry_policy_async.py diff --git a/sdk/cosmos/azure-cosmos/test/test_timeout_and_failover_retry_policy.py b/sdk/cosmos/azure-cosmos/test/test_timeout_and_failover_retry_policy.py new file mode 100644 index 000000000000..342453524246 --- /dev/null +++ b/sdk/cosmos/azure-cosmos/test/test_timeout_and_failover_retry_policy.py @@ -0,0 +1,135 @@ +# The MIT License (MIT) +# Copyright (c) Microsoft Corporation. All rights reserved. + +import unittest +import uuid + +import pytest + +import azure.cosmos.cosmos_client as cosmos_client +import azure.cosmos.exceptions as exceptions +import test_config +from azure.cosmos import _retry_utility, PartitionKey + +COLLECTION = "created_collection" +@pytest.fixture(scope="class") +def setup(): + if (TestTimeoutRetryPolicy.masterKey == '[YOUR_KEY_HERE]' or + TestTimeoutRetryPolicy.host == '[YOUR_ENDPOINT_HERE]'): + raise Exception( + "You must specify your Azure Cosmos account values for " + "'masterKey' and 'host' at the top of this class to run the " + "tests.") + + client = cosmos_client.CosmosClient(TestTimeoutRetryPolicy.host, TestTimeoutRetryPolicy.masterKey, consistency_level="Session", + connection_policy=TestTimeoutRetryPolicy.connectionPolicy) + created_database = client.get_database_client(TestTimeoutRetryPolicy.TEST_DATABASE_ID) + created_collection = created_database.create_container(TestTimeoutRetryPolicy.TEST_CONTAINER_SINGLE_PARTITION_ID, + partition_key=PartitionKey("/pk")) + yield { + COLLECTION: created_collection + } + + created_database.delete_container(TestTimeoutRetryPolicy.TEST_CONTAINER_SINGLE_PARTITION_ID) + + + + +def error_codes(): + return [408, 500, 502, 503] + + +@pytest.mark.cosmosEmulator +@pytest.mark.unittest +@pytest.mark.usefixtures("setup") +class TestTimeoutRetryPolicy: + host = test_config.TestConfig.host + masterKey = test_config.TestConfig.masterKey + connectionPolicy = test_config.TestConfig.connectionPolicy + TEST_DATABASE_ID = test_config.TestConfig.TEST_DATABASE_ID + TEST_CONTAINER_SINGLE_PARTITION_ID = "test-timeout-retry-policy-container-" + str(uuid.uuid4()) + + @pytest.mark.parametrize("error_code", error_codes()) + def test_timeout_failover_retry_policy_for_read_success(self, setup, error_code): + document_definition = {'id': 'failoverDoc-' + str(uuid.uuid4()), + 'pk': 'pk', + 'name': 'sample document', + 'key': 'value'} + + created_document = setup[COLLECTION].create_item(body=document_definition) + self.original_execute_function = _retry_utility.ExecuteFunction + try: + # should retry once and then succeed + mf = self.MockExecuteFunction(self.original_execute_function, 1, error_code) + _retry_utility.ExecuteFunction = mf + doc = setup[COLLECTION].read_item(item=created_document['id'], + partition_key=created_document['pk']) + assert doc == created_document + finally: + _retry_utility.ExecuteFunction = self.original_execute_function + + @pytest.mark.parametrize("error_code", error_codes()) + def test_timeout_failover_retry_policy_for_read_failure(self, setup, error_code): + document_definition = {'id': 'failoverDoc-' + str(uuid.uuid4()), + 'pk': 'pk', + 'name': 'sample document', + 'key': 'value'} + + created_document = setup[COLLECTION].create_item(body=document_definition) + self.original_execute_function = _retry_utility.ExecuteFunction + try: + # should retry once and then succeed + mf = self.MockExecuteFunction(self.original_execute_function, 5, error_code) + _retry_utility.ExecuteFunction = mf + setup[COLLECTION].read_item(item=created_document['id'], + partition_key=created_document['pk']) + pytest.fail("Exception was not raised.") + except exceptions.CosmosHttpResponseError as err: + assert err.status_code == error_code + finally: + _retry_utility.ExecuteFunction = self.original_execute_function + + @pytest.mark.parametrize("error_code", error_codes()) + def test_timeout_failover_retry_policy_for_write_failure(self, setup, error_code): + document_definition = {'id': 'failoverDoc' + str(uuid.uuid4()), + 'pk': 'pk', + 'name': 'sample document', + 'key': 'value'} + + self.original_execute_function = _retry_utility.ExecuteFunction + try: + # timeouts should fail immediately for writes + mf = self.MockExecuteFunction(self.original_execute_function,0, error_code) + _retry_utility.ExecuteFunction = mf + try: + setup[COLLECTION].create_item(body=document_definition) + pytest.fail("Exception was not raised.") + except exceptions.CosmosHttpResponseError as err: + assert err.status_code == error_code + finally: + _retry_utility.ExecuteFunction = self.original_execute_function + + + + + class MockExecuteFunction(object): + def __init__(self, org_func, num_exceptions, status_code): + self.org_func = org_func + self.counter = 0 + self.num_exceptions = num_exceptions + self.status_code = status_code + + def __call__(self, func, *args, **kwargs): + if self.counter != 0 and self.counter >= self.num_exceptions: + return self.org_func(func, *args, **kwargs) + else: + self.counter += 1 + raise exceptions.CosmosHttpResponseError( + status_code=self.status_code, + message="Some Exception", + response=test_config.FakeResponse({})) + + + +if __name__ == '__main__': + unittest.main() diff --git a/sdk/cosmos/azure-cosmos/test/test_timeout_and_failover_retry_policy_async.py b/sdk/cosmos/azure-cosmos/test/test_timeout_and_failover_retry_policy_async.py new file mode 100644 index 000000000000..90b2f46dc651 --- /dev/null +++ b/sdk/cosmos/azure-cosmos/test/test_timeout_and_failover_retry_policy_async.py @@ -0,0 +1,137 @@ +# The MIT License (MIT) +# Copyright (c) Microsoft Corporation. All rights reserved. + +import unittest +import uuid + +import pytest +import pytest_asyncio + +import azure.cosmos.exceptions as exceptions +import test_config +from azure.cosmos import PartitionKey +from azure.cosmos.aio import CosmosClient, _retry_utility_async + +COLLECTION = "created_collection" +@pytest_asyncio.fixture() +async def setup(): + if (TestTimeoutRetryPolicyAsync.masterKey == '[YOUR_KEY_HERE]' or + TestTimeoutRetryPolicyAsync.host == '[YOUR_ENDPOINT_HERE]'): + raise Exception( + "You must specify your Azure Cosmos account values for " + "'masterKey' and 'host' at the top of this class to run the " + "tests.") + + client = CosmosClient(TestTimeoutRetryPolicyAsync.host, TestTimeoutRetryPolicyAsync.masterKey, consistency_level="Session", + connection_policy=TestTimeoutRetryPolicyAsync.connectionPolicy) + created_database = client.get_database_client(TestTimeoutRetryPolicyAsync.TEST_DATABASE_ID) + created_collection = await created_database.create_container(TestTimeoutRetryPolicyAsync.TEST_CONTAINER_SINGLE_PARTITION_ID, + partition_key=PartitionKey("/pk")) + yield { + COLLECTION: created_collection + } + + await created_database.delete_container(TestTimeoutRetryPolicyAsync.TEST_CONTAINER_SINGLE_PARTITION_ID) + await client.close() + + + + +def error_codes(): + return [408, 500, 502, 503] + + +@pytest.mark.cosmosEmulator +@pytest.mark.asyncio +@pytest.mark.usefixtures("setup") +class TestTimeoutRetryPolicyAsync: + host = test_config.TestConfig.host + masterKey = test_config.TestConfig.masterKey + connectionPolicy = test_config.TestConfig.connectionPolicy + TEST_DATABASE_ID = test_config.TestConfig.TEST_DATABASE_ID + TEST_CONTAINER_SINGLE_PARTITION_ID = "test-timeout-retry-policy-container-" + str(uuid.uuid4()) + + @pytest.mark.parametrize("error_code", error_codes()) + async def test_timeout_failover_retry_policy_for_read_success_async(self, setup, error_code): + document_definition = {'id': 'failoverDoc-' + str(uuid.uuid4()), + 'pk': 'pk', + 'name': 'sample document', + 'key': 'value'} + + created_document = await setup[COLLECTION].create_item(body=document_definition) + self.original_execute_function = _retry_utility_async.ExecuteFunctionAsync + try: + # should retry once and then succeed + mf = self.MockExecuteFunction(self.original_execute_function, 1, error_code) + _retry_utility_async.ExecuteFunctionAsync = mf + doc = await setup[COLLECTION].read_item(item=created_document['id'], + partition_key=created_document['pk']) + assert doc == created_document + finally: + _retry_utility_async.ExecuteFunctionAsync = self.original_execute_function + + @pytest.mark.parametrize("error_code", error_codes()) + async def test_timeout_failover_retry_policy_for_read_failure_async(self, setup, error_code): + document_definition = {'id': 'failoverDoc-' + str(uuid.uuid4()), + 'pk': 'pk', + 'name': 'sample document', + 'key': 'value'} + + created_document = await setup[COLLECTION].create_item(body=document_definition) + self.original_execute_function = _retry_utility_async.ExecuteFunctionAsync + try: + # should retry once and then succeed + mf = self.MockExecuteFunction(self.original_execute_function, 5, error_code) + _retry_utility_async.ExecuteFunctionAsync = mf + await setup[COLLECTION].read_item(item=created_document['id'], + partition_key=created_document['pk']) + pytest.fail("Exception was not raised.") + except exceptions.CosmosHttpResponseError as err: + assert err.status_code == error_code + finally: + _retry_utility_async.ExecuteFunctionAsync = self.original_execute_function + + @pytest.mark.parametrize("error_code", error_codes()) + async def test_timeout_failover_retry_policy_for_write_failure_async(self, setup, error_code): + document_definition = {'id': 'failoverDoc' + str(uuid.uuid4()), + 'pk': 'pk', + 'name': 'sample document', + 'key': 'value'} + + self.original_execute_function = _retry_utility_async.ExecuteFunctionAsync + try: + # timeouts should fail immediately for writes + mf = self.MockExecuteFunction(self.original_execute_function,0, error_code) + _retry_utility_async.ExecuteFunctionAsync = mf + try: + await setup[COLLECTION].create_item(body=document_definition) + pytest.fail("Exception was not raised.") + except exceptions.CosmosHttpResponseError as err: + assert err.status_code == error_code + finally: + _retry_utility_async.ExecuteFunctionAsync = self.original_execute_function + + + + + class MockExecuteFunction(object): + def __init__(self, org_func, num_exceptions, status_code): + self.org_func = org_func + self.counter = 0 + self.num_exceptions = num_exceptions + self.status_code = status_code + + def __call__(self, func, global_endpoint_manager, *args, **kwargs): + if self.counter != 0 and self.counter >= self.num_exceptions: + return self.org_func(func, global_endpoint_manager, *args, **kwargs) + else: + self.counter += 1 + raise exceptions.CosmosHttpResponseError( + status_code=self.status_code, + message="Some Exception", + response=test_config.FakeResponse({})) + + + +if __name__ == '__main__': + unittest.main() From f22e7d21d05e55eb8cf2ff06a1bb21d6ab0658de Mon Sep 17 00:00:00 2001 From: Fabian Meiswinkel Date: Fri, 7 Feb 2025 02:22:44 +0000 Subject: [PATCH 08/86] Create test_dummy.py --- sdk/cosmos/azure-cosmos/test/test_dummy.py | 164 +++++++++++++++++++++ 1 file changed, 164 insertions(+) create mode 100644 sdk/cosmos/azure-cosmos/test/test_dummy.py diff --git a/sdk/cosmos/azure-cosmos/test/test_dummy.py b/sdk/cosmos/azure-cosmos/test/test_dummy.py new file mode 100644 index 000000000000..4fe18ef001f1 --- /dev/null +++ b/sdk/cosmos/azure-cosmos/test/test_dummy.py @@ -0,0 +1,164 @@ +# The MIT License (MIT) +# Copyright (c) Microsoft Corporation. All rights reserved. + +from collections.abc import MutableMapping +import logging +from typing import Any +import unittest +import uuid + +import pytest +import pytest_asyncio + +import azure.cosmos.exceptions as exceptions +import test_config +from azure.cosmos import PartitionKey +from azure.cosmos.aio import CosmosClient, _retry_utility_async +from azure.core.rest import HttpRequest, AsyncHttpResponse +import asyncio +import aiohttp +import sys +from azure.core.pipeline.transport import AioHttpTransport + +COLLECTION = "created_collection" +@pytest_asyncio.fixture() +async def setup(): + if (TestDummyAsync.masterKey == '[YOUR_KEY_HERE]' or + TestDummyAsync.host == '[YOUR_ENDPOINT_HERE]'): + raise Exception( + "You must specify your Azure Cosmos account values for " + "'masterKey' and 'host' at the top of this class to run the " + "tests.") + + logger = logging.getLogger('azure.cosmos') + logger.setLevel("INFO") + logger.setLevel(logging.DEBUG) + logger.addHandler(logging.StreamHandler(sys.stdout)) + custom_transport = TestDummyAsync.FaulInjectionTransport(logger) + client = CosmosClient(TestDummyAsync.host, TestDummyAsync.masterKey, consistency_level="Session", + connection_policy=TestDummyAsync.connectionPolicy, transport=custom_transport) + created_database = client.get_database_client(TestDummyAsync.TEST_DATABASE_ID) + created_collection = await created_database.create_container(TestDummyAsync.TEST_CONTAINER_SINGLE_PARTITION_ID, + partition_key=PartitionKey("/pk")) + yield { + COLLECTION: created_collection + } + + await created_database.delete_container(TestDummyAsync.TEST_CONTAINER_SINGLE_PARTITION_ID) + await client.close() + + +def error_codes(): + return [408, 500, 502, 503] + + +@pytest.mark.cosmosEmulator +@pytest.mark.asyncio +@pytest.mark.usefixtures("setup") +class TestDummyAsync: + host = test_config.TestConfig.host + masterKey = test_config.TestConfig.masterKey + connectionPolicy = test_config.TestConfig.connectionPolicy + TEST_DATABASE_ID = test_config.TestConfig.TEST_DATABASE_ID + TEST_CONTAINER_SINGLE_PARTITION_ID = "test-timeout-retry-policy-container-" + str(uuid.uuid4()) + + @pytest.mark.parametrize("error_code", error_codes()) + async def test_timeout_failover_retry_policy_for_read_success_async(self, setup, error_code): + document_definition = {'id': 'failoverDoc-' + str(uuid.uuid4()), + 'pk': 'pk', + 'name': 'sample document', + 'key': 'value'} + + created_document = await setup[COLLECTION].create_item(body=document_definition) + self.original_execute_function = _retry_utility_async.ExecuteFunctionAsync + try: + # should retry once and then succeed + mf = self.MockExecuteFunction(self.original_execute_function, 1, error_code) + _retry_utility_async.ExecuteFunctionAsync = mf + doc = await setup[COLLECTION].read_item(item=created_document['id'], + partition_key=created_document['pk']) + assert doc == created_document + finally: + _retry_utility_async.ExecuteFunctionAsync = self.original_execute_function + + @pytest.mark.parametrize("error_code", error_codes()) + async def test_timeout_failover_retry_policy_for_read_failure_async(self, setup, error_code): + document_definition = {'id': 'failoverDoc-' + str(uuid.uuid4()), + 'pk': 'pk', + 'name': 'sample document', + 'key': 'value'} + + created_document = await setup[COLLECTION].create_item(body=document_definition) + self.original_execute_function = _retry_utility_async.ExecuteFunctionAsync + try: + # should retry once and then succeed + mf = self.MockExecuteFunction(self.original_execute_function, 5, error_code) + _retry_utility_async.ExecuteFunctionAsync = mf + await setup[COLLECTION].read_item(item=created_document['id'], + partition_key=created_document['pk']) + pytest.fail("Exception was not raised.") + except exceptions.CosmosHttpResponseError as err: + assert err.status_code == error_code + finally: + _retry_utility_async.ExecuteFunctionAsync = self.original_execute_function + + @pytest.mark.parametrize("error_code", error_codes()) + async def test_timeout_failover_retry_policy_for_write_failure_async(self, setup, error_code): + document_definition = {'id': 'failoverDoc' + str(uuid.uuid4()), + 'pk': 'pk', + 'name': 'sample document', + 'key': 'value'} + + self.original_execute_function = _retry_utility_async.ExecuteFunctionAsync + try: + # timeouts should fail immediately for writes + mf = self.MockExecuteFunction(self.original_execute_function,0, error_code) + _retry_utility_async.ExecuteFunctionAsync = mf + try: + await setup[COLLECTION].create_item(body=document_definition) + pytest.fail("Exception was not raised.") + except exceptions.CosmosHttpResponseError as err: + assert err.status_code == error_code + finally: + _retry_utility_async.ExecuteFunctionAsync = self.original_execute_function + + class FaulInjectionTransport(AioHttpTransport): + def __init__(self, logger: logging.Logger, *, session: aiohttp.ClientSession | None = None, loop=None, session_owner: bool = True, **config): + self.logger = logger + super().__init__(session=session, loop=loop, session_owner=session_owner, **config) + + async def send(self, request: HttpRequest, *, stream: bool = False, proxies: MutableMapping[str, str] | None = None, **config): + # Add custom logic before sending the request + self.logger.error(f"Sending request to {request.url}") + + # Call the base class's send method to actually send the request + try: + response = await super().send(request, stream=stream, proxies=proxies, **config) + except Exception as e: + self.logger.error(f"Error: {e}") + raise + + # Add custom logic after receiving the response + self.logger.info(f"Received response with status code {response.status_code}") + + return response + + class MockExecuteFunction(object): + def __init__(self, org_func, num_exceptions, status_code): + self.org_func = org_func + self.counter = 0 + self.num_exceptions = num_exceptions + self.status_code = status_code + + def __call__(self, func, global_endpoint_manager, *args, **kwargs): + if self.counter != 0 and self.counter >= self.num_exceptions: + return self.org_func(func, global_endpoint_manager, *args, **kwargs) + else: + self.counter += 1 + raise exceptions.CosmosHttpResponseError( + status_code=self.status_code, + message="Some Exception", + response=test_config.FakeResponse({})) + +if __name__ == '__main__': + unittest.main() From dd8a466019ba9b574d093c7bffbf2d037817f2a8 Mon Sep 17 00:00:00 2001 From: Fabian Meiswinkel Date: Fri, 7 Feb 2025 02:24:06 +0000 Subject: [PATCH 09/86] Update test_dummy.py --- sdk/cosmos/azure-cosmos/test/test_dummy.py | 22 ++++++++++++++++++++-- 1 file changed, 20 insertions(+), 2 deletions(-) diff --git a/sdk/cosmos/azure-cosmos/test/test_dummy.py b/sdk/cosmos/azure-cosmos/test/test_dummy.py index 4fe18ef001f1..cd469e071ef8 100644 --- a/sdk/cosmos/azure-cosmos/test/test_dummy.py +++ b/sdk/cosmos/azure-cosmos/test/test_dummy.py @@ -3,7 +3,7 @@ from collections.abc import MutableMapping import logging -from typing import Any +from typing import Any, Callable import unittest import uuid @@ -125,11 +125,29 @@ async def test_timeout_failover_retry_policy_for_write_failure_async(self, setup class FaulInjectionTransport(AioHttpTransport): def __init__(self, logger: logging.Logger, *, session: aiohttp.ClientSession | None = None, loop=None, session_owner: bool = True, **config): self.logger = logger + self.faults = [] + self.requestTransformationOverrides = [] + self.responseTransformationOverrides = [] super().__init__(session=session, loop=loop, session_owner=session_owner, **config) + def addFault(self, predicate: Callable[[HttpRequest], bool], faultFactory: Callable[[], Exception | None]): + self.fault.append({"predicate": predicate, "faultFactory": faultFactory}) + + def addRequestTransformation(self, predicate: Callable[[HttpRequest], bool], faultFactory: Callable[[], Exception | None]): + self.fault.append({"predicate": predicate, "faultFactory": faultFactory}) + + def addResponseTransformation(self, predicate: Callable[[HttpRequest], bool], faultFactory: Callable[[], Exception | None]): + self.fault.append({"predicate": predicate, "faultFactory": faultFactory}) + async def send(self, request: HttpRequest, *, stream: bool = False, proxies: MutableMapping[str, str] | None = None, **config): + # find the first fault Factory with matching predicate if any + firstFaultFactory = next(lambda f: f["predicate"](f["predicate"]), self.faults, None) + + # Add custom logic before sending the request - self.logger.error(f"Sending request to {request.url}") + + + self.logger.info(f"Sending request to {request.url}") # Call the base class's send method to actually send the request try: From 8ac11c5a142a8091cf5a19249a6de981782059d2 Mon Sep 17 00:00:00 2001 From: Fabian Meiswinkel Date: Fri, 7 Feb 2025 07:01:31 +0000 Subject: [PATCH 10/86] Update test_dummy.py --- sdk/cosmos/azure-cosmos/test/test_dummy.py | 256 +++++++++------------ 1 file changed, 112 insertions(+), 144 deletions(-) diff --git a/sdk/cosmos/azure-cosmos/test/test_dummy.py b/sdk/cosmos/azure-cosmos/test/test_dummy.py index cd469e071ef8..4b2fda5d3437 100644 --- a/sdk/cosmos/azure-cosmos/test/test_dummy.py +++ b/sdk/cosmos/azure-cosmos/test/test_dummy.py @@ -3,6 +3,8 @@ from collections.abc import MutableMapping import logging +import time +from tokenize import String from typing import Any, Callable import unittest import uuid @@ -10,6 +12,8 @@ import pytest import pytest_asyncio +from azure.cosmos.aio._container import ContainerProxy +from azure.cosmos.aio._database import DatabaseProxy import azure.cosmos.exceptions as exceptions import test_config from azure.cosmos import PartitionKey @@ -21,162 +25,126 @@ from azure.core.pipeline.transport import AioHttpTransport COLLECTION = "created_collection" -@pytest_asyncio.fixture() -async def setup(): - if (TestDummyAsync.masterKey == '[YOUR_KEY_HERE]' or - TestDummyAsync.host == '[YOUR_ENDPOINT_HERE]'): - raise Exception( - "You must specify your Azure Cosmos account values for " - "'masterKey' and 'host' at the top of this class to run the " - "tests.") - - logger = logging.getLogger('azure.cosmos') - logger.setLevel("INFO") - logger.setLevel(logging.DEBUG) - logger.addHandler(logging.StreamHandler(sys.stdout)) - custom_transport = TestDummyAsync.FaulInjectionTransport(logger) - client = CosmosClient(TestDummyAsync.host, TestDummyAsync.masterKey, consistency_level="Session", - connection_policy=TestDummyAsync.connectionPolicy, transport=custom_transport) - created_database = client.get_database_client(TestDummyAsync.TEST_DATABASE_ID) - created_collection = await created_database.create_container(TestDummyAsync.TEST_CONTAINER_SINGLE_PARTITION_ID, - partition_key=PartitionKey("/pk")) - yield { - COLLECTION: created_collection - } - - await created_database.delete_container(TestDummyAsync.TEST_CONTAINER_SINGLE_PARTITION_ID) - await client.close() - - -def error_codes(): - return [408, 500, 502, 503] +logger = logging.getLogger('azure.cosmos') +logger.setLevel("INFO") +logger.setLevel(logging.DEBUG) +logger.addHandler(logging.StreamHandler(sys.stdout)) + +host = test_config.TestConfig.host +masterKey = test_config.TestConfig.masterKey +connectionPolicy = test_config.TestConfig.connectionPolicy +TEST_DATABASE_ID = test_config.TestConfig.TEST_DATABASE_ID + +@pytest.fixture() +def setup(): + return @pytest.mark.cosmosEmulator @pytest.mark.asyncio @pytest.mark.usefixtures("setup") class TestDummyAsync: - host = test_config.TestConfig.host - masterKey = test_config.TestConfig.masterKey - connectionPolicy = test_config.TestConfig.connectionPolicy - TEST_DATABASE_ID = test_config.TestConfig.TEST_DATABASE_ID - TEST_CONTAINER_SINGLE_PARTITION_ID = "test-timeout-retry-policy-container-" + str(uuid.uuid4()) - - @pytest.mark.parametrize("error_code", error_codes()) - async def test_timeout_failover_retry_policy_for_read_success_async(self, setup, error_code): - document_definition = {'id': 'failoverDoc-' + str(uuid.uuid4()), - 'pk': 'pk', - 'name': 'sample document', - 'key': 'value'} - - created_document = await setup[COLLECTION].create_item(body=document_definition) - self.original_execute_function = _retry_utility_async.ExecuteFunctionAsync - try: - # should retry once and then succeed - mf = self.MockExecuteFunction(self.original_execute_function, 1, error_code) - _retry_utility_async.ExecuteFunctionAsync = mf - doc = await setup[COLLECTION].read_item(item=created_document['id'], - partition_key=created_document['pk']) - assert doc == created_document - finally: - _retry_utility_async.ExecuteFunctionAsync = self.original_execute_function - - @pytest.mark.parametrize("error_code", error_codes()) - async def test_timeout_failover_retry_policy_for_read_failure_async(self, setup, error_code): - document_definition = {'id': 'failoverDoc-' + str(uuid.uuid4()), - 'pk': 'pk', + logger = logger + + async def setup(self, custom_transport: AioHttpTransport): + + host = test_config.TestConfig.host + masterKey = test_config.TestConfig.masterKey + connectionPolicy = test_config.TestConfig.connectionPolicy + TEST_DATABASE_ID = test_config.TestConfig.TEST_DATABASE_ID + TEST_CONTAINER_SINGLE_PARTITION_ID = "test-timeout-retry-policy-container-" + str(uuid.uuid4()) + + if (masterKey == '[YOUR_KEY_HERE]' or + host == '[YOUR_ENDPOINT_HERE]'): + raise Exception( + "You must specify your Azure Cosmos account values for " + "'masterKey' and 'host' at the top of this class to run the " + "tests.") + + client = CosmosClient(host, masterKey, consistency_level="Session", + connection_policy=connectionPolicy, transport=custom_transport) + created_database: DatabaseProxy = client.get_database_client(TEST_DATABASE_ID) + created_collection = await created_database.create_container(TEST_CONTAINER_SINGLE_PARTITION_ID, + partition_key=PartitionKey("/pk")) + return {"client": client, "db": created_database, "col": created_collection} + + async def test_throws_injected_error(self, setup): + custom_transport = FaulInjectionTransport(logger) + idValue: str = 'failoverDoc-' + str(uuid.uuid4()) + document_definition = {'id': idValue, + 'pk': idValue, 'name': 'sample document', 'key': 'value'} - created_document = await setup[COLLECTION].create_item(body=document_definition) - self.original_execute_function = _retry_utility_async.ExecuteFunctionAsync - try: - # should retry once and then succeed - mf = self.MockExecuteFunction(self.original_execute_function, 5, error_code) - _retry_utility_async.ExecuteFunctionAsync = mf - await setup[COLLECTION].read_item(item=created_document['id'], - partition_key=created_document['pk']) - pytest.fail("Exception was not raised.") - except exceptions.CosmosHttpResponseError as err: - assert err.status_code == error_code - finally: - _retry_utility_async.ExecuteFunctionAsync = self.original_execute_function - - @pytest.mark.parametrize("error_code", error_codes()) - async def test_timeout_failover_retry_policy_for_write_failure_async(self, setup, error_code): - document_definition = {'id': 'failoverDoc' + str(uuid.uuid4()), - 'pk': 'pk', - 'name': 'sample document', - 'key': 'value'} - - self.original_execute_function = _retry_utility_async.ExecuteFunctionAsync - try: - # timeouts should fail immediately for writes - mf = self.MockExecuteFunction(self.original_execute_function,0, error_code) - _retry_utility_async.ExecuteFunctionAsync = mf - try: - await setup[COLLECTION].create_item(body=document_definition) - pytest.fail("Exception was not raised.") - except exceptions.CosmosHttpResponseError as err: - assert err.status_code == error_code - finally: - _retry_utility_async.ExecuteFunctionAsync = self.original_execute_function - - class FaulInjectionTransport(AioHttpTransport): - def __init__(self, logger: logging.Logger, *, session: aiohttp.ClientSession | None = None, loop=None, session_owner: bool = True, **config): - self.logger = logger - self.faults = [] - self.requestTransformationOverrides = [] - self.responseTransformationOverrides = [] - super().__init__(session=session, loop=loop, session_owner=session_owner, **config) - - def addFault(self, predicate: Callable[[HttpRequest], bool], faultFactory: Callable[[], Exception | None]): - self.fault.append({"predicate": predicate, "faultFactory": faultFactory}) - - def addRequestTransformation(self, predicate: Callable[[HttpRequest], bool], faultFactory: Callable[[], Exception | None]): - self.fault.append({"predicate": predicate, "faultFactory": faultFactory}) - - def addResponseTransformation(self, predicate: Callable[[HttpRequest], bool], faultFactory: Callable[[], Exception | None]): - self.fault.append({"predicate": predicate, "faultFactory": faultFactory}) - - async def send(self, request: HttpRequest, *, stream: bool = False, proxies: MutableMapping[str, str] | None = None, **config): - # find the first fault Factory with matching predicate if any - firstFaultFactory = next(lambda f: f["predicate"](f["predicate"]), self.faults, None) - - - # Add custom logic before sending the request - - + initializedObjects = await self.setup(custom_transport) + container: ContainerProxy = initializedObjects["col"] + + created_document = await container.create_item(body=document_definition) + start = time.perf_counter() + + while ((time.perf_counter() - start) < 7): + await container.read_item(idValue, partition_key=idValue) + await asyncio.sleep(2) + + created_database: DatabaseProxy = initializedObjects["db"] + await created_database.delete_container(TestDummyAsync.TEST_CONTAINER_SINGLE_PARTITION_ID) + client: CosmosClient = initializedObjects["client"] + await client.close() + + +class FaulInjectionTransport(AioHttpTransport): + def __init__(self, logger: logging.Logger, *, session: aiohttp.ClientSession | None = None, loop=None, session_owner: bool = True, **config): + self.logger = logger + self.faults = [] + self.requestTransformations = [] + self.responseTransformations = [] + super().__init__(session=session, loop=loop, session_owner=session_owner, **config) + + def addFault(self, predicate: Callable[[HttpRequest], bool], faultFactory: Callable[[HttpRequest], asyncio.Task[Exception]]): + self.faults.append({"predicate": predicate, "apply": faultFactory}) + + def addRequestTransformation(self, predicate: Callable[[HttpRequest], bool], requestTransformation: Callable[[HttpRequest], asyncio.Task[HttpRequest]]): + self.requestTransformations.append({"predicate": predicate, "apply": requestTransformation}) + + def addResponseTransformation(self, predicate: Callable[[HttpRequest], bool], responseTransformation: Callable[[HttpRequest, Callable[[HttpRequest], asyncio.Task[AsyncHttpResponse]]], AsyncHttpResponse]): + self.responseTransformations.append({ + "predicate": predicate, + "apply": responseTransformation}) + + def firstItem(self, iterable, condition=lambda x: True): + """ + Returns the first item in the `iterable` that satisfies the `condition`. + + If no item satisfies the condition, it returns None. + """ + return next((x for x in iterable if condition(x)), None) + + async def send(self, request: HttpRequest, *, stream: bool = False, proxies: MutableMapping[str, str] | None = None, **config): + # find the first fault Factory with matching predicate if any + firstFaultFactory = self.firstItem(iter(self.faults), lambda f: f["predicate"](request)) + if (firstFaultFactory != None): + self.logger.info("") + return await firstFaultFactory["apply"]() + + # apply the chain of request transformations with matching predicates if any + matchingRequestTransformations = filter(lambda f: f["predicate"](f["predicate"]), iter(self.requestTransformations)) + for currentTransformation in matchingRequestTransformations: + request = await currentTransformation["apply"](request) + + firstResonseTransformation = self.firstItem(iter(self.responseTransformations), lambda f: f["predicate"](request)) + + getResponseTask = super().send(request, stream=stream, proxies=proxies, **config) + + if (firstResonseTransformation != None): + self.logger.info(f"Invoking response transformation") + response = await firstResonseTransformation["apply"](request, lambda: getResponseTask) + self.logger.info(f"Received response transformation result with status code {response.status_code}") + return response + else: self.logger.info(f"Sending request to {request.url}") - - # Call the base class's send method to actually send the request - try: - response = await super().send(request, stream=stream, proxies=proxies, **config) - except Exception as e: - self.logger.error(f"Error: {e}") - raise - - # Add custom logic after receiving the response + response = await getResponseTask self.logger.info(f"Received response with status code {response.status_code}") - return response - class MockExecuteFunction(object): - def __init__(self, org_func, num_exceptions, status_code): - self.org_func = org_func - self.counter = 0 - self.num_exceptions = num_exceptions - self.status_code = status_code - - def __call__(self, func, global_endpoint_manager, *args, **kwargs): - if self.counter != 0 and self.counter >= self.num_exceptions: - return self.org_func(func, global_endpoint_manager, *args, **kwargs) - else: - self.counter += 1 - raise exceptions.CosmosHttpResponseError( - status_code=self.status_code, - message="Some Exception", - response=test_config.FakeResponse({})) - if __name__ == '__main__': unittest.main() From b53e2e9ecbd1cecfc163ffd7666759d885845603 Mon Sep 17 00:00:00 2001 From: Fabian Meiswinkel Date: Fri, 7 Feb 2025 13:34:33 +0000 Subject: [PATCH 11/86] Update test_dummy.py --- sdk/cosmos/azure-cosmos/test/test_dummy.py | 38 ++++++++++++++-------- 1 file changed, 25 insertions(+), 13 deletions(-) diff --git a/sdk/cosmos/azure-cosmos/test/test_dummy.py b/sdk/cosmos/azure-cosmos/test/test_dummy.py index 4b2fda5d3437..f30a2fc7099f 100644 --- a/sdk/cosmos/azure-cosmos/test/test_dummy.py +++ b/sdk/cosmos/azure-cosmos/test/test_dummy.py @@ -46,6 +46,19 @@ def setup(): class TestDummyAsync: logger = logger + async def cleanup(self, initializedObjects: dict[str, Any]): + created_database: DatabaseProxy = initializedObjects["db"] + try: + await created_database.delete_container(initializedObjects["col"]) + except Exception as containerDeleteError: + self.logger.warn("Exception trying to delete database {}. {}".format(created_database.id, containerDeleteError)) + finally: + client: CosmosClient = initializedObjects["client"] + try: + await client.close() + except Exception as closeError: + self.logger.warn("Exception trying to delete database {}. {}".format(created_database.id, closeError)) + async def setup(self, custom_transport: AioHttpTransport): host = test_config.TestConfig.host @@ -77,19 +90,18 @@ async def test_throws_injected_error(self, setup): 'key': 'value'} initializedObjects = await self.setup(custom_transport) - container: ContainerProxy = initializedObjects["col"] - - created_document = await container.create_item(body=document_definition) - start = time.perf_counter() - - while ((time.perf_counter() - start) < 7): - await container.read_item(idValue, partition_key=idValue) - await asyncio.sleep(2) - - created_database: DatabaseProxy = initializedObjects["db"] - await created_database.delete_container(TestDummyAsync.TEST_CONTAINER_SINGLE_PARTITION_ID) - client: CosmosClient = initializedObjects["client"] - await client.close() + try: + container: ContainerProxy = initializedObjects["col"] + + created_document = await container.create_item(body=document_definition) + start = time.perf_counter() + + while ((time.perf_counter() - start) < 7): + await container.read_item(idValue, partition_key=idValue) + await asyncio.sleep(2) + + finally: + self.cleanup(initializedObjects) class FaulInjectionTransport(AioHttpTransport): From 973ec4412ae0fa847dedb86a96b4caea2ea34ae9 Mon Sep 17 00:00:00 2001 From: Fabian Meiswinkel Date: Fri, 7 Feb 2025 15:30:49 +0000 Subject: [PATCH 12/86] Iterating on fault injection tooling --- sdk/cosmos/azure-cosmos/pytest.ini | 3 + sdk/cosmos/azure-cosmos/test/test_dummy.py | 69 ++++++++++++++++++---- 2 files changed, 61 insertions(+), 11 deletions(-) create mode 100644 sdk/cosmos/azure-cosmos/pytest.ini diff --git a/sdk/cosmos/azure-cosmos/pytest.ini b/sdk/cosmos/azure-cosmos/pytest.ini new file mode 100644 index 000000000000..e211052edef0 --- /dev/null +++ b/sdk/cosmos/azure-cosmos/pytest.ini @@ -0,0 +1,3 @@ +[pytest] +markers = + cosmosEmulator: marks tests as depending in Cosmos DB Emulator \ No newline at end of file diff --git a/sdk/cosmos/azure-cosmos/test/test_dummy.py b/sdk/cosmos/azure-cosmos/test/test_dummy.py index f30a2fc7099f..9a96c75eb598 100644 --- a/sdk/cosmos/azure-cosmos/test/test_dummy.py +++ b/sdk/cosmos/azure-cosmos/test/test_dummy.py @@ -14,7 +14,7 @@ from azure.cosmos.aio._container import ContainerProxy from azure.cosmos.aio._database import DatabaseProxy -import azure.cosmos.exceptions as exceptions +from azure.cosmos.exceptions import CosmosHttpResponseError import test_config from azure.cosmos import PartitionKey from azure.cosmos.aio import CosmosClient, _retry_utility_async @@ -44,20 +44,19 @@ def setup(): @pytest.mark.asyncio @pytest.mark.usefixtures("setup") class TestDummyAsync: - logger = logger async def cleanup(self, initializedObjects: dict[str, Any]): created_database: DatabaseProxy = initializedObjects["db"] try: await created_database.delete_container(initializedObjects["col"]) except Exception as containerDeleteError: - self.logger.warn("Exception trying to delete database {}. {}".format(created_database.id, containerDeleteError)) + logger.warn("Exception trying to delete database {}. {}".format(created_database.id, containerDeleteError)) finally: client: CosmosClient = initializedObjects["client"] try: await client.close() except Exception as closeError: - self.logger.warn("Exception trying to delete database {}. {}".format(created_database.id, closeError)) + logger.warn("Exception trying to delete database {}. {}".format(created_database.id, closeError)) async def setup(self, custom_transport: AioHttpTransport): @@ -75,15 +74,62 @@ async def setup(self, custom_transport: AioHttpTransport): "tests.") client = CosmosClient(host, masterKey, consistency_level="Session", - connection_policy=connectionPolicy, transport=custom_transport) + connection_policy=connectionPolicy, transport=custom_transport, + logger=logger) created_database: DatabaseProxy = client.get_database_client(TEST_DATABASE_ID) created_collection = await created_database.create_container(TEST_CONTAINER_SINGLE_PARTITION_ID, partition_key=PartitionKey("/pk")) return {"client": client, "db": created_database, "col": created_collection} + def predicate_url_contains_id(self, r: HttpRequest, id: str): + logger.info("FaultPredicate for request {} {}".format(r.method, r.url)); + return id in r.url; + + def predicate_req_payload_contains_id(self, r: HttpRequest, id: str): + logger.info("FaultPredicate for request {} {} - request payload {}".format( + r.method, + r.url, + "NONE" if r.body is None else r.body)); + + if (r.body == None): + return False + + + return '"id":"{}"'.format(id) in r.body; + + async def throw_after_delay(self, delayInMs: int, error: Exception): + await asyncio.sleep(delayInMs/1000.0) + raise error + async def test_throws_injected_error(self, setup): + idValue: str = str(uuid.uuid4()) + document_definition = {'id': idValue, + 'pk': idValue, + 'name': 'sample document', + 'key': 'value'} + + custom_transport = FaulInjectionTransport(logger) + predicate : Callable[[HttpRequest], bool] = lambda r: self.predicate_req_payload_contains_id(r, idValue) + custom_transport.addFault(predicate, lambda: self.throw_after_delay( + 500, + CosmosHttpResponseError( + status_code=502, + message="Some random reverse proxy error."))) + + initializedObjects = await self.setup(custom_transport) + try: + container: ContainerProxy = initializedObjects["col"] + await container.create_item(body=document_definition) + pytest.fail("Expected exception not thrown") + except CosmosHttpResponseError as cosmosError: + if (cosmosError.status_code != 502): + raise cosmosError + finally: + await self.cleanup(initializedObjects) + + async def test_succeeds_with_multiple_endpoints(self, setup): custom_transport = FaulInjectionTransport(logger) - idValue: str = 'failoverDoc-' + str(uuid.uuid4()) + idValue: str = str(uuid.uuid4()) document_definition = {'id': idValue, 'pk': idValue, 'name': 'sample document', @@ -96,12 +142,12 @@ async def test_throws_injected_error(self, setup): created_document = await container.create_item(body=document_definition) start = time.perf_counter() - while ((time.perf_counter() - start) < 7): + while ((time.perf_counter() - start) < 2): await container.read_item(idValue, partition_key=idValue) - await asyncio.sleep(2) + await asyncio.sleep(0.2) finally: - self.cleanup(initializedObjects) + await self.cleanup(initializedObjects) class FaulInjectionTransport(AioHttpTransport): @@ -135,8 +181,9 @@ async def send(self, request: HttpRequest, *, stream: bool = False, proxies: Mut # find the first fault Factory with matching predicate if any firstFaultFactory = self.firstItem(iter(self.faults), lambda f: f["predicate"](request)) if (firstFaultFactory != None): - self.logger.info("") - return await firstFaultFactory["apply"]() + injectedError = await firstFaultFactory["apply"]() + self.logger.info("Found to-be-injected error {}".format(injectedError)) + raise injectedError # apply the chain of request transformations with matching predicates if any matchingRequestTransformations = filter(lambda f: f["predicate"](f["predicate"]), iter(self.requestTransformations)) From 5d72848fbf91c295007acbff2a7e110572ca07c2 Mon Sep 17 00:00:00 2001 From: Fabian Meiswinkel Date: Fri, 7 Feb 2025 16:31:23 +0000 Subject: [PATCH 13/86] Refactoring to have FaultInjectionTransport in its own file --- .../test/_fault_injection_transport.py | 89 +++++++++++++++++++ sdk/cosmos/azure-cosmos/test/test_dummy.py | 69 ++------------ 2 files changed, 94 insertions(+), 64 deletions(-) create mode 100644 sdk/cosmos/azure-cosmos/test/_fault_injection_transport.py diff --git a/sdk/cosmos/azure-cosmos/test/_fault_injection_transport.py b/sdk/cosmos/azure-cosmos/test/_fault_injection_transport.py new file mode 100644 index 000000000000..aa137fa34234 --- /dev/null +++ b/sdk/cosmos/azure-cosmos/test/_fault_injection_transport.py @@ -0,0 +1,89 @@ +# The MIT License (MIT) +# Copyright (c) 2014 Microsoft Corporation + +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: + +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. + +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +"""AioHttpTransport allowing injection of faults between SDK and Cosmos Gateway +""" + +import asyncio +import aiohttp +import logging +import sys + +from azure.core.pipeline.transport import AioHttpTransport +from azure.core.rest import HttpRequest, AsyncHttpResponse +from collections.abc import MutableMapping +from typing import Any, Callable + +class FaulInjectionTransport(AioHttpTransport): + def __init__(self, logger: logging.Logger, *, session: aiohttp.ClientSession | None = None, loop=None, session_owner: bool = True, **config): + self.logger = logger + self.faults = [] + self.requestTransformations = [] + self.responseTransformations = [] + super().__init__(session=session, loop=loop, session_owner=session_owner, **config) + + def addFault(self, predicate: Callable[[HttpRequest], bool], faultFactory: Callable[[HttpRequest], asyncio.Task[Exception]]): + self.faults.append({"predicate": predicate, "apply": faultFactory}) + + def addRequestTransformation(self, predicate: Callable[[HttpRequest], bool], requestTransformation: Callable[[HttpRequest], asyncio.Task[HttpRequest]]): + self.requestTransformations.append({"predicate": predicate, "apply": requestTransformation}) + + def addResponseTransformation(self, predicate: Callable[[HttpRequest], bool], responseTransformation: Callable[[HttpRequest, Callable[[HttpRequest], asyncio.Task[AsyncHttpResponse]]], AsyncHttpResponse]): + self.responseTransformations.append({ + "predicate": predicate, + "apply": responseTransformation}) + + def firstItem(self, iterable, condition=lambda x: True): + """ + Returns the first item in the `iterable` that satisfies the `condition`. + + If no item satisfies the condition, it returns None. + """ + return next((x for x in iterable if condition(x)), None) + + async def send(self, request: HttpRequest, *, stream: bool = False, proxies: MutableMapping[str, str] | None = None, **config): + # find the first fault Factory with matching predicate if any + firstFaultFactory = self.firstItem(iter(self.faults), lambda f: f["predicate"](request)) + if (firstFaultFactory != None): + injectedError = await firstFaultFactory["apply"]() + self.logger.info("Found to-be-injected error {}".format(injectedError)) + raise injectedError + + # apply the chain of request transformations with matching predicates if any + matchingRequestTransformations = filter(lambda f: f["predicate"](f["predicate"]), iter(self.requestTransformations)) + for currentTransformation in matchingRequestTransformations: + request = await currentTransformation["apply"](request) + + firstResonseTransformation = self.firstItem(iter(self.responseTransformations), lambda f: f["predicate"](request)) + + getResponseTask = super().send(request, stream=stream, proxies=proxies, **config) + + if (firstResonseTransformation != None): + self.logger.info(f"Invoking response transformation") + response = await firstResonseTransformation["apply"](request, lambda: getResponseTask) + self.logger.info(f"Received response transformation result with status code {response.status_code}") + return response + else: + self.logger.info(f"Sending request to {request.url}") + response = await getResponseTask + self.logger.info(f"Received response with status code {response.status_code}") + return response + \ No newline at end of file diff --git a/sdk/cosmos/azure-cosmos/test/test_dummy.py b/sdk/cosmos/azure-cosmos/test/test_dummy.py index 9a96c75eb598..45408793b900 100644 --- a/sdk/cosmos/azure-cosmos/test/test_dummy.py +++ b/sdk/cosmos/azure-cosmos/test/test_dummy.py @@ -1,16 +1,12 @@ # The MIT License (MIT) # Copyright (c) Microsoft Corporation. All rights reserved. -from collections.abc import MutableMapping import logging import time -from tokenize import String -from typing import Any, Callable import unittest import uuid import pytest -import pytest_asyncio from azure.cosmos.aio._container import ContainerProxy from azure.cosmos.aio._database import DatabaseProxy @@ -18,11 +14,12 @@ import test_config from azure.cosmos import PartitionKey from azure.cosmos.aio import CosmosClient, _retry_utility_async -from azure.core.rest import HttpRequest, AsyncHttpResponse +from azure.core.rest import HttpRequest import asyncio -import aiohttp import sys from azure.core.pipeline.transport import AioHttpTransport +from typing import Any, Callable +import _fault_injection_transport COLLECTION = "created_collection" logger = logging.getLogger('azure.cosmos') @@ -108,7 +105,7 @@ async def test_throws_injected_error(self, setup): 'name': 'sample document', 'key': 'value'} - custom_transport = FaulInjectionTransport(logger) + custom_transport = _fault_injection_transport.FaulInjectionTransport(logger) predicate : Callable[[HttpRequest], bool] = lambda r: self.predicate_req_payload_contains_id(r, idValue) custom_transport.addFault(predicate, lambda: self.throw_after_delay( 500, @@ -128,7 +125,7 @@ async def test_throws_injected_error(self, setup): await self.cleanup(initializedObjects) async def test_succeeds_with_multiple_endpoints(self, setup): - custom_transport = FaulInjectionTransport(logger) + custom_transport = _fault_injection_transport.FaulInjectionTransport(logger) idValue: str = str(uuid.uuid4()) document_definition = {'id': idValue, 'pk': idValue, @@ -149,61 +146,5 @@ async def test_succeeds_with_multiple_endpoints(self, setup): finally: await self.cleanup(initializedObjects) - -class FaulInjectionTransport(AioHttpTransport): - def __init__(self, logger: logging.Logger, *, session: aiohttp.ClientSession | None = None, loop=None, session_owner: bool = True, **config): - self.logger = logger - self.faults = [] - self.requestTransformations = [] - self.responseTransformations = [] - super().__init__(session=session, loop=loop, session_owner=session_owner, **config) - - def addFault(self, predicate: Callable[[HttpRequest], bool], faultFactory: Callable[[HttpRequest], asyncio.Task[Exception]]): - self.faults.append({"predicate": predicate, "apply": faultFactory}) - - def addRequestTransformation(self, predicate: Callable[[HttpRequest], bool], requestTransformation: Callable[[HttpRequest], asyncio.Task[HttpRequest]]): - self.requestTransformations.append({"predicate": predicate, "apply": requestTransformation}) - - def addResponseTransformation(self, predicate: Callable[[HttpRequest], bool], responseTransformation: Callable[[HttpRequest, Callable[[HttpRequest], asyncio.Task[AsyncHttpResponse]]], AsyncHttpResponse]): - self.responseTransformations.append({ - "predicate": predicate, - "apply": responseTransformation}) - - def firstItem(self, iterable, condition=lambda x: True): - """ - Returns the first item in the `iterable` that satisfies the `condition`. - - If no item satisfies the condition, it returns None. - """ - return next((x for x in iterable if condition(x)), None) - - async def send(self, request: HttpRequest, *, stream: bool = False, proxies: MutableMapping[str, str] | None = None, **config): - # find the first fault Factory with matching predicate if any - firstFaultFactory = self.firstItem(iter(self.faults), lambda f: f["predicate"](request)) - if (firstFaultFactory != None): - injectedError = await firstFaultFactory["apply"]() - self.logger.info("Found to-be-injected error {}".format(injectedError)) - raise injectedError - - # apply the chain of request transformations with matching predicates if any - matchingRequestTransformations = filter(lambda f: f["predicate"](f["predicate"]), iter(self.requestTransformations)) - for currentTransformation in matchingRequestTransformations: - request = await currentTransformation["apply"](request) - - firstResonseTransformation = self.firstItem(iter(self.responseTransformations), lambda f: f["predicate"](request)) - - getResponseTask = super().send(request, stream=stream, proxies=proxies, **config) - - if (firstResonseTransformation != None): - self.logger.info(f"Invoking response transformation") - response = await firstResonseTransformation["apply"](request, lambda: getResponseTask) - self.logger.info(f"Received response transformation result with status code {response.status_code}") - return response - else: - self.logger.info(f"Sending request to {request.url}") - response = await getResponseTask - self.logger.info(f"Received response with status code {response.status_code}") - return response - if __name__ == '__main__': unittest.main() From 8c9aa4b370afbc7596acd8303959d3aef0c486cd Mon Sep 17 00:00:00 2001 From: Fabian Meiswinkel Date: Mon, 10 Feb 2025 14:53:52 +0000 Subject: [PATCH 14/86] Update test_dummy.py --- sdk/cosmos/azure-cosmos/test/test_dummy.py | 83 ++++++++++++++-------- 1 file changed, 53 insertions(+), 30 deletions(-) diff --git a/sdk/cosmos/azure-cosmos/test/test_dummy.py b/sdk/cosmos/azure-cosmos/test/test_dummy.py index 45408793b900..91c9e219350b 100644 --- a/sdk/cosmos/azure-cosmos/test/test_dummy.py +++ b/sdk/cosmos/azure-cosmos/test/test_dummy.py @@ -20,8 +20,10 @@ from azure.core.pipeline.transport import AioHttpTransport from typing import Any, Callable import _fault_injection_transport +import os COLLECTION = "created_collection" +MGMT_TIMEOUT = 3.0 logger = logging.getLogger('azure.cosmos') logger.setLevel("INFO") logger.setLevel(logging.DEBUG) @@ -36,47 +38,66 @@ def setup(): return - @pytest.mark.cosmosEmulator @pytest.mark.asyncio @pytest.mark.usefixtures("setup") class TestDummyAsync: - - async def cleanup(self, initializedObjects: dict[str, Any]): - created_database: DatabaseProxy = initializedObjects["db"] + @classmethod + def setup_class(cls): + logger.info("starting class: {} execution".format(cls.__name__)) + cls.host = test_config.TestConfig.host + cls.masterKey = test_config.TestConfig.masterKey + + if (cls.masterKey == '[YOUR_KEY_HERE]' or + cls.host == '[YOUR_ENDPOINT_HERE]'): + raise Exception( + "You must specify your Azure Cosmos account values for " + "'masterKey' and 'host' at the top of this class to run the " + "tests.") + + cls.connectionPolicy = test_config.TestConfig.connectionPolicy + cls.database_id = test_config.TestConfig.TEST_DATABASE_ID + cls.single_partition_container_name= os.path.basename(__file__) + str(uuid.uuid4()) + + cls.mgmtClient = CosmosClient(host, masterKey, consistency_level="Session", + connection_policy=connectionPolicy, logger=logger) + created_database: DatabaseProxy = cls.mgmtClient.get_database_client(cls.database_id) + asyncio.run(asyncio.wait_for( + created_database.create_container( + cls.single_partition_container_name, + partition_key=PartitionKey("/pk")), + MGMT_TIMEOUT)) + + @classmethod + def teardown_class(cls): + logger.info("tearing down class: {}".format(cls.__name__)) + created_database: DatabaseProxy = cls.mgmtClient.get_database_client(cls.database_id) try: - await created_database.delete_container(initializedObjects["col"]) + asyncio.run(asyncio.wait_for( + created_database.delete_container(cls.single_partition_container_name), + MGMT_TIMEOUT)) except Exception as containerDeleteError: logger.warn("Exception trying to delete database {}. {}".format(created_database.id, containerDeleteError)) finally: - client: CosmosClient = initializedObjects["client"] try: - await client.close() + asyncio.run(asyncio.wait_for(cls.mgmtClient.close(), MGMT_TIMEOUT)) except Exception as closeError: logger.warn("Exception trying to delete database {}. {}".format(created_database.id, closeError)) - async def setup(self, custom_transport: AioHttpTransport): - - host = test_config.TestConfig.host - masterKey = test_config.TestConfig.masterKey - connectionPolicy = test_config.TestConfig.connectionPolicy - TEST_DATABASE_ID = test_config.TestConfig.TEST_DATABASE_ID - TEST_CONTAINER_SINGLE_PARTITION_ID = "test-timeout-retry-policy-container-" + str(uuid.uuid4()) - - if (masterKey == '[YOUR_KEY_HERE]' or - host == '[YOUR_ENDPOINT_HERE]'): - raise Exception( - "You must specify your Azure Cosmos account values for " - "'masterKey' and 'host' at the top of this class to run the " - "tests.") - + def setup_method_with_custom_transport(self, custom_transport: AioHttpTransport): client = CosmosClient(host, masterKey, consistency_level="Session", connection_policy=connectionPolicy, transport=custom_transport, logger=logger) - created_database: DatabaseProxy = client.get_database_client(TEST_DATABASE_ID) - created_collection = await created_database.create_container(TEST_CONTAINER_SINGLE_PARTITION_ID, - partition_key=PartitionKey("/pk")) - return {"client": client, "db": created_database, "col": created_collection} + db: DatabaseProxy = client.get_database_client(TEST_DATABASE_ID) + container: ContainerProxy = db.get_container_client(self.single_partition_container_name) + return {"client": client, "db": db, "col": container} + + def cleanup_method(self, initializedObjects: dict[str, Any]): + method_client: CosmosClient = initializedObjects["client"] + try: + asyncio.run(asyncio.wait_for(method_client.close(), MGMT_TIMEOUT)) + except Exception as closeError: + logger.warning("Exception trying to close method client.") def predicate_url_contains_id(self, r: HttpRequest, id: str): logger.info("FaultPredicate for request {} {}".format(r.method, r.url)); @@ -113,7 +134,7 @@ async def test_throws_injected_error(self, setup): status_code=502, message="Some random reverse proxy error."))) - initializedObjects = await self.setup(custom_transport) + initializedObjects = self.setup_method_with_custom_transport(custom_transport) try: container: ContainerProxy = initializedObjects["col"] await container.create_item(body=document_definition) @@ -122,7 +143,9 @@ async def test_throws_injected_error(self, setup): if (cosmosError.status_code != 502): raise cosmosError finally: - await self.cleanup(initializedObjects) + cleanupOp = self.cleanup_method(initializedObjects) + if (cleanupOp != None): + await cleanupOp async def test_succeeds_with_multiple_endpoints(self, setup): custom_transport = _fault_injection_transport.FaulInjectionTransport(logger) @@ -132,7 +155,7 @@ async def test_succeeds_with_multiple_endpoints(self, setup): 'name': 'sample document', 'key': 'value'} - initializedObjects = await self.setup(custom_transport) + initializedObjects = self.setup_method_with_custom_transport(custom_transport) try: container: ContainerProxy = initializedObjects["col"] @@ -144,7 +167,7 @@ async def test_succeeds_with_multiple_endpoints(self, setup): await asyncio.sleep(0.2) finally: - await self.cleanup(initializedObjects) + self.cleanup_method(initializedObjects) if __name__ == '__main__': unittest.main() From 7260e9d156ffc4321f0f3c13c89ddd136e8ea46a Mon Sep 17 00:00:00 2001 From: Fabian Meiswinkel Date: Tue, 18 Feb 2025 12:43:39 +0000 Subject: [PATCH 15/86] Reafctoring FaultInjectionTransport --- .../test/_fault_injection_transport.py | 42 ++++++++++++++++--- sdk/cosmos/azure-cosmos/test/test_dummy.py | 41 ++++++------------ 2 files changed, 49 insertions(+), 34 deletions(-) diff --git a/sdk/cosmos/azure-cosmos/test/_fault_injection_transport.py b/sdk/cosmos/azure-cosmos/test/_fault_injection_transport.py index aa137fa34234..123658b1516c 100644 --- a/sdk/cosmos/azure-cosmos/test/_fault_injection_transport.py +++ b/sdk/cosmos/azure-cosmos/test/_fault_injection_transport.py @@ -25,14 +25,13 @@ import asyncio import aiohttp import logging -import sys from azure.core.pipeline.transport import AioHttpTransport from azure.core.rest import HttpRequest, AsyncHttpResponse from collections.abc import MutableMapping from typing import Any, Callable -class FaulInjectionTransport(AioHttpTransport): +class FaultInjectionTransport(AioHttpTransport): def __init__(self, logger: logging.Logger, *, session: aiohttp.ClientSession | None = None, loop=None, session_owner: bool = True, **config): self.logger = logger self.faults = [] @@ -46,7 +45,7 @@ def addFault(self, predicate: Callable[[HttpRequest], bool], faultFactory: Calla def addRequestTransformation(self, predicate: Callable[[HttpRequest], bool], requestTransformation: Callable[[HttpRequest], asyncio.Task[HttpRequest]]): self.requestTransformations.append({"predicate": predicate, "apply": requestTransformation}) - def addResponseTransformation(self, predicate: Callable[[HttpRequest], bool], responseTransformation: Callable[[HttpRequest, Callable[[HttpRequest], asyncio.Task[AsyncHttpResponse]]], AsyncHttpResponse]): + def add_response_transformation(self, predicate: Callable[[HttpRequest], bool], responseTransformation: Callable[[HttpRequest, Callable[[HttpRequest], asyncio.Task[AsyncHttpResponse]]], asyncio.Task[AsyncHttpResponse]]): self.responseTransformations.append({ "predicate": predicate, "apply": responseTransformation}) @@ -59,7 +58,7 @@ def firstItem(self, iterable, condition=lambda x: True): """ return next((x for x in iterable if condition(x)), None) - async def send(self, request: HttpRequest, *, stream: bool = False, proxies: MutableMapping[str, str] | None = None, **config): + async def send(self, request: HttpRequest, *, stream: bool = False, proxies: MutableMapping[str, str] | None = None, **config) -> AsyncHttpResponse: # find the first fault Factory with matching predicate if any firstFaultFactory = self.firstItem(iter(self.faults), lambda f: f["predicate"](request)) if (firstFaultFactory != None): @@ -74,7 +73,7 @@ async def send(self, request: HttpRequest, *, stream: bool = False, proxies: Mut firstResonseTransformation = self.firstItem(iter(self.responseTransformations), lambda f: f["predicate"](request)) - getResponseTask = super().send(request, stream=stream, proxies=proxies, **config) + getResponseTask = super().send(request, stream=stream, proxies=proxies, **config) if (firstResonseTransformation != None): self.logger.info(f"Invoking response transformation") @@ -86,4 +85,35 @@ async def send(self, request: HttpRequest, *, stream: bool = False, proxies: Mut response = await getResponseTask self.logger.info(f"Received response with status code {response.status_code}") return response - \ No newline at end of file + + @staticmethod + def predicate_url_contains_id(r: HttpRequest, id: str) -> bool: + return id in r.url + + @staticmethod + def predicate_req_payload_contains_id(r: HttpRequest, id: str): + if r.body is None: + return False + + return '"id":"{}"'.format(id) in r.body + + @staticmethod + def predicate_req_for_document_with_id(r: HttpRequest, id: str) -> bool: + return (FaultInjectionTransport.predicate_url_contains_id(r, id) + or FaultInjectionTransport.predicate_req_payload_contains_id(r, id)) + + @staticmethod + def predicate_is_database_account_call(r: HttpRequest) -> bool: + return (r.headers.get('x-ms-thinclient-proxy-resource-type') == 'databaseaccount' + and r.headers.get('x-ms-thinclient-proxy-operation-type') == 'Read') + + @staticmethod + async def throw_after_delay(delay_in_ms: int, error: Exception): + await asyncio.sleep(delay_in_ms / 1000.0) + raise error + + @staticmethod + async def transform_pass_through(r: HttpRequest, + inner: Callable[[],asyncio.Task[AsyncHttpResponse]]) -> asyncio.Task[AsyncHttpResponse]: + await asyncio.sleep(1) + return await inner() \ No newline at end of file diff --git a/sdk/cosmos/azure-cosmos/test/test_dummy.py b/sdk/cosmos/azure-cosmos/test/test_dummy.py index 91c9e219350b..91fa0e7283b0 100644 --- a/sdk/cosmos/azure-cosmos/test/test_dummy.py +++ b/sdk/cosmos/azure-cosmos/test/test_dummy.py @@ -14,12 +14,12 @@ import test_config from azure.cosmos import PartitionKey from azure.cosmos.aio import CosmosClient, _retry_utility_async -from azure.core.rest import HttpRequest +from azure.core.rest import HttpRequest, AsyncHttpResponse import asyncio import sys from azure.core.pipeline.transport import AioHttpTransport from typing import Any, Callable -import _fault_injection_transport +from _fault_injection_transport import FaultInjectionTransport import os COLLECTION = "created_collection" @@ -36,7 +36,8 @@ @pytest.fixture() def setup(): - return + return + @pytest.mark.cosmosEmulator @pytest.mark.asyncio @@ -87,7 +88,7 @@ def teardown_class(cls): def setup_method_with_custom_transport(self, custom_transport: AioHttpTransport): client = CosmosClient(host, masterKey, consistency_level="Session", connection_policy=connectionPolicy, transport=custom_transport, - logger=logger) + logger=logger, enable_diagnostics_logging=True) db: DatabaseProxy = client.get_database_client(TEST_DATABASE_ID) container: ContainerProxy = db.get_container_client(self.single_partition_container_name) return {"client": client, "db": db, "col": container} @@ -99,26 +100,6 @@ def cleanup_method(self, initializedObjects: dict[str, Any]): except Exception as closeError: logger.warning("Exception trying to close method client.") - def predicate_url_contains_id(self, r: HttpRequest, id: str): - logger.info("FaultPredicate for request {} {}".format(r.method, r.url)); - return id in r.url; - - def predicate_req_payload_contains_id(self, r: HttpRequest, id: str): - logger.info("FaultPredicate for request {} {} - request payload {}".format( - r.method, - r.url, - "NONE" if r.body is None else r.body)); - - if (r.body == None): - return False - - - return '"id":"{}"'.format(id) in r.body; - - async def throw_after_delay(self, delayInMs: int, error: Exception): - await asyncio.sleep(delayInMs/1000.0) - raise error - async def test_throws_injected_error(self, setup): idValue: str = str(uuid.uuid4()) document_definition = {'id': idValue, @@ -126,9 +107,9 @@ async def test_throws_injected_error(self, setup): 'name': 'sample document', 'key': 'value'} - custom_transport = _fault_injection_transport.FaulInjectionTransport(logger) - predicate : Callable[[HttpRequest], bool] = lambda r: self.predicate_req_payload_contains_id(r, idValue) - custom_transport.addFault(predicate, lambda: self.throw_after_delay( + custom_transport = FaultInjectionTransport(logger) + predicate : Callable[[HttpRequest], bool] = lambda r: FaultInjectionTransport.predicate_req_for_document_with_id(r, idValue) + custom_transport.addFault(predicate, lambda: FaultInjectionTransport.throw_after_delay( 500, CosmosHttpResponseError( status_code=502, @@ -148,7 +129,11 @@ async def test_throws_injected_error(self, setup): await cleanupOp async def test_succeeds_with_multiple_endpoints(self, setup): - custom_transport = _fault_injection_transport.FaulInjectionTransport(logger) + custom_transport = FaultInjectionTransport(logger) + predicate: Callable[[HttpRequest], bool] = lambda r: FaultInjectionTransport.predicate_is_database_account_call(r) + transformation: Callable[[HttpRequest, Callable[[HttpRequest], asyncio.Task[AsyncHttpResponse]]], AsyncHttpResponse] = lambda r, inner: FaultInjectionTransport.transform_pass_through(r, inner) + custom_transport.add_response_transformation(predicate, transformation) + idValue: str = str(uuid.uuid4()) document_definition = {'id': idValue, 'pk': idValue, From 0705aeb76a367c34c4f7bdfdc09ff3f78753d7ea Mon Sep 17 00:00:00 2001 From: Fabian Meiswinkel Date: Wed, 19 Feb 2025 20:04:39 +0000 Subject: [PATCH 16/86] Iterating on tests --- .../azure/cosmos/aio/_asynchronous_request.py | 1 + .../aio/_global_endpoint_manager_async.py | 5 + .../test/_fault_injection_transport.py | 92 +++++++++++++++---- sdk/cosmos/azure-cosmos/test/test_dummy.py | 49 +++++----- 4 files changed, 105 insertions(+), 42 deletions(-) diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_asynchronous_request.py b/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_asynchronous_request.py index 81430d8df42c..f8ebf6ccdbb8 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_asynchronous_request.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_asynchronous_request.py @@ -117,6 +117,7 @@ async def _Request(global_endpoint_manager, request_params, connection_policy, p response = response.http_response headers = copy.copy(response.headers) + await response.load_body() data = response.body() if data: data = data.decode("utf-8") diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_global_endpoint_manager_async.py b/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_global_endpoint_manager_async.py index 374fd940c184..365ef9c9b395 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_global_endpoint_manager_async.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_global_endpoint_manager_async.py @@ -126,13 +126,18 @@ async def _endpoints_health_check(self, **kwargs): """ all_endpoints = [self.location_cache.read_regional_endpoints[0]] all_endpoints.extend(self.location_cache.write_regional_endpoints) + validated_endpoints = {} count = 0 for endpoint in all_endpoints: + if (endpoint.get_current() in validated_endpoints): + continue + count += 1 if count > 3: break try: await self.client._GetDatabaseAccountCheck(endpoint.get_current(), **kwargs) + validated_endpoints[endpoint.get_current()] = "" except (exceptions.CosmosHttpResponseError, AzureError): if endpoint in self.location_cache.read_regional_endpoints: self.mark_endpoint_unavailable_for_read(endpoint.get_current(), False) diff --git a/sdk/cosmos/azure-cosmos/test/_fault_injection_transport.py b/sdk/cosmos/azure-cosmos/test/_fault_injection_transport.py index 123658b1516c..0dd75b3c0e60 100644 --- a/sdk/cosmos/azure-cosmos/test/_fault_injection_transport.py +++ b/sdk/cosmos/azure-cosmos/test/_fault_injection_transport.py @@ -23,17 +23,24 @@ """ import asyncio -import aiohttp +import json import logging +import sys +from collections.abc import MutableMapping +from typing import Callable +import aiohttp from azure.core.pipeline.transport import AioHttpTransport from azure.core.rest import HttpRequest, AsyncHttpResponse -from collections.abc import MutableMapping -from typing import Any, Callable + +from azure.cosmos.exceptions import CosmosHttpResponseError + class FaultInjectionTransport(AioHttpTransport): - def __init__(self, logger: logging.Logger, *, session: aiohttp.ClientSession | None = None, loop=None, session_owner: bool = True, **config): - self.logger = logger + logger = logging.getLogger('azure.cosmos.fault_injection_transport') + logger.setLevel(logging.DEBUG) + + def __init__(self, *, session: aiohttp.ClientSession | None = None, loop=None, session_owner: bool = True, **config): self.faults = [] self.requestTransformations = [] self.responseTransformations = [] @@ -59,37 +66,50 @@ def firstItem(self, iterable, condition=lambda x: True): return next((x for x in iterable if condition(x)), None) async def send(self, request: HttpRequest, *, stream: bool = False, proxies: MutableMapping[str, str] | None = None, **config) -> AsyncHttpResponse: + FaultInjectionTransport.logger.info("--> FaultInjectionTransport.Send {} {}".format(request.method, request.url)) # find the first fault Factory with matching predicate if any firstFaultFactory = self.firstItem(iter(self.faults), lambda f: f["predicate"](request)) if (firstFaultFactory != None): + FaultInjectionTransport.logger.info("--> FaultInjectionTransport.ApplyFaultInjection") injectedError = await firstFaultFactory["apply"]() - self.logger.info("Found to-be-injected error {}".format(injectedError)) + FaultInjectionTransport.logger.info("Found to-be-injected error {}".format(injectedError)) raise injectedError # apply the chain of request transformations with matching predicates if any matchingRequestTransformations = filter(lambda f: f["predicate"](f["predicate"]), iter(self.requestTransformations)) for currentTransformation in matchingRequestTransformations: + FaultInjectionTransport.logger.info("--> FaultInjectionTransport.ApplyRequestTransformation") request = await currentTransformation["apply"](request) - firstResonseTransformation = self.firstItem(iter(self.responseTransformations), lambda f: f["predicate"](request)) - - getResponseTask = super().send(request, stream=stream, proxies=proxies, **config) - - if (firstResonseTransformation != None): - self.logger.info(f"Invoking response transformation") - response = await firstResonseTransformation["apply"](request, lambda: getResponseTask) - self.logger.info(f"Received response transformation result with status code {response.status_code}") + firstResponseTransformation = self.firstItem(iter(self.responseTransformations), lambda f: f["predicate"](request)) + + FaultInjectionTransport.logger.info("--> FaultInjectionTransport.BeforeGetResponseTask") + getResponseTask = asyncio.create_task(super().send(request, stream=stream, proxies=proxies, **config)) + FaultInjectionTransport.logger.info("<-- FaultInjectionTransport.AfterGetResponseTask") + + if (firstResponseTransformation != None): + FaultInjectionTransport.logger.info(f"Invoking response transformation") + response = await firstResponseTransformation["apply"](request, lambda: getResponseTask) + FaultInjectionTransport.logger.info(f"Received response transformation result with status code {response.status_code}") return response else: - self.logger.info(f"Sending request to {request.url}") + FaultInjectionTransport.logger.info(f"Sending request to {request.url}") response = await getResponseTask - self.logger.info(f"Received response with status code {response.status_code}") + FaultInjectionTransport.logger.info(f"Received response with status code {response.status_code}") return response @staticmethod def predicate_url_contains_id(r: HttpRequest, id: str) -> bool: return id in r.url + @staticmethod + def print_call_stack(): + print("Call stack:") + frame = sys._getframe() + while frame: + print(f"File: {frame.f_code.co_filename}, Line: {frame.f_lineno}, Function: {frame.f_code.co_name}") + frame = frame.f_back + @staticmethod def predicate_req_payload_contains_id(r: HttpRequest, id: str): if r.body is None: @@ -104,16 +124,48 @@ def predicate_req_for_document_with_id(r: HttpRequest, id: str) -> bool: @staticmethod def predicate_is_database_account_call(r: HttpRequest) -> bool: - return (r.headers.get('x-ms-thinclient-proxy-resource-type') == 'databaseaccount' + isDbAccountRead = (r.headers.get('x-ms-thinclient-proxy-resource-type') == 'databaseaccount' and r.headers.get('x-ms-thinclient-proxy-operation-type') == 'Read') + return isDbAccountRead + + @staticmethod + def predicate_is_write_operation(r: HttpRequest, uri_prefix: str) -> bool: + isWriteDocumentOperation = (r.headers.get('x-ms-thinclient-proxy-resource-type') == 'docs' + and r.headers.get('x-ms-thinclient-proxy-operation-type') != 'Read' + and r.headers.get('x-ms-thinclient-proxy-operation-type') != 'ReadFeed' + and r.headers.get('x-ms-thinclient-proxy-operation-type') != 'Query') + + return isWriteDocumentOperation and uri_prefix in r.url + + @staticmethod async def throw_after_delay(delay_in_ms: int, error: Exception): await asyncio.sleep(delay_in_ms / 1000.0) raise error @staticmethod - async def transform_pass_through(r: HttpRequest, + async def throw_write_forbidden(): + raise CosmosHttpResponseError( + status_code=403, + message="Injected error disallowing writes in this region.", + response=None, + sub_status_code=3, + ) + + @staticmethod + async def transform_convert_emulator_to_single_master_read_multi_region_account(r: HttpRequest, inner: Callable[[],asyncio.Task[AsyncHttpResponse]]) -> asyncio.Task[AsyncHttpResponse]: - await asyncio.sleep(1) - return await inner() \ No newline at end of file + + response = await inner() + if not FaultInjectionTransport.predicate_is_database_account_call(response.request): + return response + + await response.load_body() + data = response.body() + if response.status_code == 200 and data: + data = data.decode("utf-8") + result = json.loads(data) + result["readableLocations"].append({"name": "East US", "databaseAccountEndpoint" : "https://localhost:8888/"}) + FaultInjectionTransport.logger.info("Transformed Account Topology: {}".format(result)) + return response \ No newline at end of file diff --git a/sdk/cosmos/azure-cosmos/test/test_dummy.py b/sdk/cosmos/azure-cosmos/test/test_dummy.py index 91fa0e7283b0..1300717de2c3 100644 --- a/sdk/cosmos/azure-cosmos/test/test_dummy.py +++ b/sdk/cosmos/azure-cosmos/test/test_dummy.py @@ -1,31 +1,30 @@ # The MIT License (MIT) # Copyright (c) Microsoft Corporation. All rights reserved. +import asyncio import logging +import os +import sys import time import unittest import uuid +from typing import Any, Callable import pytest +from azure.core.pipeline.transport import AioHttpTransport +from azure.core.rest import HttpRequest, AsyncHttpResponse +import test_config +from _fault_injection_transport import FaultInjectionTransport +from azure.cosmos import PartitionKey +from azure.cosmos.aio import CosmosClient from azure.cosmos.aio._container import ContainerProxy from azure.cosmos.aio._database import DatabaseProxy from azure.cosmos.exceptions import CosmosHttpResponseError -import test_config -from azure.cosmos import PartitionKey -from azure.cosmos.aio import CosmosClient, _retry_utility_async -from azure.core.rest import HttpRequest, AsyncHttpResponse -import asyncio -import sys -from azure.core.pipeline.transport import AioHttpTransport -from typing import Any, Callable -from _fault_injection_transport import FaultInjectionTransport -import os COLLECTION = "created_collection" MGMT_TIMEOUT = 3.0 logger = logging.getLogger('azure.cosmos') -logger.setLevel("INFO") logger.setLevel(logging.DEBUG) logger.addHandler(logging.StreamHandler(sys.stdout)) @@ -107,7 +106,7 @@ async def test_throws_injected_error(self, setup): 'name': 'sample document', 'key': 'value'} - custom_transport = FaultInjectionTransport(logger) + custom_transport = FaultInjectionTransport() predicate : Callable[[HttpRequest], bool] = lambda r: FaultInjectionTransport.predicate_req_for_document_with_id(r, idValue) custom_transport.addFault(predicate, lambda: FaultInjectionTransport.throw_after_delay( 500, @@ -129,10 +128,15 @@ async def test_throws_injected_error(self, setup): await cleanupOp async def test_succeeds_with_multiple_endpoints(self, setup): - custom_transport = FaultInjectionTransport(logger) - predicate: Callable[[HttpRequest], bool] = lambda r: FaultInjectionTransport.predicate_is_database_account_call(r) - transformation: Callable[[HttpRequest, Callable[[HttpRequest], asyncio.Task[AsyncHttpResponse]]], AsyncHttpResponse] = lambda r, inner: FaultInjectionTransport.transform_pass_through(r, inner) - custom_transport.add_response_transformation(predicate, transformation) + custom_transport = FaultInjectionTransport() + is_get_account_predicate: Callable[[HttpRequest], bool] = lambda r: FaultInjectionTransport.predicate_is_database_account_call(r) + is_write_operation_predicate: Callable[[HttpRequest], bool] = lambda \ + r: FaultInjectionTransport.predicate_is_write_operation(r, "https://localhost") + emulator_as_multi_region_sm_account_transformation: Callable[[HttpRequest, Callable[[HttpRequest], asyncio.Task[AsyncHttpResponse]]], AsyncHttpResponse] = \ + lambda r, inner: FaultInjectionTransport.transform_convert_emulator_to_single_master_read_multi_region_account(r, inner) + + custom_transport.addFault(is_write_operation_predicate, lambda: FaultInjectionTransport.throw_write_forbidden()) + custom_transport.add_response_transformation(is_get_account_predicate, emulator_as_multi_region_sm_account_transformation) idValue: str = str(uuid.uuid4()) document_definition = {'id': idValue, @@ -143,16 +147,17 @@ async def test_succeeds_with_multiple_endpoints(self, setup): initializedObjects = self.setup_method_with_custom_transport(custom_transport) try: container: ContainerProxy = initializedObjects["col"] - + created_document = await container.create_item(body=document_definition) start = time.perf_counter() - - while ((time.perf_counter() - start) < 2): - await container.read_item(idValue, partition_key=idValue) - await asyncio.sleep(0.2) + + + #while ((time.perf_counter() - start) < 2): + # await container.read_item(idValue, partition_key=idValue) + # await asyncio.sleep(0.2) finally: - self.cleanup_method(initializedObjects) + self.cleanup_method(initializedObjects) if __name__ == '__main__': unittest.main() From baf7aea226895998c136fb0298c492ac4811f2a3 Mon Sep 17 00:00:00 2001 From: Fabian Meiswinkel Date: Thu, 20 Feb 2025 16:55:42 +0000 Subject: [PATCH 17/86] Prettifying tests --- .../test/_fault_injection_transport.py | 119 +++++++++++------- ...> test_fault_injection_transport_async.py} | 116 +++++++++-------- 2 files changed, 144 insertions(+), 91 deletions(-) rename sdk/cosmos/azure-cosmos/test/{test_dummy.py => test_fault_injection_transport_async.py} (50%) diff --git a/sdk/cosmos/azure-cosmos/test/_fault_injection_transport.py b/sdk/cosmos/azure-cosmos/test/_fault_injection_transport.py index 0dd75b3c0e60..efe58bae3032 100644 --- a/sdk/cosmos/azure-cosmos/test/_fault_injection_transport.py +++ b/sdk/cosmos/azure-cosmos/test/_fault_injection_transport.py @@ -27,10 +27,10 @@ import logging import sys from collections.abc import MutableMapping -from typing import Callable +from typing import Callable, Optional import aiohttp -from azure.core.pipeline.transport import AioHttpTransport +from azure.core.pipeline.transport import AioHttpTransport, AioHttpTransportResponse from azure.core.rest import HttpRequest, AsyncHttpResponse from azure.cosmos.exceptions import CosmosHttpResponseError @@ -46,18 +46,16 @@ def __init__(self, *, session: aiohttp.ClientSession | None = None, loop=None, s self.responseTransformations = [] super().__init__(session=session, loop=loop, session_owner=session_owner, **config) - def addFault(self, predicate: Callable[[HttpRequest], bool], faultFactory: Callable[[HttpRequest], asyncio.Task[Exception]]): - self.faults.append({"predicate": predicate, "apply": faultFactory}) + def add_fault(self, predicate: Callable[[HttpRequest], bool], fault_factory: Callable[[HttpRequest], asyncio.Task[Exception]]): + self.faults.append({"predicate": predicate, "apply": fault_factory}) - def addRequestTransformation(self, predicate: Callable[[HttpRequest], bool], requestTransformation: Callable[[HttpRequest], asyncio.Task[HttpRequest]]): - self.requestTransformations.append({"predicate": predicate, "apply": requestTransformation}) - - def add_response_transformation(self, predicate: Callable[[HttpRequest], bool], responseTransformation: Callable[[HttpRequest, Callable[[HttpRequest], asyncio.Task[AsyncHttpResponse]]], asyncio.Task[AsyncHttpResponse]]): + def add_response_transformation(self, predicate: Callable[[HttpRequest], bool], response_transformation: Callable[[HttpRequest, Callable[[HttpRequest], asyncio.Task[AsyncHttpResponse]]], asyncio.Task[AsyncHttpResponse]]): self.responseTransformations.append({ "predicate": predicate, - "apply": responseTransformation}) + "apply": response_transformation}) - def firstItem(self, iterable, condition=lambda x: True): + @staticmethod + def __first_item(iterable, condition=lambda x: True): """ Returns the first item in the `iterable` that satisfies the `condition`. @@ -68,39 +66,41 @@ def firstItem(self, iterable, condition=lambda x: True): async def send(self, request: HttpRequest, *, stream: bool = False, proxies: MutableMapping[str, str] | None = None, **config) -> AsyncHttpResponse: FaultInjectionTransport.logger.info("--> FaultInjectionTransport.Send {} {}".format(request.method, request.url)) # find the first fault Factory with matching predicate if any - firstFaultFactory = self.firstItem(iter(self.faults), lambda f: f["predicate"](request)) - if (firstFaultFactory != None): + first_fault_factory = FaultInjectionTransport.__first_item(iter(self.faults), lambda f: f["predicate"](request)) + if first_fault_factory: FaultInjectionTransport.logger.info("--> FaultInjectionTransport.ApplyFaultInjection") - injectedError = await firstFaultFactory["apply"]() - FaultInjectionTransport.logger.info("Found to-be-injected error {}".format(injectedError)) - raise injectedError + injected_error = await first_fault_factory["apply"](request) + FaultInjectionTransport.logger.info("Found to-be-injected error {}".format(injected_error)) + raise injected_error # apply the chain of request transformations with matching predicates if any - matchingRequestTransformations = filter(lambda f: f["predicate"](f["predicate"]), iter(self.requestTransformations)) - for currentTransformation in matchingRequestTransformations: + matching_request_transformations = filter(lambda f: f["predicate"](f["predicate"]), iter(self.requestTransformations)) + for currentTransformation in matching_request_transformations: FaultInjectionTransport.logger.info("--> FaultInjectionTransport.ApplyRequestTransformation") request = await currentTransformation["apply"](request) - firstResponseTransformation = self.firstItem(iter(self.responseTransformations), lambda f: f["predicate"](request)) + first_response_transformation = FaultInjectionTransport.__first_item(iter(self.responseTransformations), lambda f: f["predicate"](request)) FaultInjectionTransport.logger.info("--> FaultInjectionTransport.BeforeGetResponseTask") - getResponseTask = asyncio.create_task(super().send(request, stream=stream, proxies=proxies, **config)) + get_response_task = asyncio.create_task(super().send(request, stream=stream, proxies=proxies, **config)) FaultInjectionTransport.logger.info("<-- FaultInjectionTransport.AfterGetResponseTask") - if (firstResponseTransformation != None): + if first_response_transformation: FaultInjectionTransport.logger.info(f"Invoking response transformation") - response = await firstResponseTransformation["apply"](request, lambda: getResponseTask) + response = await first_response_transformation["apply"](request, lambda: get_response_task) + response.headers["_request"] = request FaultInjectionTransport.logger.info(f"Received response transformation result with status code {response.status_code}") return response else: FaultInjectionTransport.logger.info(f"Sending request to {request.url}") - response = await getResponseTask + response = await get_response_task + response.headers["_request"] = request FaultInjectionTransport.logger.info(f"Received response with status code {response.status_code}") return response @staticmethod - def predicate_url_contains_id(r: HttpRequest, id: str) -> bool: - return id in r.url + def predicate_url_contains_id(r: HttpRequest, id_value: str) -> bool: + return id_value in r.url @staticmethod def print_call_stack(): @@ -111,42 +111,41 @@ def print_call_stack(): frame = frame.f_back @staticmethod - def predicate_req_payload_contains_id(r: HttpRequest, id: str): + def predicate_req_payload_contains_id(r: HttpRequest, id_value: str): if r.body is None: return False - return '"id":"{}"'.format(id) in r.body + return '"id":"{}"'.format(id_value) in r.body @staticmethod - def predicate_req_for_document_with_id(r: HttpRequest, id: str) -> bool: - return (FaultInjectionTransport.predicate_url_contains_id(r, id) - or FaultInjectionTransport.predicate_req_payload_contains_id(r, id)) + def predicate_req_for_document_with_id(r: HttpRequest, id_value: str) -> bool: + return (FaultInjectionTransport.predicate_url_contains_id(r, id_value) + or FaultInjectionTransport.predicate_req_payload_contains_id(r, id_value)) @staticmethod def predicate_is_database_account_call(r: HttpRequest) -> bool: - isDbAccountRead = (r.headers.get('x-ms-thinclient-proxy-resource-type') == 'databaseaccount' + is_db_account_read = (r.headers.get('x-ms-thinclient-proxy-resource-type') == 'databaseaccount' and r.headers.get('x-ms-thinclient-proxy-operation-type') == 'Read') - return isDbAccountRead + return is_db_account_read @staticmethod def predicate_is_write_operation(r: HttpRequest, uri_prefix: str) -> bool: - isWriteDocumentOperation = (r.headers.get('x-ms-thinclient-proxy-resource-type') == 'docs' + is_write_document_operation = (r.headers.get('x-ms-thinclient-proxy-resource-type') == 'docs' and r.headers.get('x-ms-thinclient-proxy-operation-type') != 'Read' and r.headers.get('x-ms-thinclient-proxy-operation-type') != 'ReadFeed' and r.headers.get('x-ms-thinclient-proxy-operation-type') != 'Query') - return isWriteDocumentOperation and uri_prefix in r.url - + return is_write_document_operation and uri_prefix in r.url @staticmethod - async def throw_after_delay(delay_in_ms: int, error: Exception): + async def error_after_delay(delay_in_ms: int, error: Exception) -> Exception: await asyncio.sleep(delay_in_ms / 1000.0) - raise error + return error @staticmethod - async def throw_write_forbidden(): - raise CosmosHttpResponseError( + async def error_write_forbidden() -> Exception: + return CosmosHttpResponseError( status_code=403, message="Injected error disallowing writes in this region.", response=None, @@ -154,8 +153,11 @@ async def throw_write_forbidden(): ) @staticmethod - async def transform_convert_emulator_to_single_master_read_multi_region_account(r: HttpRequest, - inner: Callable[[],asyncio.Task[AsyncHttpResponse]]) -> asyncio.Task[AsyncHttpResponse]: + async def transform_convert_emulator_to_single_master_read_multi_region_account( + additional_region: str, + artificial_uri: str, + r: HttpRequest, + inner: Callable[[],asyncio.Task[AsyncHttpResponse]]) -> asyncio.Task[AsyncHttpResponse]: response = await inner() if not FaultInjectionTransport.predicate_is_database_account_call(response.request): @@ -166,6 +168,39 @@ async def transform_convert_emulator_to_single_master_read_multi_region_account( if response.status_code == 200 and data: data = data.decode("utf-8") result = json.loads(data) - result["readableLocations"].append({"name": "East US", "databaseAccountEndpoint" : "https://localhost:8888/"}) + result["readableLocations"].append({"name": additional_region, "databaseAccountEndpoint" : artificial_uri}) FaultInjectionTransport.logger.info("Transformed Account Topology: {}".format(result)) - return response \ No newline at end of file + request: HttpRequest = response.request + return FaultInjectionTransport.MockHttpResponse(request, 200, result) + + return response + + class MockHttpResponse(AioHttpTransportResponse): + def __init__(self, request: HttpRequest, status_code: int, content:Optional[dict[str, any]]): + self.request: HttpRequest = request + # This is actually never None, and set by all implementations after the call to + # __init__ of this class. This class is also a legacy impl, so it's risky to change it + # for low benefits The new "rest" implementation does define correctly status_code + # as non-optional. + self.status_code: int = status_code + self.headers: MutableMapping[str, str] = {} + self.reason: Optional[str] = None + self.content_type: Optional[str] = None + self.block_size: int = 4096 # Default to same as R + self.content: Optional[dict[str, any]] = None + self.json_text: Optional[str] = None + self.bytes: Optional[bytes] = None + if content: + self.content:Optional[dict[str, any]] = content + self.json_text:Optional[str] = json.dumps(content) + self.bytes:bytes = self.json_text.encode("utf-8") + + + def body(self) -> bytes: + return self.bytes + + def text(self, encoding: Optional[str] = None) -> str: + return self.json_text + + async def load_body(self) -> None: + return \ No newline at end of file diff --git a/sdk/cosmos/azure-cosmos/test/test_dummy.py b/sdk/cosmos/azure-cosmos/test/test_fault_injection_transport_async.py similarity index 50% rename from sdk/cosmos/azure-cosmos/test/test_dummy.py rename to sdk/cosmos/azure-cosmos/test/test_fault_injection_transport_async.py index 1300717de2c3..cba2c1074ec4 100644 --- a/sdk/cosmos/azure-cosmos/test/test_dummy.py +++ b/sdk/cosmos/azure-cosmos/test/test_fault_injection_transport_async.py @@ -29,8 +29,8 @@ logger.addHandler(logging.StreamHandler(sys.stdout)) host = test_config.TestConfig.host -masterKey = test_config.TestConfig.masterKey -connectionPolicy = test_config.TestConfig.connectionPolicy +master_key = test_config.TestConfig.masterKey +connection_policy = test_config.TestConfig.connectionPolicy TEST_DATABASE_ID = test_config.TestConfig.TEST_DATABASE_ID @pytest.fixture() @@ -41,27 +41,27 @@ def setup(): @pytest.mark.cosmosEmulator @pytest.mark.asyncio @pytest.mark.usefixtures("setup") -class TestDummyAsync: +class TestFaultInjectionTransportAsync: @classmethod def setup_class(cls): logger.info("starting class: {} execution".format(cls.__name__)) cls.host = test_config.TestConfig.host - cls.masterKey = test_config.TestConfig.masterKey + cls.master_key = test_config.TestConfig.masterKey - if (cls.masterKey == '[YOUR_KEY_HERE]' or + if (cls.master_key == '[YOUR_KEY_HERE]' or cls.host == '[YOUR_ENDPOINT_HERE]'): raise Exception( "You must specify your Azure Cosmos account values for " "'masterKey' and 'host' at the top of this class to run the " "tests.") - cls.connectionPolicy = test_config.TestConfig.connectionPolicy + cls.connection_policy = test_config.TestConfig.connectionPolicy cls.database_id = test_config.TestConfig.TEST_DATABASE_ID cls.single_partition_container_name= os.path.basename(__file__) + str(uuid.uuid4()) - cls.mgmtClient = CosmosClient(host, masterKey, consistency_level="Session", - connection_policy=connectionPolicy, logger=logger) - created_database: DatabaseProxy = cls.mgmtClient.get_database_client(cls.database_id) + cls.mgmt_client = CosmosClient(host, master_key, consistency_level="Session", + connection_policy=connection_policy, logger=logger) + created_database: DatabaseProxy = cls.mgmt_client.get_database_client(cls.database_id) asyncio.run(asyncio.wait_for( created_database.create_container( cls.single_partition_container_name, @@ -71,93 +71,111 @@ def setup_class(cls): @classmethod def teardown_class(cls): logger.info("tearing down class: {}".format(cls.__name__)) - created_database: DatabaseProxy = cls.mgmtClient.get_database_client(cls.database_id) + created_database: DatabaseProxy = cls.mgmt_client.get_database_client(cls.database_id) try: asyncio.run(asyncio.wait_for( created_database.delete_container(cls.single_partition_container_name), MGMT_TIMEOUT)) except Exception as containerDeleteError: - logger.warn("Exception trying to delete database {}. {}".format(created_database.id, containerDeleteError)) + logger.warning("Exception trying to delete database {}. {}".format(created_database.id, containerDeleteError)) finally: try: - asyncio.run(asyncio.wait_for(cls.mgmtClient.close(), MGMT_TIMEOUT)) + asyncio.run(asyncio.wait_for(cls.mgmt_client.close(), MGMT_TIMEOUT)) except Exception as closeError: - logger.warn("Exception trying to delete database {}. {}".format(created_database.id, closeError)) + logger.warning("Exception trying to delete database {}. {}".format(created_database.id, closeError)) - def setup_method_with_custom_transport(self, custom_transport: AioHttpTransport): - client = CosmosClient(host, masterKey, consistency_level="Session", - connection_policy=connectionPolicy, transport=custom_transport, - logger=logger, enable_diagnostics_logging=True) + def setup_method_with_custom_transport(self, custom_transport: AioHttpTransport, **kwargs): + client = CosmosClient(host, master_key, consistency_level="Session", + connection_policy=connection_policy, transport=custom_transport, + logger=logger, enable_diagnostics_logging=True, **kwargs) db: DatabaseProxy = client.get_database_client(TEST_DATABASE_ID) container: ContainerProxy = db.get_container_client(self.single_partition_container_name) return {"client": client, "db": db, "col": container} - def cleanup_method(self, initializedObjects: dict[str, Any]): - method_client: CosmosClient = initializedObjects["client"] + @staticmethod + def cleanup_method(initialized_objects: dict[str, Any]): + method_client: CosmosClient = initialized_objects["client"] try: asyncio.run(asyncio.wait_for(method_client.close(), MGMT_TIMEOUT)) - except Exception as closeError: - logger.warning("Exception trying to close method client.") + except Exception as close_error: + logger.warning(f"Exception trying to close method client. {close_error}") async def test_throws_injected_error(self, setup): - idValue: str = str(uuid.uuid4()) - document_definition = {'id': idValue, - 'pk': idValue, + id_value: str = str(uuid.uuid4()) + document_definition = {'id': id_value, + 'pk': id_value, 'name': 'sample document', 'key': 'value'} custom_transport = FaultInjectionTransport() - predicate : Callable[[HttpRequest], bool] = lambda r: FaultInjectionTransport.predicate_req_for_document_with_id(r, idValue) - custom_transport.addFault(predicate, lambda: FaultInjectionTransport.throw_after_delay( + predicate : Callable[[HttpRequest], bool] = lambda r: FaultInjectionTransport.predicate_req_for_document_with_id(r, id_value) + custom_transport.add_fault(predicate, lambda r: asyncio.create_task(FaultInjectionTransport.error_after_delay( 500, CosmosHttpResponseError( status_code=502, - message="Some random reverse proxy error."))) + message="Some random reverse proxy error.")))) - initializedObjects = self.setup_method_with_custom_transport(custom_transport) + initialized_objects = self.setup_method_with_custom_transport(custom_transport) try: - container: ContainerProxy = initializedObjects["col"] + container: ContainerProxy = initialized_objects["col"] await container.create_item(body=document_definition) pytest.fail("Expected exception not thrown") except CosmosHttpResponseError as cosmosError: - if (cosmosError.status_code != 502): + if cosmosError.status_code != 502: raise cosmosError finally: - cleanupOp = self.cleanup_method(initializedObjects) - if (cleanupOp != None): - await cleanupOp + TestFaultInjectionTransportAsync.cleanup_method(initialized_objects) async def test_succeeds_with_multiple_endpoints(self, setup): + localhost_uri: str = test_config.TestConfig.local_host + alternate_localhost_uri: str = localhost_uri.replace("localhost", "127.0.0.1") custom_transport = FaultInjectionTransport() is_get_account_predicate: Callable[[HttpRequest], bool] = lambda r: FaultInjectionTransport.predicate_is_database_account_call(r) is_write_operation_predicate: Callable[[HttpRequest], bool] = lambda \ - r: FaultInjectionTransport.predicate_is_write_operation(r, "https://localhost") - emulator_as_multi_region_sm_account_transformation: Callable[[HttpRequest, Callable[[HttpRequest], asyncio.Task[AsyncHttpResponse]]], AsyncHttpResponse] = \ - lambda r, inner: FaultInjectionTransport.transform_convert_emulator_to_single_master_read_multi_region_account(r, inner) + r: FaultInjectionTransport.predicate_is_write_operation(r, localhost_uri) - custom_transport.addFault(is_write_operation_predicate, lambda: FaultInjectionTransport.throw_write_forbidden()) - custom_transport.add_response_transformation(is_get_account_predicate, emulator_as_multi_region_sm_account_transformation) + # Emulator uses "South Central US" with Uri https://127.0.0.1:8888 - idValue: str = str(uuid.uuid4()) - document_definition = {'id': idValue, - 'pk': idValue, + emulator_as_multi_region_sm_account_transformation: Callable[[HttpRequest, Callable[[HttpRequest], asyncio.Task[AsyncHttpResponse]]], AsyncHttpResponse] = \ + lambda r, inner: FaultInjectionTransport.transform_convert_emulator_to_single_master_read_multi_region_account( + additional_region="East US", + artificial_uri=localhost_uri, + r=r, + inner=inner) + + custom_transport.add_fault( + is_write_operation_predicate, + lambda r: asyncio.create_task(FaultInjectionTransport.error_write_forbidden())) + custom_transport.add_response_transformation( + is_get_account_predicate, + emulator_as_multi_region_sm_account_transformation) + + id_value: str = str(uuid.uuid4()) + document_definition = {'id': id_value, + 'pk': id_value, 'name': 'sample document', 'key': 'value'} - initializedObjects = self.setup_method_with_custom_transport(custom_transport) + initialized_objects = self.setup_method_with_custom_transport( + custom_transport, + preferred_locations=["East US", "South Central US"]) try: - container: ContainerProxy = initializedObjects["col"] + container: ContainerProxy = initialized_objects["col"] created_document = await container.create_item(body=document_definition) - start = time.perf_counter() - + request: HttpRequest = created_document.get_response_headers()["_request"] + # Validate the response comes from "South Central US" (the write region) + assert request.url.startswith(alternate_localhost_uri) + start:float = time.perf_counter() - #while ((time.perf_counter() - start) < 2): - # await container.read_item(idValue, partition_key=idValue) - # await asyncio.sleep(0.2) + while (time.perf_counter() - start) < 2: + read_document = await container.read_item(id_value, partition_key=id_value) + request: HttpRequest = read_document.get_response_headers()["_request"] + # Validate the response comes from "East US" (the most preferred read-only region) + assert request.url.startswith(localhost_uri) finally: - self.cleanup_method(initializedObjects) + TestFaultInjectionTransportAsync.cleanup_method(initialized_objects) if __name__ == '__main__': unittest.main() From e90b722d30e123bdf3ccf15c87f0923074354ddf Mon Sep 17 00:00:00 2001 From: Fabian Meiswinkel Date: Fri, 21 Feb 2025 18:30:44 +0000 Subject: [PATCH 18/86] small refactoring --- .../test/_fault_injection_transport.py | 13 +++++--- .../test_fault_injection_transport_async.py | 33 ++++++++++--------- 2 files changed, 26 insertions(+), 20 deletions(-) diff --git a/sdk/cosmos/azure-cosmos/test/_fault_injection_transport.py b/sdk/cosmos/azure-cosmos/test/_fault_injection_transport.py index efe58bae3032..46ea8c2c6ce2 100644 --- a/sdk/cosmos/azure-cosmos/test/_fault_injection_transport.py +++ b/sdk/cosmos/azure-cosmos/test/_fault_injection_transport.py @@ -33,6 +33,7 @@ from azure.core.pipeline.transport import AioHttpTransport, AioHttpTransportResponse from azure.core.rest import HttpRequest, AsyncHttpResponse +import test_config from azure.cosmos.exceptions import CosmosHttpResponseError @@ -153,9 +154,9 @@ async def error_write_forbidden() -> Exception: ) @staticmethod - async def transform_convert_emulator_to_single_master_read_multi_region_account( - additional_region: str, - artificial_uri: str, + async def transform_topology_swr_mrr( + write_region_name: str, + read_region_name: str, r: HttpRequest, inner: Callable[[],asyncio.Task[AsyncHttpResponse]]) -> asyncio.Task[AsyncHttpResponse]: @@ -168,7 +169,11 @@ async def transform_convert_emulator_to_single_master_read_multi_region_account( if response.status_code == 200 and data: data = data.decode("utf-8") result = json.loads(data) - result["readableLocations"].append({"name": additional_region, "databaseAccountEndpoint" : artificial_uri}) + readable_locations = result["readableLocations"] + writable_locations = result["writableLocations"] + readable_locations[0]["name"] = write_region_name + writable_locations[0]["name"] = write_region_name + readable_locations.append({"name": read_region_name, "databaseAccountEndpoint" : test_config.TestConfig.local_host}) FaultInjectionTransport.logger.info("Transformed Account Topology: {}".format(result)) request: HttpRequest = response.request return FaultInjectionTransport.MockHttpResponse(request, 200, result) diff --git a/sdk/cosmos/azure-cosmos/test/test_fault_injection_transport_async.py b/sdk/cosmos/azure-cosmos/test/test_fault_injection_transport_async.py index cba2c1074ec4..320992fe221b 100644 --- a/sdk/cosmos/azure-cosmos/test/test_fault_injection_transport_async.py +++ b/sdk/cosmos/azure-cosmos/test/test_fault_injection_transport_async.py @@ -127,25 +127,26 @@ async def test_throws_injected_error(self, setup): TestFaultInjectionTransportAsync.cleanup_method(initialized_objects) async def test_succeeds_with_multiple_endpoints(self, setup): - localhost_uri: str = test_config.TestConfig.local_host - alternate_localhost_uri: str = localhost_uri.replace("localhost", "127.0.0.1") + expected_read_region_uri: str = test_config.TestConfig.local_host + expected_write_region_uri: str = expected_read_region_uri.replace("localhost", "127.0.0.1") custom_transport = FaultInjectionTransport() - is_get_account_predicate: Callable[[HttpRequest], bool] = lambda r: FaultInjectionTransport.predicate_is_database_account_call(r) - is_write_operation_predicate: Callable[[HttpRequest], bool] = lambda \ - r: FaultInjectionTransport.predicate_is_write_operation(r, localhost_uri) + # Inject rule to disallow writes in the read-only region + is_write_operation_in_read_region_predicate: Callable[[HttpRequest], bool] = lambda \ + r: FaultInjectionTransport.predicate_is_write_operation(r, expected_read_region_uri) - # Emulator uses "South Central US" with Uri https://127.0.0.1:8888 + custom_transport.add_fault( + is_write_operation_in_read_region_predicate, + lambda r: asyncio.create_task(FaultInjectionTransport.error_write_forbidden())) + # Inject topology transformation that would make Emulator look like a single write region + # account with two read regions + is_get_account_predicate: Callable[[HttpRequest], bool] = lambda r: FaultInjectionTransport.predicate_is_database_account_call(r) emulator_as_multi_region_sm_account_transformation: Callable[[HttpRequest, Callable[[HttpRequest], asyncio.Task[AsyncHttpResponse]]], AsyncHttpResponse] = \ - lambda r, inner: FaultInjectionTransport.transform_convert_emulator_to_single_master_read_multi_region_account( - additional_region="East US", - artificial_uri=localhost_uri, + lambda r, inner: FaultInjectionTransport.transform_topology_swr_mrr( + write_region_name="Write Region", + read_region_name="Read Region", r=r, inner=inner) - - custom_transport.add_fault( - is_write_operation_predicate, - lambda r: asyncio.create_task(FaultInjectionTransport.error_write_forbidden())) custom_transport.add_response_transformation( is_get_account_predicate, emulator_as_multi_region_sm_account_transformation) @@ -158,21 +159,21 @@ async def test_succeeds_with_multiple_endpoints(self, setup): initialized_objects = self.setup_method_with_custom_transport( custom_transport, - preferred_locations=["East US", "South Central US"]) + preferred_locations=["Read Region", "Write Region"]) try: container: ContainerProxy = initialized_objects["col"] created_document = await container.create_item(body=document_definition) request: HttpRequest = created_document.get_response_headers()["_request"] # Validate the response comes from "South Central US" (the write region) - assert request.url.startswith(alternate_localhost_uri) + assert request.url.startswith(expected_write_region_uri) start:float = time.perf_counter() while (time.perf_counter() - start) < 2: read_document = await container.read_item(id_value, partition_key=id_value) request: HttpRequest = read_document.get_response_headers()["_request"] # Validate the response comes from "East US" (the most preferred read-only region) - assert request.url.startswith(localhost_uri) + assert request.url.startswith(expected_read_region_uri) finally: TestFaultInjectionTransportAsync.cleanup_method(initialized_objects) From cb58896447b57b7ec76c0dfd24d51e570a33e854 Mon Sep 17 00:00:00 2001 From: Fabian Meiswinkel Date: Fri, 21 Feb 2025 18:58:50 +0000 Subject: [PATCH 19/86] Adding MM topology on Emulator --- .../test/_fault_injection_transport.py | 31 ++++++++++++ .../test_fault_injection_transport_async.py | 47 ++++++++++++++++++- 2 files changed, 77 insertions(+), 1 deletion(-) diff --git a/sdk/cosmos/azure-cosmos/test/_fault_injection_transport.py b/sdk/cosmos/azure-cosmos/test/_fault_injection_transport.py index 46ea8c2c6ce2..bd76d795452a 100644 --- a/sdk/cosmos/azure-cosmos/test/_fault_injection_transport.py +++ b/sdk/cosmos/azure-cosmos/test/_fault_injection_transport.py @@ -180,6 +180,37 @@ async def transform_topology_swr_mrr( return response + @staticmethod + async def transform_topology_mwr( + first_region_name: str, + second_region_name: str, + r: HttpRequest, + inner: Callable[[], asyncio.Task[AsyncHttpResponse]]) -> asyncio.Task[AsyncHttpResponse]: + + response = await inner() + if not FaultInjectionTransport.predicate_is_database_account_call(response.request): + return response + + await response.load_body() + data = response.body() + if response.status_code == 200 and data: + data = data.decode("utf-8") + result = json.loads(data) + readable_locations = result["readableLocations"] + writable_locations = result["writableLocations"] + readable_locations[0]["name"] = first_region_name + writable_locations[0]["name"] = first_region_name + readable_locations.append( + {"name": second_region_name, "databaseAccountEndpoint": test_config.TestConfig.local_host}) + writable_locations.append( + {"name": second_region_name, "databaseAccountEndpoint": test_config.TestConfig.local_host}) + result["enableMultipleWriteLocations"] = True + FaultInjectionTransport.logger.info("Transformed Account Topology: {}".format(result)) + request: HttpRequest = response.request + return FaultInjectionTransport.MockHttpResponse(request, 200, result) + + return response + class MockHttpResponse(AioHttpTransportResponse): def __init__(self, request: HttpRequest, status_code: int, content:Optional[dict[str, any]]): self.request: HttpRequest = request diff --git a/sdk/cosmos/azure-cosmos/test/test_fault_injection_transport_async.py b/sdk/cosmos/azure-cosmos/test/test_fault_injection_transport_async.py index 320992fe221b..817c263cad34 100644 --- a/sdk/cosmos/azure-cosmos/test/test_fault_injection_transport_async.py +++ b/sdk/cosmos/azure-cosmos/test/test_fault_injection_transport_async.py @@ -126,7 +126,7 @@ async def test_throws_injected_error(self, setup): finally: TestFaultInjectionTransportAsync.cleanup_method(initialized_objects) - async def test_succeeds_with_multiple_endpoints(self, setup): + async def test_swr_mrr_succeeds(self, setup): expected_read_region_uri: str = test_config.TestConfig.local_host expected_write_region_uri: str = expected_read_region_uri.replace("localhost", "127.0.0.1") custom_transport = FaultInjectionTransport() @@ -178,5 +178,50 @@ async def test_succeeds_with_multiple_endpoints(self, setup): finally: TestFaultInjectionTransportAsync.cleanup_method(initialized_objects) + async def test_mwr_succeeds(self, setup): + first_region_uri: str = test_config.TestConfig.local_host.replace("localhost", "127.0.0.1") + second_region_uri: str = test_config.TestConfig.local_host + custom_transport = FaultInjectionTransport() + + # Inject topology transformation that would make Emulator look like a single write region + # account with two read regions + is_get_account_predicate: Callable[[HttpRequest], bool] = lambda r: FaultInjectionTransport.predicate_is_database_account_call(r) + emulator_as_multi_write_region_account_transformation: Callable[[HttpRequest, Callable[[HttpRequest], asyncio.Task[AsyncHttpResponse]]], AsyncHttpResponse] = \ + lambda r, inner: FaultInjectionTransport.transform_topology_mwr( + first_region_name="First Region", + second_region_name="Second Region", + r=r, + inner=inner) + custom_transport.add_response_transformation( + is_get_account_predicate, + emulator_as_multi_write_region_account_transformation) + + id_value: str = str(uuid.uuid4()) + document_definition = {'id': id_value, + 'pk': id_value, + 'name': 'sample document', + 'key': 'value'} + + initialized_objects = self.setup_method_with_custom_transport( + custom_transport, + preferred_locations=["First Region", "Second Region"]) + try: + container: ContainerProxy = initialized_objects["col"] + + created_document = await container.create_item(body=document_definition) + request: HttpRequest = created_document.get_response_headers()["_request"] + # Validate the response comes from "South Central US" (the write region) + assert request.url.startswith(first_region_uri) + start:float = time.perf_counter() + + while (time.perf_counter() - start) < 2: + read_document = await container.read_item(id_value, partition_key=id_value) + request: HttpRequest = read_document.get_response_headers()["_request"] + # Validate the response comes from "East US" (the most preferred read-only region) + assert request.url.startswith(first_region_uri) + + finally: + TestFaultInjectionTransportAsync.cleanup_method(initialized_objects) + if __name__ == '__main__': unittest.main() From 46ec31ca1eef20f815c658fb51d5410cb3a60390 Mon Sep 17 00:00:00 2001 From: Fabian Meiswinkel Date: Sat, 22 Feb 2025 00:05:07 +0000 Subject: [PATCH 20/86] Adding cross region retry tests --- .../test/_fault_injection_transport.py | 18 ++- .../test_fault_injection_transport_async.py | 136 +++++++++++++++++- 2 files changed, 149 insertions(+), 5 deletions(-) diff --git a/sdk/cosmos/azure-cosmos/test/_fault_injection_transport.py b/sdk/cosmos/azure-cosmos/test/_fault_injection_transport.py index bd76d795452a..37c00667544a 100644 --- a/sdk/cosmos/azure-cosmos/test/_fault_injection_transport.py +++ b/sdk/cosmos/azure-cosmos/test/_fault_injection_transport.py @@ -35,7 +35,7 @@ import test_config from azure.cosmos.exceptions import CosmosHttpResponseError - +from azure.core.exceptions import ServiceRequestError class FaultInjectionTransport(AioHttpTransport): logger = logging.getLogger('azure.cosmos.fault_injection_transport') @@ -103,6 +103,10 @@ async def send(self, request: HttpRequest, *, stream: bool = False, proxies: Mut def predicate_url_contains_id(r: HttpRequest, id_value: str) -> bool: return id_value in r.url + @staticmethod + def predicate_targets_region(r: HttpRequest, region_endpoint: str) -> bool: + return r.url.startswith(region_endpoint) + @staticmethod def print_call_stack(): print("Call stack:") @@ -130,6 +134,12 @@ def predicate_is_database_account_call(r: HttpRequest) -> bool: return is_db_account_read + @staticmethod + def predicate_is_document_operation(r: HttpRequest) -> bool: + is_document_operation = (r.headers.get('x-ms-thinclient-proxy-resource-type') == 'docs') + + return is_document_operation + @staticmethod def predicate_is_write_operation(r: HttpRequest, uri_prefix: str) -> bool: is_write_document_operation = (r.headers.get('x-ms-thinclient-proxy-resource-type') == 'docs' @@ -153,6 +163,12 @@ async def error_write_forbidden() -> Exception: sub_status_code=3, ) + @staticmethod + async def error_region_down() -> Exception: + return ServiceRequestError( + message="Injected region down.", + ) + @staticmethod async def transform_topology_swr_mrr( write_region_name: str, diff --git a/sdk/cosmos/azure-cosmos/test/test_fault_injection_transport_async.py b/sdk/cosmos/azure-cosmos/test/test_fault_injection_transport_async.py index 817c263cad34..d09f017febbf 100644 --- a/sdk/cosmos/azure-cosmos/test/test_fault_injection_transport_async.py +++ b/sdk/cosmos/azure-cosmos/test/test_fault_injection_transport_async.py @@ -84,8 +84,8 @@ def teardown_class(cls): except Exception as closeError: logger.warning("Exception trying to delete database {}. {}".format(created_database.id, closeError)) - def setup_method_with_custom_transport(self, custom_transport: AioHttpTransport, **kwargs): - client = CosmosClient(host, master_key, consistency_level="Session", + def setup_method_with_custom_transport(self, custom_transport: AioHttpTransport, default_endpoint=host, **kwargs): + client = CosmosClient(default_endpoint, master_key, consistency_level="Session", connection_policy=connection_policy, transport=custom_transport, logger=logger, enable_diagnostics_logging=True, **kwargs) db: DatabaseProxy = client.get_database_client(TEST_DATABASE_ID) @@ -165,19 +165,147 @@ async def test_swr_mrr_succeeds(self, setup): created_document = await container.create_item(body=document_definition) request: HttpRequest = created_document.get_response_headers()["_request"] - # Validate the response comes from "South Central US" (the write region) + # Validate the response comes from "Write Region" (the write region) assert request.url.startswith(expected_write_region_uri) start:float = time.perf_counter() while (time.perf_counter() - start) < 2: read_document = await container.read_item(id_value, partition_key=id_value) request: HttpRequest = read_document.get_response_headers()["_request"] - # Validate the response comes from "East US" (the most preferred read-only region) + # Validate the response comes from "Read Region" (the most preferred read-only region) assert request.url.startswith(expected_read_region_uri) finally: TestFaultInjectionTransportAsync.cleanup_method(initialized_objects) + async def test_swr_mrr_region_down_read_succeeds(self, setup): + expected_read_region_uri: str = test_config.TestConfig.local_host + expected_write_region_uri: str = expected_read_region_uri.replace("localhost", "127.0.0.1") + custom_transport = FaultInjectionTransport() + # Inject rule to disallow writes in the read-only region + is_write_operation_in_read_region_predicate: Callable[[HttpRequest], bool] = lambda \ + r: FaultInjectionTransport.predicate_is_write_operation(r, expected_read_region_uri) + + custom_transport.add_fault( + is_write_operation_in_read_region_predicate, + lambda r: asyncio.create_task(FaultInjectionTransport.error_write_forbidden())) + + # Inject rule to simulate regional outage in "Read Region" + is_request_to_read_region: Callable[[HttpRequest], bool] = lambda \ + r: FaultInjectionTransport.predicate_targets_region(r, expected_read_region_uri) + + custom_transport.add_fault( + is_request_to_read_region, + lambda r: asyncio.create_task(FaultInjectionTransport.error_region_down())) + + # Inject topology transformation that would make Emulator look like a single write region + # account with two read regions + is_get_account_predicate: Callable[[HttpRequest], bool] = lambda r: FaultInjectionTransport.predicate_is_database_account_call(r) + emulator_as_multi_region_sm_account_transformation: Callable[[HttpRequest, Callable[[HttpRequest], asyncio.Task[AsyncHttpResponse]]], AsyncHttpResponse] = \ + lambda r, inner: FaultInjectionTransport.transform_topology_swr_mrr( + write_region_name="Write Region", + read_region_name="Read Region", + r=r, + inner=inner) + custom_transport.add_response_transformation( + is_get_account_predicate, + emulator_as_multi_region_sm_account_transformation) + + id_value: str = str(uuid.uuid4()) + document_definition = {'id': id_value, + 'pk': id_value, + 'name': 'sample document', + 'key': 'value'} + + initialized_objects = self.setup_method_with_custom_transport( + custom_transport, + default_endpoint=expected_write_region_uri, + preferred_locations=["Read Region", "Write Region"]) + try: + container: ContainerProxy = initialized_objects["col"] + + created_document = await container.create_item(body=document_definition) + request: HttpRequest = created_document.get_response_headers()["_request"] + # Validate the response comes from "South Central US" (the write region) + assert request.url.startswith(expected_write_region_uri) + start:float = time.perf_counter() + + while (time.perf_counter() - start) < 2: + read_document = await container.read_item(id_value, partition_key=id_value) + request: HttpRequest = read_document.get_response_headers()["_request"] + # Validate the response comes from "Write Region" ("Read Region" the most preferred read-only region is down) + assert request.url.startswith(expected_write_region_uri) + + finally: + TestFaultInjectionTransportAsync.cleanup_method(initialized_objects) + + async def test_swr_mrr_region_down_envoy_read_succeeds(self, setup): + expected_read_region_uri: str = test_config.TestConfig.local_host + expected_write_region_uri: str = expected_read_region_uri.replace("localhost", "127.0.0.1") + custom_transport = FaultInjectionTransport() + # Inject rule to disallow writes in the read-only region + is_write_operation_in_read_region_predicate: Callable[[HttpRequest], bool] = lambda \ + r: FaultInjectionTransport.predicate_is_write_operation(r, expected_read_region_uri) + + custom_transport.add_fault( + is_write_operation_in_read_region_predicate, + lambda r: asyncio.create_task(FaultInjectionTransport.error_write_forbidden())) + + # Inject rule to simulate regional outage in "Read Region" + is_request_to_read_region: Callable[[HttpRequest], bool] = lambda \ + r: FaultInjectionTransport.predicate_targets_region(r, expected_read_region_uri) and \ + FaultInjectionTransport.predicate_is_document_operation(r) + + custom_transport.add_fault( + is_request_to_read_region, + lambda r: asyncio.create_task(FaultInjectionTransport.error_after_delay( + 35000, + CosmosHttpResponseError( + status_code=502, + message="Some random reverse proxy error.")))) + + # Inject topology transformation that would make Emulator look like a single write region + # account with two read regions + is_get_account_predicate: Callable[[HttpRequest], bool] = lambda r: FaultInjectionTransport.predicate_is_database_account_call(r) + emulator_as_multi_region_sm_account_transformation: Callable[[HttpRequest, Callable[[HttpRequest], asyncio.Task[AsyncHttpResponse]]], AsyncHttpResponse] = \ + lambda r, inner: FaultInjectionTransport.transform_topology_swr_mrr( + write_region_name="Write Region", + read_region_name="Read Region", + r=r, + inner=inner) + custom_transport.add_response_transformation( + is_get_account_predicate, + emulator_as_multi_region_sm_account_transformation) + + id_value: str = str(uuid.uuid4()) + document_definition = {'id': id_value, + 'pk': id_value, + 'name': 'sample document', + 'key': 'value'} + + initialized_objects = self.setup_method_with_custom_transport( + custom_transport, + default_endpoint=expected_write_region_uri, + preferred_locations=["Read Region", "Write Region"]) + try: + container: ContainerProxy = initialized_objects["col"] + + created_document = await container.create_item(body=document_definition) + request: HttpRequest = created_document.get_response_headers()["_request"] + # Validate the response comes from "South Central US" (the write region) + assert request.url.startswith(expected_write_region_uri) + start:float = time.perf_counter() + + while (time.perf_counter() - start) < 2: + read_document = await container.read_item(id_value, partition_key=id_value) + request: HttpRequest = read_document.get_response_headers()["_request"] + # Validate the response comes from "Write Region" ("Read Region" the most preferred read-only region is down) + assert request.url.startswith(expected_write_region_uri) + + finally: + TestFaultInjectionTransportAsync.cleanup_method(initialized_objects) + + async def test_mwr_succeeds(self, setup): first_region_uri: str = test_config.TestConfig.local_host.replace("localhost", "127.0.0.1") second_region_uri: str = test_config.TestConfig.local_host From f03f51f3f6035a8a73b864bb41debd18f97cf6df Mon Sep 17 00:00:00 2001 From: Allen Kim Date: Mon, 31 Mar 2025 12:25:32 -0700 Subject: [PATCH 21/86] Add Excluded Locations Feature --- sdk/cosmos/azure-cosmos/azure/cosmos/_base.py | 1 + .../azure/cosmos/_cosmos_client_connection.py | 10 ++ .../azure/cosmos/_global_endpoint_manager.py | 4 +- .../azure/cosmos/_location_cache.py | 101 +++++++++--- .../azure/cosmos/_request_object.py | 25 ++- .../azure/cosmos/aio/_container.py | 30 ++++ .../aio/_cosmos_client_connection_async.py | 10 ++ .../aio/_global_endpoint_manager_async.py | 6 +- .../azure-cosmos/azure/cosmos/container.py | 30 ++++ .../azure/cosmos/cosmos_client.py | 4 + .../azure-cosmos/azure/cosmos/documents.py | 8 + .../samples/excluded_locations.py | 110 +++++++++++++ .../azure-cosmos/tests/test_health_check.py | 6 +- .../tests/test_health_check_async.py | 12 +- .../azure-cosmos/tests/test_location_cache.py | 148 +++++++++++++++++- .../tests/test_retry_policy_async.py | 1 + 16 files changed, 462 insertions(+), 44 deletions(-) create mode 100644 sdk/cosmos/azure-cosmos/samples/excluded_locations.py diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/_base.py b/sdk/cosmos/azure-cosmos/azure/cosmos/_base.py index 654b23c5d71f..bcfc611456ec 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/_base.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/_base.py @@ -63,6 +63,7 @@ 'priority': 'priorityLevel', 'no_response': 'responsePayloadOnWriteDisabled', 'max_item_count': 'maxItemCount', + 'excluded_locations': 'excludedLocations', } # Cosmos resource ID validation regex breakdown: diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/_cosmos_client_connection.py b/sdk/cosmos/azure-cosmos/azure/cosmos/_cosmos_client_connection.py index 3934c23bcf99..d64da38defb1 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/_cosmos_client_connection.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/_cosmos_client_connection.py @@ -2044,6 +2044,7 @@ def PatchItem( documents._OperationType.Patch, options) # Patch will use WriteEndpoint since it uses PUT operation request_params = RequestObject(resource_type, documents._OperationType.Patch) + request_params.set_excluded_location_from_options(options) request_data = {} if options.get("filterPredicate"): request_data["condition"] = options.get("filterPredicate") @@ -2132,6 +2133,7 @@ def _Batch( headers = base.GetHeaders(self, initial_headers, "post", path, collection_id, "docs", documents._OperationType.Batch, options) request_params = RequestObject("docs", documents._OperationType.Batch) + request_params.set_excluded_location_from_options(options) return cast( Tuple[List[Dict[str, Any]], CaseInsensitiveDict], self.__Post(path, request_params, batch_operations, headers, **kwargs) @@ -2192,6 +2194,7 @@ def DeleteAllItemsByPartitionKey( headers = base.GetHeaders(self, self.default_headers, "post", path, collection_id, "partitionkey", documents._OperationType.Delete, options) request_params = RequestObject("partitionkey", documents._OperationType.Delete) + request_params.set_excluded_location_from_options(options) _, last_response_headers = self.__Post( path=path, request_params=request_params, @@ -2647,6 +2650,7 @@ def Create( # Create will use WriteEndpoint since it uses POST operation request_params = RequestObject(typ, documents._OperationType.Create) + request_params.set_excluded_location_from_options(options) result, last_response_headers = self.__Post(path, request_params, body, headers, **kwargs) self.last_response_headers = last_response_headers @@ -2693,6 +2697,7 @@ def Upsert( # Upsert will use WriteEndpoint since it uses POST operation request_params = RequestObject(typ, documents._OperationType.Upsert) + request_params.set_excluded_location_from_options(options) result, last_response_headers = self.__Post(path, request_params, body, headers, **kwargs) self.last_response_headers = last_response_headers # update session for write request @@ -2736,6 +2741,7 @@ def Replace( options) # Replace will use WriteEndpoint since it uses PUT operation request_params = RequestObject(typ, documents._OperationType.Replace) + request_params.set_excluded_location_from_options(options) result, last_response_headers = self.__Put(path, request_params, resource, headers, **kwargs) self.last_response_headers = last_response_headers @@ -2777,6 +2783,7 @@ def Read( headers = base.GetHeaders(self, initial_headers, "get", path, id, typ, documents._OperationType.Read, options) # Read will use ReadEndpoint since it uses GET operation request_params = RequestObject(typ, documents._OperationType.Read) + request_params.set_excluded_location_from_options(options) result, last_response_headers = self.__Get(path, request_params, headers, **kwargs) self.last_response_headers = last_response_headers if response_hook: @@ -2816,6 +2823,7 @@ def DeleteResource( options) # Delete will use WriteEndpoint since it uses DELETE operation request_params = RequestObject(typ, documents._OperationType.Delete) + request_params.set_excluded_location_from_options(options) result, last_response_headers = self.__Delete(path, request_params, headers, **kwargs) self.last_response_headers = last_response_headers @@ -3052,6 +3060,7 @@ def __GetBodiesFromQueryResult(result: Dict[str, Any]) -> List[Dict[str, Any]]: resource_type, documents._OperationType.QueryPlan if is_query_plan else documents._OperationType.ReadFeed ) + request_params.set_excluded_location_from_options(options) headers = base.GetHeaders( self, initial_headers, @@ -3090,6 +3099,7 @@ def __GetBodiesFromQueryResult(result: Dict[str, Any]) -> List[Dict[str, Any]]: # Query operations will use ReadEndpoint even though it uses POST(for regular query operations) request_params = RequestObject(resource_type, documents._OperationType.SqlQuery) + request_params.set_excluded_location_from_options(options) req_headers = base.GetHeaders( self, initial_headers, diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/_global_endpoint_manager.py b/sdk/cosmos/azure-cosmos/azure/cosmos/_global_endpoint_manager.py index e167871dd4a5..944b684e392b 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/_global_endpoint_manager.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/_global_endpoint_manager.py @@ -50,10 +50,8 @@ def __init__(self, client): self.DefaultEndpoint = client.url_connection self.refresh_time_interval_in_ms = self.get_refresh_time_interval_in_ms_stub() self.location_cache = LocationCache( - self.PreferredLocations, self.DefaultEndpoint, - self.EnableEndpointDiscovery, - client.connection_policy.UseMultipleWriteLocations + client.connection_policy ) self.refresh_needed = False self.refresh_lock = threading.RLock() diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/_location_cache.py b/sdk/cosmos/azure-cosmos/azure/cosmos/_location_cache.py index 96651d5c8b7f..02b293e29b4b 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/_location_cache.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/_location_cache.py @@ -25,12 +25,13 @@ import collections import logging import time -from typing import Set +from typing import Set, Mapping, List from urllib.parse import urlparse from . import documents from . import http_constants from .documents import _OperationType +from ._request_object import RequestObject # pylint: disable=protected-access @@ -113,7 +114,10 @@ def get_endpoints_by_location(new_locations, except Exception as e: raise e - return endpoints_by_location, parsed_locations + # Also store a hash map of endpoints for each location + locations_by_endpoints = {value.get_primary(): key for key, value in endpoints_by_location.items()} + + return endpoints_by_location, locations_by_endpoints, parsed_locations def add_endpoint_if_preferred(endpoint: str, preferred_endpoints: Set[str], endpoints: Set[str]) -> bool: if endpoint in preferred_endpoints: @@ -150,6 +154,21 @@ def _get_health_check_endpoints( return endpoints +def _get_applicable_regional_endpoints(endpoints: List[RegionalRoutingContext], + location_name_by_endpoint: Mapping[str, str], + fall_back_endpoint: RegionalRoutingContext, + exclude_location_list: List[str]) -> List[RegionalRoutingContext]: + # filter endpoints by excluded locations + applicable_endpoints = [] + for endpoint in endpoints: + if location_name_by_endpoint.get(endpoint.get_primary()) not in exclude_location_list: + applicable_endpoints.append(endpoint) + + # if endpoint is empty add fallback endpoint + if not applicable_endpoints: + applicable_endpoints.append(fall_back_endpoint) + + return applicable_endpoints class LocationCache(object): # pylint: disable=too-many-public-methods,too-many-instance-attributes def current_time_millis(self): @@ -157,15 +176,10 @@ def current_time_millis(self): def __init__( self, - preferred_locations, default_endpoint, - enable_endpoint_discovery, - use_multiple_write_locations, + connection_policy, ): - self.preferred_locations = preferred_locations self.default_regional_routing_context = RegionalRoutingContext(default_endpoint, default_endpoint) - self.enable_endpoint_discovery = enable_endpoint_discovery - self.use_multiple_write_locations = use_multiple_write_locations self.enable_multiple_writable_locations = False self.write_regional_routing_contexts = [self.default_regional_routing_context] self.read_regional_routing_contexts = [self.default_regional_routing_context] @@ -173,8 +187,11 @@ def __init__( self.last_cache_update_time_stamp = 0 self.account_read_regional_routing_contexts_by_location = {} # pylint: disable=name-too-long self.account_write_regional_routing_contexts_by_location = {} # pylint: disable=name-too-long + self.account_locations_by_read_regional_routing_context = {} # pylint: disable=name-too-long + self.account_locations_by_write_regional_routing_context = {} # pylint: disable=name-too-long self.account_write_locations = [] self.account_read_locations = [] + self.connection_policy = connection_policy def get_write_regional_routing_contexts(self): return self.write_regional_routing_contexts @@ -207,6 +224,44 @@ def get_ordered_write_locations(self): def get_ordered_read_locations(self): return self.account_read_locations + def _get_configured_excluded_locations(self, request: RequestObject): + # If excluded locations were configured on request, use request level excluded locations. + excluded_locations = request.excluded_locations + if excluded_locations is None: + # If excluded locations were only configured on client(connection_policy), use client level + excluded_locations = self.connection_policy.ExcludedLocations + return excluded_locations + + def _get_applicable_read_regional_endpoints(self, request: RequestObject): + # Get configured excluded locations + excluded_locations = self._get_configured_excluded_locations(request) + + # If excluded locations were configured, return filtered regional endpoints by excluded locations. + if excluded_locations: + return _get_applicable_regional_endpoints( + self.get_read_regional_routing_contexts(), + self.account_locations_by_read_regional_routing_context, + self.get_write_regional_routing_contexts()[0], + excluded_locations) + + # Else, return all regional endpoints + return self.get_read_regional_routing_contexts() + + def _get_applicable_write_regional_endpoints(self, request: RequestObject): + # Get configured excluded locations + excluded_locations = self._get_configured_excluded_locations(request) + + # If excluded locations were configured, return filtered regional endpoints by excluded locations. + if excluded_locations: + return _get_applicable_regional_endpoints( + self.get_write_regional_routing_contexts(), + self.account_locations_by_write_regional_routing_context, + self.default_regional_routing_context, + excluded_locations) + + # Else, return all regional endpoints + return self.get_write_regional_routing_contexts() + def resolve_service_endpoint(self, request): if request.location_endpoint_to_route: return request.location_endpoint_to_route @@ -227,7 +282,7 @@ def resolve_service_endpoint(self, request): # For non-document resource types in case of client can use multiple write locations # or when client cannot use multiple write locations, flip-flop between the # first and the second writable region in DatabaseAccount (for manual failover) - if self.enable_endpoint_discovery and self.account_write_locations: + if self.connection_policy.EnableEndpointDiscovery and self.account_write_locations: location_index = min(location_index % 2, len(self.account_write_locations) - 1) write_location = self.account_write_locations[location_index] if (self.account_write_regional_routing_contexts_by_location @@ -247,9 +302,9 @@ def resolve_service_endpoint(self, request): return self.default_regional_routing_context.get_primary() regional_routing_contexts = ( - self.get_write_regional_routing_contexts() + self._get_applicable_write_regional_endpoints(request) if documents._OperationType.IsWriteOperation(request.operation_type) - else self.get_read_regional_routing_contexts() + else self._get_applicable_read_regional_endpoints(request) ) regional_routing_context = regional_routing_contexts[location_index % len(regional_routing_contexts)] if ( @@ -263,12 +318,14 @@ def resolve_service_endpoint(self, request): return regional_routing_context.get_primary() def should_refresh_endpoints(self): # pylint: disable=too-many-return-statements - most_preferred_location = self.preferred_locations[0] if self.preferred_locations else None + most_preferred_location = self.connection_policy.PreferredLocations[0] \ + if self.connection_policy.PreferredLocations else None # we should schedule refresh in background if we are unable to target the user's most preferredLocation. - if self.enable_endpoint_discovery: + if self.connection_policy.EnableEndpointDiscovery: - should_refresh = self.use_multiple_write_locations and not self.enable_multiple_writable_locations + should_refresh = (self.connection_policy.UseMultipleWriteLocations + and not self.enable_multiple_writable_locations) if (most_preferred_location and most_preferred_location in self.account_read_regional_routing_contexts_by_location): @@ -358,25 +415,27 @@ def update_location_cache(self, write_locations=None, read_locations=None, enabl if enable_multiple_writable_locations: self.enable_multiple_writable_locations = enable_multiple_writable_locations - if self.enable_endpoint_discovery: + if self.connection_policy.EnableEndpointDiscovery: if read_locations: (self.account_read_regional_routing_contexts_by_location, + self.account_locations_by_read_regional_routing_context, self.account_read_locations) = get_endpoints_by_location( read_locations, self.account_read_regional_routing_contexts_by_location, self.default_regional_routing_context, False, - self.use_multiple_write_locations + self.connection_policy.UseMultipleWriteLocations ) if write_locations: (self.account_write_regional_routing_contexts_by_location, + self.account_locations_by_write_regional_routing_context, self.account_write_locations) = get_endpoints_by_location( write_locations, self.account_write_regional_routing_contexts_by_location, self.default_regional_routing_context, True, - self.use_multiple_write_locations + self.connection_policy.UseMultipleWriteLocations ) self.write_regional_routing_contexts = self.get_preferred_regional_routing_contexts( @@ -399,18 +458,18 @@ def get_preferred_regional_routing_contexts( regional_endpoints = [] # if enableEndpointDiscovery is false, we always use the defaultEndpoint that # user passed in during documentClient init - if self.enable_endpoint_discovery and endpoints_by_location: # pylint: disable=too-many-nested-blocks + if self.connection_policy.EnableEndpointDiscovery and endpoints_by_location: # pylint: disable=too-many-nested-blocks if ( self.can_use_multiple_write_locations() or expected_available_operation == EndpointOperationType.ReadType ): unavailable_endpoints = [] - if self.preferred_locations: + if self.connection_policy.PreferredLocations: # When client can not use multiple write locations, preferred locations # list should only be used determining read endpoints order. If client # can use multiple write locations, preferred locations list should be # used for determining both read and write endpoints order. - for location in self.preferred_locations: + for location in self.connection_policy.PreferredLocations: regional_endpoint = endpoints_by_location[location] if location in endpoints_by_location \ else None if regional_endpoint: @@ -436,7 +495,7 @@ def get_preferred_regional_routing_contexts( return regional_endpoints def can_use_multiple_write_locations(self): - return self.use_multiple_write_locations and self.enable_multiple_writable_locations + return self.connection_policy.UseMultipleWriteLocations and self.enable_multiple_writable_locations def can_use_multiple_write_locations_for_request(self, request): # pylint: disable=name-too-long return self.can_use_multiple_write_locations() and ( diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/_request_object.py b/sdk/cosmos/azure-cosmos/azure/cosmos/_request_object.py index a220c6af42c2..94805934ce74 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/_request_object.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/_request_object.py @@ -21,7 +21,8 @@ """Represents a request object. """ -from typing import Optional +from typing import Optional, Mapping, Any + class RequestObject(object): def __init__(self, resource_type: str, operation_type: str, endpoint_override: Optional[str] = None) -> None: @@ -33,6 +34,7 @@ def __init__(self, resource_type: str, operation_type: str, endpoint_override: O self.location_index_to_route: Optional[int] = None self.location_endpoint_to_route: Optional[str] = None self.last_routed_location_endpoint_within_region: Optional[str] = None + self.excluded_locations = None def route_to_location_with_preferred_location_flag( # pylint: disable=name-too-long self, @@ -52,3 +54,24 @@ def clear_route_to_location(self) -> None: self.location_index_to_route = None self.use_preferred_locations = None self.location_endpoint_to_route = None + + def _can_set_excluded_location(self, options: Mapping[str, Any]) -> bool: + # If resource types for requests are not one of the followings, excluded locations cannot be set + if self.resource_type.lower() not in ['docs', 'documents', 'partitionkey']: + return False + + # If 'excludedLocations' wasn't in the options, excluded locations cannot be set + if (options is None + or 'excludedLocations' not in options): + return False + + # The 'excludedLocations' cannot be None + if options['excludedLocations'] is None: + raise ValueError("Excluded locations cannot be None. " + "If you want to remove all excluded locations, try passing an empty list.") + + return True + + def set_excluded_location_from_options(self, options: Mapping[str, Any]) -> None: + if self._can_set_excluded_location(options): + self.excluded_locations = options['excludedLocations'] diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_container.py b/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_container.py index 0142e215f318..590f43331652 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_container.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_container.py @@ -224,6 +224,8 @@ async def create_item( :keyword bool enable_automatic_id_generation: Enable automatic id generation if no id present. :keyword str session_token: Token for use with Session consistency. :keyword dict[str, str] initial_headers: Initial headers to be sent as part of the request. + :keyword list[str] excluded_locations: Excluded locations to be skipped from preferred locations. The locations + in this list are specified as the names of the azure Cosmos locations like, 'West US', 'East US' and so on. :keyword response_hook: A callable invoked with the response metadata. :paramtype response_hook: Callable[[Mapping[str, str], Dict[str, Any]], None] :keyword Literal["High", "Low"] priority: Priority based execution allows users to set a priority for each @@ -303,6 +305,8 @@ async def read_item( :keyword Literal["High", "Low"] priority: Priority based execution allows users to set a priority for each request. Once the user has reached their provisioned throughput, low priority requests are throttled before high priority requests start getting throttled. Feature must first be enabled at the account level. + :keyword list[str] excluded_locations: Excluded locations to be skipped from preferred locations. The locations + in this list are specified as the names of the azure Cosmos locations like, 'West US', 'East US' and so on. :raises ~azure.cosmos.exceptions.CosmosHttpResponseError: The given item couldn't be retrieved. :returns: A CosmosDict representing the retrieved item. :rtype: ~azure.cosmos.CosmosDict[str, Any] @@ -361,6 +365,8 @@ def read_all_items( :keyword Literal["High", "Low"] priority: Priority based execution allows users to set a priority for each request. Once the user has reached their provisioned throughput, low priority requests are throttled before high priority requests start getting throttled. Feature must first be enabled at the account level. + :keyword list[str] excluded_locations: Excluded locations to be skipped from preferred locations. The locations + in this list are specified as the names of the azure Cosmos locations like, 'West US', 'East US' and so on. :returns: An AsyncItemPaged of items (dicts). :rtype: AsyncItemPaged[Dict[str, Any]] """ @@ -441,6 +447,8 @@ def query_items( :keyword Literal["High", "Low"] priority: Priority based execution allows users to set a priority for each request. Once the user has reached their provisioned throughput, low priority requests are throttled before high priority requests start getting throttled. Feature must first be enabled at the account level. + :keyword list[str] excluded_locations: Excluded locations to be skipped from preferred locations. The locations + in this list are specified as the names of the azure Cosmos locations like, 'West US', 'East US' and so on. :returns: An AsyncItemPaged of items (dicts). :rtype: AsyncItemPaged[Dict[str, Any]] @@ -537,6 +545,8 @@ def query_items_change_feed( ALL_VERSIONS_AND_DELETES: Query all versions and deleted items from either `start_time='Now'` or 'continuation' token. :paramtype mode: Literal["LatestVersion", "AllVersionsAndDeletes"] + :keyword list[str] excluded_locations: Excluded locations to be skipped from preferred locations. The locations + in this list are specified as the names of the azure Cosmos locations like, 'West US', 'East US' and so on. :keyword response_hook: A callable invoked with the response metadata. :paramtype response_hook: Callable[[Mapping[str, str], Dict[str, Any]], None] :returns: An AsyncItemPaged of items (dicts). @@ -575,6 +585,8 @@ def query_items_change_feed( ALL_VERSIONS_AND_DELETES: Query all versions and deleted items from either `start_time='Now'` or 'continuation' token. :paramtype mode: Literal["LatestVersion", "AllVersionsAndDeletes"] + :keyword list[str] excluded_locations: Excluded locations to be skipped from preferred locations. The locations + in this list are specified as the names of the azure Cosmos locations like, 'West US', 'East US' and so on. :keyword response_hook: A callable invoked with the response metadata. :paramtype response_hook: Callable[[Mapping[str, str], Dict[str, Any]], None] :returns: An AsyncItemPaged of items (dicts). @@ -601,6 +613,8 @@ def query_items_change_feed( request. Once the user has reached their provisioned throughput, low priority requests are throttled before high priority requests start getting throttled. Feature must first be enabled at the account level. :paramtype priority: Literal["High", "Low"] + :keyword list[str] excluded_locations: Excluded locations to be skipped from preferred locations. The locations + in this list are specified as the names of the azure Cosmos locations like, 'West US', 'East US' and so on. :keyword response_hook: A callable invoked with the response metadata. :paramtype response_hook: Callable[[Mapping[str, str], Dict[str, Any]], None] :returns: An AsyncItemPaged of items (dicts). @@ -639,6 +653,8 @@ def query_items_change_feed( ALL_VERSIONS_AND_DELETES: Query all versions and deleted items from either `start_time='Now'` or 'continuation' token. :paramtype mode: Literal["LatestVersion", "AllVersionsAndDeletes"] + :keyword list[str] excluded_locations: Excluded locations to be skipped from preferred locations. The locations + in this list are specified as the names of the azure Cosmos locations like, 'West US', 'East US' and so on. :keyword response_hook: A callable invoked with the response metadata. :paramtype response_hook: Callable[[Mapping[str, str], Dict[str, Any]], None] :returns: An AsyncItemPaged of items (dicts). @@ -675,6 +691,8 @@ def query_items_change_feed( # pylint: disable=unused-argument ALL_VERSIONS_AND_DELETES: Query all versions and deleted items from either `start_time='Now'` or 'continuation' token. :paramtype mode: Literal["LatestVersion", "AllVersionsAndDeletes"] + :keyword list[str] excluded_locations: Excluded locations to be skipped from preferred locations. The locations + in this list are specified as the names of the azure Cosmos locations like, 'West US', 'East US' and so on. :keyword response_hook: A callable invoked with the response metadata. :paramtype response_hook: Callable[[Mapping[str, str], Dict[str, Any]], None] :returns: An AsyncItemPaged of items (dicts). @@ -748,6 +766,8 @@ async def upsert_item( :keyword bool no_response: Indicates whether service should be instructed to skip sending response payloads. When not specified explicitly here, the default value will be determined from client-level options. + :keyword list[str] excluded_locations: Excluded locations to be skipped from preferred locations. The locations + in this list are specified as the names of the azure Cosmos locations like, 'West US', 'East US' and so on. :raises ~azure.cosmos.exceptions.CosmosHttpResponseError: The given item could not be upserted. :returns: A CosmosDict representing the upserted item. The dict will be empty if `no_response` is specified. @@ -830,6 +850,8 @@ async def replace_item( :keyword bool no_response: Indicates whether service should be instructed to skip sending response payloads. When not specified explicitly here, the default value will be determined from client-level options. + :keyword list[str] excluded_locations: Excluded locations to be skipped from preferred locations. The locations + in this list are specified as the names of the azure Cosmos locations like, 'West US', 'East US' and so on. :raises ~azure.cosmos.exceptions.CosmosHttpResponseError: The replace operation failed or the item with given id does not exist. :returns: A CosmosDict representing the item after replace went through. The dict will be empty if `no_response` @@ -906,6 +928,8 @@ async def patch_item( :keyword bool no_response: Indicates whether service should be instructed to skip sending response payloads. When not specified explicitly here, the default value will be determined from client-level options. + :keyword list[str] excluded_locations: Excluded locations to be skipped from preferred locations. The locations + in this list are specified as the names of the azure Cosmos locations like, 'West US', 'East US' and so on. :raises ~azure.cosmos.exceptions.CosmosHttpResponseError: The patch operations failed or the item with given id does not exist. :returns: A CosmosDict representing the item after the patch operations went through. The dict will be empty if @@ -973,6 +997,8 @@ async def delete_item( :keyword Literal["High", "Low"] priority: Priority based execution allows users to set a priority for each request. Once the user has reached their provisioned throughput, low priority requests are throttled before high priority requests start getting throttled. Feature must first be enabled at the account level. + :keyword list[str] excluded_locations: Excluded locations to be skipped from preferred locations. The locations + in this list are specified as the names of the azure Cosmos locations like, 'West US', 'East US' and so on. :keyword response_hook: A callable invoked with the response metadata. :paramtype response_hook: Callable[[Mapping[str, str], None], None] :raises ~azure.cosmos.exceptions.CosmosHttpResponseError: The item wasn't deleted successfully. @@ -1223,6 +1249,8 @@ async def delete_all_items_by_partition_key( :keyword str pre_trigger_include: trigger id to be used as pre operation trigger. :keyword str post_trigger_include: trigger id to be used as post operation trigger. :keyword str session_token: Token for use with Session consistency. + :keyword list[str] excluded_locations: Excluded locations to be skipped from preferred locations. The locations + in this list are specified as the names of the azure Cosmos locations like, 'West US', 'East US' and so on. :keyword Callable response_hook: A callable invoked with the response metadata. :rtype: None """ @@ -1278,6 +1306,8 @@ async def execute_item_batch( :keyword Literal["High", "Low"] priority: Priority based execution allows users to set a priority for each request. Once the user has reached their provisioned throughput, low priority requests are throttled before high priority requests start getting throttled. Feature must first be enabled at the account level. + :keyword list[str] excluded_locations: Excluded locations to be skipped from preferred locations. The locations + in this list are specified as the names of the azure Cosmos locations like, 'West US', 'East US' and so on. :keyword Callable response_hook: A callable invoked with the response metadata. :returns: A CosmosList representing the items after the batch operations went through. :raises ~azure.cosmos.exceptions.CosmosHttpResponseError: The batch failed to execute. diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_cosmos_client_connection_async.py b/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_cosmos_client_connection_async.py index 49219533a7e6..9008b46bb1c1 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_cosmos_client_connection_async.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_cosmos_client_connection_async.py @@ -768,6 +768,7 @@ async def Create( # Create will use WriteEndpoint since it uses POST operation request_params = _request_object.RequestObject(typ, documents._OperationType.Create) + request_params.set_excluded_location_from_options(options) result, last_response_headers = await self.__Post(path, request_params, body, headers, **kwargs) self.last_response_headers = last_response_headers @@ -907,6 +908,7 @@ async def Upsert( # Upsert will use WriteEndpoint since it uses POST operation request_params = _request_object.RequestObject(typ, documents._OperationType.Upsert) + request_params.set_excluded_location_from_options(options) result, last_response_headers = await self.__Post(path, request_params, body, headers, **kwargs) self.last_response_headers = last_response_headers # update session for write request @@ -1208,6 +1210,7 @@ async def Read( options) # Read will use ReadEndpoint since it uses GET operation request_params = _request_object.RequestObject(typ, documents._OperationType.Read) + request_params.set_excluded_location_from_options(options) result, last_response_headers = await self.__Get(path, request_params, headers, **kwargs) self.last_response_headers = last_response_headers if response_hook: @@ -1466,6 +1469,7 @@ async def PatchItem( documents._OperationType.Patch, options) # Patch will use WriteEndpoint since it uses PUT operation request_params = _request_object.RequestObject(typ, documents._OperationType.Patch) + request_params.set_excluded_location_from_options(options) request_data = {} if options.get("filterPredicate"): request_data["condition"] = options.get("filterPredicate") @@ -1570,6 +1574,7 @@ async def Replace( options) # Replace will use WriteEndpoint since it uses PUT operation request_params = _request_object.RequestObject(typ, documents._OperationType.Replace) + request_params.set_excluded_location_from_options(options) result, last_response_headers = await self.__Put(path, request_params, resource, headers, **kwargs) self.last_response_headers = last_response_headers @@ -1893,6 +1898,7 @@ async def DeleteResource( options) # Delete will use WriteEndpoint since it uses DELETE operation request_params = _request_object.RequestObject(typ, documents._OperationType.Delete) + request_params.set_excluded_location_from_options(options) result, last_response_headers = await self.__Delete(path, request_params, headers, **kwargs) self.last_response_headers = last_response_headers @@ -2006,6 +2012,7 @@ async def _Batch( headers = base.GetHeaders(self, initial_headers, "post", path, collection_id, "docs", documents._OperationType.Batch, options) request_params = _request_object.RequestObject("docs", documents._OperationType.Batch) + request_params.set_excluded_location_from_options(options) result = await self.__Post(path, request_params, batch_operations, headers, **kwargs) return cast(Tuple[List[Dict[str, Any]], CaseInsensitiveDict], result) @@ -2861,6 +2868,7 @@ def __GetBodiesFromQueryResult(result: Dict[str, Any]) -> List[Dict[str, Any]]: typ, documents._OperationType.QueryPlan if is_query_plan else documents._OperationType.ReadFeed ) + request_params.set_excluded_location_from_options(options) headers = base.GetHeaders(self, initial_headers, "get", path, id_, typ, request_params.operation_type, options, partition_key_range_id) @@ -2890,6 +2898,7 @@ def __GetBodiesFromQueryResult(result: Dict[str, Any]) -> List[Dict[str, Any]]: # Query operations will use ReadEndpoint even though it uses POST(for regular query operations) request_params = _request_object.RequestObject(typ, documents._OperationType.SqlQuery) + request_params.set_excluded_location_from_options(options) req_headers = base.GetHeaders(self, initial_headers, "post", path, id_, typ, request_params.operation_type, options, partition_key_range_id) @@ -3259,6 +3268,7 @@ async def DeleteAllItemsByPartitionKey( headers = base.GetHeaders(self, initial_headers, "post", path, collection_id, "partitionkey", documents._OperationType.Delete, options) request_params = _request_object.RequestObject("partitionkey", documents._OperationType.Delete) + request_params.set_excluded_location_from_options(options) _, last_response_headers = await self.__Post(path=path, request_params=request_params, req_headers=headers, body=None, **kwargs) self.last_response_headers = last_response_headers diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_global_endpoint_manager_async.py b/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_global_endpoint_manager_async.py index 4d00a7ef5629..f576e97d8e0b 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_global_endpoint_manager_async.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_global_endpoint_manager_async.py @@ -25,7 +25,7 @@ import asyncio # pylint: disable=do-not-import-asyncio import logging -from asyncio import CancelledError +from asyncio import CancelledError # pylint: disable=do-not-import-asyncio from typing import Tuple from azure.core.exceptions import AzureError @@ -53,10 +53,8 @@ def __init__(self, client): self.DefaultEndpoint = client.url_connection self.refresh_time_interval_in_ms = self.get_refresh_time_interval_in_ms_stub() self.location_cache = LocationCache( - self.PreferredLocations, self.DefaultEndpoint, - self.EnableEndpointDiscovery, - client.connection_policy.UseMultipleWriteLocations + client.connection_policy ) self.startup = True self.refresh_task = None diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/container.py b/sdk/cosmos/azure-cosmos/azure/cosmos/container.py index efa5e7c09a50..a815c9110471 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/container.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/container.py @@ -233,6 +233,8 @@ def read_item( # pylint:disable=docstring-missing-param :keyword Literal["High", "Low"] priority: Priority based execution allows users to set a priority for each request. Once the user has reached their provisioned throughput, low priority requests are throttled before high priority requests start getting throttled. Feature must first be enabled at the account level. + :keyword list[str] excluded_locations: Excluded locations to be skipped from preferred locations. The locations + in this list are specified as the names of the azure Cosmos locations like, 'West US', 'East US' and so on. :returns: A CosmosDict representing the item to be retrieved. :raises ~azure.cosmos.exceptions.CosmosHttpResponseError: The given item couldn't be retrieved. :rtype: ~azure.cosmos.CosmosDict[str, Any] @@ -298,6 +300,8 @@ def read_all_items( # pylint:disable=docstring-missing-param :keyword Literal["High", "Low"] priority: Priority based execution allows users to set a priority for each request. Once the user has reached their provisioned throughput, low priority requests are throttled before high priority requests start getting throttled. Feature must first be enabled at the account level. + :keyword list[str] excluded_locations: Excluded locations to be skipped from preferred locations. The locations + in this list are specified as the names of the azure Cosmos locations like, 'West US', 'East US' and so on. :returns: An Iterable of items (dicts). :rtype: Iterable[Dict[str, Any]] """ @@ -364,6 +368,8 @@ def query_items_change_feed( ALL_VERSIONS_AND_DELETES: Query all versions and deleted items from either `start_time='Now'` or 'continuation' token. :paramtype mode: Literal["LatestVersion", "AllVersionsAndDeletes"] + :keyword list[str] excluded_locations: Excluded locations to be skipped from preferred locations. The locations + in this list are specified as the names of the azure Cosmos locations like, 'West US', 'East US' and so on. :keyword response_hook: A callable invoked with the response metadata. :paramtype response_hook: Callable[[Mapping[str, str], Dict[str, Any]], None] :returns: An Iterable of items (dicts). @@ -403,6 +409,8 @@ def query_items_change_feed( ALL_VERSIONS_AND_DELETES: Query all versions and deleted items from either `start_time='Now'` or 'continuation' token. :paramtype mode: Literal["LatestVersion", "AllVersionsAndDeletes"] + :keyword list[str] excluded_locations: Excluded locations to be skipped from preferred locations. The locations + in this list are specified as the names of the azure Cosmos locations like, 'West US', 'East US' and so on. :keyword response_hook: A callable invoked with the response metadata. :paramtype response_hook: Callable[[Mapping[str, str], Dict[str, Any]], None] :returns: An Iterable of items (dicts). @@ -429,6 +437,8 @@ def query_items_change_feed( request. Once the user has reached their provisioned throughput, low priority requests are throttled before high priority requests start getting throttled. Feature must first be enabled at the account level. :paramtype priority: Literal["High", "Low"] + :keyword list[str] excluded_locations: Excluded locations to be skipped from preferred locations. The locations + in this list are specified as the names of the azure Cosmos locations like, 'West US', 'East US' and so on. :keyword response_hook: A callable invoked with the response metadata. :paramtype response_hook: Callable[[Mapping[str, str], Dict[str, Any]], None] :returns: An Iterable of items (dicts). @@ -466,6 +476,8 @@ def query_items_change_feed( ALL_VERSIONS_AND_DELETES: Query all versions and deleted items from either `start_time='Now'` or 'continuation' token. :paramtype mode: Literal["LatestVersion", "AllVersionsAndDeletes"] + :keyword list[str] excluded_locations: Excluded locations to be skipped from preferred locations. The locations + in this list are specified as the names of the azure Cosmos locations like, 'West US', 'East US' and so on. :keyword response_hook: A callable invoked with the response metadata. :paramtype response_hook: Callable[[Mapping[str, str], Dict[str, Any]], None] :returns: An Iterable of items (dicts). @@ -501,6 +513,8 @@ def query_items_change_feed( ALL_VERSIONS_AND_DELETES: Query all versions and deleted items from either `start_time='Now'` or 'continuation' token. :paramtype mode: Literal["LatestVersion", "AllVersionsAndDeletes"] + :keyword list[str] excluded_locations: Excluded locations to be skipped from preferred locations. The locations + in this list are specified as the names of the azure Cosmos locations like, 'West US', 'East US' and so on. :keyword response_hook: A callable invoked with the response metadata. :paramtype response_hook: Callable[[Mapping[str, str], Dict[str, Any]], None] :param Any args: args @@ -601,6 +615,8 @@ def query_items( # pylint:disable=docstring-missing-param :keyword bool populate_index_metrics: Used to obtain the index metrics to understand how the query engine used existing indexes and how it could use potential new indexes. Please note that this options will incur overhead, so it should be enabled only when debugging slow queries. + :keyword list[str] excluded_locations: Excluded locations to be skipped from preferred locations. The locations + in this list are specified as the names of the azure Cosmos locations like, 'West US', 'East US' and so on. :returns: An Iterable of items (dicts). :rtype: ItemPaged[Dict[str, Any]] @@ -716,6 +732,8 @@ def replace_item( # pylint:disable=docstring-missing-param :keyword bool no_response: Indicates whether service should be instructed to skip sending response payloads. When not specified explicitly here, the default value will be determined from kwargs or when also not specified there from client-level kwargs. + :keyword list[str] excluded_locations: Excluded locations to be skipped from preferred locations. The locations + in this list are specified as the names of the azure Cosmos locations like, 'West US', 'East US' and so on. :raises ~azure.cosmos.exceptions.CosmosHttpResponseError: The replace operation failed or the item with given id does not exist. :returns: A CosmosDict representing the item after replace went through. The dict will be empty if `no_response` @@ -790,6 +808,8 @@ def upsert_item( # pylint:disable=docstring-missing-param :keyword bool no_response: Indicates whether service should be instructed to skip sending response payloads. When not specified explicitly here, the default value will be determined from kwargs or when also not specified there from client-level kwargs. + :keyword list[str] excluded_locations: Excluded locations to be skipped from preferred locations. The locations + in this list are specified as the names of the azure Cosmos locations like, 'West US', 'East US' and so on. :raises ~azure.cosmos.exceptions.CosmosHttpResponseError: The given item could not be upserted. :returns: A CosmosDict representing the upserted item. The dict will be empty if `no_response` is specified. :rtype: ~azure.cosmos.CosmosDict[str, Any] @@ -879,6 +899,8 @@ def create_item( # pylint:disable=docstring-missing-param :keyword bool no_response: Indicates whether service should be instructed to skip sending response payloads. When not specified explicitly here, the default value will be determined from kwargs or when also not specified there from client-level kwargs. + :keyword list[str] excluded_locations: Excluded locations to be skipped from preferred locations. The locations + in this list are specified as the names of the azure Cosmos locations like, 'West US', 'East US' and so on. :raises ~azure.cosmos.exceptions.CosmosHttpResponseError: Item with the given ID already exists. :returns: A CosmosDict representing the new item. The dict will be empty if `no_response` is specified. :rtype: ~azure.cosmos.CosmosDict[str, Any] @@ -970,6 +992,8 @@ def patch_item( :keyword bool no_response: Indicates whether service should be instructed to skip sending response payloads. When not specified explicitly here, the default value will be determined from kwargs or when also not specified there from client-level kwargs. + :keyword list[str] excluded_locations: Excluded locations to be skipped from preferred locations. The locations + in this list are specified as the names of the azure Cosmos locations like, 'West US', 'East US' and so on. :raises ~azure.cosmos.exceptions.CosmosHttpResponseError: The patch operations failed or the item with given id does not exist. :returns: A CosmosDict representing the item after the patch operations went through. The dict will be empty @@ -1030,6 +1054,8 @@ def execute_item_batch( :keyword Literal["High", "Low"] priority: Priority based execution allows users to set a priority for each request. Once the user has reached their provisioned throughput, low priority requests are throttled before high priority requests start getting throttled. Feature must first be enabled at the account level. + :keyword list[str] excluded_locations: Excluded locations to be skipped from preferred locations. The locations + in this list are specified as the names of the azure Cosmos locations like, 'West US', 'East US' and so on. :keyword response_hook: A callable invoked with the response metadata. :paramtype response_hook: [Callable[[Mapping[str, str], List[Dict[str, Any]]], None] :returns: A CosmosList representing the items after the batch operations went through. @@ -1102,6 +1128,8 @@ def delete_item( # pylint:disable=docstring-missing-param :keyword Literal["High", "Low"] priority: Priority based execution allows users to set a priority for each request. Once the user has reached their provisioned throughput, low priority requests are throttled before high priority requests start getting throttled. Feature must first be enabled at the account level. + :keyword list[str] excluded_locations: Excluded locations to be skipped from preferred locations. The locations + in this list are specified as the names of the azure Cosmos locations like, 'West US', 'East US' and so on. :keyword response_hook: A callable invoked with the response metadata. :paramtype response_hook: Callable[[Mapping[str, str], None], None] :raises ~azure.cosmos.exceptions.CosmosHttpResponseError: The item wasn't deleted successfully. @@ -1377,6 +1405,8 @@ def delete_all_items_by_partition_key( :keyword str pre_trigger_include: trigger id to be used as pre operation trigger. :keyword str post_trigger_include: trigger id to be used as post operation trigger. :keyword str session_token: Token for use with Session consistency. + :keyword list[str] excluded_locations: Excluded locations to be skipped from preferred locations. The locations + in this list are specified as the names of the azure Cosmos locations like, 'West US', 'East US' and so on. :keyword response_hook: A callable invoked with the response metadata. :paramtype response_hook: Callable[[Mapping[str, str], None], None] = None, :rtype: None diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/cosmos_client.py b/sdk/cosmos/azure-cosmos/azure/cosmos/cosmos_client.py index 10543f97c47b..b7a6ea94bd2b 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/cosmos_client.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/cosmos_client.py @@ -93,6 +93,8 @@ def _build_connection_policy(kwargs: Dict[str, Any]) -> ConnectionPolicy: policy.ProxyConfiguration = kwargs.pop('proxy_config', policy.ProxyConfiguration) policy.EnableEndpointDiscovery = kwargs.pop('enable_endpoint_discovery', policy.EnableEndpointDiscovery) policy.PreferredLocations = kwargs.pop('preferred_locations', policy.PreferredLocations) + # TODO: Consider storing callback method instead, such as 'Supplier' in JAVA SDK + policy.ExcludedLocations = kwargs.pop('excluded_locations', policy.ExcludedLocations) policy.UseMultipleWriteLocations = kwargs.pop('multiple_write_locations', policy.UseMultipleWriteLocations) # SSL config @@ -181,6 +183,8 @@ class CosmosClient: # pylint: disable=client-accepts-api-version-keyword :keyword bool enable_endpoint_discovery: Enable endpoint discovery for geo-replicated database accounts. (Default: True) :keyword list[str] preferred_locations: The preferred locations for geo-replicated database accounts. + :keyword list[str] excluded_locations: The excluded locations to be skipped from preferred locations. The locations + in this list are specified as the names of the azure Cosmos locations like, 'West US', 'East US' and so on. :keyword bool enable_diagnostics_logging: Enable the CosmosHttpLogging policy. Must be used along with a logger to work. :keyword ~logging.Logger logger: Logger to be used for collecting request diagnostics. Can be passed in at client diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/documents.py b/sdk/cosmos/azure-cosmos/azure/cosmos/documents.py index 40fbed24451f..9e04829be52f 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/documents.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/documents.py @@ -308,6 +308,13 @@ class ConnectionPolicy: # pylint: disable=too-many-instance-attributes locations in this list are specified as the names of the azure Cosmos locations like, 'West US', 'East US', 'Central India' and so on. :vartype PreferredLocations: List[str] + :ivar ExcludedLocations: + Gets or sets the excluded locations for geo-replicated database + accounts. When ExcludedLocations is non-empty, the client will skip this + set of locations from the final location evaluation. The locations in + this list are specified as the names of the azure Cosmos locations like, + 'West US', 'East US', 'Central India' and so on. + :vartype ExcludedLocations: ~CosmosExcludedLocations :ivar RetryOptions: Gets or sets the retry options to be applied to all requests when retrying. @@ -347,6 +354,7 @@ def __init__(self) -> None: self.ProxyConfiguration: Optional[ProxyConfiguration] = None self.EnableEndpointDiscovery: bool = True self.PreferredLocations: List[str] = [] + self.ExcludedLocations: List[str] = [] self.RetryOptions: RetryOptions = RetryOptions() self.DisableSSLVerification: bool = False self.UseMultipleWriteLocations: bool = False diff --git a/sdk/cosmos/azure-cosmos/samples/excluded_locations.py b/sdk/cosmos/azure-cosmos/samples/excluded_locations.py new file mode 100644 index 000000000000..06228c1a8cea --- /dev/null +++ b/sdk/cosmos/azure-cosmos/samples/excluded_locations.py @@ -0,0 +1,110 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See LICENSE.txt in the project root for +# license information. +# ------------------------------------------------------------------------- +from azure.cosmos import CosmosClient +from azure.cosmos.partition_key import PartitionKey +import config + +# ---------------------------------------------------------------------------------------------------------- +# Prerequisites - +# +# 1. An Azure Cosmos account - +# https://learn.microsoft.com/azure/cosmos-db/create-sql-api-python#create-a-database-account +# +# 2. Microsoft Azure Cosmos +# pip install azure-cosmos>=4.3.0b4 +# ---------------------------------------------------------------------------------------------------------- +# Sample - demonstrates how to use excluded locations in client level and request level +# ---------------------------------------------------------------------------------------------------------- +# Note: +# This sample creates a Container to your database account. +# Each time a Container is created the account will be billed for 1 hour of usage based on +# the provisioned throughput (RU/s) of that account. +# ---------------------------------------------------------------------------------------------------------- + +HOST = config.settings["host"] +MASTER_KEY = config.settings["master_key"] + +TENANT_ID = config.settings["tenant_id"] +CLIENT_ID = config.settings["client_id"] +CLIENT_SECRET = config.settings["client_secret"] + +DATABASE_ID = config.settings["database_id"] +CONTAINER_ID = config.settings["container_id"] +PARTITION_KEY = PartitionKey(path="/id") + + +def get_test_item(num): + test_item = { + 'id': 'Item_' + str(num), + 'test_object': True, + 'lastName': 'Smith' + } + return test_item + +def clean_up_db(client): + try: + client.delete_database(DATABASE_ID) + except Exception as e: + pass + +def excluded_locations_client_level_sample(): + preferred_locations = ['West US 3', 'West US', 'East US 2'] + excluded_locations = ['West US 3', 'West US'] + client = CosmosClient( + HOST, + MASTER_KEY, + preferred_locations=preferred_locations, + excluded_locations=excluded_locations + ) + clean_up_db(client) + + db = client.create_database(DATABASE_ID) + container = db.create_container(id=CONTAINER_ID, partition_key=PARTITION_KEY) + + # For write operations with single master account, write endpoint will be the default endpoint, + # since preferred_locations or excluded_locations are ignored and used + container.create_item(get_test_item(0)) + + # For read operations, read endpoints will be 'preferred_locations' - 'excluded_locations'. + # In our sample, ['West US 3', 'West US', 'East US 2'] - ['West US 3', 'West US'] => ['East US 2'], + # therefore 'East US 2' will be the read endpoint, and items will be read from 'East US 2' location + item = container.read_item(item='Item_0', partition_key='Item_0') + + clean_up_db(client) + +def excluded_locations_request_level_sample(): + preferred_locations = ['West US 3', 'West US', 'East US 2'] + excluded_locations_on_client = ['West US 3', 'West US'] + excluded_locations_on_request = ['West US 3'] + client = CosmosClient( + HOST, + MASTER_KEY, + preferred_locations=preferred_locations, + excluded_locations=excluded_locations_on_client + ) + clean_up_db(client) + + db = client.create_database(DATABASE_ID) + container = db.create_container(id=CONTAINER_ID, partition_key=PARTITION_KEY) + + # For write operations with single master account, write endpoint will be the default endpoint, + # since preferred_locations or excluded_locations are ignored and used + container.create_item(get_test_item(0)) + + # For read operations, read endpoints will be 'preferred_locations' - 'excluded_locations'. + # However, in our sample, since the excluded_locations` were passed with the read request, the `excluded_location` + # will be replaced with the locations from request, ['West US 3']. The `excluded_locations` on request always takes + # the highest priority! + # With the excluded_locations on request, the read endpoints will be ['West US', 'East US 2'] + # ['West US 3', 'West US', 'East US 2'] - ['West US 3'] => ['West US', 'East US 2'] + # Therefore, items will be read from 'West US' or 'East US 2' location + item = container.read_item(item='Item_0', partition_key='Item_0', excluded_locations=excluded_locations_on_request) + + clean_up_db(client) + +if __name__ == "__main__": + # excluded_locations_client_level_sample() + excluded_locations_request_level_sample() diff --git a/sdk/cosmos/azure-cosmos/tests/test_health_check.py b/sdk/cosmos/azure-cosmos/tests/test_health_check.py index 0d313e6c911c..75db9deacc41 100644 --- a/sdk/cosmos/azure-cosmos/tests/test_health_check.py +++ b/sdk/cosmos/azure-cosmos/tests/test_health_check.py @@ -126,14 +126,14 @@ def test_health_check_timeouts_on_unavailable_endpoints(self, setup): locational_endpoint = _location_cache.LocationCache.GetLocationalEndpoint(TestHealthCheck.host, REGION_1) setup[COLLECTION].client_connection._global_endpoint_manager.location_cache.mark_endpoint_unavailable_for_read( locational_endpoint, True) - self.original_preferred_locations = setup[COLLECTION].client_connection._global_endpoint_manager.location_cache.preferred_locations - setup[COLLECTION].client_connection._global_endpoint_manager.location_cache.preferred_locations = REGIONS + self.original_preferred_locations = setup[COLLECTION].client_connection.connection_policy.PreferredLocations + setup[COLLECTION].client_connection.connection_policy.PreferredLocations = REGIONS try: setup[COLLECTION].create_item(body={'id': 'item' + str(uuid.uuid4()), 'pk': 'pk'}) finally: _global_endpoint_manager._GlobalEndpointManager._GetDatabaseAccountStub = self.original_getDatabaseAccountStub _cosmos_client_connection.CosmosClientConnection._GetDatabaseAccountCheck = self.original_getDatabaseAccountCheck - setup[COLLECTION].client_connection._global_endpoint_manager.location_cache.preferred_locations = self.original_preferred_locations + setup[COLLECTION].client_connection.connection_policy.PreferredLocations = self.original_preferred_locations class MockGetDatabaseAccountCheck(object): def __init__(self, client_connection=None, endpoint_unavailable=False): diff --git a/sdk/cosmos/azure-cosmos/tests/test_health_check_async.py b/sdk/cosmos/azure-cosmos/tests/test_health_check_async.py index ae2bf13fd8a7..a92eca0dd778 100644 --- a/sdk/cosmos/azure-cosmos/tests/test_health_check_async.py +++ b/sdk/cosmos/azure-cosmos/tests/test_health_check_async.py @@ -153,8 +153,8 @@ async def test_health_check_success(self, setup, preferred_location, use_write_g # checks the background health check works as expected when all endpoints healthy self.original_getDatabaseAccountStub = _global_endpoint_manager_async._GlobalEndpointManager._GetDatabaseAccountStub self.original_getDatabaseAccountCheck = _cosmos_client_connection_async.CosmosClientConnection._GetDatabaseAccountCheck - self.original_preferred_locations = setup[COLLECTION].client_connection._global_endpoint_manager.location_cache.preferred_locations - setup[COLLECTION].client_connection._global_endpoint_manager.location_cache.preferred_locations = preferred_location + self.original_preferred_locations = setup[COLLECTION].client_connection.connection_policy.PreferredLocations + setup[COLLECTION].client_connection.connection_policy.PreferredLocations = preferred_location mock_get_database_account_check = self.MockGetDatabaseAccountCheck() _global_endpoint_manager_async._GlobalEndpointManager._GetDatabaseAccountStub = ( self.MockGetDatabaseAccount(REGIONS, use_write_global_endpoint, use_read_global_endpoint)) @@ -168,7 +168,7 @@ async def test_health_check_success(self, setup, preferred_location, use_write_g finally: _global_endpoint_manager_async._GlobalEndpointManager._GetDatabaseAccountStub = self.original_getDatabaseAccountStub _cosmos_client_connection_async.CosmosClientConnection._GetDatabaseAccountCheck = self.original_getDatabaseAccountCheck - setup[COLLECTION].client_connection._global_endpoint_manager.location_cache.preferred_locations = self.original_preferred_locations + setup[COLLECTION].client_connection.connection_policy.PreferredLocations = self.original_preferred_locations expected_regional_routing_contexts = [] locational_endpoint = _location_cache.LocationCache.GetLocationalEndpoint(self.host, REGION_1) @@ -189,8 +189,8 @@ async def test_health_check_failure(self, setup, preferred_location, use_write_g self.original_getDatabaseAccountStub = _global_endpoint_manager_async._GlobalEndpointManager._GetDatabaseAccountStub _global_endpoint_manager_async._GlobalEndpointManager._GetDatabaseAccountStub = ( self.MockGetDatabaseAccount(REGIONS, use_write_global_endpoint, use_read_global_endpoint)) - self.original_preferred_locations = setup[COLLECTION].client_connection._global_endpoint_manager.location_cache.preferred_locations - setup[COLLECTION].client_connection._global_endpoint_manager.location_cache.preferred_locations = preferred_location + self.original_preferred_locations = setup[COLLECTION].client_connection.connection_policy.PreferredLocations + setup[COLLECTION].client_connection.connection_policy.PreferredLocations = preferred_location try: setup[COLLECTION].client_connection._global_endpoint_manager.startup = False @@ -201,7 +201,7 @@ async def test_health_check_failure(self, setup, preferred_location, use_write_g await asyncio.sleep(1) finally: _global_endpoint_manager_async._GlobalEndpointManager._GetDatabaseAccountStub = self.original_getDatabaseAccountStub - setup[COLLECTION].client_connection._global_endpoint_manager.location_cache.preferred_locations = self.original_preferred_locations + setup[COLLECTION].client_connection.connection_policy.PreferredLocations = self.original_preferred_locations if not use_write_global_endpoint: num_unavailable_endpoints = len(REGIONS) diff --git a/sdk/cosmos/azure-cosmos/tests/test_location_cache.py b/sdk/cosmos/azure-cosmos/tests/test_location_cache.py index a957094f1790..f65a1f1a3d21 100644 --- a/sdk/cosmos/azure-cosmos/tests/test_location_cache.py +++ b/sdk/cosmos/azure-cosmos/tests/test_location_cache.py @@ -3,8 +3,10 @@ import time import unittest +from typing import Mapping, Any import pytest +from azure.cosmos import documents from azure.cosmos.documents import DatabaseAccount, _OperationType from azure.cosmos.http_constants import ResourceType @@ -35,15 +37,15 @@ def create_database_account(enable_multiple_writable_locations): return db_acc -def refresh_location_cache(preferred_locations, use_multiple_write_locations): - lc = LocationCache(preferred_locations=preferred_locations, - default_endpoint=default_endpoint, - enable_endpoint_discovery=True, - use_multiple_write_locations=use_multiple_write_locations) +def refresh_location_cache(preferred_locations, use_multiple_write_locations, connection_policy=documents.ConnectionPolicy()): + connection_policy.PreferredLocations = preferred_locations + connection_policy.UseMultipleWriteLocations = use_multiple_write_locations + lc = LocationCache(default_endpoint=default_endpoint, + connection_policy=connection_policy) return lc @pytest.mark.cosmosEmulator -class TestLocationCache(unittest.TestCase): +class TestLocationCache: def test_mark_endpoint_unavailable(self): lc = refresh_location_cache([], False) @@ -136,6 +138,140 @@ def test_resolve_request_endpoint_preferred_regions(self): assert read_resolved == write_resolved assert read_resolved == default_endpoint + @pytest.mark.parametrize("test_type",["OnClient", "OnRequest", "OnBoth"]) + def test_get_applicable_regional_endpoints_excluded_regions(self, test_type): + # Init test data + if test_type == "OnClient": + excluded_locations_on_client_list = [ + [location1_name], + [location1_name, location2_name], + [location1_name, location2_name, location3_name], + [location4_name], + [], + ] + excluded_locations_on_requests_list = [None] * 5 + elif test_type == "OnRequest": + excluded_locations_on_client_list = [[]] * 5 + excluded_locations_on_requests_list = [ + [location1_name], + [location1_name, location2_name], + [location1_name, location2_name, location3_name], + [location4_name], + [], + ] + else: + excluded_locations_on_client_list = [ + [location1_name], + [location1_name, location2_name, location3_name], + [location1_name, location2_name], + [location2_name], + [location1_name, location2_name, location3_name], + ] + excluded_locations_on_requests_list = [ + [location1_name], + [location1_name, location2_name], + [location1_name, location2_name, location3_name], + [location4_name], + [], + ] + + expected_read_endpoints_list = [ + [location2_endpoint], + [location1_endpoint], + [location1_endpoint], + [location1_endpoint, location2_endpoint], + [location1_endpoint, location2_endpoint], + ] + expected_write_endpoints_list = [ + [location2_endpoint, location3_endpoint], + [location3_endpoint], + [default_endpoint], + [location1_endpoint, location2_endpoint, location3_endpoint], + [location1_endpoint, location2_endpoint, location3_endpoint], + ] + + # Loop over each test cases + for excluded_locations_on_client, excluded_locations_on_requests, expected_read_endpoints, expected_write_endpoints in zip(excluded_locations_on_client_list, excluded_locations_on_requests_list, expected_read_endpoints_list, expected_write_endpoints_list): + # Init excluded_locations in ConnectionPolicy + connection_policy = documents.ConnectionPolicy() + connection_policy.ExcludedLocations = excluded_locations_on_client + + # Init location_cache + location_cache = refresh_location_cache([location1_name, location2_name, location3_name], True, + connection_policy) + database_account = create_database_account(True) + location_cache.perform_on_database_account_read(database_account) + + # Init requests and set excluded regions on requests + write_doc_request = RequestObject(ResourceType.Document, _OperationType.Create) + write_doc_request.excluded_locations = excluded_locations_on_requests + read_doc_request = RequestObject(ResourceType.Document, _OperationType.Read) + read_doc_request.excluded_locations = excluded_locations_on_requests + + # Test if read endpoints were correctly filtered on client level + read_doc_endpoint = location_cache._get_applicable_read_regional_endpoints(read_doc_request) + read_doc_endpoint = [regional_endpoint.get_primary() for regional_endpoint in read_doc_endpoint] + assert read_doc_endpoint == expected_read_endpoints + + # Test if write endpoints were correctly filtered on client level + write_doc_endpoint = location_cache._get_applicable_write_regional_endpoints(write_doc_request) + write_doc_endpoint = [regional_endpoint.get_primary() for regional_endpoint in write_doc_endpoint] + assert write_doc_endpoint == expected_write_endpoints + + def test_set_excluded_locations_for_requests(self): + # Init excluded_locations in ConnectionPolicy + excluded_locations_on_client = [location1_name, location2_name] + connection_policy = documents.ConnectionPolicy() + connection_policy.ExcludedLocations = excluded_locations_on_client + + # Init location_cache + location_cache = refresh_location_cache([location1_name, location2_name, location3_name], True, + connection_policy) + database_account = create_database_account(True) + location_cache.perform_on_database_account_read(database_account) + + # Test setting excluded locations + excluded_locations = [location1_name] + options: Mapping[str, Any] = {"excludedLocations": excluded_locations} + + expected_excluded_locations = excluded_locations + read_doc_request = RequestObject(ResourceType.Document, _OperationType.Create) + read_doc_request.set_excluded_location_from_options(options) + actual_excluded_locations = read_doc_request.excluded_locations + assert actual_excluded_locations == expected_excluded_locations + + expected_read_endpoints = [location2_endpoint] + read_doc_endpoint = location_cache._get_applicable_read_regional_endpoints(read_doc_request) + read_doc_endpoint = [regional_endpoint.get_primary() for regional_endpoint in read_doc_endpoint] + assert read_doc_endpoint == expected_read_endpoints + + + # Test setting excluded locations with invalid resource types + expected_excluded_locations = None + for resource_type in [ResourceType.Offer, ResourceType.Conflict]: + options: Mapping[str, Any] = {"excludedLocations": [location1_name]} + read_doc_request = RequestObject(resource_type, _OperationType.Create) + read_doc_request.set_excluded_location_from_options(options) + actual_excluded_locations = read_doc_request.excluded_locations + assert actual_excluded_locations == expected_excluded_locations + + expected_read_endpoints = [location1_endpoint] + read_doc_endpoint = location_cache._get_applicable_read_regional_endpoints(read_doc_request) + read_doc_endpoint = [regional_endpoint.get_primary() for regional_endpoint in read_doc_endpoint] + assert read_doc_endpoint == expected_read_endpoints + + + + # Test setting excluded locations with None value + expected_error_message = ("Excluded locations cannot be None. " + "If you want to remove all excluded locations, try passing an empty list.") + with pytest.raises(ValueError) as e: + options: Mapping[str, Any] = {"excludedLocations": None} + doc_request = RequestObject(ResourceType.Document, _OperationType.Create) + doc_request.set_excluded_location_from_options(options) + assert str( + e.value) == expected_error_message + if __name__ == "__main__": unittest.main() diff --git a/sdk/cosmos/azure-cosmos/tests/test_retry_policy_async.py b/sdk/cosmos/azure-cosmos/tests/test_retry_policy_async.py index be1683d1504d..4faef31c9495 100644 --- a/sdk/cosmos/azure-cosmos/tests/test_retry_policy_async.py +++ b/sdk/cosmos/azure-cosmos/tests/test_retry_policy_async.py @@ -42,6 +42,7 @@ def __init__(self) -> None: self.ProxyConfiguration: Optional[ProxyConfiguration] = None self.EnableEndpointDiscovery: bool = True self.PreferredLocations: List[str] = [] + self.ExcludedLocations = None self.RetryOptions: RetryOptions = RetryOptions() self.DisableSSLVerification: bool = False self.UseMultipleWriteLocations: bool = False From cf42098ac85b35a818ffddb774bb5f404738b70e Mon Sep 17 00:00:00 2001 From: tvaron3 Date: Mon, 31 Mar 2025 15:47:16 -0700 Subject: [PATCH 22/86] initial ppcb changes --- sdk/cosmos/azure-cosmos/CHANGELOG.md | 1 + sdk/cosmos/azure-cosmos/azure/cosmos/_base.py | 1 + .../azure-cosmos/azure/cosmos/_constants.py | 20 +++ .../azure/cosmos/_cosmos_client_connection.py | 52 ++++--- .../aio/execution_dispatcher.py | 6 +- .../execution_dispatcher.py | 11 +- .../azure/cosmos/_global_endpoint_manager.py | 11 +- .../azure/cosmos/_location_cache.py | 77 +++++++++- .../azure/cosmos/_request_object.py | 39 ++++- .../azure/cosmos/_routing/routing_range.py | 22 +++ .../cosmos/_service_request_retry_policy.py | 13 +- .../cosmos/_timeout_failover_retry_policy.py | 6 +- .../azure/cosmos/aio/_container.py | 32 ++++ .../aio/_cosmos_client_connection_async.py | 58 ++++--- .../aio/_global_endpoint_manager_async.py | 21 ++- .../azure/cosmos/aio/_retry_utility_async.py | 8 + .../azure-cosmos/azure/cosmos/container.py | 32 ++++ .../azure/cosmos/cosmos_client.py | 4 + .../azure-cosmos/azure/cosmos/documents.py | 8 + .../azure-cosmos/tests/test_location_cache.py | 143 +++++++++++++++++- .../tests/test_query_hybrid_search.py | 13 ++ .../tests/test_query_hybrid_search_async.py | 14 +- .../tests/test_query_vector_similarity.py | 19 ++- .../test_query_vector_similarity_async.py | 19 ++- 24 files changed, 550 insertions(+), 80 deletions(-) diff --git a/sdk/cosmos/azure-cosmos/CHANGELOG.md b/sdk/cosmos/azure-cosmos/CHANGELOG.md index 0a4906c4cefe..c5f836ce2a03 100644 --- a/sdk/cosmos/azure-cosmos/CHANGELOG.md +++ b/sdk/cosmos/azure-cosmos/CHANGELOG.md @@ -3,6 +3,7 @@ ### 4.10.0b3 (Unreleased) #### Features Added +* Per partition circuit breaker support. It can be enabled through environment variable `AZURE_COSMOS_ENABLE_CIRCUIT_BREAKER`. See #### Breaking Changes diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/_base.py b/sdk/cosmos/azure-cosmos/azure/cosmos/_base.py index 654b23c5d71f..bcfc611456ec 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/_base.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/_base.py @@ -63,6 +63,7 @@ 'priority': 'priorityLevel', 'no_response': 'responsePayloadOnWriteDisabled', 'max_item_count': 'maxItemCount', + 'excluded_locations': 'excludedLocations', } # Cosmos resource ID validation regex breakdown: diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/_constants.py b/sdk/cosmos/azure-cosmos/azure/cosmos/_constants.py index 890a172ccee6..2af40c74d77a 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/_constants.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/_constants.py @@ -44,6 +44,26 @@ class _Constants: # ServiceDocument Resource EnableMultipleWritableLocations: Literal["enableMultipleWriteLocations"] = "enableMultipleWriteLocations" + # Environment variables + NON_STREAMING_ORDER_BY_DISABLED_CONFIG: str = "AZURE_COSMOS_DISABLE_NON_STREAMING_ORDER_BY" + NON_STREAMING_ORDER_BY_DISABLED_CONFIG_DEFAULT: str = "False" + HS_MAX_ITEMS_CONFIG: str = "AZURE_COSMOS_HYBRID_SEARCH_MAX_ITEMS" + HS_MAX_ITEMS_CONFIG_DEFAULT: int = 1000 + MAX_ITEM_BUFFER_VS_CONFIG: str = "AZURE_COSMOS_MAX_ITEM_BUFFER_VECTOR_SEARCH" + MAX_ITEM_BUFFER_VS_CONFIG_DEFAULT: int = 50000 + CIRCUIT_BREAKER_ENABLED_CONFIG: str = "AZURE_COSMOS_ENABLE_CIRCUIT_BREAKER" + CIRCUIT_BREAKER_ENABLED_CONFIG_DEFAULT: str = "False" + # Only applicable when circuit breaker is enabled ------------------------- + CONSECUTIVE_ERROR_COUNT_TOLERATED_FOR_READ: str = "AZURE_COSMOS_CONSECUTIVE_ERROR_COUNT_TOLERATED_FOR_READ" + CONSECUTIVE_ERROR_COUNT_TOLERATED_FOR_READ_DEFAULT: int = 10 + CONSECUTIVE_ERROR_COUNT_TOLERATED_FOR_WRITE: str = "AZURE_COSMOS_CONSECUTIVE_ERROR_COUNT_TOLERATED_FOR_WRITE" + CONSECUTIVE_ERROR_COUNT_TOLERATED_FOR_WRITE_DEFAULT: int = 5 + FAILURE_PERCENTAGE_TOLERATED = "AZURE_COSMOS_FAILURE_PERCENTAGE_TOLERATED" + FAILURE_PERCENTAGE_TOLERATED_DEFAULT: int = 70 + STALE_PARTITION_UNAVAILABILITY_CHECK = "AZURE_COSMOS_STALE_PARTITION_UNAVAILABILITY_CHECK_IN_SECONDS" + STALE_PARTITION_UNAVAILABILITY_CHECK_DEFAULT: int = 120 + # ------------------------------------------------------------------------- + # Error code translations ERROR_TRANSLATIONS: Dict[int, str] = { 400: "BAD_REQUEST - Request being sent is invalid.", diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/_cosmos_client_connection.py b/sdk/cosmos/azure-cosmos/azure/cosmos/_cosmos_client_connection.py index 3934c23bcf99..298d3032877b 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/_cosmos_client_connection.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/_cosmos_client_connection.py @@ -2043,7 +2043,8 @@ def PatchItem( headers = base.GetHeaders(self, self.default_headers, "patch", path, document_id, resource_type, documents._OperationType.Patch, options) # Patch will use WriteEndpoint since it uses PUT operation - request_params = RequestObject(resource_type, documents._OperationType.Patch) + request_params = RequestObject(resource_type, documents._OperationType.Patch, headers) + request_params.set_excluded_location_from_options(options) request_data = {} if options.get("filterPredicate"): request_data["condition"] = options.get("filterPredicate") @@ -2131,7 +2132,8 @@ def _Batch( base._populate_batch_headers(initial_headers) headers = base.GetHeaders(self, initial_headers, "post", path, collection_id, "docs", documents._OperationType.Batch, options) - request_params = RequestObject("docs", documents._OperationType.Batch) + request_params = RequestObject("docs", documents._OperationType.Batch, headers) + request_params.set_excluded_location_from_options(options) return cast( Tuple[List[Dict[str, Any]], CaseInsensitiveDict], self.__Post(path, request_params, batch_operations, headers, **kwargs) @@ -2191,7 +2193,8 @@ def DeleteAllItemsByPartitionKey( collection_id = base.GetResourceIdOrFullNameFromLink(collection_link) headers = base.GetHeaders(self, self.default_headers, "post", path, collection_id, "partitionkey", documents._OperationType.Delete, options) - request_params = RequestObject("partitionkey", documents._OperationType.Delete) + request_params = RequestObject("partitionkey", documents._OperationType.Delete, headers) + request_params.set_excluded_location_from_options(options) _, last_response_headers = self.__Post( path=path, request_params=request_params, @@ -2362,7 +2365,8 @@ def ExecuteStoredProcedure( documents._OperationType.ExecuteJavaScript, options) # ExecuteStoredProcedure will use WriteEndpoint since it uses POST operation - request_params = RequestObject("sprocs", documents._OperationType.ExecuteJavaScript) + request_params = RequestObject("sprocs", documents._OperationType.ExecuteJavaScript, headers) + request_params.set_excluded_location_from_options(options) result, self.last_response_headers = self.__Post(path, request_params, params, headers, **kwargs) return result @@ -2558,7 +2562,7 @@ def GetDatabaseAccount( headers = base.GetHeaders(self, self.default_headers, "get", "", "", "", documents._OperationType.Read,{}, client_id=self.client_id) - request_params = RequestObject("databaseaccount", documents._OperationType.Read, url_connection) + request_params = RequestObject("databaseaccount", documents._OperationType.Read, headers, url_connection) result, last_response_headers = self.__Get("", request_params, headers, **kwargs) self.last_response_headers = last_response_headers database_account = DatabaseAccount() @@ -2607,7 +2611,7 @@ def _GetDatabaseAccountCheck( headers = base.GetHeaders(self, self.default_headers, "get", "", "", "", documents._OperationType.Read,{}, client_id=self.client_id) - request_params = RequestObject("databaseaccount", documents._OperationType.Read, url_connection) + request_params = RequestObject("databaseaccount", documents._OperationType.Read, headers, url_connection) self.__Get("", request_params, headers, **kwargs) @@ -2646,7 +2650,8 @@ def Create( options) # Create will use WriteEndpoint since it uses POST operation - request_params = RequestObject(typ, documents._OperationType.Create) + request_params = RequestObject(typ, documents._OperationType.Create, headers) + request_params.set_excluded_location_from_options(options) result, last_response_headers = self.__Post(path, request_params, body, headers, **kwargs) self.last_response_headers = last_response_headers @@ -2692,7 +2697,8 @@ def Upsert( headers[http_constants.HttpHeaders.IsUpsert] = True # Upsert will use WriteEndpoint since it uses POST operation - request_params = RequestObject(typ, documents._OperationType.Upsert) + request_params = RequestObject(typ, documents._OperationType.Upsert, headers) + request_params.set_excluded_location_from_options(options) result, last_response_headers = self.__Post(path, request_params, body, headers, **kwargs) self.last_response_headers = last_response_headers # update session for write request @@ -2735,7 +2741,8 @@ def Replace( headers = base.GetHeaders(self, initial_headers, "put", path, id, typ, documents._OperationType.Replace, options) # Replace will use WriteEndpoint since it uses PUT operation - request_params = RequestObject(typ, documents._OperationType.Replace) + request_params = RequestObject(typ, documents._OperationType.Replace, headers) + request_params.set_excluded_location_from_options(options) result, last_response_headers = self.__Put(path, request_params, resource, headers, **kwargs) self.last_response_headers = last_response_headers @@ -2776,7 +2783,8 @@ def Read( initial_headers = initial_headers or self.default_headers headers = base.GetHeaders(self, initial_headers, "get", path, id, typ, documents._OperationType.Read, options) # Read will use ReadEndpoint since it uses GET operation - request_params = RequestObject(typ, documents._OperationType.Read) + request_params = RequestObject(typ, documents._OperationType.Read, headers) + request_params.set_excluded_location_from_options(options) result, last_response_headers = self.__Get(path, request_params, headers, **kwargs) self.last_response_headers = last_response_headers if response_hook: @@ -2815,7 +2823,8 @@ def DeleteResource( headers = base.GetHeaders(self, initial_headers, "delete", path, id, typ, documents._OperationType.Delete, options) # Delete will use WriteEndpoint since it uses DELETE operation - request_params = RequestObject(typ, documents._OperationType.Delete) + request_params = RequestObject(typ, documents._OperationType.Delete, headers) + request_params.set_excluded_location_from_options(options) result, last_response_headers = self.__Delete(path, request_params, headers, **kwargs) self.last_response_headers = last_response_headers @@ -3047,11 +3056,8 @@ def __GetBodiesFromQueryResult(result: Dict[str, Any]) -> List[Dict[str, Any]]: initial_headers = self.default_headers.copy() # Copy to make sure that default_headers won't be changed. if query is None: + op_typ = documents._OperationType.QueryPlan if is_query_plan else documents._OperationType.ReadFeed # Query operations will use ReadEndpoint even though it uses GET(for feed requests) - request_params = RequestObject( - resource_type, - documents._OperationType.QueryPlan if is_query_plan else documents._OperationType.ReadFeed - ) headers = base.GetHeaders( self, initial_headers, @@ -3059,11 +3065,18 @@ def __GetBodiesFromQueryResult(result: Dict[str, Any]) -> List[Dict[str, Any]]: path, resource_id, resource_type, - request_params.operation_type, + op_typ, options, partition_key_range_id ) + request_params = RequestObject( + resource_type, + op_typ, + headers + ) + request_params.set_excluded_location_from_options(options) + change_feed_state: Optional[ChangeFeedState] = options.get("changeFeedState") if change_feed_state is not None: change_feed_state.populate_request_headers(self._routing_map_provider, headers) @@ -3089,7 +3102,6 @@ def __GetBodiesFromQueryResult(result: Dict[str, Any]) -> List[Dict[str, Any]]: raise SystemError("Unexpected query compatibility mode.") # Query operations will use ReadEndpoint even though it uses POST(for regular query operations) - request_params = RequestObject(resource_type, documents._OperationType.SqlQuery) req_headers = base.GetHeaders( self, initial_headers, @@ -3102,6 +3114,9 @@ def __GetBodiesFromQueryResult(result: Dict[str, Any]) -> List[Dict[str, Any]]: partition_key_range_id ) + request_params = RequestObject(resource_type, documents._OperationType.SqlQuery, req_headers) + request_params.set_excluded_location_from_options(options) + # check if query has prefix partition key isPrefixPartitionQuery = kwargs.pop("isPrefixPartitionQuery", None) if isPrefixPartitionQuery: @@ -3183,7 +3198,8 @@ def _GetQueryPlanThroughGateway(self, query: str, resource_link: str, **kwargs: documents._QueryFeature.NonStreamingOrderBy + "," + documents._QueryFeature.HybridSearch + "," + documents._QueryFeature.CountIf) - if os.environ.get('AZURE_COSMOS_DISABLE_NON_STREAMING_ORDER_BY', False): + if os.environ.get(Constants.NON_STREAMING_ORDER_BY_DISABLED_CONFIG, + Constants.NON_STREAMING_ORDER_BY_DISABLED_CONFIG_DEFAULT) == "True": supported_query_features = (documents._QueryFeature.Aggregate + "," + documents._QueryFeature.CompositeAggregate + "," + documents._QueryFeature.Distinct + "," + diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/_execution_context/aio/execution_dispatcher.py b/sdk/cosmos/azure-cosmos/azure/cosmos/_execution_context/aio/execution_dispatcher.py index 70df36d6d015..9b4d32a4598b 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/_execution_context/aio/execution_dispatcher.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/_execution_context/aio/execution_dispatcher.py @@ -34,6 +34,7 @@ from azure.cosmos.documents import _DistinctType from azure.cosmos.exceptions import CosmosHttpResponseError from azure.cosmos.http_constants import StatusCodes +from ..._constants import _Constants as Constants # pylint: disable=protected-access @@ -117,8 +118,9 @@ async def _create_pipelined_execution_context(self, query_execution_info): raise ValueError("Executing a vector search query without TOP or LIMIT can consume many" + " RUs very fast and have long runtimes. Please ensure you are using one" + " of the two filters with your vector search query.") - if total_item_buffer > os.environ.get('AZURE_COSMOS_MAX_ITEM_BUFFER_VECTOR_SEARCH', 50000): - raise ValueError("Executing a vector search query with more items than the max is not allowed." + + if total_item_buffer > int(os.environ.get(Constants.MAX_ITEM_BUFFER_VS_CONFIG, + Constants.MAX_ITEM_BUFFER_VS_CONFIG_DEFAULT)): + raise ValueError("Executing a vector search query with more items than the max is not allowed. " + "Please ensure you are using a limit smaller than the max, or change the max.") execution_context_aggregator =\ non_streaming_order_by_aggregator._NonStreamingOrderByContextAggregator(self._client, diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/_execution_context/execution_dispatcher.py b/sdk/cosmos/azure-cosmos/azure/cosmos/_execution_context/execution_dispatcher.py index 1155e537c68c..453a4dc38d25 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/_execution_context/execution_dispatcher.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/_execution_context/execution_dispatcher.py @@ -26,6 +26,7 @@ import json import os from azure.cosmos.exceptions import CosmosHttpResponseError +from .._constants import _Constants as Constants from azure.cosmos._execution_context import endpoint_component, multi_execution_aggregator from azure.cosmos._execution_context import non_streaming_order_by_aggregator, hybrid_search_aggregator from azure.cosmos._execution_context.base_execution_context import _QueryExecutionContextBase @@ -56,8 +57,9 @@ def _verify_valid_hybrid_search_query(hybrid_search_query_info): raise ValueError("Executing a hybrid search query without TOP or LIMIT can consume many" + " RUs very fast and have long runtimes. Please ensure you are using one" + " of the two filters with your hybrid search query.") - if hybrid_search_query_info['take'] > os.environ.get('AZURE_COSMOS_HYBRID_SEARCH_MAX_ITEMS', 1000): - raise ValueError("Executing a hybrid search query with more items than the max is not allowed." + + if hybrid_search_query_info['take'] > int(os.environ.get(Constants.HS_MAX_ITEMS_CONFIG, + Constants.HS_MAX_ITEMS_CONFIG_DEFAULT)): + raise ValueError("Executing a hybrid search query with more items than the max is not allowed. " + "Please ensure you are using a limit smaller than the max, or change the max.") @@ -149,8 +151,9 @@ def _create_pipelined_execution_context(self, query_execution_info): raise ValueError("Executing a vector search query without TOP or LIMIT can consume many" + " RUs very fast and have long runtimes. Please ensure you are using one" + " of the two filters with your vector search query.") - if total_item_buffer > os.environ.get('AZURE_COSMOS_MAX_ITEM_BUFFER_VECTOR_SEARCH', 50000): - raise ValueError("Executing a vector search query with more items than the max is not allowed." + + if total_item_buffer > int(os.environ.get(Constants.MAX_ITEM_BUFFER_VS_CONFIG, + Constants.MAX_ITEM_BUFFER_VS_CONFIG_DEFAULT)): + raise ValueError("Executing a vector search query with more items than the max is not allowed. " + "Please ensure you are using a limit smaller than the max, or change the max.") execution_context_aggregator = \ non_streaming_order_by_aggregator._NonStreamingOrderByContextAggregator(self._client, diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/_global_endpoint_manager.py b/sdk/cosmos/azure-cosmos/azure/cosmos/_global_endpoint_manager.py index e167871dd4a5..8aa7c388a9b6 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/_global_endpoint_manager.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/_global_endpoint_manager.py @@ -31,7 +31,7 @@ from . import _constants as constants from . import exceptions from .documents import DatabaseAccount -from ._location_cache import LocationCache +from ._location_cache import LocationCache, current_time_millis # pylint: disable=protected-access @@ -53,7 +53,8 @@ def __init__(self, client): self.PreferredLocations, self.DefaultEndpoint, self.EnableEndpointDiscovery, - client.connection_policy.UseMultipleWriteLocations + client.connection_policy.UseMultipleWriteLocations, + client.connection_policy ) self.refresh_needed = False self.refresh_lock = threading.RLock() @@ -98,7 +99,7 @@ def update_location_cache(self): self.location_cache.update_location_cache() def refresh_endpoint_list(self, database_account, **kwargs): - if self.location_cache.current_time_millis() - self.last_refresh_time > self.refresh_time_interval_in_ms: + if current_time_millis() - self.last_refresh_time > self.refresh_time_interval_in_ms: self.refresh_needed = True if self.refresh_needed: with self.refresh_lock: @@ -114,11 +115,11 @@ def _refresh_endpoint_list_private(self, database_account=None, **kwargs): if database_account: self.location_cache.perform_on_database_account_read(database_account) self.refresh_needed = False - self.last_refresh_time = self.location_cache.current_time_millis() + self.last_refresh_time = current_time_millis() else: if self.location_cache.should_refresh_endpoints() or self.refresh_needed: self.refresh_needed = False - self.last_refresh_time = self.location_cache.current_time_millis() + self.last_refresh_time = current_time_millis() # this will perform getDatabaseAccount calls to check endpoint health self._endpoints_health_check(**kwargs) diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/_location_cache.py b/sdk/cosmos/azure-cosmos/azure/cosmos/_location_cache.py index 96651d5c8b7f..d40a99f7c69f 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/_location_cache.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/_location_cache.py @@ -113,7 +113,10 @@ def get_endpoints_by_location(new_locations, except Exception as e: raise e - return endpoints_by_location, parsed_locations + # Also store a hash map of endpoints for each location + locations_by_endpoints = {value.get_primary(): key for key, value in endpoints_by_location.items()} + + return endpoints_by_location, locations_by_endpoints, parsed_locations def add_endpoint_if_preferred(endpoint: str, preferred_endpoints: Set[str], endpoints: Set[str]) -> bool: if endpoint in preferred_endpoints: @@ -150,10 +153,24 @@ def _get_health_check_endpoints( return endpoints +def get_applicable_regional_endpoints(endpoints, location_name_by_endpoint, fall_back_endpoint, + exclude_location_list): + # filter endpoints by excluded locations + applicable_endpoints = [] + for endpoint in endpoints: + if location_name_by_endpoint.get(endpoint.get_primary()) not in exclude_location_list: + applicable_endpoints.append(endpoint) + + # if endpoint is empty add fallback endpoint + if not applicable_endpoints: + applicable_endpoints.append(fall_back_endpoint) + + return applicable_endpoints + +def current_time_millis(): + return int(round(time.time() * 1000)) class LocationCache(object): # pylint: disable=too-many-public-methods,too-many-instance-attributes - def current_time_millis(self): - return int(round(time.time() * 1000)) def __init__( self, @@ -161,6 +178,7 @@ def __init__( default_endpoint, enable_endpoint_discovery, use_multiple_write_locations, + connection_policy, ): self.preferred_locations = preferred_locations self.default_regional_routing_context = RegionalRoutingContext(default_endpoint, default_endpoint) @@ -173,8 +191,11 @@ def __init__( self.last_cache_update_time_stamp = 0 self.account_read_regional_routing_contexts_by_location = {} # pylint: disable=name-too-long self.account_write_regional_routing_contexts_by_location = {} # pylint: disable=name-too-long + self.account_locations_by_read_regional_endpoints = {} # pylint: disable=name-too-long + self.account_locations_by_write_regional_endpoints = {} # pylint: disable=name-too-long self.account_write_locations = [] self.account_read_locations = [] + self.connection_policy = connection_policy def get_write_regional_routing_contexts(self): return self.write_regional_routing_contexts @@ -182,6 +203,10 @@ def get_write_regional_routing_contexts(self): def get_read_regional_routing_contexts(self): return self.read_regional_routing_contexts + def get_location_from_endpoint(self, endpoint: str) -> str: + regional_routing_context = RegionalRoutingContext(endpoint, endpoint) + return self.account_locations_by_read_regional_endpoints[regional_routing_context] + def get_write_regional_routing_context(self): return self.get_write_regional_routing_contexts()[0].get_primary() @@ -207,6 +232,45 @@ def get_ordered_write_locations(self): def get_ordered_read_locations(self): return self.account_read_locations + # Todo: @tvaron3 client should be appeneded to if using circuitbreaker exclude regions + def _get_configured_excluded_locations(self, request): + # If excluded locations were configured on request, use request level excluded locations. + excluded_locations = request.excluded_locations + if excluded_locations is None: + # If excluded locations were only configured on client(connection_policy), use client level + excluded_locations = self.connection_policy.ExcludedLocations + return excluded_locations + + def get_applicable_read_regional_endpoints(self, request): + # Get configured excluded locations + excluded_locations = self._get_configured_excluded_locations(request) + + # If excluded locations were configured, return filtered regional endpoints by excluded locations. + if excluded_locations: + return get_applicable_regional_endpoints( + self.get_read_regional_routing_contexts(), + self.account_locations_by_read_regional_endpoints, + self.get_write_regional_routing_contexts()[0], + excluded_locations) + + # Else, return all regional endpoints + return self.get_read_regional_routing_contexts() + + def get_applicable_write_regional_endpoints(self, request): + # Get configured excluded locations + excluded_locations = self._get_configured_excluded_locations(request) + + # If excluded locations were configured, return filtered regional endpoints by excluded locations. + if excluded_locations: + return get_applicable_regional_endpoints( + self.get_write_regional_routing_contexts(), + self.account_locations_by_write_regional_endpoints, + self.default_regional_routing_context, + excluded_locations) + + # Else, return all regional endpoints + return self.get_write_regional_routing_contexts() + def resolve_service_endpoint(self, request): if request.location_endpoint_to_route: return request.location_endpoint_to_route @@ -247,9 +311,9 @@ def resolve_service_endpoint(self, request): return self.default_regional_routing_context.get_primary() regional_routing_contexts = ( - self.get_write_regional_routing_contexts() + self.get_applicable_write_regional_endpoints(request) if documents._OperationType.IsWriteOperation(request.operation_type) - else self.get_read_regional_routing_contexts() + else self.get_applicable_read_regional_endpoints(request) ) regional_routing_context = regional_routing_contexts[location_index % len(regional_routing_contexts)] if ( @@ -361,6 +425,7 @@ def update_location_cache(self, write_locations=None, read_locations=None, enabl if self.enable_endpoint_discovery: if read_locations: (self.account_read_regional_routing_contexts_by_location, + self.account_locations_by_read_regional_endpoints, self.account_read_locations) = get_endpoints_by_location( read_locations, self.account_read_regional_routing_contexts_by_location, @@ -371,6 +436,7 @@ def update_location_cache(self, write_locations=None, read_locations=None, enabl if write_locations: (self.account_write_regional_routing_contexts_by_location, + self.account_locations_by_write_regional_endpoints, self.account_write_locations) = get_endpoints_by_location( write_locations, self.account_write_regional_routing_contexts_by_location, @@ -391,7 +457,6 @@ def update_location_cache(self, write_locations=None, read_locations=None, enabl EndpointOperationType.ReadType, self.write_regional_routing_contexts[0] ) - self.last_cache_update_timestamp = self.current_time_millis() # pylint: disable=attribute-defined-outside-init def get_preferred_regional_routing_contexts( self, endpoints_by_location, orderedLocations, expected_available_operation, fallback_endpoint diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/_request_object.py b/sdk/cosmos/azure-cosmos/azure/cosmos/_request_object.py index a220c6af42c2..afc9fa4d30a9 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/_request_object.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/_request_object.py @@ -21,18 +21,27 @@ """Represents a request object. """ -from typing import Optional +from typing import Optional, Mapping, Any, List, Dict + class RequestObject(object): - def __init__(self, resource_type: str, operation_type: str, endpoint_override: Optional[str] = None) -> None: + def __init__( + self, + resource_type: str, + operation_type: str, + headers: Dict[str, Any], + endpoint_override: Optional[str] = None, + ) -> None: self.resource_type = resource_type self.operation_type = operation_type self.endpoint_override = endpoint_override self.should_clear_session_token_on_session_read_failure: bool = False # pylint: disable=name-too-long + self.headers = headers self.use_preferred_locations: Optional[bool] = None self.location_index_to_route: Optional[int] = None self.location_endpoint_to_route: Optional[str] = None self.last_routed_location_endpoint_within_region: Optional[str] = None + self.excluded_locations: Optional[List[str]] = None def route_to_location_with_preferred_location_flag( # pylint: disable=name-too-long self, @@ -52,3 +61,29 @@ def clear_route_to_location(self) -> None: self.location_index_to_route = None self.use_preferred_locations = None self.location_endpoint_to_route = None + + def _can_set_excluded_location(self, options: Mapping[str, Any]) -> bool: + # If resource types for requests are one of the followings, excluded locations cannot be set + if self.resource_type.lower() in ['offers', 'conflicts']: + return False + + # If 'excludedLocations' wasn't in the options, excluded locations cannot be set + if (options is None + or 'excludedLocations' not in options): + return False + + # The 'excludedLocations' cannot be None + if options['excludedLocations'] is None: + raise ValueError("Excluded locations cannot be None. " + "If you want to remove all excluded locations, try passing an empty list.") + + return True + + def set_excluded_location_from_options(self, options: Mapping[str, Any]) -> None: + if self._can_set_excluded_location(options): + self.excluded_locations = options['excludedLocations'] + + def set_excluded_locations_from_circuit_breaker(self, excluded_locations: List[str]) -> None: + if self.excluded_locations: + self.excluded_locations.extend(excluded_locations) + self.excluded_locations = excluded_locations diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/_routing/routing_range.py b/sdk/cosmos/azure-cosmos/azure/cosmos/_routing/routing_range.py index 452bc32e5b34..4e3d603ef0d8 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/_routing/routing_range.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/_routing/routing_range.py @@ -226,3 +226,25 @@ def is_subset(self, parent_range: 'Range') -> bool: normalized_child_range = self.to_normalized_range() return (normalized_parent_range.min <= normalized_child_range.min and normalized_parent_range.max >= normalized_child_range.max) + +class PartitionKeyRangeWrapper(object): + """Internal class for a representation of a unique partition for an account + """ + + def __init__(self, partition_key_range: Range, collection_rid: str) -> None: + self.partition_key_range = partition_key_range + self.collection_rid = collection_rid + + + def __str__(self) -> str: + return ( + f"PartitionKeyRangeWrapper(" + f"partition_key_range={self.partition_key_range}, " + f"collection_rid={self.collection_rid}, " + ) + + def __eq__(self, other): + if not isinstance(other, PartitionKeyRangeWrapper): + return False + return self.partition_key_range == other.partition_key_range and self.collection_rid == other.collection_rid + diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/_service_request_retry_policy.py b/sdk/cosmos/azure-cosmos/azure/cosmos/_service_request_retry_policy.py index 9b34d048e3a6..8630714bc6f3 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/_service_request_retry_policy.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/_service_request_retry_policy.py @@ -44,9 +44,12 @@ def ShouldRetry(self): if self.request.resource_type == ResourceType.DatabaseAccount: return False - refresh_cache = self.request.last_routed_location_endpoint_within_region is not None - # This logic is for the last retry and mark the region unavailable - self.mark_endpoint_unavailable(self.request.location_endpoint_to_route, refresh_cache) + if self.global_endpoint_manager.is_circuit_breaker_applicable(self.request): + self.global_endpoint_manager.mark_partition_unavailable(self.request) + else: + refresh_cache = self.request.last_routed_location_endpoint_within_region is not None + # This logic is for the last retry and mark the region unavailable + self.mark_endpoint_unavailable(self.request.location_endpoint_to_route, refresh_cache) # Check if it is safe to do another retry if self.in_region_retry_count >= self.total_in_region_retries: @@ -65,7 +68,7 @@ def ShouldRetry(self): self.failover_retry_count += 1 if self.failover_retry_count >= self.total_retries: return False - # # Check if it is safe to failover to another region + # Check if it is safe to failover to another region location_endpoint = self.resolve_next_region_service_endpoint() else: location_endpoint = self.resolve_current_region_service_endpoint() @@ -80,7 +83,7 @@ def ShouldRetry(self): # and we reset the in region retry count self.in_region_retry_count = 0 self.failover_retry_count += 1 - # # Check if it is safe to failover to another region + # Check if it is safe to failover to another region if self.failover_retry_count >= self.total_retries: return False location_endpoint = self.resolve_next_region_service_endpoint() diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/_timeout_failover_retry_policy.py b/sdk/cosmos/azure-cosmos/azure/cosmos/_timeout_failover_retry_policy.py index b170fb4fd9d2..505a8edf3d06 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/_timeout_failover_retry_policy.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/_timeout_failover_retry_policy.py @@ -5,14 +5,13 @@ Cosmos database service. """ from azure.cosmos.documents import _OperationType +from azure.cosmos.http_constants import HttpHeaders class _TimeoutFailoverRetryPolicy(object): def __init__(self, connection_policy, global_endpoint_manager, *args): self.retry_after_in_milliseconds = 500 - self.args = args - self.global_endpoint_manager = global_endpoint_manager # If an account only has 1 region, then we still want to retry once on the same region self._max_retry_attempt_count = (len(self.global_endpoint_manager.location_cache.read_regional_routing_contexts) @@ -28,6 +27,9 @@ def ShouldRetry(self, _exception): :returns: a boolean stating whether the request should be retried :rtype: bool """ + # record the failure for circuit breaker tracking + self.global_endpoint_manager.record_failure(self.request) + # we don't retry on write operations for timeouts or any internal server errors if self.request and (not _OperationType.IsReadOnlyOperation(self.request.operation_type)): return False diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_container.py b/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_container.py index 59bf8ee71ba3..d5e5f967be4d 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_container.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_container.py @@ -166,6 +166,8 @@ async def read( :keyword bool populate_quota_info: Enable returning collection storage quota information in response headers. :keyword str session_token: Token for use with Session consistency. :keyword dict[str, str] initial_headers: Initial headers to be sent as part of the request. + :keyword list[str] excluded_locations: Excluded locations to be skipped from preferred locations. The locations + in this list are specified as the names of the azure Cosmos locations like, 'West US', 'East US' and so on. :keyword response_hook: A callable invoked with the response metadata. :paramtype response_hook: Callable[[Dict[str, str], Dict[str, Any]], None] :keyword Literal["High", "Low"] priority: Priority based execution allows users to set a priority for each @@ -227,6 +229,8 @@ async def create_item( has changed, and act according to the condition specified by the `match_condition` parameter. :keyword match_condition: The match condition to use upon the etag. :paramtype match_condition: ~azure.core.MatchConditions + :keyword list[str] excluded_locations: Excluded locations to be skipped from preferred locations. The locations + in this list are specified as the names of the azure Cosmos locations like, 'West US', 'East US' and so on. :keyword response_hook: A callable invoked with the response metadata. :paramtype response_hook: Callable[[Dict[str, str], Dict[str, Any]], None] :keyword Literal["High", "Low"] priority: Priority based execution allows users to set a priority for each @@ -297,6 +301,8 @@ async def read_item( :keyword Literal["High", "Low"] priority: Priority based execution allows users to set a priority for each request. Once the user has reached their provisioned throughput, low priority requests are throttled before high priority requests start getting throttled. Feature must first be enabled at the account level. + :keyword list[str] excluded_locations: Excluded locations to be skipped from preferred locations. The locations + in this list are specified as the names of the azure Cosmos locations like, 'West US', 'East US' and so on. :raises ~azure.cosmos.exceptions.CosmosHttpResponseError: The given item couldn't be retrieved. :returns: A CosmosDict representing the retrieved item. :rtype: ~azure.cosmos.CosmosDict[str, Any] @@ -359,6 +365,8 @@ def read_all_items( :keyword Literal["High", "Low"] priority: Priority based execution allows users to set a priority for each request. Once the user has reached their provisioned throughput, low priority requests are throttled before high priority requests start getting throttled. Feature must first be enabled at the account level. + :keyword list[str] excluded_locations: Excluded locations to be skipped from preferred locations. The locations + in this list are specified as the names of the azure Cosmos locations like, 'West US', 'East US' and so on. :returns: An AsyncItemPaged of items (dicts). :rtype: AsyncItemPaged[Dict[str, Any]] """ @@ -443,6 +451,8 @@ def query_items( :keyword Literal["High", "Low"] priority: Priority based execution allows users to set a priority for each request. Once the user has reached their provisioned throughput, low priority requests are throttled before high priority requests start getting throttled. Feature must first be enabled at the account level. + :keyword list[str] excluded_locations: Excluded locations to be skipped from preferred locations. The locations + in this list are specified as the names of the azure Cosmos locations like, 'West US', 'East US' and so on. :returns: An AsyncItemPaged of items (dicts). :rtype: AsyncItemPaged[Dict[str, Any]] @@ -541,6 +551,8 @@ def query_items_change_feed( ALL_VERSIONS_AND_DELETES: Query all versions and deleted items from either `start_time='Now'` or 'continuation' token. :paramtype mode: Literal["LatestVersion", "AllVersionsAndDeletes"] + :keyword list[str] excluded_locations: Excluded locations to be skipped from preferred locations. The locations + in this list are specified as the names of the azure Cosmos locations like, 'West US', 'East US' and so on. :keyword Callable response_hook: A callable invoked with the response metadata. :returns: An AsyncItemPaged of items (dicts). :rtype: AsyncItemPaged[Dict[str, Any]] @@ -578,6 +590,8 @@ def query_items_change_feed( ALL_VERSIONS_AND_DELETES: Query all versions and deleted items from either `start_time='Now'` or 'continuation' token. :paramtype mode: Literal["LatestVersion", "AllVersionsAndDeletes"] + :keyword list[str] excluded_locations: Excluded locations to be skipped from preferred locations. The locations + in this list are specified as the names of the azure Cosmos locations like, 'West US', 'East US' and so on. :keyword response_hook: A callable invoked with the response metadata. :paramtype response_hook: Callable[[Mapping[str, Any], AsyncItemPaged[Dict[str, Any]]], None] :returns: An AsyncItemPaged of items (dicts). @@ -604,6 +618,8 @@ def query_items_change_feed( request. Once the user has reached their provisioned throughput, low priority requests are throttled before high priority requests start getting throttled. Feature must first be enabled at the account level. :paramtype priority: Literal["High", "Low"] + :keyword list[str] excluded_locations: Excluded locations to be skipped from preferred locations. The locations + in this list are specified as the names of the azure Cosmos locations like, 'West US', 'East US' and so on. :keyword response_hook: A callable invoked with the response metadata. :paramtype response_hook: Callable[[Mapping[str, Any], AsyncItemPaged[Dict[str, Any]]], None] :returns: An AsyncItemPaged of items (dicts). @@ -642,6 +658,8 @@ def query_items_change_feed( ALL_VERSIONS_AND_DELETES: Query all versions and deleted items from either `start_time='Now'` or 'continuation' token. :paramtype mode: Literal["LatestVersion", "AllVersionsAndDeletes"] + :keyword list[str] excluded_locations: Excluded locations to be skipped from preferred locations. The locations + in this list are specified as the names of the azure Cosmos locations like, 'West US', 'East US' and so on. :keyword response_hook: A callable invoked with the response metadata. :paramtype response_hook: Callable[[Mapping[str, Any], AsyncItemPaged[Dict[str, Any]]], None] :returns: An AsyncItemPaged of items (dicts). @@ -678,6 +696,8 @@ def query_items_change_feed( # pylint: disable=unused-argument ALL_VERSIONS_AND_DELETES: Query all versions and deleted items from either `start_time='Now'` or 'continuation' token. :paramtype mode: Literal["LatestVersion", "AllVersionsAndDeletes"] + :keyword list[str] excluded_locations: Excluded locations to be skipped from preferred locations. The locations + in this list are specified as the names of the azure Cosmos locations like, 'West US', 'East US' and so on. :keyword response_hook: A callable invoked with the response metadata. :paramtype response_hook: Callable[[Mapping[str, Any], Mapping[str, Any]], None] :returns: An AsyncItemPaged of items (dicts). @@ -760,6 +780,8 @@ async def upsert_item( :keyword bool no_response: Indicates whether service should be instructed to skip sending response payloads. When not specified explicitly here, the default value will be determined from client-level options. + :keyword list[str] excluded_locations: Excluded locations to be skipped from preferred locations. The locations + in this list are specified as the names of the azure Cosmos locations like, 'West US', 'East US' and so on. :raises ~azure.cosmos.exceptions.CosmosHttpResponseError: The given item could not be upserted. :returns: A CosmosDict representing the upserted item. The dict will be empty if `no_response` is specified. @@ -833,6 +855,8 @@ async def replace_item( :keyword bool no_response: Indicates whether service should be instructed to skip sending response payloads. When not specified explicitly here, the default value will be determined from client-level options. + :keyword list[str] excluded_locations: Excluded locations to be skipped from preferred locations. The locations + in this list are specified as the names of the azure Cosmos locations like, 'West US', 'East US' and so on. :raises ~azure.cosmos.exceptions.CosmosHttpResponseError: The replace operation failed or the item with given id does not exist. :returns: A CosmosDict representing the item after replace went through. The dict will be empty if `no_response` @@ -908,6 +932,8 @@ async def patch_item( :keyword bool no_response: Indicates whether service should be instructed to skip sending response payloads. When not specified explicitly here, the default value will be determined from client-level options. + :keyword list[str] excluded_locations: Excluded locations to be skipped from preferred locations. The locations + in this list are specified as the names of the azure Cosmos locations like, 'West US', 'East US' and so on. :raises ~azure.cosmos.exceptions.CosmosHttpResponseError: The patch operations failed or the item with given id does not exist. :returns: A CosmosDict representing the item after the patch operations went through. The dict will be empty if @@ -975,6 +1001,8 @@ async def delete_item( :keyword Literal["High", "Low"] priority: Priority based execution allows users to set a priority for each request. Once the user has reached their provisioned throughput, low priority requests are throttled before high priority requests start getting throttled. Feature must first be enabled at the account level. + :keyword list[str] excluded_locations: Excluded locations to be skipped from preferred locations. The locations + in this list are specified as the names of the azure Cosmos locations like, 'West US', 'East US' and so on. :keyword response_hook: A callable invoked with the response metadata. :paramtype response_hook: Callable[[Dict[str, str], None], None] :raises ~azure.cosmos.exceptions.CosmosHttpResponseError: The item wasn't deleted successfully. @@ -1230,6 +1258,8 @@ async def delete_all_items_by_partition_key( :keyword str etag: An ETag value, or the wildcard character (*). Used to check if the resource has changed, and act according to the condition specified by the `match_condition` parameter. :keyword ~azure.core.MatchConditions match_condition: The match condition to use upon the etag. + :keyword list[str] excluded_locations: Excluded locations to be skipped from preferred locations. The locations + in this list are specified as the names of the azure Cosmos locations like, 'West US', 'East US' and so on. :keyword Callable response_hook: A callable invoked with the response metadata. :rtype: None """ @@ -1281,6 +1311,8 @@ async def execute_item_batch( :keyword Literal["High", "Low"] priority: Priority based execution allows users to set a priority for each request. Once the user has reached their provisioned throughput, low priority requests are throttled before high priority requests start getting throttled. Feature must first be enabled at the account level. + :keyword list[str] excluded_locations: Excluded locations to be skipped from preferred locations. The locations + in this list are specified as the names of the azure Cosmos locations like, 'West US', 'East US' and so on. :keyword Callable response_hook: A callable invoked with the response metadata. :returns: A CosmosList representing the items after the batch operations went through. :raises ~azure.cosmos.exceptions.CosmosHttpResponseError: The batch failed to execute. diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_cosmos_client_connection_async.py b/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_cosmos_client_connection_async.py index 49219533a7e6..0d3d61e87fa0 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_cosmos_client_connection_async.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_cosmos_client_connection_async.py @@ -54,6 +54,7 @@ from .. import documents from .._change_feed.aio.change_feed_iterable import ChangeFeedIterable from .._change_feed.change_feed_state import ChangeFeedState +from azure.cosmos.aio._global_partition_endpoint_manager_circuit_breaker_async import _GlobalPartitionEndpointManagerForCircuitBreaker from .._routing import routing_range from ..documents import ConnectionPolicy, DatabaseAccount from .._constants import _Constants as Constants @@ -63,7 +64,6 @@ from .. import _runtime_constants as runtime_constants from .. import _request_object from . import _asynchronous_request as asynchronous_request -from . import _global_endpoint_manager_async as global_endpoint_manager_async from .._routing.aio.routing_map_provider import SmartRoutingMapProvider from ._retry_utility_async import _ConnectionRetryPolicy from .. import _session @@ -169,7 +169,7 @@ def __init__( # pylint: disable=too-many-statements # Keeps the latest response headers from the server. self.last_response_headers: CaseInsensitiveDict = CaseInsensitiveDict() self.UseMultipleWriteLocations = False - self._global_endpoint_manager = global_endpoint_manager_async._GlobalEndpointManager(self) + self._global_endpoint_manager = _GlobalPartitionEndpointManagerForCircuitBreaker(self) retry_policy = None if isinstance(self.connection_policy.ConnectionRetryConfiguration, AsyncHTTPPolicy): @@ -415,7 +415,8 @@ async def GetDatabaseAccount( documents._OperationType.Read, {}, client_id=self.client_id) # path # id # type - request_params = _request_object.RequestObject("databaseaccount", documents._OperationType.Read, url_connection) + request_params = _request_object.RequestObject("databaseaccount", documents._OperationType.Read, + headers, url_connection) result, self.last_response_headers = await self.__Get("", request_params, headers, **kwargs) database_account = documents.DatabaseAccount() @@ -465,7 +466,9 @@ async def _GetDatabaseAccountCheck( documents._OperationType.Read, {}, client_id=self.client_id) # path # id # type - request_params = _request_object.RequestObject("databaseaccount", documents._OperationType.Read, url_connection) + request_params = _request_object.RequestObject("databaseaccount", documents._OperationType.Read, + headers, + url_connection) await self.__Get("", request_params, headers, **kwargs) async def CreateDatabase( @@ -729,7 +732,9 @@ async def ExecuteStoredProcedure( documents._OperationType.ExecuteJavaScript, options) # ExecuteStoredProcedure will use WriteEndpoint since it uses POST operation - request_params = _request_object.RequestObject("sprocs", documents._OperationType.ExecuteJavaScript) + request_params = _request_object.RequestObject("sprocs", + documents._OperationType.ExecuteJavaScript, headers) + request_params.set_excluded_location_from_options(options) result, self.last_response_headers = await self.__Post(path, request_params, params, headers, **kwargs) return result @@ -767,7 +772,8 @@ async def Create( documents._OperationType.Create, options) # Create will use WriteEndpoint since it uses POST operation - request_params = _request_object.RequestObject(typ, documents._OperationType.Create) + request_params = _request_object.RequestObject(typ, documents._OperationType.Create, headers) + request_params.set_excluded_location_from_options(options) result, last_response_headers = await self.__Post(path, request_params, body, headers, **kwargs) self.last_response_headers = last_response_headers @@ -906,7 +912,8 @@ async def Upsert( headers[http_constants.HttpHeaders.IsUpsert] = True # Upsert will use WriteEndpoint since it uses POST operation - request_params = _request_object.RequestObject(typ, documents._OperationType.Upsert) + request_params = _request_object.RequestObject(typ, documents._OperationType.Upsert, headers) + request_params.set_excluded_location_from_options(options) result, last_response_headers = await self.__Post(path, request_params, body, headers, **kwargs) self.last_response_headers = last_response_headers # update session for write request @@ -1207,7 +1214,8 @@ async def Read( headers = base.GetHeaders(self, initial_headers, "get", path, id, typ, documents._OperationType.Read, options) # Read will use ReadEndpoint since it uses GET operation - request_params = _request_object.RequestObject(typ, documents._OperationType.Read) + request_params = _request_object.RequestObject(typ, documents._OperationType.Read, headers) + request_params.set_excluded_location_from_options(options) result, last_response_headers = await self.__Get(path, request_params, headers, **kwargs) self.last_response_headers = last_response_headers if response_hook: @@ -1465,7 +1473,8 @@ async def PatchItem( headers = base.GetHeaders(self, initial_headers, "patch", path, document_id, typ, documents._OperationType.Patch, options) # Patch will use WriteEndpoint since it uses PUT operation - request_params = _request_object.RequestObject(typ, documents._OperationType.Patch) + request_params = _request_object.RequestObject(typ, documents._OperationType.Patch, headers) + request_params.set_excluded_location_from_options(options) request_data = {} if options.get("filterPredicate"): request_data["condition"] = options.get("filterPredicate") @@ -1569,7 +1578,8 @@ async def Replace( headers = base.GetHeaders(self, initial_headers, "put", path, id, typ, documents._OperationType.Replace, options) # Replace will use WriteEndpoint since it uses PUT operation - request_params = _request_object.RequestObject(typ, documents._OperationType.Replace) + request_params = _request_object.RequestObject(typ, documents._OperationType.Replace, headers) + request_params.set_excluded_location_from_options(options) result, last_response_headers = await self.__Put(path, request_params, resource, headers, **kwargs) self.last_response_headers = last_response_headers @@ -1892,7 +1902,8 @@ async def DeleteResource( headers = base.GetHeaders(self, initial_headers, "delete", path, id, typ, documents._OperationType.Delete, options) # Delete will use WriteEndpoint since it uses DELETE operation - request_params = _request_object.RequestObject(typ, documents._OperationType.Delete) + request_params = _request_object.RequestObject(typ, documents._OperationType.Delete, headers) + request_params.set_excluded_location_from_options(options) result, last_response_headers = await self.__Delete(path, request_params, headers, **kwargs) self.last_response_headers = last_response_headers @@ -2005,7 +2016,8 @@ async def _Batch( base._populate_batch_headers(initial_headers) headers = base.GetHeaders(self, initial_headers, "post", path, collection_id, "docs", documents._OperationType.Batch, options) - request_params = _request_object.RequestObject("docs", documents._OperationType.Batch) + request_params = _request_object.RequestObject("docs", documents._OperationType.Batch, headers) + request_params.set_excluded_location_from_options(options) result = await self.__Post(path, request_params, batch_operations, headers, **kwargs) return cast(Tuple[List[Dict[str, Any]], CaseInsensitiveDict], result) @@ -2856,13 +2868,16 @@ def __GetBodiesFromQueryResult(result: Dict[str, Any]) -> List[Dict[str, Any]]: initial_headers = self.default_headers.copy() # Copy to make sure that default_headers won't be changed. if query is None: + op_typ = documents._OperationType.QueryPlan if is_query_plan else documents._OperationType.ReadFeed # Query operations will use ReadEndpoint even though it uses GET(for feed requests) + headers = base.GetHeaders(self, initial_headers, "get", path, id_, typ, op_typ, + options, partition_key_range_id) request_params = _request_object.RequestObject( typ, - documents._OperationType.QueryPlan if is_query_plan else documents._OperationType.ReadFeed + op_typ, + headers ) - headers = base.GetHeaders(self, initial_headers, "get", path, id_, typ, request_params.operation_type, - options, partition_key_range_id) + request_params.set_excluded_location_from_options(options) change_feed_state: Optional[ChangeFeedState] = options.get("changeFeedState") if change_feed_state is not None: @@ -2889,9 +2904,11 @@ def __GetBodiesFromQueryResult(result: Dict[str, Any]) -> List[Dict[str, Any]]: raise SystemError("Unexpected query compatibility mode.") # Query operations will use ReadEndpoint even though it uses POST(for regular query operations) - request_params = _request_object.RequestObject(typ, documents._OperationType.SqlQuery) - req_headers = base.GetHeaders(self, initial_headers, "post", path, id_, typ, request_params.operation_type, + req_headers = base.GetHeaders(self, initial_headers, "post", path, id_, typ, + documents._OperationType.SqlQuery, options, partition_key_range_id) + request_params = _request_object.RequestObject(typ, documents._OperationType.SqlQuery, req_headers) + request_params.set_excluded_location_from_options(options) # check if query has prefix partition key cont_prop = kwargs.pop("containerProperties", None) @@ -3195,7 +3212,8 @@ async def _GetQueryPlanThroughGateway(self, query: str, resource_link: str, **kw documents._QueryFeature.NonStreamingOrderBy + "," + documents._QueryFeature.HybridSearch + "," + documents._QueryFeature.CountIf) - if os.environ.get('AZURE_COSMOS_DISABLE_NON_STREAMING_ORDER_BY', False): + if os.environ.get(Constants.NON_STREAMING_ORDER_BY_DISABLED_CONFIG, + Constants.NON_STREAMING_ORDER_BY_DISABLED_CONFIG_DEFAULT) == "True": supported_query_features = (documents._QueryFeature.Aggregate + "," + documents._QueryFeature.CompositeAggregate + "," + documents._QueryFeature.Distinct + "," + @@ -3258,7 +3276,9 @@ async def DeleteAllItemsByPartitionKey( initial_headers = dict(self.default_headers) headers = base.GetHeaders(self, initial_headers, "post", path, collection_id, "partitionkey", documents._OperationType.Delete, options) - request_params = _request_object.RequestObject("partitionkey", documents._OperationType.Delete) + request_params = _request_object.RequestObject("partitionkey", documents._OperationType.Delete, + headers) + request_params.set_excluded_location_from_options(options) _, last_response_headers = await self.__Post(path=path, request_params=request_params, req_headers=headers, body=None, **kwargs) self.last_response_headers = last_response_headers diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_global_endpoint_manager_async.py b/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_global_endpoint_manager_async.py index 7c845841b224..ab07f90c411d 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_global_endpoint_manager_async.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_global_endpoint_manager_async.py @@ -25,7 +25,6 @@ import asyncio # pylint: disable=do-not-import-asyncio import logging -from asyncio import CancelledError from typing import Tuple from azure.core.exceptions import AzureError @@ -33,8 +32,7 @@ from .. import _constants as constants from .. import exceptions -from .._location_cache import LocationCache - +from .._location_cache import LocationCache, current_time_millis # pylint: disable=protected-access @@ -48,15 +46,15 @@ class _GlobalEndpointManager(object): # pylint: disable=too-many-instance-attrib def __init__(self, client): self.client = client - self.EnableEndpointDiscovery = client.connection_policy.EnableEndpointDiscovery self.PreferredLocations = client.connection_policy.PreferredLocations self.DefaultEndpoint = client.url_connection self.refresh_time_interval_in_ms = self.get_refresh_time_interval_in_ms_stub() self.location_cache = LocationCache( self.PreferredLocations, self.DefaultEndpoint, - self.EnableEndpointDiscovery, - client.connection_policy.UseMultipleWriteLocations + client.connection_policy.EnableEndpointDiscovery, + client.connection_policy.UseMultipleWriteLocations, + client.connection_policy ) self.startup = True self.refresh_task = None @@ -65,6 +63,7 @@ def __init__(self, client): self.last_refresh_time = 0 self._database_account_cache = None + # TODO: @tvaron3 fix this def get_use_multiple_write_locations(self): return self.location_cache.can_use_multiple_write_locations() @@ -105,9 +104,9 @@ async def refresh_endpoint_list(self, database_account, **kwargs): try: await self.refresh_task self.refresh_task = None - except (Exception, CancelledError) as exception: #pylint: disable=broad-exception-caught + except (Exception, asyncio.CancelledError) as exception: #pylint: disable=broad-exception-caught logger.exception("Health check task failed: %s", exception) - if self.location_cache.current_time_millis() - self.last_refresh_time > self.refresh_time_interval_in_ms: + if current_time_millis() - self.last_refresh_time > self.refresh_time_interval_in_ms: self.refresh_needed = True if self.refresh_needed: async with self.refresh_lock: @@ -123,11 +122,11 @@ async def _refresh_endpoint_list_private(self, database_account=None, **kwargs): if database_account and not self.startup: self.location_cache.perform_on_database_account_read(database_account) self.refresh_needed = False - self.last_refresh_time = self.location_cache.current_time_millis() + self.last_refresh_time = current_time_millis() else: if self.location_cache.should_refresh_endpoints() or self.refresh_needed: self.refresh_needed = False - self.last_refresh_time = self.location_cache.current_time_millis() + self.last_refresh_time = current_time_millis() if not self.startup: # this will perform getDatabaseAccount calls to check endpoint health # in background @@ -216,5 +215,5 @@ async def close(self): self.refresh_task.cancel() try: await self.refresh_task - except (Exception, CancelledError) : #pylint: disable=broad-exception-caught + except (Exception, asyncio.CancelledError) : #pylint: disable=broad-exception-caught pass diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_retry_utility_async.py b/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_retry_utility_async.py index ef5bad070014..c020a5d4e31c 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_retry_utility_async.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_retry_utility_async.py @@ -104,6 +104,7 @@ async def ExecuteAsync(client, global_endpoint_manager, function, *args, **kwarg result = await ExecuteFunctionAsync(function, global_endpoint_manager, *args, **kwargs) else: result = await ExecuteFunctionAsync(function, *args, **kwargs) + global_endpoint_manager.record_success(request) if not client.last_response_headers: client.last_response_headers = {} @@ -198,6 +199,7 @@ async def ExecuteAsync(client, global_endpoint_manager, function, *args, **kwarg if client_timeout: kwargs['timeout'] = client_timeout - (time.time() - start_time) if kwargs['timeout'] <= 0: + global_endpoint_manager.record_failure(request) raise exceptions.CosmosClientTimeoutError() except ServiceRequestError as e: @@ -279,6 +281,8 @@ async def send(self, request): # the request ran into a socket timeout or failed to establish a new connection # since request wasn't sent, raise exception immediately to be dealt with in client retry policies if not _has_database_account_header(request.http_request.headers): + # TODO: @tvaron3 record failure here + request.context.global_endpoint_manager.record_failure(request.context.options['request']) if retry_settings['connect'] > 0: retry_active = self.increment(retry_settings, response=request, error=err) if retry_active: @@ -296,6 +300,8 @@ async def send(self, request): ClientConnectionError) if (isinstance(err.inner_exception, ClientConnectionError) or _has_read_retryable_headers(request.http_request.headers)): + # TODO: @tvaron3 record failure here + request.context.global_endpoint_manager.record_failure(request.context.options['request']) # This logic is based on the _retry.py file from azure-core if retry_settings['read'] > 0: retry_active = self.increment(retry_settings, response=request, error=err) @@ -310,6 +316,8 @@ async def send(self, request): if _has_database_account_header(request.http_request.headers): raise err if self._is_method_retryable(retry_settings, request.http_request): + # TODO: @tvaron3 record failure here ? Not sure + request.context.global_endpoint_manager.record_failure(request.context.options['request']) retry_active = self.increment(retry_settings, response=request, error=err) if retry_active: await self.sleep(retry_settings, request.context.transport) diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/container.py b/sdk/cosmos/azure-cosmos/azure/cosmos/container.py index 69be9491f27b..42c3d31acae6 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/container.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/container.py @@ -168,6 +168,8 @@ def read( # pylint:disable=docstring-missing-param request. Once the user has reached their provisioned throughput, low priority requests are throttled before high priority requests start getting throttled. Feature must first be enabled at the account level. :keyword dict[str, str] initial_headers: Initial headers to be sent as part of the request. + :keyword list[str] excluded_locations: Excluded locations to be skipped from preferred locations. The locations + in this list are specified as the names of the azure Cosmos locations like, 'West US', 'East US' and so on. :keyword Callable response_hook: A callable invoked with the response metadata. :raises ~azure.cosmos.exceptions.CosmosHttpResponseError: Raised if the container couldn't be retrieved. This includes if the container does not exist. @@ -225,6 +227,8 @@ def read_item( # pylint:disable=docstring-missing-param :keyword Literal["High", "Low"] priority: Priority based execution allows users to set a priority for each request. Once the user has reached their provisioned throughput, low priority requests are throttled before high priority requests start getting throttled. Feature must first be enabled at the account level. + :keyword list[str] excluded_locations: Excluded locations to be skipped from preferred locations. The locations + in this list are specified as the names of the azure Cosmos locations like, 'West US', 'East US' and so on. :returns: A CosmosDict representing the item to be retrieved. :raises ~azure.cosmos.exceptions.CosmosHttpResponseError: The given item couldn't be retrieved. :rtype: ~azure.cosmos.CosmosDict[str, Any] @@ -292,6 +296,8 @@ def read_all_items( # pylint:disable=docstring-missing-param :keyword Literal["High", "Low"] priority: Priority based execution allows users to set a priority for each request. Once the user has reached their provisioned throughput, low priority requests are throttled before high priority requests start getting throttled. Feature must first be enabled at the account level. + :keyword list[str] excluded_locations: Excluded locations to be skipped from preferred locations. The locations + in this list are specified as the names of the azure Cosmos locations like, 'West US', 'East US' and so on. :returns: An Iterable of items (dicts). :rtype: Iterable[Dict[str, Any]] """ @@ -359,6 +365,8 @@ def query_items_change_feed( ALL_VERSIONS_AND_DELETES: Query all versions and deleted items from either `start_time='Now'` or 'continuation' token. :paramtype mode: Literal["LatestVersion", "AllVersionsAndDeletes"] + :keyword list[str] excluded_locations: Excluded locations to be skipped from preferred locations. The locations + in this list are specified as the names of the azure Cosmos locations like, 'West US', 'East US' and so on. :keyword response_hook: A callable invoked with the response metadata. Note that due to the nature of combining calls to build the results, this function may be called with a either single dict or iterable of dicts :type response_hook: @@ -400,6 +408,8 @@ def query_items_change_feed( ALL_VERSIONS_AND_DELETES: Query all versions and deleted items from either `start_time='Now'` or 'continuation' token. :paramtype mode: Literal["LatestVersion", "AllVersionsAndDeletes"] + :keyword list[str] excluded_locations: Excluded locations to be skipped from preferred locations. The locations + in this list are specified as the names of the azure Cosmos locations like, 'West US', 'East US' and so on. :keyword response_hook: A callable invoked with the response metadata. :paramtype response_hook: Callable[[Mapping[str, Any], ItemPaged[Dict[str, Any]]], None] :returns: An Iterable of items (dicts). @@ -426,6 +436,8 @@ def query_items_change_feed( request. Once the user has reached their provisioned throughput, low priority requests are throttled before high priority requests start getting throttled. Feature must first be enabled at the account level. :paramtype priority: Literal["High", "Low"] + :keyword list[str] excluded_locations: Excluded locations to be skipped from preferred locations. The locations + in this list are specified as the names of the azure Cosmos locations like, 'West US', 'East US' and so on. :keyword response_hook: A callable invoked with the response metadata. :paramtype response_hook: Callable[[Mapping[str, Any], ItemPaged[Dict[str, Any]]], None] :returns: An Iterable of items (dicts). @@ -463,6 +475,8 @@ def query_items_change_feed( ALL_VERSIONS_AND_DELETES: Query all versions and deleted items from either `start_time='Now'` or 'continuation' token. :paramtype mode: Literal["LatestVersion", "AllVersionsAndDeletes"] + :keyword list[str] excluded_locations: Excluded locations to be skipped from preferred locations. The locations + in this list are specified as the names of the azure Cosmos locations like, 'West US', 'East US' and so on. :keyword response_hook: A callable invoked with the response metadata. :paramtype response_hook: Callable[[Mapping[str, Any], ItemPaged[Dict[str, Any]]], None] :returns: An Iterable of items (dicts). @@ -498,6 +512,8 @@ def query_items_change_feed( ALL_VERSIONS_AND_DELETES: Query all versions and deleted items from either `start_time='Now'` or 'continuation' token. :paramtype mode: Literal["LatestVersion", "AllVersionsAndDeletes"] + :keyword list[str] excluded_locations: Excluded locations to be skipped from preferred locations. The locations + in this list are specified as the names of the azure Cosmos locations like, 'West US', 'East US' and so on. :keyword response_hook: A callable invoked with the response metadata. :paramtype response_hook: Callable[[Mapping[str, Any], ItemPaged[Dict[str, Any]]], None] :param Any args: args @@ -604,6 +620,8 @@ def query_items( # pylint:disable=docstring-missing-param :keyword bool populate_index_metrics: Used to obtain the index metrics to understand how the query engine used existing indexes and how it could use potential new indexes. Please note that this options will incur overhead, so it should be enabled only when debugging slow queries. + :keyword list[str] excluded_locations: Excluded locations to be skipped from preferred locations. The locations + in this list are specified as the names of the azure Cosmos locations like, 'West US', 'East US' and so on. :returns: An Iterable of items (dicts). :rtype: ItemPaged[Dict[str, Any]] @@ -719,6 +737,8 @@ def replace_item( # pylint:disable=docstring-missing-param :keyword bool no_response: Indicates whether service should be instructed to skip sending response payloads. When not specified explicitly here, the default value will be determined from kwargs or when also not specified there from client-level kwargs. + :keyword list[str] excluded_locations: Excluded locations to be skipped from preferred locations. The locations + in this list are specified as the names of the azure Cosmos locations like, 'West US', 'East US' and so on. :raises ~azure.cosmos.exceptions.CosmosHttpResponseError: The replace operation failed or the item with given id does not exist. :returns: A CosmosDict representing the item after replace went through. The dict will be empty if `no_response` @@ -794,6 +814,8 @@ def upsert_item( # pylint:disable=docstring-missing-param :keyword bool no_response: Indicates whether service should be instructed to skip sending response payloads. When not specified explicitly here, the default value will be determined from kwargs or when also not specified there from client-level kwargs. + :keyword list[str] excluded_locations: Excluded locations to be skipped from preferred locations. The locations + in this list are specified as the names of the azure Cosmos locations like, 'West US', 'East US' and so on. :raises ~azure.cosmos.exceptions.CosmosHttpResponseError: The given item could not be upserted. :returns: A CosmosDict representing the upserted item. The dict will be empty if `no_response` is specified. :rtype: ~azure.cosmos.CosmosDict[str, Any] @@ -876,6 +898,8 @@ def create_item( # pylint:disable=docstring-missing-param :keyword bool no_response: Indicates whether service should be instructed to skip sending response payloads. When not specified explicitly here, the default value will be determined from kwargs or when also not specified there from client-level kwargs. + :keyword list[str] excluded_locations: Excluded locations to be skipped from preferred locations. The locations + in this list are specified as the names of the azure Cosmos locations like, 'West US', 'East US' and so on. :raises ~azure.cosmos.exceptions.CosmosHttpResponseError: Item with the given ID already exists. :returns: A CosmosDict representing the new item. The dict will be empty if `no_response` is specified. :rtype: ~azure.cosmos.CosmosDict[str, Any] @@ -954,6 +978,8 @@ def patch_item( :keyword bool no_response: Indicates whether service should be instructed to skip sending response payloads. When not specified explicitly here, the default value will be determined from kwargs or when also not specified there from client-level kwargs. + :keyword list[str] excluded_locations: Excluded locations to be skipped from preferred locations. The locations + in this list are specified as the names of the azure Cosmos locations like, 'West US', 'East US' and so on. :raises ~azure.cosmos.exceptions.CosmosHttpResponseError: The patch operations failed or the item with given id does not exist. :returns: A CosmosDict representing the item after the patch operations went through. The dict will be empty @@ -1016,6 +1042,8 @@ def execute_item_batch( :keyword Literal["High", "Low"] priority: Priority based execution allows users to set a priority for each request. Once the user has reached their provisioned throughput, low priority requests are throttled before high priority requests start getting throttled. Feature must first be enabled at the account level. + :keyword list[str] excluded_locations: Excluded locations to be skipped from preferred locations. The locations + in this list are specified as the names of the azure Cosmos locations like, 'West US', 'East US' and so on. :keyword Callable response_hook: A callable invoked with the response metadata. :returns: A CosmosList representing the items after the batch operations went through. :raises ~azure.cosmos.exceptions.CosmosHttpResponseError: The batch failed to execute. @@ -1075,6 +1103,8 @@ def delete_item( # pylint:disable=docstring-missing-param :keyword Literal["High", "Low"] priority: Priority based execution allows users to set a priority for each request. Once the user has reached their provisioned throughput, low priority requests are throttled before high priority requests start getting throttled. Feature must first be enabled at the account level. + :keyword list[str] excluded_locations: Excluded locations to be skipped from preferred locations. The locations + in this list are specified as the names of the azure Cosmos locations like, 'West US', 'East US' and so on. :keyword Callable response_hook: A callable invoked with the response metadata. :raises ~azure.cosmos.exceptions.CosmosHttpResponseError: The item wasn't deleted successfully. :raises ~azure.cosmos.exceptions.CosmosResourceNotFoundError: The item does not exist in the container. @@ -1351,6 +1381,8 @@ def delete_all_items_by_partition_key( :keyword str etag: An ETag value, or the wildcard character (*). Used to check if the resource has changed, and act according to the condition specified by the `match_condition` parameter. :keyword ~azure.core.MatchConditions match_condition: The match condition to use upon the etag. + :keyword list[str] excluded_locations: Excluded locations to be skipped from preferred locations. The locations + in this list are specified as the names of the azure Cosmos locations like, 'West US', 'East US' and so on. :keyword Callable response_hook: A callable invoked with the response metadata. :rtype: None """ diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/cosmos_client.py b/sdk/cosmos/azure-cosmos/azure/cosmos/cosmos_client.py index cc4bd43d13a2..3c4399ad6d60 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/cosmos_client.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/cosmos_client.py @@ -94,6 +94,8 @@ def _build_connection_policy(kwargs: Dict[str, Any]) -> ConnectionPolicy: policy.ProxyConfiguration = kwargs.pop('proxy_config', policy.ProxyConfiguration) policy.EnableEndpointDiscovery = kwargs.pop('enable_endpoint_discovery', policy.EnableEndpointDiscovery) policy.PreferredLocations = kwargs.pop('preferred_locations', policy.PreferredLocations) + # TODO: Consider storing callback method instead, such as 'Supplier' in JAVA SDK + policy.ExcludedLocations = kwargs.pop('excluded_locations', policy.ExcludedLocations) policy.UseMultipleWriteLocations = kwargs.pop('multiple_write_locations', policy.UseMultipleWriteLocations) # SSL config @@ -182,6 +184,8 @@ class CosmosClient: # pylint: disable=client-accepts-api-version-keyword :keyword bool enable_endpoint_discovery: Enable endpoint discovery for geo-replicated database accounts. (Default: True) :keyword list[str] preferred_locations: The preferred locations for geo-replicated database accounts. + :keyword list[str] excluded_locations: The excluded locations to be skipped from preferred locations. The locations + in this list are specified as the names of the azure Cosmos locations like, 'West US', 'East US' and so on. :keyword bool enable_diagnostics_logging: Enable the CosmosHttpLogging policy. Must be used along with a logger to work. :keyword ~logging.Logger logger: Logger to be used for collecting request diagnostics. Can be passed in at client diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/documents.py b/sdk/cosmos/azure-cosmos/azure/cosmos/documents.py index 942e426934b7..b5e18f3680e6 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/documents.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/documents.py @@ -308,6 +308,13 @@ class ConnectionPolicy: # pylint: disable=too-many-instance-attributes locations in this list are specified as the names of the azure Cosmos locations like, 'West US', 'East US', 'Central India' and so on. :vartype PreferredLocations: List[str] + :ivar ExcludedLocations: + Gets or sets the excluded locations for geo-replicated database + accounts. When ExcludedLocations is non-empty, the client will skip this + set of locations from the final location evaluation. The locations in + this list are specified as the names of the azure Cosmos locations like, + 'West US', 'East US', 'Central India' and so on. + :vartype ExcludedLocations: ~CosmosExcludedLocations :ivar RetryOptions: Gets or sets the retry options to be applied to all requests when retrying. @@ -347,6 +354,7 @@ def __init__(self) -> None: self.ProxyConfiguration: Optional[ProxyConfiguration] = None self.EnableEndpointDiscovery: bool = True self.PreferredLocations: List[str] = [] + self.ExcludedLocations: List[str] = [] self.RetryOptions: RetryOptions = RetryOptions() self.DisableSSLVerification: bool = False self.UseMultipleWriteLocations: bool = False diff --git a/sdk/cosmos/azure-cosmos/tests/test_location_cache.py b/sdk/cosmos/azure-cosmos/tests/test_location_cache.py index a957094f1790..f540797bc112 100644 --- a/sdk/cosmos/azure-cosmos/tests/test_location_cache.py +++ b/sdk/cosmos/azure-cosmos/tests/test_location_cache.py @@ -3,8 +3,10 @@ import time import unittest +from typing import Mapping, Any import pytest +from azure.cosmos import documents from azure.cosmos.documents import DatabaseAccount, _OperationType from azure.cosmos.http_constants import ResourceType @@ -35,15 +37,16 @@ def create_database_account(enable_multiple_writable_locations): return db_acc -def refresh_location_cache(preferred_locations, use_multiple_write_locations): +def refresh_location_cache(preferred_locations, use_multiple_write_locations, connection_policy=documents.ConnectionPolicy()): lc = LocationCache(preferred_locations=preferred_locations, default_endpoint=default_endpoint, enable_endpoint_discovery=True, - use_multiple_write_locations=use_multiple_write_locations) + use_multiple_write_locations=use_multiple_write_locations, + connection_policy=connection_policy) return lc @pytest.mark.cosmosEmulator -class TestLocationCache(unittest.TestCase): +class TestLocationCache: def test_mark_endpoint_unavailable(self): lc = refresh_location_cache([], False) @@ -136,6 +139,140 @@ def test_resolve_request_endpoint_preferred_regions(self): assert read_resolved == write_resolved assert read_resolved == default_endpoint + @pytest.mark.parametrize("test_type",["OnClient", "OnRequest", "OnBoth"]) + def test_get_applicable_regional_endpoints_excluded_regions(self, test_type): + # Init test data + if test_type == "OnClient": + excluded_locations_on_client_list = [ + [location1_name], + [location1_name, location2_name], + [location1_name, location2_name, location3_name], + [location4_name], + [], + ] + excluded_locations_on_requests_list = [None] * 5 + elif test_type == "OnRequest": + excluded_locations_on_client_list = [[]] * 5 + excluded_locations_on_requests_list = [ + [location1_name], + [location1_name, location2_name], + [location1_name, location2_name, location3_name], + [location4_name], + [], + ] + else: + excluded_locations_on_client_list = [ + [location1_name], + [location1_name, location2_name, location3_name], + [location1_name, location2_name], + [location2_name], + [location1_name, location2_name, location3_name], + ] + excluded_locations_on_requests_list = [ + [location1_name], + [location1_name, location2_name], + [location1_name, location2_name, location3_name], + [location4_name], + [], + ] + + expected_read_endpoints_list = [ + [location2_endpoint], + [location1_endpoint], + [location1_endpoint], + [location1_endpoint, location2_endpoint], + [location1_endpoint, location2_endpoint], + ] + expected_write_endpoints_list = [ + [location2_endpoint, location3_endpoint], + [location3_endpoint], + [default_endpoint], + [location1_endpoint, location2_endpoint, location3_endpoint], + [location1_endpoint, location2_endpoint, location3_endpoint], + ] + + # Loop over each test cases + for excluded_locations_on_client, excluded_locations_on_requests, expected_read_endpoints, expected_write_endpoints in zip(excluded_locations_on_client_list, excluded_locations_on_requests_list, expected_read_endpoints_list, expected_write_endpoints_list): + # Init excluded_locations in ConnectionPolicy + connection_policy = documents.ConnectionPolicy() + connection_policy.ExcludedLocations = excluded_locations_on_client + + # Init location_cache + location_cache = refresh_location_cache([location1_name, location2_name, location3_name], True, + connection_policy) + database_account = create_database_account(True) + location_cache.perform_on_database_account_read(database_account) + + # Init requests and set excluded regions on requests + write_doc_request = RequestObject(ResourceType.Document, _OperationType.Create) + write_doc_request.excluded_locations = excluded_locations_on_requests + read_doc_request = RequestObject(ResourceType.Document, _OperationType.Read) + read_doc_request.excluded_locations = excluded_locations_on_requests + + # Test if read endpoints were correctly filtered on client level + read_doc_endpoint = location_cache.get_applicable_read_regional_endpoints(read_doc_request) + read_doc_endpoint = [regional_endpoint.get_primary() for regional_endpoint in read_doc_endpoint] + assert read_doc_endpoint == expected_read_endpoints + + # Test if write endpoints were correctly filtered on client level + write_doc_endpoint = location_cache.get_applicable_write_regional_endpoints(write_doc_request) + write_doc_endpoint = [regional_endpoint.get_primary() for regional_endpoint in write_doc_endpoint] + assert write_doc_endpoint == expected_write_endpoints + + def test_set_excluded_locations_for_requests(self): + # Init excluded_locations in ConnectionPolicy + excluded_locations_on_client = [location1_name, location2_name] + connection_policy = documents.ConnectionPolicy() + connection_policy.ExcludedLocations = excluded_locations_on_client + + # Init location_cache + location_cache = refresh_location_cache([location1_name, location2_name, location3_name], True, + connection_policy) + database_account = create_database_account(True) + location_cache.perform_on_database_account_read(database_account) + + # Test setting excluded locations + excluded_locations = [location1_name] + options: Mapping[str, Any] = {"excludedLocations": excluded_locations} + + expected_excluded_locations = excluded_locations + read_doc_request = RequestObject(ResourceType.Document, _OperationType.Create) + read_doc_request.set_excluded_location_from_options(options) + actual_excluded_locations = read_doc_request.excluded_locations + assert actual_excluded_locations == expected_excluded_locations + + expected_read_endpoints = [location2_endpoint] + read_doc_endpoint = location_cache.get_applicable_read_regional_endpoints(read_doc_request) + read_doc_endpoint = [regional_endpoint.get_primary() for regional_endpoint in read_doc_endpoint] + assert read_doc_endpoint == expected_read_endpoints + + + # Test setting excluded locations with invalid resource types + expected_excluded_locations = None + for resource_type in [ResourceType.Offer, ResourceType.Conflict]: + options: Mapping[str, Any] = {"excludedLocations": [location1_name]} + read_doc_request = RequestObject(resource_type, _OperationType.Create) + read_doc_request.set_excluded_location_from_options(options) + actual_excluded_locations = read_doc_request.excluded_locations + assert actual_excluded_locations == expected_excluded_locations + + expected_read_endpoints = [location1_endpoint] + read_doc_endpoint = location_cache.get_applicable_read_regional_endpoints(read_doc_request) + read_doc_endpoint = [regional_endpoint.get_primary() for regional_endpoint in read_doc_endpoint] + assert read_doc_endpoint == expected_read_endpoints + + + + # Test setting excluded locations with None value + expected_error_message = ("Excluded locations cannot be None. " + "If you want to remove all excluded locations, try passing an empty list.") + with pytest.raises(ValueError) as e: + options: Mapping[str, Any] = {"excludedLocations": None} + doc_request = RequestObject(ResourceType.Document, _OperationType.Create) + doc_request.set_excluded_location_from_options(options) + assert str( + e.value) == expected_error_message + if __name__ == "__main__": unittest.main() diff --git a/sdk/cosmos/azure-cosmos/tests/test_query_hybrid_search.py b/sdk/cosmos/azure-cosmos/tests/test_query_hybrid_search.py index 429498b3071f..3a9e8527992e 100644 --- a/sdk/cosmos/azure-cosmos/tests/test_query_hybrid_search.py +++ b/sdk/cosmos/azure-cosmos/tests/test_query_hybrid_search.py @@ -2,6 +2,7 @@ # Copyright (c) Microsoft Corporation. All rights reserved. import json +import os import time import unittest import uuid @@ -94,6 +95,18 @@ def test_wrong_hybrid_search_queries(self): assert ("One of the input values is invalid" in e.message or "Specifying a sort order (ASC or DESC) in the ORDER BY RANK clause is not allowed." in e.message) + def test_hybrid_search_env_variables_async(self): + os.environ["AZURE_COSMOS_HYBRID_SEARCH_MAX_ITEMS"] = "1" + try: + query = "SELECT TOP 1 c.index, c.title FROM c WHERE FullTextContains(c.title, 'John') OR " \ + "FullTextContains(c.text, 'John') ORDER BY RANK FullTextScore(c.title, ['John'])" + results = self.test_container.query_items(query) + [item for item in results] + pytest.fail("Config was not applied properly.") + except ValueError as e: + assert e.args[0] == ("Executing a hybrid search query with more items than the max is not allowed. " + "Please ensure you are using a limit smaller than the max, or change the max.") + def test_hybrid_search_queries(self): query = "SELECT TOP 10 c.index, c.title FROM c WHERE FullTextContains(c.title, 'John') OR " \ "FullTextContains(c.text, 'John') ORDER BY RANK FullTextScore(c.title, ['John'])" diff --git a/sdk/cosmos/azure-cosmos/tests/test_query_hybrid_search_async.py b/sdk/cosmos/azure-cosmos/tests/test_query_hybrid_search_async.py index 964d9579a2e9..4223cc6bdd50 100644 --- a/sdk/cosmos/azure-cosmos/tests/test_query_hybrid_search_async.py +++ b/sdk/cosmos/azure-cosmos/tests/test_query_hybrid_search_async.py @@ -1,6 +1,6 @@ # The MIT License (MIT) # Copyright (c) Microsoft Corporation. All rights reserved. - +import os import time import unittest import uuid @@ -98,6 +98,18 @@ async def test_wrong_hybrid_search_queries_async(self): assert ("One of the input values is invalid" in e.message or "Specifying a sort order (ASC or DESC) in the ORDER BY RANK clause is not allowed." in e.message) + async def test_hybrid_search_env_variables_async(self): + os.environ["AZURE_COSMOS_HYBRID_SEARCH_MAX_ITEMS"] = "1" + try: + query = "SELECT TOP 1 c.index, c.title FROM c WHERE FullTextContains(c.title, 'John') OR " \ + "FullTextContains(c.text, 'John') ORDER BY RANK FullTextScore(c.title, ['John'])" + results = self.test_container.query_items(query) + [item async for item in results] + pytest.fail("Config was not applied properly.") + except ValueError as e: + assert e.args[0] == ("Executing a hybrid search query with more items than the max is not allowed. " + "Please ensure you are using a limit smaller than the max, or change the max.") + async def test_hybrid_search_queries_async(self): query = "SELECT TOP 10 c.index, c.title FROM c WHERE FullTextContains(c.title, 'John') OR " \ "FullTextContains(c.text, 'John') ORDER BY RANK FullTextScore(c.title, ['John'])" diff --git a/sdk/cosmos/azure-cosmos/tests/test_query_vector_similarity.py b/sdk/cosmos/azure-cosmos/tests/test_query_vector_similarity.py index e2614c5fb85f..96e3eee02936 100644 --- a/sdk/cosmos/azure-cosmos/tests/test_query_vector_similarity.py +++ b/sdk/cosmos/azure-cosmos/tests/test_query_vector_similarity.py @@ -1,6 +1,6 @@ # The MIT License (MIT) # Copyright (c) Microsoft Corporation. All rights reserved. - +import os import unittest import uuid @@ -121,6 +121,23 @@ def test_wrong_vector_search_queries(self): assert ("One of the input values is invalid." in e.message or "Specifying a sorting order (ASC or DESC) with VectorDistance function is not supported." in e.message) + + def test_vector_search_environment_variables(self): + vector_string = vector_test_data.get_embedding_string("I am having a wonderful day.") + query = "SELECT TOP 10 c.text, VectorDistance(c.embedding, [{}]) AS " \ + "SimilarityScore FROM c ORDER BY VectorDistance(c.embedding, [{}])".format(vector_string, vector_string) + os.environ["AZURE_COSMOS_MAX_ITEM_BUFFER_VECTOR_SEARCH"] = "1" + try: + [item for item in self.created_large_container.query_items(query=query, enable_cross_partition_query=True)] + pytest.fail("Config was not set correctly.") + except ValueError as e: + assert e.args[0] == ("Executing a vector search query with more items than the max is not allowed. " + "Please ensure you are using a limit smaller than the max, or change the max.") + + os.environ["AZURE_COSMOS_MAX_ITEM_BUFFER_VECTOR_SEARCH"] = "50000" + os.environ["AZURE_COSMOS_DISABLE_NON_STREAMING_ORDER_BY"] = "False" + [item for item in self.created_large_container.query_items(query=query, enable_cross_partition_query=True)] + def test_ordering_distances(self): # Besides ordering distances, we also verify that the query text properly replaces any set embedding policies # load up previously calculated embedding for the given string diff --git a/sdk/cosmos/azure-cosmos/tests/test_query_vector_similarity_async.py b/sdk/cosmos/azure-cosmos/tests/test_query_vector_similarity_async.py index 0cb031847a6f..716150358ff3 100644 --- a/sdk/cosmos/azure-cosmos/tests/test_query_vector_similarity_async.py +++ b/sdk/cosmos/azure-cosmos/tests/test_query_vector_similarity_async.py @@ -1,6 +1,6 @@ # The MIT License (MIT) # Copyright (c) Microsoft Corporation. All rights reserved. - +import os import unittest import uuid @@ -127,6 +127,23 @@ async def test_wrong_vector_search_queries_async(self): assert ("One of the input values is invalid." in e.message or "Specifying a sorting order (ASC or DESC) with VectorDistance function is not supported." in e.message) + async def test_vector_search_environment_variables_async(self): + vector_string = vector_test_data.get_embedding_string("I am having a wonderful day.") + query = "SELECT TOP 10 c.text, VectorDistance(c.embedding, [{}]) AS " \ + "SimilarityScore FROM c ORDER BY VectorDistance(c.embedding, [{}])".format(vector_string, vector_string) + os.environ["AZURE_COSMOS_MAX_ITEM_BUFFER_VECTOR_SEARCH"] = "1" + try: + [item async for item in self.created_large_container.query_items(query=query)] + pytest.fail("Config was not set correctly.") + except ValueError as e: + assert e.args[0] == ("Executing a vector search query with more items than the max is not allowed. " + "Please ensure you are using a limit smaller than the max, or change the max.") + + os.environ["AZURE_COSMOS_MAX_ITEM_BUFFER_VECTOR_SEARCH"] = "50000" + + os.environ["AZURE_COSMOS_DISABLE_NON_STREAMING_ORDER_BY"] = "False" + [item async for item in self.created_large_container.query_items(query=query)] + async def test_ordering_distances_async(self): # Besides ordering distances, we also verify that the query text properly replaces any set embedding policies From 9d461220a745a7652ce8a82199a35288d29ba475 Mon Sep 17 00:00:00 2001 From: tvaron3 Date: Tue, 1 Apr 2025 09:18:04 -0700 Subject: [PATCH 23/86] add missing changes --- ...tition_endpoint_manager_circuit_breaker.py | 136 ++++++++++ .../azure/cosmos/_partition_health_tracker.py | 256 ++++++++++++++++++ ..._endpoint_manager_circuit_breaker_async.py | 136 ++++++++++ .../samples/excluded_locations.py | 110 ++++++++ 4 files changed, 638 insertions(+) create mode 100644 sdk/cosmos/azure-cosmos/azure/cosmos/_global_partition_endpoint_manager_circuit_breaker.py create mode 100644 sdk/cosmos/azure-cosmos/azure/cosmos/_partition_health_tracker.py create mode 100644 sdk/cosmos/azure-cosmos/azure/cosmos/aio/_global_partition_endpoint_manager_circuit_breaker_async.py create mode 100644 sdk/cosmos/azure-cosmos/samples/excluded_locations.py diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/_global_partition_endpoint_manager_circuit_breaker.py b/sdk/cosmos/azure-cosmos/azure/cosmos/_global_partition_endpoint_manager_circuit_breaker.py new file mode 100644 index 000000000000..849630c83d4d --- /dev/null +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/_global_partition_endpoint_manager_circuit_breaker.py @@ -0,0 +1,136 @@ +# The MIT License (MIT) +# Copyright (c) 2021 Microsoft Corporation + +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: + +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. + +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +"""Internal class for global endpoint manager for circuit breaker. +""" +import logging +import os +from typing import TYPE_CHECKING + +from azure.cosmos import documents + +from azure.cosmos._partition_health_tracker import PartitionHealthTracker +from azure.cosmos._routing.routing_range import Range, PartitionKeyRangeWrapper +from azure.cosmos.aio._global_endpoint_manager_async import _GlobalEndpointManager +from azure.cosmos._location_cache import EndpointOperationType +from azure.cosmos._request_object import RequestObject +from azure.cosmos.http_constants import ResourceType, HttpHeaders +from azure.cosmos._constants import _Constants as Constants +if TYPE_CHECKING: + from azure.cosmos.aio._cosmos_client_connection_async import CosmosClientConnection + + +logger = logging.getLogger("azure.cosmos._GlobalEndpointManagerForCircuitBreaker") + +class _GlobalPartitionEndpointManagerForCircuitBreaker(_GlobalEndpointManager): + """ + This internal class implements the logic for partition endpoint management for + geo-replicated database accounts. + """ + + + def __init__(self, client: "CosmosClientConnection"): + super(_GlobalPartitionEndpointManagerForCircuitBreaker, self).__init__(client) + self.partition_health_tracker = PartitionHealthTracker() + + def is_circuit_breaker_applicable(self, request: RequestObject) -> bool: + """ + Check if circuit breaker is applicable for a request. + """ + if not request: + return False + + circuit_breaker_enabled = os.environ.get(Constants.CIRCUIT_BREAKER_ENABLED_CONFIG, + Constants.CIRCUIT_BREAKER_ENABLED_CONFIG_DEFAULT) == "True" + if not circuit_breaker_enabled: + return False + + if (not self.location_cache.can_use_multiple_write_locations_for_request(request) + and documents._OperationType.IsWriteOperation(request.operation_type)): + return False + + if request.resource_type != ResourceType.Document: + return False + + if request.operation_type != documents._OperationType.QueryPlan: + return False + + return True + + def _create_pkrange_wrapper(self, request: RequestObject) -> PartitionKeyRangeWrapper: + """ + Create a PartitionKeyRangeWrapper object. + """ + container_rid = request.headers[HttpHeaders.IntendedCollectionRID] + partition_key = request.headers[HttpHeaders.PartitionKey] + # get the partition key range for the given partition key + target_container_link = None + for container_link, properties in self.client._container_properties_cache: + # TODO: @tvaron3 consider moving this to a constant with other usages + if properties["_rid"] == container_rid: + target_container_link = container_link + # throw exception if it is not found + pkrange = self.client._routing_map_provider.get_overlapping_range(target_container_link, partition_key) + return PartitionKeyRangeWrapper(pkrange, container_rid) + + def record_failure( + self, + request: RequestObject + ) -> None: + if self.is_circuit_breaker_applicable(request): + #convert operation_type to EndpointOperationType + endpoint_operation_type = EndpointOperationType.WriteType if ( + documents._OperationType.IsWriteOperation(request.operation_type)) else EndpointOperationType.ReadType + location = self.location_cache.get_location_from_endpoint(request.location_endpoint_to_route) + pkrange_wrapper = self._create_pkrange_wrapper(request) + self.partition_health_tracker.add_failure(pkrange_wrapper, endpoint_operation_type, location) + + def resolve_service_endpoint(self, request): + if self.is_circuit_breaker_applicable(request): + pkrange_wrapper = self._create_pkrange_wrapper(request) + request.set_excluded_locations_from_circuit_breaker( + self.partition_health_tracker.get_excluded_locations(pkrange_wrapper) + ) + return super(_GlobalPartitionEndpointManagerForCircuitBreaker, self).resolve_service_endpoint(request) + + + def mark_partition_unavailable(self, request: RequestObject) -> None: + """ + Mark the partition unavailable from the given request. + """ + location = self.location_cache.get_location_from_endpoint(request.location_endpoint_to_route) + pkrange_wrapper = self._create_pkrange_wrapper(request) + self.partition_health_tracker.mark_partition_unavailable(pkrange_wrapper, location) + + def record_success( + self, + request: RequestObject + ) -> None: + if self.is_circuit_breaker_applicable(request): + #convert operation_type to either Read or Write + endpoint_operation_type = EndpointOperationType.WriteType if ( + documents._OperationType.IsWriteOperation(request.operation_type)) else EndpointOperationType.ReadType + location = self.location_cache.get_location_from_endpoint(request.location_endpoint_to_route) + pkrange_wrapper = self._create_pkrange_wrapper(request) + self.partition_health_tracker.add_success(pkrange_wrapper, endpoint_operation_type, location) + +# TODO: @tvaron3 there should be no in region retries when trying on healthy tentative ----------------------- + diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/_partition_health_tracker.py b/sdk/cosmos/azure-cosmos/azure/cosmos/_partition_health_tracker.py new file mode 100644 index 000000000000..d7c5c4a2cb31 --- /dev/null +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/_partition_health_tracker.py @@ -0,0 +1,256 @@ +# The MIT License (MIT) +# Copyright (c) 2021 Microsoft Corporation + +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: + +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. + +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +"""Internal class for partition health tracker for circuit breaker. +""" +import os +from typing import Dict, Set, Any +from ._constants import _Constants as Constants +from azure.cosmos._location_cache import current_time_millis, EndpointOperationType +from azure.cosmos._routing.routing_range import PartitionKeyRangeWrapper, Range + + +MINIMUM_REQUESTS_FOR_FAILURE_RATE = 100 +REFRESH_INTERVAL = 60 * 1000 # milliseconds +INITIAL_UNAVAILABLE_TIME = 60 * 1000 # milliseconds +# partition is unhealthy if sdk tried to recover and failed +UNHEALTHY = "unhealthy" +# partition is unhealthy tentative when it initially marked unavailable +UNHEALTHY_TENTATIVE = "unhealthy_tentative" +# partition is healthy tentative when sdk is trying to recover +HEALTHY_TENTATIVE = "healthy_tentative" +# unavailability info keys +LAST_UNAVAILABILITY_CHECK_TIME_STAMP = "lastUnavailabilityCheckTimeStamp" +HEALTH_STATUS = "healthStatus" + + +def _has_exceeded_failure_rate_threshold( + successes: int, + failures: int, + failure_rate_threshold: int, +) -> bool: + if successes + failures < MINIMUM_REQUESTS_FOR_FAILURE_RATE: + return False + return (failures / successes * 100) >= failure_rate_threshold + +class _PartitionHealthInfo(object): + """ + This internal class keeps the health and statistics for a partition. + """ + + def __init__(self) -> None: + self.write_failure_count: int = 0 + self.read_failure_count: int = 0 + self.write_success_count: int = 0 + self.read_success_count: int = 0 + self.read_consecutive_failure_count: int = 0 + self.write_consecutive_failure_count: int = 0 + self.unavailability_info: Dict[str, Any] = {} + + + def reset_health_stats(self) -> None: + self.write_failure_count = 0 + self.read_failure_count = 0 + self.write_success_count = 0 + self.read_success_count = 0 + self.read_consecutive_failure_count = 0 + self.write_consecutive_failure_count = 0 + + +class PartitionHealthTracker(object): + """ + This internal class implements the logic for tracking health thresholds for a partition. + """ + + + def __init__(self) -> None: + # partition -> regions -> health info + self.pkrange_wrapper_to_health_info: Dict[PartitionKeyRangeWrapper, Dict[str, _PartitionHealthInfo]] = {} + self.last_refresh = current_time_millis() + + # TODO: @tvaron3 look for useful places to add logs + + def mark_partition_unavailable(self, pkrange_wrapper: PartitionKeyRangeWrapper, location: str) -> None: + # mark the partition key range as unavailable + self._transition_health_status_on_failure(pkrange_wrapper, location) + + def _transition_health_status_on_failure( + self, + pkrange_wrapper: PartitionKeyRangeWrapper, + location: str + ) -> None: + current_time = current_time_millis() + if pkrange_wrapper not in self.pkrange_wrapper_to_health_info: + # healthy -> unhealthy tentative + partition_health_info = _PartitionHealthInfo() + partition_health_info.unavailability_info = { + LAST_UNAVAILABILITY_CHECK_TIME_STAMP: current_time, + HEALTH_STATUS: UNHEALTHY_TENTATIVE + } + self.pkrange_wrapper_to_health_info[pkrange_wrapper] = { + location: partition_health_info + } + else: + region_to_partition_health = self.pkrange_wrapper_to_health_info[pkrange_wrapper] + if location in region_to_partition_health: + # healthy tentative -> unhealthy + # if the operation type is not empty, we are in the healthy tentative state + self.pkrange_wrapper_to_health_info[pkrange_wrapper][location].unavailability_info[HEALTH_STATUS] = UNHEALTHY + # reset the last unavailability check time stamp + self.pkrange_wrapper_to_health_info[pkrange_wrapper][location].unavailability_info[LAST_UNAVAILABILITY_CHECK_TIME_STAMP] = UNHEALTHY + else: + # healthy -> unhealthy tentative + # if the operation type is empty, we are in the unhealthy tentative state + partition_health_info = _PartitionHealthInfo() + partition_health_info.unavailability_info = { + LAST_UNAVAILABILITY_CHECK_TIME_STAMP: current_time, + HEALTH_STATUS: UNHEALTHY_TENTATIVE + } + self.pkrange_wrapper_to_health_info[pkrange_wrapper][location] = partition_health_info + + def _transition_health_status_on_success( + self, + pkrange_wrapper: PartitionKeyRangeWrapper, + location: str + ) -> None: + if pkrange_wrapper in self.pkrange_wrapper_to_health_info: + # healthy tentative -> healthy + self.pkrange_wrapper_to_health_info[pkrange_wrapper].pop(location, None) + + def _check_stale_partition_info(self, pkrange_wrapper: PartitionKeyRangeWrapper) -> None: + current_time = current_time_millis() + + stale_partition_unavailability_check = int(os.getenv(Constants.STALE_PARTITION_UNAVAILABILITY_CHECK, + Constants.STALE_PARTITION_UNAVAILABILITY_CHECK_DEFAULT)) * 1000 + if pkrange_wrapper in self.pkrange_wrapper_to_health_info: + for location, partition_health_info in self.pkrange_wrapper_to_health_info[pkrange_wrapper].items(): + elapsed_time = current_time - partition_health_info.unavailability_info[LAST_UNAVAILABILITY_CHECK_TIME_STAMP] + current_health_status = partition_health_info.unavailability_info[HEALTH_STATUS] + # check if the partition key range is still unavailable + if ((current_health_status == UNHEALTHY and elapsed_time > stale_partition_unavailability_check) + or (current_health_status == UNHEALTHY_TENTATIVE + and elapsed_time > INITIAL_UNAVAILABLE_TIME)): + # unhealthy or unhealthy tentative -> healthy tentative + self.pkrange_wrapper_to_health_info[pkrange_wrapper][location].unavailability_info[HEALTH_STATUS] = HEALTHY_TENTATIVE + + if current_time - self.last_refresh < REFRESH_INTERVAL: + # all partition stats reset every minute + self._reset_partition_health_tracker_stats() + + + def get_excluded_locations(self, pkrange_wrapper: PartitionKeyRangeWrapper) -> Set[str]: + self._check_stale_partition_info(pkrange_wrapper) + if pkrange_wrapper in self.pkrange_wrapper_to_health_info: + return set(self.pkrange_wrapper_to_health_info[pkrange_wrapper].keys()) + else: + return set() + + + def add_failure(self, pkrange_wrapper: PartitionKeyRangeWrapper, operation_type: str, location: str) -> None: + # Retrieve the failure rate threshold from the environment. + failure_rate_threshold = int(os.getenv(Constants.FAILURE_PERCENTAGE_TOLERATED, + Constants.FAILURE_PERCENTAGE_TOLERATED_DEFAULT)) + + # Ensure that the health info dictionary is properly initialized. + if pkrange_wrapper not in self.pkrange_wrapper_to_health_info: + self.pkrange_wrapper_to_health_info[pkrange_wrapper] = {} + if location not in self.pkrange_wrapper_to_health_info[pkrange_wrapper]: + self.pkrange_wrapper_to_health_info[pkrange_wrapper][location] = _PartitionHealthInfo() + + health_info = self.pkrange_wrapper_to_health_info[pkrange_wrapper][location] + + # Determine attribute names and environment variables based on the operation type. + if operation_type == EndpointOperationType.WriteType: + success_attr = 'write_success_count' + failure_attr = 'write_failure_count' + consecutive_attr = 'write_consecutive_failure_count' + env_key = Constants.CONSECUTIVE_ERROR_COUNT_TOLERATED_FOR_WRITE + default_consec_threshold = Constants.CONSECUTIVE_ERROR_COUNT_TOLERATED_FOR_WRITE_DEFAULT + else: + success_attr = 'read_success_count' + failure_attr = 'read_failure_count' + consecutive_attr = 'read_consecutive_failure_count' + env_key = Constants.CONSECUTIVE_ERROR_COUNT_TOLERATED_FOR_READ + default_consec_threshold = Constants.CONSECUTIVE_ERROR_COUNT_TOLERATED_FOR_READ_DEFAULT + + # Increment failure and consecutive failure counts. + setattr(health_info, failure_attr, getattr(health_info, failure_attr) + 1) + setattr(health_info, consecutive_attr, getattr(health_info, consecutive_attr) + 1) + + # Retrieve the consecutive failure threshold from the environment. + consecutive_failure_threshold = int(os.getenv(env_key, default_consec_threshold)) + + # Call the threshold checker with the current stats. + self._check_thresholds( + pkrange_wrapper, + getattr(health_info, success_attr), + getattr(health_info, failure_attr), + getattr(health_info, consecutive_attr), + location, + failure_rate_threshold, + consecutive_failure_threshold + ) + + def _check_thresholds( + self, + pkrange_wrapper: PartitionKeyRangeWrapper, + successes: int, + failures: int, + consecutive_failures: int, + location: str, + failure_rate_threshold: int, + consecutive_failure_threshold: int, + ) -> None: + + # check the failure rate was not exceeded + if _has_exceeded_failure_rate_threshold( + successes, + failures, + failure_rate_threshold + ): + self._transition_health_status_on_failure(pkrange_wrapper, location) + + # add to consecutive failures and check that threshold was not exceeded + if consecutive_failures >= consecutive_failure_threshold: + self._transition_health_status_on_failure(pkrange_wrapper, location) + + def add_success(self, pkrange_wrapper: PartitionKeyRangeWrapper, operation_type: str, location: str) -> None: + # Ensure that the health info dictionary is initialized. + if pkrange_wrapper not in self.pkrange_wrapper_to_health_info: + self.pkrange_wrapper_to_health_info[pkrange_wrapper] = {} + if location not in self.pkrange_wrapper_to_health_info[pkrange_wrapper]: + self.pkrange_wrapper_to_health_info[pkrange_wrapper][location] = _PartitionHealthInfo() + + health_info = self.pkrange_wrapper_to_health_info[pkrange_wrapper][location] + + if operation_type == EndpointOperationType.WriteType: + health_info.write_success_count += 1 + health_info.write_consecutive_failure_count = 0 + else: + health_info.read_success_count += 1 + health_info.read_consecutive_failure_count = 0 + self._transition_health_status_on_success(pkrange_wrapper, operation_type) + + + def _reset_partition_health_tracker_stats(self) -> None: + for pkrange_wrapper in self.pkrange_wrapper_to_health_info: + for location in self.pkrange_wrapper_to_health_info[pkrange_wrapper]: + self.pkrange_wrapper_to_health_info[pkrange_wrapper][location].reset_health_stats() diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_global_partition_endpoint_manager_circuit_breaker_async.py b/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_global_partition_endpoint_manager_circuit_breaker_async.py new file mode 100644 index 000000000000..849630c83d4d --- /dev/null +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_global_partition_endpoint_manager_circuit_breaker_async.py @@ -0,0 +1,136 @@ +# The MIT License (MIT) +# Copyright (c) 2021 Microsoft Corporation + +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: + +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. + +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +"""Internal class for global endpoint manager for circuit breaker. +""" +import logging +import os +from typing import TYPE_CHECKING + +from azure.cosmos import documents + +from azure.cosmos._partition_health_tracker import PartitionHealthTracker +from azure.cosmos._routing.routing_range import Range, PartitionKeyRangeWrapper +from azure.cosmos.aio._global_endpoint_manager_async import _GlobalEndpointManager +from azure.cosmos._location_cache import EndpointOperationType +from azure.cosmos._request_object import RequestObject +from azure.cosmos.http_constants import ResourceType, HttpHeaders +from azure.cosmos._constants import _Constants as Constants +if TYPE_CHECKING: + from azure.cosmos.aio._cosmos_client_connection_async import CosmosClientConnection + + +logger = logging.getLogger("azure.cosmos._GlobalEndpointManagerForCircuitBreaker") + +class _GlobalPartitionEndpointManagerForCircuitBreaker(_GlobalEndpointManager): + """ + This internal class implements the logic for partition endpoint management for + geo-replicated database accounts. + """ + + + def __init__(self, client: "CosmosClientConnection"): + super(_GlobalPartitionEndpointManagerForCircuitBreaker, self).__init__(client) + self.partition_health_tracker = PartitionHealthTracker() + + def is_circuit_breaker_applicable(self, request: RequestObject) -> bool: + """ + Check if circuit breaker is applicable for a request. + """ + if not request: + return False + + circuit_breaker_enabled = os.environ.get(Constants.CIRCUIT_BREAKER_ENABLED_CONFIG, + Constants.CIRCUIT_BREAKER_ENABLED_CONFIG_DEFAULT) == "True" + if not circuit_breaker_enabled: + return False + + if (not self.location_cache.can_use_multiple_write_locations_for_request(request) + and documents._OperationType.IsWriteOperation(request.operation_type)): + return False + + if request.resource_type != ResourceType.Document: + return False + + if request.operation_type != documents._OperationType.QueryPlan: + return False + + return True + + def _create_pkrange_wrapper(self, request: RequestObject) -> PartitionKeyRangeWrapper: + """ + Create a PartitionKeyRangeWrapper object. + """ + container_rid = request.headers[HttpHeaders.IntendedCollectionRID] + partition_key = request.headers[HttpHeaders.PartitionKey] + # get the partition key range for the given partition key + target_container_link = None + for container_link, properties in self.client._container_properties_cache: + # TODO: @tvaron3 consider moving this to a constant with other usages + if properties["_rid"] == container_rid: + target_container_link = container_link + # throw exception if it is not found + pkrange = self.client._routing_map_provider.get_overlapping_range(target_container_link, partition_key) + return PartitionKeyRangeWrapper(pkrange, container_rid) + + def record_failure( + self, + request: RequestObject + ) -> None: + if self.is_circuit_breaker_applicable(request): + #convert operation_type to EndpointOperationType + endpoint_operation_type = EndpointOperationType.WriteType if ( + documents._OperationType.IsWriteOperation(request.operation_type)) else EndpointOperationType.ReadType + location = self.location_cache.get_location_from_endpoint(request.location_endpoint_to_route) + pkrange_wrapper = self._create_pkrange_wrapper(request) + self.partition_health_tracker.add_failure(pkrange_wrapper, endpoint_operation_type, location) + + def resolve_service_endpoint(self, request): + if self.is_circuit_breaker_applicable(request): + pkrange_wrapper = self._create_pkrange_wrapper(request) + request.set_excluded_locations_from_circuit_breaker( + self.partition_health_tracker.get_excluded_locations(pkrange_wrapper) + ) + return super(_GlobalPartitionEndpointManagerForCircuitBreaker, self).resolve_service_endpoint(request) + + + def mark_partition_unavailable(self, request: RequestObject) -> None: + """ + Mark the partition unavailable from the given request. + """ + location = self.location_cache.get_location_from_endpoint(request.location_endpoint_to_route) + pkrange_wrapper = self._create_pkrange_wrapper(request) + self.partition_health_tracker.mark_partition_unavailable(pkrange_wrapper, location) + + def record_success( + self, + request: RequestObject + ) -> None: + if self.is_circuit_breaker_applicable(request): + #convert operation_type to either Read or Write + endpoint_operation_type = EndpointOperationType.WriteType if ( + documents._OperationType.IsWriteOperation(request.operation_type)) else EndpointOperationType.ReadType + location = self.location_cache.get_location_from_endpoint(request.location_endpoint_to_route) + pkrange_wrapper = self._create_pkrange_wrapper(request) + self.partition_health_tracker.add_success(pkrange_wrapper, endpoint_operation_type, location) + +# TODO: @tvaron3 there should be no in region retries when trying on healthy tentative ----------------------- + diff --git a/sdk/cosmos/azure-cosmos/samples/excluded_locations.py b/sdk/cosmos/azure-cosmos/samples/excluded_locations.py new file mode 100644 index 000000000000..06228c1a8cea --- /dev/null +++ b/sdk/cosmos/azure-cosmos/samples/excluded_locations.py @@ -0,0 +1,110 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See LICENSE.txt in the project root for +# license information. +# ------------------------------------------------------------------------- +from azure.cosmos import CosmosClient +from azure.cosmos.partition_key import PartitionKey +import config + +# ---------------------------------------------------------------------------------------------------------- +# Prerequisites - +# +# 1. An Azure Cosmos account - +# https://learn.microsoft.com/azure/cosmos-db/create-sql-api-python#create-a-database-account +# +# 2. Microsoft Azure Cosmos +# pip install azure-cosmos>=4.3.0b4 +# ---------------------------------------------------------------------------------------------------------- +# Sample - demonstrates how to use excluded locations in client level and request level +# ---------------------------------------------------------------------------------------------------------- +# Note: +# This sample creates a Container to your database account. +# Each time a Container is created the account will be billed for 1 hour of usage based on +# the provisioned throughput (RU/s) of that account. +# ---------------------------------------------------------------------------------------------------------- + +HOST = config.settings["host"] +MASTER_KEY = config.settings["master_key"] + +TENANT_ID = config.settings["tenant_id"] +CLIENT_ID = config.settings["client_id"] +CLIENT_SECRET = config.settings["client_secret"] + +DATABASE_ID = config.settings["database_id"] +CONTAINER_ID = config.settings["container_id"] +PARTITION_KEY = PartitionKey(path="/id") + + +def get_test_item(num): + test_item = { + 'id': 'Item_' + str(num), + 'test_object': True, + 'lastName': 'Smith' + } + return test_item + +def clean_up_db(client): + try: + client.delete_database(DATABASE_ID) + except Exception as e: + pass + +def excluded_locations_client_level_sample(): + preferred_locations = ['West US 3', 'West US', 'East US 2'] + excluded_locations = ['West US 3', 'West US'] + client = CosmosClient( + HOST, + MASTER_KEY, + preferred_locations=preferred_locations, + excluded_locations=excluded_locations + ) + clean_up_db(client) + + db = client.create_database(DATABASE_ID) + container = db.create_container(id=CONTAINER_ID, partition_key=PARTITION_KEY) + + # For write operations with single master account, write endpoint will be the default endpoint, + # since preferred_locations or excluded_locations are ignored and used + container.create_item(get_test_item(0)) + + # For read operations, read endpoints will be 'preferred_locations' - 'excluded_locations'. + # In our sample, ['West US 3', 'West US', 'East US 2'] - ['West US 3', 'West US'] => ['East US 2'], + # therefore 'East US 2' will be the read endpoint, and items will be read from 'East US 2' location + item = container.read_item(item='Item_0', partition_key='Item_0') + + clean_up_db(client) + +def excluded_locations_request_level_sample(): + preferred_locations = ['West US 3', 'West US', 'East US 2'] + excluded_locations_on_client = ['West US 3', 'West US'] + excluded_locations_on_request = ['West US 3'] + client = CosmosClient( + HOST, + MASTER_KEY, + preferred_locations=preferred_locations, + excluded_locations=excluded_locations_on_client + ) + clean_up_db(client) + + db = client.create_database(DATABASE_ID) + container = db.create_container(id=CONTAINER_ID, partition_key=PARTITION_KEY) + + # For write operations with single master account, write endpoint will be the default endpoint, + # since preferred_locations or excluded_locations are ignored and used + container.create_item(get_test_item(0)) + + # For read operations, read endpoints will be 'preferred_locations' - 'excluded_locations'. + # However, in our sample, since the excluded_locations` were passed with the read request, the `excluded_location` + # will be replaced with the locations from request, ['West US 3']. The `excluded_locations` on request always takes + # the highest priority! + # With the excluded_locations on request, the read endpoints will be ['West US', 'East US 2'] + # ['West US 3', 'West US', 'East US 2'] - ['West US 3'] => ['West US', 'East US 2'] + # Therefore, items will be read from 'West US' or 'East US 2' location + item = container.read_item(item='Item_0', partition_key='Item_0', excluded_locations=excluded_locations_on_request) + + clean_up_db(client) + +if __name__ == "__main__": + # excluded_locations_client_level_sample() + excluded_locations_request_level_sample() From 4efa9ad0ed683b5e0553ec1c8c733d5f1e0ba3c5 Mon Sep 17 00:00:00 2001 From: tvaron3 Date: Tue, 1 Apr 2025 14:38:01 -0700 Subject: [PATCH 24/86] fix mypy errors --- .../tests/_fault_injection_transport.py | 22 ++++++++--------- .../test_fault_injection_transport_async.py | 24 +++++++++---------- 2 files changed, 23 insertions(+), 23 deletions(-) diff --git a/sdk/cosmos/azure-cosmos/tests/_fault_injection_transport.py b/sdk/cosmos/azure-cosmos/tests/_fault_injection_transport.py index 37c00667544a..7f83608b3cd7 100644 --- a/sdk/cosmos/azure-cosmos/tests/_fault_injection_transport.py +++ b/sdk/cosmos/azure-cosmos/tests/_fault_injection_transport.py @@ -27,7 +27,7 @@ import logging import sys from collections.abc import MutableMapping -from typing import Callable, Optional +from typing import Callable, Optional, Any import aiohttp from azure.core.pipeline.transport import AioHttpTransport, AioHttpTransportResponse @@ -41,7 +41,7 @@ class FaultInjectionTransport(AioHttpTransport): logger = logging.getLogger('azure.cosmos.fault_injection_transport') logger.setLevel(logging.DEBUG) - def __init__(self, *, session: aiohttp.ClientSession | None = None, loop=None, session_owner: bool = True, **config): + def __init__(self, *, session: Optional[aiohttp.ClientSession] = None, loop=None, session_owner: bool = True, **config): self.faults = [] self.requestTransformations = [] self.responseTransformations = [] @@ -64,7 +64,7 @@ def __first_item(iterable, condition=lambda x: True): """ return next((x for x in iterable if condition(x)), None) - async def send(self, request: HttpRequest, *, stream: bool = False, proxies: MutableMapping[str, str] | None = None, **config) -> AsyncHttpResponse: + async def send(self, request: HttpRequest, *, stream: bool = False, proxies: Optional[MutableMapping[str, str]] = None, **config) -> AsyncHttpResponse: FaultInjectionTransport.logger.info("--> FaultInjectionTransport.Send {} {}".format(request.method, request.url)) # find the first fault Factory with matching predicate if any first_fault_factory = FaultInjectionTransport.__first_item(iter(self.faults), lambda f: f["predicate"](request)) @@ -201,7 +201,7 @@ async def transform_topology_mwr( first_region_name: str, second_region_name: str, r: HttpRequest, - inner: Callable[[], asyncio.Task[AsyncHttpResponse]]) -> asyncio.Task[AsyncHttpResponse]: + inner: Callable[[], asyncio.Task[AsyncHttpResponse]]) -> AsyncHttpResponse: response = await inner() if not FaultInjectionTransport.predicate_is_database_account_call(response.request): @@ -228,7 +228,7 @@ async def transform_topology_mwr( return response class MockHttpResponse(AioHttpTransportResponse): - def __init__(self, request: HttpRequest, status_code: int, content:Optional[dict[str, any]]): + def __init__(self, request: HttpRequest, status_code: int, content:Optional[dict[str, Any]]): self.request: HttpRequest = request # This is actually never None, and set by all implementations after the call to # __init__ of this class. This class is also a legacy impl, so it's risky to change it @@ -239,19 +239,19 @@ def __init__(self, request: HttpRequest, status_code: int, content:Optional[dict self.reason: Optional[str] = None self.content_type: Optional[str] = None self.block_size: int = 4096 # Default to same as R - self.content: Optional[dict[str, any]] = None + self.content: Optional[dict[str, Any]] = None self.json_text: Optional[str] = None self.bytes: Optional[bytes] = None if content: - self.content:Optional[dict[str, any]] = content - self.json_text:Optional[str] = json.dumps(content) - self.bytes:bytes = self.json_text.encode("utf-8") + self.content = content + self.json_text = json.dumps(content) + self.bytes = self.json_text.encode("utf-8") - def body(self) -> bytes: + def body(self) -> Optional[bytes]: return self.bytes - def text(self, encoding: Optional[str] = None) -> str: + def text(self, encoding: Optional[str] = None) -> Optional[str]: return self.json_text async def load_body(self) -> None: diff --git a/sdk/cosmos/azure-cosmos/tests/test_fault_injection_transport_async.py b/sdk/cosmos/azure-cosmos/tests/test_fault_injection_transport_async.py index d09f017febbf..71b6824ef240 100644 --- a/sdk/cosmos/azure-cosmos/tests/test_fault_injection_transport_async.py +++ b/sdk/cosmos/azure-cosmos/tests/test_fault_injection_transport_async.py @@ -61,7 +61,7 @@ def setup_class(cls): cls.mgmt_client = CosmosClient(host, master_key, consistency_level="Session", connection_policy=connection_policy, logger=logger) - created_database: DatabaseProxy = cls.mgmt_client.get_database_client(cls.database_id) + created_database = cls.mgmt_client.get_database_client(cls.database_id) asyncio.run(asyncio.wait_for( created_database.create_container( cls.single_partition_container_name, @@ -71,7 +71,7 @@ def setup_class(cls): @classmethod def teardown_class(cls): logger.info("tearing down class: {}".format(cls.__name__)) - created_database: DatabaseProxy = cls.mgmt_client.get_database_client(cls.database_id) + created_database = cls.mgmt_client.get_database_client(cls.database_id) try: asyncio.run(asyncio.wait_for( created_database.delete_container(cls.single_partition_container_name), @@ -100,7 +100,7 @@ def cleanup_method(initialized_objects: dict[str, Any]): except Exception as close_error: logger.warning(f"Exception trying to close method client. {close_error}") - async def test_throws_injected_error(self, setup): + async def test_throws_injected_error(self, setup: object): id_value: str = str(uuid.uuid4()) document_definition = {'id': id_value, 'pk': id_value, @@ -126,7 +126,7 @@ async def test_throws_injected_error(self, setup): finally: TestFaultInjectionTransportAsync.cleanup_method(initialized_objects) - async def test_swr_mrr_succeeds(self, setup): + async def test_swr_mrr_succeeds(self, setup: object): expected_read_region_uri: str = test_config.TestConfig.local_host expected_write_region_uri: str = expected_read_region_uri.replace("localhost", "127.0.0.1") custom_transport = FaultInjectionTransport() @@ -167,18 +167,18 @@ async def test_swr_mrr_succeeds(self, setup): request: HttpRequest = created_document.get_response_headers()["_request"] # Validate the response comes from "Write Region" (the write region) assert request.url.startswith(expected_write_region_uri) - start:float = time.perf_counter() + start: float = time.perf_counter() while (time.perf_counter() - start) < 2: read_document = await container.read_item(id_value, partition_key=id_value) - request: HttpRequest = read_document.get_response_headers()["_request"] + request = read_document.get_response_headers()["_request"] # Validate the response comes from "Read Region" (the most preferred read-only region) assert request.url.startswith(expected_read_region_uri) finally: TestFaultInjectionTransportAsync.cleanup_method(initialized_objects) - async def test_swr_mrr_region_down_read_succeeds(self, setup): + async def test_swr_mrr_region_down_read_succeeds(self, setup: object): expected_read_region_uri: str = test_config.TestConfig.local_host expected_write_region_uri: str = expected_read_region_uri.replace("localhost", "127.0.0.1") custom_transport = FaultInjectionTransport() @@ -232,14 +232,14 @@ async def test_swr_mrr_region_down_read_succeeds(self, setup): while (time.perf_counter() - start) < 2: read_document = await container.read_item(id_value, partition_key=id_value) - request: HttpRequest = read_document.get_response_headers()["_request"] + request = read_document.get_response_headers()["_request"] # Validate the response comes from "Write Region" ("Read Region" the most preferred read-only region is down) assert request.url.startswith(expected_write_region_uri) finally: TestFaultInjectionTransportAsync.cleanup_method(initialized_objects) - async def test_swr_mrr_region_down_envoy_read_succeeds(self, setup): + async def test_swr_mrr_region_down_envoy_read_succeeds(self, setup: object): expected_read_region_uri: str = test_config.TestConfig.local_host expected_write_region_uri: str = expected_read_region_uri.replace("localhost", "127.0.0.1") custom_transport = FaultInjectionTransport() @@ -298,7 +298,7 @@ async def test_swr_mrr_region_down_envoy_read_succeeds(self, setup): while (time.perf_counter() - start) < 2: read_document = await container.read_item(id_value, partition_key=id_value) - request: HttpRequest = read_document.get_response_headers()["_request"] + request = read_document.get_response_headers()["_request"] # Validate the response comes from "Write Region" ("Read Region" the most preferred read-only region is down) assert request.url.startswith(expected_write_region_uri) @@ -306,7 +306,7 @@ async def test_swr_mrr_region_down_envoy_read_succeeds(self, setup): TestFaultInjectionTransportAsync.cleanup_method(initialized_objects) - async def test_mwr_succeeds(self, setup): + async def test_mwr_succeeds(self, setup: object): first_region_uri: str = test_config.TestConfig.local_host.replace("localhost", "127.0.0.1") second_region_uri: str = test_config.TestConfig.local_host custom_transport = FaultInjectionTransport() @@ -344,7 +344,7 @@ async def test_mwr_succeeds(self, setup): while (time.perf_counter() - start) < 2: read_document = await container.read_item(id_value, partition_key=id_value) - request: HttpRequest = read_document.get_response_headers()["_request"] + request = read_document.get_response_headers()["_request"] # Validate the response comes from "East US" (the most preferred read-only region) assert request.url.startswith(first_region_uri) From d86d381c66b54876c6a1f1773a05edc742361680 Mon Sep 17 00:00:00 2001 From: tvaron3 Date: Wed, 2 Apr 2025 10:01:53 -0700 Subject: [PATCH 25/86] Refactored gem for ppcb and hooked up retryconfigurations with failure tracking --- sdk/cosmos/azure-cosmos/CHANGELOG.md | 2 +- ...tition_endpoint_manager_circuit_breaker.py | 80 ++--------- ...n_endpoint_manager_circuit_breaker_core.py | 131 ++++++++++++++++++ .../azure/cosmos/_location_cache.py | 4 +- .../azure/cosmos/_partition_health_tracker.py | 9 +- .../azure/cosmos/_request_object.py | 11 +- .../azure/cosmos/aio/_asynchronous_request.py | 4 + .../aio/_cosmos_client_connection_async.py | 4 +- ..._endpoint_manager_circuit_breaker_async.py | 87 ++---------- .../azure/cosmos/aio/_retry_utility_async.py | 12 +- 10 files changed, 179 insertions(+), 165 deletions(-) create mode 100644 sdk/cosmos/azure-cosmos/azure/cosmos/_global_partition_endpoint_manager_circuit_breaker_core.py diff --git a/sdk/cosmos/azure-cosmos/CHANGELOG.md b/sdk/cosmos/azure-cosmos/CHANGELOG.md index c5f836ce2a03..475203a1e213 100644 --- a/sdk/cosmos/azure-cosmos/CHANGELOG.md +++ b/sdk/cosmos/azure-cosmos/CHANGELOG.md @@ -3,7 +3,7 @@ ### 4.10.0b3 (Unreleased) #### Features Added -* Per partition circuit breaker support. It can be enabled through environment variable `AZURE_COSMOS_ENABLE_CIRCUIT_BREAKER`. See +* Per partition circuit breaker support. It can be enabled through the environment variable `AZURE_COSMOS_ENABLE_CIRCUIT_BREAKER`. See [PR 40302](https://github.com/Azure/azure-sdk-for-python/pull/40302) #### Breaking Changes diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/_global_partition_endpoint_manager_circuit_breaker.py b/sdk/cosmos/azure-cosmos/azure/cosmos/_global_partition_endpoint_manager_circuit_breaker.py index 849630c83d4d..072b946c1402 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/_global_partition_endpoint_manager_circuit_breaker.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/_global_partition_endpoint_manager_circuit_breaker.py @@ -21,24 +21,17 @@ """Internal class for global endpoint manager for circuit breaker. """ -import logging -import os from typing import TYPE_CHECKING -from azure.cosmos import documents +from azure.cosmos._global_partition_endpoint_manager_circuit_breaker_core import \ + _GlobalPartitionEndpointManagerForCircuitBreakerCore -from azure.cosmos._partition_health_tracker import PartitionHealthTracker -from azure.cosmos._routing.routing_range import Range, PartitionKeyRangeWrapper from azure.cosmos.aio._global_endpoint_manager_async import _GlobalEndpointManager -from azure.cosmos._location_cache import EndpointOperationType from azure.cosmos._request_object import RequestObject -from azure.cosmos.http_constants import ResourceType, HttpHeaders -from azure.cosmos._constants import _Constants as Constants if TYPE_CHECKING: - from azure.cosmos.aio._cosmos_client_connection_async import CosmosClientConnection + from azure.cosmos._cosmos_client_connection import CosmosClientConnection -logger = logging.getLogger("azure.cosmos._GlobalEndpointManagerForCircuitBreaker") class _GlobalPartitionEndpointManagerForCircuitBreaker(_GlobalEndpointManager): """ @@ -49,88 +42,33 @@ class _GlobalPartitionEndpointManagerForCircuitBreaker(_GlobalEndpointManager): def __init__(self, client: "CosmosClientConnection"): super(_GlobalPartitionEndpointManagerForCircuitBreaker, self).__init__(client) - self.partition_health_tracker = PartitionHealthTracker() + self.global_partition_endpoint_manager_core = _GlobalPartitionEndpointManagerForCircuitBreakerCore(client, self.location_cache) def is_circuit_breaker_applicable(self, request: RequestObject) -> bool: """ Check if circuit breaker is applicable for a request. """ - if not request: - return False - - circuit_breaker_enabled = os.environ.get(Constants.CIRCUIT_BREAKER_ENABLED_CONFIG, - Constants.CIRCUIT_BREAKER_ENABLED_CONFIG_DEFAULT) == "True" - if not circuit_breaker_enabled: - return False - - if (not self.location_cache.can_use_multiple_write_locations_for_request(request) - and documents._OperationType.IsWriteOperation(request.operation_type)): - return False - - if request.resource_type != ResourceType.Document: - return False - - if request.operation_type != documents._OperationType.QueryPlan: - return False - - return True - - def _create_pkrange_wrapper(self, request: RequestObject) -> PartitionKeyRangeWrapper: - """ - Create a PartitionKeyRangeWrapper object. - """ - container_rid = request.headers[HttpHeaders.IntendedCollectionRID] - partition_key = request.headers[HttpHeaders.PartitionKey] - # get the partition key range for the given partition key - target_container_link = None - for container_link, properties in self.client._container_properties_cache: - # TODO: @tvaron3 consider moving this to a constant with other usages - if properties["_rid"] == container_rid: - target_container_link = container_link - # throw exception if it is not found - pkrange = self.client._routing_map_provider.get_overlapping_range(target_container_link, partition_key) - return PartitionKeyRangeWrapper(pkrange, container_rid) + return self.global_partition_endpoint_manager_core.is_circuit_breaker_applicable(request) def record_failure( self, request: RequestObject ) -> None: - if self.is_circuit_breaker_applicable(request): - #convert operation_type to EndpointOperationType - endpoint_operation_type = EndpointOperationType.WriteType if ( - documents._OperationType.IsWriteOperation(request.operation_type)) else EndpointOperationType.ReadType - location = self.location_cache.get_location_from_endpoint(request.location_endpoint_to_route) - pkrange_wrapper = self._create_pkrange_wrapper(request) - self.partition_health_tracker.add_failure(pkrange_wrapper, endpoint_operation_type, location) + self.global_partition_endpoint_manager_core.record_failure(request) def resolve_service_endpoint(self, request): - if self.is_circuit_breaker_applicable(request): - pkrange_wrapper = self._create_pkrange_wrapper(request) - request.set_excluded_locations_from_circuit_breaker( - self.partition_health_tracker.get_excluded_locations(pkrange_wrapper) - ) + request = self.global_partition_endpoint_manager_core.add_excluded_locations_to_request(request) return super(_GlobalPartitionEndpointManagerForCircuitBreaker, self).resolve_service_endpoint(request) - def mark_partition_unavailable(self, request: RequestObject) -> None: """ Mark the partition unavailable from the given request. """ - location = self.location_cache.get_location_from_endpoint(request.location_endpoint_to_route) - pkrange_wrapper = self._create_pkrange_wrapper(request) - self.partition_health_tracker.mark_partition_unavailable(pkrange_wrapper, location) + self.global_partition_endpoint_manager_core.mark_partition_unavailable(request) def record_success( self, request: RequestObject ) -> None: - if self.is_circuit_breaker_applicable(request): - #convert operation_type to either Read or Write - endpoint_operation_type = EndpointOperationType.WriteType if ( - documents._OperationType.IsWriteOperation(request.operation_type)) else EndpointOperationType.ReadType - location = self.location_cache.get_location_from_endpoint(request.location_endpoint_to_route) - pkrange_wrapper = self._create_pkrange_wrapper(request) - self.partition_health_tracker.add_success(pkrange_wrapper, endpoint_operation_type, location) - -# TODO: @tvaron3 there should be no in region retries when trying on healthy tentative ----------------------- + self.global_partition_endpoint_manager_core.record_success(request) diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/_global_partition_endpoint_manager_circuit_breaker_core.py b/sdk/cosmos/azure-cosmos/azure/cosmos/_global_partition_endpoint_manager_circuit_breaker_core.py new file mode 100644 index 000000000000..23a7c50047ca --- /dev/null +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/_global_partition_endpoint_manager_circuit_breaker_core.py @@ -0,0 +1,131 @@ +# The MIT License (MIT) +# Copyright (c) 2021 Microsoft Corporation + +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: + +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. + +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +"""Internal class for global endpoint manager for circuit breaker. +""" +import logging +import os + +from azure.cosmos import documents + +from azure.cosmos._partition_health_tracker import PartitionHealthTracker +from azure.cosmos._routing.routing_range import PartitionKeyRangeWrapper +from azure.cosmos._location_cache import EndpointOperationType, LocationCache +from azure.cosmos._request_object import RequestObject +from azure.cosmos.http_constants import ResourceType, HttpHeaders +from azure.cosmos._constants import _Constants as Constants + + +logger = logging.getLogger("azure.cosmos._GlobalEndpointManagerForCircuitBreakerCore") + +class _GlobalPartitionEndpointManagerForCircuitBreakerCore(object): + """ + This internal class implements the logic for partition endpoint management for + geo-replicated database accounts. + """ + + + def __init__(self, client, location_cache: LocationCache): + self.partition_health_tracker = PartitionHealthTracker() + self.location_cache = location_cache + self.client = client + + def is_circuit_breaker_applicable(self, request: RequestObject) -> bool: + """ + Check if circuit breaker is applicable for a request. + """ + if not request: + return False + + circuit_breaker_enabled = os.environ.get(Constants.CIRCUIT_BREAKER_ENABLED_CONFIG, + Constants.CIRCUIT_BREAKER_ENABLED_CONFIG_DEFAULT) == "True" + if not circuit_breaker_enabled: + return False + + if (not self.location_cache.can_use_multiple_write_locations_for_request(request) + and documents._OperationType.IsWriteOperation(request.operation_type)): + return False + + if request.resource_type != ResourceType.Document: + return False + + if request.operation_type != documents._OperationType.QueryPlan: + return False + + return True + + def _create_pkrange_wrapper(self, request: RequestObject) -> PartitionKeyRangeWrapper: + """ + Create a PartitionKeyRangeWrapper object. + """ + container_rid = request.headers[HttpHeaders.IntendedCollectionRID] + partition_key = request.headers[HttpHeaders.PartitionKey] + # get the partition key range for the given partition key + target_container_link = None + for container_link, properties in self.client._container_properties_cache: + # TODO: @tvaron3 consider moving this to a constant with other usages + if properties["_rid"] == container_rid: + target_container_link = container_link + # throw exception if it is not found + pkrange = self.client._routing_map_provider.get_overlapping_range(target_container_link, partition_key) + return PartitionKeyRangeWrapper(pkrange, container_rid) + + def record_failure( + self, + request: RequestObject + ) -> None: + if self.is_circuit_breaker_applicable(request): + #convert operation_type to EndpointOperationType + endpoint_operation_type = EndpointOperationType.WriteType if ( + documents._OperationType.IsWriteOperation(request.operation_type)) else EndpointOperationType.ReadType + location = self.location_cache.get_location_from_endpoint(str(request.location_endpoint_to_route)) + pkrange_wrapper = self._create_pkrange_wrapper(request) + self.partition_health_tracker.add_failure(pkrange_wrapper, endpoint_operation_type, location) + + def add_excluded_locations_to_request(self, request: RequestObject) -> RequestObject: + if self.is_circuit_breaker_applicable(request): + pkrange_wrapper = self._create_pkrange_wrapper(request) + request.set_excluded_locations_from_circuit_breaker( + self.partition_health_tracker.get_excluded_locations(pkrange_wrapper) + ) + return request + + def mark_partition_unavailable(self, request: RequestObject) -> None: + """ + Mark the partition unavailable from the given request. + """ + location = self.location_cache.get_location_from_endpoint(request.location_endpoint_to_route) + pkrange_wrapper = self._create_pkrange_wrapper(request) + self.partition_health_tracker.mark_partition_unavailable(pkrange_wrapper, location) + + def record_success( + self, + request: RequestObject + ) -> None: + if self.is_circuit_breaker_applicable(request): + #convert operation_type to either Read or Write + endpoint_operation_type = EndpointOperationType.WriteType if ( + documents._OperationType.IsWriteOperation(request.operation_type)) else EndpointOperationType.ReadType + location = self.location_cache.get_location_from_endpoint(request.location_endpoint_to_route) + pkrange_wrapper = self._create_pkrange_wrapper(request) + self.partition_health_tracker.add_success(pkrange_wrapper, endpoint_operation_type, location) + +# TODO: @tvaron3 there should be no in region retries when trying on healthy tentative ----------------------- \ No newline at end of file diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/_location_cache.py b/sdk/cosmos/azure-cosmos/azure/cosmos/_location_cache.py index d40a99f7c69f..ae6cac6feddc 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/_location_cache.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/_location_cache.py @@ -25,7 +25,7 @@ import collections import logging import time -from typing import Set +from typing import Set, Optional from urllib.parse import urlparse from . import documents @@ -232,13 +232,13 @@ def get_ordered_write_locations(self): def get_ordered_read_locations(self): return self.account_read_locations - # Todo: @tvaron3 client should be appeneded to if using circuitbreaker exclude regions def _get_configured_excluded_locations(self, request): # If excluded locations were configured on request, use request level excluded locations. excluded_locations = request.excluded_locations if excluded_locations is None: # If excluded locations were only configured on client(connection_policy), use client level excluded_locations = self.connection_policy.ExcludedLocations + excluded_locations.union(request.excluded_locations_circuit_breaker) return excluded_locations def get_applicable_read_regional_endpoints(self, request): diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/_partition_health_tracker.py b/sdk/cosmos/azure-cosmos/azure/cosmos/_partition_health_tracker.py index d7c5c4a2cb31..74665a5e7eb5 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/_partition_health_tracker.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/_partition_health_tracker.py @@ -22,7 +22,7 @@ """Internal class for partition health tracker for circuit breaker. """ import os -from typing import Dict, Set, Any +from typing import Dict, Set, Any, Optional from ._constants import _Constants as Constants from azure.cosmos._location_cache import current_time_millis, EndpointOperationType from azure.cosmos._routing.routing_range import PartitionKeyRangeWrapper, Range @@ -164,7 +164,12 @@ def get_excluded_locations(self, pkrange_wrapper: PartitionKeyRangeWrapper) -> S return set() - def add_failure(self, pkrange_wrapper: PartitionKeyRangeWrapper, operation_type: str, location: str) -> None: + def add_failure( + self, + pkrange_wrapper: PartitionKeyRangeWrapper, + operation_type: str, + location: Optional[str] + ) -> None: # Retrieve the failure rate threshold from the environment. failure_rate_threshold = int(os.getenv(Constants.FAILURE_PERCENTAGE_TOLERATED, Constants.FAILURE_PERCENTAGE_TOLERATED_DEFAULT)) diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/_request_object.py b/sdk/cosmos/azure-cosmos/azure/cosmos/_request_object.py index afc9fa4d30a9..50dd4c7fc4d1 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/_request_object.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/_request_object.py @@ -21,7 +21,7 @@ """Represents a request object. """ -from typing import Optional, Mapping, Any, List, Dict +from typing import Optional, Mapping, Any, Dict, Set class RequestObject(object): @@ -41,7 +41,8 @@ def __init__( self.location_index_to_route: Optional[int] = None self.location_endpoint_to_route: Optional[str] = None self.last_routed_location_endpoint_within_region: Optional[str] = None - self.excluded_locations: Optional[List[str]] = None + self.excluded_locations: Optional[Set[str]] = None + self.excluded_locations_circuit_breaker: Set[str] = set() def route_to_location_with_preferred_location_flag( # pylint: disable=name-too-long self, @@ -83,7 +84,5 @@ def set_excluded_location_from_options(self, options: Mapping[str, Any]) -> None if self._can_set_excluded_location(options): self.excluded_locations = options['excludedLocations'] - def set_excluded_locations_from_circuit_breaker(self, excluded_locations: List[str]) -> None: - if self.excluded_locations: - self.excluded_locations.extend(excluded_locations) - self.excluded_locations = excluded_locations + def set_excluded_locations_from_circuit_breaker(self, excluded_locations: Set[str]) -> None: + self.excluded_locations_circuit_breaker = excluded_locations diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_asynchronous_request.py b/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_asynchronous_request.py index 81430d8df42c..25f6ac203d85 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_asynchronous_request.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_asynchronous_request.py @@ -101,6 +101,8 @@ async def _Request(global_endpoint_manager, request_params, connection_policy, p read_timeout=read_timeout, connection_verify=kwargs.pop("connection_verify", ca_certs), connection_cert=kwargs.pop("connection_cert", cert_files), + request_params=request_params, + global_endpoint_manager=global_endpoint_manager, **kwargs ) else: @@ -111,6 +113,8 @@ async def _Request(global_endpoint_manager, request_params, connection_policy, p read_timeout=read_timeout, # If SSL is disabled, verify = false connection_verify=kwargs.pop("connection_verify", is_ssl_enabled), + request_params=request_params, + global_endpoint_manager=global_endpoint_manager, **kwargs ) diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_cosmos_client_connection_async.py b/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_cosmos_client_connection_async.py index 0d3d61e87fa0..8fbdb0fb9a83 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_cosmos_client_connection_async.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_cosmos_client_connection_async.py @@ -54,7 +54,7 @@ from .. import documents from .._change_feed.aio.change_feed_iterable import ChangeFeedIterable from .._change_feed.change_feed_state import ChangeFeedState -from azure.cosmos.aio._global_partition_endpoint_manager_circuit_breaker_async import _GlobalPartitionEndpointManagerForCircuitBreaker +from azure.cosmos.aio._global_partition_endpoint_manager_circuit_breaker_async import _GlobalPartitionEndpointManagerForCircuitBreakerAsync from .._routing import routing_range from ..documents import ConnectionPolicy, DatabaseAccount from .._constants import _Constants as Constants @@ -169,7 +169,7 @@ def __init__( # pylint: disable=too-many-statements # Keeps the latest response headers from the server. self.last_response_headers: CaseInsensitiveDict = CaseInsensitiveDict() self.UseMultipleWriteLocations = False - self._global_endpoint_manager = _GlobalPartitionEndpointManagerForCircuitBreaker(self) + self._global_endpoint_manager = _GlobalPartitionEndpointManagerForCircuitBreakerAsync(self) retry_policy = None if isinstance(self.connection_policy.ConnectionRetryConfiguration, AsyncHTTPPolicy): diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_global_partition_endpoint_manager_circuit_breaker_async.py b/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_global_partition_endpoint_manager_circuit_breaker_async.py index 849630c83d4d..71bb628c31a0 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_global_partition_endpoint_manager_circuit_breaker_async.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_global_partition_endpoint_manager_circuit_breaker_async.py @@ -21,26 +21,19 @@ """Internal class for global endpoint manager for circuit breaker. """ -import logging -import os -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Optional -from azure.cosmos import documents +from azure.cosmos._global_partition_endpoint_manager_circuit_breaker_core import \ + _GlobalPartitionEndpointManagerForCircuitBreakerCore -from azure.cosmos._partition_health_tracker import PartitionHealthTracker -from azure.cosmos._routing.routing_range import Range, PartitionKeyRangeWrapper from azure.cosmos.aio._global_endpoint_manager_async import _GlobalEndpointManager -from azure.cosmos._location_cache import EndpointOperationType from azure.cosmos._request_object import RequestObject -from azure.cosmos.http_constants import ResourceType, HttpHeaders -from azure.cosmos._constants import _Constants as Constants if TYPE_CHECKING: from azure.cosmos.aio._cosmos_client_connection_async import CosmosClientConnection -logger = logging.getLogger("azure.cosmos._GlobalEndpointManagerForCircuitBreaker") -class _GlobalPartitionEndpointManagerForCircuitBreaker(_GlobalEndpointManager): +class _GlobalPartitionEndpointManagerForCircuitBreakerAsync(_GlobalEndpointManager): """ This internal class implements the logic for partition endpoint management for geo-replicated database accounts. @@ -48,89 +41,33 @@ class _GlobalPartitionEndpointManagerForCircuitBreaker(_GlobalEndpointManager): def __init__(self, client: "CosmosClientConnection"): - super(_GlobalPartitionEndpointManagerForCircuitBreaker, self).__init__(client) - self.partition_health_tracker = PartitionHealthTracker() + super(_GlobalPartitionEndpointManagerForCircuitBreakerAsync, self).__init__(client) + self.global_partition_endpoint_manager_core = _GlobalPartitionEndpointManagerForCircuitBreakerCore(client, self.location_cache) def is_circuit_breaker_applicable(self, request: RequestObject) -> bool: """ Check if circuit breaker is applicable for a request. """ - if not request: - return False - - circuit_breaker_enabled = os.environ.get(Constants.CIRCUIT_BREAKER_ENABLED_CONFIG, - Constants.CIRCUIT_BREAKER_ENABLED_CONFIG_DEFAULT) == "True" - if not circuit_breaker_enabled: - return False - - if (not self.location_cache.can_use_multiple_write_locations_for_request(request) - and documents._OperationType.IsWriteOperation(request.operation_type)): - return False - - if request.resource_type != ResourceType.Document: - return False - - if request.operation_type != documents._OperationType.QueryPlan: - return False - - return True - - def _create_pkrange_wrapper(self, request: RequestObject) -> PartitionKeyRangeWrapper: - """ - Create a PartitionKeyRangeWrapper object. - """ - container_rid = request.headers[HttpHeaders.IntendedCollectionRID] - partition_key = request.headers[HttpHeaders.PartitionKey] - # get the partition key range for the given partition key - target_container_link = None - for container_link, properties in self.client._container_properties_cache: - # TODO: @tvaron3 consider moving this to a constant with other usages - if properties["_rid"] == container_rid: - target_container_link = container_link - # throw exception if it is not found - pkrange = self.client._routing_map_provider.get_overlapping_range(target_container_link, partition_key) - return PartitionKeyRangeWrapper(pkrange, container_rid) + return self.global_partition_endpoint_manager_core.is_circuit_breaker_applicable(request) def record_failure( self, request: RequestObject ) -> None: - if self.is_circuit_breaker_applicable(request): - #convert operation_type to EndpointOperationType - endpoint_operation_type = EndpointOperationType.WriteType if ( - documents._OperationType.IsWriteOperation(request.operation_type)) else EndpointOperationType.ReadType - location = self.location_cache.get_location_from_endpoint(request.location_endpoint_to_route) - pkrange_wrapper = self._create_pkrange_wrapper(request) - self.partition_health_tracker.add_failure(pkrange_wrapper, endpoint_operation_type, location) + self.global_partition_endpoint_manager_core.record_failure(request) def resolve_service_endpoint(self, request): - if self.is_circuit_breaker_applicable(request): - pkrange_wrapper = self._create_pkrange_wrapper(request) - request.set_excluded_locations_from_circuit_breaker( - self.partition_health_tracker.get_excluded_locations(pkrange_wrapper) - ) - return super(_GlobalPartitionEndpointManagerForCircuitBreaker, self).resolve_service_endpoint(request) - + request = self.global_partition_endpoint_manager_core.add_excluded_locations_to_request(request) + return super(_GlobalPartitionEndpointManagerForCircuitBreakerAsync, self).resolve_service_endpoint(request) def mark_partition_unavailable(self, request: RequestObject) -> None: """ Mark the partition unavailable from the given request. """ - location = self.location_cache.get_location_from_endpoint(request.location_endpoint_to_route) - pkrange_wrapper = self._create_pkrange_wrapper(request) - self.partition_health_tracker.mark_partition_unavailable(pkrange_wrapper, location) + self.global_partition_endpoint_manager_core.mark_partition_unavailable(request) def record_success( self, request: RequestObject ) -> None: - if self.is_circuit_breaker_applicable(request): - #convert operation_type to either Read or Write - endpoint_operation_type = EndpointOperationType.WriteType if ( - documents._OperationType.IsWriteOperation(request.operation_type)) else EndpointOperationType.ReadType - location = self.location_cache.get_location_from_endpoint(request.location_endpoint_to_route) - pkrange_wrapper = self._create_pkrange_wrapper(request) - self.partition_health_tracker.add_success(pkrange_wrapper, endpoint_operation_type, location) - -# TODO: @tvaron3 there should be no in region retries when trying on healthy tentative ----------------------- - + self.global_partition_endpoint_manager_core.record_success(request) diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_retry_utility_async.py b/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_retry_utility_async.py index c020a5d4e31c..c5613530994d 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_retry_utility_async.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_retry_utility_async.py @@ -28,6 +28,7 @@ from azure.core.exceptions import AzureError, ClientAuthenticationError, ServiceRequestError, ServiceResponseError from azure.core.pipeline.policies import AsyncRetryPolicy +from ._global_partition_endpoint_manager_circuit_breaker_async import _GlobalPartitionEndpointManagerForCircuitBreakerAsync from .. import _default_retry_policy, _database_account_retry_policy from .. import _endpoint_discovery_retry_policy from .. import _gone_retry_policy @@ -257,6 +258,8 @@ async def send(self, request): """ absolute_timeout = request.context.options.pop('timeout', None) per_request_timeout = request.context.options.pop('connection_timeout', 0) + request_params = request.context.options.get('request_params', None) + global_endpoint_manager = request.context.options.get('global_endpoint_manager', None) retry_error = None retry_active = True response = None @@ -281,8 +284,7 @@ async def send(self, request): # the request ran into a socket timeout or failed to establish a new connection # since request wasn't sent, raise exception immediately to be dealt with in client retry policies if not _has_database_account_header(request.http_request.headers): - # TODO: @tvaron3 record failure here - request.context.global_endpoint_manager.record_failure(request.context.options['request']) + global_endpoint_manager.record_failure(request_params) if retry_settings['connect'] > 0: retry_active = self.increment(retry_settings, response=request, error=err) if retry_active: @@ -300,8 +302,7 @@ async def send(self, request): ClientConnectionError) if (isinstance(err.inner_exception, ClientConnectionError) or _has_read_retryable_headers(request.http_request.headers)): - # TODO: @tvaron3 record failure here - request.context.global_endpoint_manager.record_failure(request.context.options['request']) + global_endpoint_manager.record_failure(request_params) # This logic is based on the _retry.py file from azure-core if retry_settings['read'] > 0: retry_active = self.increment(retry_settings, response=request, error=err) @@ -316,8 +317,7 @@ async def send(self, request): if _has_database_account_header(request.http_request.headers): raise err if self._is_method_retryable(retry_settings, request.http_request): - # TODO: @tvaron3 record failure here ? Not sure - request.context.global_endpoint_manager.record_failure(request.context.options['request']) + global_endpoint_manager.record_failure(request_params) retry_active = self.increment(retry_settings, response=request, error=err) if retry_active: await self.sleep(retry_settings, request.context.transport) From 9e51011b472a89ad34b4cc001d586da4ab350fc7 Mon Sep 17 00:00:00 2001 From: tvaron3 Date: Wed, 2 Apr 2025 12:05:14 -0700 Subject: [PATCH 26/86] fix use multiple write locations bug --- sdk/cosmos/azure-cosmos/CHANGELOG.md | 2 +- .../azure-cosmos/azure/cosmos/_global_endpoint_manager.py | 4 ++-- .../azure/cosmos/_service_request_retry_policy.py | 2 +- .../azure/cosmos/aio/_global_endpoint_manager_async.py | 6 +++--- 4 files changed, 7 insertions(+), 7 deletions(-) diff --git a/sdk/cosmos/azure-cosmos/CHANGELOG.md b/sdk/cosmos/azure-cosmos/CHANGELOG.md index ab16ba4f1cd9..97a88849f18c 100644 --- a/sdk/cosmos/azure-cosmos/CHANGELOG.md +++ b/sdk/cosmos/azure-cosmos/CHANGELOG.md @@ -3,7 +3,7 @@ ### 4.10.0b5 (Unreleased) #### Features Added -* Per partition circuit breaker support. It can be enabled through the environment variable `AZURE_COSMOS_ENABLE_CIRCUIT_BREAKER`. See [PR 40302](https://github.com/Azure/azure-sdk-for-python/pull/40302) +* Per partition circuit breaker support. It can be enabled through the environment variable `AZURE_COSMOS_ENABLE_CIRCUIT_BREAKER`. See [PR 40302](https://github.com/Azure/azure-sdk-for-python/pull/40302). #### Breaking Changes diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/_global_endpoint_manager.py b/sdk/cosmos/azure-cosmos/azure/cosmos/_global_endpoint_manager.py index 8aa7c388a9b6..52b0b64cc27e 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/_global_endpoint_manager.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/_global_endpoint_manager.py @@ -61,8 +61,8 @@ def __init__(self, client): self.last_refresh_time = 0 self._database_account_cache = None - def get_use_multiple_write_locations(self): - return self.location_cache.can_use_multiple_write_locations() + def get_use_multiple_write_locations(self, request): + return self.location_cache.can_use_multiple_write_locations_for_request() def get_refresh_time_interval_in_ms_stub(self): return constants._Constants.DefaultEndpointsRefreshTime diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/_service_request_retry_policy.py b/sdk/cosmos/azure-cosmos/azure/cosmos/_service_request_retry_policy.py index 8630714bc6f3..18dd17c9f5ad 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/_service_request_retry_policy.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/_service_request_retry_policy.py @@ -61,7 +61,7 @@ def ShouldRetry(self): self.request.last_routed_location_endpoint_within_region = self.request.location_endpoint_to_route if (_OperationType.IsReadOnlyOperation(self.request.operation_type) - or self.global_endpoint_manager.get_use_multiple_write_locations()): + or self.global_endpoint_manager.get_use_multiple_write_locations(self.request)): self.update_location_cache() # We just directly got to the next location in case of read requests # We don't retry again on the same region for regional endpoint diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_global_endpoint_manager_async.py b/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_global_endpoint_manager_async.py index f2bf7dcbe5fd..f001772b246a 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_global_endpoint_manager_async.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_global_endpoint_manager_async.py @@ -33,6 +33,7 @@ from .. import _constants as constants from .. import exceptions from .._location_cache import LocationCache, current_time_millis +from .._request_object import RequestObject # pylint: disable=protected-access @@ -63,9 +64,8 @@ def __init__(self, client): self.last_refresh_time = 0 self._database_account_cache = None - # TODO: @tvaron3 fix this - def get_use_multiple_write_locations(self): - return self.location_cache.can_use_multiple_write_locations() + def get_use_multiple_write_locations(self, request: RequestObject): + return self.location_cache.can_use_multiple_write_locations_for_request(request) def get_refresh_time_interval_in_ms_stub(self): return constants._Constants.DefaultEndpointsRefreshTime From 8d276515de28d8850ab4d5b0d5e54486f9ae6f75 Mon Sep 17 00:00:00 2001 From: tvaron3 Date: Wed, 2 Apr 2025 13:25:42 -0700 Subject: [PATCH 27/86] clean up and revert vs env variable changes --- .../azure-cosmos/azure/cosmos/_constants.py | 6 ------ .../azure/cosmos/_cosmos_client_connection.py | 11 +---------- .../aio/execution_dispatcher.py | 6 ++---- .../execution_dispatcher.py | 11 ++++------- .../azure/cosmos/_global_endpoint_manager.py | 3 --- ...tition_endpoint_manager_circuit_breaker.py | 2 +- ...n_endpoint_manager_circuit_breaker_core.py | 2 +- .../cosmos/_service_request_retry_policy.py | 2 +- .../cosmos/_timeout_failover_retry_policy.py | 1 - .../aio/_cosmos_client_connection_async.py | 11 +---------- .../aio/_global_endpoint_manager_async.py | 3 --- .../tests/test_query_hybrid_search.py | 13 ------------- .../tests/test_query_hybrid_search_async.py | 14 +------------- .../tests/test_query_vector_similarity.py | 19 +------------------ .../test_query_vector_similarity_async.py | 19 +------------------ 15 files changed, 14 insertions(+), 109 deletions(-) diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/_constants.py b/sdk/cosmos/azure-cosmos/azure/cosmos/_constants.py index 2af40c74d77a..cf029179f1a1 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/_constants.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/_constants.py @@ -45,12 +45,6 @@ class _Constants: EnableMultipleWritableLocations: Literal["enableMultipleWriteLocations"] = "enableMultipleWriteLocations" # Environment variables - NON_STREAMING_ORDER_BY_DISABLED_CONFIG: str = "AZURE_COSMOS_DISABLE_NON_STREAMING_ORDER_BY" - NON_STREAMING_ORDER_BY_DISABLED_CONFIG_DEFAULT: str = "False" - HS_MAX_ITEMS_CONFIG: str = "AZURE_COSMOS_HYBRID_SEARCH_MAX_ITEMS" - HS_MAX_ITEMS_CONFIG_DEFAULT: int = 1000 - MAX_ITEM_BUFFER_VS_CONFIG: str = "AZURE_COSMOS_MAX_ITEM_BUFFER_VECTOR_SEARCH" - MAX_ITEM_BUFFER_VS_CONFIG_DEFAULT: int = 50000 CIRCUIT_BREAKER_ENABLED_CONFIG: str = "AZURE_COSMOS_ENABLE_CIRCUIT_BREAKER" CIRCUIT_BREAKER_ENABLED_CONFIG_DEFAULT: str = "False" # Only applicable when circuit breaker is enabled ------------------------- diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/_cosmos_client_connection.py b/sdk/cosmos/azure-cosmos/azure/cosmos/_cosmos_client_connection.py index 7bb379a6b6e7..3c8a11218738 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/_cosmos_client_connection.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/_cosmos_client_connection.py @@ -2045,7 +2045,6 @@ def PatchItem( # Patch will use WriteEndpoint since it uses PUT operation request_params = RequestObject(resource_type, documents._OperationType.Patch, headers) request_params.set_excluded_location_from_options(options) - request_params.set_excluded_location_from_options(options) request_data = {} if options.get("filterPredicate"): request_data["condition"] = options.get("filterPredicate") @@ -2135,7 +2134,6 @@ def _Batch( documents._OperationType.Batch, options) request_params = RequestObject("docs", documents._OperationType.Batch, headers) request_params.set_excluded_location_from_options(options) - request_params.set_excluded_location_from_options(options) return cast( Tuple[List[Dict[str, Any]], CaseInsensitiveDict], self.__Post(path, request_params, batch_operations, headers, **kwargs) @@ -2197,7 +2195,6 @@ def DeleteAllItemsByPartitionKey( "partitionkey", documents._OperationType.Delete, options) request_params = RequestObject("partitionkey", documents._OperationType.Delete, headers) request_params.set_excluded_location_from_options(options) - request_params.set_excluded_location_from_options(options) _, last_response_headers = self.__Post( path=path, request_params=request_params, @@ -2655,7 +2652,6 @@ def Create( request_params = RequestObject(typ, documents._OperationType.Create, headers) request_params.set_excluded_location_from_options(options) - request_params.set_excluded_location_from_options(options) result, last_response_headers = self.__Post(path, request_params, body, headers, **kwargs) self.last_response_headers = last_response_headers @@ -2703,7 +2699,6 @@ def Upsert( # Upsert will use WriteEndpoint since it uses POST operation request_params = RequestObject(typ, documents._OperationType.Upsert, headers) request_params.set_excluded_location_from_options(options) - request_params.set_excluded_location_from_options(options) result, last_response_headers = self.__Post(path, request_params, body, headers, **kwargs) self.last_response_headers = last_response_headers # update session for write request @@ -2748,7 +2743,6 @@ def Replace( # Replace will use WriteEndpoint since it uses PUT operation request_params = RequestObject(typ, documents._OperationType.Replace, headers) request_params.set_excluded_location_from_options(options) - request_params.set_excluded_location_from_options(options) result, last_response_headers = self.__Put(path, request_params, resource, headers, **kwargs) self.last_response_headers = last_response_headers @@ -2791,7 +2785,6 @@ def Read( # Read will use ReadEndpoint since it uses GET operation request_params = RequestObject(typ, documents._OperationType.Read, headers) request_params.set_excluded_location_from_options(options) - request_params.set_excluded_location_from_options(options) result, last_response_headers = self.__Get(path, request_params, headers, **kwargs) self.last_response_headers = last_response_headers if response_hook: @@ -2832,7 +2825,6 @@ def DeleteResource( # Delete will use WriteEndpoint since it uses DELETE operation request_params = RequestObject(typ, documents._OperationType.Delete, headers) request_params.set_excluded_location_from_options(options) - request_params.set_excluded_location_from_options(options) result, last_response_headers = self.__Delete(path, request_params, headers, **kwargs) self.last_response_headers = last_response_headers @@ -3206,8 +3198,7 @@ def _GetQueryPlanThroughGateway(self, query: str, resource_link: str, **kwargs: documents._QueryFeature.NonStreamingOrderBy + "," + documents._QueryFeature.HybridSearch + "," + documents._QueryFeature.CountIf) - if os.environ.get(Constants.NON_STREAMING_ORDER_BY_DISABLED_CONFIG, - Constants.NON_STREAMING_ORDER_BY_DISABLED_CONFIG_DEFAULT) == "True": + if os.environ.get('AZURE_COSMOS_DISABLE_NON_STREAMING_ORDER_BY', False): supported_query_features = (documents._QueryFeature.Aggregate + "," + documents._QueryFeature.CompositeAggregate + "," + documents._QueryFeature.Distinct + "," + diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/_execution_context/aio/execution_dispatcher.py b/sdk/cosmos/azure-cosmos/azure/cosmos/_execution_context/aio/execution_dispatcher.py index 9b4d32a4598b..70df36d6d015 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/_execution_context/aio/execution_dispatcher.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/_execution_context/aio/execution_dispatcher.py @@ -34,7 +34,6 @@ from azure.cosmos.documents import _DistinctType from azure.cosmos.exceptions import CosmosHttpResponseError from azure.cosmos.http_constants import StatusCodes -from ..._constants import _Constants as Constants # pylint: disable=protected-access @@ -118,9 +117,8 @@ async def _create_pipelined_execution_context(self, query_execution_info): raise ValueError("Executing a vector search query without TOP or LIMIT can consume many" + " RUs very fast and have long runtimes. Please ensure you are using one" + " of the two filters with your vector search query.") - if total_item_buffer > int(os.environ.get(Constants.MAX_ITEM_BUFFER_VS_CONFIG, - Constants.MAX_ITEM_BUFFER_VS_CONFIG_DEFAULT)): - raise ValueError("Executing a vector search query with more items than the max is not allowed. " + + if total_item_buffer > os.environ.get('AZURE_COSMOS_MAX_ITEM_BUFFER_VECTOR_SEARCH', 50000): + raise ValueError("Executing a vector search query with more items than the max is not allowed." + "Please ensure you are using a limit smaller than the max, or change the max.") execution_context_aggregator =\ non_streaming_order_by_aggregator._NonStreamingOrderByContextAggregator(self._client, diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/_execution_context/execution_dispatcher.py b/sdk/cosmos/azure-cosmos/azure/cosmos/_execution_context/execution_dispatcher.py index 453a4dc38d25..1155e537c68c 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/_execution_context/execution_dispatcher.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/_execution_context/execution_dispatcher.py @@ -26,7 +26,6 @@ import json import os from azure.cosmos.exceptions import CosmosHttpResponseError -from .._constants import _Constants as Constants from azure.cosmos._execution_context import endpoint_component, multi_execution_aggregator from azure.cosmos._execution_context import non_streaming_order_by_aggregator, hybrid_search_aggregator from azure.cosmos._execution_context.base_execution_context import _QueryExecutionContextBase @@ -57,9 +56,8 @@ def _verify_valid_hybrid_search_query(hybrid_search_query_info): raise ValueError("Executing a hybrid search query without TOP or LIMIT can consume many" + " RUs very fast and have long runtimes. Please ensure you are using one" + " of the two filters with your hybrid search query.") - if hybrid_search_query_info['take'] > int(os.environ.get(Constants.HS_MAX_ITEMS_CONFIG, - Constants.HS_MAX_ITEMS_CONFIG_DEFAULT)): - raise ValueError("Executing a hybrid search query with more items than the max is not allowed. " + + if hybrid_search_query_info['take'] > os.environ.get('AZURE_COSMOS_HYBRID_SEARCH_MAX_ITEMS', 1000): + raise ValueError("Executing a hybrid search query with more items than the max is not allowed." + "Please ensure you are using a limit smaller than the max, or change the max.") @@ -151,9 +149,8 @@ def _create_pipelined_execution_context(self, query_execution_info): raise ValueError("Executing a vector search query without TOP or LIMIT can consume many" + " RUs very fast and have long runtimes. Please ensure you are using one" + " of the two filters with your vector search query.") - if total_item_buffer > int(os.environ.get(Constants.MAX_ITEM_BUFFER_VS_CONFIG, - Constants.MAX_ITEM_BUFFER_VS_CONFIG_DEFAULT)): - raise ValueError("Executing a vector search query with more items than the max is not allowed. " + + if total_item_buffer > os.environ.get('AZURE_COSMOS_MAX_ITEM_BUFFER_VECTOR_SEARCH', 50000): + raise ValueError("Executing a vector search query with more items than the max is not allowed." + "Please ensure you are using a limit smaller than the max, or change the max.") execution_context_aggregator = \ non_streaming_order_by_aggregator._NonStreamingOrderByContextAggregator(self._client, diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/_global_endpoint_manager.py b/sdk/cosmos/azure-cosmos/azure/cosmos/_global_endpoint_manager.py index 35a4c60c6ca4..8cf9d06d5486 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/_global_endpoint_manager.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/_global_endpoint_manager.py @@ -58,9 +58,6 @@ def __init__(self, client): self.last_refresh_time = 0 self._database_account_cache = None - def get_use_multiple_write_locations(self, request): - return self.location_cache.can_use_multiple_write_locations_for_request() - def get_refresh_time_interval_in_ms_stub(self): return constants._Constants.DefaultEndpointsRefreshTime diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/_global_partition_endpoint_manager_circuit_breaker.py b/sdk/cosmos/azure-cosmos/azure/cosmos/_global_partition_endpoint_manager_circuit_breaker.py index 072b946c1402..b95ef0a2a7da 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/_global_partition_endpoint_manager_circuit_breaker.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/_global_partition_endpoint_manager_circuit_breaker.py @@ -26,7 +26,7 @@ from azure.cosmos._global_partition_endpoint_manager_circuit_breaker_core import \ _GlobalPartitionEndpointManagerForCircuitBreakerCore -from azure.cosmos.aio._global_endpoint_manager_async import _GlobalEndpointManager +from azure.cosmos._global_endpoint_manager import _GlobalEndpointManager from azure.cosmos._request_object import RequestObject if TYPE_CHECKING: from azure.cosmos._cosmos_client_connection import CosmosClientConnection diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/_global_partition_endpoint_manager_circuit_breaker_core.py b/sdk/cosmos/azure-cosmos/azure/cosmos/_global_partition_endpoint_manager_circuit_breaker_core.py index 23a7c50047ca..142c51a1ed19 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/_global_partition_endpoint_manager_circuit_breaker_core.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/_global_partition_endpoint_manager_circuit_breaker_core.py @@ -128,4 +128,4 @@ def record_success( pkrange_wrapper = self._create_pkrange_wrapper(request) self.partition_health_tracker.add_success(pkrange_wrapper, endpoint_operation_type, location) -# TODO: @tvaron3 there should be no in region retries when trying on healthy tentative ----------------------- \ No newline at end of file +# TODO: @tvaron3 there should be no in region retries when trying on healthy tentative ----------------------- diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/_service_request_retry_policy.py b/sdk/cosmos/azure-cosmos/azure/cosmos/_service_request_retry_policy.py index 18dd17c9f5ad..edd15f20337f 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/_service_request_retry_policy.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/_service_request_retry_policy.py @@ -61,7 +61,7 @@ def ShouldRetry(self): self.request.last_routed_location_endpoint_within_region = self.request.location_endpoint_to_route if (_OperationType.IsReadOnlyOperation(self.request.operation_type) - or self.global_endpoint_manager.get_use_multiple_write_locations(self.request)): + or self.global_endpoint_manager.can_use_multiple_write_locations(self.request)): self.update_location_cache() # We just directly got to the next location in case of read requests # We don't retry again on the same region for regional endpoint diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/_timeout_failover_retry_policy.py b/sdk/cosmos/azure-cosmos/azure/cosmos/_timeout_failover_retry_policy.py index 505a8edf3d06..f70e27bae70c 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/_timeout_failover_retry_policy.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/_timeout_failover_retry_policy.py @@ -5,7 +5,6 @@ Cosmos database service. """ from azure.cosmos.documents import _OperationType -from azure.cosmos.http_constants import HttpHeaders class _TimeoutFailoverRetryPolicy(object): diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_cosmos_client_connection_async.py b/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_cosmos_client_connection_async.py index cfef48d64765..09add40f5785 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_cosmos_client_connection_async.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_cosmos_client_connection_async.py @@ -774,7 +774,6 @@ async def Create( request_params = _request_object.RequestObject(typ, documents._OperationType.Create, headers) request_params.set_excluded_location_from_options(options) - request_params.set_excluded_location_from_options(options) result, last_response_headers = await self.__Post(path, request_params, body, headers, **kwargs) self.last_response_headers = last_response_headers @@ -915,7 +914,6 @@ async def Upsert( # Upsert will use WriteEndpoint since it uses POST operation request_params = _request_object.RequestObject(typ, documents._OperationType.Upsert, headers) request_params.set_excluded_location_from_options(options) - request_params.set_excluded_location_from_options(options) result, last_response_headers = await self.__Post(path, request_params, body, headers, **kwargs) self.last_response_headers = last_response_headers # update session for write request @@ -1218,7 +1216,6 @@ async def Read( # Read will use ReadEndpoint since it uses GET operation request_params = _request_object.RequestObject(typ, documents._OperationType.Read, headers) request_params.set_excluded_location_from_options(options) - request_params.set_excluded_location_from_options(options) result, last_response_headers = await self.__Get(path, request_params, headers, **kwargs) self.last_response_headers = last_response_headers if response_hook: @@ -1478,7 +1475,6 @@ async def PatchItem( # Patch will use WriteEndpoint since it uses PUT operation request_params = _request_object.RequestObject(typ, documents._OperationType.Patch, headers) request_params.set_excluded_location_from_options(options) - request_params.set_excluded_location_from_options(options) request_data = {} if options.get("filterPredicate"): request_data["condition"] = options.get("filterPredicate") @@ -1584,7 +1580,6 @@ async def Replace( # Replace will use WriteEndpoint since it uses PUT operation request_params = _request_object.RequestObject(typ, documents._OperationType.Replace, headers) request_params.set_excluded_location_from_options(options) - request_params.set_excluded_location_from_options(options) result, last_response_headers = await self.__Put(path, request_params, resource, headers, **kwargs) self.last_response_headers = last_response_headers @@ -1909,7 +1904,6 @@ async def DeleteResource( # Delete will use WriteEndpoint since it uses DELETE operation request_params = _request_object.RequestObject(typ, documents._OperationType.Delete, headers) request_params.set_excluded_location_from_options(options) - request_params.set_excluded_location_from_options(options) result, last_response_headers = await self.__Delete(path, request_params, headers, **kwargs) self.last_response_headers = last_response_headers @@ -2024,7 +2018,6 @@ async def _Batch( documents._OperationType.Batch, options) request_params = _request_object.RequestObject("docs", documents._OperationType.Batch, headers) request_params.set_excluded_location_from_options(options) - request_params.set_excluded_location_from_options(options) result = await self.__Post(path, request_params, batch_operations, headers, **kwargs) return cast(Tuple[List[Dict[str, Any]], CaseInsensitiveDict], result) @@ -3219,8 +3212,7 @@ async def _GetQueryPlanThroughGateway(self, query: str, resource_link: str, **kw documents._QueryFeature.NonStreamingOrderBy + "," + documents._QueryFeature.HybridSearch + "," + documents._QueryFeature.CountIf) - if os.environ.get(Constants.NON_STREAMING_ORDER_BY_DISABLED_CONFIG, - Constants.NON_STREAMING_ORDER_BY_DISABLED_CONFIG_DEFAULT) == "True": + if os.environ.get('AZURE_COSMOS_DISABLE_NON_STREAMING_ORDER_BY', False): supported_query_features = (documents._QueryFeature.Aggregate + "," + documents._QueryFeature.CompositeAggregate + "," + documents._QueryFeature.Distinct + "," + @@ -3286,7 +3278,6 @@ async def DeleteAllItemsByPartitionKey( request_params = _request_object.RequestObject("partitionkey", documents._OperationType.Delete, headers) request_params.set_excluded_location_from_options(options) - request_params.set_excluded_location_from_options(options) _, last_response_headers = await self.__Post(path=path, request_params=request_params, req_headers=headers, body=None, **kwargs) self.last_response_headers = last_response_headers diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_global_endpoint_manager_async.py b/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_global_endpoint_manager_async.py index 0917b3de94ca..00438cc2214e 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_global_endpoint_manager_async.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_global_endpoint_manager_async.py @@ -62,9 +62,6 @@ def __init__(self, client): self.last_refresh_time = 0 self._database_account_cache = None - def get_use_multiple_write_locations(self, request: RequestObject): - return self.location_cache.can_use_multiple_write_locations_for_request(request) - def get_refresh_time_interval_in_ms_stub(self): return constants._Constants.DefaultEndpointsRefreshTime diff --git a/sdk/cosmos/azure-cosmos/tests/test_query_hybrid_search.py b/sdk/cosmos/azure-cosmos/tests/test_query_hybrid_search.py index 3a9e8527992e..429498b3071f 100644 --- a/sdk/cosmos/azure-cosmos/tests/test_query_hybrid_search.py +++ b/sdk/cosmos/azure-cosmos/tests/test_query_hybrid_search.py @@ -2,7 +2,6 @@ # Copyright (c) Microsoft Corporation. All rights reserved. import json -import os import time import unittest import uuid @@ -95,18 +94,6 @@ def test_wrong_hybrid_search_queries(self): assert ("One of the input values is invalid" in e.message or "Specifying a sort order (ASC or DESC) in the ORDER BY RANK clause is not allowed." in e.message) - def test_hybrid_search_env_variables_async(self): - os.environ["AZURE_COSMOS_HYBRID_SEARCH_MAX_ITEMS"] = "1" - try: - query = "SELECT TOP 1 c.index, c.title FROM c WHERE FullTextContains(c.title, 'John') OR " \ - "FullTextContains(c.text, 'John') ORDER BY RANK FullTextScore(c.title, ['John'])" - results = self.test_container.query_items(query) - [item for item in results] - pytest.fail("Config was not applied properly.") - except ValueError as e: - assert e.args[0] == ("Executing a hybrid search query with more items than the max is not allowed. " - "Please ensure you are using a limit smaller than the max, or change the max.") - def test_hybrid_search_queries(self): query = "SELECT TOP 10 c.index, c.title FROM c WHERE FullTextContains(c.title, 'John') OR " \ "FullTextContains(c.text, 'John') ORDER BY RANK FullTextScore(c.title, ['John'])" diff --git a/sdk/cosmos/azure-cosmos/tests/test_query_hybrid_search_async.py b/sdk/cosmos/azure-cosmos/tests/test_query_hybrid_search_async.py index 4223cc6bdd50..964d9579a2e9 100644 --- a/sdk/cosmos/azure-cosmos/tests/test_query_hybrid_search_async.py +++ b/sdk/cosmos/azure-cosmos/tests/test_query_hybrid_search_async.py @@ -1,6 +1,6 @@ # The MIT License (MIT) # Copyright (c) Microsoft Corporation. All rights reserved. -import os + import time import unittest import uuid @@ -98,18 +98,6 @@ async def test_wrong_hybrid_search_queries_async(self): assert ("One of the input values is invalid" in e.message or "Specifying a sort order (ASC or DESC) in the ORDER BY RANK clause is not allowed." in e.message) - async def test_hybrid_search_env_variables_async(self): - os.environ["AZURE_COSMOS_HYBRID_SEARCH_MAX_ITEMS"] = "1" - try: - query = "SELECT TOP 1 c.index, c.title FROM c WHERE FullTextContains(c.title, 'John') OR " \ - "FullTextContains(c.text, 'John') ORDER BY RANK FullTextScore(c.title, ['John'])" - results = self.test_container.query_items(query) - [item async for item in results] - pytest.fail("Config was not applied properly.") - except ValueError as e: - assert e.args[0] == ("Executing a hybrid search query with more items than the max is not allowed. " - "Please ensure you are using a limit smaller than the max, or change the max.") - async def test_hybrid_search_queries_async(self): query = "SELECT TOP 10 c.index, c.title FROM c WHERE FullTextContains(c.title, 'John') OR " \ "FullTextContains(c.text, 'John') ORDER BY RANK FullTextScore(c.title, ['John'])" diff --git a/sdk/cosmos/azure-cosmos/tests/test_query_vector_similarity.py b/sdk/cosmos/azure-cosmos/tests/test_query_vector_similarity.py index 96e3eee02936..e2614c5fb85f 100644 --- a/sdk/cosmos/azure-cosmos/tests/test_query_vector_similarity.py +++ b/sdk/cosmos/azure-cosmos/tests/test_query_vector_similarity.py @@ -1,6 +1,6 @@ # The MIT License (MIT) # Copyright (c) Microsoft Corporation. All rights reserved. -import os + import unittest import uuid @@ -121,23 +121,6 @@ def test_wrong_vector_search_queries(self): assert ("One of the input values is invalid." in e.message or "Specifying a sorting order (ASC or DESC) with VectorDistance function is not supported." in e.message) - - def test_vector_search_environment_variables(self): - vector_string = vector_test_data.get_embedding_string("I am having a wonderful day.") - query = "SELECT TOP 10 c.text, VectorDistance(c.embedding, [{}]) AS " \ - "SimilarityScore FROM c ORDER BY VectorDistance(c.embedding, [{}])".format(vector_string, vector_string) - os.environ["AZURE_COSMOS_MAX_ITEM_BUFFER_VECTOR_SEARCH"] = "1" - try: - [item for item in self.created_large_container.query_items(query=query, enable_cross_partition_query=True)] - pytest.fail("Config was not set correctly.") - except ValueError as e: - assert e.args[0] == ("Executing a vector search query with more items than the max is not allowed. " - "Please ensure you are using a limit smaller than the max, or change the max.") - - os.environ["AZURE_COSMOS_MAX_ITEM_BUFFER_VECTOR_SEARCH"] = "50000" - os.environ["AZURE_COSMOS_DISABLE_NON_STREAMING_ORDER_BY"] = "False" - [item for item in self.created_large_container.query_items(query=query, enable_cross_partition_query=True)] - def test_ordering_distances(self): # Besides ordering distances, we also verify that the query text properly replaces any set embedding policies # load up previously calculated embedding for the given string diff --git a/sdk/cosmos/azure-cosmos/tests/test_query_vector_similarity_async.py b/sdk/cosmos/azure-cosmos/tests/test_query_vector_similarity_async.py index 716150358ff3..0cb031847a6f 100644 --- a/sdk/cosmos/azure-cosmos/tests/test_query_vector_similarity_async.py +++ b/sdk/cosmos/azure-cosmos/tests/test_query_vector_similarity_async.py @@ -1,6 +1,6 @@ # The MIT License (MIT) # Copyright (c) Microsoft Corporation. All rights reserved. -import os + import unittest import uuid @@ -127,23 +127,6 @@ async def test_wrong_vector_search_queries_async(self): assert ("One of the input values is invalid." in e.message or "Specifying a sorting order (ASC or DESC) with VectorDistance function is not supported." in e.message) - async def test_vector_search_environment_variables_async(self): - vector_string = vector_test_data.get_embedding_string("I am having a wonderful day.") - query = "SELECT TOP 10 c.text, VectorDistance(c.embedding, [{}]) AS " \ - "SimilarityScore FROM c ORDER BY VectorDistance(c.embedding, [{}])".format(vector_string, vector_string) - os.environ["AZURE_COSMOS_MAX_ITEM_BUFFER_VECTOR_SEARCH"] = "1" - try: - [item async for item in self.created_large_container.query_items(query=query)] - pytest.fail("Config was not set correctly.") - except ValueError as e: - assert e.args[0] == ("Executing a vector search query with more items than the max is not allowed. " - "Please ensure you are using a limit smaller than the max, or change the max.") - - os.environ["AZURE_COSMOS_MAX_ITEM_BUFFER_VECTOR_SEARCH"] = "50000" - - os.environ["AZURE_COSMOS_DISABLE_NON_STREAMING_ORDER_BY"] = "False" - [item async for item in self.created_large_container.query_items(query=query)] - async def test_ordering_distances_async(self): # Besides ordering distances, we also verify that the query text properly replaces any set embedding policies From 90fe5c2ff3a58b9c74a2fb2f17de51e4c3553f80 Mon Sep 17 00:00:00 2001 From: tvaron3 Date: Wed, 2 Apr 2025 14:26:31 -0700 Subject: [PATCH 28/86] remove async await --- .../azure-cosmos/azure/cosmos/aio/_asynchronous_request.py | 1 - 1 file changed, 1 deletion(-) diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_asynchronous_request.py b/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_asynchronous_request.py index f8ebf6ccdbb8..81430d8df42c 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_asynchronous_request.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_asynchronous_request.py @@ -117,7 +117,6 @@ async def _Request(global_endpoint_manager, request_params, connection_policy, p response = response.http_response headers = copy.copy(response.headers) - await response.load_body() data = response.body() if data: data = data.decode("utf-8") From 206be781b0cfa5c87a1ca69e50b5407cdca31e6e Mon Sep 17 00:00:00 2001 From: tvaron3 Date: Wed, 2 Apr 2025 16:25:50 -0700 Subject: [PATCH 29/86] refactor and fix tests --- ...py => _fault_injection_transport_async.py} | 29 ++--- .../azure-cosmos/tests/test_crud_async.py | 3 - .../test_fault_injection_transport_async.py | 104 ++++++++---------- 3 files changed, 59 insertions(+), 77 deletions(-) rename sdk/cosmos/azure-cosmos/tests/{_fault_injection_transport.py => _fault_injection_transport_async.py} (91%) diff --git a/sdk/cosmos/azure-cosmos/tests/_fault_injection_transport.py b/sdk/cosmos/azure-cosmos/tests/_fault_injection_transport_async.py similarity index 91% rename from sdk/cosmos/azure-cosmos/tests/_fault_injection_transport.py rename to sdk/cosmos/azure-cosmos/tests/_fault_injection_transport_async.py index 7f83608b3cd7..230d8f89dfe5 100644 --- a/sdk/cosmos/azure-cosmos/tests/_fault_injection_transport.py +++ b/sdk/cosmos/azure-cosmos/tests/_fault_injection_transport_async.py @@ -27,30 +27,31 @@ import logging import sys from collections.abc import MutableMapping -from typing import Callable, Optional, Any +from typing import Callable, Optional, Any, Dict, List import aiohttp from azure.core.pipeline.transport import AioHttpTransport, AioHttpTransportResponse from azure.core.rest import HttpRequest, AsyncHttpResponse +from azure.cosmos import documents import test_config from azure.cosmos.exceptions import CosmosHttpResponseError from azure.core.exceptions import ServiceRequestError -class FaultInjectionTransport(AioHttpTransport): +class FaultInjectionTransportAsync(AioHttpTransport): logger = logging.getLogger('azure.cosmos.fault_injection_transport') logger.setLevel(logging.DEBUG) def __init__(self, *, session: Optional[aiohttp.ClientSession] = None, loop=None, session_owner: bool = True, **config): - self.faults = [] - self.requestTransformations = [] - self.responseTransformations = [] + self.faults: List[Dict[str, Any]] = [] + self.requestTransformations: List[Dict[str, Any]] = [] + self.responseTransformations: List[Dict[str, Any]] = [] super().__init__(session=session, loop=loop, session_owner=session_owner, **config) def add_fault(self, predicate: Callable[[HttpRequest], bool], fault_factory: Callable[[HttpRequest], asyncio.Task[Exception]]): self.faults.append({"predicate": predicate, "apply": fault_factory}) - def add_response_transformation(self, predicate: Callable[[HttpRequest], bool], response_transformation: Callable[[HttpRequest, Callable[[HttpRequest], asyncio.Task[AsyncHttpResponse]]], asyncio.Task[AsyncHttpResponse]]): + def add_response_transformation(self, predicate: Callable[[HttpRequest], bool], response_transformation: Callable[[HttpRequest, Callable[[HttpRequest], AioHttpTransportResponse]], AioHttpTransportResponse]): self.responseTransformations.append({ "predicate": predicate, "apply": response_transformation}) @@ -142,10 +143,8 @@ def predicate_is_document_operation(r: HttpRequest) -> bool: @staticmethod def predicate_is_write_operation(r: HttpRequest, uri_prefix: str) -> bool: - is_write_document_operation = (r.headers.get('x-ms-thinclient-proxy-resource-type') == 'docs' - and r.headers.get('x-ms-thinclient-proxy-operation-type') != 'Read' - and r.headers.get('x-ms-thinclient-proxy-operation-type') != 'ReadFeed' - and r.headers.get('x-ms-thinclient-proxy-operation-type') != 'Query') + is_write_document_operation = documents._OperationType.IsWriteOperation( + str(r.headers.get('x-ms-thinclient-proxy-operation-type'))) return is_write_document_operation and uri_prefix in r.url @@ -173,14 +172,12 @@ async def error_region_down() -> Exception: async def transform_topology_swr_mrr( write_region_name: str, read_region_name: str, - r: HttpRequest, - inner: Callable[[],asyncio.Task[AsyncHttpResponse]]) -> asyncio.Task[AsyncHttpResponse]: + inner: Callable[[],asyncio.Task[AioHttpTransportResponse]]) -> AioHttpTransportResponse: response = await inner() if not FaultInjectionTransport.predicate_is_database_account_call(response.request): return response - await response.load_body() data = response.body() if response.status_code == 200 and data: data = data.decode("utf-8") @@ -200,14 +197,12 @@ async def transform_topology_swr_mrr( async def transform_topology_mwr( first_region_name: str, second_region_name: str, - r: HttpRequest, - inner: Callable[[], asyncio.Task[AsyncHttpResponse]]) -> AsyncHttpResponse: + inner: Callable[[], asyncio.Task[AioHttpTransportResponse]]) -> AioHttpTransportResponse: response = await inner() if not FaultInjectionTransport.predicate_is_database_account_call(response.request): return response - await response.load_body() data = response.body() if response.status_code == 200 and data: data = data.decode("utf-8") @@ -251,7 +246,7 @@ def __init__(self, request: HttpRequest, status_code: int, content:Optional[dict def body(self) -> Optional[bytes]: return self.bytes - def text(self, encoding: Optional[str] = None) -> Optional[str]: + def text(self) -> Optional[str]: return self.json_text async def load_body(self) -> None: diff --git a/sdk/cosmos/azure-cosmos/tests/test_crud_async.py b/sdk/cosmos/azure-cosmos/tests/test_crud_async.py index 160b5bb93cc6..ca6cfad8287d 100644 --- a/sdk/cosmos/azure-cosmos/tests/test_crud_async.py +++ b/sdk/cosmos/azure-cosmos/tests/test_crud_async.py @@ -4,8 +4,6 @@ """End-to-end test. """ -import json -import os.path import time import unittest import urllib.parse as urllib @@ -17,7 +15,6 @@ from azure.core.exceptions import AzureError, ServiceResponseError from azure.core.pipeline.transport import AsyncioRequestsTransport, AsyncioRequestsTransportResponse -import azure.cosmos._base as base import azure.cosmos.documents as documents import azure.cosmos.exceptions as exceptions import test_config diff --git a/sdk/cosmos/azure-cosmos/tests/test_fault_injection_transport_async.py b/sdk/cosmos/azure-cosmos/tests/test_fault_injection_transport_async.py index 71b6824ef240..2227019cbe63 100644 --- a/sdk/cosmos/azure-cosmos/tests/test_fault_injection_transport_async.py +++ b/sdk/cosmos/azure-cosmos/tests/test_fault_injection_transport_async.py @@ -12,18 +12,18 @@ import pytest from azure.core.pipeline.transport import AioHttpTransport +from azure.core.pipeline.transport._aiohttp import AioHttpTransportResponse from azure.core.rest import HttpRequest, AsyncHttpResponse import test_config -from _fault_injection_transport import FaultInjectionTransport +from _fault_injection_transport_async import FaultInjectionTransportAsync from azure.cosmos import PartitionKey from azure.cosmos.aio import CosmosClient from azure.cosmos.aio._container import ContainerProxy from azure.cosmos.aio._database import DatabaseProxy from azure.cosmos.exceptions import CosmosHttpResponseError -COLLECTION = "created_collection" -MGMT_TIMEOUT = 3.0 +MGMT_TIMEOUT = 3.0 logger = logging.getLogger('azure.cosmos') logger.setLevel(logging.DEBUG) logger.addHandler(logging.StreamHandler(sys.stdout)) @@ -32,21 +32,16 @@ master_key = test_config.TestConfig.masterKey connection_policy = test_config.TestConfig.connectionPolicy TEST_DATABASE_ID = test_config.TestConfig.TEST_DATABASE_ID - -@pytest.fixture() -def setup(): - return - +single_partition_container_name = os.path.basename(__file__) + str(uuid.uuid4()) @pytest.mark.cosmosEmulator @pytest.mark.asyncio -@pytest.mark.usefixtures("setup") class TestFaultInjectionTransportAsync: @classmethod def setup_class(cls): logger.info("starting class: {} execution".format(cls.__name__)) - cls.host = test_config.TestConfig.host - cls.master_key = test_config.TestConfig.masterKey + cls.host = host + cls.master_key = master_key if (cls.master_key == '[YOUR_KEY_HERE]' or cls.host == '[YOUR_ENDPOINT_HERE]'): @@ -55,12 +50,12 @@ def setup_class(cls): "'masterKey' and 'host' at the top of this class to run the " "tests.") - cls.connection_policy = test_config.TestConfig.connectionPolicy - cls.database_id = test_config.TestConfig.TEST_DATABASE_ID - cls.single_partition_container_name= os.path.basename(__file__) + str(uuid.uuid4()) + cls.connection_policy = connection_policy + cls.database_id = TEST_DATABASE_ID + cls.single_partition_container_name = single_partition_container_name - cls.mgmt_client = CosmosClient(host, master_key, consistency_level="Session", - connection_policy=connection_policy, logger=logger) + cls.mgmt_client = CosmosClient(cls.host, cls.master_key, consistency_level="Session", + connection_policy=cls.connection_policy, logger=logger) created_database = cls.mgmt_client.get_database_client(cls.database_id) asyncio.run(asyncio.wait_for( created_database.create_container( @@ -89,7 +84,7 @@ def setup_method_with_custom_transport(self, custom_transport: AioHttpTransport, connection_policy=connection_policy, transport=custom_transport, logger=logger, enable_diagnostics_logging=True, **kwargs) db: DatabaseProxy = client.get_database_client(TEST_DATABASE_ID) - container: ContainerProxy = db.get_container_client(self.single_partition_container_name) + container: ContainerProxy = db.get_container_client(single_partition_container_name) return {"client": client, "db": db, "col": container} @staticmethod @@ -100,16 +95,16 @@ def cleanup_method(initialized_objects: dict[str, Any]): except Exception as close_error: logger.warning(f"Exception trying to close method client. {close_error}") - async def test_throws_injected_error(self, setup: object): + async def test_throws_injected_error(self: "TestFaultInjectionTransportAsync"): id_value: str = str(uuid.uuid4()) document_definition = {'id': id_value, 'pk': id_value, 'name': 'sample document', 'key': 'value'} - custom_transport = FaultInjectionTransport() - predicate : Callable[[HttpRequest], bool] = lambda r: FaultInjectionTransport.predicate_req_for_document_with_id(r, id_value) - custom_transport.add_fault(predicate, lambda r: asyncio.create_task(FaultInjectionTransport.error_after_delay( + custom_transport = FaultInjectionTransportAsync() + predicate : Callable[[HttpRequest], bool] = lambda r: FaultInjectionTransportAsync.predicate_req_for_document_with_id(r, id_value) + custom_transport.add_fault(predicate, lambda r: asyncio.create_task(FaultInjectionTransportAsync.error_after_delay( 500, CosmosHttpResponseError( status_code=502, @@ -126,26 +121,25 @@ async def test_throws_injected_error(self, setup: object): finally: TestFaultInjectionTransportAsync.cleanup_method(initialized_objects) - async def test_swr_mrr_succeeds(self, setup: object): + async def test_swr_mrr_succeeds(self: "TestFaultInjectionTransportAsync"): expected_read_region_uri: str = test_config.TestConfig.local_host expected_write_region_uri: str = expected_read_region_uri.replace("localhost", "127.0.0.1") - custom_transport = FaultInjectionTransport() + custom_transport = FaultInjectionTransportAsync() # Inject rule to disallow writes in the read-only region is_write_operation_in_read_region_predicate: Callable[[HttpRequest], bool] = lambda \ - r: FaultInjectionTransport.predicate_is_write_operation(r, expected_read_region_uri) + r: FaultInjectionTransportAsync.predicate_is_write_operation(r, expected_read_region_uri) custom_transport.add_fault( is_write_operation_in_read_region_predicate, - lambda r: asyncio.create_task(FaultInjectionTransport.error_write_forbidden())) + lambda r: asyncio.create_task(FaultInjectionTransportAsync.error_write_forbidden())) # Inject topology transformation that would make Emulator look like a single write region # account with two read regions - is_get_account_predicate: Callable[[HttpRequest], bool] = lambda r: FaultInjectionTransport.predicate_is_database_account_call(r) - emulator_as_multi_region_sm_account_transformation: Callable[[HttpRequest, Callable[[HttpRequest], asyncio.Task[AsyncHttpResponse]]], AsyncHttpResponse] = \ - lambda r, inner: FaultInjectionTransport.transform_topology_swr_mrr( + is_get_account_predicate: Callable[[HttpRequest], bool] = lambda r: FaultInjectionTransportAsync.predicate_is_database_account_call(r) + emulator_as_multi_region_sm_account_transformation: Callable[[HttpRequest, Callable[[], AsyncHttpResponse]], AioHttpTransportResponse] = \ + lambda r, inner: FaultInjectionTransportAsync.transform_topology_swr_mrr( write_region_name="Write Region", read_region_name="Read Region", - r=r, inner=inner) custom_transport.add_response_transformation( is_get_account_predicate, @@ -178,34 +172,33 @@ async def test_swr_mrr_succeeds(self, setup: object): finally: TestFaultInjectionTransportAsync.cleanup_method(initialized_objects) - async def test_swr_mrr_region_down_read_succeeds(self, setup: object): + async def test_swr_mrr_region_down_read_succeeds(self: "TestFaultInjectionTransportAsync"): expected_read_region_uri: str = test_config.TestConfig.local_host expected_write_region_uri: str = expected_read_region_uri.replace("localhost", "127.0.0.1") - custom_transport = FaultInjectionTransport() + custom_transport = FaultInjectionTransportAsync() # Inject rule to disallow writes in the read-only region is_write_operation_in_read_region_predicate: Callable[[HttpRequest], bool] = lambda \ - r: FaultInjectionTransport.predicate_is_write_operation(r, expected_read_region_uri) + r: FaultInjectionTransportAsync.predicate_is_write_operation(r, expected_read_region_uri) custom_transport.add_fault( is_write_operation_in_read_region_predicate, - lambda r: asyncio.create_task(FaultInjectionTransport.error_write_forbidden())) + lambda r: asyncio.create_task(FaultInjectionTransportAsync.error_write_forbidden())) # Inject rule to simulate regional outage in "Read Region" is_request_to_read_region: Callable[[HttpRequest], bool] = lambda \ - r: FaultInjectionTransport.predicate_targets_region(r, expected_read_region_uri) + r: FaultInjectionTransportAsync.predicate_targets_region(r, expected_read_region_uri) custom_transport.add_fault( is_request_to_read_region, - lambda r: asyncio.create_task(FaultInjectionTransport.error_region_down())) + lambda r: asyncio.create_task(FaultInjectionTransportAsync.error_region_down())) # Inject topology transformation that would make Emulator look like a single write region # account with two read regions - is_get_account_predicate: Callable[[HttpRequest], bool] = lambda r: FaultInjectionTransport.predicate_is_database_account_call(r) - emulator_as_multi_region_sm_account_transformation: Callable[[HttpRequest, Callable[[HttpRequest], asyncio.Task[AsyncHttpResponse]]], AsyncHttpResponse] = \ - lambda r, inner: FaultInjectionTransport.transform_topology_swr_mrr( + is_get_account_predicate: Callable[[HttpRequest], bool] = lambda r: FaultInjectionTransportAsync.predicate_is_database_account_call(r) + emulator_as_multi_region_sm_account_transformation: Callable[[HttpRequest, Callable[[HttpRequest], AsyncHttpResponse]], AioHttpTransportResponse] = \ + lambda r, inner: FaultInjectionTransportAsync.transform_topology_swr_mrr( write_region_name="Write Region", read_region_name="Read Region", - r=r, inner=inner) custom_transport.add_response_transformation( is_get_account_predicate, @@ -239,26 +232,26 @@ async def test_swr_mrr_region_down_read_succeeds(self, setup: object): finally: TestFaultInjectionTransportAsync.cleanup_method(initialized_objects) - async def test_swr_mrr_region_down_envoy_read_succeeds(self, setup: object): + async def test_swr_mrr_region_down_envoy_read_succeeds(self: "TestFaultInjectionTransportAsync"): expected_read_region_uri: str = test_config.TestConfig.local_host expected_write_region_uri: str = expected_read_region_uri.replace("localhost", "127.0.0.1") - custom_transport = FaultInjectionTransport() + custom_transport = FaultInjectionTransportAsync() # Inject rule to disallow writes in the read-only region is_write_operation_in_read_region_predicate: Callable[[HttpRequest], bool] = lambda \ - r: FaultInjectionTransport.predicate_is_write_operation(r, expected_read_region_uri) + r: FaultInjectionTransportAsync.predicate_is_write_operation(r, expected_read_region_uri) custom_transport.add_fault( is_write_operation_in_read_region_predicate, - lambda r: asyncio.create_task(FaultInjectionTransport.error_write_forbidden())) + lambda r: asyncio.create_task(FaultInjectionTransportAsync.error_write_forbidden())) # Inject rule to simulate regional outage in "Read Region" is_request_to_read_region: Callable[[HttpRequest], bool] = lambda \ - r: FaultInjectionTransport.predicate_targets_region(r, expected_read_region_uri) and \ - FaultInjectionTransport.predicate_is_document_operation(r) + r: FaultInjectionTransportAsync.predicate_targets_region(r, expected_read_region_uri) and \ + FaultInjectionTransportAsync.predicate_is_document_operation(r) custom_transport.add_fault( is_request_to_read_region, - lambda r: asyncio.create_task(FaultInjectionTransport.error_after_delay( + lambda r: asyncio.create_task(FaultInjectionTransportAsync.error_after_delay( 35000, CosmosHttpResponseError( status_code=502, @@ -266,12 +259,11 @@ async def test_swr_mrr_region_down_envoy_read_succeeds(self, setup: object): # Inject topology transformation that would make Emulator look like a single write region # account with two read regions - is_get_account_predicate: Callable[[HttpRequest], bool] = lambda r: FaultInjectionTransport.predicate_is_database_account_call(r) - emulator_as_multi_region_sm_account_transformation: Callable[[HttpRequest, Callable[[HttpRequest], asyncio.Task[AsyncHttpResponse]]], AsyncHttpResponse] = \ - lambda r, inner: FaultInjectionTransport.transform_topology_swr_mrr( + is_get_account_predicate: Callable[[HttpRequest], bool] = lambda r: FaultInjectionTransportAsync.predicate_is_database_account_call(r) + emulator_as_multi_region_sm_account_transformation: Callable[[HttpRequest, Callable[[HttpRequest], asyncio.Task[AsyncHttpResponse]]], AioHttpTransportResponse] = \ + lambda r, inner: FaultInjectionTransportAsync.transform_topology_swr_mrr( write_region_name="Write Region", read_region_name="Read Region", - r=r, inner=inner) custom_transport.add_response_transformation( is_get_account_predicate, @@ -306,19 +298,17 @@ async def test_swr_mrr_region_down_envoy_read_succeeds(self, setup: object): TestFaultInjectionTransportAsync.cleanup_method(initialized_objects) - async def test_mwr_succeeds(self, setup: object): + async def test_mwr_succeeds(self: "TestFaultInjectionTransportAsync"): first_region_uri: str = test_config.TestConfig.local_host.replace("localhost", "127.0.0.1") - second_region_uri: str = test_config.TestConfig.local_host - custom_transport = FaultInjectionTransport() + custom_transport = FaultInjectionTransportAsync() # Inject topology transformation that would make Emulator look like a single write region # account with two read regions - is_get_account_predicate: Callable[[HttpRequest], bool] = lambda r: FaultInjectionTransport.predicate_is_database_account_call(r) - emulator_as_multi_write_region_account_transformation: Callable[[HttpRequest, Callable[[HttpRequest], asyncio.Task[AsyncHttpResponse]]], AsyncHttpResponse] = \ - lambda r, inner: FaultInjectionTransport.transform_topology_mwr( + is_get_account_predicate: Callable[[HttpRequest], bool] = lambda r: FaultInjectionTransportAsync.predicate_is_database_account_call(r) + emulator_as_multi_write_region_account_transformation: Callable[[HttpRequest, Callable[[HttpRequest], asyncio.Task[AsyncHttpResponse]]], AioHttpTransportResponse] = \ + lambda r, inner: FaultInjectionTransportAsync.transform_topology_mwr( first_region_name="First Region", second_region_name="Second Region", - r=r, inner=inner) custom_transport.add_response_transformation( is_get_account_predicate, From 622589fced33fa4ec415be94bbe8cb051156a804 Mon Sep 17 00:00:00 2001 From: Tomas Varon Saldarriaga Date: Wed, 2 Apr 2025 16:31:59 -0700 Subject: [PATCH 30/86] Fix refactoring --- .../tests/_fault_injection_transport_async.py | 40 +++++++++---------- 1 file changed, 20 insertions(+), 20 deletions(-) diff --git a/sdk/cosmos/azure-cosmos/tests/_fault_injection_transport_async.py b/sdk/cosmos/azure-cosmos/tests/_fault_injection_transport_async.py index 230d8f89dfe5..dec1699b8742 100644 --- a/sdk/cosmos/azure-cosmos/tests/_fault_injection_transport_async.py +++ b/sdk/cosmos/azure-cosmos/tests/_fault_injection_transport_async.py @@ -66,38 +66,38 @@ def __first_item(iterable, condition=lambda x: True): return next((x for x in iterable if condition(x)), None) async def send(self, request: HttpRequest, *, stream: bool = False, proxies: Optional[MutableMapping[str, str]] = None, **config) -> AsyncHttpResponse: - FaultInjectionTransport.logger.info("--> FaultInjectionTransport.Send {} {}".format(request.method, request.url)) + FaultInjectionTransportAsync.logger.info("--> FaultInjectionTransport.Send {} {}".format(request.method, request.url)) # find the first fault Factory with matching predicate if any - first_fault_factory = FaultInjectionTransport.__first_item(iter(self.faults), lambda f: f["predicate"](request)) + first_fault_factory = FaultInjectionTransportAsync.__first_item(iter(self.faults), lambda f: f["predicate"](request)) if first_fault_factory: - FaultInjectionTransport.logger.info("--> FaultInjectionTransport.ApplyFaultInjection") + FaultInjectionTransportAsync.logger.info("--> FaultInjectionTransport.ApplyFaultInjection") injected_error = await first_fault_factory["apply"](request) - FaultInjectionTransport.logger.info("Found to-be-injected error {}".format(injected_error)) + FaultInjectionTransportAsync.logger.info("Found to-be-injected error {}".format(injected_error)) raise injected_error # apply the chain of request transformations with matching predicates if any matching_request_transformations = filter(lambda f: f["predicate"](f["predicate"]), iter(self.requestTransformations)) for currentTransformation in matching_request_transformations: - FaultInjectionTransport.logger.info("--> FaultInjectionTransport.ApplyRequestTransformation") + FaultInjectionTransportAsync.logger.info("--> FaultInjectionTransport.ApplyRequestTransformation") request = await currentTransformation["apply"](request) - first_response_transformation = FaultInjectionTransport.__first_item(iter(self.responseTransformations), lambda f: f["predicate"](request)) + first_response_transformation = FaultInjectionTransportAsync.__first_item(iter(self.responseTransformations), lambda f: f["predicate"](request)) - FaultInjectionTransport.logger.info("--> FaultInjectionTransport.BeforeGetResponseTask") + FaultInjectionTransportAsync.logger.info("--> FaultInjectionTransport.BeforeGetResponseTask") get_response_task = asyncio.create_task(super().send(request, stream=stream, proxies=proxies, **config)) - FaultInjectionTransport.logger.info("<-- FaultInjectionTransport.AfterGetResponseTask") + FaultInjectionTransportAsync.logger.info("<-- FaultInjectionTransport.AfterGetResponseTask") if first_response_transformation: - FaultInjectionTransport.logger.info(f"Invoking response transformation") + FaultInjectionTransportAsync.logger.info(f"Invoking response transformation") response = await first_response_transformation["apply"](request, lambda: get_response_task) response.headers["_request"] = request - FaultInjectionTransport.logger.info(f"Received response transformation result with status code {response.status_code}") + FaultInjectionTransportAsync.logger.info(f"Received response transformation result with status code {response.status_code}") return response else: - FaultInjectionTransport.logger.info(f"Sending request to {request.url}") + FaultInjectionTransportAsync.logger.info(f"Sending request to {request.url}") response = await get_response_task response.headers["_request"] = request - FaultInjectionTransport.logger.info(f"Received response with status code {response.status_code}") + FaultInjectionTransportAsync.logger.info(f"Received response with status code {response.status_code}") return response @staticmethod @@ -125,8 +125,8 @@ def predicate_req_payload_contains_id(r: HttpRequest, id_value: str): @staticmethod def predicate_req_for_document_with_id(r: HttpRequest, id_value: str) -> bool: - return (FaultInjectionTransport.predicate_url_contains_id(r, id_value) - or FaultInjectionTransport.predicate_req_payload_contains_id(r, id_value)) + return (FaultInjectionTransportAsync.predicate_url_contains_id(r, id_value) + or FaultInjectionTransportAsync.predicate_req_payload_contains_id(r, id_value)) @staticmethod def predicate_is_database_account_call(r: HttpRequest) -> bool: @@ -175,7 +175,7 @@ async def transform_topology_swr_mrr( inner: Callable[[],asyncio.Task[AioHttpTransportResponse]]) -> AioHttpTransportResponse: response = await inner() - if not FaultInjectionTransport.predicate_is_database_account_call(response.request): + if not FaultInjectionTransportAsync.predicate_is_database_account_call(response.request): return response data = response.body() @@ -187,9 +187,9 @@ async def transform_topology_swr_mrr( readable_locations[0]["name"] = write_region_name writable_locations[0]["name"] = write_region_name readable_locations.append({"name": read_region_name, "databaseAccountEndpoint" : test_config.TestConfig.local_host}) - FaultInjectionTransport.logger.info("Transformed Account Topology: {}".format(result)) + FaultInjectionTransportAsync.logger.info("Transformed Account Topology: {}".format(result)) request: HttpRequest = response.request - return FaultInjectionTransport.MockHttpResponse(request, 200, result) + return FaultInjectionTransportAsync.MockHttpResponse(request, 200, result) return response @@ -200,7 +200,7 @@ async def transform_topology_mwr( inner: Callable[[], asyncio.Task[AioHttpTransportResponse]]) -> AioHttpTransportResponse: response = await inner() - if not FaultInjectionTransport.predicate_is_database_account_call(response.request): + if not FaultInjectionTransportAsync.predicate_is_database_account_call(response.request): return response data = response.body() @@ -216,9 +216,9 @@ async def transform_topology_mwr( writable_locations.append( {"name": second_region_name, "databaseAccountEndpoint": test_config.TestConfig.local_host}) result["enableMultipleWriteLocations"] = True - FaultInjectionTransport.logger.info("Transformed Account Topology: {}".format(result)) + FaultInjectionTransportAsync.logger.info("Transformed Account Topology: {}".format(result)) request: HttpRequest = response.request - return FaultInjectionTransport.MockHttpResponse(request, 200, result) + return FaultInjectionTransportAsync.MockHttpResponse(request, 200, result) return response From 4dd17ea97d5a877635f9f0472463fb9196144a2a Mon Sep 17 00:00:00 2001 From: Tomas Varon Saldarriaga Date: Wed, 2 Apr 2025 17:01:29 -0700 Subject: [PATCH 31/86] Fix tests --- .../tests/test_fault_injection_transport_async.py | 9 ++------- 1 file changed, 2 insertions(+), 7 deletions(-) diff --git a/sdk/cosmos/azure-cosmos/tests/test_fault_injection_transport_async.py b/sdk/cosmos/azure-cosmos/tests/test_fault_injection_transport_async.py index 2227019cbe63..05fbd6f1f658 100644 --- a/sdk/cosmos/azure-cosmos/tests/test_fault_injection_transport_async.py +++ b/sdk/cosmos/azure-cosmos/tests/test_fault_injection_transport_async.py @@ -30,7 +30,6 @@ host = test_config.TestConfig.host master_key = test_config.TestConfig.masterKey -connection_policy = test_config.TestConfig.connectionPolicy TEST_DATABASE_ID = test_config.TestConfig.TEST_DATABASE_ID single_partition_container_name = os.path.basename(__file__) + str(uuid.uuid4()) @@ -49,13 +48,10 @@ def setup_class(cls): "You must specify your Azure Cosmos account values for " "'masterKey' and 'host' at the top of this class to run the " "tests.") - - cls.connection_policy = connection_policy cls.database_id = TEST_DATABASE_ID cls.single_partition_container_name = single_partition_container_name - cls.mgmt_client = CosmosClient(cls.host, cls.master_key, consistency_level="Session", - connection_policy=cls.connection_policy, logger=logger) + cls.mgmt_client = CosmosClient(cls.host, cls.master_key, consistency_level="Session", logger=logger) created_database = cls.mgmt_client.get_database_client(cls.database_id) asyncio.run(asyncio.wait_for( created_database.create_container( @@ -81,8 +77,7 @@ def teardown_class(cls): def setup_method_with_custom_transport(self, custom_transport: AioHttpTransport, default_endpoint=host, **kwargs): client = CosmosClient(default_endpoint, master_key, consistency_level="Session", - connection_policy=connection_policy, transport=custom_transport, - logger=logger, enable_diagnostics_logging=True, **kwargs) + transport=custom_transport, logger=logger, enable_diagnostics_logging=True, **kwargs) db: DatabaseProxy = client.get_database_client(TEST_DATABASE_ID) container: ContainerProxy = db.get_container_client(single_partition_container_name) return {"client": client, "db": db, "col": container} From e631b74e51d2c136b79cb3d54a5263f5c44eb1db Mon Sep 17 00:00:00 2001 From: tvaron3 Date: Wed, 2 Apr 2025 17:26:00 -0700 Subject: [PATCH 32/86] fix tests --- sdk/cosmos/azure-cosmos/pytest.ini | 2 +- .../tests/_fault_injection_transport_async.py | 8 +-- .../test_fault_injection_transport_async.py | 49 ++++++++++++++++++- 3 files changed, 52 insertions(+), 7 deletions(-) diff --git a/sdk/cosmos/azure-cosmos/pytest.ini b/sdk/cosmos/azure-cosmos/pytest.ini index e211052edef0..647aac1464f8 100644 --- a/sdk/cosmos/azure-cosmos/pytest.ini +++ b/sdk/cosmos/azure-cosmos/pytest.ini @@ -1,3 +1,3 @@ [pytest] markers = - cosmosEmulator: marks tests as depending in Cosmos DB Emulator \ No newline at end of file + cosmosEmulator: marks tests as depending in Cosmos DB Emulator diff --git a/sdk/cosmos/azure-cosmos/tests/_fault_injection_transport_async.py b/sdk/cosmos/azure-cosmos/tests/_fault_injection_transport_async.py index dec1699b8742..428059231796 100644 --- a/sdk/cosmos/azure-cosmos/tests/_fault_injection_transport_async.py +++ b/sdk/cosmos/azure-cosmos/tests/_fault_injection_transport_async.py @@ -27,7 +27,7 @@ import logging import sys from collections.abc import MutableMapping -from typing import Callable, Optional, Any, Dict, List +from typing import Callable, Optional, Any, Dict, List, Awaitable import aiohttp from azure.core.pipeline.transport import AioHttpTransport, AioHttpTransportResponse @@ -48,7 +48,7 @@ def __init__(self, *, session: Optional[aiohttp.ClientSession] = None, loop=None self.responseTransformations: List[Dict[str, Any]] = [] super().__init__(session=session, loop=loop, session_owner=session_owner, **config) - def add_fault(self, predicate: Callable[[HttpRequest], bool], fault_factory: Callable[[HttpRequest], asyncio.Task[Exception]]): + def add_fault(self, predicate: Callable[[HttpRequest], bool], fault_factory: Callable[[HttpRequest], Awaitable[Exception]]): self.faults.append({"predicate": predicate, "apply": fault_factory}) def add_response_transformation(self, predicate: Callable[[HttpRequest], bool], response_transformation: Callable[[HttpRequest, Callable[[HttpRequest], AioHttpTransportResponse]], AioHttpTransportResponse]): @@ -172,7 +172,7 @@ async def error_region_down() -> Exception: async def transform_topology_swr_mrr( write_region_name: str, read_region_name: str, - inner: Callable[[],asyncio.Task[AioHttpTransportResponse]]) -> AioHttpTransportResponse: + inner: Callable[[], Awaitable[AioHttpTransportResponse]]) -> AioHttpTransportResponse: response = await inner() if not FaultInjectionTransportAsync.predicate_is_database_account_call(response.request): @@ -197,7 +197,7 @@ async def transform_topology_swr_mrr( async def transform_topology_mwr( first_region_name: str, second_region_name: str, - inner: Callable[[], asyncio.Task[AioHttpTransportResponse]]) -> AioHttpTransportResponse: + inner: Callable[[], Awaitable[AioHttpTransportResponse]]) -> AioHttpTransportResponse: response = await inner() if not FaultInjectionTransportAsync.predicate_is_database_account_call(response.request): diff --git a/sdk/cosmos/azure-cosmos/tests/test_fault_injection_transport_async.py b/sdk/cosmos/azure-cosmos/tests/test_fault_injection_transport_async.py index 05fbd6f1f658..0a186afbb875 100644 --- a/sdk/cosmos/azure-cosmos/tests/test_fault_injection_transport_async.py +++ b/sdk/cosmos/azure-cosmos/tests/test_fault_injection_transport_async.py @@ -75,7 +75,8 @@ def teardown_class(cls): except Exception as closeError: logger.warning("Exception trying to delete database {}. {}".format(created_database.id, closeError)) - def setup_method_with_custom_transport(self, custom_transport: AioHttpTransport, default_endpoint=host, **kwargs): + @staticmethod + def setup_method_with_custom_transport(custom_transport: AioHttpTransport, default_endpoint=host, **kwargs): client = CosmosClient(default_endpoint, master_key, consistency_level="Session", transport=custom_transport, logger=logger, enable_diagnostics_logging=True, **kwargs) db: DatabaseProxy = client.get_database_client(TEST_DATABASE_ID) @@ -247,7 +248,7 @@ async def test_swr_mrr_region_down_envoy_read_succeeds(self: "TestFaultInjection custom_transport.add_fault( is_request_to_read_region, lambda r: asyncio.create_task(FaultInjectionTransportAsync.error_after_delay( - 35000, + 500, CosmosHttpResponseError( status_code=502, message="Some random reverse proxy error.")))) @@ -336,5 +337,49 @@ async def test_mwr_succeeds(self: "TestFaultInjectionTransportAsync"): finally: TestFaultInjectionTransportAsync.cleanup_method(initialized_objects) + # add a test for delays + # add test for complete failures + + async def test_mwr_region_down_succeeds(self: "TestFaultInjectionTransportAsync"): + second_region_uri: str = test_config.TestConfig.local_host + custom_transport = FaultInjectionTransportAsync() + + # Inject topology transformation that would make Emulator look like a single write region + # account with two read regions + is_get_account_predicate: Callable[[HttpRequest], bool] = lambda r: FaultInjectionTransportAsync.predicate_is_database_account_call(r) + emulator_as_multi_write_region_account_transformation: Callable[[HttpRequest, Callable[[HttpRequest], asyncio.Task[AsyncHttpResponse]]], AioHttpTransportResponse] = \ + lambda r, inner: FaultInjectionTransportAsync.transform_topology_mwr( + first_region_name="First Region", + second_region_name="Second Region", + inner=inner) + custom_transport.add_response_transformation( + is_get_account_predicate, + emulator_as_multi_write_region_account_transformation) + + id_value: str = str(uuid.uuid4()) + document_definition = {'id': id_value, + 'pk': id_value, + 'name': 'sample document', + 'key': 'value'} + + initialized_objects = self.setup_method_with_custom_transport( + custom_transport, + preferred_locations=["First Region", "Second Region"]) + try: + container: ContainerProxy = initialized_objects["col"] + + start:float = time.perf_counter() + while (time.perf_counter() - start) < 2: + upsert_document = await container.upsert_item(body=document_definition) + request = upsert_document.get_response_headers()["_request"] + assert request.url.startswith(second_region_uri) + read_document = await container.read_item(id_value, partition_key=id_value) + request = read_document.get_response_headers()["_request"] + # Validate the response comes from "East US" (the most preferred read-only region) + assert request.url.startswith(second_region_uri) + + finally: + TestFaultInjectionTransportAsync.cleanup_method(initialized_objects) + if __name__ == '__main__': unittest.main() From 2d5f0d7d624658c9f4a36d205010904b6ba3aa6d Mon Sep 17 00:00:00 2001 From: tvaron3 Date: Wed, 2 Apr 2025 17:52:56 -0700 Subject: [PATCH 33/86] add more tests --- .../tests/_fault_injection_transport_async.py | 4 +-- .../test_fault_injection_transport_async.py | 25 ++++++++++++++++--- 2 files changed, 22 insertions(+), 7 deletions(-) diff --git a/sdk/cosmos/azure-cosmos/tests/_fault_injection_transport_async.py b/sdk/cosmos/azure-cosmos/tests/_fault_injection_transport_async.py index 428059231796..f5b9c8b126dd 100644 --- a/sdk/cosmos/azure-cosmos/tests/_fault_injection_transport_async.py +++ b/sdk/cosmos/azure-cosmos/tests/_fault_injection_transport_async.py @@ -26,9 +26,7 @@ import json import logging import sys -from collections.abc import MutableMapping -from typing import Callable, Optional, Any, Dict, List, Awaitable - +from typing import Callable, Optional, Any, Dict, List, Awaitable, MutableMapping import aiohttp from azure.core.pipeline.transport import AioHttpTransport, AioHttpTransportResponse from azure.core.rest import HttpRequest, AsyncHttpResponse diff --git a/sdk/cosmos/azure-cosmos/tests/test_fault_injection_transport_async.py b/sdk/cosmos/azure-cosmos/tests/test_fault_injection_transport_async.py index 0a186afbb875..c6e0e5b39b5c 100644 --- a/sdk/cosmos/azure-cosmos/tests/test_fault_injection_transport_async.py +++ b/sdk/cosmos/azure-cosmos/tests/test_fault_injection_transport_async.py @@ -101,17 +101,21 @@ async def test_throws_injected_error(self: "TestFaultInjectionTransportAsync"): custom_transport = FaultInjectionTransportAsync() predicate : Callable[[HttpRequest], bool] = lambda r: FaultInjectionTransportAsync.predicate_req_for_document_with_id(r, id_value) custom_transport.add_fault(predicate, lambda r: asyncio.create_task(FaultInjectionTransportAsync.error_after_delay( - 500, + 10000, CosmosHttpResponseError( status_code=502, message="Some random reverse proxy error.")))) initialized_objects = self.setup_method_with_custom_transport(custom_transport) + start: float = time.perf_counter() try: container: ContainerProxy = initialized_objects["col"] await container.create_item(body=document_definition) pytest.fail("Expected exception not thrown") except CosmosHttpResponseError as cosmosError: + end = time.perf_counter() - start + # validate response took more than 10 seconds + assert end > 10 if cosmosError.status_code != 502: raise cosmosError finally: @@ -298,7 +302,7 @@ async def test_mwr_succeeds(self: "TestFaultInjectionTransportAsync"): first_region_uri: str = test_config.TestConfig.local_host.replace("localhost", "127.0.0.1") custom_transport = FaultInjectionTransportAsync() - # Inject topology transformation that would make Emulator look like a single write region + # Inject topology transformation that would make Emulator look like a multiple write region account # account with two read regions is_get_account_predicate: Callable[[HttpRequest], bool] = lambda r: FaultInjectionTransportAsync.predicate_is_database_account_call(r) emulator_as_multi_write_region_account_transformation: Callable[[HttpRequest, Callable[[HttpRequest], asyncio.Task[AsyncHttpResponse]]], AioHttpTransportResponse] = \ @@ -337,14 +341,14 @@ async def test_mwr_succeeds(self: "TestFaultInjectionTransportAsync"): finally: TestFaultInjectionTransportAsync.cleanup_method(initialized_objects) - # add a test for delays # add test for complete failures async def test_mwr_region_down_succeeds(self: "TestFaultInjectionTransportAsync"): + first_region_uri: str = test_config.TestConfig.local_host.replace("localhost", "127.0.0.1") second_region_uri: str = test_config.TestConfig.local_host custom_transport = FaultInjectionTransportAsync() - # Inject topology transformation that would make Emulator look like a single write region + # Inject topology transformation that would make Emulator look like a multiple write region account # account with two read regions is_get_account_predicate: Callable[[HttpRequest], bool] = lambda r: FaultInjectionTransportAsync.predicate_is_database_account_call(r) emulator_as_multi_write_region_account_transformation: Callable[[HttpRequest, Callable[[HttpRequest], asyncio.Task[AsyncHttpResponse]]], AioHttpTransportResponse] = \ @@ -356,6 +360,19 @@ async def test_mwr_region_down_succeeds(self: "TestFaultInjectionTransportAsync" is_get_account_predicate, emulator_as_multi_write_region_account_transformation) + # Inject rule to simulate regional outage in "First Region" + is_request_to_read_region: Callable[[HttpRequest], bool] = lambda \ + r: FaultInjectionTransportAsync.predicate_targets_region(r, first_region_uri) and \ + FaultInjectionTransportAsync.predicate_is_document_operation(r) + + custom_transport.add_fault( + is_request_to_read_region, + lambda r: asyncio.create_task(FaultInjectionTransportAsync.error_after_delay( + 500, + CosmosHttpResponseError( + status_code=408, + message="Induced Request Timeout")))) + id_value: str = str(uuid.uuid4()) document_definition = {'id': id_value, 'pk': id_value, From f04a50695dc8ad74f2ce25090639e17a67a4320b Mon Sep 17 00:00:00 2001 From: tvaron3 Date: Wed, 2 Apr 2025 17:54:44 -0700 Subject: [PATCH 34/86] add more tests --- .../azure-cosmos/tests/test_fault_injection_transport_async.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sdk/cosmos/azure-cosmos/tests/test_fault_injection_transport_async.py b/sdk/cosmos/azure-cosmos/tests/test_fault_injection_transport_async.py index c6e0e5b39b5c..238e10efba3e 100644 --- a/sdk/cosmos/azure-cosmos/tests/test_fault_injection_transport_async.py +++ b/sdk/cosmos/azure-cosmos/tests/test_fault_injection_transport_async.py @@ -351,7 +351,7 @@ async def test_mwr_region_down_succeeds(self: "TestFaultInjectionTransportAsync" # Inject topology transformation that would make Emulator look like a multiple write region account # account with two read regions is_get_account_predicate: Callable[[HttpRequest], bool] = lambda r: FaultInjectionTransportAsync.predicate_is_database_account_call(r) - emulator_as_multi_write_region_account_transformation: Callable[[HttpRequest, Callable[[HttpRequest], asyncio.Task[AsyncHttpResponse]]], AioHttpTransportResponse] = \ + emulator_as_multi_write_region_account_transformation = \ lambda r, inner: FaultInjectionTransportAsync.transform_topology_mwr( first_region_name="First Region", second_region_name="Second Region", From bcee9cf4888f1e8326ff307454136aba4d8edf54 Mon Sep 17 00:00:00 2001 From: Tomas Varon Saldarriaga Date: Wed, 2 Apr 2025 18:24:06 -0700 Subject: [PATCH 35/86] Add tests --- .../test_fault_injection_transport_async.py | 139 ++++++++++++++++-- 1 file changed, 126 insertions(+), 13 deletions(-) diff --git a/sdk/cosmos/azure-cosmos/tests/test_fault_injection_transport_async.py b/sdk/cosmos/azure-cosmos/tests/test_fault_injection_transport_async.py index 238e10efba3e..5fcc953df5f2 100644 --- a/sdk/cosmos/azure-cosmos/tests/test_fault_injection_transport_async.py +++ b/sdk/cosmos/azure-cosmos/tests/test_fault_injection_transport_async.py @@ -22,6 +22,7 @@ from azure.cosmos.aio._container import ContainerProxy from azure.cosmos.aio._database import DatabaseProxy from azure.cosmos.exceptions import CosmosHttpResponseError +from azure.core.exceptions import ServiceRequestError MGMT_TIMEOUT = 3.0 logger = logging.getLogger('azure.cosmos') @@ -91,7 +92,7 @@ def cleanup_method(initialized_objects: dict[str, Any]): except Exception as close_error: logger.warning(f"Exception trying to close method client. {close_error}") - async def test_throws_injected_error(self: "TestFaultInjectionTransportAsync"): + async def test_throws_injected_error_async(self: "TestFaultInjectionTransportAsync"): id_value: str = str(uuid.uuid4()) document_definition = {'id': id_value, 'pk': id_value, @@ -121,7 +122,7 @@ async def test_throws_injected_error(self: "TestFaultInjectionTransportAsync"): finally: TestFaultInjectionTransportAsync.cleanup_method(initialized_objects) - async def test_swr_mrr_succeeds(self: "TestFaultInjectionTransportAsync"): + async def test_swr_mrr_succeeds_async(self: "TestFaultInjectionTransportAsync"): expected_read_region_uri: str = test_config.TestConfig.local_host expected_write_region_uri: str = expected_read_region_uri.replace("localhost", "127.0.0.1") custom_transport = FaultInjectionTransportAsync() @@ -172,7 +173,7 @@ async def test_swr_mrr_succeeds(self: "TestFaultInjectionTransportAsync"): finally: TestFaultInjectionTransportAsync.cleanup_method(initialized_objects) - async def test_swr_mrr_region_down_read_succeeds(self: "TestFaultInjectionTransportAsync"): + async def test_swr_mrr_region_down_read_succeeds_async(self: "TestFaultInjectionTransportAsync"): expected_read_region_uri: str = test_config.TestConfig.local_host expected_write_region_uri: str = expected_read_region_uri.replace("localhost", "127.0.0.1") custom_transport = FaultInjectionTransportAsync() @@ -232,7 +233,7 @@ async def test_swr_mrr_region_down_read_succeeds(self: "TestFaultInjectionTransp finally: TestFaultInjectionTransportAsync.cleanup_method(initialized_objects) - async def test_swr_mrr_region_down_envoy_read_succeeds(self: "TestFaultInjectionTransportAsync"): + async def test_swr_mrr_region_down_envoy_read_succeeds_async(self: "TestFaultInjectionTransportAsync"): expected_read_region_uri: str = test_config.TestConfig.local_host expected_write_region_uri: str = expected_read_region_uri.replace("localhost", "127.0.0.1") custom_transport = FaultInjectionTransportAsync() @@ -298,7 +299,7 @@ async def test_swr_mrr_region_down_envoy_read_succeeds(self: "TestFaultInjection TestFaultInjectionTransportAsync.cleanup_method(initialized_objects) - async def test_mwr_succeeds(self: "TestFaultInjectionTransportAsync"): + async def test_mwr_succeeds_async(self: "TestFaultInjectionTransportAsync"): first_region_uri: str = test_config.TestConfig.local_host.replace("localhost", "127.0.0.1") custom_transport = FaultInjectionTransportAsync() @@ -343,7 +344,7 @@ async def test_mwr_succeeds(self: "TestFaultInjectionTransportAsync"): # add test for complete failures - async def test_mwr_region_down_succeeds(self: "TestFaultInjectionTransportAsync"): + async def test_mwr_region_down_succeeds_async(self: "TestFaultInjectionTransportAsync"): first_region_uri: str = test_config.TestConfig.local_host.replace("localhost", "127.0.0.1") second_region_uri: str = test_config.TestConfig.local_host custom_transport = FaultInjectionTransportAsync() @@ -361,17 +362,13 @@ async def test_mwr_region_down_succeeds(self: "TestFaultInjectionTransportAsync" emulator_as_multi_write_region_account_transformation) # Inject rule to simulate regional outage in "First Region" - is_request_to_read_region: Callable[[HttpRequest], bool] = lambda \ + is_request_to_first_region: Callable[[HttpRequest], bool] = lambda \ r: FaultInjectionTransportAsync.predicate_targets_region(r, first_region_uri) and \ FaultInjectionTransportAsync.predicate_is_document_operation(r) custom_transport.add_fault( - is_request_to_read_region, - lambda r: asyncio.create_task(FaultInjectionTransportAsync.error_after_delay( - 500, - CosmosHttpResponseError( - status_code=408, - message="Induced Request Timeout")))) + is_request_to_first_region, + lambda r: asyncio.create_task(FaultInjectionTransportAsync.error_region_down())) id_value: str = str(uuid.uuid4()) document_definition = {'id': id_value, @@ -387,6 +384,7 @@ async def test_mwr_region_down_succeeds(self: "TestFaultInjectionTransportAsync" start:float = time.perf_counter() while (time.perf_counter() - start) < 2: + # reads and writes should failover to second region upsert_document = await container.upsert_item(body=document_definition) request = upsert_document.get_response_headers()["_request"] assert request.url.startswith(second_region_uri) @@ -398,5 +396,120 @@ async def test_mwr_region_down_succeeds(self: "TestFaultInjectionTransportAsync" finally: TestFaultInjectionTransportAsync.cleanup_method(initialized_objects) + async def test_swr_mrr_all_regions_down_for_read_async(self: "TestFaultInjectionTransportAsync"): + expected_read_region_uri: str = test_config.TestConfig.local_host + expected_write_region_uri: str = expected_read_region_uri.replace("localhost", "127.0.0.1") + custom_transport = FaultInjectionTransportAsync() + + # Inject rule to disallow writes in the read-only region + is_write_operation_in_read_region_predicate: Callable[[HttpRequest], bool] = lambda \ + r: FaultInjectionTransportAsync.predicate_is_write_operation(r, expected_read_region_uri) + + custom_transport.add_fault( + is_write_operation_in_read_region_predicate, + lambda r: asyncio.create_task(FaultInjectionTransportAsync.error_write_forbidden())) + + # Inject topology transformation that would make Emulator look like a multiple write region account + # account with two read regions + is_get_account_predicate: Callable[[HttpRequest], bool] = lambda \ + r: FaultInjectionTransportAsync.predicate_is_database_account_call(r) + emulator_as_multi_write_region_account_transformation = \ + lambda r, inner: FaultInjectionTransportAsync.transform_topology_swr_mrr( + write_region_name="Write Region", + read_region_name="Read Region", + inner=inner) + custom_transport.add_response_transformation( + is_get_account_predicate, + emulator_as_multi_write_region_account_transformation) + + # Inject rule to simulate regional outage in "First Region" + is_request_to_first_region: Callable[[HttpRequest], bool] = lambda \ + r: (FaultInjectionTransportAsync.predicate_targets_region(r, expected_write_region_uri) and + FaultInjectionTransportAsync.predicate_is_document_operation(r) and + not FaultInjectionTransportAsync.predicate_is_write_operation(r, expected_write_region_uri)) + + + # Inject rule to simulate regional outage in "Second Region" + is_request_to_second_region: Callable[[HttpRequest], bool] = lambda \ + r: (FaultInjectionTransportAsync.predicate_targets_region(r, expected_read_region_uri) and + FaultInjectionTransportAsync.predicate_is_document_operation(r) and + not FaultInjectionTransportAsync.predicate_is_write_operation(r, expected_write_region_uri)) + + custom_transport.add_fault( + is_request_to_first_region, + lambda r: asyncio.create_task(FaultInjectionTransportAsync.error_region_down())) + custom_transport.add_fault( + is_request_to_second_region, + lambda r: asyncio.create_task(FaultInjectionTransportAsync.error_region_down())) + + id_value: str = str(uuid.uuid4()) + document_definition = {'id': id_value, + 'pk': id_value, + 'name': 'sample document', + 'key': 'value'} + + initialized_objects = self.setup_method_with_custom_transport( + custom_transport, + preferred_locations=["First Region", "Second Region"]) + try: + container: ContainerProxy = initialized_objects["col"] + await container.upsert_item(body=document_definition) + with pytest.raises(ServiceRequestError): + await container.read_item(id_value, partition_key=id_value) + finally: + TestFaultInjectionTransportAsync.cleanup_method(initialized_objects) + + async def test_mwr_all_regions_down_async(self: "TestFaultInjectionTransportAsync"): + + first_region_uri: str = test_config.TestConfig.local_host.replace("localhost", "127.0.0.1") + second_region_uri: str = test_config.TestConfig.local_host + custom_transport = FaultInjectionTransportAsync() + + # Inject topology transformation that would make Emulator look like a multiple write region account + # account with two read regions + is_get_account_predicate: Callable[[HttpRequest], bool] = lambda \ + r: FaultInjectionTransportAsync.predicate_is_database_account_call(r) + emulator_as_multi_write_region_account_transformation = \ + lambda r, inner: FaultInjectionTransportAsync.transform_topology_mwr( + first_region_name="First Region", + second_region_name="Second Region", + inner=inner) + custom_transport.add_response_transformation( + is_get_account_predicate, + emulator_as_multi_write_region_account_transformation) + + # Inject rule to simulate regional outage in "First Region" + is_request_to_first_region: Callable[[HttpRequest], bool] = lambda \ + r: FaultInjectionTransportAsync.predicate_targets_region(r, first_region_uri) and \ + FaultInjectionTransportAsync.predicate_is_document_operation(r) + + # Inject rule to simulate regional outage in "Second Region" + is_request_to_second_region: Callable[[HttpRequest], bool] = lambda \ + r: FaultInjectionTransportAsync.predicate_targets_region(r, second_region_uri) and \ + FaultInjectionTransportAsync.predicate_is_document_operation(r) + + custom_transport.add_fault( + is_request_to_first_region, + lambda r: asyncio.create_task(FaultInjectionTransportAsync.error_region_down())) + custom_transport.add_fault( + is_request_to_second_region, + lambda r: asyncio.create_task(FaultInjectionTransportAsync.error_region_down())) + + id_value: str = str(uuid.uuid4()) + document_definition = {'id': id_value, + 'pk': id_value, + 'name': 'sample document', + 'key': 'value'} + + initialized_objects = self.setup_method_with_custom_transport( + custom_transport, + preferred_locations=["First Region", "Second Region"]) + try: + container: ContainerProxy = initialized_objects["col"] + with pytest.raises(ServiceRequestError): + await container.upsert_item(body=document_definition) + finally: + TestFaultInjectionTransportAsync.cleanup_method(initialized_objects) + if __name__ == '__main__': unittest.main() From 779f9d19b6a8bacca8d3cee2481b4a0cf34c5dcd Mon Sep 17 00:00:00 2001 From: Tomas Varon Saldarriaga Date: Wed, 2 Apr 2025 23:27:56 -0700 Subject: [PATCH 36/86] fix tests --- sdk/cosmos/azure-cosmos/pytest.ini | 5 ++++- .../tests/_fault_injection_transport_async.py | 6 +++--- .../tests/test_fault_injection_transport_async.py | 12 ++++++------ .../azure-cosmos/tests/test_feed_range_async.py | 2 +- 4 files changed, 14 insertions(+), 11 deletions(-) diff --git a/sdk/cosmos/azure-cosmos/pytest.ini b/sdk/cosmos/azure-cosmos/pytest.ini index 647aac1464f8..0ea65741e343 100644 --- a/sdk/cosmos/azure-cosmos/pytest.ini +++ b/sdk/cosmos/azure-cosmos/pytest.ini @@ -1,3 +1,6 @@ [pytest] markers = - cosmosEmulator: marks tests as depending in Cosmos DB Emulator + cosmosEmulator: marks tests as depending in Cosmos DB Emulator. + cosmosLong: marks tests to be run on a Cosmos DB live account. + cosmosQuery: marks tests running queries on Cosmos DB live account. + cosmosSplit: marks test where there are partition splits on CosmosDB live account. diff --git a/sdk/cosmos/azure-cosmos/tests/_fault_injection_transport_async.py b/sdk/cosmos/azure-cosmos/tests/_fault_injection_transport_async.py index f5b9c8b126dd..4551b0235bad 100644 --- a/sdk/cosmos/azure-cosmos/tests/_fault_injection_transport_async.py +++ b/sdk/cosmos/azure-cosmos/tests/_fault_injection_transport_async.py @@ -221,7 +221,7 @@ async def transform_topology_mwr( return response class MockHttpResponse(AioHttpTransportResponse): - def __init__(self, request: HttpRequest, status_code: int, content:Optional[dict[str, Any]]): + def __init__(self, request: HttpRequest, status_code: int, content:Optional[Dict[str, Any]]): self.request: HttpRequest = request # This is actually never None, and set by all implementations after the call to # __init__ of this class. This class is also a legacy impl, so it's risky to change it @@ -232,7 +232,7 @@ def __init__(self, request: HttpRequest, status_code: int, content:Optional[dict self.reason: Optional[str] = None self.content_type: Optional[str] = None self.block_size: int = 4096 # Default to same as R - self.content: Optional[dict[str, Any]] = None + self.content: Optional[Dict[str, Any]] = None self.json_text: Optional[str] = None self.bytes: Optional[bytes] = None if content: @@ -248,4 +248,4 @@ def text(self) -> Optional[str]: return self.json_text async def load_body(self) -> None: - return \ No newline at end of file + return diff --git a/sdk/cosmos/azure-cosmos/tests/test_fault_injection_transport_async.py b/sdk/cosmos/azure-cosmos/tests/test_fault_injection_transport_async.py index 5fcc953df5f2..3f0507acc038 100644 --- a/sdk/cosmos/azure-cosmos/tests/test_fault_injection_transport_async.py +++ b/sdk/cosmos/azure-cosmos/tests/test_fault_injection_transport_async.py @@ -8,7 +8,7 @@ import time import unittest import uuid -from typing import Any, Callable +from typing import Any, Callable, Awaitable, Dict import pytest from azure.core.pipeline.transport import AioHttpTransport @@ -24,7 +24,7 @@ from azure.cosmos.exceptions import CosmosHttpResponseError from azure.core.exceptions import ServiceRequestError -MGMT_TIMEOUT = 3.0 +MGMT_TIMEOUT = 10.0 logger = logging.getLogger('azure.cosmos') logger.setLevel(logging.DEBUG) logger.addHandler(logging.StreamHandler(sys.stdout)) @@ -74,7 +74,7 @@ def teardown_class(cls): try: asyncio.run(asyncio.wait_for(cls.mgmt_client.close(), MGMT_TIMEOUT)) except Exception as closeError: - logger.warning("Exception trying to delete database {}. {}".format(created_database.id, closeError)) + logger.warning("Exception trying to close client {}. {}".format(created_database.id, closeError)) @staticmethod def setup_method_with_custom_transport(custom_transport: AioHttpTransport, default_endpoint=host, **kwargs): @@ -85,7 +85,7 @@ def setup_method_with_custom_transport(custom_transport: AioHttpTransport, defau return {"client": client, "db": db, "col": container} @staticmethod - def cleanup_method(initialized_objects: dict[str, Any]): + def cleanup_method(initialized_objects: Dict[str, Any]): method_client: CosmosClient = initialized_objects["client"] try: asyncio.run(asyncio.wait_for(method_client.close(), MGMT_TIMEOUT)) @@ -261,7 +261,7 @@ async def test_swr_mrr_region_down_envoy_read_succeeds_async(self: "TestFaultInj # Inject topology transformation that would make Emulator look like a single write region # account with two read regions is_get_account_predicate: Callable[[HttpRequest], bool] = lambda r: FaultInjectionTransportAsync.predicate_is_database_account_call(r) - emulator_as_multi_region_sm_account_transformation: Callable[[HttpRequest, Callable[[HttpRequest], asyncio.Task[AsyncHttpResponse]]], AioHttpTransportResponse] = \ + emulator_as_multi_region_sm_account_transformation: Callable[[HttpRequest, Callable[[HttpRequest], Awaitable[AsyncHttpResponse]]], AioHttpTransportResponse] = \ lambda r, inner: FaultInjectionTransportAsync.transform_topology_swr_mrr( write_region_name="Write Region", read_region_name="Read Region", @@ -306,7 +306,7 @@ async def test_mwr_succeeds_async(self: "TestFaultInjectionTransportAsync"): # Inject topology transformation that would make Emulator look like a multiple write region account # account with two read regions is_get_account_predicate: Callable[[HttpRequest], bool] = lambda r: FaultInjectionTransportAsync.predicate_is_database_account_call(r) - emulator_as_multi_write_region_account_transformation: Callable[[HttpRequest, Callable[[HttpRequest], asyncio.Task[AsyncHttpResponse]]], AioHttpTransportResponse] = \ + emulator_as_multi_write_region_account_transformation: Callable[[HttpRequest, Callable[[HttpRequest], Awaitable[AsyncHttpResponse]]], AioHttpTransportResponse] = \ lambda r, inner: FaultInjectionTransportAsync.transform_topology_mwr( first_region_name="First Region", second_region_name="Second Region", diff --git a/sdk/cosmos/azure-cosmos/tests/test_feed_range_async.py b/sdk/cosmos/azure-cosmos/tests/test_feed_range_async.py index b540a5a70423..84318f4dc5bb 100644 --- a/sdk/cosmos/azure-cosmos/tests/test_feed_range_async.py +++ b/sdk/cosmos/azure-cosmos/tests/test_feed_range_async.py @@ -58,7 +58,7 @@ async def test_feed_range_is_subset_from_pk_async(self): True, False)).to_dict() epk_child_feed_range = await self.container_for_test.feed_range_from_partition_key("1") - assert self.container_for_test.is_feed_range_subset(epk_parent_feed_range, epk_child_feed_range) + assert await self.container_for_test.is_feed_range_subset(epk_parent_feed_range, epk_child_feed_range) if __name__ == '__main__': unittest.main() From b4db22e8c7b2cec48c61030ae8dc377297281271 Mon Sep 17 00:00:00 2001 From: Tomas Varon Saldarriaga Date: Wed, 2 Apr 2025 23:29:04 -0700 Subject: [PATCH 37/86] fix tests --- .../azure-cosmos/tests/test_fault_injection_transport_async.py | 2 -- 1 file changed, 2 deletions(-) diff --git a/sdk/cosmos/azure-cosmos/tests/test_fault_injection_transport_async.py b/sdk/cosmos/azure-cosmos/tests/test_fault_injection_transport_async.py index 3f0507acc038..b781526a807d 100644 --- a/sdk/cosmos/azure-cosmos/tests/test_fault_injection_transport_async.py +++ b/sdk/cosmos/azure-cosmos/tests/test_fault_injection_transport_async.py @@ -342,8 +342,6 @@ async def test_mwr_succeeds_async(self: "TestFaultInjectionTransportAsync"): finally: TestFaultInjectionTransportAsync.cleanup_method(initialized_objects) - # add test for complete failures - async def test_mwr_region_down_succeeds_async(self: "TestFaultInjectionTransportAsync"): first_region_uri: str = test_config.TestConfig.local_host.replace("localhost", "127.0.0.1") second_region_uri: str = test_config.TestConfig.local_host From 9a94d028217a5af4540731d68932a8cd103dc3d2 Mon Sep 17 00:00:00 2001 From: tvaron3 Date: Thu, 3 Apr 2025 09:01:48 -0700 Subject: [PATCH 38/86] fix tests --- sdk/cosmos/azure-cosmos/azure/cosmos/_location_cache.py | 2 +- sdk/cosmos/azure-cosmos/azure/cosmos/_request_object.py | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/_location_cache.py b/sdk/cosmos/azure-cosmos/azure/cosmos/_location_cache.py index d81884d14e55..873c032d7ead 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/_location_cache.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/_location_cache.py @@ -235,7 +235,7 @@ def _get_configured_excluded_locations(self, request: RequestObject): if excluded_locations is None: # If excluded locations were only configured on client(connection_policy), use client level excluded_locations = self.connection_policy.ExcludedLocations - excluded_locations.union(request.excluded_locations_circuit_breaker) + excluded_locations.extend(request.excluded_locations_circuit_breaker) return excluded_locations def _get_applicable_read_regional_endpoints(self, request: RequestObject): diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/_request_object.py b/sdk/cosmos/azure-cosmos/azure/cosmos/_request_object.py index 4f02ab0bb52b..28dc2fefd73b 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/_request_object.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/_request_object.py @@ -21,7 +21,7 @@ """Represents a request object. """ -from typing import Optional, Mapping, Any, Dict, Set +from typing import Optional, Mapping, Any, Dict, Set, List class RequestObject(object): @@ -41,7 +41,7 @@ def __init__( self.location_index_to_route: Optional[int] = None self.location_endpoint_to_route: Optional[str] = None self.last_routed_location_endpoint_within_region: Optional[str] = None - self.excluded_locations: Optional[Set[str]] = None + self.excluded_locations: Optional[List[str]] = None self.excluded_locations_circuit_breaker: Set[str] = set() def route_to_location_with_preferred_location_flag( # pylint: disable=name-too-long From eab1b63f2b1f9e2260c60cbca4b3f96a514fad3d Mon Sep 17 00:00:00 2001 From: Tomas Varon Saldarriaga Date: Thu, 3 Apr 2025 09:50:10 -0700 Subject: [PATCH 39/86] fix test --- .../azure-cosmos/tests/test_fault_injection_transport_async.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sdk/cosmos/azure-cosmos/tests/test_fault_injection_transport_async.py b/sdk/cosmos/azure-cosmos/tests/test_fault_injection_transport_async.py index b781526a807d..40c73c31381b 100644 --- a/sdk/cosmos/azure-cosmos/tests/test_fault_injection_transport_async.py +++ b/sdk/cosmos/azure-cosmos/tests/test_fault_injection_transport_async.py @@ -24,7 +24,7 @@ from azure.cosmos.exceptions import CosmosHttpResponseError from azure.core.exceptions import ServiceRequestError -MGMT_TIMEOUT = 10.0 +MGMT_TIMEOUT = 1.0 logger = logging.getLogger('azure.cosmos') logger.setLevel(logging.DEBUG) logger.addHandler(logging.StreamHandler(sys.stdout)) From 93c2d7ddac7bd7f305f0f7b419ae6ab7279455da Mon Sep 17 00:00:00 2001 From: Tomas Varon Saldarriaga Date: Thu, 3 Apr 2025 11:12:03 -0700 Subject: [PATCH 40/86] fix test --- .../test_fault_injection_transport_async.py | 53 ++++++++++--------- 1 file changed, 27 insertions(+), 26 deletions(-) diff --git a/sdk/cosmos/azure-cosmos/tests/test_fault_injection_transport_async.py b/sdk/cosmos/azure-cosmos/tests/test_fault_injection_transport_async.py index 40c73c31381b..4f3a78c28760 100644 --- a/sdk/cosmos/azure-cosmos/tests/test_fault_injection_transport_async.py +++ b/sdk/cosmos/azure-cosmos/tests/test_fault_injection_transport_async.py @@ -9,6 +9,7 @@ import unittest import uuid from typing import Any, Callable, Awaitable, Dict +from unittest import IsolatedAsyncioTestCase import pytest from azure.core.pipeline.transport import AioHttpTransport @@ -36,9 +37,9 @@ @pytest.mark.cosmosEmulator @pytest.mark.asyncio -class TestFaultInjectionTransportAsync: +class TestFaultInjectionTransportAsync(IsolatedAsyncioTestCase): @classmethod - def setup_class(cls): + async def asyncSetUp(cls): logger.info("starting class: {} execution".format(cls.__name__)) cls.host = host cls.master_key = master_key @@ -54,30 +55,30 @@ def setup_class(cls): cls.mgmt_client = CosmosClient(cls.host, cls.master_key, consistency_level="Session", logger=logger) created_database = cls.mgmt_client.get_database_client(cls.database_id) - asyncio.run(asyncio.wait_for( + await asyncio.wait_for( created_database.create_container( cls.single_partition_container_name, partition_key=PartitionKey("/pk")), - MGMT_TIMEOUT)) + MGMT_TIMEOUT) @classmethod - def teardown_class(cls): + async def asyncTearDown(cls): logger.info("tearing down class: {}".format(cls.__name__)) created_database = cls.mgmt_client.get_database_client(cls.database_id) try: - asyncio.run(asyncio.wait_for( + await asyncio.wait_for( created_database.delete_container(cls.single_partition_container_name), - MGMT_TIMEOUT)) + MGMT_TIMEOUT) except Exception as containerDeleteError: logger.warning("Exception trying to delete database {}. {}".format(created_database.id, containerDeleteError)) finally: try: - asyncio.run(asyncio.wait_for(cls.mgmt_client.close(), MGMT_TIMEOUT)) + await asyncio.wait_for(cls.mgmt_client.close(), MGMT_TIMEOUT) except Exception as closeError: logger.warning("Exception trying to close client {}. {}".format(created_database.id, closeError)) @staticmethod - def setup_method_with_custom_transport(custom_transport: AioHttpTransport, default_endpoint=host, **kwargs): + async def setup_method_with_custom_transport(custom_transport: AioHttpTransport, default_endpoint=host, **kwargs): client = CosmosClient(default_endpoint, master_key, consistency_level="Session", transport=custom_transport, logger=logger, enable_diagnostics_logging=True, **kwargs) db: DatabaseProxy = client.get_database_client(TEST_DATABASE_ID) @@ -85,7 +86,7 @@ def setup_method_with_custom_transport(custom_transport: AioHttpTransport, defau return {"client": client, "db": db, "col": container} @staticmethod - def cleanup_method(initialized_objects: Dict[str, Any]): + async def cleanup_method(initialized_objects: Dict[str, Any]): method_client: CosmosClient = initialized_objects["client"] try: asyncio.run(asyncio.wait_for(method_client.close(), MGMT_TIMEOUT)) @@ -107,7 +108,7 @@ async def test_throws_injected_error_async(self: "TestFaultInjectionTransportAsy status_code=502, message="Some random reverse proxy error.")))) - initialized_objects = self.setup_method_with_custom_transport(custom_transport) + initialized_objects = await self.setup_method_with_custom_transport(custom_transport) start: float = time.perf_counter() try: container: ContainerProxy = initialized_objects["col"] @@ -120,7 +121,7 @@ async def test_throws_injected_error_async(self: "TestFaultInjectionTransportAsy if cosmosError.status_code != 502: raise cosmosError finally: - TestFaultInjectionTransportAsync.cleanup_method(initialized_objects) + await TestFaultInjectionTransportAsync.cleanup_method(initialized_objects) async def test_swr_mrr_succeeds_async(self: "TestFaultInjectionTransportAsync"): expected_read_region_uri: str = test_config.TestConfig.local_host @@ -152,7 +153,7 @@ async def test_swr_mrr_succeeds_async(self: "TestFaultInjectionTransportAsync"): 'name': 'sample document', 'key': 'value'} - initialized_objects = self.setup_method_with_custom_transport( + initialized_objects = await self.setup_method_with_custom_transport( custom_transport, preferred_locations=["Read Region", "Write Region"]) try: @@ -171,7 +172,7 @@ async def test_swr_mrr_succeeds_async(self: "TestFaultInjectionTransportAsync"): assert request.url.startswith(expected_read_region_uri) finally: - TestFaultInjectionTransportAsync.cleanup_method(initialized_objects) + await TestFaultInjectionTransportAsync.cleanup_method(initialized_objects) async def test_swr_mrr_region_down_read_succeeds_async(self: "TestFaultInjectionTransportAsync"): expected_read_region_uri: str = test_config.TestConfig.local_host @@ -211,7 +212,7 @@ async def test_swr_mrr_region_down_read_succeeds_async(self: "TestFaultInjection 'name': 'sample document', 'key': 'value'} - initialized_objects = self.setup_method_with_custom_transport( + initialized_objects = await self.setup_method_with_custom_transport( custom_transport, default_endpoint=expected_write_region_uri, preferred_locations=["Read Region", "Write Region"]) @@ -231,7 +232,7 @@ async def test_swr_mrr_region_down_read_succeeds_async(self: "TestFaultInjection assert request.url.startswith(expected_write_region_uri) finally: - TestFaultInjectionTransportAsync.cleanup_method(initialized_objects) + await TestFaultInjectionTransportAsync.cleanup_method(initialized_objects) async def test_swr_mrr_region_down_envoy_read_succeeds_async(self: "TestFaultInjectionTransportAsync"): expected_read_region_uri: str = test_config.TestConfig.local_host @@ -276,7 +277,7 @@ async def test_swr_mrr_region_down_envoy_read_succeeds_async(self: "TestFaultInj 'name': 'sample document', 'key': 'value'} - initialized_objects = self.setup_method_with_custom_transport( + initialized_objects = await self.setup_method_with_custom_transport( custom_transport, default_endpoint=expected_write_region_uri, preferred_locations=["Read Region", "Write Region"]) @@ -296,7 +297,7 @@ async def test_swr_mrr_region_down_envoy_read_succeeds_async(self: "TestFaultInj assert request.url.startswith(expected_write_region_uri) finally: - TestFaultInjectionTransportAsync.cleanup_method(initialized_objects) + await TestFaultInjectionTransportAsync.cleanup_method(initialized_objects) async def test_mwr_succeeds_async(self: "TestFaultInjectionTransportAsync"): @@ -321,7 +322,7 @@ async def test_mwr_succeeds_async(self: "TestFaultInjectionTransportAsync"): 'name': 'sample document', 'key': 'value'} - initialized_objects = self.setup_method_with_custom_transport( + initialized_objects = await self.setup_method_with_custom_transport( custom_transport, preferred_locations=["First Region", "Second Region"]) try: @@ -340,7 +341,7 @@ async def test_mwr_succeeds_async(self: "TestFaultInjectionTransportAsync"): assert request.url.startswith(first_region_uri) finally: - TestFaultInjectionTransportAsync.cleanup_method(initialized_objects) + await TestFaultInjectionTransportAsync.cleanup_method(initialized_objects) async def test_mwr_region_down_succeeds_async(self: "TestFaultInjectionTransportAsync"): first_region_uri: str = test_config.TestConfig.local_host.replace("localhost", "127.0.0.1") @@ -374,7 +375,7 @@ async def test_mwr_region_down_succeeds_async(self: "TestFaultInjectionTransport 'name': 'sample document', 'key': 'value'} - initialized_objects = self.setup_method_with_custom_transport( + initialized_objects = await self.setup_method_with_custom_transport( custom_transport, preferred_locations=["First Region", "Second Region"]) try: @@ -392,7 +393,7 @@ async def test_mwr_region_down_succeeds_async(self: "TestFaultInjectionTransport assert request.url.startswith(second_region_uri) finally: - TestFaultInjectionTransportAsync.cleanup_method(initialized_objects) + await TestFaultInjectionTransportAsync.cleanup_method(initialized_objects) async def test_swr_mrr_all_regions_down_for_read_async(self: "TestFaultInjectionTransportAsync"): expected_read_region_uri: str = test_config.TestConfig.local_host @@ -446,7 +447,7 @@ async def test_swr_mrr_all_regions_down_for_read_async(self: "TestFaultInjection 'name': 'sample document', 'key': 'value'} - initialized_objects = self.setup_method_with_custom_transport( + initialized_objects = await self.setup_method_with_custom_transport( custom_transport, preferred_locations=["First Region", "Second Region"]) try: @@ -455,7 +456,7 @@ async def test_swr_mrr_all_regions_down_for_read_async(self: "TestFaultInjection with pytest.raises(ServiceRequestError): await container.read_item(id_value, partition_key=id_value) finally: - TestFaultInjectionTransportAsync.cleanup_method(initialized_objects) + await TestFaultInjectionTransportAsync.cleanup_method(initialized_objects) async def test_mwr_all_regions_down_async(self: "TestFaultInjectionTransportAsync"): @@ -499,7 +500,7 @@ async def test_mwr_all_regions_down_async(self: "TestFaultInjectionTransportAsyn 'name': 'sample document', 'key': 'value'} - initialized_objects = self.setup_method_with_custom_transport( + initialized_objects = await self.setup_method_with_custom_transport( custom_transport, preferred_locations=["First Region", "Second Region"]) try: @@ -507,7 +508,7 @@ async def test_mwr_all_regions_down_async(self: "TestFaultInjectionTransportAsyn with pytest.raises(ServiceRequestError): await container.upsert_item(body=document_definition) finally: - TestFaultInjectionTransportAsync.cleanup_method(initialized_objects) + await TestFaultInjectionTransportAsync.cleanup_method(initialized_objects) if __name__ == '__main__': unittest.main() From 345f3901a8c79690d698ceb6781a2377591332f6 Mon Sep 17 00:00:00 2001 From: Tomas Varon Saldarriaga Date: Thu, 3 Apr 2025 11:46:06 -0700 Subject: [PATCH 41/86] fix tests --- .../tests/test_fault_injection_transport_async.py | 10 ++++------ 1 file changed, 4 insertions(+), 6 deletions(-) diff --git a/sdk/cosmos/azure-cosmos/tests/test_fault_injection_transport_async.py b/sdk/cosmos/azure-cosmos/tests/test_fault_injection_transport_async.py index 4f3a78c28760..f3b56766b4d1 100644 --- a/sdk/cosmos/azure-cosmos/tests/test_fault_injection_transport_async.py +++ b/sdk/cosmos/azure-cosmos/tests/test_fault_injection_transport_async.py @@ -25,7 +25,7 @@ from azure.cosmos.exceptions import CosmosHttpResponseError from azure.core.exceptions import ServiceRequestError -MGMT_TIMEOUT = 1.0 +MGMT_TIMEOUT = 5.0 logger = logging.getLogger('azure.cosmos') logger.setLevel(logging.DEBUG) logger.addHandler(logging.StreamHandler(sys.stdout)) @@ -33,7 +33,6 @@ host = test_config.TestConfig.host master_key = test_config.TestConfig.masterKey TEST_DATABASE_ID = test_config.TestConfig.TEST_DATABASE_ID -single_partition_container_name = os.path.basename(__file__) + str(uuid.uuid4()) @pytest.mark.cosmosEmulator @pytest.mark.asyncio @@ -51,7 +50,7 @@ async def asyncSetUp(cls): "'masterKey' and 'host' at the top of this class to run the " "tests.") cls.database_id = TEST_DATABASE_ID - cls.single_partition_container_name = single_partition_container_name + cls.single_partition_container_name = os.path.basename(__file__) + str(uuid.uuid4()) cls.mgmt_client = CosmosClient(cls.host, cls.master_key, consistency_level="Session", logger=logger) created_database = cls.mgmt_client.get_database_client(cls.database_id) @@ -77,12 +76,11 @@ async def asyncTearDown(cls): except Exception as closeError: logger.warning("Exception trying to close client {}. {}".format(created_database.id, closeError)) - @staticmethod - async def setup_method_with_custom_transport(custom_transport: AioHttpTransport, default_endpoint=host, **kwargs): + async def setup_method_with_custom_transport(self, custom_transport: AioHttpTransport, default_endpoint=host, **kwargs): client = CosmosClient(default_endpoint, master_key, consistency_level="Session", transport=custom_transport, logger=logger, enable_diagnostics_logging=True, **kwargs) db: DatabaseProxy = client.get_database_client(TEST_DATABASE_ID) - container: ContainerProxy = db.get_container_client(single_partition_container_name) + container: ContainerProxy = db.get_container_client(self.single_partition_container_name) return {"client": client, "db": db, "col": container} @staticmethod From fe74aa0ab0fb106d03a995e6ac52dea833e3636c Mon Sep 17 00:00:00 2001 From: Tomas Varon Saldarriaga Date: Thu, 3 Apr 2025 13:43:34 -0700 Subject: [PATCH 42/86] fix async in test --- .../azure-cosmos/tests/test_fault_injection_transport_async.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sdk/cosmos/azure-cosmos/tests/test_fault_injection_transport_async.py b/sdk/cosmos/azure-cosmos/tests/test_fault_injection_transport_async.py index f3b56766b4d1..1df1de05936d 100644 --- a/sdk/cosmos/azure-cosmos/tests/test_fault_injection_transport_async.py +++ b/sdk/cosmos/azure-cosmos/tests/test_fault_injection_transport_async.py @@ -87,7 +87,7 @@ async def setup_method_with_custom_transport(self, custom_transport: AioHttpTran async def cleanup_method(initialized_objects: Dict[str, Any]): method_client: CosmosClient = initialized_objects["client"] try: - asyncio.run(asyncio.wait_for(method_client.close(), MGMT_TIMEOUT)) + await asyncio.wait_for(method_client.close(), MGMT_TIMEOUT) except Exception as close_error: logger.warning(f"Exception trying to close method client. {close_error}") From 5bb9f1fb166a5a9803af8287eac54b506da3c8b5 Mon Sep 17 00:00:00 2001 From: Kushagra Thapar Date: Thu, 3 Apr 2025 14:47:32 -0700 Subject: [PATCH 43/86] Added multi-region tests --- sdk/cosmos/live-platform-matrix.json | 37 ++++++++++++++++++++++++++++ sdk/cosmos/test-resources.bicep | 6 +++++ 2 files changed, 43 insertions(+) diff --git a/sdk/cosmos/live-platform-matrix.json b/sdk/cosmos/live-platform-matrix.json index 485a15ca92e8..bca59256d05d 100644 --- a/sdk/cosmos/live-platform-matrix.json +++ b/sdk/cosmos/live-platform-matrix.json @@ -88,6 +88,43 @@ "TestMarkArgument": "cosmosLong" } } + }, + { + "WindowsConfig": { + "Windows2022_38_multi_region": { + "OSVmImage": "env:WINDOWSVMIMAGE", + "Pool": "env:WINDOWSPOOL", + "PythonVersion": "3.8", + "CoverageArg": "--disablecov", + "TestSamples": "false", + "TestMarkArgument": "cosmosMultiRegion", + "ArmConfig": { + "ArmTemplateParameters": "@{ enableMultipleRegions = $true }" + } + }, + "Windows2022_310_multi_region": { + "OSVmImage": "env:WINDOWSVMIMAGE", + "Pool": "env:WINDOWSPOOL", + "PythonVersion": "3.10", + "CoverageArg": "--disablecov", + "TestSamples": "false", + "TestMarkArgument": "cosmosMultiRegion", + "ArmConfig": { + "ArmTemplateParameters": "@{ enableMultipleRegions = $true }" + } + }, + "Windows2022_312_multi_region": { + "OSVmImage": "env:WINDOWSVMIMAGE", + "Pool": "env:WINDOWSPOOL", + "PythonVersion": "3.12", + "CoverageArg": "--disablecov", + "TestSamples": "false", + "TestMarkArgument": "cosmosMultiRegion", + "ArmConfig": { + "ArmTemplateParameters": "@{ enableMultipleRegions = $true }" + } + } + } } ] } diff --git a/sdk/cosmos/test-resources.bicep b/sdk/cosmos/test-resources.bicep index 17d88b0be92a..61588a526eed 100644 --- a/sdk/cosmos/test-resources.bicep +++ b/sdk/cosmos/test-resources.bicep @@ -41,6 +41,12 @@ var multiRegionConfiguration = [ failoverPriority: 1 isZoneRedundant: false } + { + locationName: 'West US 2' + provisioningState: 'Succeeded' + failoverPriority: 2 + isZoneRedundant: false + } ] var locationsConfiguration = (enableMultipleRegions ? multiRegionConfiguration : singleRegionConfiguration) var roleDefinitionId = guid(baseName, 'roleDefinitionId') From 996217ae3fec5740cf0bc3eb180d2ba6af725953 Mon Sep 17 00:00:00 2001 From: Allen Kim Date: Thu, 3 Apr 2025 15:08:49 -0700 Subject: [PATCH 44/86] Fix _AddParitionKey to pass options to sub methods --- .../azure/cosmos/_cosmos_client_connection.py | 10 +++++++--- .../azure-cosmos/azure/cosmos/_request_object.py | 2 +- 2 files changed, 8 insertions(+), 4 deletions(-) diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/_cosmos_client_connection.py b/sdk/cosmos/azure-cosmos/azure/cosmos/_cosmos_client_connection.py index d64da38defb1..acc9ac0010af 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/_cosmos_client_connection.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/_cosmos_client_connection.py @@ -3265,7 +3265,7 @@ def _AddPartitionKey( options: Mapping[str, Any] ) -> Dict[str, Any]: collection_link = base.TrimBeginningAndEndingSlashes(collection_link) - partitionKeyDefinition = self._get_partition_key_definition(collection_link) + partitionKeyDefinition = self._get_partition_key_definition(collection_link, options) new_options = dict(options) # If the collection doesn't have a partition key definition, skip it as it's a legacy collection if partitionKeyDefinition: @@ -3367,7 +3367,11 @@ def _UpdateSessionIfRequired( # update session self.session.update_session(response_result, response_headers) - def _get_partition_key_definition(self, collection_link: str) -> Optional[Dict[str, Any]]: + def _get_partition_key_definition( + self, + collection_link: str, + options: Mapping[str, Any] + ) -> Optional[Dict[str, Any]]: partition_key_definition: Optional[Dict[str, Any]] # If the document collection link is present in the cache, then use the cached partitionkey definition if collection_link in self.__container_properties_cache: @@ -3375,7 +3379,7 @@ def _get_partition_key_definition(self, collection_link: str) -> Optional[Dict[s partition_key_definition = cached_container.get("partitionKey") # Else read the collection from backend and add it to the cache else: - container = self.ReadContainer(collection_link) + container = self.ReadContainer(collection_link, options) partition_key_definition = container.get("partitionKey") self.__container_properties_cache[collection_link] = _set_properties_cache(container) return partition_key_definition diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/_request_object.py b/sdk/cosmos/azure-cosmos/azure/cosmos/_request_object.py index 94805934ce74..185aa1d89cb8 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/_request_object.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/_request_object.py @@ -57,7 +57,7 @@ def clear_route_to_location(self) -> None: def _can_set_excluded_location(self, options: Mapping[str, Any]) -> bool: # If resource types for requests are not one of the followings, excluded locations cannot be set - if self.resource_type.lower() not in ['docs', 'documents', 'partitionkey']: + if self.resource_type.lower() not in ['docs', 'documents', 'partitionkey', 'colls']: return False # If 'excludedLocations' wasn't in the options, excluded locations cannot be set From 41fc9176bec2687e41bc80e7f5254763754c0930 Mon Sep 17 00:00:00 2001 From: Allen Kim Date: Thu, 3 Apr 2025 15:10:39 -0700 Subject: [PATCH 45/86] Added initial live tests --- .../tests/test_excluded_locations.py | 478 ++++++++++++++++++ 1 file changed, 478 insertions(+) create mode 100644 sdk/cosmos/azure-cosmos/tests/test_excluded_locations.py diff --git a/sdk/cosmos/azure-cosmos/tests/test_excluded_locations.py b/sdk/cosmos/azure-cosmos/tests/test_excluded_locations.py new file mode 100644 index 000000000000..01d1e9e9cf7e --- /dev/null +++ b/sdk/cosmos/azure-cosmos/tests/test_excluded_locations.py @@ -0,0 +1,478 @@ +# The MIT License (MIT) +# Copyright (c) Microsoft Corporation. All rights reserved. + +import logging +import unittest +import uuid +import test_config +import pytest + +import azure.cosmos.cosmos_client as cosmos_client +from azure.cosmos.partition_key import PartitionKey +from azure.cosmos.exceptions import CosmosResourceNotFoundError + + +class MockHandler(logging.Handler): + def __init__(self): + super(MockHandler, self).__init__() + self.messages = [] + + def reset(self): + self.messages = [] + + def emit(self, record): + self.messages.append(record.msg) + +MOCK_HANDLER = MockHandler() +CONFIG = test_config.TestConfig() +HOST = CONFIG.host +KEY = CONFIG.masterKey +DATABASE_ID = CONFIG.TEST_DATABASE_ID +CONTAINER_ID = CONFIG.TEST_SINGLE_PARTITION_CONTAINER_ID +PARTITION_KEY = CONFIG.TEST_CONTAINER_PARTITION_KEY +ITEM_ID = 'doc1' +ITEM_PK_VALUE = 'pk' +TEST_ITEM = {'id': ITEM_ID, PARTITION_KEY: ITEM_PK_VALUE} + +# L0 = "Default" +# L1 = "West US 3" +# L2 = "West US" +# L3 = "East US 2" +# L4 = "Central US" + +L0 = "Default" +L1 = "East US 2" +L2 = "East US" +L3 = "West US 2" +L4 = "Central US" + +CLIENT_ONLY_TEST_DATA = [ + # preferred_locations, client_excluded_locations, excluded_locations_request + # 0. No excluded location + [[L1, L2, L3], [], None], + # 1. Single excluded location + [[L1, L2, L3], [L1], None], + # 2. Multiple excluded locations + [[L1, L2, L3], [L1, L2], None], + # 3. Exclude all locations + [[L1, L2, L3], [L1, L2, L3], None], + # 4. Exclude a location not in preferred locations + [[L1, L2, L3], [L4], None], +] + +CLIENT_AND_REQUEST_TEST_DATA = [ + # preferred_locations, client_excluded_locations, excluded_locations_request + # 0. No client excluded locations + a request excluded location + [[L1, L2, L3], [], [L1]], + # 1. The same client and request excluded location + [[L1, L2, L3], [L1], [L1]], + # 2. Less request excluded locations + [[L1, L2, L3], [L1, L2], [L1]], + # 3. More request excluded locations + [[L1, L2, L3], [L1], [L1, L2]], + # 4. All locations were excluded + [[L1, L2, L3], [L1, L2, L3], [L1, L2, L3]], + # 5. No common excluded locations + [[L1, L2, L3], [L1], [L2, L3]], + # 6. Reqeust excluded location not in preferred locations + [[L1, L2, L3], [L1, L2, L3], [L4]], + # 7. Empty excluded locations, remove all client excluded locations + [[L1, L2, L3], [L1, L2], []], +] + +ALL_INPUT_TEST_DATA = CLIENT_ONLY_TEST_DATA + CLIENT_AND_REQUEST_TEST_DATA +# ALL_INPUT_TEST_DATA = CLIENT_ONLY_TEST_DATA +# ALL_INPUT_TEST_DATA = CLIENT_AND_REQUEST_TEST_DATA + +def read_item_test_data(): + client_only_output_data = [ + [L1], # 0 + [L2], # 1 + [L3], # 2 + [L1], # 3 + [L1] # 4 + ] + client_and_request_output_data = [ + [L2], # 0 + [L2], # 1 + [L2], # 2 + [L3], # 3 + [L1], # 4 + [L1], # 5 + [L1], # 6 + [L1], # 7 + ] + all_output_test_data = client_only_output_data + client_and_request_output_data + + all_test_data = [input_data + [output_data] for input_data, output_data in zip(ALL_INPUT_TEST_DATA, all_output_test_data)] + return all_test_data + +def query_items_change_feed_test_data(): + client_only_output_data = [ + [L1, L1, L1], #0 + [L2, L2, L2], #1 + [L3, L3, L3], #2 + [L1, L1, L1], #3 + [L1, L1, L1] #4 + ] + client_and_request_output_data = [ + [L1, L1, L2], #0 + [L2, L2, L2], #1 + [L3, L3, L2], #2 + [L2, L2, L3], #3 + [L1, L1, L1], #4 + [L2, L2, L1], #5 + [L1, L1, L1], #6 + [L3, L3, L1], #7 + ] + all_output_test_data = client_only_output_data + client_and_request_output_data + + all_test_data = [input_data + [output_data] for input_data, output_data in zip(ALL_INPUT_TEST_DATA, all_output_test_data)] + return all_test_data + +def replace_item_test_data(): + client_only_output_data = [ + [L1, L1], #0 + [L2, L2], #1 + [L3, L3], #2 + [L1, L0], #3 + [L1, L1] #4 + ] + client_and_request_output_data = [ + [L2, L2], #0 + [L2, L2], #1 + [L2, L2], #2 + [L3, L3], #3 + [L1, L0], #4 + [L1, L1], #5 + [L1, L1], #6 + [L1, L1], #7 + ] + all_output_test_data = client_only_output_data + client_and_request_output_data + + all_test_data = [input_data + [output_data] for input_data, output_data in zip(ALL_INPUT_TEST_DATA, all_output_test_data)] + return all_test_data + +def patch_item_test_data(): + client_only_output_data = [ + [L1], #0 + [L2], #1 + [L3], #2 + [L0], #3 + [L1] #4 + ] + client_and_request_output_data = [ + [L2], #0 + [L2], #1 + [L2], #2 + [L3], #3 + [L0], #4 + [L1], #5 + [L1], #6 + [L1], #7 + ] + all_output_test_data = client_only_output_data + client_and_request_output_data + + all_test_data = [input_data + [output_data] for input_data, output_data in zip(ALL_INPUT_TEST_DATA, all_output_test_data)] + return all_test_data + +@pytest.fixture(scope="class", autouse=True) +def setup_and_teardown(): + print("Setup: This runs before any tests") + logger = logging.getLogger("azure") + logger.addHandler(MOCK_HANDLER) + logger.setLevel(logging.DEBUG) + + container = cosmos_client.CosmosClient(HOST, KEY).get_database_client(DATABASE_ID).get_container_client(CONTAINER_ID) + container.create_item(body=TEST_ITEM) + + yield + # Code to run after tests + print("Teardown: This runs after all tests") + +@pytest.mark.cosmosMultiRegion +class TestExcludedLocations: + def _init_container(self, preferred_locations, client_excluded_locations, multiple_write_locations = True): + client = cosmos_client.CosmosClient(HOST, KEY, + preferred_locations=preferred_locations, + excluded_locations=client_excluded_locations, + multiple_write_locations=multiple_write_locations) + db = client.get_database_client(DATABASE_ID) + container = db.get_container_client(CONTAINER_ID) + MOCK_HANDLER.reset() + + return client, db, container + + def _verify_endpoint(self, client, expected_locations): + # get mapping for locations + location_mapping = (client.client_connection._global_endpoint_manager. + location_cache.account_locations_by_write_regional_routing_context) + default_endpoint = (client.client_connection._global_endpoint_manager. + location_cache.default_regional_routing_context.get_primary()) + + # get Request URL + msgs = MOCK_HANDLER.messages + req_urls = [url.replace("Request URL: '", "") for url in msgs if 'Request URL:' in url] + + # get location + actual_locations = [] + for req_url in req_urls: + if req_url.startswith(default_endpoint): + actual_locations.append(L0) + else: + for endpoint in location_mapping: + if req_url.startswith(endpoint): + location = location_mapping[endpoint] + actual_locations.append(location) + break + + assert actual_locations == expected_locations + + @pytest.mark.parametrize('test_data', read_item_test_data()) + def test_read_item(self, test_data): + # Init test variables + preferred_locations, client_excluded_locations, request_excluded_locations, expected_locations = test_data + + # Client setup + client, db, container = self._init_container(preferred_locations, client_excluded_locations) + + # API call: read_item + if request_excluded_locations is None: + container.read_item(ITEM_ID, ITEM_PK_VALUE) + else: + container.read_item(ITEM_ID, ITEM_PK_VALUE, excluded_locations=request_excluded_locations) + + # Verify endpoint locations + self._verify_endpoint(client, expected_locations) + + @pytest.mark.parametrize('test_data', read_item_test_data()) + def test_read_all_items(self, test_data): + # Init test variables + preferred_locations, client_excluded_locations, request_excluded_locations, expected_locations = test_data + + # Client setup + client, db, container = self._init_container(preferred_locations, client_excluded_locations) + + # API call: read_all_items + if request_excluded_locations is None: + list(container.read_all_items()) + else: + list(container.read_all_items(excluded_locations=request_excluded_locations)) + + # Verify endpoint locations + self._verify_endpoint(client, expected_locations) + + @pytest.mark.parametrize('test_data', read_item_test_data()) + def test_query_items(self, test_data): + # Init test variables + preferred_locations, client_excluded_locations, request_excluded_locations, expected_locations = test_data + + # Client setup and create an item + client, db, container = self._init_container(preferred_locations, client_excluded_locations) + + # API call: query_items + if request_excluded_locations is None: + list(container.query_items(None)) + else: + list(container.query_items(None, excluded_locations=request_excluded_locations)) + + # Verify endpoint locations + self._verify_endpoint(client, expected_locations) + + @pytest.mark.parametrize('test_data', query_items_change_feed_test_data()) + def test_query_items_change_feed(self, test_data): + # Init test variables + preferred_locations, client_excluded_locations, request_excluded_locations, expected_locations = test_data + + + # Client setup and create an item + client, db, container = self._init_container(preferred_locations, client_excluded_locations) + + # API call: query_items_change_feed + if request_excluded_locations is None: + list(container.query_items_change_feed()) + else: + list(container.query_items_change_feed(excluded_locations=request_excluded_locations)) + + # Verify endpoint locations + self._verify_endpoint(client, expected_locations) + + + @pytest.mark.parametrize('test_data', replace_item_test_data()) + def test_replace_item(self, test_data): + # Init test variables + preferred_locations, client_excluded_locations, request_excluded_locations, expected_locations = test_data + + for multiple_write_locations in [True, False]: + # Client setup and create an item + client, db, container = self._init_container(preferred_locations, client_excluded_locations, multiple_write_locations) + + # API call: replace_item + if request_excluded_locations is None: + container.replace_item(ITEM_ID, body=TEST_ITEM) + else: + container.replace_item(ITEM_ID, body=TEST_ITEM, excluded_locations=request_excluded_locations) + + # Verify endpoint locations + if multiple_write_locations: + self._verify_endpoint(client, expected_locations) + else: + self._verify_endpoint(client, [expected_locations[0], L1]) + + @pytest.mark.parametrize('test_data', replace_item_test_data()) + def test_upsert_item(self, test_data): + # Init test variables + preferred_locations, client_excluded_locations, request_excluded_locations, expected_locations = test_data + + for multiple_write_locations in [True, False]: + # Client setup and create an item + client, db, container = self._init_container(preferred_locations, client_excluded_locations, multiple_write_locations) + + # API call: upsert_item + body = {'pk': 'pk', 'id': f'doc2-{str(uuid.uuid4())}'} + if request_excluded_locations is None: + container.upsert_item(body=body) + else: + container.upsert_item(body=body, excluded_locations=request_excluded_locations) + + # get location from mock_handler + if multiple_write_locations: + self._verify_endpoint(client, expected_locations) + else: + self._verify_endpoint(client, [expected_locations[0], L1]) + + @pytest.mark.parametrize('test_data', replace_item_test_data()) + def test_create_item(self, test_data): + # Init test variables + preferred_locations, client_excluded_locations, request_excluded_locations, expected_locations = test_data + + for multiple_write_locations in [True, False]: + # Client setup and create an item + client, db, container = self._init_container(preferred_locations, client_excluded_locations, multiple_write_locations) + + # API call: create_item + body = {'pk': 'pk', 'id': f'doc2-{str(uuid.uuid4())}'} + if request_excluded_locations is None: + container.create_item(body=body) + else: + container.create_item(body=body, excluded_locations=request_excluded_locations) + + # get location from mock_handler + if multiple_write_locations: + self._verify_endpoint(client, expected_locations) + else: + self._verify_endpoint(client, [expected_locations[0], L1]) + + @pytest.mark.parametrize('test_data', patch_item_test_data()) + def test_patch_item(self, test_data): + # Init test variables + preferred_locations, client_excluded_locations, request_excluded_locations, expected_locations = test_data + + for multiple_write_locations in [True, False]: + # Client setup and create an item + client, db, container = self._init_container(preferred_locations, client_excluded_locations, + multiple_write_locations) + + # API call: patch_item + operations = [ + {"op": "add", "path": "/test_data", "value": f'Data-{str(uuid.uuid4())}'}, + ] + if request_excluded_locations is None: + container.patch_item(item=ITEM_ID, partition_key=ITEM_PK_VALUE, + patch_operations=operations) + else: + container.patch_item(item=ITEM_ID, partition_key=ITEM_PK_VALUE, + patch_operations=operations, + excluded_locations=request_excluded_locations) + + # get location from mock_handler + if multiple_write_locations: + self._verify_endpoint(client, expected_locations) + else: + self._verify_endpoint(client, [L1]) + + @pytest.mark.parametrize('test_data', patch_item_test_data()) + def test_execute_item_batch(self, test_data): + # Init test variables + preferred_locations, client_excluded_locations, request_excluded_locations, expected_locations = test_data + + for multiple_write_locations in [True, False]: + # Client setup and create an item + client, db, container = self._init_container(preferred_locations, client_excluded_locations, + multiple_write_locations) + + # API call: execute_item_batch + batch_operations = [] + for i in range(3): + batch_operations.append(("create", ({"id": f'Doc-{str(uuid.uuid4())}', PARTITION_KEY: ITEM_PK_VALUE},))) + + if request_excluded_locations is None: + container.execute_item_batch(batch_operations=batch_operations, + partition_key=ITEM_PK_VALUE,) + else: + container.execute_item_batch(batch_operations=batch_operations, + partition_key=ITEM_PK_VALUE, + excluded_locations=request_excluded_locations) + + # get location from mock_handler + if multiple_write_locations: + self._verify_endpoint(client, expected_locations) + else: + self._verify_endpoint(client, [L1]) + + @pytest.mark.parametrize('test_data', patch_item_test_data()) + def test_delete_item(self, test_data): + # Init test variables + preferred_locations, client_excluded_locations, request_excluded_locations, expected_locations = test_data + + for multiple_write_locations in [True, False]: + # Client setup + client, db, container = self._init_container(preferred_locations, client_excluded_locations, multiple_write_locations) + + #create before delete + item_id = f'doc2-{str(uuid.uuid4())}' + container.create_item(body={PARTITION_KEY: ITEM_PK_VALUE, 'id': item_id}) + MOCK_HANDLER.reset() + + # API call: read_item + if request_excluded_locations is None: + container.delete_item(item_id, ITEM_PK_VALUE) + else: + container.delete_item(item_id, ITEM_PK_VALUE, excluded_locations=request_excluded_locations) + + # Verify endpoint locations + if multiple_write_locations: + self._verify_endpoint(client, expected_locations) + else: + self._verify_endpoint(client, [L1]) + + # TODO: enable this test once we figure out how to enable delete_all_items_by_partition_key feature + # @pytest.mark.parametrize('test_data', patch_item_test_data()) + # def test_delete_all_items_by_partition_key(self, test_data): + # # Init test variables + # preferred_locations, client_excluded_locations, request_excluded_locations, expected_locations = test_data + # + # for multiple_write_locations in [True, False]: + # # Client setup + # client, db, container = self._init_container(preferred_locations, client_excluded_locations, multiple_write_locations) + # + # #create before delete + # item_id = f'doc2-{str(uuid.uuid4())}' + # pk_value = f'temp_partition_key_value-{str(uuid.uuid4())}' + # container.create_item(body={PARTITION_KEY: pk_value, 'id': item_id}) + # MOCK_HANDLER.reset() + # + # # API call: read_item + # if request_excluded_locations is None: + # container.delete_all_items_by_partition_key(pk_value) + # else: + # container.delete_all_items_by_partition_key(pk_value, excluded_locations=request_excluded_locations) + # + # # Verify endpoint locations + # if multiple_write_locations: + # self._verify_endpoint(client, expected_locations) + # else: + # self._verify_endpoint(client, [L1]) + +if __name__ == "__main__": + unittest.main() From 07b8f39aeaba9a68e7656777bb0732655a20bf4a Mon Sep 17 00:00:00 2001 From: Allen Kim Date: Thu, 3 Apr 2025 15:28:38 -0700 Subject: [PATCH 46/86] Updated live-platform-matrix for multi-region tests --- sdk/cosmos/live-platform-matrix.json | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/sdk/cosmos/live-platform-matrix.json b/sdk/cosmos/live-platform-matrix.json index bca59256d05d..b3242623be78 100644 --- a/sdk/cosmos/live-platform-matrix.json +++ b/sdk/cosmos/live-platform-matrix.json @@ -86,11 +86,7 @@ "CoverageArg": "--disablecov", "TestSamples": "false", "TestMarkArgument": "cosmosLong" - } - } - }, - { - "WindowsConfig": { + }, "Windows2022_38_multi_region": { "OSVmImage": "env:WINDOWSVMIMAGE", "Pool": "env:WINDOWSPOOL", @@ -110,7 +106,9 @@ "TestSamples": "false", "TestMarkArgument": "cosmosMultiRegion", "ArmConfig": { - "ArmTemplateParameters": "@{ enableMultipleRegions = $true }" + "MultiMaster_MultiRegion": { + "ArmTemplateParameters": "@{ enableMultipleRegions = $true }" + } } }, "Windows2022_312_multi_region": { @@ -121,7 +119,9 @@ "TestSamples": "false", "TestMarkArgument": "cosmosMultiRegion", "ArmConfig": { - "ArmTemplateParameters": "@{ enableMultipleRegions = $true }" + "MultiMaster_MultiRegion": { + "ArmTemplateParameters": "@{ enableMultipleRegions = $true }" + } } } } From 1b09739533408a2699443a88df1fe7f59e7fa617 Mon Sep 17 00:00:00 2001 From: tvaron3 Date: Thu, 3 Apr 2025 15:31:56 -0700 Subject: [PATCH 47/86] initial sync version of fault injection --- .../tests/_fault_injection_transport.py | 253 ++++++++++++++++++ .../tests/test_fault_injection_transport.py | 99 +++++++ 2 files changed, 352 insertions(+) create mode 100644 sdk/cosmos/azure-cosmos/tests/_fault_injection_transport.py create mode 100644 sdk/cosmos/azure-cosmos/tests/test_fault_injection_transport.py diff --git a/sdk/cosmos/azure-cosmos/tests/_fault_injection_transport.py b/sdk/cosmos/azure-cosmos/tests/_fault_injection_transport.py new file mode 100644 index 000000000000..9cdd18936945 --- /dev/null +++ b/sdk/cosmos/azure-cosmos/tests/_fault_injection_transport.py @@ -0,0 +1,253 @@ +# The MIT License (MIT) +# Copyright (c) 2014 Microsoft Corporation + +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: + +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. + +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +"""AioHttpTransport allowing injection of faults between SDK and Cosmos Gateway +""" + +import json +import logging +import sys +from time import sleep +from typing import Callable, Optional, Any, Dict, List, MutableMapping + +from azure.core.pipeline.transport import HttpRequest, HttpResponse +from azure.core.pipeline.transport._requests_basic import RequestsTransport, RequestsTransportResponse +from requests import Session + +from azure.cosmos import documents + +import test_config +from azure.cosmos.exceptions import CosmosHttpResponseError +from azure.core.exceptions import ServiceRequestError + +class FaultInjectionTransport(RequestsTransport): + logger = logging.getLogger('azure.cosmos.fault_injection_transport') + logger.setLevel(logging.DEBUG) + + def __init__(self, *, session: Optional[Session] = None, loop=None, session_owner: bool = True, **config): + self.faults: List[Dict[str, Any]] = [] + self.requestTransformations: List[Dict[str, Any]] = [] + self.responseTransformations: List[Dict[str, Any]] = [] + super().__init__(session=session, loop=loop, session_owner=session_owner, **config) + + def add_fault(self, predicate: Callable[[HttpRequest], bool], fault_factory: Callable[[HttpRequest], Exception]): + self.faults.append({"predicate": predicate, "apply": fault_factory}) + + def add_response_transformation(self, predicate: Callable[[HttpRequest], bool], response_transformation: Callable[[HttpRequest, Callable[[HttpRequest], RequestsTransportResponse]], RequestsTransportResponse]): + self.responseTransformations.append({ + "predicate": predicate, + "apply": response_transformation}) + + @staticmethod + def __first_item(iterable, condition=lambda x: True): + """ + Returns the first item in the `iterable` that satisfies the `condition`. + + If no item satisfies the condition, it returns None. + """ + return next((x for x in iterable if condition(x)), None) + + def send(self, request: HttpRequest, *, proxies: Optional[MutableMapping[str, str]] = None, **kwargs) -> HttpResponse: + FaultInjectionTransport.logger.info("--> FaultInjectionTransport.Send {} {}".format(request.method, request.url)) + # find the first fault Factory with matching predicate if any + first_fault_factory = FaultInjectionTransport.__first_item(iter(self.faults), lambda f: f["predicate"](request)) + if first_fault_factory: + FaultInjectionTransport.logger.info("--> FaultInjectionTransport.ApplyFaultInjection") + injected_error = first_fault_factory["apply"](request) + FaultInjectionTransport.logger.info("Found to-be-injected error {}".format(injected_error)) + raise injected_error + + # apply the chain of request transformations with matching predicates if any + matching_request_transformations = filter(lambda f: f["predicate"](f["predicate"]), iter(self.requestTransformations)) + for currentTransformation in matching_request_transformations: + FaultInjectionTransport.logger.info("--> FaultInjectionTransport.ApplyRequestTransformation") + request = currentTransformation["apply"](request) + + first_response_transformation = FaultInjectionTransport.__first_item(iter(self.responseTransformations), lambda f: f["predicate"](request)) + + FaultInjectionTransport.logger.info("--> FaultInjectionTransport.BeforeGetResponseTask") + get_response_task = super().send(request, proxies=proxies, **kwargs) + FaultInjectionTransport.logger.info("<-- FaultInjectionTransport.AfterGetResponseTask") + + if first_response_transformation: + FaultInjectionTransport.logger.info(f"Invoking response transformation") + response = first_response_transformation["apply"](request, lambda: get_response_task) + response.headers["_request"] = request + FaultInjectionTransport.logger.info(f"Received response transformation result with status code {response.status_code}") + return response + else: + FaultInjectionTransport.logger.info(f"Sending request to {request.url}") + response = get_response_task + response.headers["_request"] = request + FaultInjectionTransport.logger.info(f"Received response with status code {response.status_code}") + return response + + @staticmethod + def predicate_url_contains_id(r: HttpRequest, id_value: str) -> bool: + return id_value in r.url + + @staticmethod + def predicate_targets_region(r: HttpRequest, region_endpoint: str) -> bool: + return r.url.startswith(region_endpoint) + + @staticmethod + def print_call_stack(): + print("Call stack:") + frame = sys._getframe() + while frame: + print(f"File: {frame.f_code.co_filename}, Line: {frame.f_lineno}, Function: {frame.f_code.co_name}") + frame = frame.f_back + + @staticmethod + def predicate_req_payload_contains_id(r: HttpRequest, id_value: str): + if r.body is None: + return False + + return '"id":"{}"'.format(id_value) in r.body + + @staticmethod + def predicate_req_for_document_with_id(r: HttpRequest, id_value: str) -> bool: + return (FaultInjectionTransport.predicate_url_contains_id(r, id_value) + or FaultInjectionTransport.predicate_req_payload_contains_id(r, id_value)) + + @staticmethod + def predicate_is_database_account_call(r: HttpRequest) -> bool: + is_db_account_read = (r.headers.get('x-ms-thinclient-proxy-resource-type') == 'databaseaccount' + and r.headers.get('x-ms-thinclient-proxy-operation-type') == 'Read') + + return is_db_account_read + + @staticmethod + def predicate_is_document_operation(r: HttpRequest) -> bool: + is_document_operation = (r.headers.get('x-ms-thinclient-proxy-resource-type') == 'docs') + + return is_document_operation + + @staticmethod + def predicate_is_write_operation(r: HttpRequest, uri_prefix: str) -> bool: + is_write_document_operation = documents._OperationType.IsWriteOperation( + str(r.headers.get('x-ms-thinclient-proxy-operation-type'))) + + return is_write_document_operation and uri_prefix in r.url + + @staticmethod + def error_after_delay(delay_in_ms: int, error: Exception) -> Exception: + sleep(delay_in_ms / 1000.0) + return error + + @staticmethod + def error_write_forbidden() -> Exception: + return CosmosHttpResponseError( + status_code=403, + message="Injected error disallowing writes in this region.", + response=None, + sub_status_code=3, + ) + + @staticmethod + def error_region_down() -> Exception: + return ServiceRequestError( + message="Injected region down.", + ) + + @staticmethod + def transform_topology_swr_mrr( + write_region_name: str, + read_region_name: str, + inner: Callable[[], RequestsTransportResponse]) -> RequestsTransportResponse: + + response = inner() + if not FaultInjectionTransport.predicate_is_database_account_call(response.request): + return response + + data = response.body() + if response.status_code == 200 and data: + data = data.decode("utf-8") + result = json.loads(data) + readable_locations = result["readableLocations"] + writable_locations = result["writableLocations"] + readable_locations[0]["name"] = write_region_name + writable_locations[0]["name"] = write_region_name + readable_locations.append({"name": read_region_name, "databaseAccountEndpoint" : test_config.TestConfig.local_host}) + FaultInjectionTransport.logger.info("Transformed Account Topology: {}".format(result)) + request: HttpRequest = response.request + return FaultInjectionTransport.MockHttpResponse(request, 200, result) + + return response + + @staticmethod + def transform_topology_mwr( + first_region_name: str, + second_region_name: str, + inner: Callable[[], RequestsTransportResponse]) -> RequestsTransportResponse: + + response = inner() + if not FaultInjectionTransport.predicate_is_database_account_call(response.request): + return response + + data = response.body() + if response.status_code == 200 and data: + data = data.decode("utf-8") + result = json.loads(data) + readable_locations = result["readableLocations"] + writable_locations = result["writableLocations"] + readable_locations[0]["name"] = first_region_name + writable_locations[0]["name"] = first_region_name + readable_locations.append( + {"name": second_region_name, "databaseAccountEndpoint": test_config.TestConfig.local_host}) + writable_locations.append( + {"name": second_region_name, "databaseAccountEndpoint": test_config.TestConfig.local_host}) + result["enableMultipleWriteLocations"] = True + FaultInjectionTransport.logger.info("Transformed Account Topology: {}".format(result)) + request: HttpRequest = response.request + return FaultInjectionTransport.MockHttpResponse(request, 200, result) + + return response + + class MockHttpResponse(RequestsTransportResponse): + def __init__(self, request: HttpRequest, status_code: int, content:Optional[Dict[str, Any]]): + self.request: HttpRequest = request + # This is actually never None, and set by all implementations after the call to + # __init__ of this class. This class is also a legacy impl, so it's risky to change it + # for low benefits The new "rest" implementation does define correctly status_code + # as non-optional. + self.status_code: int = status_code + self.headers: MutableMapping[str, str] = {} + self.reason: Optional[str] = None + self.content_type: Optional[str] = None + self.block_size: int = 4096 # Default to same as R + self.content: Optional[Dict[str, Any]] = None + self.json_text: str = "" + self.bytes: bytes = b"" + if content: + self.content = content + self.json_text = json.dumps(content) + self.bytes = self.json_text.encode("utf-8") + + + def body(self) -> bytes: + return self.bytes + + def text(self, encoding: Optional[str] = None) -> str: + return self.json_text + + def load_body(self) -> None: + return diff --git a/sdk/cosmos/azure-cosmos/tests/test_fault_injection_transport.py b/sdk/cosmos/azure-cosmos/tests/test_fault_injection_transport.py new file mode 100644 index 000000000000..291163159696 --- /dev/null +++ b/sdk/cosmos/azure-cosmos/tests/test_fault_injection_transport.py @@ -0,0 +1,99 @@ +import logging +import os +import sys +import time +import uuid +from typing import Callable + +import pytest +from azure.core.pipeline.transport._requests_basic import RequestsTransport +from azure.core.rest import HttpRequest + +import test_config +from _fault_injection_transport_async import FaultInjectionTransportAsync +from azure.cosmos import PartitionKey +from azure.cosmos import CosmosClient +from azure.cosmos.container import ContainerProxy +from azure.cosmos.database import DatabaseProxy +from azure.cosmos.exceptions import CosmosHttpResponseError + +logger = logging.getLogger('azure.cosmos') +logger.setLevel(logging.DEBUG) +logger.addHandler(logging.StreamHandler(sys.stdout)) + +host = test_config.TestConfig.host +master_key = test_config.TestConfig.masterKey +TEST_DATABASE_ID = test_config.TestConfig.TEST_DATABASE_ID +SINGLE_PARTITION_CONTAINER_NAME = os.path.basename(__file__) + str(uuid.uuid4()) + +@pytest.mark.unittest +@pytest.mark.cosmosEmulator +class TestFaultInjectionTransport: + + @classmethod + async def setup_class(cls): + logger.info("starting class: {} execution".format(cls.__name__)) + cls.host = host + cls.master_key = master_key + + if (cls.master_key == '[YOUR_KEY_HERE]' or + cls.host == '[YOUR_ENDPOINT_HERE]'): + raise Exception( + "You must specify your Azure Cosmos account values for " + "'masterKey' and 'host' at the top of this class to run the " + "tests.") + cls.database_id = TEST_DATABASE_ID + cls.single_partition_container_name = SINGLE_PARTITION_CONTAINER_NAME + + cls.mgmt_client = CosmosClient(cls.host, cls.master_key, consistency_level="Session", logger=logger) + created_database = cls.mgmt_client.get_database_client(cls.database_id) + created_database.create_container( + cls.single_partition_container_name, + partition_key=PartitionKey("/pk")) + + + @classmethod + async def teardown_class(cls): + logger.info("tearing down class: {}".format(cls.__name__)) + created_database = cls.mgmt_client.get_database_client(cls.database_id) + try: + created_database.delete_container(cls.single_partition_container_name), + except Exception as containerDeleteError: + logger.warning("Exception trying to delete database {}. {}".format(created_database.id, containerDeleteError)) + + @staticmethod + def setup_method_with_custom_transport(custom_transport: RequestsTransport, default_endpoint=host, **kwargs): + client = CosmosClient(default_endpoint, master_key, consistency_level="Session", + transport=custom_transport, logger=logger, enable_diagnostics_logging=True, **kwargs) + db: DatabaseProxy = client.get_database_client(TEST_DATABASE_ID) + container: ContainerProxy = db.get_container_client(SINGLE_PARTITION_CONTAINER_NAME) + return {"client": client, "db": db, "col": container} + + + def test_throws_injected_error_async(self: "TestFaultInjectionTransport"): + id_value: str = str(uuid.uuid4()) + document_definition = {'id': id_value, + 'pk': id_value, + 'name': 'sample document', + 'key': 'value'} + + custom_transport = FaultInjectionTransportAsync() + predicate : Callable[[HttpRequest], bool] = lambda r: FaultInjectionTransportAsync.predicate_req_for_document_with_id(r, id_value) + custom_transport.add_fault(predicate, lambda r: FaultInjectionTransportAsync.error_after_delay( + 10000, + CosmosHttpResponseError( + status_code=502, + message="Some random reverse proxy error."))) + + initialized_objects = TestFaultInjectionTransport.setup_method_with_custom_transport(custom_transport) + start: float = time.perf_counter() + try: + container: ContainerProxy = initialized_objects["col"] + container.create_item(body=document_definition) + pytest.fail("Expected exception not thrown") + except CosmosHttpResponseError as cosmosError: + end = time.perf_counter() - start + # validate response took more than 10 seconds + assert end > 10 + if cosmosError.status_code != 502: + raise cosmosError \ No newline at end of file From 2fb3dc93c455cf90c9008e3a670c85068f4b9e8c Mon Sep 17 00:00:00 2001 From: Tomas Varon Saldarriaga Date: Thu, 3 Apr 2025 15:59:57 -0700 Subject: [PATCH 48/86] add all sync tests --- .../tests/test_fault_injection_transport.py | 392 +++++++++++++++++- 1 file changed, 381 insertions(+), 11 deletions(-) diff --git a/sdk/cosmos/azure-cosmos/tests/test_fault_injection_transport.py b/sdk/cosmos/azure-cosmos/tests/test_fault_injection_transport.py index 291163159696..ef1f28b14adc 100644 --- a/sdk/cosmos/azure-cosmos/tests/test_fault_injection_transport.py +++ b/sdk/cosmos/azure-cosmos/tests/test_fault_injection_transport.py @@ -2,6 +2,7 @@ import os import sys import time +import unittest import uuid from typing import Callable @@ -10,12 +11,13 @@ from azure.core.rest import HttpRequest import test_config -from _fault_injection_transport_async import FaultInjectionTransportAsync from azure.cosmos import PartitionKey from azure.cosmos import CosmosClient from azure.cosmos.container import ContainerProxy from azure.cosmos.database import DatabaseProxy from azure.cosmos.exceptions import CosmosHttpResponseError +from tests._fault_injection_transport import FaultInjectionTransport +from azure.core.exceptions import ServiceRequestError logger = logging.getLogger('azure.cosmos') logger.setLevel(logging.DEBUG) @@ -31,7 +33,7 @@ class TestFaultInjectionTransport: @classmethod - async def setup_class(cls): + def setup_class(cls): logger.info("starting class: {} execution".format(cls.__name__)) cls.host = host cls.master_key = master_key @@ -47,13 +49,11 @@ async def setup_class(cls): cls.mgmt_client = CosmosClient(cls.host, cls.master_key, consistency_level="Session", logger=logger) created_database = cls.mgmt_client.get_database_client(cls.database_id) - created_database.create_container( - cls.single_partition_container_name, - partition_key=PartitionKey("/pk")) + created_database.create_container(cls.single_partition_container_name, partition_key=PartitionKey("/pk")) @classmethod - async def teardown_class(cls): + def teardown_class(cls): logger.info("tearing down class: {}".format(cls.__name__)) created_database = cls.mgmt_client.get_database_client(cls.database_id) try: @@ -70,16 +70,16 @@ def setup_method_with_custom_transport(custom_transport: RequestsTransport, defa return {"client": client, "db": db, "col": container} - def test_throws_injected_error_async(self: "TestFaultInjectionTransport"): + def test_throws_injected_error(self: "TestFaultInjectionTransport"): id_value: str = str(uuid.uuid4()) document_definition = {'id': id_value, 'pk': id_value, 'name': 'sample document', 'key': 'value'} - custom_transport = FaultInjectionTransportAsync() - predicate : Callable[[HttpRequest], bool] = lambda r: FaultInjectionTransportAsync.predicate_req_for_document_with_id(r, id_value) - custom_transport.add_fault(predicate, lambda r: FaultInjectionTransportAsync.error_after_delay( + custom_transport = FaultInjectionTransport() + predicate : Callable[[HttpRequest], bool] = lambda r: FaultInjectionTransport.predicate_req_for_document_with_id(r, id_value) + custom_transport.add_fault(predicate, lambda r: FaultInjectionTransport.error_after_delay( 10000, CosmosHttpResponseError( status_code=502, @@ -96,4 +96,374 @@ def test_throws_injected_error_async(self: "TestFaultInjectionTransport"): # validate response took more than 10 seconds assert end > 10 if cosmosError.status_code != 502: - raise cosmosError \ No newline at end of file + raise cosmosError + + def test_swr_mrr_succeeds(self: "TestFaultInjectionTransport"): + expected_read_region_uri: str = test_config.TestConfig.local_host + expected_write_region_uri: str = expected_read_region_uri.replace("localhost", "127.0.0.1") + custom_transport = FaultInjectionTransport() + # Inject rule to disallow writes in the read-only region + is_write_operation_in_read_region_predicate: Callable[[HttpRequest], bool] = lambda \ + r: FaultInjectionTransport.predicate_is_write_operation(r, expected_read_region_uri) + + custom_transport.add_fault( + is_write_operation_in_read_region_predicate, + lambda r: FaultInjectionTransport.error_write_forbidden()) + + # Inject topology transformation that would make Emulator look like a single write region + # account with two read regions + is_get_account_predicate: Callable[[HttpRequest], bool] = lambda r: FaultInjectionTransport.predicate_is_database_account_call(r) + emulator_as_multi_region_sm_account_transformation = \ + lambda r, inner: FaultInjectionTransport.transform_topology_swr_mrr( + write_region_name="Write Region", + read_region_name="Read Region", + inner=inner) + custom_transport.add_response_transformation( + is_get_account_predicate, + emulator_as_multi_region_sm_account_transformation) + + id_value: str = str(uuid.uuid4()) + document_definition = {'id': id_value, + 'pk': id_value, + 'name': 'sample document', + 'key': 'value'} + + initialized_objects = self.setup_method_with_custom_transport( + custom_transport, + preferred_locations=["Read Region", "Write Region"]) + container: ContainerProxy = initialized_objects["col"] + + created_document = container.create_item(body=document_definition) + request: HttpRequest = created_document.get_response_headers()["_request"] + # Validate the response comes from "Write Region" (the write region) + assert request.url.startswith(expected_write_region_uri) + start: float = time.perf_counter() + + while (time.perf_counter() - start) < 2: + read_document = container.read_item(id_value, partition_key=id_value) + request = read_document.get_response_headers()["_request"] + # Validate the response comes from "Read Region" (the most preferred read-only region) + assert request.url.startswith(expected_read_region_uri) + + + def test_swr_mrr_region_down_read_succeeds(self: "TestFaultInjectionTransport"): + expected_read_region_uri: str = test_config.TestConfig.local_host + expected_write_region_uri: str = expected_read_region_uri.replace("localhost", "127.0.0.1") + custom_transport = FaultInjectionTransport() + # Inject rule to disallow writes in the read-only region + is_write_operation_in_read_region_predicate: Callable[[HttpRequest], bool] = lambda \ + r: FaultInjectionTransport.predicate_is_write_operation(r, expected_read_region_uri) + + custom_transport.add_fault( + is_write_operation_in_read_region_predicate, + lambda r: FaultInjectionTransport.error_write_forbidden()) + + # Inject rule to simulate regional outage in "Read Region" + is_request_to_read_region: Callable[[HttpRequest], bool] = lambda \ + r: FaultInjectionTransport.predicate_targets_region(r, expected_read_region_uri) + + custom_transport.add_fault( + is_request_to_read_region, + lambda r: FaultInjectionTransport.error_region_down()) + + # Inject topology transformation that would make Emulator look like a single write region + # account with two read regions + is_get_account_predicate: Callable[[HttpRequest], bool] = lambda r: FaultInjectionTransport.predicate_is_database_account_call(r) + emulator_as_multi_region_sm_account_transformation = \ + lambda r, inner: FaultInjectionTransport.transform_topology_swr_mrr( + write_region_name="Write Region", + read_region_name="Read Region", + inner=inner) + custom_transport.add_response_transformation( + is_get_account_predicate, + emulator_as_multi_region_sm_account_transformation) + + id_value: str = str(uuid.uuid4()) + document_definition = {'id': id_value, + 'pk': id_value, + 'name': 'sample document', + 'key': 'value'} + + initialized_objects = self.setup_method_with_custom_transport( + custom_transport, + default_endpoint=expected_write_region_uri, + preferred_locations=["Read Region", "Write Region"]) + container: ContainerProxy = initialized_objects["col"] + + created_document = container.create_item(body=document_definition) + request: HttpRequest = created_document.get_response_headers()["_request"] + # Validate the response comes from "South Central US" (the write region) + assert request.url.startswith(expected_write_region_uri) + start:float = time.perf_counter() + + while (time.perf_counter() - start) < 2: + read_document = container.read_item(id_value, partition_key=id_value) + request = read_document.get_response_headers()["_request"] + # Validate the response comes from "Write Region" ("Read Region" the most preferred read-only region is down) + assert request.url.startswith(expected_write_region_uri) + + + def test_swr_mrr_region_down_envoy_read_succeeds(self: "TestFaultInjectionTransport"): + expected_read_region_uri: str = test_config.TestConfig.local_host + expected_write_region_uri: str = expected_read_region_uri.replace("localhost", "127.0.0.1") + custom_transport = FaultInjectionTransport() + # Inject rule to disallow writes in the read-only region + is_write_operation_in_read_region_predicate: Callable[[HttpRequest], bool] = lambda \ + r: FaultInjectionTransport.predicate_is_write_operation(r, expected_read_region_uri) + + custom_transport.add_fault( + is_write_operation_in_read_region_predicate, + lambda r: FaultInjectionTransport.error_write_forbidden()) + + # Inject rule to simulate regional outage in "Read Region" + is_request_to_read_region: Callable[[HttpRequest], bool] = lambda \ + r: FaultInjectionTransport.predicate_targets_region(r, expected_read_region_uri) and \ + FaultInjectionTransport.predicate_is_document_operation(r) + + custom_transport.add_fault( + is_request_to_read_region, + lambda r: FaultInjectionTransport.error_after_delay( + 500, + CosmosHttpResponseError( + status_code=502, + message="Some random reverse proxy error."))) + + # Inject topology transformation that would make Emulator look like a single write region + # account with two read regions + is_get_account_predicate: Callable[[HttpRequest], bool] = lambda r: FaultInjectionTransport.predicate_is_database_account_call(r) + emulator_as_multi_region_sm_account_transformation = \ + lambda r, inner: FaultInjectionTransport.transform_topology_swr_mrr( + write_region_name="Write Region", + read_region_name="Read Region", + inner=inner) + custom_transport.add_response_transformation( + is_get_account_predicate, + emulator_as_multi_region_sm_account_transformation) + + id_value: str = str(uuid.uuid4()) + document_definition = {'id': id_value, + 'pk': id_value, + 'name': 'sample document', + 'key': 'value'} + + initialized_objects = self.setup_method_with_custom_transport( + custom_transport, + default_endpoint=expected_write_region_uri, + preferred_locations=["Read Region", "Write Region"]) + container: ContainerProxy = initialized_objects["col"] + + created_document = container.create_item(body=document_definition) + request: HttpRequest = created_document.get_response_headers()["_request"] + # Validate the response comes from "South Central US" (the write region) + assert request.url.startswith(expected_write_region_uri) + start:float = time.perf_counter() + + while (time.perf_counter() - start) < 2: + read_document = container.read_item(id_value, partition_key=id_value) + request = read_document.get_response_headers()["_request"] + # Validate the response comes from "Write Region" ("Read Region" the most preferred read-only region is down) + assert request.url.startswith(expected_write_region_uri) + + + + def test_mwr_succeeds(self: "TestFaultInjectionTransport"): + first_region_uri: str = test_config.TestConfig.local_host.replace("localhost", "127.0.0.1") + custom_transport = FaultInjectionTransport() + + # Inject topology transformation that would make Emulator look like a multiple write region account + # account with two read regions + is_get_account_predicate: Callable[[HttpRequest], bool] = lambda r: FaultInjectionTransport.predicate_is_database_account_call(r) + emulator_as_multi_write_region_account_transformation = \ + lambda r, inner: FaultInjectionTransport.transform_topology_mwr( + first_region_name="First Region", + second_region_name="Second Region", + inner=inner) + custom_transport.add_response_transformation( + is_get_account_predicate, + emulator_as_multi_write_region_account_transformation) + + id_value: str = str(uuid.uuid4()) + document_definition = {'id': id_value, + 'pk': id_value, + 'name': 'sample document', + 'key': 'value'} + + initialized_objects = self.setup_method_with_custom_transport( + custom_transport, + preferred_locations=["First Region", "Second Region"]) + container: ContainerProxy = initialized_objects["col"] + + created_document = container.create_item(body=document_definition) + request: HttpRequest = created_document.get_response_headers()["_request"] + # Validate the response comes from "South Central US" (the write region) + assert request.url.startswith(first_region_uri) + start:float = time.perf_counter() + + while (time.perf_counter() - start) < 2: + read_document = container.read_item(id_value, partition_key=id_value) + request = read_document.get_response_headers()["_request"] + # Validate the response comes from "East US" (the most preferred read-only region) + assert request.url.startswith(first_region_uri) + + + def test_mwr_region_down_succeeds(self: "TestFaultInjectionTransport"): + first_region_uri: str = test_config.TestConfig.local_host.replace("localhost", "127.0.0.1") + second_region_uri: str = test_config.TestConfig.local_host + custom_transport = FaultInjectionTransport() + + # Inject topology transformation that would make Emulator look like a multiple write region account + # account with two read regions + is_get_account_predicate: Callable[[HttpRequest], bool] = lambda r: FaultInjectionTransport.predicate_is_database_account_call(r) + emulator_as_multi_write_region_account_transformation = \ + lambda r, inner: FaultInjectionTransport.transform_topology_mwr( + first_region_name="First Region", + second_region_name="Second Region", + inner=inner) + custom_transport.add_response_transformation( + is_get_account_predicate, + emulator_as_multi_write_region_account_transformation) + + # Inject rule to simulate regional outage in "First Region" + is_request_to_first_region: Callable[[HttpRequest], bool] = lambda \ + r: FaultInjectionTransport.predicate_targets_region(r, first_region_uri) and \ + FaultInjectionTransport.predicate_is_document_operation(r) + + custom_transport.add_fault( + is_request_to_first_region, + lambda r: FaultInjectionTransport.error_region_down()) + + id_value: str = str(uuid.uuid4()) + document_definition = {'id': id_value, + 'pk': id_value, + 'name': 'sample document', + 'key': 'value'} + + initialized_objects = self.setup_method_with_custom_transport( + custom_transport, + preferred_locations=["First Region", "Second Region"]) + container: ContainerProxy = initialized_objects["col"] + + start:float = time.perf_counter() + while (time.perf_counter() - start) < 2: + # reads and writes should failover to second region + upsert_document = container.upsert_item(body=document_definition) + request = upsert_document.get_response_headers()["_request"] + assert request.url.startswith(second_region_uri) + read_document = container.read_item(id_value, partition_key=id_value) + request = read_document.get_response_headers()["_request"] + # Validate the response comes from "East US" (the most preferred read-only region) + assert request.url.startswith(second_region_uri) + + + def test_swr_mrr_all_regions_down_for_read(self: "TestFaultInjectionTransport"): + expected_read_region_uri: str = test_config.TestConfig.local_host + expected_write_region_uri: str = expected_read_region_uri.replace("localhost", "127.0.0.1") + custom_transport = FaultInjectionTransport() + + # Inject rule to disallow writes in the read-only region + is_write_operation_in_read_region_predicate: Callable[[HttpRequest], bool] = lambda \ + r: FaultInjectionTransport.predicate_is_write_operation(r, expected_read_region_uri) + + custom_transport.add_fault( + is_write_operation_in_read_region_predicate, + lambda r: FaultInjectionTransport.error_write_forbidden()) + + # Inject topology transformation that would make Emulator look like a multiple write region account + # account with two read regions + is_get_account_predicate: Callable[[HttpRequest], bool] = lambda \ + r: FaultInjectionTransport.predicate_is_database_account_call(r) + emulator_as_multi_write_region_account_transformation = \ + lambda r, inner: FaultInjectionTransport.transform_topology_swr_mrr( + write_region_name="Write Region", + read_region_name="Read Region", + inner=inner) + custom_transport.add_response_transformation( + is_get_account_predicate, + emulator_as_multi_write_region_account_transformation) + + # Inject rule to simulate regional outage in "First Region" + is_request_to_first_region: Callable[[HttpRequest], bool] = lambda \ + r: (FaultInjectionTransport.predicate_targets_region(r, expected_write_region_uri) and + FaultInjectionTransport.predicate_is_document_operation(r) and + not FaultInjectionTransport.predicate_is_write_operation(r, expected_write_region_uri)) + + + # Inject rule to simulate regional outage in "Second Region" + is_request_to_second_region: Callable[[HttpRequest], bool] = lambda \ + r: (FaultInjectionTransport.predicate_targets_region(r, expected_read_region_uri) and + FaultInjectionTransport.predicate_is_document_operation(r) and + not FaultInjectionTransport.predicate_is_write_operation(r, expected_write_region_uri)) + + custom_transport.add_fault( + is_request_to_first_region, + lambda r: FaultInjectionTransport.error_region_down()) + custom_transport.add_fault( + is_request_to_second_region, + lambda r: FaultInjectionTransport.error_region_down()) + + id_value: str = str(uuid.uuid4()) + document_definition = {'id': id_value, + 'pk': id_value, + 'name': 'sample document', + 'key': 'value'} + + initialized_objects = self.setup_method_with_custom_transport( + custom_transport, + preferred_locations=["First Region", "Second Region"]) + container: ContainerProxy = initialized_objects["col"] + container.upsert_item(body=document_definition) + with pytest.raises(ServiceRequestError): + container.read_item(id_value, partition_key=id_value) + + def test_mwr_all_regions_down(self: "TestFaultInjectionTransport"): + + first_region_uri: str = test_config.TestConfig.local_host.replace("localhost", "127.0.0.1") + second_region_uri: str = test_config.TestConfig.local_host + custom_transport = FaultInjectionTransport() + + # Inject topology transformation that would make Emulator look like a multiple write region account + # account with two read regions + is_get_account_predicate: Callable[[HttpRequest], bool] = lambda \ + r: FaultInjectionTransport.predicate_is_database_account_call(r) + emulator_as_multi_write_region_account_transformation = \ + lambda r, inner: FaultInjectionTransport.transform_topology_mwr( + first_region_name="First Region", + second_region_name="Second Region", + inner=inner) + custom_transport.add_response_transformation( + is_get_account_predicate, + emulator_as_multi_write_region_account_transformation) + + # Inject rule to simulate regional outage in "First Region" + is_request_to_first_region: Callable[[HttpRequest], bool] = lambda \ + r: FaultInjectionTransport.predicate_targets_region(r, first_region_uri) and \ + FaultInjectionTransport.predicate_is_document_operation(r) + + # Inject rule to simulate regional outage in "Second Region" + is_request_to_second_region: Callable[[HttpRequest], bool] = lambda \ + r: FaultInjectionTransport.predicate_targets_region(r, second_region_uri) and \ + FaultInjectionTransport.predicate_is_document_operation(r) + + custom_transport.add_fault( + is_request_to_first_region, + lambda r: FaultInjectionTransport.error_region_down()) + custom_transport.add_fault( + is_request_to_second_region, + lambda r: FaultInjectionTransport.error_region_down()) + + id_value: str = str(uuid.uuid4()) + document_definition = {'id': id_value, + 'pk': id_value, + 'name': 'sample document', + 'key': 'value'} + + initialized_objects = self.setup_method_with_custom_transport( + custom_transport, + preferred_locations=["First Region", "Second Region"]) + container: ContainerProxy = initialized_objects["col"] + with pytest.raises(ServiceRequestError): + container.upsert_item(body=document_definition) + + +if __name__ == '__main__': + unittest.main() \ No newline at end of file From 7b81482b7bd282883a3aefb0c5cd0674484dc203 Mon Sep 17 00:00:00 2001 From: Tomas Varon Saldarriaga Date: Thu, 3 Apr 2025 16:08:49 -0700 Subject: [PATCH 49/86] add new error and fix logs --- .../tests/_fault_injection_transport.py | 10 ++++++++-- .../tests/_fault_injection_transport_async.py | 20 ++++++++++++------- 2 files changed, 21 insertions(+), 9 deletions(-) diff --git a/sdk/cosmos/azure-cosmos/tests/_fault_injection_transport.py b/sdk/cosmos/azure-cosmos/tests/_fault_injection_transport.py index 9cdd18936945..628456d95158 100644 --- a/sdk/cosmos/azure-cosmos/tests/_fault_injection_transport.py +++ b/sdk/cosmos/azure-cosmos/tests/_fault_injection_transport.py @@ -19,7 +19,7 @@ # OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE # SOFTWARE. -"""AioHttpTransport allowing injection of faults between SDK and Cosmos Gateway +"""RequestTransport allowing injection of faults between SDK and Cosmos Gateway """ import json @@ -36,7 +36,7 @@ import test_config from azure.cosmos.exceptions import CosmosHttpResponseError -from azure.core.exceptions import ServiceRequestError +from azure.core.exceptions import ServiceRequestError, ServiceResponseError class FaultInjectionTransport(RequestsTransport): logger = logging.getLogger('azure.cosmos.fault_injection_transport') @@ -168,6 +168,12 @@ def error_region_down() -> Exception: message="Injected region down.", ) + @staticmethod + def error_service_response() -> Exception: + return ServiceResponseError( + message="Injected Service Response Error.", + ) + @staticmethod def transform_topology_swr_mrr( write_region_name: str, diff --git a/sdk/cosmos/azure-cosmos/tests/_fault_injection_transport_async.py b/sdk/cosmos/azure-cosmos/tests/_fault_injection_transport_async.py index 4551b0235bad..13dda0dc7e20 100644 --- a/sdk/cosmos/azure-cosmos/tests/_fault_injection_transport_async.py +++ b/sdk/cosmos/azure-cosmos/tests/_fault_injection_transport_async.py @@ -34,10 +34,10 @@ import test_config from azure.cosmos.exceptions import CosmosHttpResponseError -from azure.core.exceptions import ServiceRequestError +from azure.core.exceptions import ServiceRequestError, ServiceResponseError class FaultInjectionTransportAsync(AioHttpTransport): - logger = logging.getLogger('azure.cosmos.fault_injection_transport') + logger = logging.getLogger('azure.cosmos.fault_injection_transport_async') logger.setLevel(logging.DEBUG) def __init__(self, *, session: Optional[aiohttp.ClientSession] = None, loop=None, session_owner: bool = True, **config): @@ -64,11 +64,11 @@ def __first_item(iterable, condition=lambda x: True): return next((x for x in iterable if condition(x)), None) async def send(self, request: HttpRequest, *, stream: bool = False, proxies: Optional[MutableMapping[str, str]] = None, **config) -> AsyncHttpResponse: - FaultInjectionTransportAsync.logger.info("--> FaultInjectionTransport.Send {} {}".format(request.method, request.url)) + FaultInjectionTransportAsync.logger.info("--> FaultInjectionTransportAsync.Send {} {}".format(request.method, request.url)) # find the first fault Factory with matching predicate if any first_fault_factory = FaultInjectionTransportAsync.__first_item(iter(self.faults), lambda f: f["predicate"](request)) if first_fault_factory: - FaultInjectionTransportAsync.logger.info("--> FaultInjectionTransport.ApplyFaultInjection") + FaultInjectionTransportAsync.logger.info("--> FaultInjectionTransportAsync.ApplyFaultInjection") injected_error = await first_fault_factory["apply"](request) FaultInjectionTransportAsync.logger.info("Found to-be-injected error {}".format(injected_error)) raise injected_error @@ -76,14 +76,14 @@ async def send(self, request: HttpRequest, *, stream: bool = False, proxies: Opt # apply the chain of request transformations with matching predicates if any matching_request_transformations = filter(lambda f: f["predicate"](f["predicate"]), iter(self.requestTransformations)) for currentTransformation in matching_request_transformations: - FaultInjectionTransportAsync.logger.info("--> FaultInjectionTransport.ApplyRequestTransformation") + FaultInjectionTransportAsync.logger.info("--> FaultInjectionTransportAsync.ApplyRequestTransformation") request = await currentTransformation["apply"](request) first_response_transformation = FaultInjectionTransportAsync.__first_item(iter(self.responseTransformations), lambda f: f["predicate"](request)) - FaultInjectionTransportAsync.logger.info("--> FaultInjectionTransport.BeforeGetResponseTask") + FaultInjectionTransportAsync.logger.info("--> FaultInjectionTransportAsync.BeforeGetResponseTask") get_response_task = asyncio.create_task(super().send(request, stream=stream, proxies=proxies, **config)) - FaultInjectionTransportAsync.logger.info("<-- FaultInjectionTransport.AfterGetResponseTask") + FaultInjectionTransportAsync.logger.info("<-- FaultInjectionTransportAsync.AfterGetResponseTask") if first_response_transformation: FaultInjectionTransportAsync.logger.info(f"Invoking response transformation") @@ -166,6 +166,12 @@ async def error_region_down() -> Exception: message="Injected region down.", ) + @staticmethod + async def error_service_response() -> Exception: + return ServiceResponseError( + message="Injected Service Response Error.", + ) + @staticmethod async def transform_topology_swr_mrr( write_region_name: str, From f355e306d4c998b9e897f24a6d21b881da1bb730 Mon Sep 17 00:00:00 2001 From: Tomas Varon Saldarriaga Date: Thu, 3 Apr 2025 16:28:45 -0700 Subject: [PATCH 50/86] fix test --- sdk/cosmos/azure-cosmos/tests/test_fault_injection_transport.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sdk/cosmos/azure-cosmos/tests/test_fault_injection_transport.py b/sdk/cosmos/azure-cosmos/tests/test_fault_injection_transport.py index ef1f28b14adc..304fa8d50f0d 100644 --- a/sdk/cosmos/azure-cosmos/tests/test_fault_injection_transport.py +++ b/sdk/cosmos/azure-cosmos/tests/test_fault_injection_transport.py @@ -16,7 +16,7 @@ from azure.cosmos.container import ContainerProxy from azure.cosmos.database import DatabaseProxy from azure.cosmos.exceptions import CosmosHttpResponseError -from tests._fault_injection_transport import FaultInjectionTransport +from _fault_injection_transport import FaultInjectionTransport from azure.core.exceptions import ServiceRequestError logger = logging.getLogger('azure.cosmos') From 8495c5139742060f74301dd0441c8e5b4fff787a Mon Sep 17 00:00:00 2001 From: Allen Kim Date: Fri, 4 Apr 2025 10:34:42 -0700 Subject: [PATCH 51/86] Add cosmosQuery mark to TestQuery --- sdk/cosmos/azure-cosmos/tests/test_query.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sdk/cosmos/azure-cosmos/tests/test_query.py b/sdk/cosmos/azure-cosmos/tests/test_query.py index 28262aa0f7e3..2a99263ed457 100644 --- a/sdk/cosmos/azure-cosmos/tests/test_query.py +++ b/sdk/cosmos/azure-cosmos/tests/test_query.py @@ -17,7 +17,7 @@ from azure.cosmos.documents import _DistinctType from azure.cosmos.partition_key import PartitionKey - +@pytest.mark.cosmosQuery class TestQuery(unittest.TestCase): """Test to ensure escaping of non-ascii characters from partition key""" From b29980c0ed0feda7b3fb04adb5d4560bd813ccd8 Mon Sep 17 00:00:00 2001 From: Allen Kim Date: Fri, 4 Apr 2025 10:35:05 -0700 Subject: [PATCH 52/86] Correct spelling --- sdk/cosmos/azure-cosmos/tests/test_excluded_locations.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sdk/cosmos/azure-cosmos/tests/test_excluded_locations.py b/sdk/cosmos/azure-cosmos/tests/test_excluded_locations.py index 01d1e9e9cf7e..7a93e6e89ec7 100644 --- a/sdk/cosmos/azure-cosmos/tests/test_excluded_locations.py +++ b/sdk/cosmos/azure-cosmos/tests/test_excluded_locations.py @@ -74,7 +74,7 @@ def emit(self, record): [[L1, L2, L3], [L1, L2, L3], [L1, L2, L3]], # 5. No common excluded locations [[L1, L2, L3], [L1], [L2, L3]], - # 6. Reqeust excluded location not in preferred locations + # 6. Request excluded location not in preferred locations [[L1, L2, L3], [L1, L2, L3], [L4]], # 7. Empty excluded locations, remove all client excluded locations [[L1, L2, L3], [L1, L2], []], From 5e79172c2da25361e80a098222b712542b1b19ea Mon Sep 17 00:00:00 2001 From: Allen Kim Date: Fri, 4 Apr 2025 10:35:37 -0700 Subject: [PATCH 53/86] Fixed live platform matrix syntax --- sdk/cosmos/live-platform-matrix.json | 44 ++++++++-------------------- 1 file changed, 13 insertions(+), 31 deletions(-) diff --git a/sdk/cosmos/live-platform-matrix.json b/sdk/cosmos/live-platform-matrix.json index b3242623be78..494c5fc62cea 100644 --- a/sdk/cosmos/live-platform-matrix.json +++ b/sdk/cosmos/live-platform-matrix.json @@ -86,43 +86,25 @@ "CoverageArg": "--disablecov", "TestSamples": "false", "TestMarkArgument": "cosmosLong" - }, + } + } + }, + { + "DESIRED_CONSISTENCIES": "[\"Session\"]", + "ACCOUNT_CONSISTENCY": "Session", + "ArmConfig": { + "MultiMaster_MultiRegion": { + "ArmTemplateParameters": "@{ enableMultipleWriteLocations = $true; defaultConsistencyLevel = 'Session'; enableMultipleRegions = $true }" + } + }, + "WindowsConfig": { "Windows2022_38_multi_region": { "OSVmImage": "env:WINDOWSVMIMAGE", "Pool": "env:WINDOWSPOOL", "PythonVersion": "3.8", "CoverageArg": "--disablecov", "TestSamples": "false", - "TestMarkArgument": "cosmosMultiRegion", - "ArmConfig": { - "ArmTemplateParameters": "@{ enableMultipleRegions = $true }" - } - }, - "Windows2022_310_multi_region": { - "OSVmImage": "env:WINDOWSVMIMAGE", - "Pool": "env:WINDOWSPOOL", - "PythonVersion": "3.10", - "CoverageArg": "--disablecov", - "TestSamples": "false", - "TestMarkArgument": "cosmosMultiRegion", - "ArmConfig": { - "MultiMaster_MultiRegion": { - "ArmTemplateParameters": "@{ enableMultipleRegions = $true }" - } - } - }, - "Windows2022_312_multi_region": { - "OSVmImage": "env:WINDOWSVMIMAGE", - "Pool": "env:WINDOWSPOOL", - "PythonVersion": "3.12", - "CoverageArg": "--disablecov", - "TestSamples": "false", - "TestMarkArgument": "cosmosMultiRegion", - "ArmConfig": { - "MultiMaster_MultiRegion": { - "ArmTemplateParameters": "@{ enableMultipleRegions = $true }" - } - } + "TestMarkArgument": "cosmosMultiRegion" } } } From fd40cd724873ad9ab46520a18304a5a900fde7b1 Mon Sep 17 00:00:00 2001 From: Allen Kim Date: Fri, 4 Apr 2025 11:42:08 -0700 Subject: [PATCH 54/86] Changed Multi-regions --- .../tests/test_excluded_locations.py | 18 +++++++++--------- sdk/cosmos/live-platform-matrix.json | 6 ++---- sdk/cosmos/test-resources.bicep | 6 +++--- 3 files changed, 14 insertions(+), 16 deletions(-) diff --git a/sdk/cosmos/azure-cosmos/tests/test_excluded_locations.py b/sdk/cosmos/azure-cosmos/tests/test_excluded_locations.py index 7a93e6e89ec7..2d17bad85ba5 100644 --- a/sdk/cosmos/azure-cosmos/tests/test_excluded_locations.py +++ b/sdk/cosmos/azure-cosmos/tests/test_excluded_locations.py @@ -34,18 +34,18 @@ def emit(self, record): ITEM_PK_VALUE = 'pk' TEST_ITEM = {'id': ITEM_ID, PARTITION_KEY: ITEM_PK_VALUE} -# L0 = "Default" -# L1 = "West US 3" -# L2 = "West US" -# L3 = "East US 2" -# L4 = "Central US" - L0 = "Default" -L1 = "East US 2" -L2 = "East US" -L3 = "West US 2" +L1 = "West US 3" +L2 = "West US" +L3 = "East US 2" L4 = "Central US" +# L0 = "Default" +# L1 = "East US 2" +# L2 = "East US" +# L3 = "West US 2" +# L4 = "Central US" + CLIENT_ONLY_TEST_DATA = [ # preferred_locations, client_excluded_locations, excluded_locations_request # 0. No excluded location diff --git a/sdk/cosmos/live-platform-matrix.json b/sdk/cosmos/live-platform-matrix.json index 494c5fc62cea..7a02486a8827 100644 --- a/sdk/cosmos/live-platform-matrix.json +++ b/sdk/cosmos/live-platform-matrix.json @@ -90,18 +90,16 @@ } }, { - "DESIRED_CONSISTENCIES": "[\"Session\"]", - "ACCOUNT_CONSISTENCY": "Session", "ArmConfig": { "MultiMaster_MultiRegion": { "ArmTemplateParameters": "@{ enableMultipleWriteLocations = $true; defaultConsistencyLevel = 'Session'; enableMultipleRegions = $true }" } }, "WindowsConfig": { - "Windows2022_38_multi_region": { + "Windows2022_312": { "OSVmImage": "env:WINDOWSVMIMAGE", "Pool": "env:WINDOWSPOOL", - "PythonVersion": "3.8", + "PythonVersion": "3.12", "CoverageArg": "--disablecov", "TestSamples": "false", "TestMarkArgument": "cosmosMultiRegion" diff --git a/sdk/cosmos/test-resources.bicep b/sdk/cosmos/test-resources.bicep index 61588a526eed..b05dead26737 100644 --- a/sdk/cosmos/test-resources.bicep +++ b/sdk/cosmos/test-resources.bicep @@ -30,19 +30,19 @@ var singleRegionConfiguration = [ ] var multiRegionConfiguration = [ { - locationName: 'East US 2' + locationName: 'West US 3' provisioningState: 'Succeeded' failoverPriority: 0 isZoneRedundant: false } { - locationName: 'East US' + locationName: 'West US' provisioningState: 'Succeeded' failoverPriority: 1 isZoneRedundant: false } { - locationName: 'West US 2' + locationName: 'East US 2' provisioningState: 'Succeeded' failoverPriority: 2 isZoneRedundant: false From 85e1206f7771cb0f7a82e7460ae9ca4b247f5712 Mon Sep 17 00:00:00 2001 From: tvaron3 Date: Fri, 4 Apr 2025 14:56:01 -0700 Subject: [PATCH 55/86] first ppcb test --- .../azure-cosmos/azure/cosmos/_constants.py | 2 +- ...tition_endpoint_manager_circuit_breaker.py | 1 + ...n_endpoint_manager_circuit_breaker_core.py | 15 +- .../azure/cosmos/_partition_health_tracker.py | 2 +- .../tests/test_ppcb_sm_mrr_async.py | 159 ++++++++++++++++++ 5 files changed, 172 insertions(+), 7 deletions(-) create mode 100644 sdk/cosmos/azure-cosmos/tests/test_ppcb_sm_mrr_async.py diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/_constants.py b/sdk/cosmos/azure-cosmos/azure/cosmos/_constants.py index cf029179f1a1..38848f3a5d72 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/_constants.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/_constants.py @@ -53,7 +53,7 @@ class _Constants: CONSECUTIVE_ERROR_COUNT_TOLERATED_FOR_WRITE: str = "AZURE_COSMOS_CONSECUTIVE_ERROR_COUNT_TOLERATED_FOR_WRITE" CONSECUTIVE_ERROR_COUNT_TOLERATED_FOR_WRITE_DEFAULT: int = 5 FAILURE_PERCENTAGE_TOLERATED = "AZURE_COSMOS_FAILURE_PERCENTAGE_TOLERATED" - FAILURE_PERCENTAGE_TOLERATED_DEFAULT: int = 70 + FAILURE_PERCENTAGE_TOLERATED_DEFAULT: int = 90 STALE_PARTITION_UNAVAILABILITY_CHECK = "AZURE_COSMOS_STALE_PARTITION_UNAVAILABILITY_CHECK_IN_SECONDS" STALE_PARTITION_UNAVAILABILITY_CHECK_DEFAULT: int = 120 # ------------------------------------------------------------------------- diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/_global_partition_endpoint_manager_circuit_breaker.py b/sdk/cosmos/azure-cosmos/azure/cosmos/_global_partition_endpoint_manager_circuit_breaker.py index b95ef0a2a7da..3f38be706e73 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/_global_partition_endpoint_manager_circuit_breaker.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/_global_partition_endpoint_manager_circuit_breaker.py @@ -57,6 +57,7 @@ def record_failure( self.global_partition_endpoint_manager_core.record_failure(request) def resolve_service_endpoint(self, request): + # TODO: @tvaron3 check here if it is healthy tentative and move it back to Unhealthy request = self.global_partition_endpoint_manager_core.add_excluded_locations_to_request(request) return super(_GlobalPartitionEndpointManagerForCircuitBreaker, self).resolve_service_endpoint(request) diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/_global_partition_endpoint_manager_circuit_breaker_core.py b/sdk/cosmos/azure-cosmos/azure/cosmos/_global_partition_endpoint_manager_circuit_breaker_core.py index 142c51a1ed19..3f60391db4e5 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/_global_partition_endpoint_manager_circuit_breaker_core.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/_global_partition_endpoint_manager_circuit_breaker_core.py @@ -81,10 +81,11 @@ def _create_pkrange_wrapper(self, request: RequestObject) -> PartitionKeyRangeWr # get the partition key range for the given partition key target_container_link = None for container_link, properties in self.client._container_properties_cache: - # TODO: @tvaron3 consider moving this to a constant with other usages if properties["_rid"] == container_rid: target_container_link = container_link - # throw exception if it is not found + if target_container_link: + raise RuntimeError("Illegal state: the container cache is not properly initialized.") + # TODO: @tvaron3 check different clients and create them in different ways pkrange = self.client._routing_map_provider.get_overlapping_range(target_container_link, partition_key) return PartitionKeyRangeWrapper(pkrange, container_rid) @@ -98,7 +99,11 @@ def record_failure( documents._OperationType.IsWriteOperation(request.operation_type)) else EndpointOperationType.ReadType location = self.location_cache.get_location_from_endpoint(str(request.location_endpoint_to_route)) pkrange_wrapper = self._create_pkrange_wrapper(request) - self.partition_health_tracker.add_failure(pkrange_wrapper, endpoint_operation_type, location) + self.partition_health_tracker.add_failure(pkrange_wrapper, endpoint_operation_type, str(location)) + + # TODO: @tvaron3 lower request timeout to 5.5 seconds for recovering + # TODO: @tvaron3 exponential backoff for recovering + def add_excluded_locations_to_request(self, request: RequestObject) -> RequestObject: if self.is_circuit_breaker_applicable(request): @@ -112,7 +117,7 @@ def mark_partition_unavailable(self, request: RequestObject) -> None: """ Mark the partition unavailable from the given request. """ - location = self.location_cache.get_location_from_endpoint(request.location_endpoint_to_route) + location = self.location_cache.get_location_from_endpoint(str(request.location_endpoint_to_route)) pkrange_wrapper = self._create_pkrange_wrapper(request) self.partition_health_tracker.mark_partition_unavailable(pkrange_wrapper, location) @@ -124,7 +129,7 @@ def record_success( #convert operation_type to either Read or Write endpoint_operation_type = EndpointOperationType.WriteType if ( documents._OperationType.IsWriteOperation(request.operation_type)) else EndpointOperationType.ReadType - location = self.location_cache.get_location_from_endpoint(request.location_endpoint_to_route) + location = self.location_cache.get_location_from_endpoint(str(request.location_endpoint_to_route)) pkrange_wrapper = self._create_pkrange_wrapper(request) self.partition_health_tracker.add_success(pkrange_wrapper, endpoint_operation_type, location) diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/_partition_health_tracker.py b/sdk/cosmos/azure-cosmos/azure/cosmos/_partition_health_tracker.py index 74665a5e7eb5..9fc64dbb5f63 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/_partition_health_tracker.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/_partition_health_tracker.py @@ -168,7 +168,7 @@ def add_failure( self, pkrange_wrapper: PartitionKeyRangeWrapper, operation_type: str, - location: Optional[str] + location: str ) -> None: # Retrieve the failure rate threshold from the environment. failure_rate_threshold = int(os.getenv(Constants.FAILURE_PERCENTAGE_TOLERATED, diff --git a/sdk/cosmos/azure-cosmos/tests/test_ppcb_sm_mrr_async.py b/sdk/cosmos/azure-cosmos/tests/test_ppcb_sm_mrr_async.py new file mode 100644 index 000000000000..0a42e0469364 --- /dev/null +++ b/sdk/cosmos/azure-cosmos/tests/test_ppcb_sm_mrr_async.py @@ -0,0 +1,159 @@ +# The MIT License (MIT) +# Copyright (c) Microsoft Corporation. All rights reserved. +import asyncio +import os +import unittest +import uuid +from typing import Dict, Any + +import pytest +import pytest_asyncio +from azure.core.pipeline.transport._aiohttp import AioHttpTransport + +from azure.cosmos import PartitionKey +from azure.cosmos._partition_health_tracker import HEALTH_STATUS, UNHEALTHY, UNHEALTHY_TENTATIVE +from azure.cosmos.aio import CosmosClient +from azure.cosmos.exceptions import CosmosHttpResponseError +from tests import test_config +from tests._fault_injection_transport_async import FaultInjectionTransportAsync + +COLLECTION = "created_collection" +@pytest_asyncio.fixture(scope='class') +async def setup(): + if (TestPPCBSmMrrAsync.master_key == '[YOUR_KEY_HERE]' or + TestPPCBSmMrrAsync.host == '[YOUR_ENDPOINT_HERE]'): + raise Exception( + "You must specify your Azure Cosmos account values for " + "'masterKey' and 'host' at the top of this class to run the " + "tests.") + os.environ["AZURE_COSMOS_ENABLE_CIRCUIT_BREAKER"] = "True" + client = CosmosClient(TestPPCBSmMrrAsync.host, TestPPCBSmMrrAsync.master_key, consistency_level="Session") + created_database = client.get_database_client(TestPPCBSmMrrAsync.TEST_DATABASE_ID) + created_collection = await created_database.create_container(TestPPCBSmMrrAsync.TEST_CONTAINER_SINGLE_PARTITION_ID, + partition_key=PartitionKey("/pk"), throughput=10000) + yield { + COLLECTION: created_collection + } + + await created_database.delete_container(TestPPCBSmMrrAsync.TEST_CONTAINER_SINGLE_PARTITION_ID) + await client.close() + os.environ["AZURE_COSMOS_ENABLE_CIRCUIT_BREAKER"] = "False" + + + + +def error_codes(): + + return [408, 500, 502, 503] + + +@pytest.mark.cosmosEmulator +@pytest.mark.asyncio +@pytest.mark.usefixtures("setup") +class TestPPCBSmMrrAsync: + host = test_config.TestConfig.host + master_key = test_config.TestConfig.masterKey + connectionPolicy = test_config.TestConfig.connectionPolicy + TEST_DATABASE_ID = test_config.TestConfig.TEST_DATABASE_ID + TEST_CONTAINER_SINGLE_PARTITION_ID = os.path.basename(__file__) + str(uuid.uuid4()) + + async def setup_method_with_custom_transport(self, custom_transport: AioHttpTransport, default_endpoint=host, **kwargs): + client = CosmosClient(default_endpoint, self.master_key, consistency_level="Session", + transport=custom_transport, **kwargs) + db = client.get_database_client(self.TEST_DATABASE_ID) + container = db.get_container_client(self.TEST_CONTAINER_SINGLE_PARTITION_ID) + return {"client": client, "db": db, "col": container} + + @staticmethod + async def cleanup_method(initialized_objects: Dict[str, Any]): + method_client: CosmosClient = initialized_objects["client"] + await method_client.close() + + async def create_custom_transport_sm_mrr(self): + custom_transport = FaultInjectionTransportAsync() + # Inject rule to disallow writes in the read-only region + is_write_operation_in_read_region_predicate = lambda \ + r: FaultInjectionTransportAsync.predicate_is_write_operation(r, self.host) + + custom_transport.add_fault( + is_write_operation_in_read_region_predicate, + lambda r: asyncio.create_task(FaultInjectionTransportAsync.error_write_forbidden())) + + # Inject topology transformation that would make Emulator look like a single write region + # account with two read regions + is_get_account_predicate = lambda r: FaultInjectionTransportAsync.predicate_is_database_account_call(r) + emulator_as_multi_region_sm_account_transformation = \ + lambda r, inner: FaultInjectionTransportAsync.transform_topology_swr_mrr( + write_region_name="Write Region", + read_region_name="Read Region", + inner=inner) + custom_transport.add_response_transformation( + is_get_account_predicate, + emulator_as_multi_region_sm_account_transformation) + return custom_transport + + + + @pytest.mark.parametrize("error_code", error_codes()) + async def test_consecutive_failure_threshold_async(self, setup, error_code): + expected_read_region_uri = self.host + expected_write_region_uri = self.host.replace("localhost", "127.0.0.1") + custom_transport = await self.create_custom_transport_sm_mrr() + id_value = 'failoverDoc-' + str(uuid.uuid4()) + document_definition = {'id': id_value, + 'pk': 'pk1', + 'name': 'sample document', + 'key': 'value'} + predicate = lambda r: FaultInjectionTransportAsync.predicate_req_for_document_with_id(r, id_value) + custom_transport.add_fault(predicate, lambda r: asyncio.create_task(CosmosHttpResponseError( + status_code=error_code, + message="Some injected fault."))) + + custom_setup = await self.setup_method_with_custom_transport(custom_transport, default_endpoint=self.host) + container = custom_setup['col'] + + # writes should fail in sm mrr with circuit breaker and should not mark unavailable a partition + for i in range(6): + with pytest.raises(CosmosHttpResponseError): + await container.create_item(body=document_definition) + + TestPPCBSmMrrAsync.validate_unhealthy_partitions(container.client_connection._global_endpoint_manager, 0) + + + + # create item with client without fault injection + await setup[COLLECTION].create_item(body=document_definition) + + # reads should failover and only the relevant partition should be marked as unavailable + await container.read_item(item=document_definition['id'], partition_key=document_definition['pk']) + # partition should not have been marked unavailable after one error + TestPPCBSmMrrAsync.validate_unhealthy_partitions(container.client_connection._global_endpoint_manager, 0) + + for i in range(10): + await container.read_item(item=document_definition['id'], partition_key=document_definition['pk']) + + # the partition should have been marked as unavailable + TestPPCBSmMrrAsync.validate_unhealthy_partitions(container.client_connection._global_endpoint_manager, 1) + + + @staticmethod + def validate_unhealthy_partitions(global_endpoint_manager, expected_unhealthy_partitions): + health_info_map = global_endpoint_manager.global_partition_endpoint_manager_core.partition_health_tracker.pkrange_wrapper_to_health_info + unhealthy_partitions = 0 + for pk_range_wrapper, location_to_health_info in health_info_map.items(): + for location, health_info in location_to_health_info: + if health_info[HEALTH_STATUS] == UNHEALTHY_TENTATIVE or health_info[HEALTH_STATUS] == UNHEALTHY: + unhealthy_partitions += 1 + assert len(health_info_map) == expected_unhealthy_partitions + assert unhealthy_partitions == expected_unhealthy_partitions + + + + + + + # test_failure_rate_threshold + + +if __name__ == '__main__': + unittest.main() \ No newline at end of file From 34e3d82ba9d2d2cceafd538e455e96b4b5fccf58 Mon Sep 17 00:00:00 2001 From: Tomas Varon Saldarriaga Date: Mon, 7 Apr 2025 10:54:42 -0400 Subject: [PATCH 56/86] fix test --- ...n_endpoint_manager_circuit_breaker_core.py | 20 ++--- .../azure/cosmos/_location_cache.py | 3 +- .../azure/cosmos/_routing/routing_range.py | 3 + .../azure/cosmos/aio/_retry_utility_async.py | 8 +- .../tests/test_ppcb_sm_mrr_async.py | 80 +++++++++---------- 5 files changed, 58 insertions(+), 56 deletions(-) diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/_global_partition_endpoint_manager_circuit_breaker_core.py b/sdk/cosmos/azure-cosmos/azure/cosmos/_global_partition_endpoint_manager_circuit_breaker_core.py index 3f60391db4e5..014809cac5b4 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/_global_partition_endpoint_manager_circuit_breaker_core.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/_global_partition_endpoint_manager_circuit_breaker_core.py @@ -67,12 +67,12 @@ def is_circuit_breaker_applicable(self, request: RequestObject) -> bool: if request.resource_type != ResourceType.Document: return False - if request.operation_type != documents._OperationType.QueryPlan: + if request.operation_type == documents._OperationType.QueryPlan: return False return True - def _create_pkrange_wrapper(self, request: RequestObject) -> PartitionKeyRangeWrapper: + def _create_pk_range_wrapper(self, request: RequestObject) -> PartitionKeyRangeWrapper: """ Create a PartitionKeyRangeWrapper object. """ @@ -80,14 +80,14 @@ def _create_pkrange_wrapper(self, request: RequestObject) -> PartitionKeyRangeWr partition_key = request.headers[HttpHeaders.PartitionKey] # get the partition key range for the given partition key target_container_link = None - for container_link, properties in self.client._container_properties_cache: + for container_link, properties in self.client._container_properties_cache.items(): if properties["_rid"] == container_rid: target_container_link = container_link - if target_container_link: + if not target_container_link: raise RuntimeError("Illegal state: the container cache is not properly initialized.") # TODO: @tvaron3 check different clients and create them in different ways - pkrange = self.client._routing_map_provider.get_overlapping_range(target_container_link, partition_key) - return PartitionKeyRangeWrapper(pkrange, container_rid) + pk_range = await self.client._routing_map_provider.get_overlapping_ranges(target_container_link, partition_key) + return PartitionKeyRangeWrapper(pk_range, container_rid) def record_failure( self, @@ -98,7 +98,7 @@ def record_failure( endpoint_operation_type = EndpointOperationType.WriteType if ( documents._OperationType.IsWriteOperation(request.operation_type)) else EndpointOperationType.ReadType location = self.location_cache.get_location_from_endpoint(str(request.location_endpoint_to_route)) - pkrange_wrapper = self._create_pkrange_wrapper(request) + pkrange_wrapper = self._create_pk_range_wrapper(request) self.partition_health_tracker.add_failure(pkrange_wrapper, endpoint_operation_type, str(location)) # TODO: @tvaron3 lower request timeout to 5.5 seconds for recovering @@ -107,7 +107,7 @@ def record_failure( def add_excluded_locations_to_request(self, request: RequestObject) -> RequestObject: if self.is_circuit_breaker_applicable(request): - pkrange_wrapper = self._create_pkrange_wrapper(request) + pkrange_wrapper = self._create_pk_range_wrapper(request) request.set_excluded_locations_from_circuit_breaker( self.partition_health_tracker.get_excluded_locations(pkrange_wrapper) ) @@ -118,7 +118,7 @@ def mark_partition_unavailable(self, request: RequestObject) -> None: Mark the partition unavailable from the given request. """ location = self.location_cache.get_location_from_endpoint(str(request.location_endpoint_to_route)) - pkrange_wrapper = self._create_pkrange_wrapper(request) + pkrange_wrapper = self._create_pk_range_wrapper(request) self.partition_health_tracker.mark_partition_unavailable(pkrange_wrapper, location) def record_success( @@ -130,7 +130,7 @@ def record_success( endpoint_operation_type = EndpointOperationType.WriteType if ( documents._OperationType.IsWriteOperation(request.operation_type)) else EndpointOperationType.ReadType location = self.location_cache.get_location_from_endpoint(str(request.location_endpoint_to_route)) - pkrange_wrapper = self._create_pkrange_wrapper(request) + pkrange_wrapper = self._create_pk_range_wrapper(request) self.partition_health_tracker.add_success(pkrange_wrapper, endpoint_operation_type, location) # TODO: @tvaron3 there should be no in region retries when trying on healthy tentative ----------------------- diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/_location_cache.py b/sdk/cosmos/azure-cosmos/azure/cosmos/_location_cache.py index 873c032d7ead..cb9ef0840a3c 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/_location_cache.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/_location_cache.py @@ -201,8 +201,7 @@ def get_read_regional_routing_contexts(self): return self.read_regional_routing_contexts def get_location_from_endpoint(self, endpoint: str) -> str: - regional_routing_context = RegionalRoutingContext(endpoint, endpoint) - return self.account_locations_by_read_regional_endpoints[regional_routing_context] + return self.account_locations_by_read_regional_routing_context[endpoint] def get_write_regional_routing_context(self): return self.get_write_regional_routing_contexts()[0].get_primary() diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/_routing/routing_range.py b/sdk/cosmos/azure-cosmos/azure/cosmos/_routing/routing_range.py index 4e3d603ef0d8..21a22ca89f61 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/_routing/routing_range.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/_routing/routing_range.py @@ -248,3 +248,6 @@ def __eq__(self, other): return False return self.partition_key_range == other.partition_key_range and self.collection_rid == other.collection_rid + def __hash__(self): + return hash((self.partition_key_range, self.collection_rid)) + diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_retry_utility_async.py b/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_retry_utility_async.py index c5613530994d..e1094736b88b 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_retry_utility_async.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_retry_utility_async.py @@ -103,9 +103,9 @@ async def ExecuteAsync(client, global_endpoint_manager, function, *args, **kwarg try: if args: result = await ExecuteFunctionAsync(function, global_endpoint_manager, *args, **kwargs) + global_endpoint_manager.record_success(args[0]) else: result = await ExecuteFunctionAsync(function, *args, **kwargs) - global_endpoint_manager.record_success(request) if not client.last_response_headers: client.last_response_headers = {} @@ -200,7 +200,7 @@ async def ExecuteAsync(client, global_endpoint_manager, function, *args, **kwarg if client_timeout: kwargs['timeout'] = client_timeout - (time.time() - start_time) if kwargs['timeout'] <= 0: - global_endpoint_manager.record_failure(request) + global_endpoint_manager.record_failure(args[0]) raise exceptions.CosmosClientTimeoutError() except ServiceRequestError as e: @@ -258,8 +258,8 @@ async def send(self, request): """ absolute_timeout = request.context.options.pop('timeout', None) per_request_timeout = request.context.options.pop('connection_timeout', 0) - request_params = request.context.options.get('request_params', None) - global_endpoint_manager = request.context.options.get('global_endpoint_manager', None) + request_params = request.context.options.pop('request_params', None) + global_endpoint_manager = request.context.options.pop('global_endpoint_manager', None) retry_error = None retry_active = True response = None diff --git a/sdk/cosmos/azure-cosmos/tests/test_ppcb_sm_mrr_async.py b/sdk/cosmos/azure-cosmos/tests/test_ppcb_sm_mrr_async.py index 1360ae5630ba..a8e39ae361da 100644 --- a/sdk/cosmos/azure-cosmos/tests/test_ppcb_sm_mrr_async.py +++ b/sdk/cosmos/azure-cosmos/tests/test_ppcb_sm_mrr_async.py @@ -9,6 +9,7 @@ import pytest import pytest_asyncio from azure.core.pipeline.transport._aiohttp import AioHttpTransport +from azure.core.exceptions import ServiceResponseError from azure.cosmos import PartitionKey from azure.cosmos._partition_health_tracker import HEALTH_STATUS, UNHEALTHY, UNHEALTHY_TENTATIVE @@ -18,19 +19,16 @@ from tests._fault_injection_transport_async import FaultInjectionTransportAsync COLLECTION = "created_collection" -@pytest_asyncio.fixture(scope='class') +@pytest_asyncio.fixture() async def setup(): - if (TestPPCBSmMrrAsync.master_key == '[YOUR_KEY_HERE]' or - TestPPCBSmMrrAsync.host == '[YOUR_ENDPOINT_HERE]'): - raise Exception( - "You must specify your Azure Cosmos account values for " - "'masterKey' and 'host' at the top of this class to run the " - "tests.") os.environ["AZURE_COSMOS_ENABLE_CIRCUIT_BREAKER"] = "True" client = CosmosClient(TestPPCBSmMrrAsync.host, TestPPCBSmMrrAsync.master_key, consistency_level="Session") created_database = client.get_database_client(TestPPCBSmMrrAsync.TEST_DATABASE_ID) + # print(TestPPCBSmMrrAsync.TEST_DATABASE_ID) + await client.create_database_if_not_exists(TestPPCBSmMrrAsync.TEST_DATABASE_ID) created_collection = await created_database.create_container(TestPPCBSmMrrAsync.TEST_CONTAINER_SINGLE_PARTITION_ID, - partition_key=PartitionKey("/pk"), throughput=10000) + partition_key=PartitionKey("/pk"), + offer_throughput=10000) yield { COLLECTION: created_collection } @@ -39,13 +37,15 @@ async def setup(): await client.close() os.environ["AZURE_COSMOS_ENABLE_CIRCUIT_BREAKER"] = "False" - - - -def error_codes(): - - return [408, 500, 502, 503] - +def errors(): + errors_list = [] + error_codes = [408, 500, 502, 503] + for error_code in error_codes: + errors_list.append(CosmosHttpResponseError( + status_code=error_code, + message="Some injected fault.")) + errors_list.append(ServiceResponseError(message="Injected Service Response Error.")) + return errors_list @pytest.mark.cosmosEmulator @pytest.mark.asyncio @@ -59,6 +59,7 @@ class TestPPCBSmMrrAsync: async def setup_method_with_custom_transport(self, custom_transport: AioHttpTransport, default_endpoint=host, **kwargs): client = CosmosClient(default_endpoint, self.master_key, consistency_level="Session", + preferred_locations=["Write Region", "Read Region"], transport=custom_transport, **kwargs) db = client.get_database_client(self.TEST_DATABASE_ID) container = db.get_container_client(self.TEST_CONTAINER_SINGLE_PARTITION_ID) @@ -92,45 +93,48 @@ async def create_custom_transport_sm_mrr(self): emulator_as_multi_region_sm_account_transformation) return custom_transport - - - @pytest.mark.parametrize("error_code", error_codes()) - async def test_consecutive_failure_threshold_async(self, setup, error_code): + @pytest.mark.parametrize("error", errors()) + async def test_consecutive_failure_threshold_async(self, setup, error): + expected_read_region_uri = self.host + expected_write_region_uri = expected_read_region_uri.replace("localhost", "127.0.0.1") custom_transport = await self.create_custom_transport_sm_mrr() id_value = 'failoverDoc-' + str(uuid.uuid4()) document_definition = {'id': id_value, 'pk': 'pk1', 'name': 'sample document', 'key': 'value'} - predicate = lambda r: FaultInjectionTransportAsync.predicate_req_for_document_with_id(r, id_value) - custom_transport.add_fault(predicate, lambda r: asyncio.create_task(CosmosHttpResponseError( - status_code=error_code, - message="Some injected fault."))) + predicate = lambda r: (FaultInjectionTransportAsync.predicate_req_for_document_with_id(r, id_value) and + FaultInjectionTransportAsync.predicate_targets_region(r, expected_write_region_uri)) + custom_transport.add_fault(predicate, lambda r: asyncio.create_task(FaultInjectionTransportAsync.error_after_delay( + 0, + error + ))) custom_setup = await self.setup_method_with_custom_transport(custom_transport, default_endpoint=self.host) container = custom_setup['col'] # writes should fail in sm mrr with circuit breaker and should not mark unavailable a partition for i in range(6): - with pytest.raises(CosmosHttpResponseError): + with pytest.raises(CosmosHttpResponseError or ServiceResponseError): await container.create_item(body=document_definition) TestPPCBSmMrrAsync.validate_unhealthy_partitions(container.client_connection._global_endpoint_manager, 0) - - # create item with client without fault injection await setup[COLLECTION].create_item(body=document_definition) - # reads should failover and only the relevant partition should be marked as unavailable + # reads should fail over and only the relevant partition should be marked as unavailable await container.read_item(item=document_definition['id'], partition_key=document_definition['pk']) # partition should not have been marked unavailable after one error TestPPCBSmMrrAsync.validate_unhealthy_partitions(container.client_connection._global_endpoint_manager, 0) - for i in range(10): - await container.read_item(item=document_definition['id'], partition_key=document_definition['pk']) + for i in range(11): + read_resp = await container.read_item(item=document_definition['id'], partition_key=document_definition['pk']) + request = read_resp.get_response_headers()["_request"] + # Validate the response comes from "Read Region" (the most preferred read-only region) + assert request.url.startswith(expected_read_region_uri) - # the partition should have been marked as unavailable + # the partition should have been marked as unavailable after breaking read threshold TestPPCBSmMrrAsync.validate_unhealthy_partitions(container.client_connection._global_endpoint_manager, 1) @@ -139,19 +143,15 @@ def validate_unhealthy_partitions(global_endpoint_manager, expected_unhealthy_pa health_info_map = global_endpoint_manager.global_partition_endpoint_manager_core.partition_health_tracker.pkrange_wrapper_to_health_info unhealthy_partitions = 0 for pk_range_wrapper, location_to_health_info in health_info_map.items(): - for location, health_info in location_to_health_info: - if health_info[HEALTH_STATUS] == UNHEALTHY_TENTATIVE or health_info[HEALTH_STATUS] == UNHEALTHY: + for location, health_info in location_to_health_info.items(): + health_status = health_info.unavailability_info.get(HEALTH_STATUS) + if health_status == UNHEALTHY_TENTATIVE or health_status == UNHEALTHY: unhealthy_partitions += 1 - assert len(health_info_map) == expected_unhealthy_partitions assert unhealthy_partitions == expected_unhealthy_partitions - - - - - # test_failure_rate_threshold - + # test_failure_rate_threshold - add service response error + # test service request marks only a partition unavailable not an entire region if __name__ == '__main__': - unittest.main() \ No newline at end of file + unittest.main() From ce1466618076f815b2312a376c2539856b1ff0dc Mon Sep 17 00:00:00 2001 From: tvaron3 Date: Mon, 7 Apr 2025 18:05:33 -0400 Subject: [PATCH 57/86] refactor due to pk range wrapper needing io call and pylint --- .../azure/cosmos/_cosmos_client_connection.py | 4 +- .../azure/cosmos/_global_endpoint_manager.py | 8 +- ...tition_endpoint_manager_circuit_breaker.py | 54 +++++++++---- ...n_endpoint_manager_circuit_breaker_core.py | 81 +++++++------------ .../azure/cosmos/_location_cache.py | 2 +- .../azure/cosmos/_partition_health_tracker.py | 37 ++++----- .../azure/cosmos/_request_object.py | 4 +- .../azure/cosmos/_retry_utility.py | 28 ++++--- .../azure/cosmos/_routing/routing_range.py | 1 - .../cosmos/_service_request_retry_policy.py | 9 ++- .../cosmos/_service_response_retry_policy.py | 8 +- .../azure/cosmos/_session_retry_policy.py | 12 ++- .../azure/cosmos/_synchronized_request.py | 12 ++- .../cosmos/_timeout_failover_retry_policy.py | 8 +- .../azure/cosmos/aio/_asynchronous_request.py | 8 +- .../aio/_cosmos_client_connection_async.py | 3 +- .../aio/_global_endpoint_manager_async.py | 8 +- ..._endpoint_manager_circuit_breaker_async.py | 62 +++++++++----- .../azure/cosmos/aio/_retry_utility_async.py | 20 ++--- 19 files changed, 215 insertions(+), 154 deletions(-) diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/_cosmos_client_connection.py b/sdk/cosmos/azure-cosmos/azure/cosmos/_cosmos_client_connection.py index 298d3032877b..5e9847c9ad93 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/_cosmos_client_connection.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/_cosmos_client_connection.py @@ -48,7 +48,7 @@ HttpResponse # pylint: disable=no-legacy-azure-core-http-response-import from . import _base as base -from . import _global_endpoint_manager as global_endpoint_manager +from ._global_partition_endpoint_manager_circuit_breaker import _GlobalPartitionEndpointManagerForCircuitBreaker from . import _query_iterable as query_iterable from . import _runtime_constants as runtime_constants from . import _session @@ -164,7 +164,7 @@ def __init__( # pylint: disable=too-many-statements self.last_response_headers: CaseInsensitiveDict = CaseInsensitiveDict() self.UseMultipleWriteLocations = False - self._global_endpoint_manager = global_endpoint_manager._GlobalEndpointManager(self) + self._global_endpoint_manager = _GlobalPartitionEndpointManagerForCircuitBreaker(self) retry_policy = None if isinstance(self.connection_policy.ConnectionRetryConfiguration, HTTPPolicy): diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/_global_endpoint_manager.py b/sdk/cosmos/azure-cosmos/azure/cosmos/_global_endpoint_manager.py index 8cf9d06d5486..62dd60a30da0 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/_global_endpoint_manager.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/_global_endpoint_manager.py @@ -30,6 +30,8 @@ from . import _constants as constants from . import exceptions +from ._request_object import RequestObject +from ._routing.routing_range import PartitionKeyRangeWrapper from .documents import DatabaseAccount from ._location_cache import LocationCache, current_time_millis @@ -67,7 +69,11 @@ def get_write_endpoint(self): def get_read_endpoint(self): return self.location_cache.get_read_regional_routing_context() - def resolve_service_endpoint(self, request): + def resolve_service_endpoint( + self, + request: RequestObject, + pk_range_wrapper: PartitionKeyRangeWrapper # pylint: disable=unused-argument + ) -> str: return self.location_cache.resolve_service_endpoint(request) def mark_endpoint_unavailable_for_read(self, endpoint, refresh_cache): diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/_global_partition_endpoint_manager_circuit_breaker.py b/sdk/cosmos/azure-cosmos/azure/cosmos/_global_partition_endpoint_manager_circuit_breaker.py index 3f38be706e73..288205cb2411 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/_global_partition_endpoint_manager_circuit_breaker.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/_global_partition_endpoint_manager_circuit_breaker.py @@ -28,6 +28,9 @@ from azure.cosmos._global_endpoint_manager import _GlobalEndpointManager from azure.cosmos._request_object import RequestObject +from azure.cosmos._routing.routing_range import PartitionKeyRangeWrapper +from azure.cosmos.http_constants import HttpHeaders + if TYPE_CHECKING: from azure.cosmos._cosmos_client_connection import CosmosClientConnection @@ -42,34 +45,55 @@ class _GlobalPartitionEndpointManagerForCircuitBreaker(_GlobalEndpointManager): def __init__(self, client: "CosmosClientConnection"): super(_GlobalPartitionEndpointManagerForCircuitBreaker, self).__init__(client) - self.global_partition_endpoint_manager_core = _GlobalPartitionEndpointManagerForCircuitBreakerCore(client, self.location_cache) + self.global_partition_endpoint_manager_core = ( + _GlobalPartitionEndpointManagerForCircuitBreakerCore(client, self.location_cache)) def is_circuit_breaker_applicable(self, request: RequestObject) -> bool: - """ - Check if circuit breaker is applicable for a request. - """ return self.global_partition_endpoint_manager_core.is_circuit_breaker_applicable(request) + + def create_pk_range_wrapper(self, request: RequestObject) -> PartitionKeyRangeWrapper: + container_rid = request.headers[HttpHeaders.IntendedCollectionRID] + partition_key = request.headers[HttpHeaders.PartitionKey] + # get the partition key range for the given partition key + target_container_link = None + for container_link, properties in self.Client._container_properties_cache.items(): # pylint: disable=protected-access + if properties["_rid"] == container_rid: + target_container_link = container_link + if not target_container_link: + raise RuntimeError("Illegal state: the container cache is not properly initialized.") + # TODO: @tvaron3 check different clients and create them in different ways + pk_range = (self.Client._routing_map_provider # pylint: disable=protected-access + .get_overlapping_ranges(target_container_link, partition_key)) + return PartitionKeyRangeWrapper(pk_range, container_rid) + def record_failure( self, request: RequestObject ) -> None: - self.global_partition_endpoint_manager_core.record_failure(request) + if self.is_circuit_breaker_applicable(request): + pk_range_wrapper = self.create_pk_range_wrapper(request) + self.global_partition_endpoint_manager_core.record_failure(request, pk_range_wrapper) - def resolve_service_endpoint(self, request): + def resolve_service_endpoint(self, request: RequestObject, pk_range_wrapper: PartitionKeyRangeWrapper) -> str: # TODO: @tvaron3 check here if it is healthy tentative and move it back to Unhealthy - request = self.global_partition_endpoint_manager_core.add_excluded_locations_to_request(request) - return super(_GlobalPartitionEndpointManagerForCircuitBreaker, self).resolve_service_endpoint(request) + if self.is_circuit_breaker_applicable(request): + request = self.global_partition_endpoint_manager_core.add_excluded_locations_to_request(request, + pk_range_wrapper) + return super(_GlobalPartitionEndpointManagerForCircuitBreaker, self).resolve_service_endpoint(request, + pk_range_wrapper) - def mark_partition_unavailable(self, request: RequestObject) -> None: - """ - Mark the partition unavailable from the given request. - """ - self.global_partition_endpoint_manager_core.mark_partition_unavailable(request) + def mark_partition_unavailable( + self, + request: RequestObject, + pk_range_wrapper: PartitionKeyRangeWrapper + ) -> None: + self.global_partition_endpoint_manager_core.mark_partition_unavailable(request, pk_range_wrapper) def record_success( self, request: RequestObject ) -> None: - self.global_partition_endpoint_manager_core.record_success(request) - + if self.global_partition_endpoint_manager_core.is_circuit_breaker_applicable(request): + pk_range_wrapper = self.create_pk_range_wrapper(request) + self.global_partition_endpoint_manager_core.record_success(request, pk_range_wrapper) diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/_global_partition_endpoint_manager_circuit_breaker_core.py b/sdk/cosmos/azure-cosmos/azure/cosmos/_global_partition_endpoint_manager_circuit_breaker_core.py index 014809cac5b4..1f900215f08d 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/_global_partition_endpoint_manager_circuit_breaker_core.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/_global_partition_endpoint_manager_circuit_breaker_core.py @@ -30,7 +30,7 @@ from azure.cosmos._routing.routing_range import PartitionKeyRangeWrapper from azure.cosmos._location_cache import EndpointOperationType, LocationCache from azure.cosmos._request_object import RequestObject -from azure.cosmos.http_constants import ResourceType, HttpHeaders +from azure.cosmos.http_constants import ResourceType from azure.cosmos._constants import _Constants as Constants @@ -49,9 +49,6 @@ def __init__(self, client, location_cache: LocationCache): self.client = client def is_circuit_breaker_applicable(self, request: RequestObject) -> bool: - """ - Check if circuit breaker is applicable for a request. - """ if not request: return False @@ -61,76 +58,54 @@ def is_circuit_breaker_applicable(self, request: RequestObject) -> bool: return False if (not self.location_cache.can_use_multiple_write_locations_for_request(request) - and documents._OperationType.IsWriteOperation(request.operation_type)): + and documents._OperationType.IsWriteOperation(request.operation_type)): # pylint: disable=protected-access return False if request.resource_type != ResourceType.Document: return False - if request.operation_type == documents._OperationType.QueryPlan: + if request.operation_type == documents._OperationType.QueryPlan: # pylint: disable=protected-access return False return True - def _create_pk_range_wrapper(self, request: RequestObject) -> PartitionKeyRangeWrapper: - """ - Create a PartitionKeyRangeWrapper object. - """ - container_rid = request.headers[HttpHeaders.IntendedCollectionRID] - partition_key = request.headers[HttpHeaders.PartitionKey] - # get the partition key range for the given partition key - target_container_link = None - for container_link, properties in self.client._container_properties_cache.items(): - if properties["_rid"] == container_rid: - target_container_link = container_link - if not target_container_link: - raise RuntimeError("Illegal state: the container cache is not properly initialized.") - # TODO: @tvaron3 check different clients and create them in different ways - pk_range = await self.client._routing_map_provider.get_overlapping_ranges(target_container_link, partition_key) - return PartitionKeyRangeWrapper(pk_range, container_rid) - def record_failure( self, - request: RequestObject + request: RequestObject, + pk_range_wrapper: PartitionKeyRangeWrapper ) -> None: - if self.is_circuit_breaker_applicable(request): - #convert operation_type to EndpointOperationType - endpoint_operation_type = EndpointOperationType.WriteType if ( - documents._OperationType.IsWriteOperation(request.operation_type)) else EndpointOperationType.ReadType - location = self.location_cache.get_location_from_endpoint(str(request.location_endpoint_to_route)) - pkrange_wrapper = self._create_pk_range_wrapper(request) - self.partition_health_tracker.add_failure(pkrange_wrapper, endpoint_operation_type, str(location)) - + #convert operation_type to EndpointOperationType + endpoint_operation_type = (EndpointOperationType.WriteType if ( + documents._OperationType.IsWriteOperation(request.operation_type)) # pylint: disable=protected-access + else EndpointOperationType.ReadType) + location = self.location_cache.get_location_from_endpoint(str(request.location_endpoint_to_route)) + self.partition_health_tracker.add_failure(pk_range_wrapper, endpoint_operation_type, str(location)) # TODO: @tvaron3 lower request timeout to 5.5 seconds for recovering # TODO: @tvaron3 exponential backoff for recovering - - def add_excluded_locations_to_request(self, request: RequestObject) -> RequestObject: - if self.is_circuit_breaker_applicable(request): - pkrange_wrapper = self._create_pk_range_wrapper(request) - request.set_excluded_locations_from_circuit_breaker( - self.partition_health_tracker.get_excluded_locations(pkrange_wrapper) - ) + def add_excluded_locations_to_request( + self, + request: RequestObject, + pk_range_wrapper: PartitionKeyRangeWrapper + ) -> RequestObject: + request.set_excluded_locations_from_circuit_breaker( + self.partition_health_tracker.get_excluded_locations(pk_range_wrapper) + ) return request - def mark_partition_unavailable(self, request: RequestObject) -> None: - """ - Mark the partition unavailable from the given request. - """ + def mark_partition_unavailable(self, request: RequestObject, pk_range_wrapper: PartitionKeyRangeWrapper) -> None: location = self.location_cache.get_location_from_endpoint(str(request.location_endpoint_to_route)) - pkrange_wrapper = self._create_pk_range_wrapper(request) - self.partition_health_tracker.mark_partition_unavailable(pkrange_wrapper, location) + self.partition_health_tracker.mark_partition_unavailable(pk_range_wrapper, location) def record_success( self, - request: RequestObject + request: RequestObject, + pk_range_wrapper: PartitionKeyRangeWrapper ) -> None: - if self.is_circuit_breaker_applicable(request): - #convert operation_type to either Read or Write - endpoint_operation_type = EndpointOperationType.WriteType if ( - documents._OperationType.IsWriteOperation(request.operation_type)) else EndpointOperationType.ReadType - location = self.location_cache.get_location_from_endpoint(str(request.location_endpoint_to_route)) - pkrange_wrapper = self._create_pk_range_wrapper(request) - self.partition_health_tracker.add_success(pkrange_wrapper, endpoint_operation_type, location) + #convert operation_type to either Read or Write + endpoint_operation_type = EndpointOperationType.WriteType if ( + documents._OperationType.IsWriteOperation(request.operation_type)) else EndpointOperationType.ReadType # pylint: disable=protected-access + location = self.location_cache.get_location_from_endpoint(str(request.location_endpoint_to_route)) + self.partition_health_tracker.add_success(pk_range_wrapper, endpoint_operation_type, location) # TODO: @tvaron3 there should be no in region retries when trying on healthy tentative ----------------------- diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/_location_cache.py b/sdk/cosmos/azure-cosmos/azure/cosmos/_location_cache.py index cb9ef0840a3c..1d47ef51b5e0 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/_location_cache.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/_location_cache.py @@ -25,7 +25,7 @@ import collections import logging import time -from typing import Set, Mapping, List, Optional +from typing import Set, Mapping, List from urllib.parse import urlparse from . import documents diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/_partition_health_tracker.py b/sdk/cosmos/azure-cosmos/azure/cosmos/_partition_health_tracker.py index 9fc64dbb5f63..04849ca1cb4b 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/_partition_health_tracker.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/_partition_health_tracker.py @@ -22,10 +22,10 @@ """Internal class for partition health tracker for circuit breaker. """ import os -from typing import Dict, Set, Any, Optional -from ._constants import _Constants as Constants +from typing import Dict, Set, Any +from azure.cosmos._routing.routing_range import PartitionKeyRangeWrapper from azure.cosmos._location_cache import current_time_millis, EndpointOperationType -from azure.cosmos._routing.routing_range import PartitionKeyRangeWrapper, Range +from ._constants import _Constants as Constants MINIMUM_REQUESTS_FOR_FAILURE_RATE = 100 @@ -113,9 +113,10 @@ def _transition_health_status_on_failure( if location in region_to_partition_health: # healthy tentative -> unhealthy # if the operation type is not empty, we are in the healthy tentative state - self.pkrange_wrapper_to_health_info[pkrange_wrapper][location].unavailability_info[HEALTH_STATUS] = UNHEALTHY + region_to_partition_health[location].unavailability_info[HEALTH_STATUS] = UNHEALTHY # reset the last unavailability check time stamp - self.pkrange_wrapper_to_health_info[pkrange_wrapper][location].unavailability_info[LAST_UNAVAILABILITY_CHECK_TIME_STAMP] = UNHEALTHY + region_to_partition_health[location].unavailability_info[LAST_UNAVAILABILITY_CHECK_TIME_STAMP] \ + = UNHEALTHY else: # healthy -> unhealthy tentative # if the operation type is empty, we are in the unhealthy tentative state @@ -135,21 +136,22 @@ def _transition_health_status_on_success( # healthy tentative -> healthy self.pkrange_wrapper_to_health_info[pkrange_wrapper].pop(location, None) - def _check_stale_partition_info(self, pkrange_wrapper: PartitionKeyRangeWrapper) -> None: + def _check_stale_partition_info(self, pk_range_wrapper: PartitionKeyRangeWrapper) -> None: current_time = current_time_millis() - stale_partition_unavailability_check = int(os.getenv(Constants.STALE_PARTITION_UNAVAILABILITY_CHECK, + stale_partition_unavailability_check = int(os.environ.get(Constants.STALE_PARTITION_UNAVAILABILITY_CHECK, Constants.STALE_PARTITION_UNAVAILABILITY_CHECK_DEFAULT)) * 1000 - if pkrange_wrapper in self.pkrange_wrapper_to_health_info: - for location, partition_health_info in self.pkrange_wrapper_to_health_info[pkrange_wrapper].items(): - elapsed_time = current_time - partition_health_info.unavailability_info[LAST_UNAVAILABILITY_CHECK_TIME_STAMP] + if pk_range_wrapper in self.pkrange_wrapper_to_health_info: + for _, partition_health_info in self.pkrange_wrapper_to_health_info[pk_range_wrapper].items(): + elapsed_time = (current_time - + partition_health_info.unavailability_info[LAST_UNAVAILABILITY_CHECK_TIME_STAMP]) current_health_status = partition_health_info.unavailability_info[HEALTH_STATUS] # check if the partition key range is still unavailable if ((current_health_status == UNHEALTHY and elapsed_time > stale_partition_unavailability_check) or (current_health_status == UNHEALTHY_TENTATIVE and elapsed_time > INITIAL_UNAVAILABLE_TIME)): # unhealthy or unhealthy tentative -> healthy tentative - self.pkrange_wrapper_to_health_info[pkrange_wrapper][location].unavailability_info[HEALTH_STATUS] = HEALTHY_TENTATIVE + partition_health_info.unavailability_info[HEALTH_STATUS] = HEALTHY_TENTATIVE if current_time - self.last_refresh < REFRESH_INTERVAL: # all partition stats reset every minute @@ -160,8 +162,7 @@ def get_excluded_locations(self, pkrange_wrapper: PartitionKeyRangeWrapper) -> S self._check_stale_partition_info(pkrange_wrapper) if pkrange_wrapper in self.pkrange_wrapper_to_health_info: return set(self.pkrange_wrapper_to_health_info[pkrange_wrapper].keys()) - else: - return set() + return set() def add_failure( @@ -171,7 +172,7 @@ def add_failure( location: str ) -> None: # Retrieve the failure rate threshold from the environment. - failure_rate_threshold = int(os.getenv(Constants.FAILURE_PERCENTAGE_TOLERATED, + failure_rate_threshold = int(os.environ.get(Constants.FAILURE_PERCENTAGE_TOLERATED, Constants.FAILURE_PERCENTAGE_TOLERATED_DEFAULT)) # Ensure that the health info dictionary is properly initialized. @@ -201,7 +202,7 @@ def add_failure( setattr(health_info, consecutive_attr, getattr(health_info, consecutive_attr) + 1) # Retrieve the consecutive failure threshold from the environment. - consecutive_failure_threshold = int(os.getenv(env_key, default_consec_threshold)) + consecutive_failure_threshold = int(os.environ.get(env_key, default_consec_threshold)) # Call the threshold checker with the current stats. self._check_thresholds( @@ -256,6 +257,6 @@ def add_success(self, pkrange_wrapper: PartitionKeyRangeWrapper, operation_type: def _reset_partition_health_tracker_stats(self) -> None: - for pkrange_wrapper in self.pkrange_wrapper_to_health_info: - for location in self.pkrange_wrapper_to_health_info[pkrange_wrapper]: - self.pkrange_wrapper_to_health_info[pkrange_wrapper][location].reset_health_stats() + for locations in self.pkrange_wrapper_to_health_info.values(): + for health_info in locations.values(): + health_info.reset_health_stats() diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/_request_object.py b/sdk/cosmos/azure-cosmos/azure/cosmos/_request_object.py index 28dc2fefd73b..dace40aba2fb 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/_request_object.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/_request_object.py @@ -24,7 +24,7 @@ from typing import Optional, Mapping, Any, Dict, Set, List -class RequestObject(object): +class RequestObject(object): # pylint: disable=too-many-instance-attributes def __init__( self, resource_type: str, @@ -84,5 +84,5 @@ def set_excluded_location_from_options(self, options: Mapping[str, Any]) -> None if self._can_set_excluded_location(options): self.excluded_locations = options['excludedLocations'] - def set_excluded_locations_from_circuit_breaker(self, excluded_locations: Set[str]) -> None: + def set_excluded_locations_from_circuit_breaker(self, excluded_locations: Set[str]) -> None: # pylint: disable=name-too-long self.excluded_locations_circuit_breaker = excluded_locations diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/_retry_utility.py b/sdk/cosmos/azure-cosmos/azure/cosmos/_retry_utility.py index 927ed7a41baa..44c6b088696e 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/_retry_utility.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/_retry_utility.py @@ -45,7 +45,7 @@ # pylint: disable=protected-access, disable=too-many-lines, disable=too-many-statements, disable=too-many-branches -def Execute(client, global_endpoint_manager, function, *args, **kwargs): +def Execute(client, global_endpoint_manager, function, *args, **kwargs): # pylint: disable=too-many-locals """Executes the function with passed parameters applying all retry policies :param object client: @@ -58,6 +58,9 @@ def Execute(client, global_endpoint_manager, function, *args, **kwargs): :returns: the result of running the passed in function as a (result, headers) tuple :rtype: tuple of (dict, dict) """ + pk_range_wrapper = None + if global_endpoint_manager.is_circuit_breaker_applicable(args[0]): + pk_range_wrapper = global_endpoint_manager.create_pk_range_wrapper(args[0]) # instantiate all retry policies here to be applied for each request execution endpointDiscovery_retry_policy = _endpoint_discovery_retry_policy.EndpointDiscoveryRetryPolicy( client.connection_policy, global_endpoint_manager, *args @@ -73,19 +76,19 @@ def Execute(client, global_endpoint_manager, function, *args, **kwargs): defaultRetry_policy = _default_retry_policy.DefaultRetryPolicy(*args) sessionRetry_policy = _session_retry_policy._SessionRetryPolicy( - client.connection_policy.EnableEndpointDiscovery, global_endpoint_manager, *args + client.connection_policy.EnableEndpointDiscovery, global_endpoint_manager, pk_range_wrapper, *args ) partition_key_range_gone_retry_policy = _gone_retry_policy.PartitionKeyRangeGoneRetryPolicy(client, *args) timeout_failover_retry_policy = _timeout_failover_retry_policy._TimeoutFailoverRetryPolicy( - client.connection_policy, global_endpoint_manager, *args + client.connection_policy, global_endpoint_manager, pk_range_wrapper, *args ) service_response_retry_policy = _service_response_retry_policy.ServiceResponseRetryPolicy( - client.connection_policy, global_endpoint_manager, *args, + client.connection_policy, global_endpoint_manager, pk_range_wrapper, *args, ) service_request_retry_policy = _service_request_retry_policy.ServiceRequestRetryPolicy( - client.connection_policy, global_endpoint_manager, *args, + client.connection_policy, global_endpoint_manager, pk_range_wrapper, *args, ) # HttpRequest we would need to modify for Container Recreate Retry Policy request = None @@ -104,6 +107,7 @@ def Execute(client, global_endpoint_manager, function, *args, **kwargs): try: if args: result = ExecuteFunction(function, global_endpoint_manager, *args, **kwargs) + global_endpoint_manager.record_success(args[0]) else: result = ExecuteFunction(function, *args, **kwargs) if not client.last_response_headers: @@ -172,9 +176,9 @@ def Execute(client, global_endpoint_manager, function, *args, **kwargs): retry_policy.container_rid = cached_container["_rid"] request.headers[retry_policy._intended_headers] = retry_policy.container_rid - elif e.status_code == StatusCodes.REQUEST_TIMEOUT: - retry_policy = timeout_failover_retry_policy - elif e.status_code >= StatusCodes.INTERNAL_SERVER_ERROR: + elif e.status_code == StatusCodes.REQUEST_TIMEOUT or e.status_code >= StatusCodes.INTERNAL_SERVER_ERROR: + # record the failure for circuit breaker tracking + global_endpoint_manager.record_failure(args[0]) retry_policy = timeout_failover_retry_policy else: retry_policy = defaultRetry_policy @@ -200,6 +204,7 @@ def Execute(client, global_endpoint_manager, function, *args, **kwargs): if client_timeout: kwargs['timeout'] = client_timeout - (time.time() - start_time) if kwargs['timeout'] <= 0: + global_endpoint_manager.record_failure(args[0]) raise exceptions.CosmosClientTimeoutError() except ServiceRequestError as e: @@ -291,7 +296,8 @@ def send(self, request): """ absolute_timeout = request.context.options.pop('timeout', None) per_request_timeout = request.context.options.pop('connection_timeout', 0) - + request_params = request.context.options.pop('request_params', None) + global_endpoint_manager = request.context.options.pop('global_endpoint_manager', None) retry_error = None retry_active = True response = None @@ -317,6 +323,7 @@ def send(self, request): # since request wasn't sent, raise exception immediately to be dealt with in client retry policies # This logic is based on the _retry.py file from azure-core if not _has_database_account_header(request.http_request.headers): + global_endpoint_manager.record_failure(request_params) if retry_settings['connect'] > 0: retry_active = self.increment(retry_settings, response=request, error=err) if retry_active: @@ -329,7 +336,7 @@ def send(self, request): if (not _has_read_retryable_headers(request.http_request.headers) or _has_database_account_header(request.http_request.headers)): raise err - + global_endpoint_manager.record_failure(request_params) # This logic is based on the _retry.py file from azure-core if retry_settings['read'] > 0: retry_active = self.increment(retry_settings, response=request, error=err) @@ -342,6 +349,7 @@ def send(self, request): if _has_database_account_header(request.http_request.headers): raise err if self._is_method_retryable(retry_settings, request.http_request): + global_endpoint_manager.record_failure(request_params) retry_active = self.increment(retry_settings, response=request, error=err) if retry_active: self.sleep(retry_settings, request.context.transport) diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/_routing/routing_range.py b/sdk/cosmos/azure-cosmos/azure/cosmos/_routing/routing_range.py index 21a22ca89f61..e31682725828 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/_routing/routing_range.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/_routing/routing_range.py @@ -250,4 +250,3 @@ def __eq__(self, other): def __hash__(self): return hash((self.partition_key_range, self.collection_rid)) - diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/_service_request_retry_policy.py b/sdk/cosmos/azure-cosmos/azure/cosmos/_service_request_retry_policy.py index edd15f20337f..b49185512fa4 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/_service_request_retry_policy.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/_service_request_retry_policy.py @@ -13,9 +13,10 @@ class ServiceRequestRetryPolicy(object): - def __init__(self, connection_policy, global_endpoint_manager, *args): + def __init__(self, connection_policy, global_endpoint_manager, pk_range_wrapper, *args): self.args = args self.global_endpoint_manager = global_endpoint_manager + self.pk_range_wrapper = pk_range_wrapper self.total_retries = len(self.global_endpoint_manager.location_cache.read_regional_routing_contexts) self.total_in_region_retries = 1 self.in_region_retry_count = 0 @@ -45,7 +46,7 @@ def ShouldRetry(self): return False if self.global_endpoint_manager.is_circuit_breaker_applicable(self.request): - self.global_endpoint_manager.mark_partition_unavailable(self.request) + self.global_endpoint_manager.mark_partition_unavailable(self.request, self.pk_range_wrapper) else: refresh_cache = self.request.last_routed_location_endpoint_within_region is not None # This logic is for the last retry and mark the region unavailable @@ -99,7 +100,7 @@ def resolve_current_region_service_endpoint(self): # resolve the next service endpoint in the same region # since we maintain 2 endpoints per region for write operations self.request.route_to_location_with_preferred_location_flag(0, True) - return self.global_endpoint_manager.resolve_service_endpoint(self.request) + return self.global_endpoint_manager.resolve_service_endpoint(self.request, self.pk_range_wrapper) # This function prepares the request to go to the next region def resolve_next_region_service_endpoint(self): @@ -113,7 +114,7 @@ def resolve_next_region_service_endpoint(self): self.request.route_to_location_with_preferred_location_flag(0, True) # Resolve the endpoint for the request and pin the resolution to the resolved endpoint # This enables marking the endpoint unavailability on endpoint failover/unreachability - return self.global_endpoint_manager.resolve_service_endpoint(self.request) + return self.global_endpoint_manager.resolve_service_endpoint(self.request, self.pk_range_wrapper) def mark_endpoint_unavailable(self, unavailable_endpoint, refresh_cache: bool): if _OperationType.IsReadOnlyOperation(self.request.operation_type): diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/_service_response_retry_policy.py b/sdk/cosmos/azure-cosmos/azure/cosmos/_service_response_retry_policy.py index 83a856f39d33..330ffb5929a5 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/_service_response_retry_policy.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/_service_response_retry_policy.py @@ -12,15 +12,17 @@ class ServiceResponseRetryPolicy(object): - def __init__(self, connection_policy, global_endpoint_manager, *args): + def __init__(self, connection_policy, global_endpoint_manager, pk_range_wrapper, *args): self.args = args self.global_endpoint_manager = global_endpoint_manager + self.pk_range_wrapper = pk_range_wrapper self.total_retries = len(self.global_endpoint_manager.location_cache.read_regional_routing_contexts) self.failover_retry_count = 0 self.connection_policy = connection_policy self.request = args[0] if args else None if self.request: - self.location_endpoint = self.global_endpoint_manager.resolve_service_endpoint(self.request) + self.location_endpoint = self.global_endpoint_manager.resolve_service_endpoint(self.request, + pk_range_wrapper) self.logger = logging.getLogger('azure.cosmos.ServiceResponseRetryPolicy') def ShouldRetry(self): @@ -57,4 +59,4 @@ def resolve_next_region_service_endpoint(self): self.request.route_to_location_with_preferred_location_flag(self.failover_retry_count, True) # Resolve the endpoint for the request and pin the resolution to the resolved endpoint # This enables marking the endpoint unavailability on endpoint failover/unreachability - return self.global_endpoint_manager.resolve_service_endpoint(self.request) + return self.global_endpoint_manager.resolve_service_endpoint(self.request, self.pk_range_wrapper) diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/_session_retry_policy.py b/sdk/cosmos/azure-cosmos/azure/cosmos/_session_retry_policy.py index 1614f337de5b..e52fbe996e11 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/_session_retry_policy.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/_session_retry_policy.py @@ -41,10 +41,11 @@ class _SessionRetryPolicy(object): Max_retry_attempt_count = 1 Retry_after_in_milliseconds = 0 - def __init__(self, endpoint_discovery_enable, global_endpoint_manager, *args): + def __init__(self, endpoint_discovery_enable, global_endpoint_manager, pk_range_wrapper, *args): self.global_endpoint_manager = global_endpoint_manager self._max_retry_attempt_count = _SessionRetryPolicy.Max_retry_attempt_count self.session_token_retry_count = 0 + self.pk_range_wrapper = pk_range_wrapper self.retry_after_in_milliseconds = _SessionRetryPolicy.Retry_after_in_milliseconds self.endpoint_discovery_enable = endpoint_discovery_enable self.request = args[0] if args else None @@ -57,7 +58,8 @@ def __init__(self, endpoint_discovery_enable, global_endpoint_manager, *args): # Resolve the endpoint for the request and pin the resolution to the resolved endpoint # This enables marking the endpoint unavailability on endpoint failover/unreachability - self.location_endpoint = self.global_endpoint_manager.resolve_service_endpoint(self.request) + self.location_endpoint = self.global_endpoint_manager.resolve_service_endpoint(self.request, + self.pk_range_wrapper) self.request.route_to_location(self.location_endpoint) def ShouldRetry(self, _exception): @@ -98,7 +100,8 @@ def ShouldRetry(self, _exception): # Resolve the endpoint for the request and pin the resolution to the resolved endpoint # This enables marking the endpoint unavailability on endpoint failover/unreachability - self.location_endpoint = self.global_endpoint_manager.resolve_service_endpoint(self.request) + self.location_endpoint = self.global_endpoint_manager.resolve_service_endpoint(self.request, + self.pk_range_wrapper) self.request.route_to_location(self.location_endpoint) return True @@ -113,6 +116,7 @@ def ShouldRetry(self, _exception): # Resolve the endpoint for the request and pin the resolution to the resolved endpoint # This enables marking the endpoint unavailability on endpoint failover/unreachability - self.location_endpoint = self.global_endpoint_manager.resolve_service_endpoint(self.request) + self.location_endpoint = self.global_endpoint_manager.resolve_service_endpoint(self.request, + self.pk_range_wrapper) self.request.route_to_location(self.location_endpoint) return True diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/_synchronized_request.py b/sdk/cosmos/azure-cosmos/azure/cosmos/_synchronized_request.py index 68e37caf1d9d..43516430bdb6 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/_synchronized_request.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/_synchronized_request.py @@ -65,7 +65,7 @@ def _request_body_from_data(data): return None -def _Request(global_endpoint_manager, request_params, connection_policy, pipeline_client, request, **kwargs): +def _Request(global_endpoint_manager, request_params, connection_policy, pipeline_client, request, **kwargs): # pylint: disable=too-many-statements """Makes one http request using the requests module. :param _GlobalEndpointManager global_endpoint_manager: @@ -104,7 +104,11 @@ def _Request(global_endpoint_manager, request_params, connection_policy, pipelin if request_params.endpoint_override: base_url = request_params.endpoint_override else: - base_url = global_endpoint_manager.resolve_service_endpoint(request_params) + pk_range_wrapper = None + if global_endpoint_manager.is_circuit_breaker_applicable(request_params): + # Circuit breaker is applicable, so we need to use the endpoint from the request + pk_range_wrapper = global_endpoint_manager.create_pk_range_wrapper(request_params) + base_url = global_endpoint_manager.resolve_service_endpoint(request_params, pk_range_wrapper) if not request.url.startswith(base_url): request.url = _replace_url_prefix(request.url, base_url) @@ -132,6 +136,8 @@ def _Request(global_endpoint_manager, request_params, connection_policy, pipelin read_timeout=read_timeout, connection_verify=kwargs.pop("connection_verify", ca_certs), connection_cert=kwargs.pop("connection_cert", cert_files), + request_params=request_params, + global_endpoint_manager=global_endpoint_manager, **kwargs ) else: @@ -142,6 +148,8 @@ def _Request(global_endpoint_manager, request_params, connection_policy, pipelin read_timeout=read_timeout, # If SSL is disabled, verify = false connection_verify=kwargs.pop("connection_verify", is_ssl_enabled), + request_params=request_params, + global_endpoint_manager=global_endpoint_manager, **kwargs ) diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/_timeout_failover_retry_policy.py b/sdk/cosmos/azure-cosmos/azure/cosmos/_timeout_failover_retry_policy.py index f70e27bae70c..69bc973c3346 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/_timeout_failover_retry_policy.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/_timeout_failover_retry_policy.py @@ -9,9 +9,10 @@ class _TimeoutFailoverRetryPolicy(object): - def __init__(self, connection_policy, global_endpoint_manager, *args): + def __init__(self, connection_policy, global_endpoint_manager, pk_range_wrapper, *args): self.retry_after_in_milliseconds = 500 self.global_endpoint_manager = global_endpoint_manager + self.pk_range_wrapper = pk_range_wrapper # If an account only has 1 region, then we still want to retry once on the same region self._max_retry_attempt_count = (len(self.global_endpoint_manager.location_cache.read_regional_routing_contexts) + 1) @@ -26,9 +27,6 @@ def ShouldRetry(self, _exception): :returns: a boolean stating whether the request should be retried :rtype: bool """ - # record the failure for circuit breaker tracking - self.global_endpoint_manager.record_failure(self.request) - # we don't retry on write operations for timeouts or any internal server errors if self.request and (not _OperationType.IsReadOnlyOperation(self.request.operation_type)): return False @@ -57,4 +55,4 @@ def resolve_next_region_service_endpoint(self): self.request.route_to_location_with_preferred_location_flag(self.retry_count, True) # Resolve the endpoint for the request and pin the resolution to the resolved endpoint # This enables marking the endpoint unavailability on endpoint failover/unreachability - return self.global_endpoint_manager.resolve_service_endpoint(self.request) + return self.global_endpoint_manager.resolve_service_endpoint(self.request, self.pk_range_wrapper) diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_asynchronous_request.py b/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_asynchronous_request.py index 25f6ac203d85..4fda37ea0a87 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_asynchronous_request.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_asynchronous_request.py @@ -34,7 +34,7 @@ from .._synchronized_request import _request_body_from_data, _replace_url_prefix -async def _Request(global_endpoint_manager, request_params, connection_policy, pipeline_client, request, **kwargs): +async def _Request(global_endpoint_manager, request_params, connection_policy, pipeline_client, request, **kwargs): # pylint: disable=too-many-statements """Makes one http request using the requests module. :param _GlobalEndpointManager global_endpoint_manager: @@ -73,7 +73,11 @@ async def _Request(global_endpoint_manager, request_params, connection_policy, p if request_params.endpoint_override: base_url = request_params.endpoint_override else: - base_url = global_endpoint_manager.resolve_service_endpoint(request_params) + pk_range_wrapper = None + if global_endpoint_manager.is_circuit_breaker_applicable(request_params): + # Circuit breaker is applicable, so we need to use the endpoint from the request + pk_range_wrapper = await global_endpoint_manager.create_pk_range_wrapper(request_params) + base_url = global_endpoint_manager.resolve_service_endpoint(request_params, pk_range_wrapper) if not request.url.startswith(base_url): request.url = _replace_url_prefix(request.url, base_url) diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_cosmos_client_connection_async.py b/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_cosmos_client_connection_async.py index 8fbdb0fb9a83..31ae9dc334cd 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_cosmos_client_connection_async.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_cosmos_client_connection_async.py @@ -48,13 +48,14 @@ DistributedTracingPolicy, ProxyPolicy) from azure.core.utils import CaseInsensitiveDict +from azure.cosmos.aio._global_partition_endpoint_manager_circuit_breaker_async import ( + _GlobalPartitionEndpointManagerForCircuitBreakerAsync) from .. import _base as base from .._base import _set_properties_cache from .. import documents from .._change_feed.aio.change_feed_iterable import ChangeFeedIterable from .._change_feed.change_feed_state import ChangeFeedState -from azure.cosmos.aio._global_partition_endpoint_manager_circuit_breaker_async import _GlobalPartitionEndpointManagerForCircuitBreakerAsync from .._routing import routing_range from ..documents import ConnectionPolicy, DatabaseAccount from .._constants import _Constants as Constants diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_global_endpoint_manager_async.py b/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_global_endpoint_manager_async.py index 00438cc2214e..0fe666f1983c 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_global_endpoint_manager_async.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_global_endpoint_manager_async.py @@ -25,7 +25,6 @@ import asyncio # pylint: disable=do-not-import-asyncio import logging -from asyncio import CancelledError # pylint: disable=do-not-import-asyncio from typing import Tuple from azure.core.exceptions import AzureError @@ -35,6 +34,7 @@ from .. import exceptions from .._location_cache import LocationCache, current_time_millis from .._request_object import RequestObject +from .._routing.routing_range import PartitionKeyRangeWrapper # pylint: disable=protected-access @@ -71,7 +71,11 @@ def get_write_endpoint(self): def get_read_endpoint(self): return self.location_cache.get_read_regional_routing_context() - def resolve_service_endpoint(self, request): + def resolve_service_endpoint( + self, + request: RequestObject, + pk_range_wrapper: PartitionKeyRangeWrapper # pylint: disable=unused-argument + ) -> str: return self.location_cache.resolve_service_endpoint(request) def mark_endpoint_unavailable_for_read(self, endpoint, refresh_cache): diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_global_partition_endpoint_manager_circuit_breaker_async.py b/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_global_partition_endpoint_manager_circuit_breaker_async.py index 71bb628c31a0..3aadc9b6aba0 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_global_partition_endpoint_manager_circuit_breaker_async.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_global_partition_endpoint_manager_circuit_breaker_async.py @@ -21,13 +21,16 @@ """Internal class for global endpoint manager for circuit breaker. """ -from typing import TYPE_CHECKING, Optional +from typing import TYPE_CHECKING from azure.cosmos._global_partition_endpoint_manager_circuit_breaker_core import \ _GlobalPartitionEndpointManagerForCircuitBreakerCore +from azure.cosmos._routing.routing_range import PartitionKeyRangeWrapper from azure.cosmos.aio._global_endpoint_manager_async import _GlobalEndpointManager from azure.cosmos._request_object import RequestObject +from azure.cosmos.http_constants import HttpHeaders + if TYPE_CHECKING: from azure.cosmos.aio._cosmos_client_connection_async import CosmosClientConnection @@ -42,32 +45,53 @@ class _GlobalPartitionEndpointManagerForCircuitBreakerAsync(_GlobalEndpointManag def __init__(self, client: "CosmosClientConnection"): super(_GlobalPartitionEndpointManagerForCircuitBreakerAsync, self).__init__(client) - self.global_partition_endpoint_manager_core = _GlobalPartitionEndpointManagerForCircuitBreakerCore(client, self.location_cache) + self.global_partition_endpoint_manager_core = ( + _GlobalPartitionEndpointManagerForCircuitBreakerCore(client, self.location_cache)) + + async def create_pk_range_wrapper(self, request: RequestObject) -> PartitionKeyRangeWrapper: + container_rid = request.headers[HttpHeaders.IntendedCollectionRID] + partition_key = request.headers[HttpHeaders.PartitionKey] + # get the partition key range for the given partition key + target_container_link = None + for container_link, properties in self.client._container_properties_cache.items(): # pylint: disable=protected-access + if properties["_rid"] == container_rid: + target_container_link = container_link + if not target_container_link: + raise RuntimeError("Illegal state: the container cache is not properly initialized.") + # TODO: @tvaron3 check different clients and create them in different ways + pk_range = await (self.client._routing_map_provider # pylint: disable=protected-access + .get_overlapping_ranges(target_container_link, partition_key)) + return PartitionKeyRangeWrapper(pk_range, container_rid) def is_circuit_breaker_applicable(self, request: RequestObject) -> bool: - """ - Check if circuit breaker is applicable for a request. - """ return self.global_partition_endpoint_manager_core.is_circuit_breaker_applicable(request) - def record_failure( + async def record_failure( self, request: RequestObject ) -> None: - self.global_partition_endpoint_manager_core.record_failure(request) - - def resolve_service_endpoint(self, request): - request = self.global_partition_endpoint_manager_core.add_excluded_locations_to_request(request) - return super(_GlobalPartitionEndpointManagerForCircuitBreakerAsync, self).resolve_service_endpoint(request) - - def mark_partition_unavailable(self, request: RequestObject) -> None: - """ - Mark the partition unavailable from the given request. - """ - self.global_partition_endpoint_manager_core.mark_partition_unavailable(request) + if self.is_circuit_breaker_applicable(request): + pk_range_wrapper = await self.create_pk_range_wrapper(request) + self.global_partition_endpoint_manager_core.record_failure(request, pk_range_wrapper) + + def resolve_service_endpoint(self, request: RequestObject, pk_range_wrapper: PartitionKeyRangeWrapper): + if self.global_partition_endpoint_manager_core.is_circuit_breaker_applicable(request) and pk_range_wrapper: + request = self.global_partition_endpoint_manager_core.add_excluded_locations_to_request(request, + pk_range_wrapper) + return (super(_GlobalPartitionEndpointManagerForCircuitBreakerAsync, self) + .resolve_service_endpoint(request, pk_range_wrapper)) + + def mark_partition_unavailable( + self, + request: RequestObject, + pk_range_wrapper: PartitionKeyRangeWrapper + ) -> None: + self.global_partition_endpoint_manager_core.mark_partition_unavailable(request, pk_range_wrapper) - def record_success( + async def record_success( self, request: RequestObject ) -> None: - self.global_partition_endpoint_manager_core.record_success(request) + if self.global_partition_endpoint_manager_core.is_circuit_breaker_applicable(request): + pk_range_wrapper = await self.create_pk_range_wrapper(request) + self.global_partition_endpoint_manager_core.record_success(request, pk_range_wrapper) diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_retry_utility_async.py b/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_retry_utility_async.py index e1094736b88b..6e94c9260b8b 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_retry_utility_async.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_retry_utility_async.py @@ -28,7 +28,6 @@ from azure.core.exceptions import AzureError, ClientAuthenticationError, ServiceRequestError, ServiceResponseError from azure.core.pipeline.policies import AsyncRetryPolicy -from ._global_partition_endpoint_manager_circuit_breaker_async import _GlobalPartitionEndpointManagerForCircuitBreakerAsync from .. import _default_retry_policy, _database_account_retry_policy from .. import _endpoint_discovery_retry_policy from .. import _gone_retry_policy @@ -59,6 +58,9 @@ async def ExecuteAsync(client, global_endpoint_manager, function, *args, **kwarg :returns: the result of running the passed in function as a (result, headers) tuple :rtype: tuple of (dict, dict) """ + pk_range_wrapper = None + if args and global_endpoint_manager.is_circuit_breaker_applicable(args[0]): + pk_range_wrapper = await global_endpoint_manager.create_pk_range_wrapper(args[0]) # instantiate all retry policies here to be applied for each request execution endpointDiscovery_retry_policy = _endpoint_discovery_retry_policy.EndpointDiscoveryRetryPolicy( client.connection_policy, global_endpoint_manager, *args @@ -74,17 +76,17 @@ async def ExecuteAsync(client, global_endpoint_manager, function, *args, **kwarg defaultRetry_policy = _default_retry_policy.DefaultRetryPolicy(*args) sessionRetry_policy = _session_retry_policy._SessionRetryPolicy( - client.connection_policy.EnableEndpointDiscovery, global_endpoint_manager, *args + client.connection_policy.EnableEndpointDiscovery, global_endpoint_manager, pk_range_wrapper, *args ) partition_key_range_gone_retry_policy = _gone_retry_policy.PartitionKeyRangeGoneRetryPolicy(client, *args) timeout_failover_retry_policy = _timeout_failover_retry_policy._TimeoutFailoverRetryPolicy( - client.connection_policy, global_endpoint_manager, *args + client.connection_policy, global_endpoint_manager, pk_range_wrapper, *args ) service_response_retry_policy = _service_response_retry_policy.ServiceResponseRetryPolicy( - client.connection_policy, global_endpoint_manager, *args, + client.connection_policy, global_endpoint_manager, pk_range_wrapper, *args, ) service_request_retry_policy = _service_request_retry_policy.ServiceRequestRetryPolicy( - client.connection_policy, global_endpoint_manager, *args, + client.connection_policy, global_endpoint_manager, pk_range_wrapper, *args, ) # HttpRequest we would need to modify for Container Recreate Retry Policy request = None @@ -103,7 +105,7 @@ async def ExecuteAsync(client, global_endpoint_manager, function, *args, **kwarg try: if args: result = await ExecuteFunctionAsync(function, global_endpoint_manager, *args, **kwargs) - global_endpoint_manager.record_success(args[0]) + await global_endpoint_manager.record_success(args[0]) else: result = await ExecuteFunctionAsync(function, *args, **kwargs) if not client.last_response_headers: @@ -172,9 +174,9 @@ async def ExecuteAsync(client, global_endpoint_manager, function, *args, **kwarg retry_policy.container_rid = cached_container["_rid"] request.headers[retry_policy._intended_headers] = retry_policy.container_rid - elif e.status_code == StatusCodes.REQUEST_TIMEOUT: - retry_policy = timeout_failover_retry_policy - elif e.status_code >= StatusCodes.INTERNAL_SERVER_ERROR: + elif e.status_code == StatusCodes.REQUEST_TIMEOUT or e.status_code >= StatusCodes.INTERNAL_SERVER_ERROR: + # record the failure for circuit breaker tracking + await global_endpoint_manager.record_failure(args[0]) retry_policy = timeout_failover_retry_policy else: retry_policy = defaultRetry_policy From 29305f46c2109abcceb80bd5e1b6737a124669c5 Mon Sep 17 00:00:00 2001 From: Allen Kim Date: Mon, 7 Apr 2025 15:58:01 -0700 Subject: [PATCH 58/86] Added client level ExcludedLocation for async --- sdk/cosmos/azure-cosmos/azure/cosmos/aio/_cosmos_client.py | 1 + 1 file changed, 1 insertion(+) diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_cosmos_client.py b/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_cosmos_client.py index 683f16288cd3..647f6d59f615 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_cosmos_client.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_cosmos_client.py @@ -84,6 +84,7 @@ def _build_connection_policy(kwargs: Dict[str, Any]) -> ConnectionPolicy: policy.ProxyConfiguration = kwargs.pop('proxy_config', policy.ProxyConfiguration) policy.EnableEndpointDiscovery = kwargs.pop('enable_endpoint_discovery', policy.EnableEndpointDiscovery) policy.PreferredLocations = kwargs.pop('preferred_locations', policy.PreferredLocations) + policy.ExcludedLocations = kwargs.pop('excluded_locations', policy.ExcludedLocations) policy.UseMultipleWriteLocations = kwargs.pop('multiple_write_locations', policy.UseMultipleWriteLocations) # SSL config From c77b4e726b8ab2d7c7f6426c21ef7afb1e7b10e4 Mon Sep 17 00:00:00 2001 From: Allen Kim Date: Mon, 7 Apr 2025 16:39:55 -0700 Subject: [PATCH 59/86] Update Live test settings --- sdk/cosmos/live-platform-matrix.json | 10 +++++----- sdk/cosmos/test-resources.bicep | 6 ------ 2 files changed, 5 insertions(+), 11 deletions(-) diff --git a/sdk/cosmos/live-platform-matrix.json b/sdk/cosmos/live-platform-matrix.json index 7a02486a8827..dc3ad3c32e17 100644 --- a/sdk/cosmos/live-platform-matrix.json +++ b/sdk/cosmos/live-platform-matrix.json @@ -90,11 +90,6 @@ } }, { - "ArmConfig": { - "MultiMaster_MultiRegion": { - "ArmTemplateParameters": "@{ enableMultipleWriteLocations = $true; defaultConsistencyLevel = 'Session'; enableMultipleRegions = $true }" - } - }, "WindowsConfig": { "Windows2022_312": { "OSVmImage": "env:WINDOWSVMIMAGE", @@ -104,6 +99,11 @@ "TestSamples": "false", "TestMarkArgument": "cosmosMultiRegion" } + }, + "ArmConfig": { + "MultiMaster_MultiRegion": { + "ArmTemplateParameters": "@{ enableMultipleWriteLocations = $true; defaultConsistencyLevel = 'Session'; enableMultipleRegions = $true }" + } } } ] diff --git a/sdk/cosmos/test-resources.bicep b/sdk/cosmos/test-resources.bicep index b05dead26737..88abe955f8d8 100644 --- a/sdk/cosmos/test-resources.bicep +++ b/sdk/cosmos/test-resources.bicep @@ -41,12 +41,6 @@ var multiRegionConfiguration = [ failoverPriority: 1 isZoneRedundant: false } - { - locationName: 'East US 2' - provisioningState: 'Succeeded' - failoverPriority: 2 - isZoneRedundant: false - } ] var locationsConfiguration = (enableMultipleRegions ? multiRegionConfiguration : singleRegionConfiguration) var roleDefinitionId = guid(baseName, 'roleDefinitionId') From d82fa74255e899556f48f8797ff8afbe7ad595bc Mon Sep 17 00:00:00 2001 From: Allen Kim Date: Mon, 7 Apr 2025 16:40:32 -0700 Subject: [PATCH 60/86] Added Async tests --- .../tests/test_excluded_locations.py | 76 ++- .../tests/test_excluded_locations_async.py | 470 ++++++++++++++++++ 2 files changed, 504 insertions(+), 42 deletions(-) create mode 100644 sdk/cosmos/azure-cosmos/tests/test_excluded_locations_async.py diff --git a/sdk/cosmos/azure-cosmos/tests/test_excluded_locations.py b/sdk/cosmos/azure-cosmos/tests/test_excluded_locations.py index 2d17bad85ba5..13e9ba713653 100644 --- a/sdk/cosmos/azure-cosmos/tests/test_excluded_locations.py +++ b/sdk/cosmos/azure-cosmos/tests/test_excluded_locations.py @@ -38,46 +38,42 @@ def emit(self, record): L1 = "West US 3" L2 = "West US" L3 = "East US 2" -L4 = "Central US" # L0 = "Default" # L1 = "East US 2" # L2 = "East US" # L3 = "West US 2" -# L4 = "Central US" CLIENT_ONLY_TEST_DATA = [ # preferred_locations, client_excluded_locations, excluded_locations_request # 0. No excluded location - [[L1, L2, L3], [], None], + [[L1, L2], [], None], # 1. Single excluded location - [[L1, L2, L3], [L1], None], - # 2. Multiple excluded locations - [[L1, L2, L3], [L1, L2], None], - # 3. Exclude all locations - [[L1, L2, L3], [L1, L2, L3], None], - # 4. Exclude a location not in preferred locations - [[L1, L2, L3], [L4], None], + [[L1, L2], [L1], None], + # 2. Exclude all locations + [[L1, L2], [L1, L2], None], + # 3. Exclude a location not in preferred locations + [[L1, L2], [L3], None], ] CLIENT_AND_REQUEST_TEST_DATA = [ # preferred_locations, client_excluded_locations, excluded_locations_request # 0. No client excluded locations + a request excluded location - [[L1, L2, L3], [], [L1]], + [[L1, L2], [], [L1]], # 1. The same client and request excluded location - [[L1, L2, L3], [L1], [L1]], + [[L1, L2], [L1], [L1]], # 2. Less request excluded locations - [[L1, L2, L3], [L1, L2], [L1]], + [[L1, L2], [L1, L2], [L1]], # 3. More request excluded locations - [[L1, L2, L3], [L1], [L1, L2]], + [[L1, L2], [L1], [L1, L2]], # 4. All locations were excluded - [[L1, L2, L3], [L1, L2, L3], [L1, L2, L3]], + [[L1, L2], [L1, L2], [L1, L2]], # 5. No common excluded locations - [[L1, L2, L3], [L1], [L2, L3]], + [[L1, L2], [L1], [L2]], # 6. Request excluded location not in preferred locations - [[L1, L2, L3], [L1, L2, L3], [L4]], + [[L1, L2], [L1, L2], [L3]], # 7. Empty excluded locations, remove all client excluded locations - [[L1, L2, L3], [L1, L2], []], + [[L1, L2], [L1, L2], []], ] ALL_INPUT_TEST_DATA = CLIENT_ONLY_TEST_DATA + CLIENT_AND_REQUEST_TEST_DATA @@ -88,15 +84,14 @@ def read_item_test_data(): client_only_output_data = [ [L1], # 0 [L2], # 1 - [L3], # 2 + [L1], # 2 [L1], # 3 - [L1] # 4 ] client_and_request_output_data = [ [L2], # 0 [L2], # 1 [L2], # 2 - [L3], # 3 + [L1], # 3 [L1], # 4 [L1], # 5 [L1], # 6 @@ -109,21 +104,20 @@ def read_item_test_data(): def query_items_change_feed_test_data(): client_only_output_data = [ - [L1, L1, L1], #0 - [L2, L2, L2], #1 - [L3, L3, L3], #2 - [L1, L1, L1], #3 - [L1, L1, L1] #4 + [L1, L1, L1, L1], #0 + [L2, L2, L2, L2], #1 + [L1, L1, L1, L1], #2 + [L1, L1, L1, L1] #3 ] client_and_request_output_data = [ - [L1, L1, L2], #0 - [L2, L2, L2], #1 - [L3, L3, L2], #2 - [L2, L2, L3], #3 - [L1, L1, L1], #4 - [L2, L2, L1], #5 - [L1, L1, L1], #6 - [L3, L3, L1], #7 + [L1, L1, L2, L2], #0 + [L2, L2, L2, L2], #1 + [L1, L1, L2, L2], #2 + [L2, L2, L1, L1], #3 + [L1, L1, L1, L1], #4 + [L2, L2, L1, L1], #5 + [L1, L1, L1, L1], #6 + [L1, L1, L1, L1], #7 ] all_output_test_data = client_only_output_data + client_and_request_output_data @@ -134,15 +128,14 @@ def replace_item_test_data(): client_only_output_data = [ [L1, L1], #0 [L2, L2], #1 - [L3, L3], #2 - [L1, L0], #3 - [L1, L1] #4 + [L1, L0], #2 + [L1, L1] #3 ] client_and_request_output_data = [ [L2, L2], #0 [L2, L2], #1 [L2, L2], #2 - [L3, L3], #3 + [L1, L0], #3 [L1, L0], #4 [L1, L1], #5 [L1, L1], #6 @@ -157,7 +150,6 @@ def patch_item_test_data(): client_only_output_data = [ [L1], #0 [L2], #1 - [L3], #2 [L0], #3 [L1] #4 ] @@ -165,7 +157,7 @@ def patch_item_test_data(): [L2], #0 [L2], #1 [L2], #2 - [L3], #3 + [L0], #3 [L0], #4 [L1], #5 [L1], #6 @@ -290,9 +282,9 @@ def test_query_items_change_feed(self, test_data): # API call: query_items_change_feed if request_excluded_locations is None: - list(container.query_items_change_feed()) + items = list(container.query_items_change_feed(start_time="Beginning")) else: - list(container.query_items_change_feed(excluded_locations=request_excluded_locations)) + items = list(container.query_items_change_feed(start_time="Beginning", excluded_locations=request_excluded_locations)) # Verify endpoint locations self._verify_endpoint(client, expected_locations) diff --git a/sdk/cosmos/azure-cosmos/tests/test_excluded_locations_async.py b/sdk/cosmos/azure-cosmos/tests/test_excluded_locations_async.py new file mode 100644 index 000000000000..7564071de4f9 --- /dev/null +++ b/sdk/cosmos/azure-cosmos/tests/test_excluded_locations_async.py @@ -0,0 +1,470 @@ +# The MIT License (MIT) +# Copyright (c) Microsoft Corporation. All rights reserved. + +import logging +import unittest +import uuid +import test_config +import pytest +import pytest_asyncio + +from azure.cosmos.aio import CosmosClient +from azure.cosmos.partition_key import PartitionKey + + +class MockHandler(logging.Handler): + def __init__(self): + super(MockHandler, self).__init__() + self.messages = [] + + def reset(self): + self.messages = [] + + def emit(self, record): + self.messages.append(record.msg) + +MOCK_HANDLER = MockHandler() +CONFIG = test_config.TestConfig() +HOST = CONFIG.host +KEY = CONFIG.masterKey +DATABASE_ID = CONFIG.TEST_DATABASE_ID +CONTAINER_ID = CONFIG.TEST_SINGLE_PARTITION_CONTAINER_ID +PARTITION_KEY = CONFIG.TEST_CONTAINER_PARTITION_KEY +ITEM_ID = 'doc1' +ITEM_PK_VALUE = 'pk' +TEST_ITEM = {'id': ITEM_ID, PARTITION_KEY: ITEM_PK_VALUE} + +L0 = "Default" +L1 = "West US 3" +L2 = "West US" +L3 = "East US 2" + +# L0 = "Default" +# L1 = "East US 2" +# L2 = "East US" +# L3 = "West US 2" + +CLIENT_ONLY_TEST_DATA = [ + # preferred_locations, client_excluded_locations, excluded_locations_request + # 0. No excluded location + [[L1, L2], [], None], + # 1. Single excluded location + [[L1, L2], [L1], None], + # 2. Exclude all locations + [[L1, L2], [L1, L2], None], + # 3. Exclude a location not in preferred locations + [[L1, L2], [L3], None], +] + +CLIENT_AND_REQUEST_TEST_DATA = [ + # preferred_locations, client_excluded_locations, excluded_locations_request + # 0. No client excluded locations + a request excluded location + [[L1, L2], [], [L1]], + # 1. The same client and request excluded location + [[L1, L2], [L1], [L1]], + # 2. Less request excluded locations + [[L1, L2], [L1, L2], [L1]], + # 3. More request excluded locations + [[L1, L2], [L1], [L1, L2]], + # 4. All locations were excluded + [[L1, L2], [L1, L2], [L1, L2]], + # 5. No common excluded locations + [[L1, L2], [L1], [L2, L3]], + # 6. Request excluded location not in preferred locations + [[L1, L2], [L1, L2], [L3]], + # 7. Empty excluded locations, remove all client excluded locations + [[L1, L2], [L1, L2], []], +] + +ALL_INPUT_TEST_DATA = CLIENT_ONLY_TEST_DATA + CLIENT_AND_REQUEST_TEST_DATA + +def read_item_test_data(): + client_only_output_data = [ + [L1], # 0 + [L2], # 1 + [L1], # 2 + [L1] # 3 + ] + client_and_request_output_data = [ + [L2], # 0 + [L2], # 1 + [L2], # 2 + [L1], # 3 + [L1], # 4 + [L1], # 5 + [L1], # 6 + [L1], # 7 + ] + all_output_test_data = client_only_output_data + client_and_request_output_data + + all_test_data = [input_data + [output_data] for input_data, output_data in zip(ALL_INPUT_TEST_DATA, all_output_test_data)] + return all_test_data + +def query_items_change_feed_test_data(): + client_only_output_data = [ + [L1, L1, L1], #0 + [L2, L2, L2], #1 + [L1, L1, L1], #2 + [L1, L1, L1] #3 + ] + client_and_request_output_data = [ + [L1, L2, L2], #0 + [L2, L2, L2], #1 + [L1, L2, L2], #2 + [L2, L1, L1], #3 + [L1, L1, L1], #4 + [L2, L1, L1], #5 + [L1, L1, L1], #6 + [L1, L1, L1], #7 + ] + all_output_test_data = client_only_output_data + client_and_request_output_data + + all_test_data = [input_data + [output_data] for input_data, output_data in zip(ALL_INPUT_TEST_DATA, all_output_test_data)] + return all_test_data + +def replace_item_test_data(): + client_only_output_data = [ + [L1], #0 + [L2], #1 + [L0], #2 + [L1] #3 + ] + client_and_request_output_data = [ + [L2], #0 + [L2], #1 + [L2], #2 + [L0], #3 + [L0], #4 + [L1], #5 + [L1], #6 + [L1], #7 + ] + all_output_test_data = client_only_output_data + client_and_request_output_data + + all_test_data = [input_data + [output_data] for input_data, output_data in zip(ALL_INPUT_TEST_DATA, all_output_test_data)] + return all_test_data + +def patch_item_test_data(): + client_only_output_data = [ + [L1], #0 + [L2], #1 + [L0], #2 + [L1] #3 + ] + client_and_request_output_data = [ + [L2], #0 + [L2], #1 + [L2], #2 + [L0], #3 + [L0], #4 + [L1], #5 + [L1], #6 + [L1], #7 + ] + all_output_test_data = client_only_output_data + client_and_request_output_data + + all_test_data = [input_data + [output_data] for input_data, output_data in zip(ALL_INPUT_TEST_DATA, all_output_test_data)] + return all_test_data + +@pytest_asyncio.fixture(scope="class", autouse=True) +async def setup_and_teardown(): + print("Setup: This runs before any tests") + logger = logging.getLogger("azure") + logger.addHandler(MOCK_HANDLER) + logger.setLevel(logging.DEBUG) + + test_client = CosmosClient(HOST, KEY) + container = test_client.get_database_client(DATABASE_ID).get_container_client(CONTAINER_ID) + await container.create_item(body=TEST_ITEM) + + yield + await test_client.close() + +@pytest.mark.cosmosMultiRegion +@pytest.mark.asyncio +@pytest.mark.usefixtures("setup_and_teardown") +class TestExcludedLocations: + async def _init_container(self, preferred_locations, client_excluded_locations, multiple_write_locations = True): + client = CosmosClient(HOST, KEY, + preferred_locations=preferred_locations, + excluded_locations=client_excluded_locations, + multiple_write_locations=multiple_write_locations) + db = await client.create_database_if_not_exists(DATABASE_ID) + container = await db.create_container_if_not_exists(CONTAINER_ID, PartitionKey(path='/' + PARTITION_KEY, kind='Hash')) + MOCK_HANDLER.reset() + + return client, db, container + + async def _verify_endpoint(self, client, expected_locations): + # get mapping for locations + location_mapping = (client.client_connection._global_endpoint_manager. + location_cache.account_locations_by_write_regional_routing_context) + default_endpoint = (client.client_connection._global_endpoint_manager. + location_cache.default_regional_routing_context.get_primary()) + + # get Request URL + msgs = MOCK_HANDLER.messages + req_urls = [url.replace("Request URL: '", "") for url in msgs if 'Request URL:' in url] + + # get location + actual_locations = [] + for req_url in req_urls: + if req_url.startswith(default_endpoint): + actual_locations.append(L0) + else: + for endpoint in location_mapping: + if req_url.startswith(endpoint): + location = location_mapping[endpoint] + actual_locations.append(location) + break + + assert actual_locations == expected_locations + + @pytest.mark.parametrize('test_data', read_item_test_data()) + async def test_read_item(self, test_data): + # Init test variables + preferred_locations, client_excluded_locations, request_excluded_locations, expected_locations = test_data + + # Client setup + client, db, container = await self._init_container(preferred_locations, client_excluded_locations) + + # API call: read_item + if request_excluded_locations is None: + await container.read_item(ITEM_ID, ITEM_PK_VALUE) + else: + await container.read_item(ITEM_ID, ITEM_PK_VALUE, excluded_locations=request_excluded_locations) + + # Verify endpoint locations + await self._verify_endpoint(client, expected_locations) + + @pytest.mark.parametrize('test_data', read_item_test_data()) + async def test_read_all_items(self, test_data): + # Init test variables + preferred_locations, client_excluded_locations, request_excluded_locations, expected_locations = test_data + + # Client setup + client, db, container = await self._init_container(preferred_locations, client_excluded_locations) + + # API call: read_all_items + if request_excluded_locations is None: + all_items = [item async for item in container.read_all_items()] + else: + all_items = [item async for item in container.read_all_items(excluded_locations=request_excluded_locations)] + + # Verify endpoint locations + await self._verify_endpoint(client, expected_locations) + + @pytest.mark.parametrize('test_data', read_item_test_data()) + async def test_query_items(self, test_data): + # Init test variables + preferred_locations, client_excluded_locations, request_excluded_locations, expected_locations = test_data + + # Client setup and create an item + client, db, container = await self._init_container(preferred_locations, client_excluded_locations) + + # API call: query_items + if request_excluded_locations is None: + all_items = [item async for item in container.query_items(None)] + else: + all_items = [item async for item in container.query_items(None, excluded_locations=request_excluded_locations)] + + # Verify endpoint locations + await self._verify_endpoint(client, expected_locations) + + @pytest.mark.parametrize('test_data', query_items_change_feed_test_data()) + async def test_query_items_change_feed(self, test_data): + # Init test variables + preferred_locations, client_excluded_locations, request_excluded_locations, expected_locations = test_data + + + # Client setup and create an item + client, db, container = await self._init_container(preferred_locations, client_excluded_locations) + + # API call: query_items_change_feed + if request_excluded_locations is None: + all_items = [item async for item in container.query_items_change_feed(start_time="Beginning")] + else: + all_items = [item async for item in container.query_items_change_feed(start_time="Beginning", excluded_locations=request_excluded_locations)] + + # Verify endpoint locations + await self._verify_endpoint(client, expected_locations) + + + @pytest.mark.parametrize('test_data', replace_item_test_data()) + async def test_replace_item(self, test_data): + # Init test variables + preferred_locations, client_excluded_locations, request_excluded_locations, expected_locations = test_data + + for multiple_write_locations in [True, False]: + # Client setup and create an item + client, db, container = await self._init_container(preferred_locations, client_excluded_locations, multiple_write_locations) + + # API call: replace_item + if request_excluded_locations is None: + await container.replace_item(ITEM_ID, body=TEST_ITEM) + else: + await container.replace_item(ITEM_ID, body=TEST_ITEM, excluded_locations=request_excluded_locations) + + # Verify endpoint locations + if multiple_write_locations: + await self._verify_endpoint(client, expected_locations) + else: + await self._verify_endpoint(client, [L1]) + + @pytest.mark.parametrize('test_data', replace_item_test_data()) + async def test_upsert_item(self, test_data): + # Init test variables + preferred_locations, client_excluded_locations, request_excluded_locations, expected_locations = test_data + + for multiple_write_locations in [True, False]: + # Client setup and create an item + client, db, container = await self._init_container(preferred_locations, client_excluded_locations, multiple_write_locations) + + # API call: upsert_item + body = {'pk': 'pk', 'id': f'doc2-{str(uuid.uuid4())}'} + if request_excluded_locations is None: + await container.upsert_item(body=body) + else: + await container.upsert_item(body=body, excluded_locations=request_excluded_locations) + + # get location from mock_handler + if multiple_write_locations: + await self._verify_endpoint(client, expected_locations) + else: + await self._verify_endpoint(client, [L1]) + + @pytest.mark.parametrize('test_data', replace_item_test_data()) + async def test_create_item(self, test_data): + # Init test variables + preferred_locations, client_excluded_locations, request_excluded_locations, expected_locations = test_data + + for multiple_write_locations in [True, False]: + # Client setup and create an item + client, db, container = await self._init_container(preferred_locations, client_excluded_locations, multiple_write_locations) + + # API call: create_item + body = {'pk': 'pk', 'id': f'doc2-{str(uuid.uuid4())}'} + if request_excluded_locations is None: + await container.create_item(body=body) + else: + await container.create_item(body=body, excluded_locations=request_excluded_locations) + + # get location from mock_handler + if multiple_write_locations: + await self._verify_endpoint(client, expected_locations) + else: + await self._verify_endpoint(client, [L1]) + + @pytest.mark.parametrize('test_data', patch_item_test_data()) + async def test_patch_item(self, test_data): + # Init test variables + preferred_locations, client_excluded_locations, request_excluded_locations, expected_locations = test_data + + for multiple_write_locations in [True, False]: + # Client setup and create an item + client, db, container = await self._init_container(preferred_locations, client_excluded_locations, + multiple_write_locations) + + # API call: patch_item + operations = [ + {"op": "add", "path": "/test_data", "value": f'Data-{str(uuid.uuid4())}'}, + ] + if request_excluded_locations is None: + await container.patch_item(item=ITEM_ID, partition_key=ITEM_PK_VALUE, + patch_operations=operations) + else: + await container.patch_item(item=ITEM_ID, partition_key=ITEM_PK_VALUE, + patch_operations=operations, + excluded_locations=request_excluded_locations) + + # get location from mock_handler + if multiple_write_locations: + await self._verify_endpoint(client, expected_locations) + else: + await self._verify_endpoint(client, [L1]) + + @pytest.mark.parametrize('test_data', patch_item_test_data()) + async def test_execute_item_batch(self, test_data): + # Init test variables + preferred_locations, client_excluded_locations, request_excluded_locations, expected_locations = test_data + + for multiple_write_locations in [True, False]: + # Client setup and create an item + client, db, container = await self._init_container(preferred_locations, client_excluded_locations, + multiple_write_locations) + + # API call: execute_item_batch + batch_operations = [] + for i in range(3): + batch_operations.append(("create", ({"id": f'Doc-{str(uuid.uuid4())}', PARTITION_KEY: ITEM_PK_VALUE},))) + + if request_excluded_locations is None: + await container.execute_item_batch(batch_operations=batch_operations, + partition_key=ITEM_PK_VALUE,) + else: + await container.execute_item_batch(batch_operations=batch_operations, + partition_key=ITEM_PK_VALUE, + excluded_locations=request_excluded_locations) + + # get location from mock_handler + if multiple_write_locations: + await self._verify_endpoint(client, expected_locations) + else: + await self._verify_endpoint(client, [L1]) + + @pytest.mark.parametrize('test_data', patch_item_test_data()) + async def test_delete_item(self, test_data): + # Init test variables + preferred_locations, client_excluded_locations, request_excluded_locations, expected_locations = test_data + + for multiple_write_locations in [True, False]: + # Client setup + client, db, container = await self._init_container(preferred_locations, client_excluded_locations, multiple_write_locations) + + #create before delete + item_id = f'doc2-{str(uuid.uuid4())}' + await container.create_item(body={PARTITION_KEY: ITEM_PK_VALUE, 'id': item_id}) + MOCK_HANDLER.reset() + + # API call: read_item + if request_excluded_locations is None: + await container.delete_item(item_id, ITEM_PK_VALUE) + else: + await container.delete_item(item_id, ITEM_PK_VALUE, excluded_locations=request_excluded_locations) + + # Verify endpoint locations + if multiple_write_locations: + await self._verify_endpoint(client, expected_locations) + else: + await self._verify_endpoint(client, [L1]) + + # TODO: enable this test once we figure out how to enable delete_all_items_by_partition_key feature + # @pytest.mark.parametrize('test_data', patch_item_test_data()) + # def test_delete_all_items_by_partition_key(self, test_data): + # # Init test variables + # preferred_locations, client_excluded_locations, request_excluded_locations, expected_locations = test_data + # + # for multiple_write_locations in [True, False]: + # # Client setup + # client, db, container = self._init_container(preferred_locations, client_excluded_locations, multiple_write_locations) + # + # #create before delete + # item_id = f'doc2-{str(uuid.uuid4())}' + # pk_value = f'temp_partition_key_value-{str(uuid.uuid4())}' + # container.create_item(body={PARTITION_KEY: pk_value, 'id': item_id}) + # MOCK_HANDLER.reset() + # + # # API call: read_item + # if request_excluded_locations is None: + # container.delete_all_items_by_partition_key(pk_value) + # else: + # container.delete_all_items_by_partition_key(pk_value, excluded_locations=request_excluded_locations) + # + # # Verify endpoint locations + # if multiple_write_locations: + # self._verify_endpoint(client, expected_locations) + # else: + # self._verify_endpoint(client, [L1]) + +if __name__ == "__main__": + unittest.main() From 56108892418867fb21d6c7dad05c6ca0a2fbf982 Mon Sep 17 00:00:00 2001 From: Allen Kim Date: Mon, 7 Apr 2025 16:55:49 -0700 Subject: [PATCH 61/86] Add more live tests for all other Python versions --- sdk/cosmos/live-platform-matrix.json | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/sdk/cosmos/live-platform-matrix.json b/sdk/cosmos/live-platform-matrix.json index dc3ad3c32e17..6763c1c06562 100644 --- a/sdk/cosmos/live-platform-matrix.json +++ b/sdk/cosmos/live-platform-matrix.json @@ -91,6 +91,22 @@ }, { "WindowsConfig": { + "Windows2022_38": { + "OSVmImage": "env:WINDOWSVMIMAGE", + "Pool": "env:WINDOWSPOOL", + "PythonVersion": "3.8", + "CoverageArg": "--disablecov", + "TestSamples": "false", + "TestMarkArgument": "cosmosMultiRegion" + }, + "Windows2022_310": { + "OSVmImage": "env:WINDOWSVMIMAGE", + "Pool": "env:WINDOWSPOOL", + "PythonVersion": "3.10", + "CoverageArg": "--disablecov", + "TestSamples": "false", + "TestMarkArgument": "cosmosMultiRegion" + }, "Windows2022_312": { "OSVmImage": "env:WINDOWSVMIMAGE", "Pool": "env:WINDOWSPOOL", From f4cb8b3ba9c1507793af77547281741e221b7af1 Mon Sep 17 00:00:00 2001 From: Allen Kim Date: Mon, 7 Apr 2025 17:07:29 -0700 Subject: [PATCH 62/86] Fix Async test failure --- sdk/cosmos/azure-cosmos/tests/test_excluded_locations.py | 2 +- sdk/cosmos/azure-cosmos/tests/test_excluded_locations_async.py | 2 +- 2 files changed, 2 insertions(+), 2 deletions(-) diff --git a/sdk/cosmos/azure-cosmos/tests/test_excluded_locations.py b/sdk/cosmos/azure-cosmos/tests/test_excluded_locations.py index 13e9ba713653..2159c6c97425 100644 --- a/sdk/cosmos/azure-cosmos/tests/test_excluded_locations.py +++ b/sdk/cosmos/azure-cosmos/tests/test_excluded_locations.py @@ -176,7 +176,7 @@ def setup_and_teardown(): logger.setLevel(logging.DEBUG) container = cosmos_client.CosmosClient(HOST, KEY).get_database_client(DATABASE_ID).get_container_client(CONTAINER_ID) - container.create_item(body=TEST_ITEM) + container.upsert_item(body=TEST_ITEM) yield # Code to run after tests diff --git a/sdk/cosmos/azure-cosmos/tests/test_excluded_locations_async.py b/sdk/cosmos/azure-cosmos/tests/test_excluded_locations_async.py index 7564071de4f9..b0079e753039 100644 --- a/sdk/cosmos/azure-cosmos/tests/test_excluded_locations_async.py +++ b/sdk/cosmos/azure-cosmos/tests/test_excluded_locations_async.py @@ -175,7 +175,7 @@ async def setup_and_teardown(): test_client = CosmosClient(HOST, KEY) container = test_client.get_database_client(DATABASE_ID).get_container_client(CONTAINER_ID) - await container.create_item(body=TEST_ITEM) + await container.upsert_item(body=TEST_ITEM) yield await test_client.close() From e98ab571cd172ade4ad46d04b691b7179de96775 Mon Sep 17 00:00:00 2001 From: Tomas Varon Saldarriaga Date: Tue, 8 Apr 2025 12:02:02 -0400 Subject: [PATCH 63/86] add test for failure_rate threshold --- ...n_endpoint_manager_circuit_breaker_core.py | 4 +- .../azure/cosmos/_partition_health_tracker.py | 122 +++++++++++------- ..._endpoint_manager_circuit_breaker_async.py | 16 ++- .../azure/cosmos/aio/_retry_utility_async.py | 8 +- .../tests/test_ppcb_sm_mrr_async.py | 92 +++++++++++-- 5 files changed, 171 insertions(+), 71 deletions(-) diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/_global_partition_endpoint_manager_circuit_breaker_core.py b/sdk/cosmos/azure-cosmos/azure/cosmos/_global_partition_endpoint_manager_circuit_breaker_core.py index 1f900215f08d..577b8410f435 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/_global_partition_endpoint_manager_circuit_breaker_core.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/_global_partition_endpoint_manager_circuit_breaker_core.py @@ -26,7 +26,7 @@ from azure.cosmos import documents -from azure.cosmos._partition_health_tracker import PartitionHealthTracker +from azure.cosmos._partition_health_tracker import _PartitionHealthTracker from azure.cosmos._routing.routing_range import PartitionKeyRangeWrapper from azure.cosmos._location_cache import EndpointOperationType, LocationCache from azure.cosmos._request_object import RequestObject @@ -44,7 +44,7 @@ class _GlobalPartitionEndpointManagerForCircuitBreakerCore(object): def __init__(self, client, location_cache: LocationCache): - self.partition_health_tracker = PartitionHealthTracker() + self.partition_health_tracker = _PartitionHealthTracker() self.location_cache = location_cache self.client = client diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/_partition_health_tracker.py b/sdk/cosmos/azure-cosmos/azure/cosmos/_partition_health_tracker.py index 04849ca1cb4b..f6c75c274ab6 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/_partition_health_tracker.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/_partition_health_tracker.py @@ -21,6 +21,7 @@ """Internal class for partition health tracker for circuit breaker. """ +import logging import os from typing import Dict, Set, Any from azure.cosmos._routing.routing_range import PartitionKeyRangeWrapper @@ -74,8 +75,18 @@ def reset_health_stats(self) -> None: self.read_consecutive_failure_count = 0 self.write_consecutive_failure_count = 0 + def __str__(self) -> str: + return (f"{self.__class__.__name__}: {self.unavailability_info}\n" + f"write failure count: {self.write_failure_count}\n" + f"read failure count: {self.read_failure_count}\n" + f"write success count: {self.write_success_count}\n" + f"read success count: {self.read_success_count}\n" + f"write consecutive failure count: {self.write_consecutive_failure_count}\n" + f"read consecutive failure count: {self.read_consecutive_failure_count}\n") -class PartitionHealthTracker(object): +logger = logging.getLogger("azure.cosmos._PartitionHealthTracker") + +class _PartitionHealthTracker(object): """ This internal class implements the logic for tracking health thresholds for a partition. """ @@ -83,7 +94,7 @@ class PartitionHealthTracker(object): def __init__(self) -> None: # partition -> regions -> health info - self.pkrange_wrapper_to_health_info: Dict[PartitionKeyRangeWrapper, Dict[str, _PartitionHealthInfo]] = {} + self.pk_range_wrapper_to_health_info: Dict[PartitionKeyRangeWrapper, Dict[str, _PartitionHealthInfo]] = {} self.last_refresh = current_time_millis() # TODO: @tvaron3 look for useful places to add logs @@ -97,26 +108,27 @@ def _transition_health_status_on_failure( pkrange_wrapper: PartitionKeyRangeWrapper, location: str ) -> None: + logger.warn("{} has been marked as unavailable.".format(pkrange_wrapper)) current_time = current_time_millis() - if pkrange_wrapper not in self.pkrange_wrapper_to_health_info: + if pkrange_wrapper not in self.pk_range_wrapper_to_health_info: # healthy -> unhealthy tentative partition_health_info = _PartitionHealthInfo() partition_health_info.unavailability_info = { LAST_UNAVAILABILITY_CHECK_TIME_STAMP: current_time, HEALTH_STATUS: UNHEALTHY_TENTATIVE } - self.pkrange_wrapper_to_health_info[pkrange_wrapper] = { + self.pk_range_wrapper_to_health_info[pkrange_wrapper] = { location: partition_health_info } else: - region_to_partition_health = self.pkrange_wrapper_to_health_info[pkrange_wrapper] - if location in region_to_partition_health: + region_to_partition_health = self.pk_range_wrapper_to_health_info[pkrange_wrapper] + if location in region_to_partition_health and region_to_partition_health[location].unavailability_info: # healthy tentative -> unhealthy # if the operation type is not empty, we are in the healthy tentative state region_to_partition_health[location].unavailability_info[HEALTH_STATUS] = UNHEALTHY # reset the last unavailability check time stamp region_to_partition_health[location].unavailability_info[LAST_UNAVAILABILITY_CHECK_TIME_STAMP] \ - = UNHEALTHY + = current_time else: # healthy -> unhealthy tentative # if the operation type is empty, we are in the unhealthy tentative state @@ -125,49 +137,55 @@ def _transition_health_status_on_failure( LAST_UNAVAILABILITY_CHECK_TIME_STAMP: current_time, HEALTH_STATUS: UNHEALTHY_TENTATIVE } - self.pkrange_wrapper_to_health_info[pkrange_wrapper][location] = partition_health_info + self.pk_range_wrapper_to_health_info[pkrange_wrapper][location] = partition_health_info def _transition_health_status_on_success( self, pkrange_wrapper: PartitionKeyRangeWrapper, location: str ) -> None: - if pkrange_wrapper in self.pkrange_wrapper_to_health_info: + if pkrange_wrapper in self.pk_range_wrapper_to_health_info: # healthy tentative -> healthy - self.pkrange_wrapper_to_health_info[pkrange_wrapper].pop(location, None) + self.pk_range_wrapper_to_health_info[pkrange_wrapper].pop(location, None) def _check_stale_partition_info(self, pk_range_wrapper: PartitionKeyRangeWrapper) -> None: current_time = current_time_millis() stale_partition_unavailability_check = int(os.environ.get(Constants.STALE_PARTITION_UNAVAILABILITY_CHECK, Constants.STALE_PARTITION_UNAVAILABILITY_CHECK_DEFAULT)) * 1000 - if pk_range_wrapper in self.pkrange_wrapper_to_health_info: - for _, partition_health_info in self.pkrange_wrapper_to_health_info[pk_range_wrapper].items(): - elapsed_time = (current_time - - partition_health_info.unavailability_info[LAST_UNAVAILABILITY_CHECK_TIME_STAMP]) - current_health_status = partition_health_info.unavailability_info[HEALTH_STATUS] - # check if the partition key range is still unavailable - if ((current_health_status == UNHEALTHY and elapsed_time > stale_partition_unavailability_check) - or (current_health_status == UNHEALTHY_TENTATIVE - and elapsed_time > INITIAL_UNAVAILABLE_TIME)): - # unhealthy or unhealthy tentative -> healthy tentative - partition_health_info.unavailability_info[HEALTH_STATUS] = HEALTHY_TENTATIVE - - if current_time - self.last_refresh < REFRESH_INTERVAL: + if pk_range_wrapper in self.pk_range_wrapper_to_health_info: + for _, partition_health_info in self.pk_range_wrapper_to_health_info[pk_range_wrapper].items(): + if partition_health_info.unavailability_info: + elapsed_time = (current_time - + partition_health_info.unavailability_info[LAST_UNAVAILABILITY_CHECK_TIME_STAMP]) + current_health_status = partition_health_info.unavailability_info[HEALTH_STATUS] + # check if the partition key range is still unavailable + if ((current_health_status == UNHEALTHY and elapsed_time > stale_partition_unavailability_check) + or (current_health_status == UNHEALTHY_TENTATIVE + and elapsed_time > INITIAL_UNAVAILABLE_TIME)): + # unhealthy or unhealthy tentative -> healthy tentative + partition_health_info.unavailability_info[HEALTH_STATUS] = HEALTHY_TENTATIVE + + if current_time - self.last_refresh > REFRESH_INTERVAL: # all partition stats reset every minute self._reset_partition_health_tracker_stats() - def get_excluded_locations(self, pkrange_wrapper: PartitionKeyRangeWrapper) -> Set[str]: - self._check_stale_partition_info(pkrange_wrapper) - if pkrange_wrapper in self.pkrange_wrapper_to_health_info: - return set(self.pkrange_wrapper_to_health_info[pkrange_wrapper].keys()) - return set() + def get_excluded_locations(self, pk_range_wrapper: PartitionKeyRangeWrapper) -> Set[str]: + self._check_stale_partition_info(pk_range_wrapper) + excluded_locations = set() + if pk_range_wrapper in self.pk_range_wrapper_to_health_info: + for location, partition_health_info in self.pk_range_wrapper_to_health_info[pk_range_wrapper].items(): + if partition_health_info.unavailability_info: + health_status = partition_health_info.unavailability_info[HEALTH_STATUS] + if health_status == UNHEALTHY_TENTATIVE or health_status == UNHEALTHY: + excluded_locations.add(location) + return excluded_locations def add_failure( self, - pkrange_wrapper: PartitionKeyRangeWrapper, + pk_range_wrapper: PartitionKeyRangeWrapper, operation_type: str, location: str ) -> None: @@ -176,12 +194,12 @@ def add_failure( Constants.FAILURE_PERCENTAGE_TOLERATED_DEFAULT)) # Ensure that the health info dictionary is properly initialized. - if pkrange_wrapper not in self.pkrange_wrapper_to_health_info: - self.pkrange_wrapper_to_health_info[pkrange_wrapper] = {} - if location not in self.pkrange_wrapper_to_health_info[pkrange_wrapper]: - self.pkrange_wrapper_to_health_info[pkrange_wrapper][location] = _PartitionHealthInfo() + if pk_range_wrapper not in self.pk_range_wrapper_to_health_info: + self.pk_range_wrapper_to_health_info[pk_range_wrapper] = {} + if location not in self.pk_range_wrapper_to_health_info[pk_range_wrapper]: + self.pk_range_wrapper_to_health_info[pk_range_wrapper][location] = _PartitionHealthInfo() - health_info = self.pkrange_wrapper_to_health_info[pkrange_wrapper][location] + health_info = self.pk_range_wrapper_to_health_info[pk_range_wrapper][location] # Determine attribute names and environment variables based on the operation type. if operation_type == EndpointOperationType.WriteType: @@ -189,24 +207,24 @@ def add_failure( failure_attr = 'write_failure_count' consecutive_attr = 'write_consecutive_failure_count' env_key = Constants.CONSECUTIVE_ERROR_COUNT_TOLERATED_FOR_WRITE - default_consec_threshold = Constants.CONSECUTIVE_ERROR_COUNT_TOLERATED_FOR_WRITE_DEFAULT + default_consecutive_threshold = Constants.CONSECUTIVE_ERROR_COUNT_TOLERATED_FOR_WRITE_DEFAULT else: success_attr = 'read_success_count' failure_attr = 'read_failure_count' consecutive_attr = 'read_consecutive_failure_count' env_key = Constants.CONSECUTIVE_ERROR_COUNT_TOLERATED_FOR_READ - default_consec_threshold = Constants.CONSECUTIVE_ERROR_COUNT_TOLERATED_FOR_READ_DEFAULT + default_consecutive_threshold = Constants.CONSECUTIVE_ERROR_COUNT_TOLERATED_FOR_READ_DEFAULT # Increment failure and consecutive failure counts. setattr(health_info, failure_attr, getattr(health_info, failure_attr) + 1) setattr(health_info, consecutive_attr, getattr(health_info, consecutive_attr) + 1) # Retrieve the consecutive failure threshold from the environment. - consecutive_failure_threshold = int(os.environ.get(env_key, default_consec_threshold)) + consecutive_failure_threshold = int(os.environ.get(env_key, default_consecutive_threshold)) # Call the threshold checker with the current stats. self._check_thresholds( - pkrange_wrapper, + pk_range_wrapper, getattr(health_info, success_attr), getattr(health_info, failure_attr), getattr(health_info, consecutive_attr), @@ -214,10 +232,13 @@ def add_failure( failure_rate_threshold, consecutive_failure_threshold ) + print(self.pk_range_wrapper_to_health_info[pk_range_wrapper][location]) + print(pk_range_wrapper) + print(location) def _check_thresholds( self, - pkrange_wrapper: PartitionKeyRangeWrapper, + pk_range_wrapper: PartitionKeyRangeWrapper, successes: int, failures: int, consecutive_failures: int, @@ -232,20 +253,20 @@ def _check_thresholds( failures, failure_rate_threshold ): - self._transition_health_status_on_failure(pkrange_wrapper, location) + self._transition_health_status_on_failure(pk_range_wrapper, location) # add to consecutive failures and check that threshold was not exceeded if consecutive_failures >= consecutive_failure_threshold: - self._transition_health_status_on_failure(pkrange_wrapper, location) + self._transition_health_status_on_failure(pk_range_wrapper, location) - def add_success(self, pkrange_wrapper: PartitionKeyRangeWrapper, operation_type: str, location: str) -> None: + def add_success(self, pk_range_wrapper: PartitionKeyRangeWrapper, operation_type: str, location: str) -> None: # Ensure that the health info dictionary is initialized. - if pkrange_wrapper not in self.pkrange_wrapper_to_health_info: - self.pkrange_wrapper_to_health_info[pkrange_wrapper] = {} - if location not in self.pkrange_wrapper_to_health_info[pkrange_wrapper]: - self.pkrange_wrapper_to_health_info[pkrange_wrapper][location] = _PartitionHealthInfo() + if pk_range_wrapper not in self.pk_range_wrapper_to_health_info: + self.pk_range_wrapper_to_health_info[pk_range_wrapper] = {} + if location not in self.pk_range_wrapper_to_health_info[pk_range_wrapper]: + self.pk_range_wrapper_to_health_info[pk_range_wrapper][location] = _PartitionHealthInfo() - health_info = self.pkrange_wrapper_to_health_info[pkrange_wrapper][location] + health_info = self.pk_range_wrapper_to_health_info[pk_range_wrapper][location] if operation_type == EndpointOperationType.WriteType: health_info.write_success_count += 1 @@ -253,10 +274,13 @@ def add_success(self, pkrange_wrapper: PartitionKeyRangeWrapper, operation_type: else: health_info.read_success_count += 1 health_info.read_consecutive_failure_count = 0 - self._transition_health_status_on_success(pkrange_wrapper, operation_type) + self._transition_health_status_on_success(pk_range_wrapper, operation_type) + print(self.pk_range_wrapper_to_health_info[pk_range_wrapper][location]) + print(pk_range_wrapper) + print(location) def _reset_partition_health_tracker_stats(self) -> None: - for locations in self.pkrange_wrapper_to_health_info.values(): + for locations in self.pk_range_wrapper_to_health_info.values(): for health_info in locations.values(): health_info.reset_health_stats() diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_global_partition_endpoint_manager_circuit_breaker_async.py b/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_global_partition_endpoint_manager_circuit_breaker_async.py index 3aadc9b6aba0..2ba64690da7b 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_global_partition_endpoint_manager_circuit_breaker_async.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_global_partition_endpoint_manager_circuit_breaker_async.py @@ -23,9 +23,10 @@ """ from typing import TYPE_CHECKING +from azure.cosmos import PartitionKey from azure.cosmos._global_partition_endpoint_manager_circuit_breaker_core import \ _GlobalPartitionEndpointManagerForCircuitBreakerCore -from azure.cosmos._routing.routing_range import PartitionKeyRangeWrapper +from azure.cosmos._routing.routing_range import PartitionKeyRangeWrapper, Range from azure.cosmos.aio._global_endpoint_manager_async import _GlobalEndpointManager from azure.cosmos._request_object import RequestObject @@ -50,18 +51,23 @@ def __init__(self, client: "CosmosClientConnection"): async def create_pk_range_wrapper(self, request: RequestObject) -> PartitionKeyRangeWrapper: container_rid = request.headers[HttpHeaders.IntendedCollectionRID] - partition_key = request.headers[HttpHeaders.PartitionKey] + partition_key_value = request.headers[HttpHeaders.PartitionKey] # get the partition key range for the given partition key target_container_link = None for container_link, properties in self.client._container_properties_cache.items(): # pylint: disable=protected-access if properties["_rid"] == container_rid: target_container_link = container_link + partition_key_definition = properties["partitionKey"] + partition_key = PartitionKey(path=partition_key_definition["paths"], kind=partition_key_definition["kind"]) + + epk_range = [partition_key._get_epk_range_for_partition_key(partition_key_value)] if not target_container_link: raise RuntimeError("Illegal state: the container cache is not properly initialized.") # TODO: @tvaron3 check different clients and create them in different ways - pk_range = await (self.client._routing_map_provider # pylint: disable=protected-access - .get_overlapping_ranges(target_container_link, partition_key)) - return PartitionKeyRangeWrapper(pk_range, container_rid) + partition_ranges = await (self.client._routing_map_provider # pylint: disable=protected-access + .get_overlapping_ranges(target_container_link, epk_range)) + partition_range = Range.PartitionKeyRangeToRange(partition_ranges[0]) + return PartitionKeyRangeWrapper(partition_range, container_rid) def is_circuit_breaker_applicable(self, request: RequestObject) -> bool: return self.global_partition_endpoint_manager_core.is_circuit_breaker_applicable(request) diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_retry_utility_async.py b/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_retry_utility_async.py index 6e94c9260b8b..5d4d680b50e4 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_retry_utility_async.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_retry_utility_async.py @@ -202,7 +202,7 @@ async def ExecuteAsync(client, global_endpoint_manager, function, *args, **kwarg if client_timeout: kwargs['timeout'] = client_timeout - (time.time() - start_time) if kwargs['timeout'] <= 0: - global_endpoint_manager.record_failure(args[0]) + await global_endpoint_manager.record_failure(args[0]) raise exceptions.CosmosClientTimeoutError() except ServiceRequestError as e: @@ -286,7 +286,7 @@ async def send(self, request): # the request ran into a socket timeout or failed to establish a new connection # since request wasn't sent, raise exception immediately to be dealt with in client retry policies if not _has_database_account_header(request.http_request.headers): - global_endpoint_manager.record_failure(request_params) + await global_endpoint_manager.record_failure(request_params) if retry_settings['connect'] > 0: retry_active = self.increment(retry_settings, response=request, error=err) if retry_active: @@ -304,7 +304,7 @@ async def send(self, request): ClientConnectionError) if (isinstance(err.inner_exception, ClientConnectionError) or _has_read_retryable_headers(request.http_request.headers)): - global_endpoint_manager.record_failure(request_params) + await global_endpoint_manager.record_failure(request_params) # This logic is based on the _retry.py file from azure-core if retry_settings['read'] > 0: retry_active = self.increment(retry_settings, response=request, error=err) @@ -319,7 +319,7 @@ async def send(self, request): if _has_database_account_header(request.http_request.headers): raise err if self._is_method_retryable(retry_settings, request.http_request): - global_endpoint_manager.record_failure(request_params) + await global_endpoint_manager.record_failure(request_params) retry_active = self.increment(retry_settings, response=request, error=err) if retry_active: await self.sleep(retry_settings, request.context.transport) diff --git a/sdk/cosmos/azure-cosmos/tests/test_ppcb_sm_mrr_async.py b/sdk/cosmos/azure-cosmos/tests/test_ppcb_sm_mrr_async.py index a8e39ae361da..8902a72d2f6d 100644 --- a/sdk/cosmos/azure-cosmos/tests/test_ppcb_sm_mrr_async.py +++ b/sdk/cosmos/azure-cosmos/tests/test_ppcb_sm_mrr_async.py @@ -43,7 +43,7 @@ def errors(): for error_code in error_codes: errors_list.append(CosmosHttpResponseError( status_code=error_code, - message="Some injected fault.")) + message="Some injected error.")) errors_list.append(ServiceResponseError(message="Injected Service Response Error.")) return errors_list @@ -115,10 +115,11 @@ async def test_consecutive_failure_threshold_async(self, setup, error): # writes should fail in sm mrr with circuit breaker and should not mark unavailable a partition for i in range(6): - with pytest.raises(CosmosHttpResponseError or ServiceResponseError): + with pytest.raises((CosmosHttpResponseError, ServiceResponseError)): await container.create_item(body=document_definition) + global_endpoint_manager = container.client_connection._global_endpoint_manager - TestPPCBSmMrrAsync.validate_unhealthy_partitions(container.client_connection._global_endpoint_manager, 0) + TestPPCBSmMrrAsync.validate_unhealthy_partitions(global_endpoint_manager, 0) # create item with client without fault injection await setup[COLLECTION].create_item(body=document_definition) @@ -126,32 +127,101 @@ async def test_consecutive_failure_threshold_async(self, setup, error): # reads should fail over and only the relevant partition should be marked as unavailable await container.read_item(item=document_definition['id'], partition_key=document_definition['pk']) # partition should not have been marked unavailable after one error - TestPPCBSmMrrAsync.validate_unhealthy_partitions(container.client_connection._global_endpoint_manager, 0) + TestPPCBSmMrrAsync.validate_unhealthy_partitions(global_endpoint_manager, 0) - for i in range(11): + for i in range(10): read_resp = await container.read_item(item=document_definition['id'], partition_key=document_definition['pk']) request = read_resp.get_response_headers()["_request"] # Validate the response comes from "Read Region" (the most preferred read-only region) assert request.url.startswith(expected_read_region_uri) # the partition should have been marked as unavailable after breaking read threshold - TestPPCBSmMrrAsync.validate_unhealthy_partitions(container.client_connection._global_endpoint_manager, 1) + TestPPCBSmMrrAsync.validate_unhealthy_partitions(global_endpoint_manager, 1) + @pytest.mark.parametrize("error", errors()) + async def test_failure_rate_threshold_async(self, setup, error): + expected_read_region_uri = self.host + expected_write_region_uri = expected_read_region_uri.replace("localhost", "127.0.0.1") + custom_transport = await self.create_custom_transport_sm_mrr() + id_value = 'failoverDoc-' + str(uuid.uuid4()) + # two documents targeted to same partition, one will always fail and the other will succeed + document_definition = {'id': id_value, + 'pk': 'pk1', + 'name': 'sample document', + 'key': 'value'} + document_definition_2 = {'id': str(uuid.uuid4()), + 'pk': 'pk1', + 'name': 'sample document', + 'key': 'value'} + predicate = lambda r: (FaultInjectionTransportAsync.predicate_req_for_document_with_id(r, id_value) and + FaultInjectionTransportAsync.predicate_targets_region(r, expected_write_region_uri)) + custom_transport.add_fault(predicate, + lambda r: asyncio.create_task(FaultInjectionTransportAsync.error_after_delay( + 0, + error + ))) + + custom_setup = await self.setup_method_with_custom_transport(custom_transport, default_endpoint=self.host) + container = custom_setup['col'] + + # writes should fail in sm mrr with circuit breaker and should not mark unavailable a partition + for i in range(6): + if i % 2 == 0: + with pytest.raises((CosmosHttpResponseError, ServiceResponseError)): + await container.upsert_item(body=document_definition) + else: + await container.upsert_item(body=document_definition_2) + global_endpoint_manager = container.client_connection._global_endpoint_manager + + TestPPCBSmMrrAsync.validate_unhealthy_partitions(global_endpoint_manager, 0) + + # create item with client without fault injection + await setup[COLLECTION].create_item(body=document_definition) + + # reads should fail over and only the relevant partition should be marked as unavailable + await container.read_item(item=document_definition['id'], partition_key=document_definition['pk']) + # partition should not have been marked unavailable after one error + TestPPCBSmMrrAsync.validate_unhealthy_partitions(global_endpoint_manager, 0) + # lower minimum requests for testing + global_endpoint_manager.global_partition_endpoint_manager_core.partition_health_tracker.MINIMUM_REQUESTS_FOR_FAILURE_RATE = 10 + try: + for i in range(20): + if i == 8: + read_resp = await container.read_item(item=document_definition_2['id'], + partition_key=document_definition_2['pk']) + else: + read_resp = await container.read_item(item=document_definition['id'], + partition_key=document_definition['pk']) + request = read_resp.get_response_headers()["_request"] + # Validate the response comes from "Read Region" (the most preferred read-only region) + assert request.url.startswith(expected_read_region_uri) + + # the partition should have been marked as unavailable after breaking read threshold + TestPPCBSmMrrAsync.validate_unhealthy_partitions(global_endpoint_manager, 1) + finally: + # restore minimum requests + global_endpoint_manager.global_partition_endpoint_manager_core.partition_health_tracker.MINIMUM_REQUESTS_FOR_FAILURE_RATE = 100 @staticmethod - def validate_unhealthy_partitions(global_endpoint_manager, expected_unhealthy_partitions): - health_info_map = global_endpoint_manager.global_partition_endpoint_manager_core.partition_health_tracker.pkrange_wrapper_to_health_info + def validate_unhealthy_partitions(global_endpoint_manager, + expected_unhealthy_partitions): + health_info_map = global_endpoint_manager.global_partition_endpoint_manager_core.partition_health_tracker.pk_range_wrapper_to_health_info unhealthy_partitions = 0 for pk_range_wrapper, location_to_health_info in health_info_map.items(): for location, health_info in location_to_health_info.items(): health_status = health_info.unavailability_info.get(HEALTH_STATUS) if health_status == UNHEALTHY_TENTATIVE or health_status == UNHEALTHY: unhealthy_partitions += 1 - assert unhealthy_partitions == expected_unhealthy_partitions + else: + assert health_info.read_consecutive_failure_count < 10 + assert health_info.write_failure_count == 0 + assert health_info.write_consecutive_failure_count == 0 + assert unhealthy_partitions == expected_unhealthy_partitions - # test_failure_rate_threshold - add service response error - # test service request marks only a partition unavailable not an entire region + # test_failure_rate_threshold - add service response error - across operation types + # test service request marks only a partition unavailable not an entire region - across operation types + # test cosmos client timeout if __name__ == '__main__': unittest.main() From 4f081681f2cd87a103bf65edd4b61a12283cfcb7 Mon Sep 17 00:00:00 2001 From: Allen Kim Date: Tue, 8 Apr 2025 10:23:28 -0700 Subject: [PATCH 64/86] Fix live test failures --- .../tests/test_excluded_locations.py | 18 +++-- .../tests/test_excluded_locations_async.py | 66 ++++++++++--------- 2 files changed, 46 insertions(+), 38 deletions(-) diff --git a/sdk/cosmos/azure-cosmos/tests/test_excluded_locations.py b/sdk/cosmos/azure-cosmos/tests/test_excluded_locations.py index 2159c6c97425..9af367303107 100644 --- a/sdk/cosmos/azure-cosmos/tests/test_excluded_locations.py +++ b/sdk/cosmos/azure-cosmos/tests/test_excluded_locations.py @@ -168,6 +168,12 @@ def patch_item_test_data(): all_test_data = [input_data + [output_data] for input_data, output_data in zip(ALL_INPUT_TEST_DATA, all_output_test_data)] return all_test_data +def _create_item_with_excluded_locations(container, body, excluded_locations): + if excluded_locations is None: + container.create_item(body=body) + else: + container.create_item(body=body, excluded_locations=excluded_locations) + @pytest.fixture(scope="class", autouse=True) def setup_and_teardown(): print("Setup: This runs before any tests") @@ -344,10 +350,7 @@ def test_create_item(self, test_data): # API call: create_item body = {'pk': 'pk', 'id': f'doc2-{str(uuid.uuid4())}'} - if request_excluded_locations is None: - container.create_item(body=body) - else: - container.create_item(body=body, excluded_locations=request_excluded_locations) + _create_item_with_excluded_locations(container, body, request_excluded_locations) # get location from mock_handler if multiple_write_locations: @@ -421,12 +424,13 @@ def test_delete_item(self, test_data): # Client setup client, db, container = self._init_container(preferred_locations, client_excluded_locations, multiple_write_locations) - #create before delete + # create before delete item_id = f'doc2-{str(uuid.uuid4())}' - container.create_item(body={PARTITION_KEY: ITEM_PK_VALUE, 'id': item_id}) + body = {PARTITION_KEY: ITEM_PK_VALUE, 'id': item_id} + _create_item_with_excluded_locations(container, body, request_excluded_locations) MOCK_HANDLER.reset() - # API call: read_item + # API call: delete_item if request_excluded_locations is None: container.delete_item(item_id, ITEM_PK_VALUE) else: diff --git a/sdk/cosmos/azure-cosmos/tests/test_excluded_locations_async.py b/sdk/cosmos/azure-cosmos/tests/test_excluded_locations_async.py index b0079e753039..dd6ce3776f68 100644 --- a/sdk/cosmos/azure-cosmos/tests/test_excluded_locations_async.py +++ b/sdk/cosmos/azure-cosmos/tests/test_excluded_locations_async.py @@ -80,20 +80,20 @@ def emit(self, record): def read_item_test_data(): client_only_output_data = [ - [L1], # 0 - [L2], # 1 - [L1], # 2 - [L1] # 3 + [L1, L1], # 0 + [L2, L2], # 1 + [L1, L1], # 2 + [L1, L1], # 3 ] client_and_request_output_data = [ - [L2], # 0 - [L2], # 1 - [L2], # 2 - [L1], # 3 - [L1], # 4 - [L1], # 5 - [L1], # 6 - [L1], # 7 + [L2, L2], # 0 + [L2, L2], # 1 + [L2, L2], # 2 + [L1, L1], # 3 + [L1, L1], # 4 + [L1, L1], # 5 + [L1, L1], # 6 + [L1, L1], # 7 ] all_output_test_data = client_only_output_data + client_and_request_output_data @@ -102,20 +102,20 @@ def read_item_test_data(): def query_items_change_feed_test_data(): client_only_output_data = [ - [L1, L1, L1], #0 - [L2, L2, L2], #1 - [L1, L1, L1], #2 - [L1, L1, L1] #3 + [L1, L1, L1, L1], #0 + [L2, L2, L2, L2], #1 + [L1, L1, L1, L1], #2 + [L1, L1, L1, L1] #3 ] client_and_request_output_data = [ - [L1, L2, L2], #0 - [L2, L2, L2], #1 - [L1, L2, L2], #2 - [L2, L1, L1], #3 - [L1, L1, L1], #4 - [L2, L1, L1], #5 - [L1, L1, L1], #6 - [L1, L1, L1], #7 + [L1, L2, L2, L2], #0 + [L2, L2, L2, L2], #1 + [L1, L2, L2, L2], #2 + [L2, L1, L1, L1], #3 + [L1, L1, L1, L1], #4 + [L2, L1, L1, L1], #5 + [L1, L1, L1, L1], #6 + [L1, L1, L1, L1], #7 ] all_output_test_data = client_only_output_data + client_and_request_output_data @@ -166,6 +166,12 @@ def patch_item_test_data(): all_test_data = [input_data + [output_data] for input_data, output_data in zip(ALL_INPUT_TEST_DATA, all_output_test_data)] return all_test_data +async def _create_item_with_excluded_locations(container, body, excluded_locations): + if excluded_locations is None: + await container.create_item(body=body) + else: + await container.create_item(body=body, excluded_locations=excluded_locations) + @pytest_asyncio.fixture(scope="class", autouse=True) async def setup_and_teardown(): print("Setup: This runs before any tests") @@ -344,10 +350,7 @@ async def test_create_item(self, test_data): # API call: create_item body = {'pk': 'pk', 'id': f'doc2-{str(uuid.uuid4())}'} - if request_excluded_locations is None: - await container.create_item(body=body) - else: - await container.create_item(body=body, excluded_locations=request_excluded_locations) + await _create_item_with_excluded_locations(container, body, request_excluded_locations) # get location from mock_handler if multiple_write_locations: @@ -421,12 +424,13 @@ async def test_delete_item(self, test_data): # Client setup client, db, container = await self._init_container(preferred_locations, client_excluded_locations, multiple_write_locations) - #create before delete + # create before delete item_id = f'doc2-{str(uuid.uuid4())}' - await container.create_item(body={PARTITION_KEY: ITEM_PK_VALUE, 'id': item_id}) + body = {PARTITION_KEY: ITEM_PK_VALUE, 'id': item_id} + await _create_item_with_excluded_locations(container, body, request_excluded_locations) MOCK_HANDLER.reset() - # API call: read_item + # API call: delete_item if request_excluded_locations is None: await container.delete_item(item_id, ITEM_PK_VALUE) else: From 36407c691b76d8ad8bcf2127f91f3c1cfa2dab2f Mon Sep 17 00:00:00 2001 From: tvaron3 Date: Tue, 8 Apr 2025 13:41:23 -0400 Subject: [PATCH 65/86] fix pylint and cspell --- .../azure/cosmos/_partition_health_tracker.py | 25 +++++------ ..._endpoint_manager_circuit_breaker_async.py | 9 ++-- .../tests/test_ppcb_sm_mrr_async.py | 45 +++++++++++++------ 3 files changed, 48 insertions(+), 31 deletions(-) diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/_partition_health_tracker.py b/sdk/cosmos/azure-cosmos/azure/cosmos/_partition_health_tracker.py index f6c75c274ab6..4b7c0522f5c2 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/_partition_health_tracker.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/_partition_health_tracker.py @@ -99,29 +99,28 @@ def __init__(self) -> None: # TODO: @tvaron3 look for useful places to add logs - def mark_partition_unavailable(self, pkrange_wrapper: PartitionKeyRangeWrapper, location: str) -> None: + def mark_partition_unavailable(self, pk_range_wrapper: PartitionKeyRangeWrapper, location: str) -> None: # mark the partition key range as unavailable - self._transition_health_status_on_failure(pkrange_wrapper, location) + self._transition_health_status_on_failure(pk_range_wrapper, location) def _transition_health_status_on_failure( self, - pkrange_wrapper: PartitionKeyRangeWrapper, + pk_range_wrapper: PartitionKeyRangeWrapper, location: str ) -> None: - logger.warn("{} has been marked as unavailable.".format(pkrange_wrapper)) - current_time = current_time_millis() - if pkrange_wrapper not in self.pk_range_wrapper_to_health_info: + logger.warning("%s has been marked as unavailable.", pk_range_wrapper) current_time = current_time_millis() + if pk_range_wrapper not in self.pk_range_wrapper_to_health_info: # healthy -> unhealthy tentative partition_health_info = _PartitionHealthInfo() partition_health_info.unavailability_info = { LAST_UNAVAILABILITY_CHECK_TIME_STAMP: current_time, HEALTH_STATUS: UNHEALTHY_TENTATIVE } - self.pk_range_wrapper_to_health_info[pkrange_wrapper] = { + self.pk_range_wrapper_to_health_info[pk_range_wrapper] = { location: partition_health_info } else: - region_to_partition_health = self.pk_range_wrapper_to_health_info[pkrange_wrapper] + region_to_partition_health = self.pk_range_wrapper_to_health_info[pk_range_wrapper] if location in region_to_partition_health and region_to_partition_health[location].unavailability_info: # healthy tentative -> unhealthy # if the operation type is not empty, we are in the healthy tentative state @@ -137,16 +136,16 @@ def _transition_health_status_on_failure( LAST_UNAVAILABILITY_CHECK_TIME_STAMP: current_time, HEALTH_STATUS: UNHEALTHY_TENTATIVE } - self.pk_range_wrapper_to_health_info[pkrange_wrapper][location] = partition_health_info + self.pk_range_wrapper_to_health_info[pk_range_wrapper][location] = partition_health_info def _transition_health_status_on_success( self, - pkrange_wrapper: PartitionKeyRangeWrapper, + pk_range_wrapper: PartitionKeyRangeWrapper, location: str ) -> None: - if pkrange_wrapper in self.pk_range_wrapper_to_health_info: + if pk_range_wrapper in self.pk_range_wrapper_to_health_info: # healthy tentative -> healthy - self.pk_range_wrapper_to_health_info[pkrange_wrapper].pop(location, None) + self.pk_range_wrapper_to_health_info[pk_range_wrapper].pop(location, None) def _check_stale_partition_info(self, pk_range_wrapper: PartitionKeyRangeWrapper) -> None: current_time = current_time_millis() @@ -178,7 +177,7 @@ def get_excluded_locations(self, pk_range_wrapper: PartitionKeyRangeWrapper) -> for location, partition_health_info in self.pk_range_wrapper_to_health_info[pk_range_wrapper].items(): if partition_health_info.unavailability_info: health_status = partition_health_info.unavailability_info[HEALTH_STATUS] - if health_status == UNHEALTHY_TENTATIVE or health_status == UNHEALTHY: + if health_status in (UNHEALTHY_TENTATIVE, UNHEALTHY): excluded_locations.add(location) return excluded_locations diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_global_partition_endpoint_manager_circuit_breaker_async.py b/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_global_partition_endpoint_manager_circuit_breaker_async.py index 2ba64690da7b..c1badbd8a167 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_global_partition_endpoint_manager_circuit_breaker_async.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_global_partition_endpoint_manager_circuit_breaker_async.py @@ -37,7 +37,7 @@ -class _GlobalPartitionEndpointManagerForCircuitBreakerAsync(_GlobalEndpointManager): +class _GlobalPartitionEndpointManagerForCircuitBreakerAsync(_GlobalEndpointManager): # pylint: disable=protected-access """ This internal class implements the logic for partition endpoint management for geo-replicated database accounts. @@ -54,17 +54,18 @@ async def create_pk_range_wrapper(self, request: RequestObject) -> PartitionKeyR partition_key_value = request.headers[HttpHeaders.PartitionKey] # get the partition key range for the given partition key target_container_link = None - for container_link, properties in self.client._container_properties_cache.items(): # pylint: disable=protected-access + for container_link, properties in self.client._container_properties_cache.items(): if properties["_rid"] == container_rid: target_container_link = container_link partition_key_definition = properties["partitionKey"] - partition_key = PartitionKey(path=partition_key_definition["paths"], kind=partition_key_definition["kind"]) + partition_key = PartitionKey(path=partition_key_definition["paths"], + kind=partition_key_definition["kind"]) epk_range = [partition_key._get_epk_range_for_partition_key(partition_key_value)] if not target_container_link: raise RuntimeError("Illegal state: the container cache is not properly initialized.") # TODO: @tvaron3 check different clients and create them in different ways - partition_ranges = await (self.client._routing_map_provider # pylint: disable=protected-access + partition_ranges = await (self.client._routing_map_provider .get_overlapping_ranges(target_container_link, epk_range)) partition_range = Range.PartitionKeyRangeToRange(partition_ranges[0]) return PartitionKeyRangeWrapper(partition_range, container_rid) diff --git a/sdk/cosmos/azure-cosmos/tests/test_ppcb_sm_mrr_async.py b/sdk/cosmos/azure-cosmos/tests/test_ppcb_sm_mrr_async.py index 8902a72d2f6d..e26c3a270d83 100644 --- a/sdk/cosmos/azure-cosmos/tests/test_ppcb_sm_mrr_async.py +++ b/sdk/cosmos/azure-cosmos/tests/test_ppcb_sm_mrr_async.py @@ -11,29 +11,29 @@ from azure.core.pipeline.transport._aiohttp import AioHttpTransport from azure.core.exceptions import ServiceResponseError +import test_config from azure.cosmos import PartitionKey from azure.cosmos._partition_health_tracker import HEALTH_STATUS, UNHEALTHY, UNHEALTHY_TENTATIVE from azure.cosmos.aio import CosmosClient from azure.cosmos.exceptions import CosmosHttpResponseError -from tests import test_config -from tests._fault_injection_transport_async import FaultInjectionTransportAsync +from _fault_injection_transport_async import FaultInjectionTransportAsync COLLECTION = "created_collection" @pytest_asyncio.fixture() async def setup(): os.environ["AZURE_COSMOS_ENABLE_CIRCUIT_BREAKER"] = "True" - client = CosmosClient(TestPPCBSmMrrAsync.host, TestPPCBSmMrrAsync.master_key, consistency_level="Session") - created_database = client.get_database_client(TestPPCBSmMrrAsync.TEST_DATABASE_ID) + client = CosmosClient(TestPerPartitionCircuitBreakerSmMrrAsync.host, TestPerPartitionCircuitBreakerSmMrrAsync.master_key, consistency_level="Session") + created_database = client.get_database_client(TestPerPartitionCircuitBreakerSmMrrAsync.TEST_DATABASE_ID) # print(TestPPCBSmMrrAsync.TEST_DATABASE_ID) - await client.create_database_if_not_exists(TestPPCBSmMrrAsync.TEST_DATABASE_ID) - created_collection = await created_database.create_container(TestPPCBSmMrrAsync.TEST_CONTAINER_SINGLE_PARTITION_ID, + await client.create_database_if_not_exists(TestPerPartitionCircuitBreakerSmMrrAsync.TEST_DATABASE_ID) + created_collection = await created_database.create_container(TestPerPartitionCircuitBreakerSmMrrAsync.TEST_CONTAINER_SINGLE_PARTITION_ID, partition_key=PartitionKey("/pk"), offer_throughput=10000) yield { COLLECTION: created_collection } - await created_database.delete_container(TestPPCBSmMrrAsync.TEST_CONTAINER_SINGLE_PARTITION_ID) + await created_database.delete_container(TestPerPartitionCircuitBreakerSmMrrAsync.TEST_CONTAINER_SINGLE_PARTITION_ID) await client.close() os.environ["AZURE_COSMOS_ENABLE_CIRCUIT_BREAKER"] = "False" @@ -50,7 +50,7 @@ def errors(): @pytest.mark.cosmosEmulator @pytest.mark.asyncio @pytest.mark.usefixtures("setup") -class TestPPCBSmMrrAsync: +class TestPerPartitionCircuitBreakerSmMrrAsync: host = test_config.TestConfig.host master_key = test_config.TestConfig.masterKey connectionPolicy = test_config.TestConfig.connectionPolicy @@ -70,6 +70,23 @@ async def cleanup_method(initialized_objects: Dict[str, Any]): method_client: CosmosClient = initialized_objects["client"] await method_client.close() + async def perform_write_operation(operation, container, id, pk): + document_definition = {'id': id, + 'pk': pk, + 'name': 'sample document', + 'key': 'value'} + if operation == "create": + await container.create_item(body=document_definition) + elif operation == "upsert": + await container.upsert_item(body=document_definition) + elif operation == "replace": + await container.replace_item(item=document_definition['id'], body=document_definition) + elif operation == "delete": + await container.delete_item(item=document_definition['id'], partition_key=document_definition['pk']) + elif operation == "read": + await container.read_item(item=document_definition['id'], partition_key=document_definition['pk']) + + async def create_custom_transport_sm_mrr(self): custom_transport = FaultInjectionTransportAsync() # Inject rule to disallow writes in the read-only region @@ -119,7 +136,7 @@ async def test_consecutive_failure_threshold_async(self, setup, error): await container.create_item(body=document_definition) global_endpoint_manager = container.client_connection._global_endpoint_manager - TestPPCBSmMrrAsync.validate_unhealthy_partitions(global_endpoint_manager, 0) + TestPerPartitionCircuitBreakerSmMrrAsync.validate_unhealthy_partitions(global_endpoint_manager, 0) # create item with client without fault injection await setup[COLLECTION].create_item(body=document_definition) @@ -127,7 +144,7 @@ async def test_consecutive_failure_threshold_async(self, setup, error): # reads should fail over and only the relevant partition should be marked as unavailable await container.read_item(item=document_definition['id'], partition_key=document_definition['pk']) # partition should not have been marked unavailable after one error - TestPPCBSmMrrAsync.validate_unhealthy_partitions(global_endpoint_manager, 0) + TestPerPartitionCircuitBreakerSmMrrAsync.validate_unhealthy_partitions(global_endpoint_manager, 0) for i in range(10): read_resp = await container.read_item(item=document_definition['id'], partition_key=document_definition['pk']) @@ -136,7 +153,7 @@ async def test_consecutive_failure_threshold_async(self, setup, error): assert request.url.startswith(expected_read_region_uri) # the partition should have been marked as unavailable after breaking read threshold - TestPPCBSmMrrAsync.validate_unhealthy_partitions(global_endpoint_manager, 1) + TestPerPartitionCircuitBreakerSmMrrAsync.validate_unhealthy_partitions(global_endpoint_manager, 1) @pytest.mark.parametrize("error", errors()) async def test_failure_rate_threshold_async(self, setup, error): @@ -173,7 +190,7 @@ async def test_failure_rate_threshold_async(self, setup, error): await container.upsert_item(body=document_definition_2) global_endpoint_manager = container.client_connection._global_endpoint_manager - TestPPCBSmMrrAsync.validate_unhealthy_partitions(global_endpoint_manager, 0) + TestPerPartitionCircuitBreakerSmMrrAsync.validate_unhealthy_partitions(global_endpoint_manager, 0) # create item with client without fault injection await setup[COLLECTION].create_item(body=document_definition) @@ -181,7 +198,7 @@ async def test_failure_rate_threshold_async(self, setup, error): # reads should fail over and only the relevant partition should be marked as unavailable await container.read_item(item=document_definition['id'], partition_key=document_definition['pk']) # partition should not have been marked unavailable after one error - TestPPCBSmMrrAsync.validate_unhealthy_partitions(global_endpoint_manager, 0) + TestPerPartitionCircuitBreakerSmMrrAsync.validate_unhealthy_partitions(global_endpoint_manager, 0) # lower minimum requests for testing global_endpoint_manager.global_partition_endpoint_manager_core.partition_health_tracker.MINIMUM_REQUESTS_FOR_FAILURE_RATE = 10 try: @@ -197,7 +214,7 @@ async def test_failure_rate_threshold_async(self, setup, error): assert request.url.startswith(expected_read_region_uri) # the partition should have been marked as unavailable after breaking read threshold - TestPPCBSmMrrAsync.validate_unhealthy_partitions(global_endpoint_manager, 1) + TestPerPartitionCircuitBreakerSmMrrAsync.validate_unhealthy_partitions(global_endpoint_manager, 1) finally: # restore minimum requests global_endpoint_manager.global_partition_endpoint_manager_core.partition_health_tracker.MINIMUM_REQUESTS_FOR_FAILURE_RATE = 100 From 4e2fd6b691478cd75efaf2d62f65b82f5cc23416 Mon Sep 17 00:00:00 2001 From: Allen Kim Date: Tue, 8 Apr 2025 10:46:37 -0700 Subject: [PATCH 66/86] Fix live test failures --- .../tests/test_excluded_locations_async.py | 24 +++++++++---------- 1 file changed, 12 insertions(+), 12 deletions(-) diff --git a/sdk/cosmos/azure-cosmos/tests/test_excluded_locations_async.py b/sdk/cosmos/azure-cosmos/tests/test_excluded_locations_async.py index dd6ce3776f68..4a39b6a78c2c 100644 --- a/sdk/cosmos/azure-cosmos/tests/test_excluded_locations_async.py +++ b/sdk/cosmos/azure-cosmos/tests/test_excluded_locations_async.py @@ -80,20 +80,20 @@ def emit(self, record): def read_item_test_data(): client_only_output_data = [ - [L1, L1], # 0 - [L2, L2], # 1 - [L1, L1], # 2 - [L1, L1], # 3 + [L1], # 0 + [L2], # 1 + [L1], # 2 + [L1], # 3 ] client_and_request_output_data = [ - [L2, L2], # 0 - [L2, L2], # 1 - [L2, L2], # 2 - [L1, L1], # 3 - [L1, L1], # 4 - [L1, L1], # 5 - [L1, L1], # 6 - [L1, L1], # 7 + [L2], # 0 + [L2], # 1 + [L2], # 2 + [L1], # 3 + [L1], # 4 + [L1], # 5 + [L1], # 6 + [L1], # 7 ] all_output_test_data = client_only_output_data + client_and_request_output_data From 1baf872d58142af58af08dffb6e2c20a8cad1771 Mon Sep 17 00:00:00 2001 From: tvaron3 Date: Tue, 8 Apr 2025 13:51:20 -0400 Subject: [PATCH 67/86] fix pylint --- .../azure-cosmos/azure/cosmos/_partition_health_tracker.py | 3 ++- ...global_partition_endpoint_manager_circuit_breaker_async.py | 4 ++-- sdk/cosmos/azure-cosmos/tests/test_ppcb_sm_mrr_async.py | 3 ++- 3 files changed, 6 insertions(+), 4 deletions(-) diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/_partition_health_tracker.py b/sdk/cosmos/azure-cosmos/azure/cosmos/_partition_health_tracker.py index 4b7c0522f5c2..9f30bac2bd2c 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/_partition_health_tracker.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/_partition_health_tracker.py @@ -108,7 +108,8 @@ def _transition_health_status_on_failure( pk_range_wrapper: PartitionKeyRangeWrapper, location: str ) -> None: - logger.warning("%s has been marked as unavailable.", pk_range_wrapper) current_time = current_time_millis() + logger.warning("%s has been marked as unavailable.", pk_range_wrapper) + current_time = current_time_millis() if pk_range_wrapper not in self.pk_range_wrapper_to_health_info: # healthy -> unhealthy tentative partition_health_info = _PartitionHealthInfo() diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_global_partition_endpoint_manager_circuit_breaker_async.py b/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_global_partition_endpoint_manager_circuit_breaker_async.py index c1badbd8a167..8fe8b9a79afe 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_global_partition_endpoint_manager_circuit_breaker_async.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_global_partition_endpoint_manager_circuit_breaker_async.py @@ -36,8 +36,8 @@ from azure.cosmos.aio._cosmos_client_connection_async import CosmosClientConnection - -class _GlobalPartitionEndpointManagerForCircuitBreakerAsync(_GlobalEndpointManager): # pylint: disable=protected-access +# pylint: disable=protected-access +class _GlobalPartitionEndpointManagerForCircuitBreakerAsync(_GlobalEndpointManager): """ This internal class implements the logic for partition endpoint management for geo-replicated database accounts. diff --git a/sdk/cosmos/azure-cosmos/tests/test_ppcb_sm_mrr_async.py b/sdk/cosmos/azure-cosmos/tests/test_ppcb_sm_mrr_async.py index e26c3a270d83..4a6b37b23f0b 100644 --- a/sdk/cosmos/azure-cosmos/tests/test_ppcb_sm_mrr_async.py +++ b/sdk/cosmos/azure-cosmos/tests/test_ppcb_sm_mrr_async.py @@ -236,7 +236,8 @@ def validate_unhealthy_partitions(global_endpoint_manager, assert unhealthy_partitions == expected_unhealthy_partitions - # test_failure_rate_threshold - add service response error - across operation types + # test_failure_rate_threshold - add service response error - across operation types - test recovering the partition again + # # test service request marks only a partition unavailable not an entire region - across operation types # test cosmos client timeout From e0dab2977be9f6519644aa15ef5569ddd14ad83c Mon Sep 17 00:00:00 2001 From: Allen Kim Date: Tue, 8 Apr 2025 11:11:30 -0700 Subject: [PATCH 68/86] Fix live test failures --- .../tests/test_excluded_locations_async.py | 26 +++++++++++++++++-- 1 file changed, 24 insertions(+), 2 deletions(-) diff --git a/sdk/cosmos/azure-cosmos/tests/test_excluded_locations_async.py b/sdk/cosmos/azure-cosmos/tests/test_excluded_locations_async.py index 4a39b6a78c2c..e5c1dcfeed26 100644 --- a/sdk/cosmos/azure-cosmos/tests/test_excluded_locations_async.py +++ b/sdk/cosmos/azure-cosmos/tests/test_excluded_locations_async.py @@ -100,6 +100,28 @@ def read_item_test_data(): all_test_data = [input_data + [output_data] for input_data, output_data in zip(ALL_INPUT_TEST_DATA, all_output_test_data)] return all_test_data +def read_all_item_test_data(): + client_only_output_data = [ + [L1, L1], # 0 + [L2, L2], # 1 + [L1, L1], # 2 + [L1, L1], # 3 + ] + client_and_request_output_data = [ + [L2, L2], # 0 + [L2, L2], # 1 + [L2, L2], # 2 + [L1, L1], # 3 + [L1, L1], # 4 + [L1, L1], # 5 + [L1, L1], # 6 + [L1, L1], # 7 + ] + all_output_test_data = client_only_output_data + client_and_request_output_data + + all_test_data = [input_data + [output_data] for input_data, output_data in zip(ALL_INPUT_TEST_DATA, all_output_test_data)] + return all_test_data + def query_items_change_feed_test_data(): client_only_output_data = [ [L1, L1, L1, L1], #0 @@ -243,7 +265,7 @@ async def test_read_item(self, test_data): # Verify endpoint locations await self._verify_endpoint(client, expected_locations) - @pytest.mark.parametrize('test_data', read_item_test_data()) + @pytest.mark.parametrize('test_data', read_all_item_test_data()) async def test_read_all_items(self, test_data): # Init test variables preferred_locations, client_excluded_locations, request_excluded_locations, expected_locations = test_data @@ -260,7 +282,7 @@ async def test_read_all_items(self, test_data): # Verify endpoint locations await self._verify_endpoint(client, expected_locations) - @pytest.mark.parametrize('test_data', read_item_test_data()) + @pytest.mark.parametrize('test_data', read_all_item_test_data()) async def test_query_items(self, test_data): # Init test variables preferred_locations, client_excluded_locations, request_excluded_locations, expected_locations = test_data From 798c12f513f2822e8e270ffc00111d05f57fdd28 Mon Sep 17 00:00:00 2001 From: Allen Kim Date: Tue, 8 Apr 2025 11:38:36 -0700 Subject: [PATCH 69/86] Add test_delete_all_items_by_partition_key --- .../tests/test_excluded_locations.py | 55 ++++++++++--------- 1 file changed, 28 insertions(+), 27 deletions(-) diff --git a/sdk/cosmos/azure-cosmos/tests/test_excluded_locations.py b/sdk/cosmos/azure-cosmos/tests/test_excluded_locations.py index 9af367303107..0a63136700c0 100644 --- a/sdk/cosmos/azure-cosmos/tests/test_excluded_locations.py +++ b/sdk/cosmos/azure-cosmos/tests/test_excluded_locations.py @@ -442,33 +442,34 @@ def test_delete_item(self, test_data): else: self._verify_endpoint(client, [L1]) - # TODO: enable this test once we figure out how to enable delete_all_items_by_partition_key feature - # @pytest.mark.parametrize('test_data', patch_item_test_data()) - # def test_delete_all_items_by_partition_key(self, test_data): - # # Init test variables - # preferred_locations, client_excluded_locations, request_excluded_locations, expected_locations = test_data - # - # for multiple_write_locations in [True, False]: - # # Client setup - # client, db, container = self._init_container(preferred_locations, client_excluded_locations, multiple_write_locations) - # - # #create before delete - # item_id = f'doc2-{str(uuid.uuid4())}' - # pk_value = f'temp_partition_key_value-{str(uuid.uuid4())}' - # container.create_item(body={PARTITION_KEY: pk_value, 'id': item_id}) - # MOCK_HANDLER.reset() - # - # # API call: read_item - # if request_excluded_locations is None: - # container.delete_all_items_by_partition_key(pk_value) - # else: - # container.delete_all_items_by_partition_key(pk_value, excluded_locations=request_excluded_locations) - # - # # Verify endpoint locations - # if multiple_write_locations: - # self._verify_endpoint(client, expected_locations) - # else: - # self._verify_endpoint(client, [L1]) + @pytest.mark.parametrize('test_data', patch_item_test_data()) + def test_delete_all_items_by_partition_key(self, test_data): + # Init test variables + preferred_locations, client_excluded_locations, request_excluded_locations, expected_locations = test_data + + for multiple_write_locations in [True, False]: + # Client setup + client, db, container = self._init_container(preferred_locations, client_excluded_locations, + multiple_write_locations) + + # create before delete + item_id = f'doc2-{str(uuid.uuid4())}' + pk_value = f'temp_partition_key_value-{str(uuid.uuid4())}' + body = {PARTITION_KEY: pk_value, 'id': item_id} + _create_item_with_excluded_locations(container, body, request_excluded_locations) + MOCK_HANDLER.reset() + + # API call: delete_item + if request_excluded_locations is None: + container.delete_all_items_by_partition_key(pk_value) + else: + container.delete_all_items_by_partition_key(pk_value, excluded_locations=request_excluded_locations) + + # Verify endpoint locations + if multiple_write_locations: + self._verify_endpoint(client, expected_locations) + else: + self._verify_endpoint(client, [L1]) if __name__ == "__main__": unittest.main() From 2c5b8fce52682014b4214f930a29f2b97b13e91c Mon Sep 17 00:00:00 2001 From: Allen Kim Date: Tue, 8 Apr 2025 15:01:29 -0700 Subject: [PATCH 70/86] Remove test_delete_all_items_by_partition_key --- .../tests/test_excluded_locations.py | 29 ------------------- .../tests/test_excluded_locations_async.py | 28 ------------------ 2 files changed, 57 deletions(-) diff --git a/sdk/cosmos/azure-cosmos/tests/test_excluded_locations.py b/sdk/cosmos/azure-cosmos/tests/test_excluded_locations.py index 0a63136700c0..d99f9a3b3fda 100644 --- a/sdk/cosmos/azure-cosmos/tests/test_excluded_locations.py +++ b/sdk/cosmos/azure-cosmos/tests/test_excluded_locations.py @@ -442,34 +442,5 @@ def test_delete_item(self, test_data): else: self._verify_endpoint(client, [L1]) - @pytest.mark.parametrize('test_data', patch_item_test_data()) - def test_delete_all_items_by_partition_key(self, test_data): - # Init test variables - preferred_locations, client_excluded_locations, request_excluded_locations, expected_locations = test_data - - for multiple_write_locations in [True, False]: - # Client setup - client, db, container = self._init_container(preferred_locations, client_excluded_locations, - multiple_write_locations) - - # create before delete - item_id = f'doc2-{str(uuid.uuid4())}' - pk_value = f'temp_partition_key_value-{str(uuid.uuid4())}' - body = {PARTITION_KEY: pk_value, 'id': item_id} - _create_item_with_excluded_locations(container, body, request_excluded_locations) - MOCK_HANDLER.reset() - - # API call: delete_item - if request_excluded_locations is None: - container.delete_all_items_by_partition_key(pk_value) - else: - container.delete_all_items_by_partition_key(pk_value, excluded_locations=request_excluded_locations) - - # Verify endpoint locations - if multiple_write_locations: - self._verify_endpoint(client, expected_locations) - else: - self._verify_endpoint(client, [L1]) - if __name__ == "__main__": unittest.main() diff --git a/sdk/cosmos/azure-cosmos/tests/test_excluded_locations_async.py b/sdk/cosmos/azure-cosmos/tests/test_excluded_locations_async.py index e5c1dcfeed26..109f9ff7207a 100644 --- a/sdk/cosmos/azure-cosmos/tests/test_excluded_locations_async.py +++ b/sdk/cosmos/azure-cosmos/tests/test_excluded_locations_async.py @@ -464,33 +464,5 @@ async def test_delete_item(self, test_data): else: await self._verify_endpoint(client, [L1]) - # TODO: enable this test once we figure out how to enable delete_all_items_by_partition_key feature - # @pytest.mark.parametrize('test_data', patch_item_test_data()) - # def test_delete_all_items_by_partition_key(self, test_data): - # # Init test variables - # preferred_locations, client_excluded_locations, request_excluded_locations, expected_locations = test_data - # - # for multiple_write_locations in [True, False]: - # # Client setup - # client, db, container = self._init_container(preferred_locations, client_excluded_locations, multiple_write_locations) - # - # #create before delete - # item_id = f'doc2-{str(uuid.uuid4())}' - # pk_value = f'temp_partition_key_value-{str(uuid.uuid4())}' - # container.create_item(body={PARTITION_KEY: pk_value, 'id': item_id}) - # MOCK_HANDLER.reset() - # - # # API call: read_item - # if request_excluded_locations is None: - # container.delete_all_items_by_partition_key(pk_value) - # else: - # container.delete_all_items_by_partition_key(pk_value, excluded_locations=request_excluded_locations) - # - # # Verify endpoint locations - # if multiple_write_locations: - # self._verify_endpoint(client, expected_locations) - # else: - # self._verify_endpoint(client, [L1]) - if __name__ == "__main__": unittest.main() From 739e09006c8023ed9b415d6d08a92234ed4890e8 Mon Sep 17 00:00:00 2001 From: tvaron3 Date: Tue, 8 Apr 2025 20:11:37 -0400 Subject: [PATCH 71/86] fix and add tests --- .../azure/cosmos/_retry_utility.py | 2 +- .../tests/test_ppcb_sm_mrr_async.py | 167 ++++++++++++------ 2 files changed, 110 insertions(+), 59 deletions(-) diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/_retry_utility.py b/sdk/cosmos/azure-cosmos/azure/cosmos/_retry_utility.py index 44c6b088696e..7d27885f10db 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/_retry_utility.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/_retry_utility.py @@ -59,7 +59,7 @@ def Execute(client, global_endpoint_manager, function, *args, **kwargs): # pylin :rtype: tuple of (dict, dict) """ pk_range_wrapper = None - if global_endpoint_manager.is_circuit_breaker_applicable(args[0]): + if args and global_endpoint_manager.is_circuit_breaker_applicable(args[0]): pk_range_wrapper = global_endpoint_manager.create_pk_range_wrapper(args[0]) # instantiate all retry policies here to be applied for each request execution endpointDiscovery_retry_policy = _endpoint_discovery_retry_policy.EndpointDiscoveryRetryPolicy( diff --git a/sdk/cosmos/azure-cosmos/tests/test_ppcb_sm_mrr_async.py b/sdk/cosmos/azure-cosmos/tests/test_ppcb_sm_mrr_async.py index 4a6b37b23f0b..c34c6ee8e731 100644 --- a/sdk/cosmos/azure-cosmos/tests/test_ppcb_sm_mrr_async.py +++ b/sdk/cosmos/azure-cosmos/tests/test_ppcb_sm_mrr_async.py @@ -24,7 +24,6 @@ async def setup(): os.environ["AZURE_COSMOS_ENABLE_CIRCUIT_BREAKER"] = "True" client = CosmosClient(TestPerPartitionCircuitBreakerSmMrrAsync.host, TestPerPartitionCircuitBreakerSmMrrAsync.master_key, consistency_level="Session") created_database = client.get_database_client(TestPerPartitionCircuitBreakerSmMrrAsync.TEST_DATABASE_ID) - # print(TestPPCBSmMrrAsync.TEST_DATABASE_ID) await client.create_database_if_not_exists(TestPerPartitionCircuitBreakerSmMrrAsync.TEST_DATABASE_ID) created_collection = await created_database.create_container(TestPerPartitionCircuitBreakerSmMrrAsync.TEST_CONTAINER_SINGLE_PARTITION_ID, partition_key=PartitionKey("/pk"), @@ -37,15 +36,24 @@ async def setup(): await client.close() os.environ["AZURE_COSMOS_ENABLE_CIRCUIT_BREAKER"] = "False" -def errors(): - errors_list = [] +def operations_and_errors(): + write_operations = ["create", "upsert", "replace", "delete", "patch", "batch"] + read_operations = ["read", "query", "changefeed"] + errors = [] error_codes = [408, 500, 502, 503] for error_code in error_codes: - errors_list.append(CosmosHttpResponseError( + errors.append(CosmosHttpResponseError( status_code=error_code, message="Some injected error.")) - errors_list.append(ServiceResponseError(message="Injected Service Response Error.")) - return errors_list + errors.append(ServiceResponseError(message="Injected Service Response Error.")) + params = [] + for write_operation in write_operations: + for read_operation in read_operations: + for error in errors: + params.append((write_operation, read_operation, error)) + + return params + @pytest.mark.cosmosEmulator @pytest.mark.asyncio @@ -70,21 +78,52 @@ async def cleanup_method(initialized_objects: Dict[str, Any]): method_client: CosmosClient = initialized_objects["client"] await method_client.close() - async def perform_write_operation(operation, container, id, pk): - document_definition = {'id': id, - 'pk': pk, - 'name': 'sample document', - 'key': 'value'} + @staticmethod + async def perform_write_operation(operation, container, doc_id, pk): + doc = {'id': doc_id, + 'pk': pk, + 'name': 'sample document', + 'key': 'value'} if operation == "create": - await container.create_item(body=document_definition) + await container.create_item(body=doc) elif operation == "upsert": - await container.upsert_item(body=document_definition) + await container.upsert_item(body=doc) elif operation == "replace": - await container.replace_item(item=document_definition['id'], body=document_definition) + new_doc = {'id': doc_id, + 'pk': pk, + 'name': 'sample document' + str(uuid), + 'key': 'value'} + await container.replace_item(item=doc['id'], body=new_doc) elif operation == "delete": - await container.delete_item(item=document_definition['id'], partition_key=document_definition['pk']) - elif operation == "read": - await container.read_item(item=document_definition['id'], partition_key=document_definition['pk']) + await container.create_item(body=doc) + await container.delete_item(item=doc['id'], partition_key=doc['pk']) + elif operation == "patch": + operations = [{"op": "incr", "path": "/company", "value": 3}] + await container.patch_item(item=doc['id'], partition_key=doc['pk'], operations=operations) + elif operation == "batch": + batch_operations = [ + ("create", (doc, )), + ("upsert", (doc,)), + ("upsert", (doc,)), + ("upsert", (doc,)), + ] + await container.execute_item_batch(batch_operations, partition_key=doc['pk']) + + @staticmethod + async def perform_read_operation(operation, container, doc_id, pk, expected_read_region_uri): + if operation == "read": + read_resp = await container.read_item(item=doc_id, partition_key=pk) + request = read_resp.get_response_headers()["_request"] + # Validate the response comes from "Read Region" (the most preferred read-only region) + assert request.url.startswith(expected_read_region_uri) + elif operation == "query": + query = "SELECT * FROM c WHERE c.id = @id AND c.pk = @pk" + parameters = [{"name": "@id", "value": doc_id}, {"name": "@pk", "value": pk}] + async for _ in container.query_items(query=query, parameters=parameters): + pass + elif operation == "changefeed": + async for _ in container.query_items_change_feed(): + pass async def create_custom_transport_sm_mrr(self): @@ -110,8 +149,8 @@ async def create_custom_transport_sm_mrr(self): emulator_as_multi_region_sm_account_transformation) return custom_transport - @pytest.mark.parametrize("error", errors()) - async def test_consecutive_failure_threshold_async(self, setup, error): + @pytest.mark.parametrize("write_operation, read_operation, error", operations_and_errors()) + async def test_consecutive_failure_threshold_async(self, setup, write_operation, read_operation, error): expected_read_region_uri = self.host expected_write_region_uri = expected_read_region_uri.replace("localhost", "127.0.0.1") custom_transport = await self.create_custom_transport_sm_mrr() @@ -133,7 +172,10 @@ async def test_consecutive_failure_threshold_async(self, setup, error): # writes should fail in sm mrr with circuit breaker and should not mark unavailable a partition for i in range(6): with pytest.raises((CosmosHttpResponseError, ServiceResponseError)): - await container.create_item(body=document_definition) + await TestPerPartitionCircuitBreakerSmMrrAsync.perform_write_operation(write_operation, + container, + document_definition['id'], + document_definition['pk']) global_endpoint_manager = container.client_connection._global_endpoint_manager TestPerPartitionCircuitBreakerSmMrrAsync.validate_unhealthy_partitions(global_endpoint_manager, 0) @@ -142,21 +184,26 @@ async def test_consecutive_failure_threshold_async(self, setup, error): await setup[COLLECTION].create_item(body=document_definition) # reads should fail over and only the relevant partition should be marked as unavailable - await container.read_item(item=document_definition['id'], partition_key=document_definition['pk']) + await TestPerPartitionCircuitBreakerSmMrrAsync.perform_read_operation(read_operation, + container, + document_definition['id'], + document_definition['pk'], + expected_read_region_uri) # partition should not have been marked unavailable after one error TestPerPartitionCircuitBreakerSmMrrAsync.validate_unhealthy_partitions(global_endpoint_manager, 0) for i in range(10): - read_resp = await container.read_item(item=document_definition['id'], partition_key=document_definition['pk']) - request = read_resp.get_response_headers()["_request"] - # Validate the response comes from "Read Region" (the most preferred read-only region) - assert request.url.startswith(expected_read_region_uri) + await TestPerPartitionCircuitBreakerSmMrrAsync.perform_read_operation(read_operation, + container, + document_definition['id'], + document_definition['pk'], + expected_read_region_uri) # the partition should have been marked as unavailable after breaking read threshold TestPerPartitionCircuitBreakerSmMrrAsync.validate_unhealthy_partitions(global_endpoint_manager, 1) - @pytest.mark.parametrize("error", errors()) - async def test_failure_rate_threshold_async(self, setup, error): + @pytest.mark.parametrize("write_operation, read_operation, error", operations_and_errors()) + async def test_failure_rate_threshold_async(self, setup, write_operation, read_operation, error): expected_read_region_uri = self.host expected_write_region_uri = expected_read_region_uri.replace("localhost", "127.0.0.1") custom_transport = await self.create_custom_transport_sm_mrr() @@ -166,10 +213,10 @@ async def test_failure_rate_threshold_async(self, setup, error): 'pk': 'pk1', 'name': 'sample document', 'key': 'value'} - document_definition_2 = {'id': str(uuid.uuid4()), - 'pk': 'pk1', - 'name': 'sample document', - 'key': 'value'} + doc_2 = {'id': str(uuid.uuid4()), + 'pk': 'pk1', + 'name': 'sample document', + 'key': 'value'} predicate = lambda r: (FaultInjectionTransportAsync.predicate_req_for_document_with_id(r, id_value) and FaultInjectionTransportAsync.predicate_targets_region(r, expected_write_region_uri)) custom_transport.add_fault(predicate, @@ -180,45 +227,50 @@ async def test_failure_rate_threshold_async(self, setup, error): custom_setup = await self.setup_method_with_custom_transport(custom_transport, default_endpoint=self.host) container = custom_setup['col'] - - # writes should fail in sm mrr with circuit breaker and should not mark unavailable a partition - for i in range(6): - if i % 2 == 0: - with pytest.raises((CosmosHttpResponseError, ServiceResponseError)): - await container.upsert_item(body=document_definition) - else: - await container.upsert_item(body=document_definition_2) global_endpoint_manager = container.client_connection._global_endpoint_manager - - TestPerPartitionCircuitBreakerSmMrrAsync.validate_unhealthy_partitions(global_endpoint_manager, 0) - - # create item with client without fault injection - await setup[COLLECTION].create_item(body=document_definition) - - # reads should fail over and only the relevant partition should be marked as unavailable - await container.read_item(item=document_definition['id'], partition_key=document_definition['pk']) - # partition should not have been marked unavailable after one error - TestPerPartitionCircuitBreakerSmMrrAsync.validate_unhealthy_partitions(global_endpoint_manager, 0) # lower minimum requests for testing global_endpoint_manager.global_partition_endpoint_manager_core.partition_health_tracker.MINIMUM_REQUESTS_FOR_FAILURE_RATE = 10 try: + # writes should fail in sm mrr with circuit breaker and should not mark unavailable a partition + for i in range(14): + if i == 9: + await container.upsert_item(body=doc_2) + with pytest.raises((CosmosHttpResponseError, ServiceResponseError)): + await TestPerPartitionCircuitBreakerSmMrrAsync.perform_write_operation(write_operation, + container, + document_definition['id'], + document_definition['pk']) + + TestPerPartitionCircuitBreakerSmMrrAsync.validate_unhealthy_partitions(global_endpoint_manager, 0) + + # create item with client without fault injection + await setup[COLLECTION].create_item(body=document_definition) + + # reads should fail over and only the relevant partition should be marked as unavailable + await container.read_item(item=document_definition['id'], partition_key=document_definition['pk']) + # partition should not have been marked unavailable after one error + TestPerPartitionCircuitBreakerSmMrrAsync.validate_unhealthy_partitions(global_endpoint_manager, 0) for i in range(20): if i == 8: - read_resp = await container.read_item(item=document_definition_2['id'], - partition_key=document_definition_2['pk']) + read_resp = await container.read_item(item=doc_2['id'], + partition_key=doc_2['pk']) + request = read_resp.get_response_headers()["_request"] + # Validate the response comes from "Read Region" (the most preferred read-only region) + assert request.url.startswith(expected_read_region_uri) else: - read_resp = await container.read_item(item=document_definition['id'], - partition_key=document_definition['pk']) - request = read_resp.get_response_headers()["_request"] - # Validate the response comes from "Read Region" (the most preferred read-only region) - assert request.url.startswith(expected_read_region_uri) - + await TestPerPartitionCircuitBreakerSmMrrAsync.perform_read_operation(read_operation, + container, + document_definition['id'], + document_definition['pk'], + expected_read_region_uri) # the partition should have been marked as unavailable after breaking read threshold TestPerPartitionCircuitBreakerSmMrrAsync.validate_unhealthy_partitions(global_endpoint_manager, 1) finally: # restore minimum requests global_endpoint_manager.global_partition_endpoint_manager_core.partition_health_tracker.MINIMUM_REQUESTS_FOR_FAILURE_RATE = 100 + # look at the urls for verifying fall back + @staticmethod def validate_unhealthy_partitions(global_endpoint_manager, expected_unhealthy_partitions): @@ -237,7 +289,6 @@ def validate_unhealthy_partitions(global_endpoint_manager, assert unhealthy_partitions == expected_unhealthy_partitions # test_failure_rate_threshold - add service response error - across operation types - test recovering the partition again - # # test service request marks only a partition unavailable not an entire region - across operation types # test cosmos client timeout From d5c380a4098d8aaa3dcfeee65f6467d104714671 Mon Sep 17 00:00:00 2001 From: Tomas Varon Saldarriaga Date: Tue, 8 Apr 2025 23:33:54 -0400 Subject: [PATCH 72/86] add collection rid to batch --- ..._endpoint_manager_circuit_breaker_async.py | 39 +++++++++++-------- .../azure-cosmos/azure/cosmos/container.py | 2 + .../tests/test_ppcb_sm_mrr_async.py | 2 +- 3 files changed, 25 insertions(+), 18 deletions(-) diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_global_partition_endpoint_manager_circuit_breaker_async.py b/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_global_partition_endpoint_manager_circuit_breaker_async.py index 8fe8b9a79afe..d9a06317804e 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_global_partition_endpoint_manager_circuit_breaker_async.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_global_partition_endpoint_manager_circuit_breaker_async.py @@ -51,23 +51,28 @@ def __init__(self, client: "CosmosClientConnection"): async def create_pk_range_wrapper(self, request: RequestObject) -> PartitionKeyRangeWrapper: container_rid = request.headers[HttpHeaders.IntendedCollectionRID] - partition_key_value = request.headers[HttpHeaders.PartitionKey] - # get the partition key range for the given partition key - target_container_link = None - for container_link, properties in self.client._container_properties_cache.items(): - if properties["_rid"] == container_rid: - target_container_link = container_link - partition_key_definition = properties["partitionKey"] - partition_key = PartitionKey(path=partition_key_definition["paths"], - kind=partition_key_definition["kind"]) - - epk_range = [partition_key._get_epk_range_for_partition_key(partition_key_value)] - if not target_container_link: - raise RuntimeError("Illegal state: the container cache is not properly initialized.") - # TODO: @tvaron3 check different clients and create them in different ways - partition_ranges = await (self.client._routing_map_provider - .get_overlapping_ranges(target_container_link, epk_range)) - partition_range = Range.PartitionKeyRangeToRange(partition_ranges[0]) + print(request.headers) + if request.headers.get(HttpHeaders.PartitionKey): + partition_key_value = request.headers[HttpHeaders.PartitionKey] + # get the partition key range for the given partition key + target_container_link = None + for container_link, properties in self.client._container_properties_cache.items(): + if properties["_rid"] == container_rid: + target_container_link = container_link + partition_key_definition = properties["partitionKey"] + partition_key = PartitionKey(path=partition_key_definition["paths"], + kind=partition_key_definition["kind"]) + + epk_range = [partition_key._get_epk_range_for_partition_key(partition_key_value)] + if not target_container_link: + raise RuntimeError("Illegal state: the container cache is not properly initialized.") + # TODO: @tvaron3 check different clients and create them in different ways + partition_ranges = await (self.client._routing_map_provider + .get_overlapping_ranges(target_container_link, epk_range)) + partition_range = Range.PartitionKeyRangeToRange(partition_ranges[0]) + elif request.headers.get(HttpHeaders.PartitionKeyRangeID): + pk_range_id = request.headers[HttpHeaders.PartitionKeyRangeID] + return PartitionKeyRangeWrapper(partition_range, container_rid) def is_circuit_breaker_applicable(self, request: RequestObject) -> bool: diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/container.py b/sdk/cosmos/azure-cosmos/azure/cosmos/container.py index 3f48824c9f7d..c659ae746b4a 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/container.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/container.py @@ -1088,6 +1088,8 @@ def execute_item_batch( request_options = build_options(kwargs) request_options["partitionKey"] = self._set_partition_key(partition_key) request_options["disableAutomaticIdGeneration"] = True + container_properties = self._get_properties() + request_options["containerRID"] = container_properties["_rid"] return self.client_connection.Batch( collection_link=self.container_link, batch_operations=batch_operations, options=request_options, **kwargs) diff --git a/sdk/cosmos/azure-cosmos/tests/test_ppcb_sm_mrr_async.py b/sdk/cosmos/azure-cosmos/tests/test_ppcb_sm_mrr_async.py index c34c6ee8e731..0be283a289cc 100644 --- a/sdk/cosmos/azure-cosmos/tests/test_ppcb_sm_mrr_async.py +++ b/sdk/cosmos/azure-cosmos/tests/test_ppcb_sm_mrr_async.py @@ -99,7 +99,7 @@ async def perform_write_operation(operation, container, doc_id, pk): await container.delete_item(item=doc['id'], partition_key=doc['pk']) elif operation == "patch": operations = [{"op": "incr", "path": "/company", "value": 3}] - await container.patch_item(item=doc['id'], partition_key=doc['pk'], operations=operations) + await container.patch_item(item=doc['id'], partition_key=doc['pk'], patch_operations=operations) elif operation == "batch": batch_operations = [ ("create", (doc, )), From e7f7265e7548b3ed3a215e0c0621d9b28ae711fe Mon Sep 17 00:00:00 2001 From: tvaron3 Date: Wed, 9 Apr 2025 10:43:18 -0400 Subject: [PATCH 73/86] add partition key range id to partition key range to cache --- .../_routing/aio/routing_map_provider.py | 38 ++++++++++++++++--- ..._endpoint_manager_circuit_breaker_async.py | 30 +++++++++------ .../tests/test_ppcb_sm_mrr_async.py | 3 +- 3 files changed, 53 insertions(+), 18 deletions(-) diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/_routing/aio/routing_map_provider.py b/sdk/cosmos/azure-cosmos/azure/cosmos/_routing/aio/routing_map_provider.py index e70ae355c495..f59513d05d24 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/_routing/aio/routing_map_provider.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/_routing/aio/routing_map_provider.py @@ -22,11 +22,13 @@ """Internal class for partition key range cache implementation in the Azure Cosmos database service. """ +from typing import Dict, Any from ... import _base from ..collection_routing_map import CollectionRoutingMap from .. import routing_range + # pylint: disable=protected-access @@ -58,13 +60,21 @@ async def get_overlapping_ranges(self, collection_link, partition_key_ranges, ** :return: List of overlapping partition key ranges. :rtype: list """ - cl = self._documentClient - collection_id = _base.GetResourceIdOrFullNameFromLink(collection_link) + await self.initialize_collection_routing_map_if_needed(collection_link, str(collection_id), **kwargs) + + return self._collection_routing_map_by_item[collection_id].get_overlapping_ranges(partition_key_ranges) + async def initialize_collection_routing_map_if_needed( + self, + collection_link: str, + collection_id: str, + **kwargs: Dict[str, Any] + ): + client = self._documentClient collection_routing_map = self._collection_routing_map_by_item.get(collection_id) if collection_routing_map is None: - collection_pk_ranges = [pk async for pk in cl._ReadPartitionKeyRanges(collection_link, **kwargs)] + collection_pk_ranges = [pk async for pk in client._ReadPartitionKeyRanges(collection_link, **kwargs)] # for large collections, a split may complete between the read partition key ranges query page responses, # causing the partitionKeyRanges to have both the children ranges and their parents. Therefore, we need # to discard the parent ranges to have a valid routing map. @@ -72,8 +82,18 @@ async def get_overlapping_ranges(self, collection_link, partition_key_ranges, ** collection_routing_map = CollectionRoutingMap.CompleteRoutingMap( [(r, True) for r in collection_pk_ranges], collection_id ) - self._collection_routing_map_by_item[collection_id] = collection_routing_map - return collection_routing_map.get_overlapping_ranges(partition_key_ranges) + self._collection_routing_map_by_item[collection_id] = collection_routing_map + + async def get_range_by_partition_key_range_id( + self, + collection_link: str, + partition_key_range_id: int, + **kwargs: Dict[str, Any] + ) -> Dict[str, Any]: + collection_id = _base.GetResourceIdOrFullNameFromLink(collection_link) + await self.initialize_collection_routing_map_if_needed(collection_link, str(collection_id), **kwargs) + + return self._collection_routing_map_by_item[collection_id].get_range_by_partition_key_range_id(partition_key_range_id) @staticmethod def _discard_parent_ranges(partitionKeyRanges): @@ -196,3 +216,11 @@ async def get_overlapping_ranges(self, collection_link, partition_key_ranges, ** pass return target_partition_key_ranges + + async def get_range_by_partition_key_range_id( + self, + collection_link: str, + partition_key_range_id: int, + **kwargs: Dict[str, Any] + ) -> Dict[str, Any]: + return await super().get_range_by_partition_key_range_id(collection_link, partition_key_range_id, **kwargs) diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_global_partition_endpoint_manager_circuit_breaker_async.py b/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_global_partition_endpoint_manager_circuit_breaker_async.py index d9a06317804e..77757b6a2752 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_global_partition_endpoint_manager_circuit_breaker_async.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_global_partition_endpoint_manager_circuit_breaker_async.py @@ -49,29 +49,35 @@ def __init__(self, client: "CosmosClientConnection"): self.global_partition_endpoint_manager_core = ( _GlobalPartitionEndpointManagerForCircuitBreakerCore(client, self.location_cache)) - async def create_pk_range_wrapper(self, request: RequestObject) -> PartitionKeyRangeWrapper: + async def create_pk_range_wrapper(self, request: RequestObject, kwargs) -> PartitionKeyRangeWrapper: container_rid = request.headers[HttpHeaders.IntendedCollectionRID] print(request.headers) + target_container_link = None + partition_key = None + # TODO: @tvaron3 see if the container link is absolutely necessary or change container cache + for container_link, properties in self.client._container_properties_cache.items(): + if properties["_rid"] == container_rid: + target_container_link = container_link + partition_key_definition = properties["partitionKey"] + partition_key = PartitionKey(path=partition_key_definition["paths"], + kind=partition_key_definition["kind"]) + + if not target_container_link or not partition_key: + raise RuntimeError("Illegal state: the container cache is not properly initialized.") + if request.headers.get(HttpHeaders.PartitionKey): partition_key_value = request.headers[HttpHeaders.PartitionKey] # get the partition key range for the given partition key - target_container_link = None - for container_link, properties in self.client._container_properties_cache.items(): - if properties["_rid"] == container_rid: - target_container_link = container_link - partition_key_definition = properties["partitionKey"] - partition_key = PartitionKey(path=partition_key_definition["paths"], - kind=partition_key_definition["kind"]) - - epk_range = [partition_key._get_epk_range_for_partition_key(partition_key_value)] - if not target_container_link: - raise RuntimeError("Illegal state: the container cache is not properly initialized.") # TODO: @tvaron3 check different clients and create them in different ways + epk_range = [partition_key._get_epk_range_for_partition_key(partition_key_value)] partition_ranges = await (self.client._routing_map_provider .get_overlapping_ranges(target_container_link, epk_range)) partition_range = Range.PartitionKeyRangeToRange(partition_ranges[0]) elif request.headers.get(HttpHeaders.PartitionKeyRangeID): pk_range_id = request.headers[HttpHeaders.PartitionKeyRangeID] + range = await (self.client._routing_map_provider + .get_range_by_partition_key_range_id(target_container_link, pk_range_id)) + partition_range = Range.PartitionKeyRangeToRange(range) return PartitionKeyRangeWrapper(partition_range, container_rid) diff --git a/sdk/cosmos/azure-cosmos/tests/test_ppcb_sm_mrr_async.py b/sdk/cosmos/azure-cosmos/tests/test_ppcb_sm_mrr_async.py index 0be283a289cc..99f7a822c29a 100644 --- a/sdk/cosmos/azure-cosmos/tests/test_ppcb_sm_mrr_async.py +++ b/sdk/cosmos/azure-cosmos/tests/test_ppcb_sm_mrr_async.py @@ -201,6 +201,7 @@ async def test_consecutive_failure_threshold_async(self, setup, write_operation, # the partition should have been marked as unavailable after breaking read threshold TestPerPartitionCircuitBreakerSmMrrAsync.validate_unhealthy_partitions(global_endpoint_manager, 1) + # test recovering the partition again @pytest.mark.parametrize("write_operation, read_operation, error", operations_and_errors()) async def test_failure_rate_threshold_async(self, setup, write_operation, read_operation, error): @@ -269,7 +270,7 @@ async def test_failure_rate_threshold_async(self, setup, write_operation, read_o # restore minimum requests global_endpoint_manager.global_partition_endpoint_manager_core.partition_health_tracker.MINIMUM_REQUESTS_FOR_FAILURE_RATE = 100 - # look at the urls for verifying fall back + # look at the urls for verifying fall back and use another id for same partition @staticmethod def validate_unhealthy_partitions(global_endpoint_manager, From 38f80331b1c32162d84ab7c718b4dd878ec046be Mon Sep 17 00:00:00 2001 From: Tomas Varon Saldarriaga Date: Wed, 9 Apr 2025 11:43:43 -0400 Subject: [PATCH 74/86] address failures --- ...obal_partition_endpoint_manager_circuit_breaker_async.py | 2 +- sdk/cosmos/azure-cosmos/tests/test_ppcb_sm_mrr_async.py | 6 +++++- 2 files changed, 6 insertions(+), 2 deletions(-) diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_global_partition_endpoint_manager_circuit_breaker_async.py b/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_global_partition_endpoint_manager_circuit_breaker_async.py index 77757b6a2752..5231ed5c06c4 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_global_partition_endpoint_manager_circuit_breaker_async.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_global_partition_endpoint_manager_circuit_breaker_async.py @@ -49,7 +49,7 @@ def __init__(self, client: "CosmosClientConnection"): self.global_partition_endpoint_manager_core = ( _GlobalPartitionEndpointManagerForCircuitBreakerCore(client, self.location_cache)) - async def create_pk_range_wrapper(self, request: RequestObject, kwargs) -> PartitionKeyRangeWrapper: + async def create_pk_range_wrapper(self, request: RequestObject) -> PartitionKeyRangeWrapper: container_rid = request.headers[HttpHeaders.IntendedCollectionRID] print(request.headers) target_container_link = None diff --git a/sdk/cosmos/azure-cosmos/tests/test_ppcb_sm_mrr_async.py b/sdk/cosmos/azure-cosmos/tests/test_ppcb_sm_mrr_async.py index 99f7a822c29a..4b2b33512bd0 100644 --- a/sdk/cosmos/azure-cosmos/tests/test_ppcb_sm_mrr_async.py +++ b/sdk/cosmos/azure-cosmos/tests/test_ppcb_sm_mrr_async.py @@ -119,8 +119,9 @@ async def perform_read_operation(operation, container, doc_id, pk, expected_read elif operation == "query": query = "SELECT * FROM c WHERE c.id = @id AND c.pk = @pk" parameters = [{"name": "@id", "value": doc_id}, {"name": "@pk", "value": pk}] - async for _ in container.query_items(query=query, parameters=parameters): + async for _ in container.query_items(query=query, partition_key=pk, parameters=parameters): pass + # need to do query with no pk and with feed range elif operation == "changefeed": async for _ in container.query_items_change_feed(): pass @@ -149,6 +150,9 @@ async def create_custom_transport_sm_mrr(self): emulator_as_multi_region_sm_account_transformation) return custom_transport + + # split this into write and read tests + @pytest.mark.parametrize("write_operation, read_operation, error", operations_and_errors()) async def test_consecutive_failure_threshold_async(self, setup, write_operation, read_operation, error): expected_read_region_uri = self.host From 828a99b5b044ea8b5d70b368eb8d2189afe72932 Mon Sep 17 00:00:00 2001 From: tvaron3 Date: Wed, 9 Apr 2025 11:52:22 -0400 Subject: [PATCH 75/86] update tests --- sdk/cosmos/azure-cosmos/tests/test_ppcb_sm_mrr_async.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sdk/cosmos/azure-cosmos/tests/test_ppcb_sm_mrr_async.py b/sdk/cosmos/azure-cosmos/tests/test_ppcb_sm_mrr_async.py index 4b2b33512bd0..113574a7725f 100644 --- a/sdk/cosmos/azure-cosmos/tests/test_ppcb_sm_mrr_async.py +++ b/sdk/cosmos/azure-cosmos/tests/test_ppcb_sm_mrr_async.py @@ -163,7 +163,7 @@ async def test_consecutive_failure_threshold_async(self, setup, write_operation, 'pk': 'pk1', 'name': 'sample document', 'key': 'value'} - predicate = lambda r: (FaultInjectionTransportAsync.predicate_req_for_document_with_id(r, id_value) and + predicate = lambda r: (FaultInjectionTransportAsync.predicate_is_document_operation(r) and FaultInjectionTransportAsync.predicate_targets_region(r, expected_write_region_uri)) custom_transport.add_fault(predicate, lambda r: asyncio.create_task(FaultInjectionTransportAsync.error_after_delay( 0, From 2b9b58fbc885781cda6c039bb231853d6c652b79 Mon Sep 17 00:00:00 2001 From: Allen Kim Date: Wed, 9 Apr 2025 19:40:18 -0700 Subject: [PATCH 76/86] Added missing doc for excluded_locations in async client --- sdk/cosmos/azure-cosmos/azure/cosmos/aio/_cosmos_client.py | 1 + 1 file changed, 1 insertion(+) diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_cosmos_client.py b/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_cosmos_client.py index 647f6d59f615..e5e526670629 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_cosmos_client.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_cosmos_client.py @@ -162,6 +162,7 @@ class CosmosClient: # pylint: disable=client-accepts-api-version-keyword :keyword bool enable_endpoint_discovery: Enable endpoint discovery for geo-replicated database accounts. (Default: True) :keyword list[str] preferred_locations: The preferred locations for geo-replicated database accounts. + :keyword list[str] excluded_locations: The excluded locations to be skipped from preferred locations. The locations :keyword bool enable_diagnostics_logging: Enable the CosmosHttpLogging policy. Must be used along with a logger to work. :keyword ~logging.Logger logger: Logger to be used for collecting request diagnostics. Can be passed in at client From 1c98b48f15660200d77a87e8811080bde636df42 Mon Sep 17 00:00:00 2001 From: Allen Kim Date: Wed, 9 Apr 2025 19:40:57 -0700 Subject: [PATCH 77/86] Remove duplicate functions --- .../tests/test_excluded_locations.py | 123 +++++++++--------- .../tests/test_excluded_locations_async.py | 101 ++++++-------- 2 files changed, 99 insertions(+), 125 deletions(-) diff --git a/sdk/cosmos/azure-cosmos/tests/test_excluded_locations.py b/sdk/cosmos/azure-cosmos/tests/test_excluded_locations.py index d99f9a3b3fda..49b7f0553871 100644 --- a/sdk/cosmos/azure-cosmos/tests/test_excluded_locations.py +++ b/sdk/cosmos/azure-cosmos/tests/test_excluded_locations.py @@ -188,51 +188,50 @@ def setup_and_teardown(): # Code to run after tests print("Teardown: This runs after all tests") -@pytest.mark.cosmosMultiRegion -class TestExcludedLocations: - def _init_container(self, preferred_locations, client_excluded_locations, multiple_write_locations = True): - client = cosmos_client.CosmosClient(HOST, KEY, - preferred_locations=preferred_locations, - excluded_locations=client_excluded_locations, - multiple_write_locations=multiple_write_locations) - db = client.get_database_client(DATABASE_ID) - container = db.get_container_client(CONTAINER_ID) - MOCK_HANDLER.reset() - - return client, db, container - - def _verify_endpoint(self, client, expected_locations): - # get mapping for locations - location_mapping = (client.client_connection._global_endpoint_manager. - location_cache.account_locations_by_write_regional_routing_context) - default_endpoint = (client.client_connection._global_endpoint_manager. - location_cache.default_regional_routing_context.get_primary()) - - # get Request URL - msgs = MOCK_HANDLER.messages - req_urls = [url.replace("Request URL: '", "") for url in msgs if 'Request URL:' in url] - - # get location - actual_locations = [] - for req_url in req_urls: - if req_url.startswith(default_endpoint): - actual_locations.append(L0) - else: - for endpoint in location_mapping: - if req_url.startswith(endpoint): - location = location_mapping[endpoint] - actual_locations.append(location) - break +def _init_container(preferred_locations, client_excluded_locations, multiple_write_locations = True): + client = cosmos_client.CosmosClient(HOST, KEY, + preferred_locations=preferred_locations, + excluded_locations=client_excluded_locations, + multiple_write_locations=multiple_write_locations) + db = client.get_database_client(DATABASE_ID) + container = db.get_container_client(CONTAINER_ID) + MOCK_HANDLER.reset() + + return client, db, container + +def _verify_endpoint(messages, client, expected_locations): + # get mapping for locations + location_mapping = (client.client_connection._global_endpoint_manager. + location_cache.account_locations_by_write_regional_routing_context) + default_endpoint = (client.client_connection._global_endpoint_manager. + location_cache.default_regional_routing_context.get_primary()) + + # get Request URL + req_urls = [url.replace("Request URL: '", "") for url in messages if 'Request URL:' in url] + + # get location + actual_locations = [] + for req_url in req_urls: + if req_url.startswith(default_endpoint): + actual_locations.append(L0) + else: + for endpoint in location_mapping: + if req_url.startswith(endpoint): + location = location_mapping[endpoint] + actual_locations.append(location) + break - assert actual_locations == expected_locations + assert actual_locations == expected_locations +@pytest.mark.cosmosMultiRegion +class TestExcludedLocations: @pytest.mark.parametrize('test_data', read_item_test_data()) def test_read_item(self, test_data): # Init test variables preferred_locations, client_excluded_locations, request_excluded_locations, expected_locations = test_data # Client setup - client, db, container = self._init_container(preferred_locations, client_excluded_locations) + client, db, container = _init_container(preferred_locations, client_excluded_locations) # API call: read_item if request_excluded_locations is None: @@ -241,7 +240,7 @@ def test_read_item(self, test_data): container.read_item(ITEM_ID, ITEM_PK_VALUE, excluded_locations=request_excluded_locations) # Verify endpoint locations - self._verify_endpoint(client, expected_locations) + _verify_endpoint(MOCK_HANDLER.messages, client, expected_locations) @pytest.mark.parametrize('test_data', read_item_test_data()) def test_read_all_items(self, test_data): @@ -249,7 +248,7 @@ def test_read_all_items(self, test_data): preferred_locations, client_excluded_locations, request_excluded_locations, expected_locations = test_data # Client setup - client, db, container = self._init_container(preferred_locations, client_excluded_locations) + client, db, container = _init_container(preferred_locations, client_excluded_locations) # API call: read_all_items if request_excluded_locations is None: @@ -258,7 +257,7 @@ def test_read_all_items(self, test_data): list(container.read_all_items(excluded_locations=request_excluded_locations)) # Verify endpoint locations - self._verify_endpoint(client, expected_locations) + _verify_endpoint(MOCK_HANDLER.messages, client, expected_locations) @pytest.mark.parametrize('test_data', read_item_test_data()) def test_query_items(self, test_data): @@ -266,7 +265,7 @@ def test_query_items(self, test_data): preferred_locations, client_excluded_locations, request_excluded_locations, expected_locations = test_data # Client setup and create an item - client, db, container = self._init_container(preferred_locations, client_excluded_locations) + client, db, container = _init_container(preferred_locations, client_excluded_locations) # API call: query_items if request_excluded_locations is None: @@ -275,7 +274,7 @@ def test_query_items(self, test_data): list(container.query_items(None, excluded_locations=request_excluded_locations)) # Verify endpoint locations - self._verify_endpoint(client, expected_locations) + _verify_endpoint(MOCK_HANDLER.messages, client, expected_locations) @pytest.mark.parametrize('test_data', query_items_change_feed_test_data()) def test_query_items_change_feed(self, test_data): @@ -284,7 +283,7 @@ def test_query_items_change_feed(self, test_data): # Client setup and create an item - client, db, container = self._init_container(preferred_locations, client_excluded_locations) + client, db, container = _init_container(preferred_locations, client_excluded_locations) # API call: query_items_change_feed if request_excluded_locations is None: @@ -293,7 +292,7 @@ def test_query_items_change_feed(self, test_data): items = list(container.query_items_change_feed(start_time="Beginning", excluded_locations=request_excluded_locations)) # Verify endpoint locations - self._verify_endpoint(client, expected_locations) + _verify_endpoint(MOCK_HANDLER.messages, client, expected_locations) @pytest.mark.parametrize('test_data', replace_item_test_data()) @@ -303,7 +302,7 @@ def test_replace_item(self, test_data): for multiple_write_locations in [True, False]: # Client setup and create an item - client, db, container = self._init_container(preferred_locations, client_excluded_locations, multiple_write_locations) + client, db, container = _init_container(preferred_locations, client_excluded_locations, multiple_write_locations) # API call: replace_item if request_excluded_locations is None: @@ -313,9 +312,9 @@ def test_replace_item(self, test_data): # Verify endpoint locations if multiple_write_locations: - self._verify_endpoint(client, expected_locations) + _verify_endpoint(MOCK_HANDLER.messages, client, expected_locations) else: - self._verify_endpoint(client, [expected_locations[0], L1]) + _verify_endpoint(client, [expected_locations[0], L1]) @pytest.mark.parametrize('test_data', replace_item_test_data()) def test_upsert_item(self, test_data): @@ -324,7 +323,7 @@ def test_upsert_item(self, test_data): for multiple_write_locations in [True, False]: # Client setup and create an item - client, db, container = self._init_container(preferred_locations, client_excluded_locations, multiple_write_locations) + client, db, container = _init_container(preferred_locations, client_excluded_locations, multiple_write_locations) # API call: upsert_item body = {'pk': 'pk', 'id': f'doc2-{str(uuid.uuid4())}'} @@ -335,9 +334,9 @@ def test_upsert_item(self, test_data): # get location from mock_handler if multiple_write_locations: - self._verify_endpoint(client, expected_locations) + _verify_endpoint(MOCK_HANDLER.messages, client, expected_locations) else: - self._verify_endpoint(client, [expected_locations[0], L1]) + _verify_endpoint(client, [expected_locations[0], L1]) @pytest.mark.parametrize('test_data', replace_item_test_data()) def test_create_item(self, test_data): @@ -346,7 +345,7 @@ def test_create_item(self, test_data): for multiple_write_locations in [True, False]: # Client setup and create an item - client, db, container = self._init_container(preferred_locations, client_excluded_locations, multiple_write_locations) + client, db, container = _init_container(preferred_locations, client_excluded_locations, multiple_write_locations) # API call: create_item body = {'pk': 'pk', 'id': f'doc2-{str(uuid.uuid4())}'} @@ -354,9 +353,9 @@ def test_create_item(self, test_data): # get location from mock_handler if multiple_write_locations: - self._verify_endpoint(client, expected_locations) + _verify_endpoint(MOCK_HANDLER.messages, client, expected_locations) else: - self._verify_endpoint(client, [expected_locations[0], L1]) + _verify_endpoint(client, [expected_locations[0], L1]) @pytest.mark.parametrize('test_data', patch_item_test_data()) def test_patch_item(self, test_data): @@ -365,7 +364,7 @@ def test_patch_item(self, test_data): for multiple_write_locations in [True, False]: # Client setup and create an item - client, db, container = self._init_container(preferred_locations, client_excluded_locations, + client, db, container = _init_container(preferred_locations, client_excluded_locations, multiple_write_locations) # API call: patch_item @@ -382,9 +381,9 @@ def test_patch_item(self, test_data): # get location from mock_handler if multiple_write_locations: - self._verify_endpoint(client, expected_locations) + _verify_endpoint(MOCK_HANDLER.messages, client, expected_locations) else: - self._verify_endpoint(client, [L1]) + _verify_endpoint(client, [L1]) @pytest.mark.parametrize('test_data', patch_item_test_data()) def test_execute_item_batch(self, test_data): @@ -393,7 +392,7 @@ def test_execute_item_batch(self, test_data): for multiple_write_locations in [True, False]: # Client setup and create an item - client, db, container = self._init_container(preferred_locations, client_excluded_locations, + client, db, container = _init_container(preferred_locations, client_excluded_locations, multiple_write_locations) # API call: execute_item_batch @@ -411,9 +410,9 @@ def test_execute_item_batch(self, test_data): # get location from mock_handler if multiple_write_locations: - self._verify_endpoint(client, expected_locations) + _verify_endpoint(MOCK_HANDLER.messages, client, expected_locations) else: - self._verify_endpoint(client, [L1]) + _verify_endpoint(client, [L1]) @pytest.mark.parametrize('test_data', patch_item_test_data()) def test_delete_item(self, test_data): @@ -422,7 +421,7 @@ def test_delete_item(self, test_data): for multiple_write_locations in [True, False]: # Client setup - client, db, container = self._init_container(preferred_locations, client_excluded_locations, multiple_write_locations) + client, db, container = _init_container(preferred_locations, client_excluded_locations, multiple_write_locations) # create before delete item_id = f'doc2-{str(uuid.uuid4())}' @@ -438,9 +437,9 @@ def test_delete_item(self, test_data): # Verify endpoint locations if multiple_write_locations: - self._verify_endpoint(client, expected_locations) + _verify_endpoint(MOCK_HANDLER.messages, client, expected_locations) else: - self._verify_endpoint(client, [L1]) + _verify_endpoint(client, [L1]) if __name__ == "__main__": unittest.main() diff --git a/sdk/cosmos/azure-cosmos/tests/test_excluded_locations_async.py b/sdk/cosmos/azure-cosmos/tests/test_excluded_locations_async.py index 109f9ff7207a..50c0b69acd76 100644 --- a/sdk/cosmos/azure-cosmos/tests/test_excluded_locations_async.py +++ b/sdk/cosmos/azure-cosmos/tests/test_excluded_locations_async.py @@ -10,7 +10,7 @@ from azure.cosmos.aio import CosmosClient from azure.cosmos.partition_key import PartitionKey - +from test_excluded_locations import _verify_endpoint class MockHandler(logging.Handler): def __init__(self): @@ -208,53 +208,28 @@ async def setup_and_teardown(): yield await test_client.close() +async def _init_container(preferred_locations, client_excluded_locations, multiple_write_locations = True): + client = CosmosClient(HOST, KEY, + preferred_locations=preferred_locations, + excluded_locations=client_excluded_locations, + multiple_write_locations=multiple_write_locations) + db = await client.create_database_if_not_exists(DATABASE_ID) + container = await db.create_container_if_not_exists(CONTAINER_ID, PartitionKey(path='/' + PARTITION_KEY, kind='Hash')) + MOCK_HANDLER.reset() + + return client, db, container + @pytest.mark.cosmosMultiRegion @pytest.mark.asyncio @pytest.mark.usefixtures("setup_and_teardown") class TestExcludedLocations: - async def _init_container(self, preferred_locations, client_excluded_locations, multiple_write_locations = True): - client = CosmosClient(HOST, KEY, - preferred_locations=preferred_locations, - excluded_locations=client_excluded_locations, - multiple_write_locations=multiple_write_locations) - db = await client.create_database_if_not_exists(DATABASE_ID) - container = await db.create_container_if_not_exists(CONTAINER_ID, PartitionKey(path='/' + PARTITION_KEY, kind='Hash')) - MOCK_HANDLER.reset() - - return client, db, container - - async def _verify_endpoint(self, client, expected_locations): - # get mapping for locations - location_mapping = (client.client_connection._global_endpoint_manager. - location_cache.account_locations_by_write_regional_routing_context) - default_endpoint = (client.client_connection._global_endpoint_manager. - location_cache.default_regional_routing_context.get_primary()) - - # get Request URL - msgs = MOCK_HANDLER.messages - req_urls = [url.replace("Request URL: '", "") for url in msgs if 'Request URL:' in url] - - # get location - actual_locations = [] - for req_url in req_urls: - if req_url.startswith(default_endpoint): - actual_locations.append(L0) - else: - for endpoint in location_mapping: - if req_url.startswith(endpoint): - location = location_mapping[endpoint] - actual_locations.append(location) - break - - assert actual_locations == expected_locations - @pytest.mark.parametrize('test_data', read_item_test_data()) async def test_read_item(self, test_data): # Init test variables preferred_locations, client_excluded_locations, request_excluded_locations, expected_locations = test_data # Client setup - client, db, container = await self._init_container(preferred_locations, client_excluded_locations) + client, db, container = await _init_container(preferred_locations, client_excluded_locations) # API call: read_item if request_excluded_locations is None: @@ -263,7 +238,7 @@ async def test_read_item(self, test_data): await container.read_item(ITEM_ID, ITEM_PK_VALUE, excluded_locations=request_excluded_locations) # Verify endpoint locations - await self._verify_endpoint(client, expected_locations) + _verify_endpoint(MOCK_HANDLER.messages, client, expected_locations) @pytest.mark.parametrize('test_data', read_all_item_test_data()) async def test_read_all_items(self, test_data): @@ -271,7 +246,7 @@ async def test_read_all_items(self, test_data): preferred_locations, client_excluded_locations, request_excluded_locations, expected_locations = test_data # Client setup - client, db, container = await self._init_container(preferred_locations, client_excluded_locations) + client, db, container = await _init_container(preferred_locations, client_excluded_locations) # API call: read_all_items if request_excluded_locations is None: @@ -280,7 +255,7 @@ async def test_read_all_items(self, test_data): all_items = [item async for item in container.read_all_items(excluded_locations=request_excluded_locations)] # Verify endpoint locations - await self._verify_endpoint(client, expected_locations) + _verify_endpoint(MOCK_HANDLER.messages, client, expected_locations) @pytest.mark.parametrize('test_data', read_all_item_test_data()) async def test_query_items(self, test_data): @@ -288,7 +263,7 @@ async def test_query_items(self, test_data): preferred_locations, client_excluded_locations, request_excluded_locations, expected_locations = test_data # Client setup and create an item - client, db, container = await self._init_container(preferred_locations, client_excluded_locations) + client, db, container = await _init_container(preferred_locations, client_excluded_locations) # API call: query_items if request_excluded_locations is None: @@ -297,7 +272,7 @@ async def test_query_items(self, test_data): all_items = [item async for item in container.query_items(None, excluded_locations=request_excluded_locations)] # Verify endpoint locations - await self._verify_endpoint(client, expected_locations) + _verify_endpoint(MOCK_HANDLER.messages, client, expected_locations) @pytest.mark.parametrize('test_data', query_items_change_feed_test_data()) async def test_query_items_change_feed(self, test_data): @@ -306,7 +281,7 @@ async def test_query_items_change_feed(self, test_data): # Client setup and create an item - client, db, container = await self._init_container(preferred_locations, client_excluded_locations) + client, db, container = await _init_container(preferred_locations, client_excluded_locations) # API call: query_items_change_feed if request_excluded_locations is None: @@ -315,7 +290,7 @@ async def test_query_items_change_feed(self, test_data): all_items = [item async for item in container.query_items_change_feed(start_time="Beginning", excluded_locations=request_excluded_locations)] # Verify endpoint locations - await self._verify_endpoint(client, expected_locations) + _verify_endpoint(MOCK_HANDLER.messages, client, expected_locations) @pytest.mark.parametrize('test_data', replace_item_test_data()) @@ -325,7 +300,7 @@ async def test_replace_item(self, test_data): for multiple_write_locations in [True, False]: # Client setup and create an item - client, db, container = await self._init_container(preferred_locations, client_excluded_locations, multiple_write_locations) + client, db, container = await _init_container(preferred_locations, client_excluded_locations, multiple_write_locations) # API call: replace_item if request_excluded_locations is None: @@ -335,9 +310,9 @@ async def test_replace_item(self, test_data): # Verify endpoint locations if multiple_write_locations: - await self._verify_endpoint(client, expected_locations) + _verify_endpoint(MOCK_HANDLER.messages, client, expected_locations) else: - await self._verify_endpoint(client, [L1]) + _verify_endpoint(client, [L1]) @pytest.mark.parametrize('test_data', replace_item_test_data()) async def test_upsert_item(self, test_data): @@ -346,7 +321,7 @@ async def test_upsert_item(self, test_data): for multiple_write_locations in [True, False]: # Client setup and create an item - client, db, container = await self._init_container(preferred_locations, client_excluded_locations, multiple_write_locations) + client, db, container = await _init_container(preferred_locations, client_excluded_locations, multiple_write_locations) # API call: upsert_item body = {'pk': 'pk', 'id': f'doc2-{str(uuid.uuid4())}'} @@ -357,9 +332,9 @@ async def test_upsert_item(self, test_data): # get location from mock_handler if multiple_write_locations: - await self._verify_endpoint(client, expected_locations) + _verify_endpoint(MOCK_HANDLER.messages, client, expected_locations) else: - await self._verify_endpoint(client, [L1]) + _verify_endpoint(client, [L1]) @pytest.mark.parametrize('test_data', replace_item_test_data()) async def test_create_item(self, test_data): @@ -368,7 +343,7 @@ async def test_create_item(self, test_data): for multiple_write_locations in [True, False]: # Client setup and create an item - client, db, container = await self._init_container(preferred_locations, client_excluded_locations, multiple_write_locations) + client, db, container = await _init_container(preferred_locations, client_excluded_locations, multiple_write_locations) # API call: create_item body = {'pk': 'pk', 'id': f'doc2-{str(uuid.uuid4())}'} @@ -376,9 +351,9 @@ async def test_create_item(self, test_data): # get location from mock_handler if multiple_write_locations: - await self._verify_endpoint(client, expected_locations) + _verify_endpoint(MOCK_HANDLER.messages, client, expected_locations) else: - await self._verify_endpoint(client, [L1]) + _verify_endpoint(client, [L1]) @pytest.mark.parametrize('test_data', patch_item_test_data()) async def test_patch_item(self, test_data): @@ -387,7 +362,7 @@ async def test_patch_item(self, test_data): for multiple_write_locations in [True, False]: # Client setup and create an item - client, db, container = await self._init_container(preferred_locations, client_excluded_locations, + client, db, container = await _init_container(preferred_locations, client_excluded_locations, multiple_write_locations) # API call: patch_item @@ -404,9 +379,9 @@ async def test_patch_item(self, test_data): # get location from mock_handler if multiple_write_locations: - await self._verify_endpoint(client, expected_locations) + _verify_endpoint(MOCK_HANDLER.messages, client, expected_locations) else: - await self._verify_endpoint(client, [L1]) + _verify_endpoint(client, [L1]) @pytest.mark.parametrize('test_data', patch_item_test_data()) async def test_execute_item_batch(self, test_data): @@ -415,7 +390,7 @@ async def test_execute_item_batch(self, test_data): for multiple_write_locations in [True, False]: # Client setup and create an item - client, db, container = await self._init_container(preferred_locations, client_excluded_locations, + client, db, container = await _init_container(preferred_locations, client_excluded_locations, multiple_write_locations) # API call: execute_item_batch @@ -433,9 +408,9 @@ async def test_execute_item_batch(self, test_data): # get location from mock_handler if multiple_write_locations: - await self._verify_endpoint(client, expected_locations) + _verify_endpoint(MOCK_HANDLER.messages, client, expected_locations) else: - await self._verify_endpoint(client, [L1]) + _verify_endpoint(client, [L1]) @pytest.mark.parametrize('test_data', patch_item_test_data()) async def test_delete_item(self, test_data): @@ -444,7 +419,7 @@ async def test_delete_item(self, test_data): for multiple_write_locations in [True, False]: # Client setup - client, db, container = await self._init_container(preferred_locations, client_excluded_locations, multiple_write_locations) + client, db, container = await _init_container(preferred_locations, client_excluded_locations, multiple_write_locations) # create before delete item_id = f'doc2-{str(uuid.uuid4())}' @@ -460,9 +435,9 @@ async def test_delete_item(self, test_data): # Verify endpoint locations if multiple_write_locations: - await self._verify_endpoint(client, expected_locations) + _verify_endpoint(MOCK_HANDLER.messages, client, expected_locations) else: - await self._verify_endpoint(client, [L1]) + _verify_endpoint(client, [L1]) if __name__ == "__main__": unittest.main() From b5accfa44f8ac6d74681444ce09dd26404401d40 Mon Sep 17 00:00:00 2001 From: tvaron3 Date: Thu, 10 Apr 2025 18:55:29 -0400 Subject: [PATCH 78/86] add more operations --- .../tests/test_ppcb_sm_mrr_async.py | 17 ++++++++++++++--- 1 file changed, 14 insertions(+), 3 deletions(-) diff --git a/sdk/cosmos/azure-cosmos/tests/test_ppcb_sm_mrr_async.py b/sdk/cosmos/azure-cosmos/tests/test_ppcb_sm_mrr_async.py index 113574a7725f..ee5e3fcd6fa9 100644 --- a/sdk/cosmos/azure-cosmos/tests/test_ppcb_sm_mrr_async.py +++ b/sdk/cosmos/azure-cosmos/tests/test_ppcb_sm_mrr_async.py @@ -38,7 +38,7 @@ async def setup(): def operations_and_errors(): write_operations = ["create", "upsert", "replace", "delete", "patch", "batch"] - read_operations = ["read", "query", "changefeed"] + read_operations = ["read", "query", "changefeed", "read_all_items", "delete_all_items_by_partition_key"] errors = [] error_codes = [408, 500, 502, 503] for error_code in error_codes: @@ -108,6 +108,11 @@ async def perform_write_operation(operation, container, doc_id, pk): ("upsert", (doc,)), ] await container.execute_item_batch(batch_operations, partition_key=doc['pk']) + elif operation == "delete_all_items_by_partition_key": + await container.create_item(body=doc) + await container.create_item(body=doc) + await container.create_item(body=doc) + await container.delete_all_items_by_partition_key(pk) @staticmethod async def perform_read_operation(operation, container, doc_id, pk, expected_read_region_uri): @@ -119,12 +124,18 @@ async def perform_read_operation(operation, container, doc_id, pk, expected_read elif operation == "query": query = "SELECT * FROM c WHERE c.id = @id AND c.pk = @pk" parameters = [{"name": "@id", "value": doc_id}, {"name": "@pk", "value": pk}] - async for _ in container.query_items(query=query, partition_key=pk, parameters=parameters): - pass + async for item in container.query_items(query=query, partition_key=pk, parameters=parameters): + assert item['id'] == doc_id # need to do query with no pk and with feed range elif operation == "changefeed": async for _ in container.query_items_change_feed(): pass + elif operation == "read_all_items": + async for item in container.read_all_items(partition_key=pk): + assert item['pk'] == pk + + + async def create_custom_transport_sm_mrr(self): From 8324a71e8ce47cf0501109b3fba0407cf058d65d Mon Sep 17 00:00:00 2001 From: Allen Kim Date: Fri, 11 Apr 2025 15:36:28 -0700 Subject: [PATCH 79/86] Fix live tests with multi write locations --- .../azure-cosmos/tests/test_excluded_locations.py | 12 ++++++------ .../tests/test_excluded_locations_async.py | 12 ++++++------ 2 files changed, 12 insertions(+), 12 deletions(-) diff --git a/sdk/cosmos/azure-cosmos/tests/test_excluded_locations.py b/sdk/cosmos/azure-cosmos/tests/test_excluded_locations.py index 49b7f0553871..4b517796dd3a 100644 --- a/sdk/cosmos/azure-cosmos/tests/test_excluded_locations.py +++ b/sdk/cosmos/azure-cosmos/tests/test_excluded_locations.py @@ -314,7 +314,7 @@ def test_replace_item(self, test_data): if multiple_write_locations: _verify_endpoint(MOCK_HANDLER.messages, client, expected_locations) else: - _verify_endpoint(client, [expected_locations[0], L1]) + _verify_endpoint(MOCK_HANDLER.messages, client, [expected_locations[0], L1]) @pytest.mark.parametrize('test_data', replace_item_test_data()) def test_upsert_item(self, test_data): @@ -336,7 +336,7 @@ def test_upsert_item(self, test_data): if multiple_write_locations: _verify_endpoint(MOCK_HANDLER.messages, client, expected_locations) else: - _verify_endpoint(client, [expected_locations[0], L1]) + _verify_endpoint(MOCK_HANDLER.messages, client, [expected_locations[0], L1]) @pytest.mark.parametrize('test_data', replace_item_test_data()) def test_create_item(self, test_data): @@ -355,7 +355,7 @@ def test_create_item(self, test_data): if multiple_write_locations: _verify_endpoint(MOCK_HANDLER.messages, client, expected_locations) else: - _verify_endpoint(client, [expected_locations[0], L1]) + _verify_endpoint(MOCK_HANDLER.messages, client, [expected_locations[0], L1]) @pytest.mark.parametrize('test_data', patch_item_test_data()) def test_patch_item(self, test_data): @@ -383,7 +383,7 @@ def test_patch_item(self, test_data): if multiple_write_locations: _verify_endpoint(MOCK_HANDLER.messages, client, expected_locations) else: - _verify_endpoint(client, [L1]) + _verify_endpoint(MOCK_HANDLER.messages, client, [L1]) @pytest.mark.parametrize('test_data', patch_item_test_data()) def test_execute_item_batch(self, test_data): @@ -412,7 +412,7 @@ def test_execute_item_batch(self, test_data): if multiple_write_locations: _verify_endpoint(MOCK_HANDLER.messages, client, expected_locations) else: - _verify_endpoint(client, [L1]) + _verify_endpoint(MOCK_HANDLER.messages, client, [L1]) @pytest.mark.parametrize('test_data', patch_item_test_data()) def test_delete_item(self, test_data): @@ -439,7 +439,7 @@ def test_delete_item(self, test_data): if multiple_write_locations: _verify_endpoint(MOCK_HANDLER.messages, client, expected_locations) else: - _verify_endpoint(client, [L1]) + _verify_endpoint(MOCK_HANDLER.messages, client, [L1]) if __name__ == "__main__": unittest.main() diff --git a/sdk/cosmos/azure-cosmos/tests/test_excluded_locations_async.py b/sdk/cosmos/azure-cosmos/tests/test_excluded_locations_async.py index 50c0b69acd76..11ababfdfafd 100644 --- a/sdk/cosmos/azure-cosmos/tests/test_excluded_locations_async.py +++ b/sdk/cosmos/azure-cosmos/tests/test_excluded_locations_async.py @@ -312,7 +312,7 @@ async def test_replace_item(self, test_data): if multiple_write_locations: _verify_endpoint(MOCK_HANDLER.messages, client, expected_locations) else: - _verify_endpoint(client, [L1]) + _verify_endpoint(MOCK_HANDLER.messages, client, [L1]) @pytest.mark.parametrize('test_data', replace_item_test_data()) async def test_upsert_item(self, test_data): @@ -334,7 +334,7 @@ async def test_upsert_item(self, test_data): if multiple_write_locations: _verify_endpoint(MOCK_HANDLER.messages, client, expected_locations) else: - _verify_endpoint(client, [L1]) + _verify_endpoint(MOCK_HANDLER.messages, client, [L1]) @pytest.mark.parametrize('test_data', replace_item_test_data()) async def test_create_item(self, test_data): @@ -353,7 +353,7 @@ async def test_create_item(self, test_data): if multiple_write_locations: _verify_endpoint(MOCK_HANDLER.messages, client, expected_locations) else: - _verify_endpoint(client, [L1]) + _verify_endpoint(MOCK_HANDLER.messages, client, [L1]) @pytest.mark.parametrize('test_data', patch_item_test_data()) async def test_patch_item(self, test_data): @@ -381,7 +381,7 @@ async def test_patch_item(self, test_data): if multiple_write_locations: _verify_endpoint(MOCK_HANDLER.messages, client, expected_locations) else: - _verify_endpoint(client, [L1]) + _verify_endpoint(MOCK_HANDLER.messages, client, [L1]) @pytest.mark.parametrize('test_data', patch_item_test_data()) async def test_execute_item_batch(self, test_data): @@ -410,7 +410,7 @@ async def test_execute_item_batch(self, test_data): if multiple_write_locations: _verify_endpoint(MOCK_HANDLER.messages, client, expected_locations) else: - _verify_endpoint(client, [L1]) + _verify_endpoint(MOCK_HANDLER.messages, client, [L1]) @pytest.mark.parametrize('test_data', patch_item_test_data()) async def test_delete_item(self, test_data): @@ -437,7 +437,7 @@ async def test_delete_item(self, test_data): if multiple_write_locations: _verify_endpoint(MOCK_HANDLER.messages, client, expected_locations) else: - _verify_endpoint(client, [L1]) + _verify_endpoint(MOCK_HANDLER.messages, client, [L1]) if __name__ == "__main__": unittest.main() From b65f07d5b733fb067d7c58ad43d269f8997b49d8 Mon Sep 17 00:00:00 2001 From: Allen Kim Date: Fri, 11 Apr 2025 15:38:50 -0700 Subject: [PATCH 80/86] Fixed bug with endpoint routing with multi write region partition key API calls --- .../azure/cosmos/_cosmos_client_connection.py | 4 ++-- sdk/cosmos/azure-cosmos/azure/cosmos/_location_cache.py | 1 + sdk/cosmos/azure-cosmos/azure/cosmos/_request_object.py | 9 +++++++-- 3 files changed, 10 insertions(+), 4 deletions(-) diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/_cosmos_client_connection.py b/sdk/cosmos/azure-cosmos/azure/cosmos/_cosmos_client_connection.py index 2ad76e73766d..2de1dedf58e7 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/_cosmos_client_connection.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/_cosmos_client_connection.py @@ -2192,8 +2192,8 @@ def DeleteAllItemsByPartitionKey( path = '{}{}/{}'.format(path, "operations", "partitionkeydelete") collection_id = base.GetResourceIdOrFullNameFromLink(collection_link) headers = base.GetHeaders(self, self.default_headers, "post", path, collection_id, - "partitionkey", documents._OperationType.Delete, options) - request_params = RequestObject("partitionkey", documents._OperationType.Delete) + http_constants.ResourceType.PartitionKey, documents._OperationType.Delete, options) + request_params = RequestObject(http_constants.ResourceType.PartitionKey, documents._OperationType.Delete) request_params.set_excluded_location_from_options(options) _, last_response_headers = self.__Post( path=path, diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/_location_cache.py b/sdk/cosmos/azure-cosmos/azure/cosmos/_location_cache.py index 02b293e29b4b..b2207b4431b6 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/_location_cache.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/_location_cache.py @@ -500,6 +500,7 @@ def can_use_multiple_write_locations(self): def can_use_multiple_write_locations_for_request(self, request): # pylint: disable=name-too-long return self.can_use_multiple_write_locations() and ( request.resource_type == http_constants.ResourceType.Document + or request.resource_type == http_constants.ResourceType.PartitionKey or ( request.resource_type == http_constants.ResourceType.StoredProcedure and request.operation_type == documents._OperationType.ExecuteJavaScript diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/_request_object.py b/sdk/cosmos/azure-cosmos/azure/cosmos/_request_object.py index 185aa1d89cb8..e5b39c221016 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/_request_object.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/_request_object.py @@ -22,7 +22,7 @@ """Represents a request object. """ from typing import Optional, Mapping, Any - +from . import http_constants class RequestObject(object): def __init__(self, resource_type: str, operation_type: str, endpoint_override: Optional[str] = None) -> None: @@ -57,7 +57,12 @@ def clear_route_to_location(self) -> None: def _can_set_excluded_location(self, options: Mapping[str, Any]) -> bool: # If resource types for requests are not one of the followings, excluded locations cannot be set - if self.resource_type.lower() not in ['docs', 'documents', 'partitionkey', 'colls']: + acceptable_resource_types = [ + http_constants.ResourceType.Document, + http_constants.ResourceType.PartitionKey, + http_constants.ResourceType.Collection, + ] + if self.resource_type.lower() not in acceptable_resource_types: return False # If 'excludedLocations' wasn't in the options, excluded locations cannot be set From 4a144d9dbe8ddf575eec9c9441607df28ccb6893 Mon Sep 17 00:00:00 2001 From: Allen Kim Date: Fri, 11 Apr 2025 15:41:39 -0700 Subject: [PATCH 81/86] Adding emulator tests for delete_all_items_by_partition_key API --- .../tests/_fault_injection_transport.py | 34 +++ .../tests/test_excluded_locations_emulator.py | 127 +++++++++ .../test_excluded_locations_emulator_async.py | 254 ++++++++++++++++++ 3 files changed, 415 insertions(+) create mode 100644 sdk/cosmos/azure-cosmos/tests/test_excluded_locations_emulator.py create mode 100644 sdk/cosmos/azure-cosmos/tests/test_excluded_locations_emulator_async.py diff --git a/sdk/cosmos/azure-cosmos/tests/_fault_injection_transport.py b/sdk/cosmos/azure-cosmos/tests/_fault_injection_transport.py index 628456d95158..0a0a81026ae7 100644 --- a/sdk/cosmos/azure-cosmos/tests/_fault_injection_transport.py +++ b/sdk/cosmos/azure-cosmos/tests/_fault_injection_transport.py @@ -228,6 +228,40 @@ def transform_topology_mwr( return response + @staticmethod + def transform_topology_mwr_with_url( + first_region_name: str, + first_region_url: str, + second_region_name: str, + second_region_url: str, + inner: Callable[[], RequestsTransportResponse]) -> RequestsTransportResponse: + + response = inner() + if not FaultInjectionTransport.predicate_is_database_account_call(response.request): + return response + + data = response.body() + if response.status_code == 200 and data: + readable_locations = [ + {"name": first_region_name, "databaseAccountEndpoint": first_region_url}, + {"name": second_region_name, "databaseAccountEndpoint": second_region_url} + ] + writeable_locations = [ + {"name": first_region_name, "databaseAccountEndpoint": first_region_url}, + {"name": second_region_name, "databaseAccountEndpoint": second_region_url} + ] + + data = data.decode("utf-8") + result = json.loads(data) + result["readableLocations"] = readable_locations + result["writableLocations"] = writeable_locations + result["enableMultipleWriteLocations"] = True + FaultInjectionTransport.logger.info("Transformed Account Topology: {}".format(result)) + request: HttpRequest = response.request + return FaultInjectionTransport.MockHttpResponse(request, 200, result) + + return response + class MockHttpResponse(RequestsTransportResponse): def __init__(self, request: HttpRequest, status_code: int, content:Optional[Dict[str, Any]]): self.request: HttpRequest = request diff --git a/sdk/cosmos/azure-cosmos/tests/test_excluded_locations_emulator.py b/sdk/cosmos/azure-cosmos/tests/test_excluded_locations_emulator.py new file mode 100644 index 000000000000..b39058aef6ba --- /dev/null +++ b/sdk/cosmos/azure-cosmos/tests/test_excluded_locations_emulator.py @@ -0,0 +1,127 @@ +# The MIT License (MIT) +# Copyright (c) Microsoft Corporation. All rights reserved. + +import logging +import unittest +import uuid +import test_config +import sys +import pytest +from typing import Callable, List, Mapping, Any + +from azure.core.rest import HttpRequest +from _fault_injection_transport import FaultInjectionTransport +from azure.cosmos.container import ContainerProxy +from test_excluded_locations import L1, L2, CLIENT_ONLY_TEST_DATA, CLIENT_AND_REQUEST_TEST_DATA +from test_fault_injection_transport import TestFaultInjectionTransport + +logger = logging.getLogger('azure.cosmos') +logger.setLevel(logging.DEBUG) +logger.addHandler(logging.StreamHandler(sys.stdout)) + +CONFIG = test_config.TestConfig() +HOST = CONFIG.host +KEY = CONFIG.masterKey +DATABASE_ID = CONFIG.TEST_DATABASE_ID +CONTAINER_ID = CONFIG.TEST_SINGLE_PARTITION_CONTAINER_ID + +L1_URL = test_config.TestConfig.local_host +L2_URL = L1_URL.replace("localhost", "127.0.0.1") +URL_TO_LOCATIONS = { + L1_URL: L1, + L2_URL: L2 +} + +ALL_INPUT_TEST_DATA = CLIENT_ONLY_TEST_DATA + CLIENT_AND_REQUEST_TEST_DATA + +def delete_all_items_by_partition_key_test_data() -> List[str]: + client_only_output_data = [ + L1, #0 + L2, #1 + L1, #3 + L1 #4 + ] + client_and_request_output_data = [ + L2, #0 + L2, #1 + L2, #2 + L1, #3 + L1, #4 + L1, #5 + L1, #6 + L1, #7 + ] + all_output_test_data = client_only_output_data + client_and_request_output_data + # all_output_test_data = client_and_request_output_data + + all_test_data = [input_data + [output_data] for input_data, output_data in zip(ALL_INPUT_TEST_DATA, all_output_test_data)] + return all_test_data + +def _get_location(initialized_objects: Mapping[str, Any]) -> str: + # get Request URL + header = initialized_objects['client'].client_connection.last_response_headers + request_url = header["_request"].url + + # verify + location = "" + for url in URL_TO_LOCATIONS: + if request_url.startswith(url): + location = URL_TO_LOCATIONS[url] + break + return location + +@pytest.mark.unittest +@pytest.mark.cosmosEmulator +class TestExcludedLocations: + @pytest.mark.parametrize('test_data', delete_all_items_by_partition_key_test_data()) + def test_delete_all_items_by_partition_key(self: "TestExcludedLocations", test_data: List[List[str]]): + # Init test variables + preferred_locations, client_excluded_locations, request_excluded_locations, expected_location = test_data + + # Inject topology transformation that would make Emulator look like a multiple write region account + # account with two read regions + custom_transport = FaultInjectionTransport() + is_get_account_predicate: Callable[[HttpRequest], bool] = lambda \ + r: FaultInjectionTransport.predicate_is_database_account_call(r) + emulator_as_multi_write_region_account_transformation = \ + lambda r, inner: FaultInjectionTransport.transform_topology_mwr_with_url( + first_region_name=L1, + first_region_url=L1_URL, + second_region_name=L2, + second_region_url=L2_URL, + inner=inner) + custom_transport.add_response_transformation( + is_get_account_predicate, + emulator_as_multi_write_region_account_transformation) + + for multiple_write_locations in [True, False]: + # Create client + initialized_objects = TestFaultInjectionTransport.setup_method_with_custom_transport( + custom_transport, + HOST, + preferred_locations=preferred_locations, + excluded_locations=client_excluded_locations, + multiple_write_locations=multiple_write_locations, + ) + container: ContainerProxy = initialized_objects["col"] + + # create an item + id_value: str = str(uuid.uuid4()) + document_definition = {'id': id_value, 'pk': id_value} + container.create_item(body=document_definition) + + # API call: delete_all_items_by_partition_key + if request_excluded_locations is None: + container.delete_all_items_by_partition_key(id_value) + else: + container.delete_all_items_by_partition_key(id_value, excluded_locations=request_excluded_locations) + + # Verify endpoint locations + actual_location = _get_location(initialized_objects) + if multiple_write_locations: + assert actual_location == expected_location + else: + assert actual_location == L1 + +if __name__ == "__main__": + unittest.main() diff --git a/sdk/cosmos/azure-cosmos/tests/test_excluded_locations_emulator_async.py b/sdk/cosmos/azure-cosmos/tests/test_excluded_locations_emulator_async.py new file mode 100644 index 000000000000..f468c80649ca --- /dev/null +++ b/sdk/cosmos/azure-cosmos/tests/test_excluded_locations_emulator_async.py @@ -0,0 +1,254 @@ +# The MIT License (MIT) +# Copyright (c) Microsoft Corporation. All rights reserved. + +import logging +import unittest +import uuid +import test_config +import pytest +import pytest_asyncio + +from azure.cosmos.aio import CosmosClient +from azure.cosmos.partition_key import PartitionKey +from test_excluded_locations import _verify_endpoint + +class MockHandler(logging.Handler): + def __init__(self): + super(MockHandler, self).__init__() + self.messages = [] + + def reset(self): + self.messages = [] + + def emit(self, record): + self.messages.append(record.msg) + +MOCK_HANDLER = MockHandler() +CONFIG = test_config.TestConfig() +HOST = CONFIG.host +KEY = CONFIG.masterKey +DATABASE_ID = CONFIG.TEST_DATABASE_ID +CONTAINER_ID = CONFIG.TEST_SINGLE_PARTITION_CONTAINER_ID +PARTITION_KEY = CONFIG.TEST_CONTAINER_PARTITION_KEY +ITEM_ID = 'doc1' +ITEM_PK_VALUE = 'pk' +TEST_ITEM = {'id': ITEM_ID, PARTITION_KEY: ITEM_PK_VALUE} + +L0 = "Default" +L1 = "West US 3" +L2 = "West US" +L3 = "East US 2" + +# L0 = "Default" +# L1 = "East US 2" +# L2 = "East US" +# L3 = "West US 2" + +CLIENT_ONLY_TEST_DATA = [ + # preferred_locations, client_excluded_locations, excluded_locations_request + # 0. No excluded location + [[L1, L2], [], None], + # 1. Single excluded location + [[L1, L2], [L1], None], + # 2. Exclude all locations + [[L1, L2], [L1, L2], None], + # 3. Exclude a location not in preferred locations + [[L1, L2], [L3], None], +] + +CLIENT_AND_REQUEST_TEST_DATA = [ + # preferred_locations, client_excluded_locations, excluded_locations_request + # 0. No client excluded locations + a request excluded location + [[L1, L2], [], [L1]], + # 1. The same client and request excluded location + [[L1, L2], [L1], [L1]], + # 2. Less request excluded locations + [[L1, L2], [L1, L2], [L1]], + # 3. More request excluded locations + [[L1, L2], [L1], [L1, L2]], + # 4. All locations were excluded + [[L1, L2], [L1, L2], [L1, L2]], + # 5. No common excluded locations + [[L1, L2], [L1], [L2, L3]], + # 6. Request excluded location not in preferred locations + [[L1, L2], [L1, L2], [L3]], + # 7. Empty excluded locations, remove all client excluded locations + [[L1, L2], [L1, L2], []], +] + +ALL_INPUT_TEST_DATA = CLIENT_ONLY_TEST_DATA + CLIENT_AND_REQUEST_TEST_DATA + +def read_item_test_data(): + client_only_output_data = [ + [L1], # 0 + [L2], # 1 + [L1], # 2 + [L1], # 3 + ] + client_and_request_output_data = [ + [L2], # 0 + [L2], # 1 + [L2], # 2 + [L1], # 3 + [L1], # 4 + [L1], # 5 + [L1], # 6 + [L1], # 7 + ] + all_output_test_data = client_only_output_data + client_and_request_output_data + + all_test_data = [input_data + [output_data] for input_data, output_data in zip(ALL_INPUT_TEST_DATA, all_output_test_data)] + return all_test_data + +def read_all_item_test_data(): + client_only_output_data = [ + [L1, L1], # 0 + [L2, L2], # 1 + [L1, L1], # 2 + [L1, L1], # 3 + ] + client_and_request_output_data = [ + [L2, L2], # 0 + [L2, L2], # 1 + [L2, L2], # 2 + [L1, L1], # 3 + [L1, L1], # 4 + [L1, L1], # 5 + [L1, L1], # 6 + [L1, L1], # 7 + ] + all_output_test_data = client_only_output_data + client_and_request_output_data + + all_test_data = [input_data + [output_data] for input_data, output_data in zip(ALL_INPUT_TEST_DATA, all_output_test_data)] + return all_test_data + +def query_items_change_feed_test_data(): + client_only_output_data = [ + [L1, L1, L1, L1], #0 + [L2, L2, L2, L2], #1 + [L1, L1, L1, L1], #2 + [L1, L1, L1, L1] #3 + ] + client_and_request_output_data = [ + [L1, L2, L2, L2], #0 + [L2, L2, L2, L2], #1 + [L1, L2, L2, L2], #2 + [L2, L1, L1, L1], #3 + [L1, L1, L1, L1], #4 + [L2, L1, L1, L1], #5 + [L1, L1, L1, L1], #6 + [L1, L1, L1, L1], #7 + ] + all_output_test_data = client_only_output_data + client_and_request_output_data + + all_test_data = [input_data + [output_data] for input_data, output_data in zip(ALL_INPUT_TEST_DATA, all_output_test_data)] + return all_test_data + +def replace_item_test_data(): + client_only_output_data = [ + [L1], #0 + [L2], #1 + [L0], #2 + [L1] #3 + ] + client_and_request_output_data = [ + [L2], #0 + [L2], #1 + [L2], #2 + [L0], #3 + [L0], #4 + [L1], #5 + [L1], #6 + [L1], #7 + ] + all_output_test_data = client_only_output_data + client_and_request_output_data + + all_test_data = [input_data + [output_data] for input_data, output_data in zip(ALL_INPUT_TEST_DATA, all_output_test_data)] + return all_test_data + +def patch_item_test_data(): + client_only_output_data = [ + [L1], #0 + [L2], #1 + [L0], #2 + [L1] #3 + ] + client_and_request_output_data = [ + [L2], #0 + [L2], #1 + [L2], #2 + [L0], #3 + [L0], #4 + [L1], #5 + [L1], #6 + [L1], #7 + ] + all_output_test_data = client_only_output_data + client_and_request_output_data + + all_test_data = [input_data + [output_data] for input_data, output_data in zip(ALL_INPUT_TEST_DATA, all_output_test_data)] + return all_test_data + +async def _create_item_with_excluded_locations(container, body, excluded_locations): + if excluded_locations is None: + await container.create_item(body=body) + else: + await container.create_item(body=body, excluded_locations=excluded_locations) + +@pytest_asyncio.fixture(scope="class", autouse=True) +async def setup_and_teardown(): + print("Setup: This runs before any tests") + logger = logging.getLogger("azure") + logger.addHandler(MOCK_HANDLER) + logger.setLevel(logging.DEBUG) + + test_client = CosmosClient(HOST, KEY) + container = test_client.get_database_client(DATABASE_ID).get_container_client(CONTAINER_ID) + await container.upsert_item(body=TEST_ITEM) + + yield + await test_client.close() + +async def _init_container(preferred_locations, client_excluded_locations, multiple_write_locations = True): + client = CosmosClient(HOST, KEY, + preferred_locations=preferred_locations, + excluded_locations=client_excluded_locations, + multiple_write_locations=multiple_write_locations) + db = await client.create_database_if_not_exists(DATABASE_ID) + container = await db.create_container_if_not_exists(CONTAINER_ID, PartitionKey(path='/' + PARTITION_KEY, kind='Hash')) + MOCK_HANDLER.reset() + + return client, db, container + +@pytest.mark.cosmosMultiRegion +@pytest.mark.asyncio +@pytest.mark.usefixtures("setup_and_teardown") +class TestExcludedLocations: + @pytest.mark.parametrize('test_data', patch_item_test_data()) + async def test_delete_item(self, test_data): + # Init test variables + preferred_locations, client_excluded_locations, request_excluded_locations, expected_locations = test_data + + for multiple_write_locations in [True, False]: + # Client setup + client, db, container = await _init_container(preferred_locations, client_excluded_locations, multiple_write_locations) + + # create before delete + item_id = f'doc2-{str(uuid.uuid4())}' + body = {PARTITION_KEY: ITEM_PK_VALUE, 'id': item_id} + await _create_item_with_excluded_locations(container, body, request_excluded_locations) + MOCK_HANDLER.reset() + + # API call: delete_item + if request_excluded_locations is None: + await container.delete_item(item_id, ITEM_PK_VALUE) + else: + await container.delete_item(item_id, ITEM_PK_VALUE, excluded_locations=request_excluded_locations) + + # Verify endpoint locations + if multiple_write_locations: + _verify_endpoint(MOCK_HANDLER.messages, client, expected_locations) + else: + _verify_endpoint(client, [L1]) + +if __name__ == "__main__": + unittest.main() From 9c68f753aaf551ff5e064f6d9740fc6ab3cae2fe Mon Sep 17 00:00:00 2001 From: Allen Kim Date: Mon, 14 Apr 2025 10:49:10 -0700 Subject: [PATCH 82/86] minimized duplicate codes --- .../tests/_fault_injection_transport.py | 52 +++++-------------- .../tests/test_excluded_locations_emulator.py | 31 +++++------ .../tests/test_fault_injection_transport.py | 14 +++-- 3 files changed, 36 insertions(+), 61 deletions(-) diff --git a/sdk/cosmos/azure-cosmos/tests/_fault_injection_transport.py b/sdk/cosmos/azure-cosmos/tests/_fault_injection_transport.py index 0a0a81026ae7..9a99229ec995 100644 --- a/sdk/cosmos/azure-cosmos/tests/_fault_injection_transport.py +++ b/sdk/cosmos/azure-cosmos/tests/_fault_injection_transport.py @@ -203,7 +203,10 @@ def transform_topology_swr_mrr( def transform_topology_mwr( first_region_name: str, second_region_name: str, - inner: Callable[[], RequestsTransportResponse]) -> RequestsTransportResponse: + inner: Callable[[], RequestsTransportResponse], + first_region_url: str = None, + second_region_url: str = test_config.TestConfig.local_host + ) -> RequestsTransportResponse: response = inner() if not FaultInjectionTransport.predicate_is_database_account_call(response.request): @@ -215,46 +218,17 @@ def transform_topology_mwr( result = json.loads(data) readable_locations = result["readableLocations"] writable_locations = result["writableLocations"] - readable_locations[0]["name"] = first_region_name - writable_locations[0]["name"] = first_region_name + + if first_region_url is None: + first_region_url = readable_locations[0]["databaseAccountEndpoint"] + readable_locations[0] = \ + {"name": first_region_name, "databaseAccountEndpoint": first_region_url} + writable_locations[0] = \ + {"name": first_region_name, "databaseAccountEndpoint": first_region_url} readable_locations.append( - {"name": second_region_name, "databaseAccountEndpoint": test_config.TestConfig.local_host}) + {"name": second_region_name, "databaseAccountEndpoint": second_region_url}) writable_locations.append( - {"name": second_region_name, "databaseAccountEndpoint": test_config.TestConfig.local_host}) - result["enableMultipleWriteLocations"] = True - FaultInjectionTransport.logger.info("Transformed Account Topology: {}".format(result)) - request: HttpRequest = response.request - return FaultInjectionTransport.MockHttpResponse(request, 200, result) - - return response - - @staticmethod - def transform_topology_mwr_with_url( - first_region_name: str, - first_region_url: str, - second_region_name: str, - second_region_url: str, - inner: Callable[[], RequestsTransportResponse]) -> RequestsTransportResponse: - - response = inner() - if not FaultInjectionTransport.predicate_is_database_account_call(response.request): - return response - - data = response.body() - if response.status_code == 200 and data: - readable_locations = [ - {"name": first_region_name, "databaseAccountEndpoint": first_region_url}, - {"name": second_region_name, "databaseAccountEndpoint": second_region_url} - ] - writeable_locations = [ - {"name": first_region_name, "databaseAccountEndpoint": first_region_url}, - {"name": second_region_name, "databaseAccountEndpoint": second_region_url} - ] - - data = data.decode("utf-8") - result = json.loads(data) - result["readableLocations"] = readable_locations - result["writableLocations"] = writeable_locations + {"name": second_region_name, "databaseAccountEndpoint": second_region_url}) result["enableMultipleWriteLocations"] = True FaultInjectionTransport.logger.info("Transformed Account Topology: {}".format(result)) request: HttpRequest = response.request diff --git a/sdk/cosmos/azure-cosmos/tests/test_excluded_locations_emulator.py b/sdk/cosmos/azure-cosmos/tests/test_excluded_locations_emulator.py index b39058aef6ba..7fcb558827f7 100644 --- a/sdk/cosmos/azure-cosmos/tests/test_excluded_locations_emulator.py +++ b/sdk/cosmos/azure-cosmos/tests/test_excluded_locations_emulator.py @@ -1,11 +1,9 @@ # The MIT License (MIT) # Copyright (c) Microsoft Corporation. All rights reserved. -import logging import unittest import uuid import test_config -import sys import pytest from typing import Callable, List, Mapping, Any @@ -15,15 +13,7 @@ from test_excluded_locations import L1, L2, CLIENT_ONLY_TEST_DATA, CLIENT_AND_REQUEST_TEST_DATA from test_fault_injection_transport import TestFaultInjectionTransport -logger = logging.getLogger('azure.cosmos') -logger.setLevel(logging.DEBUG) -logger.addHandler(logging.StreamHandler(sys.stdout)) - CONFIG = test_config.TestConfig() -HOST = CONFIG.host -KEY = CONFIG.masterKey -DATABASE_ID = CONFIG.TEST_DATABASE_ID -CONTAINER_ID = CONFIG.TEST_SINGLE_PARTITION_CONTAINER_ID L1_URL = test_config.TestConfig.local_host L2_URL = L1_URL.replace("localhost", "127.0.0.1") @@ -52,21 +42,22 @@ def delete_all_items_by_partition_key_test_data() -> List[str]: L1, #7 ] all_output_test_data = client_only_output_data + client_and_request_output_data - # all_output_test_data = client_and_request_output_data all_test_data = [input_data + [output_data] for input_data, output_data in zip(ALL_INPUT_TEST_DATA, all_output_test_data)] return all_test_data -def _get_location(initialized_objects: Mapping[str, Any]) -> str: +def _get_location( + initialized_objects: Mapping[str, Any], + url_to_locations: Mapping[str, str] = URL_TO_LOCATIONS) -> str: # get Request URL header = initialized_objects['client'].client_connection.last_response_headers request_url = header["_request"].url # verify location = "" - for url in URL_TO_LOCATIONS: + for url in url_to_locations: if request_url.startswith(url): - location = URL_TO_LOCATIONS[url] + location = url_to_locations[url] break return location @@ -84,12 +75,13 @@ def test_delete_all_items_by_partition_key(self: "TestExcludedLocations", test_d is_get_account_predicate: Callable[[HttpRequest], bool] = lambda \ r: FaultInjectionTransport.predicate_is_database_account_call(r) emulator_as_multi_write_region_account_transformation = \ - lambda r, inner: FaultInjectionTransport.transform_topology_mwr_with_url( + lambda r, inner: FaultInjectionTransport.transform_topology_mwr( first_region_name=L1, - first_region_url=L1_URL, second_region_name=L2, + inner=inner, + first_region_url=L1_URL, second_region_url=L2_URL, - inner=inner) + ) custom_transport.add_response_transformation( is_get_account_predicate, emulator_as_multi_write_region_account_transformation) @@ -98,7 +90,10 @@ def test_delete_all_items_by_partition_key(self: "TestExcludedLocations", test_d # Create client initialized_objects = TestFaultInjectionTransport.setup_method_with_custom_transport( custom_transport, - HOST, + default_endpoint=CONFIG.host, + key=CONFIG.masterKey, + database_id=CONFIG.TEST_DATABASE_ID, + container_id=CONFIG.TEST_SINGLE_PARTITION_CONTAINER_ID, preferred_locations=preferred_locations, excluded_locations=client_excluded_locations, multiple_write_locations=multiple_write_locations, diff --git a/sdk/cosmos/azure-cosmos/tests/test_fault_injection_transport.py b/sdk/cosmos/azure-cosmos/tests/test_fault_injection_transport.py index 304fa8d50f0d..4d7ea16ee58e 100644 --- a/sdk/cosmos/azure-cosmos/tests/test_fault_injection_transport.py +++ b/sdk/cosmos/azure-cosmos/tests/test_fault_injection_transport.py @@ -62,11 +62,17 @@ def teardown_class(cls): logger.warning("Exception trying to delete database {}. {}".format(created_database.id, containerDeleteError)) @staticmethod - def setup_method_with_custom_transport(custom_transport: RequestsTransport, default_endpoint=host, **kwargs): - client = CosmosClient(default_endpoint, master_key, consistency_level="Session", + def setup_method_with_custom_transport( + custom_transport: RequestsTransport, + default_endpoint: str = host, + key: str = master_key, + database_id: str = TEST_DATABASE_ID, + container_id: str = SINGLE_PARTITION_CONTAINER_NAME, + **kwargs): + client = CosmosClient(default_endpoint, key, consistency_level="Session", transport=custom_transport, logger=logger, enable_diagnostics_logging=True, **kwargs) - db: DatabaseProxy = client.get_database_client(TEST_DATABASE_ID) - container: ContainerProxy = db.get_container_client(SINGLE_PARTITION_CONTAINER_NAME) + db: DatabaseProxy = client.get_database_client(database_id) + container: ContainerProxy = db.get_container_client(container_id) return {"client": client, "db": db, "col": container} From 225bc26133a7d944e6becf60fa3ac33f46c0db1b Mon Sep 17 00:00:00 2001 From: Allen Kim Date: Mon, 14 Apr 2025 11:35:09 -0700 Subject: [PATCH 83/86] Added Async emulator tests --- .../tests/_fault_injection_transport_async.py | 18 +- .../tests/test_excluded_locations_emulator.py | 4 +- .../test_excluded_locations_emulator_async.py | 284 ++++-------------- .../test_fault_injection_transport_async.py | 34 ++- 4 files changed, 101 insertions(+), 239 deletions(-) diff --git a/sdk/cosmos/azure-cosmos/tests/_fault_injection_transport_async.py b/sdk/cosmos/azure-cosmos/tests/_fault_injection_transport_async.py index 13dda0dc7e20..6bdeb4ed49c9 100644 --- a/sdk/cosmos/azure-cosmos/tests/_fault_injection_transport_async.py +++ b/sdk/cosmos/azure-cosmos/tests/_fault_injection_transport_async.py @@ -201,7 +201,10 @@ async def transform_topology_swr_mrr( async def transform_topology_mwr( first_region_name: str, second_region_name: str, - inner: Callable[[], Awaitable[AioHttpTransportResponse]]) -> AioHttpTransportResponse: + inner: Callable[[], Awaitable[AioHttpTransportResponse]], + first_region_url: str = None, + second_region_url: str = test_config.TestConfig.local_host + ) -> AioHttpTransportResponse: response = await inner() if not FaultInjectionTransportAsync.predicate_is_database_account_call(response.request): @@ -213,12 +216,17 @@ async def transform_topology_mwr( result = json.loads(data) readable_locations = result["readableLocations"] writable_locations = result["writableLocations"] - readable_locations[0]["name"] = first_region_name - writable_locations[0]["name"] = first_region_name + + if first_region_url is None: + first_region_url = readable_locations[0]["databaseAccountEndpoint"] + readable_locations[0] = \ + {"name": first_region_name, "databaseAccountEndpoint": first_region_url} + writable_locations[0] = \ + {"name": first_region_name, "databaseAccountEndpoint": first_region_url} readable_locations.append( - {"name": second_region_name, "databaseAccountEndpoint": test_config.TestConfig.local_host}) + {"name": second_region_name, "databaseAccountEndpoint": second_region_url}) writable_locations.append( - {"name": second_region_name, "databaseAccountEndpoint": test_config.TestConfig.local_host}) + {"name": second_region_name, "databaseAccountEndpoint": second_region_url}) result["enableMultipleWriteLocations"] = True FaultInjectionTransportAsync.logger.info("Transformed Account Topology: {}".format(result)) request: HttpRequest = response.request diff --git a/sdk/cosmos/azure-cosmos/tests/test_excluded_locations_emulator.py b/sdk/cosmos/azure-cosmos/tests/test_excluded_locations_emulator.py index 7fcb558827f7..96b3fc185afb 100644 --- a/sdk/cosmos/azure-cosmos/tests/test_excluded_locations_emulator.py +++ b/sdk/cosmos/azure-cosmos/tests/test_excluded_locations_emulator.py @@ -46,7 +46,7 @@ def delete_all_items_by_partition_key_test_data() -> List[str]: all_test_data = [input_data + [output_data] for input_data, output_data in zip(ALL_INPUT_TEST_DATA, all_output_test_data)] return all_test_data -def _get_location( +def get_location( initialized_objects: Mapping[str, Any], url_to_locations: Mapping[str, str] = URL_TO_LOCATIONS) -> str: # get Request URL @@ -112,7 +112,7 @@ def test_delete_all_items_by_partition_key(self: "TestExcludedLocations", test_d container.delete_all_items_by_partition_key(id_value, excluded_locations=request_excluded_locations) # Verify endpoint locations - actual_location = _get_location(initialized_objects) + actual_location = get_location(initialized_objects) if multiple_write_locations: assert actual_location == expected_location else: diff --git a/sdk/cosmos/azure-cosmos/tests/test_excluded_locations_emulator_async.py b/sdk/cosmos/azure-cosmos/tests/test_excluded_locations_emulator_async.py index f468c80649ca..f706ca9e465e 100644 --- a/sdk/cosmos/azure-cosmos/tests/test_excluded_locations_emulator_async.py +++ b/sdk/cosmos/azure-cosmos/tests/test_excluded_locations_emulator_async.py @@ -1,254 +1,100 @@ # The MIT License (MIT) # Copyright (c) Microsoft Corporation. All rights reserved. -import logging import unittest import uuid import test_config import pytest -import pytest_asyncio +from typing import Callable, List, Mapping, Any -from azure.cosmos.aio import CosmosClient -from azure.cosmos.partition_key import PartitionKey -from test_excluded_locations import _verify_endpoint +from azure.core.pipeline.transport import AioHttpTransport +from _fault_injection_transport_async import FaultInjectionTransportAsync +from azure.cosmos.aio._container import ContainerProxy +from test_excluded_locations import L1, L2, CLIENT_ONLY_TEST_DATA, CLIENT_AND_REQUEST_TEST_DATA +from test_excluded_locations_emulator import L1_URL, L2_URL, get_location +from test_fault_injection_transport_async import TestFaultInjectionTransportAsync -class MockHandler(logging.Handler): - def __init__(self): - super(MockHandler, self).__init__() - self.messages = [] - - def reset(self): - self.messages = [] - - def emit(self, record): - self.messages.append(record.msg) - -MOCK_HANDLER = MockHandler() CONFIG = test_config.TestConfig() -HOST = CONFIG.host -KEY = CONFIG.masterKey -DATABASE_ID = CONFIG.TEST_DATABASE_ID -CONTAINER_ID = CONFIG.TEST_SINGLE_PARTITION_CONTAINER_ID -PARTITION_KEY = CONFIG.TEST_CONTAINER_PARTITION_KEY -ITEM_ID = 'doc1' -ITEM_PK_VALUE = 'pk' -TEST_ITEM = {'id': ITEM_ID, PARTITION_KEY: ITEM_PK_VALUE} - -L0 = "Default" -L1 = "West US 3" -L2 = "West US" -L3 = "East US 2" - -# L0 = "Default" -# L1 = "East US 2" -# L2 = "East US" -# L3 = "West US 2" - -CLIENT_ONLY_TEST_DATA = [ - # preferred_locations, client_excluded_locations, excluded_locations_request - # 0. No excluded location - [[L1, L2], [], None], - # 1. Single excluded location - [[L1, L2], [L1], None], - # 2. Exclude all locations - [[L1, L2], [L1, L2], None], - # 3. Exclude a location not in preferred locations - [[L1, L2], [L3], None], -] - -CLIENT_AND_REQUEST_TEST_DATA = [ - # preferred_locations, client_excluded_locations, excluded_locations_request - # 0. No client excluded locations + a request excluded location - [[L1, L2], [], [L1]], - # 1. The same client and request excluded location - [[L1, L2], [L1], [L1]], - # 2. Less request excluded locations - [[L1, L2], [L1, L2], [L1]], - # 3. More request excluded locations - [[L1, L2], [L1], [L1, L2]], - # 4. All locations were excluded - [[L1, L2], [L1, L2], [L1, L2]], - # 5. No common excluded locations - [[L1, L2], [L1], [L2, L3]], - # 6. Request excluded location not in preferred locations - [[L1, L2], [L1, L2], [L3]], - # 7. Empty excluded locations, remove all client excluded locations - [[L1, L2], [L1, L2], []], -] ALL_INPUT_TEST_DATA = CLIENT_ONLY_TEST_DATA + CLIENT_AND_REQUEST_TEST_DATA -def read_item_test_data(): - client_only_output_data = [ - [L1], # 0 - [L2], # 1 - [L1], # 2 - [L1], # 3 - ] - client_and_request_output_data = [ - [L2], # 0 - [L2], # 1 - [L2], # 2 - [L1], # 3 - [L1], # 4 - [L1], # 5 - [L1], # 6 - [L1], # 7 - ] - all_output_test_data = client_only_output_data + client_and_request_output_data - - all_test_data = [input_data + [output_data] for input_data, output_data in zip(ALL_INPUT_TEST_DATA, all_output_test_data)] - return all_test_data - -def read_all_item_test_data(): +def delete_all_items_by_partition_key_test_data() -> List[str]: client_only_output_data = [ - [L1, L1], # 0 - [L2, L2], # 1 - [L1, L1], # 2 - [L1, L1], # 3 + L1, #0 + L2, #1 + L1, #3 + L1 #4 ] client_and_request_output_data = [ - [L2, L2], # 0 - [L2, L2], # 1 - [L2, L2], # 2 - [L1, L1], # 3 - [L1, L1], # 4 - [L1, L1], # 5 - [L1, L1], # 6 - [L1, L1], # 7 + L2, #0 + L2, #1 + L2, #2 + L1, #3 + L1, #4 + L1, #5 + L1, #6 + L1, #7 ] all_output_test_data = client_only_output_data + client_and_request_output_data all_test_data = [input_data + [output_data] for input_data, output_data in zip(ALL_INPUT_TEST_DATA, all_output_test_data)] return all_test_data -def query_items_change_feed_test_data(): - client_only_output_data = [ - [L1, L1, L1, L1], #0 - [L2, L2, L2, L2], #1 - [L1, L1, L1, L1], #2 - [L1, L1, L1, L1] #3 - ] - client_and_request_output_data = [ - [L1, L2, L2, L2], #0 - [L2, L2, L2, L2], #1 - [L1, L2, L2, L2], #2 - [L2, L1, L1, L1], #3 - [L1, L1, L1, L1], #4 - [L2, L1, L1, L1], #5 - [L1, L1, L1, L1], #6 - [L1, L1, L1, L1], #7 - ] - all_output_test_data = client_only_output_data + client_and_request_output_data - - all_test_data = [input_data + [output_data] for input_data, output_data in zip(ALL_INPUT_TEST_DATA, all_output_test_data)] - return all_test_data - -def replace_item_test_data(): - client_only_output_data = [ - [L1], #0 - [L2], #1 - [L0], #2 - [L1] #3 - ] - client_and_request_output_data = [ - [L2], #0 - [L2], #1 - [L2], #2 - [L0], #3 - [L0], #4 - [L1], #5 - [L1], #6 - [L1], #7 - ] - all_output_test_data = client_only_output_data + client_and_request_output_data - - all_test_data = [input_data + [output_data] for input_data, output_data in zip(ALL_INPUT_TEST_DATA, all_output_test_data)] - return all_test_data - -def patch_item_test_data(): - client_only_output_data = [ - [L1], #0 - [L2], #1 - [L0], #2 - [L1] #3 - ] - client_and_request_output_data = [ - [L2], #0 - [L2], #1 - [L2], #2 - [L0], #3 - [L0], #4 - [L1], #5 - [L1], #6 - [L1], #7 - ] - all_output_test_data = client_only_output_data + client_and_request_output_data - - all_test_data = [input_data + [output_data] for input_data, output_data in zip(ALL_INPUT_TEST_DATA, all_output_test_data)] - return all_test_data - -async def _create_item_with_excluded_locations(container, body, excluded_locations): - if excluded_locations is None: - await container.create_item(body=body) - else: - await container.create_item(body=body, excluded_locations=excluded_locations) - -@pytest_asyncio.fixture(scope="class", autouse=True) -async def setup_and_teardown(): - print("Setup: This runs before any tests") - logger = logging.getLogger("azure") - logger.addHandler(MOCK_HANDLER) - logger.setLevel(logging.DEBUG) - - test_client = CosmosClient(HOST, KEY) - container = test_client.get_database_client(DATABASE_ID).get_container_client(CONTAINER_ID) - await container.upsert_item(body=TEST_ITEM) - - yield - await test_client.close() - -async def _init_container(preferred_locations, client_excluded_locations, multiple_write_locations = True): - client = CosmosClient(HOST, KEY, - preferred_locations=preferred_locations, - excluded_locations=client_excluded_locations, - multiple_write_locations=multiple_write_locations) - db = await client.create_database_if_not_exists(DATABASE_ID) - container = await db.create_container_if_not_exists(CONTAINER_ID, PartitionKey(path='/' + PARTITION_KEY, kind='Hash')) - MOCK_HANDLER.reset() - - return client, db, container - -@pytest.mark.cosmosMultiRegion +@pytest.mark.cosmosEmulator @pytest.mark.asyncio -@pytest.mark.usefixtures("setup_and_teardown") -class TestExcludedLocations: - @pytest.mark.parametrize('test_data', patch_item_test_data()) - async def test_delete_item(self, test_data): +class TestExcludedLocationsAsync: + @pytest.mark.parametrize('test_data', delete_all_items_by_partition_key_test_data()) + async def test_delete_all_items_by_partition_key(self: "TestExcludedLocationsAsync", test_data: List[List[str]]): # Init test variables - preferred_locations, client_excluded_locations, request_excluded_locations, expected_locations = test_data + preferred_locations, client_excluded_locations, request_excluded_locations, expected_location = test_data + + # Inject topology transformation that would make Emulator look like a multiple write region account + # account with two read regions + custom_transport = FaultInjectionTransportAsync() + is_get_account_predicate: Callable[[AioHttpTransport], bool] = lambda \ + r: FaultInjectionTransportAsync.predicate_is_database_account_call(r) + emulator_as_multi_write_region_account_transformation = \ + lambda r, inner: FaultInjectionTransportAsync.transform_topology_mwr( + first_region_name=L1, + first_region_url=L1_URL, + inner=inner, + second_region_name=L2, + second_region_url=L2_URL) + custom_transport.add_response_transformation( + is_get_account_predicate, + emulator_as_multi_write_region_account_transformation) for multiple_write_locations in [True, False]: - # Client setup - client, db, container = await _init_container(preferred_locations, client_excluded_locations, multiple_write_locations) - - # create before delete - item_id = f'doc2-{str(uuid.uuid4())}' - body = {PARTITION_KEY: ITEM_PK_VALUE, 'id': item_id} - await _create_item_with_excluded_locations(container, body, request_excluded_locations) - MOCK_HANDLER.reset() - - # API call: delete_item + # Create client + initialized_objects = await TestFaultInjectionTransportAsync.setup_method_with_custom_transport( + custom_transport, + default_endpoint=CONFIG.host, + key=CONFIG.masterKey, + database_id=CONFIG.TEST_DATABASE_ID, + container_id=CONFIG.TEST_SINGLE_PARTITION_CONTAINER_ID, + preferred_locations=preferred_locations, + excluded_locations=client_excluded_locations, + multiple_write_locations=multiple_write_locations, + ) + container: ContainerProxy = initialized_objects["col"] + + # create an item + id_value: str = str(uuid.uuid4()) + document_definition = {'id': id_value, 'pk': id_value} + await container.create_item(body=document_definition) + + # API call: delete_all_items_by_partition_key if request_excluded_locations is None: - await container.delete_item(item_id, ITEM_PK_VALUE) + await container.delete_all_items_by_partition_key(id_value) else: - await container.delete_item(item_id, ITEM_PK_VALUE, excluded_locations=request_excluded_locations) + await container.delete_all_items_by_partition_key(id_value, excluded_locations=request_excluded_locations) # Verify endpoint locations + actual_location = get_location(initialized_objects) if multiple_write_locations: - _verify_endpoint(MOCK_HANDLER.messages, client, expected_locations) + assert actual_location == expected_location else: - _verify_endpoint(client, [L1]) + assert actual_location == L1 if __name__ == "__main__": unittest.main() diff --git a/sdk/cosmos/azure-cosmos/tests/test_fault_injection_transport_async.py b/sdk/cosmos/azure-cosmos/tests/test_fault_injection_transport_async.py index 1df1de05936d..83535510c983 100644 --- a/sdk/cosmos/azure-cosmos/tests/test_fault_injection_transport_async.py +++ b/sdk/cosmos/azure-cosmos/tests/test_fault_injection_transport_async.py @@ -33,6 +33,7 @@ host = test_config.TestConfig.host master_key = test_config.TestConfig.masterKey TEST_DATABASE_ID = test_config.TestConfig.TEST_DATABASE_ID +SINGLE_PARTITION_CONTAINER_NAME = os.path.basename(__file__) + str(uuid.uuid4()) @pytest.mark.cosmosEmulator @pytest.mark.asyncio @@ -50,7 +51,7 @@ async def asyncSetUp(cls): "'masterKey' and 'host' at the top of this class to run the " "tests.") cls.database_id = TEST_DATABASE_ID - cls.single_partition_container_name = os.path.basename(__file__) + str(uuid.uuid4()) + cls.single_partition_container_name = SINGLE_PARTITION_CONTAINER_NAME cls.mgmt_client = CosmosClient(cls.host, cls.master_key, consistency_level="Session", logger=logger) created_database = cls.mgmt_client.get_database_client(cls.database_id) @@ -76,11 +77,18 @@ async def asyncTearDown(cls): except Exception as closeError: logger.warning("Exception trying to close client {}. {}".format(created_database.id, closeError)) - async def setup_method_with_custom_transport(self, custom_transport: AioHttpTransport, default_endpoint=host, **kwargs): - client = CosmosClient(default_endpoint, master_key, consistency_level="Session", + @staticmethod + async def setup_method_with_custom_transport( + custom_transport: AioHttpTransport, + default_endpoint: str = host, + key: str = master_key, + database_id: str = TEST_DATABASE_ID, + container_id: str = SINGLE_PARTITION_CONTAINER_NAME, + **kwargs): + client = CosmosClient(default_endpoint, key, consistency_level="Session", transport=custom_transport, logger=logger, enable_diagnostics_logging=True, **kwargs) - db: DatabaseProxy = client.get_database_client(TEST_DATABASE_ID) - container: ContainerProxy = db.get_container_client(self.single_partition_container_name) + db: DatabaseProxy = client.get_database_client(database_id) + container: ContainerProxy = db.get_container_client(container_id) return {"client": client, "db": db, "col": container} @staticmethod @@ -106,7 +114,7 @@ async def test_throws_injected_error_async(self: "TestFaultInjectionTransportAsy status_code=502, message="Some random reverse proxy error.")))) - initialized_objects = await self.setup_method_with_custom_transport(custom_transport) + initialized_objects = await TestFaultInjectionTransportAsync.setup_method_with_custom_transport(custom_transport) start: float = time.perf_counter() try: container: ContainerProxy = initialized_objects["col"] @@ -151,7 +159,7 @@ async def test_swr_mrr_succeeds_async(self: "TestFaultInjectionTransportAsync"): 'name': 'sample document', 'key': 'value'} - initialized_objects = await self.setup_method_with_custom_transport( + initialized_objects = await TestFaultInjectionTransportAsync.setup_method_with_custom_transport( custom_transport, preferred_locations=["Read Region", "Write Region"]) try: @@ -210,7 +218,7 @@ async def test_swr_mrr_region_down_read_succeeds_async(self: "TestFaultInjection 'name': 'sample document', 'key': 'value'} - initialized_objects = await self.setup_method_with_custom_transport( + initialized_objects = await TestFaultInjectionTransportAsync.setup_method_with_custom_transport( custom_transport, default_endpoint=expected_write_region_uri, preferred_locations=["Read Region", "Write Region"]) @@ -275,7 +283,7 @@ async def test_swr_mrr_region_down_envoy_read_succeeds_async(self: "TestFaultInj 'name': 'sample document', 'key': 'value'} - initialized_objects = await self.setup_method_with_custom_transport( + initialized_objects = await TestFaultInjectionTransportAsync.setup_method_with_custom_transport( custom_transport, default_endpoint=expected_write_region_uri, preferred_locations=["Read Region", "Write Region"]) @@ -320,7 +328,7 @@ async def test_mwr_succeeds_async(self: "TestFaultInjectionTransportAsync"): 'name': 'sample document', 'key': 'value'} - initialized_objects = await self.setup_method_with_custom_transport( + initialized_objects = await TestFaultInjectionTransportAsync.setup_method_with_custom_transport( custom_transport, preferred_locations=["First Region", "Second Region"]) try: @@ -373,7 +381,7 @@ async def test_mwr_region_down_succeeds_async(self: "TestFaultInjectionTransport 'name': 'sample document', 'key': 'value'} - initialized_objects = await self.setup_method_with_custom_transport( + initialized_objects = await TestFaultInjectionTransportAsync.setup_method_with_custom_transport( custom_transport, preferred_locations=["First Region", "Second Region"]) try: @@ -445,7 +453,7 @@ async def test_swr_mrr_all_regions_down_for_read_async(self: "TestFaultInjection 'name': 'sample document', 'key': 'value'} - initialized_objects = await self.setup_method_with_custom_transport( + initialized_objects = await TestFaultInjectionTransportAsync.setup_method_with_custom_transport( custom_transport, preferred_locations=["First Region", "Second Region"]) try: @@ -498,7 +506,7 @@ async def test_mwr_all_regions_down_async(self: "TestFaultInjectionTransportAsyn 'name': 'sample document', 'key': 'value'} - initialized_objects = await self.setup_method_with_custom_transport( + initialized_objects = await TestFaultInjectionTransportAsync.setup_method_with_custom_transport( custom_transport, preferred_locations=["First Region", "Second Region"]) try: From 5f2c5a08276db59a215f7b2e0feabf3f41800942 Mon Sep 17 00:00:00 2001 From: Allen Kim Date: Mon, 14 Apr 2025 12:19:16 -0700 Subject: [PATCH 84/86] Nit: Changed test names --- .../azure-cosmos/tests/test_excluded_locations_emulator.py | 6 +++--- .../tests/test_excluded_locations_emulator_async.py | 6 +++--- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/sdk/cosmos/azure-cosmos/tests/test_excluded_locations_emulator.py b/sdk/cosmos/azure-cosmos/tests/test_excluded_locations_emulator.py index 96b3fc185afb..375bcfc899d8 100644 --- a/sdk/cosmos/azure-cosmos/tests/test_excluded_locations_emulator.py +++ b/sdk/cosmos/azure-cosmos/tests/test_excluded_locations_emulator.py @@ -63,14 +63,14 @@ def get_location( @pytest.mark.unittest @pytest.mark.cosmosEmulator -class TestExcludedLocations: +class TestExcludedLocationsEmulator: @pytest.mark.parametrize('test_data', delete_all_items_by_partition_key_test_data()) - def test_delete_all_items_by_partition_key(self: "TestExcludedLocations", test_data: List[List[str]]): + def test_delete_all_items_by_partition_key(self: "TestExcludedLocationsEmulator", test_data: List[List[str]]): # Init test variables preferred_locations, client_excluded_locations, request_excluded_locations, expected_location = test_data # Inject topology transformation that would make Emulator look like a multiple write region account - # account with two read regions + # with two read regions custom_transport = FaultInjectionTransport() is_get_account_predicate: Callable[[HttpRequest], bool] = lambda \ r: FaultInjectionTransport.predicate_is_database_account_call(r) diff --git a/sdk/cosmos/azure-cosmos/tests/test_excluded_locations_emulator_async.py b/sdk/cosmos/azure-cosmos/tests/test_excluded_locations_emulator_async.py index f706ca9e465e..c24c2f13c9f7 100644 --- a/sdk/cosmos/azure-cosmos/tests/test_excluded_locations_emulator_async.py +++ b/sdk/cosmos/azure-cosmos/tests/test_excluded_locations_emulator_async.py @@ -42,14 +42,14 @@ def delete_all_items_by_partition_key_test_data() -> List[str]: @pytest.mark.cosmosEmulator @pytest.mark.asyncio -class TestExcludedLocationsAsync: +class TestExcludedLocationsEmulatorAsync: @pytest.mark.parametrize('test_data', delete_all_items_by_partition_key_test_data()) - async def test_delete_all_items_by_partition_key(self: "TestExcludedLocationsAsync", test_data: List[List[str]]): + async def test_delete_all_items_by_partition_key(self: "TestExcludedLocationsEmulatorAsync", test_data: List[List[str]]): # Init test variables preferred_locations, client_excluded_locations, request_excluded_locations, expected_location = test_data # Inject topology transformation that would make Emulator look like a multiple write region account - # account with two read regions + # with two read regions custom_transport = FaultInjectionTransportAsync() is_get_account_predicate: Callable[[AioHttpTransport], bool] = lambda \ r: FaultInjectionTransportAsync.predicate_is_database_account_call(r) From c3e39e522415d4767e8c4a503da42f483b1bf217 Mon Sep 17 00:00:00 2001 From: Allen Kim Date: Tue, 15 Apr 2025 14:06:39 -0700 Subject: [PATCH 85/86] Addressed comments about documents --- .../azure-cosmos/azure/cosmos/documents.py | 2 +- .../samples/excluded_locations.py | 25 +++++++++++-------- 2 files changed, 16 insertions(+), 11 deletions(-) diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/documents.py b/sdk/cosmos/azure-cosmos/azure/cosmos/documents.py index 9e04829be52f..7ccc99da9dfe 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/documents.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/documents.py @@ -314,7 +314,7 @@ class ConnectionPolicy: # pylint: disable=too-many-instance-attributes set of locations from the final location evaluation. The locations in this list are specified as the names of the azure Cosmos locations like, 'West US', 'East US', 'Central India' and so on. - :vartype ExcludedLocations: ~CosmosExcludedLocations + :vartype ExcludedLocations: List[str] :ivar RetryOptions: Gets or sets the retry options to be applied to all requests when retrying. diff --git a/sdk/cosmos/azure-cosmos/samples/excluded_locations.py b/sdk/cosmos/azure-cosmos/samples/excluded_locations.py index 06228c1a8cea..a8c699a7cccf 100644 --- a/sdk/cosmos/azure-cosmos/samples/excluded_locations.py +++ b/sdk/cosmos/azure-cosmos/samples/excluded_locations.py @@ -15,6 +15,9 @@ # # 2. Microsoft Azure Cosmos # pip install azure-cosmos>=4.3.0b4 +# +# 3. Configure Azure Cosmos account to add 3+ regions, such as 'West US 3', 'West US', 'East US 2'. +# If you added other regions, update L1~L3 with the regions in your account. # ---------------------------------------------------------------------------------------------------------- # Sample - demonstrates how to use excluded locations in client level and request level # ---------------------------------------------------------------------------------------------------------- @@ -33,12 +36,14 @@ DATABASE_ID = config.settings["database_id"] CONTAINER_ID = config.settings["container_id"] -PARTITION_KEY = PartitionKey(path="/id") +PARTITION_KEY = PartitionKey(path="/pk") +L1, L2, L3 = 'West US 3', 'West US', 'East US 2' def get_test_item(num): test_item = { 'id': 'Item_' + str(num), + 'pk': 'PartitionKey_' + str(num), 'test_object': True, 'lastName': 'Smith' } @@ -51,8 +56,8 @@ def clean_up_db(client): pass def excluded_locations_client_level_sample(): - preferred_locations = ['West US 3', 'West US', 'East US 2'] - excluded_locations = ['West US 3', 'West US'] + preferred_locations = [L1, L2, L3] + excluded_locations = [L1, L2] client = CosmosClient( HOST, MASTER_KEY, @@ -66,19 +71,19 @@ def excluded_locations_client_level_sample(): # For write operations with single master account, write endpoint will be the default endpoint, # since preferred_locations or excluded_locations are ignored and used - container.create_item(get_test_item(0)) + created_item = container.create_item(get_test_item(0)) # For read operations, read endpoints will be 'preferred_locations' - 'excluded_locations'. # In our sample, ['West US 3', 'West US', 'East US 2'] - ['West US 3', 'West US'] => ['East US 2'], # therefore 'East US 2' will be the read endpoint, and items will be read from 'East US 2' location - item = container.read_item(item='Item_0', partition_key='Item_0') + item = container.read_item(item=created_item['id'], partition_key=created_item['pk']) clean_up_db(client) def excluded_locations_request_level_sample(): - preferred_locations = ['West US 3', 'West US', 'East US 2'] - excluded_locations_on_client = ['West US 3', 'West US'] - excluded_locations_on_request = ['West US 3'] + preferred_locations = [L1, L2, L3] + excluded_locations_on_client = [L1, L2] + excluded_locations_on_request = [L1] client = CosmosClient( HOST, MASTER_KEY, @@ -92,7 +97,7 @@ def excluded_locations_request_level_sample(): # For write operations with single master account, write endpoint will be the default endpoint, # since preferred_locations or excluded_locations are ignored and used - container.create_item(get_test_item(0)) + created_item = container.create_item(get_test_item(0), excluded_locations=excluded_locations_on_request) # For read operations, read endpoints will be 'preferred_locations' - 'excluded_locations'. # However, in our sample, since the excluded_locations` were passed with the read request, the `excluded_location` @@ -101,7 +106,7 @@ def excluded_locations_request_level_sample(): # With the excluded_locations on request, the read endpoints will be ['West US', 'East US 2'] # ['West US 3', 'West US', 'East US 2'] - ['West US 3'] => ['West US', 'East US 2'] # Therefore, items will be read from 'West US' or 'East US 2' location - item = container.read_item(item='Item_0', partition_key='Item_0', excluded_locations=excluded_locations_on_request) + item = container.read_item(item=created_item['id'], partition_key=created_item['pk'], excluded_locations=excluded_locations_on_request) clean_up_db(client) From 39a464cb4a5f10c73557fac5dd685e9fcd15becc Mon Sep 17 00:00:00 2001 From: tvaron3 Date: Wed, 16 Apr 2025 17:10:58 -0700 Subject: [PATCH 86/86] live tests --- .../azure/cosmos/_partition_health_tracker.py | 8 +- .../azure/cosmos/_request_object.py | 4 +- .../azure-cosmos/tests/test_ppcb_mm_async.py | 408 ++++++++++++++++++ .../tests/test_ppcb_sm_mrr_async.py | 2 + 4 files changed, 416 insertions(+), 6 deletions(-) create mode 100644 sdk/cosmos/azure-cosmos/tests/test_ppcb_mm_async.py diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/_partition_health_tracker.py b/sdk/cosmos/azure-cosmos/azure/cosmos/_partition_health_tracker.py index 9f30bac2bd2c..034d1cf1ac10 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/_partition_health_tracker.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/_partition_health_tracker.py @@ -23,7 +23,7 @@ """ import logging import os -from typing import Dict, Set, Any +from typing import Dict, Set, Any, List from azure.cosmos._routing.routing_range import PartitionKeyRangeWrapper from azure.cosmos._location_cache import current_time_millis, EndpointOperationType from ._constants import _Constants as Constants @@ -171,15 +171,15 @@ def _check_stale_partition_info(self, pk_range_wrapper: PartitionKeyRangeWrapper self._reset_partition_health_tracker_stats() - def get_excluded_locations(self, pk_range_wrapper: PartitionKeyRangeWrapper) -> Set[str]: + def get_excluded_locations(self, pk_range_wrapper: PartitionKeyRangeWrapper) -> List[str]: self._check_stale_partition_info(pk_range_wrapper) - excluded_locations = set() + excluded_locations = [] if pk_range_wrapper in self.pk_range_wrapper_to_health_info: for location, partition_health_info in self.pk_range_wrapper_to_health_info[pk_range_wrapper].items(): if partition_health_info.unavailability_info: health_status = partition_health_info.unavailability_info[HEALTH_STATUS] if health_status in (UNHEALTHY_TENTATIVE, UNHEALTHY): - excluded_locations.add(location) + excluded_locations.append(location) return excluded_locations diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/_request_object.py b/sdk/cosmos/azure-cosmos/azure/cosmos/_request_object.py index 253434c8b881..24dc9e0dd9c2 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/_request_object.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/_request_object.py @@ -42,7 +42,7 @@ def __init__( self.location_endpoint_to_route: Optional[str] = None self.last_routed_location_endpoint_within_region: Optional[str] = None self.excluded_locations: Optional[List[str]] = None - self.excluded_locations_circuit_breaker: Set[str] = set() + self.excluded_locations_circuit_breaker: List[str] = [] def route_to_location_with_preferred_location_flag( # pylint: disable=name-too-long self, @@ -89,5 +89,5 @@ def set_excluded_location_from_options(self, options: Mapping[str, Any]) -> None if self._can_set_excluded_location(options): self.excluded_locations = options['excludedLocations'] - def set_excluded_locations_from_circuit_breaker(self, excluded_locations: Set[str]) -> None: # pylint: disable=name-too-long + def set_excluded_locations_from_circuit_breaker(self, excluded_locations: List[str]) -> None: # pylint: disable=name-too-long self.excluded_locations_circuit_breaker = excluded_locations diff --git a/sdk/cosmos/azure-cosmos/tests/test_ppcb_mm_async.py b/sdk/cosmos/azure-cosmos/tests/test_ppcb_mm_async.py new file mode 100644 index 000000000000..2d033890429d --- /dev/null +++ b/sdk/cosmos/azure-cosmos/tests/test_ppcb_mm_async.py @@ -0,0 +1,408 @@ +# The MIT License (MIT) +# Copyright (c) Microsoft Corporation. All rights reserved. +import asyncio +import os +import unittest +import uuid +from typing import Dict, Any + +import pytest +import pytest_asyncio +from azure.core.pipeline.transport._aiohttp import AioHttpTransport +from azure.core.exceptions import ServiceResponseError + +import test_config +from azure.cosmos import PartitionKey, _location_cache +from azure.cosmos._partition_health_tracker import HEALTH_STATUS, UNHEALTHY, UNHEALTHY_TENTATIVE +from azure.cosmos.aio import CosmosClient +from azure.cosmos.exceptions import CosmosHttpResponseError +from _fault_injection_transport_async import FaultInjectionTransportAsync + +REGION_1 = "West US 3" +REGION_2 = "Mexico Central" # "West US" + + +COLLECTION = "created_collection" +@pytest_asyncio.fixture() +async def setup(): + os.environ["AZURE_COSMOS_ENABLE_CIRCUIT_BREAKER"] = "True" + client = CosmosClient(TestPerPartitionCircuitBreakerMMAsync.host, TestPerPartitionCircuitBreakerMMAsync.master_key, consistency_level="Session") + created_database = client.get_database_client(TestPerPartitionCircuitBreakerMMAsync.TEST_DATABASE_ID) + await client.create_database_if_not_exists(TestPerPartitionCircuitBreakerMMAsync.TEST_DATABASE_ID) + created_collection = await created_database.create_container(TestPerPartitionCircuitBreakerMMAsync.TEST_CONTAINER_SINGLE_PARTITION_ID, + partition_key=PartitionKey("/pk"), + offer_throughput=10000) + yield { + COLLECTION: created_collection + } + + await created_database.delete_container(TestPerPartitionCircuitBreakerMMAsync.TEST_CONTAINER_SINGLE_PARTITION_ID) + await client.close() + os.environ["AZURE_COSMOS_ENABLE_CIRCUIT_BREAKER"] = "False" + +def write_operations_and_errors(): + write_operations = ["create", "upsert", "replace", "delete", "patch", "batch", "delete_all_items_by_partition_key"] + errors = [] + error_codes = [408, 500, 502, 503] + for error_code in error_codes: + errors.append(CosmosHttpResponseError( + status_code=error_code, + message="Some injected error.")) + errors.append(ServiceResponseError(message="Injected Service Response Error.")) + params = [] + for write_operation in write_operations: + for error in errors: + params.append((write_operation, error)) + + return params + +def read_operations_and_errors(): + read_operations = ["read", "query", "changefeed", "read_all_items"] + errors = [] + error_codes = [408, 500, 502, 503] + for error_code in error_codes: + errors.append(CosmosHttpResponseError( + status_code=error_code, + message="Some injected error.")) + errors.append(ServiceResponseError(message="Injected Service Response Error.")) + params = [] + for read_operation in read_operations: + for error in errors: + params.append((read_operation, error)) + + return params + +def validate_response_uri(response, expected_uri): + request = response.get_response_headers()["_request"] + assert request.url.startswith(expected_uri) + +async def perform_write_operation(operation, container, fault_injection_container, doc_id, pk, expected_uri): + doc = {'id': doc_id, + 'pk': pk, + 'name': 'sample document', + 'key': 'value'} + if operation == "create": + resp = await fault_injection_container.create_item(body=doc) + elif operation == "upsert": + resp = await fault_injection_container.upsert_item(body=doc) + elif operation == "replace": + new_doc = {'id': doc_id, + 'pk': pk, + 'name': 'sample document' + str(uuid), + 'key': 'value'} + resp = await fault_injection_container.replace_item(item=doc['id'], body=new_doc) + elif operation == "delete": + await container.create_item(body=doc) + resp = await fault_injection_container.delete_item(item=doc['id'], partition_key=doc['pk']) + elif operation == "patch": + operations = [{"op": "incr", "path": "/company", "value": 3}] + resp = await fault_injection_container.patch_item(item=doc['id'], partition_key=doc['pk'], patch_operations=operations) + elif operation == "batch": + batch_operations = [ + ("create", (doc, )), + ("upsert", (doc,)), + ("upsert", (doc,)), + ("upsert", (doc,)), + ] + resp = await fault_injection_container.execute_item_batch(batch_operations, partition_key=doc['pk']) + elif operation == "delete_all_items_by_partition_key": + await container.create_item(body=doc) + await container.create_item(body=doc) + await container.create_item(body=doc) + resp = await fault_injection_container.delete_all_items_by_partition_key(pk) + validate_response_uri(resp, expected_uri) + + +@pytest.mark.cosmosMultiRegion +@pytest.mark.asyncio +@pytest.mark.usefixtures("setup") +class TestPerPartitionCircuitBreakerMMAsync: + host = test_config.TestConfig.host + master_key = test_config.TestConfig.masterKey + connectionPolicy = test_config.TestConfig.connectionPolicy + TEST_DATABASE_ID = test_config.TestConfig.TEST_DATABASE_ID + TEST_CONTAINER_SINGLE_PARTITION_ID = os.path.basename(__file__) + str(uuid.uuid4()) + + async def setup_method_with_custom_transport(self, custom_transport: AioHttpTransport, default_endpoint=host, **kwargs): + client = CosmosClient(default_endpoint, self.master_key, consistency_level="Session", + preferred_locations=[REGION_1, REGION_2], + multiple_write_locations=True, + transport=custom_transport, **kwargs) + db = client.get_database_client(self.TEST_DATABASE_ID) + container = db.get_container_client(self.TEST_CONTAINER_SINGLE_PARTITION_ID) + return {"client": client, "db": db, "col": container} + + @staticmethod + async def cleanup_method(initialized_objects: Dict[str, Any]): + method_client: CosmosClient = initialized_objects["client"] + await method_client.close() + + @staticmethod + async def perform_read_operation(operation, container, doc_id, pk, expected_uri): + if operation == "read": + read_resp = await container.read_item(item=doc_id, partition_key=pk) + request = read_resp.get_response_headers()["_request"] + # Validate the response comes from "Read Region" (the most preferred read-only region) + assert request.url.startswith(expected_uri) + elif operation == "query": + query = "SELECT * FROM c WHERE c.id = @id AND c.pk = @pk" + parameters = [{"name": "@id", "value": doc_id}, {"name": "@pk", "value": pk}] + async for item in container.query_items(query=query, partition_key=pk, parameters=parameters): + assert item['id'] == doc_id + # need to do query with no pk and with feed range + elif operation == "changefeed": + async for _ in container.query_items_change_feed(): + pass + elif operation == "read_all_items": + async for item in container.read_all_items(partition_key=pk): + assert item['pk'] == pk + + + @pytest.mark.parametrize("write_operation, error", write_operations_and_errors()) + async def test_write_consecutive_failure_threshold_async(self, setup, write_operation, error): + expected_uri = _location_cache.LocationCache.GetLocationalEndpoint(self.host, REGION_2) + custom_transport = FaultInjectionTransportAsync() + id_value = 'failoverDoc-' + str(uuid.uuid4()) + document_definition = {'id': id_value, + 'pk': 'pk1', + 'name': 'sample document', + 'key': 'value'} + predicate = lambda r: (FaultInjectionTransportAsync.predicate_is_document_operation(r) and + FaultInjectionTransportAsync.predicate_targets_region(r, expected_uri)) + custom_transport.add_fault(predicate, lambda r: asyncio.create_task(FaultInjectionTransportAsync.error_after_delay( + 0, + error + ))) + + custom_setup = await self.setup_method_with_custom_transport(custom_transport, default_endpoint=self.host) + container = custom_setup['col'] + global_endpoint_manager = container.client_connection._global_endpoint_manager + + + await perform_write_operation(write_operation, + setup[COLLECTION], + container, + document_definition['id'], + document_definition['pk'], + expected_uri) + + TestPerPartitionCircuitBreakerMMAsync.validate_unhealthy_partitions(global_endpoint_manager, 0) + + # writes should fail but still be tracked + for i in range(6): + with pytest.raises((CosmosHttpResponseError, ServiceResponseError)): + await perform_write_operation(write_operation, + setup[COLLECTION], + container, + document_definition['id'], + document_definition['pk'], + expected_uri) + + TestPerPartitionCircuitBreakerMMAsync.validate_unhealthy_partitions(global_endpoint_manager, 1) + + @pytest.mark.parametrize("read_operation, error", read_operations_and_errors()) + async def test_read_consecutive_failure_threshold_async(self, setup, read_operation, error): + expected_uri = _location_cache.LocationCache.GetLocationalEndpoint(self.host, REGION_2) + custom_transport = FaultInjectionTransportAsync() + id_value = 'failoverDoc-' + str(uuid.uuid4()) + document_definition = {'id': id_value, + 'pk': 'pk1', + 'name': 'sample document', + 'key': 'value'} + predicate = lambda r: (FaultInjectionTransportAsync.predicate_is_document_operation(r) and + FaultInjectionTransportAsync.predicate_targets_region(r, expected_uri)) + custom_transport.add_fault(predicate, lambda r: asyncio.create_task(FaultInjectionTransportAsync.error_after_delay( + 0, + error + ))) + + custom_setup = await self.setup_method_with_custom_transport(custom_transport, default_endpoint=self.host) + container = custom_setup['col'] + global_endpoint_manager = container.client_connection._global_endpoint_manager + + + await TestPerPartitionCircuitBreakerMMAsync.perform_write_operation(write_operation, + setup[COLLECTION], + container, + document_definition['id'], + document_definition['pk'], + expected_uri) + + # writes should fail in sm mrr with circuit breaker and should not mark unavailable a partition + for i in range(6): + with pytest.raises((CosmosHttpResponseError, ServiceResponseError)): + await TestPerPartitionCircuitBreakerMMAsync.perform_write_operation(write_operation, + setup[COLLECTION], + container, + document_definition['id'], + document_definition['pk'], + expected_uri) + + TestPerPartitionCircuitBreakerMMAsync.validate_unhealthy_partitions(global_endpoint_manager, 1) + + # create item with client without fault injection + await setup[COLLECTION].create_item(body=document_definition) + + # reads should fail over and only the relevant partition should be marked as unavailable + await TestPerPartitionCircuitBreakerMMAsync.perform_read_operation(read_operation, + container, + document_definition['id'], + document_definition['pk'], + expected_uri) + # partition should not have been marked unavailable after one error + TestPerPartitionCircuitBreakerMMAsync.validate_unhealthy_partitions(global_endpoint_manager, 1) + + for i in range(10): + await TestPerPartitionCircuitBreakerMMAsync.perform_read_operation(read_operation, + container, + document_definition['id'], + document_definition['pk'], + expected_uri) + + # the partition should have been marked as unavailable after breaking read threshold + TestPerPartitionCircuitBreakerMMAsync.validate_unhealthy_partitions(global_endpoint_manager, 1) + # test recovering the partition again + + @pytest.mark.parametrize("write_operation, error", write_operations_and_errors()) + async def test_failure_rate_threshold_async(self, setup, write_operation, error): + expected_uri = _location_cache.LocationCache.GetLocationalEndpoint(self.host, REGION_2) + custom_transport = await self.create_custom_transport_sm_mrr() + id_value = 'failoverDoc-' + str(uuid.uuid4()) + # two documents targeted to same partition, one will always fail and the other will succeed + document_definition = {'id': id_value, + 'pk': 'pk1', + 'name': 'sample document', + 'key': 'value'} + doc_2 = {'id': str(uuid.uuid4()), + 'pk': 'pk1', + 'name': 'sample document', + 'key': 'value'} + predicate = lambda r: (FaultInjectionTransportAsync.predicate_req_for_document_with_id(r, id_value) and + FaultInjectionTransportAsync.predicate_targets_region(r, expected_write_region_uri)) + custom_transport.add_fault(predicate, + lambda r: asyncio.create_task(FaultInjectionTransportAsync.error_after_delay( + 0, + error + ))) + + custom_setup = await self.setup_method_with_custom_transport(custom_transport, default_endpoint=self.host) + container = custom_setup['col'] + global_endpoint_manager = container.client_connection._global_endpoint_manager + # lower minimum requests for testing + global_endpoint_manager.global_partition_endpoint_manager_core.partition_health_tracker.MINIMUM_REQUESTS_FOR_FAILURE_RATE = 10 + try: + # writes should fail in sm mrr with circuit breaker and should not mark unavailable a partition + for i in range(14): + if i == 9: + await container.upsert_item(body=doc_2) + with pytest.raises((CosmosHttpResponseError, ServiceResponseError)): + await TestPerPartitionCircuitBreakerMMAsync.perform_write_operation(write_operation, + setup[COLLECTION], + container, + document_definition['id'], + document_definition['pk'], + expected_uri) + + TestPerPartitionCircuitBreakerMMAsync.validate_unhealthy_partitions(global_endpoint_manager, 0) + + # create item with client without fault injection + await setup[COLLECTION].create_item(body=document_definition) + finally: + # restore minimum requests + global_endpoint_manager.global_partition_endpoint_manager_core.partition_health_tracker.MINIMUM_REQUESTS_FOR_FAILURE_RATE = 100 + + + + @pytest.mark.parametrize("write_operation, error", write_operations_and_errors()) + async def test_read_failure_rate_threshold_async(self, setup, write_operation, error): + expected_uri = _location_cache.LocationCache.GetLocationalEndpoint(self.host, REGION_2) + custom_transport = await self.create_custom_transport_sm_mrr() + id_value = 'failoverDoc-' + str(uuid.uuid4()) + # two documents targeted to same partition, one will always fail and the other will succeed + document_definition = {'id': id_value, + 'pk': 'pk1', + 'name': 'sample document', + 'key': 'value'} + doc_2 = {'id': str(uuid.uuid4()), + 'pk': 'pk1', + 'name': 'sample document', + 'key': 'value'} + predicate = lambda r: (FaultInjectionTransportAsync.predicate_req_for_document_with_id(r, id_value) and + FaultInjectionTransportAsync.predicate_targets_region(r, expected_write_region_uri)) + custom_transport.add_fault(predicate, + lambda r: asyncio.create_task(FaultInjectionTransportAsync.error_after_delay( + 0, + error + ))) + + custom_setup = await self.setup_method_with_custom_transport(custom_transport, default_endpoint=self.host) + container = custom_setup['col'] + global_endpoint_manager = container.client_connection._global_endpoint_manager + # lower minimum requests for testing + global_endpoint_manager.global_partition_endpoint_manager_core.partition_health_tracker.MINIMUM_REQUESTS_FOR_FAILURE_RATE = 10 + try: + # writes should fail in sm mrr with circuit breaker and should not mark unavailable a partition + for i in range(14): + if i == 9: + await container.upsert_item(body=doc_2) + with pytest.raises((CosmosHttpResponseError, ServiceResponseError)): + await TestPerPartitionCircuitBreakerMMAsync.perform_write_operation(write_operation, + setup[COLLECTION], + container, + document_definition['id'], + document_definition['pk'], + expected_uri) + + TestPerPartitionCircuitBreakerMMAsync.validate_unhealthy_partitions(global_endpoint_manager, 0) + + # create item with client without fault injection + await setup[COLLECTION].create_item(body=document_definition) + + # reads should fail over and only the relevant partition should be marked as unavailable + await container.read_item(item=document_definition['id'], partition_key=document_definition['pk']) + # partition should not have been marked unavailable after one error + TestPerPartitionCircuitBreakerMMAsync.validate_unhealthy_partitions(global_endpoint_manager, 0) + for i in range(20): + if i == 8: + read_resp = await container.read_item(item=doc_2['id'], + partition_key=doc_2['pk']) + request = read_resp.get_response_headers()["_request"] + # Validate the response comes from "Read Region" (the most preferred read-only region) + assert request.url.startswith(expected_uri) + else: + await TestPerPartitionCircuitBreakerMMAsync.perform_read_operation(read_operation, + container, + document_definition['id'], + document_definition['pk'], + expected_uri) + # the partition should have been marked as unavailable after breaking read threshold + TestPerPartitionCircuitBreakerMMAsync.validate_unhealthy_partitions(global_endpoint_manager, 1) + finally: + # restore minimum requests + global_endpoint_manager.global_partition_endpoint_manager_core.partition_health_tracker.MINIMUM_REQUESTS_FOR_FAILURE_RATE = 100 + + # look at the urls for verifying fall back and use another id for same partition + + @staticmethod + def validate_unhealthy_partitions(global_endpoint_manager, + expected_unhealthy_partitions): + health_info_map = global_endpoint_manager.global_partition_endpoint_manager_core.partition_health_tracker.pk_range_wrapper_to_health_info + unhealthy_partitions = 0 + for pk_range_wrapper, location_to_health_info in health_info_map.items(): + for location, health_info in location_to_health_info.items(): + health_status = health_info.unavailability_info.get(HEALTH_STATUS) + if health_status == UNHEALTHY_TENTATIVE or health_status == UNHEALTHY: + unhealthy_partitions += 1 + else: + assert health_info.read_consecutive_failure_count < 10 + assert health_info.write_failure_count == 0 + assert health_info.write_consecutive_failure_count == 0 + + assert unhealthy_partitions == expected_unhealthy_partitions + + # test_failure_rate_threshold - add service response error - across operation types - test recovering the partition again + # test service request marks only a partition unavailable not an entire region - across operation types + # test cosmos client timeout + +if __name__ == '__main__': + unittest.main() diff --git a/sdk/cosmos/azure-cosmos/tests/test_ppcb_sm_mrr_async.py b/sdk/cosmos/azure-cosmos/tests/test_ppcb_sm_mrr_async.py index ee5e3fcd6fa9..ec6cf5ecff7d 100644 --- a/sdk/cosmos/azure-cosmos/tests/test_ppcb_sm_mrr_async.py +++ b/sdk/cosmos/azure-cosmos/tests/test_ppcb_sm_mrr_async.py @@ -297,6 +297,8 @@ def validate_unhealthy_partitions(global_endpoint_manager, health_status = health_info.unavailability_info.get(HEALTH_STATUS) if health_status == UNHEALTHY_TENTATIVE or health_status == UNHEALTHY: unhealthy_partitions += 1 + assert health_info.write_failure_count == 0 + assert health_info.write_consecutive_failure_count == 0 else: assert health_info.read_consecutive_failure_count < 10 assert health_info.write_failure_count == 0