diff --git a/sdk/cosmos/azure-cosmos/CHANGELOG.md b/sdk/cosmos/azure-cosmos/CHANGELOG.md index 1949488630a3..f88d60f41d51 100644 --- a/sdk/cosmos/azure-cosmos/CHANGELOG.md +++ b/sdk/cosmos/azure-cosmos/CHANGELOG.md @@ -5,8 +5,10 @@ #### Features Added #### Breaking Changes +* Adds cross region retries when no preferred locations are set. This is only a breaking change for customers using bounded staleness consistency. See [PR 39714](https://github.com/Azure/azure-sdk-for-python/pull/39714) #### Bugs Fixed +* Fixed bug where replacing manual throughput using `ThroughputProperties` would not work. See [PR 41564](https://github.com/Azure/azure-sdk-for-python/pull/41564) #### Other Changes diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/_location_cache.py b/sdk/cosmos/azure-cosmos/azure/cosmos/_location_cache.py index 90578c63e5dd..9e2797835ead 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/_location_cache.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/_location_cache.py @@ -24,12 +24,13 @@ """ import collections import logging -from typing import Set, Mapping, List +from typing import Set, Mapping, OrderedDict, Dict +from typing import List from urllib.parse import urlparse from . import documents, _base as base from .http_constants import ResourceType -from .documents import _OperationType +from .documents import _OperationType, ConnectionPolicy from ._request_object import RequestObject # pylint: disable=protected-access @@ -43,8 +44,8 @@ class EndpointOperationType(object): class RegionalRoutingContext(object): def __init__(self, primary_endpoint: str, alternate_endpoint: str): - self.primary_endpoint = primary_endpoint - self.alternate_endpoint = alternate_endpoint + self.primary_endpoint: str = primary_endpoint + self.alternate_endpoint: str = alternate_endpoint def set_primary(self, endpoint: str): self.primary_endpoint = endpoint @@ -65,13 +66,13 @@ def __eq__(self, other): def __str__(self): return "Primary: " + self.primary_endpoint + ", Alternate: " + self.alternate_endpoint -def get_endpoints_by_location(new_locations, - old_endpoints_by_location, - default_regional_endpoint, - writes, - use_multiple_write_locations): +def get_endpoints_by_location(new_locations: List[Dict[str, str]], + old_regional_routing_contexts_by_location: Dict[str, RegionalRoutingContext], + default_regional_endpoint: RegionalRoutingContext, + writes: bool, + use_multiple_write_locations: bool): # construct from previous object - endpoints_by_location = collections.OrderedDict() + regional_routing_context_by_location: OrderedDict[str, RegionalRoutingContext] = collections.OrderedDict() parsed_locations = [] @@ -86,8 +87,8 @@ def get_endpoints_by_location(new_locations, parsed_locations.append(new_location["name"]) if not writes or use_multiple_write_locations: regional_object = RegionalRoutingContext(region_uri, region_uri) - elif new_location["name"] in old_endpoints_by_location: - regional_object = old_endpoints_by_location[new_location["name"]] + elif new_location["name"] in old_regional_routing_contexts_by_location: + regional_object = old_regional_routing_contexts_by_location[new_location["name"]] current = regional_object.get_primary() # swap the previous with current and current with new region_uri received from the gateway if current != region_uri: @@ -108,15 +109,14 @@ def get_endpoints_by_location(new_locations, default_regional_endpoint.get_primary(), new_location["name"]) regional_object.set_alternate(constructed_region_uri) - # pass in object with region uri , last known good, curr etc - endpoints_by_location.update({new_location["name"]: regional_object}) + regional_routing_context_by_location.update({new_location["name"]: regional_object}) except Exception as e: raise e # Also store a hash map of endpoints for each location - locations_by_endpoints = {value.get_primary(): key for key, value in endpoints_by_location.items()} + locations_by_endpoints = {value.get_primary(): key for key, value in regional_routing_context_by_location.items()} - return endpoints_by_location, locations_by_endpoints, parsed_locations + return regional_routing_context_by_location, locations_by_endpoints, parsed_locations def _get_health_check_endpoints(regional_routing_contexts) -> Set[str]: # should use the endpoints in the order returned from gateway and only the ones specified in preferred locations @@ -154,22 +154,24 @@ class LocationCache(object): # pylint: disable=too-many-public-methods,too-many def __init__( self, - default_endpoint, - connection_policy, + default_endpoint: str, + connection_policy: ConnectionPolicy, ): - self.default_regional_routing_context = RegionalRoutingContext(default_endpoint, default_endpoint) - self.enable_multiple_writable_locations = False - self.write_regional_routing_contexts = [self.default_regional_routing_context] - self.read_regional_routing_contexts = [self.default_regional_routing_context] - self.location_unavailability_info_by_endpoint = {} - self.last_cache_update_time_stamp = 0 - self.account_read_regional_routing_contexts_by_location = {} # pylint: disable=name-too-long - self.account_write_regional_routing_contexts_by_location = {} # pylint: disable=name-too-long - self.account_locations_by_read_endpoints = {} # pylint: disable=name-too-long - self.account_locations_by_write_endpoints = {} # pylint: disable=name-too-long - self.account_write_locations = [] - self.account_read_locations = [] - self.connection_policy = connection_policy + self.default_regional_routing_context: RegionalRoutingContext = RegionalRoutingContext(default_endpoint, + default_endpoint) + self.effective_preferred_locations: List[str] = [] + self.enable_multiple_writable_locations: bool = False + self.write_regional_routing_contexts: List[RegionalRoutingContext] = [self.default_regional_routing_context] + self.read_regional_routing_contexts: List[RegionalRoutingContext] = [self.default_regional_routing_context] + self.location_unavailability_info_by_endpoint: Dict[str, Dict[str, Set[EndpointOperationType]]] = {} + self.last_cache_update_time_stamp: int = 0 + self.account_read_regional_routing_contexts_by_location: Dict[str, RegionalRoutingContext] = {} # pylint: disable=name-too-long + self.account_write_regional_routing_contexts_by_location: Dict[str, RegionalRoutingContext] = {} # pylint: disable=name-too-long + self.account_locations_by_read_endpoints: Dict[str, str] = {} # pylint: disable=name-too-long + self.account_locations_by_write_endpoints: Dict[str, str] = {} # pylint: disable=name-too-long + self.account_write_locations: List[str] = [] + self.account_read_locations: List[str] = [] + self.connection_policy: ConnectionPolicy = connection_policy def get_write_regional_routing_contexts(self): return self.write_regional_routing_contexts @@ -310,8 +312,7 @@ def resolve_service_endpoint(self, request): return regional_routing_context.get_primary() def should_refresh_endpoints(self): # pylint: disable=too-many-return-statements - most_preferred_location = self.connection_policy.PreferredLocations[0] \ - if self.connection_policy.PreferredLocations else None + most_preferred_location = self.effective_preferred_locations[0] if self.effective_preferred_locations else None # we should schedule refresh in background if we are unable to target the user's most preferredLocation. if self.connection_policy.EnableEndpointDiscovery: @@ -379,7 +380,7 @@ def is_endpoint_unavailable_internal(self, endpoint: str, expected_available_ope return True def mark_endpoint_unavailable( - self, unavailable_endpoint: str, unavailable_operation_type: str, refresh_cache: bool): + self, unavailable_endpoint: str, unavailable_operation_type: EndpointOperationType, refresh_cache: bool): logger.warning("Marking %s unavailable for %s ", unavailable_endpoint, unavailable_operation_type) @@ -431,6 +432,15 @@ def update_location_cache(self, write_locations=None, read_locations=None, enabl self.connection_policy.UseMultipleWriteLocations ) + # if preferred locations is empty and the default endpoint is a global endpoint, + # we should use the read locations from gateway as effective preferred locations + if self.connection_policy.PreferredLocations: + self.effective_preferred_locations = self.connection_policy.PreferredLocations + elif self.is_default_endpoint_regional(): + self.effective_preferred_locations = [] + elif not self.effective_preferred_locations: + self.effective_preferred_locations = self.account_read_locations + self.write_regional_routing_contexts = self.get_preferred_regional_routing_contexts( self.account_write_regional_routing_contexts_by_location, self.account_write_locations, @@ -456,12 +466,12 @@ def get_preferred_regional_routing_contexts( or expected_available_operation == EndpointOperationType.ReadType ): unavailable_endpoints = [] - if self.connection_policy.PreferredLocations: + if self.effective_preferred_locations: # When client can not use multiple write locations, preferred locations # list should only be used determining read endpoints order. If client # can use multiple write locations, preferred locations list should be # used for determining both read and write endpoints order. - for location in self.connection_policy.PreferredLocations: + for location in self.effective_preferred_locations: regional_endpoint = endpoints_by_location[location] if location in endpoints_by_location \ else None if regional_endpoint: @@ -486,6 +496,13 @@ def get_preferred_regional_routing_contexts( return regional_endpoints + # if the endpoint is returned from the gateway in the account topology, it is a regional endpoint + def is_default_endpoint_regional(self) -> bool: + return any( + context.get_primary() == self.default_regional_routing_context.get_primary() + for context in self.account_read_regional_routing_contexts_by_location.values() + ) + def can_use_multiple_write_locations(self): return self.connection_policy.UseMultipleWriteLocations and self.enable_multiple_writable_locations diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/_session_retry_policy.py b/sdk/cosmos/azure-cosmos/azure/cosmos/_session_retry_policy.py index 69b9d52f286d..f10366ac4c7f 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/_session_retry_policy.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/_session_retry_policy.py @@ -23,17 +23,8 @@ in the Azure Cosmos database service. """ -import logging from azure.cosmos.documents import _OperationType -logger = logging.getLogger(__name__) -logger.setLevel(logging.INFO) -log_formatter = logging.Formatter("%(levelname)s:%(message)s") -log_handler = logging.StreamHandler() -log_handler.setFormatter(log_formatter) -logger.addHandler(log_handler) - - class _SessionRetryPolicy(object): """The session retry policy used to handle read/write session unavailability. """ diff --git a/sdk/cosmos/azure-cosmos/tests/test_effective_preferred_locations.py b/sdk/cosmos/azure-cosmos/tests/test_effective_preferred_locations.py new file mode 100644 index 000000000000..d36c93b5c746 --- /dev/null +++ b/sdk/cosmos/azure-cosmos/tests/test_effective_preferred_locations.py @@ -0,0 +1,206 @@ +# The MIT License (MIT) +# Copyright (c) Microsoft Corporation. All rights reserved. + +import uuid +from typing import List + +import pytest +from azure.core.exceptions import ServiceRequestError + +import test_config +from azure.cosmos import DatabaseAccount, _location_cache, CosmosClient, _global_endpoint_manager, \ + _cosmos_client_connection +from azure.cosmos._location_cache import RegionalRoutingContext +from _fault_injection_transport import FaultInjectionTransport +from azure.cosmos.exceptions import CosmosHttpResponseError + +COLLECTION = "created_collection" +REGION_1 = "West US 3" +REGION_2 = "West US" +REGION_3 = "West US 2" +ACCOUNT_REGIONS = [REGION_1, REGION_2, REGION_3] + +@pytest.fixture() +def setup(): + if (TestPreferredLocations.master_key == '[YOUR_KEY_HERE]' or + TestPreferredLocations.host == '[YOUR_ENDPOINT_HERE]'): + raise Exception( + "You must specify your Azure Cosmos account values for " + "'masterKey' and 'host' at the top of this class to run the " + "tests.") + + client = CosmosClient(TestPreferredLocations.host, TestPreferredLocations.master_key, consistency_level="Session") + created_database = client.get_database_client(TestPreferredLocations.TEST_DATABASE_ID) + created_collection = created_database.get_container_client(TestPreferredLocations.TEST_CONTAINER_SINGLE_PARTITION_ID) + yield { + COLLECTION: created_collection + } + +def preferred_locations(): + host = test_config.TestConfig.host + locational_endpoint = _location_cache.LocationCache.GetLocationalEndpoint(host, REGION_2) + return [ + ([], host), + ([REGION_1, REGION_2], host), + ([REGION_1], host), + ([REGION_2, REGION_3], host), + ([REGION_1, REGION_2, REGION_3], host), + ([], locational_endpoint), + ([REGION_2], locational_endpoint), + ([REGION_3, REGION_1], locational_endpoint), + ([REGION_1, REGION_3], locational_endpoint), + ([REGION_1, REGION_2, REGION_3], locational_endpoint) + ] + +def construct_item(): + return { + "id": "test_item_no_preferred_locations" + str(uuid.uuid4()), + test_config.TestConfig.TEST_CONTAINER_PARTITION_KEY: str(uuid.uuid4()) + } + +def error(): + status_codes = [503, 408, 404] + sub_status = [0, 0, 1002] + errors = [] + for i, status_code in enumerate(status_codes): + errors.append(CosmosHttpResponseError( + status_code=status_code, + message=f"Error with status code {status_code} and substatus {sub_status[i]}", + sub_status=sub_status[i] + )) + return errors + +@pytest.mark.unittest +@pytest.mark.usefixtures("setup") +class TestPreferredLocations: + host = test_config.TestConfig.host + master_key = test_config.TestConfig.masterKey + TEST_DATABASE_ID = test_config.TestConfig.TEST_DATABASE_ID + TEST_CONTAINER_SINGLE_PARTITION_ID = test_config.TestConfig.TEST_SINGLE_PARTITION_CONTAINER_ID + partition_key = test_config.TestConfig.TEST_CONTAINER_PARTITION_KEY + + def setup_method_with_custom_transport(self, custom_transport, error_lambda, default_endpoint=host, **kwargs): + uri_down = _location_cache.LocationCache.GetLocationalEndpoint(self.host, REGION_1) + predicate = lambda r: (FaultInjectionTransport.predicate_is_document_operation(r) and + (FaultInjectionTransport.predicate_targets_region(r, uri_down) or + FaultInjectionTransport.predicate_targets_region(r, default_endpoint))) + custom_transport.add_fault(predicate, + error_lambda) + client = CosmosClient(default_endpoint, + self.master_key, + multiple_write_locations=True, + transport=custom_transport, **kwargs) + db = client.get_database_client(self.TEST_DATABASE_ID) + container = db.get_container_client(self.TEST_CONTAINER_SINGLE_PARTITION_ID) + return {"client": client, "db": db, "col": container} + + @pytest.mark.cosmosEmulator + @pytest.mark.parametrize("preferred_location, default_endpoint", preferred_locations()) + def test_effective_preferred_regions(self, setup, preferred_location, default_endpoint): + + self.original_getDatabaseAccountStub = _global_endpoint_manager._GlobalEndpointManager._GetDatabaseAccountStub + self.original_getDatabaseAccountCheck = _cosmos_client_connection.CosmosClientConnection._GetDatabaseAccountCheck + _global_endpoint_manager._GlobalEndpointManager._GetDatabaseAccountStub = self.MockGetDatabaseAccount(ACCOUNT_REGIONS) + _cosmos_client_connection.CosmosClientConnection._GetDatabaseAccountCheck = self.MockGetDatabaseAccount(ACCOUNT_REGIONS) + try: + client = CosmosClient(default_endpoint, self.master_key, preferred_locations=preferred_location) + # this will setup the location cache + client.client_connection._global_endpoint_manager.force_refresh_on_startup(None) + finally: + _global_endpoint_manager._GlobalEndpointManager._GetDatabaseAccountStub = self.original_getDatabaseAccountStub + _cosmos_client_connection.CosmosClientConnection._GetDatabaseAccountCheck = self.original_getDatabaseAccountCheck + expected_dual_endpoints = [] + + # if preferred location set should use that + if preferred_location: + expected_locations = preferred_location + # if client created with regional endpoint preferred locations, only use hub region + elif default_endpoint != self.host: + expected_locations = ACCOUNT_REGIONS[:1] + # if client created with global endpoint and no preferred locations, use all regions + else: + expected_locations = ACCOUNT_REGIONS + + for location in expected_locations: + locational_endpoint = _location_cache.LocationCache.GetLocationalEndpoint(self.host, location) + if default_endpoint == self.host or preferred_location: + expected_dual_endpoints.append(RegionalRoutingContext(locational_endpoint, locational_endpoint)) + else: + expected_dual_endpoints.append(RegionalRoutingContext(locational_endpoint, default_endpoint)) + + read_dual_endpoints = client.client_connection._global_endpoint_manager.location_cache.read_regional_routing_contexts + assert read_dual_endpoints == expected_dual_endpoints + + @pytest.mark.cosmosMultiRegion + @pytest.mark.parametrize("error", error()) + def test_read_no_preferred_locations_with_errors(self, setup, error): + container = setup[COLLECTION] + item_to_read = construct_item() + container.create_item(item_to_read) + + # setup fault injection so that first account region fails + custom_transport = FaultInjectionTransport() + error_lambda = lambda r:FaultInjectionTransport.error_after_delay( + 0, + error + ) + expected = _location_cache.LocationCache.GetLocationalEndpoint(self.host, REGION_2) + fault_setup = self.setup_method_with_custom_transport(custom_transport=custom_transport, error_lambda=error_lambda) + fault_container = fault_setup["col"] + response = fault_container.read_item(item=item_to_read["id"], partition_key=item_to_read[self.partition_key]) + request = response.get_response_headers()["_request"] + # Validate the response comes from another region meaning that the account locations were used + assert request.url.startswith(expected) + + # should fail if using excluded locations because no where to failover to + with pytest.raises(CosmosHttpResponseError): + fault_container.read_item(item=item_to_read["id"], partition_key=item_to_read[self.partition_key], excluded_locations=[REGION_2]) + + @pytest.mark.cosmosMultiRegion + def test_write_no_preferred_locations_with_errors(self, setup): + # setup fault injection so that first account region fails + custom_transport = FaultInjectionTransport() + expected = _location_cache.LocationCache.GetLocationalEndpoint(self.host, REGION_2) + error_lambda = lambda r: FaultInjectionTransport.error_region_down() + + fault_setup = self.setup_method_with_custom_transport(custom_transport=custom_transport, error_lambda=error_lambda) + fault_container = fault_setup["col"] + response = fault_container.create_item(body=construct_item()) + request = response.get_response_headers()["_request"] + # Validate the response comes from another region meaning that the account locations were used + assert request.url.startswith(expected) + + # should fail if using excluded locations because no where to failover to + with pytest.raises(ServiceRequestError): + fault_container.create_item(body=construct_item(), excluded_locations=[REGION_2]) + + class MockGetDatabaseAccount(object): + def __init__( + self, + regions: List[str], + ): + self.regions = regions + + def __call__(self, endpoint): + read_regions = self.regions + read_locations = [] + counter = 0 + for loc in read_regions: + locational_endpoint = _location_cache.LocationCache.GetLocationalEndpoint(TestPreferredLocations.host, loc) + read_locations.append({'databaseAccountEndpoint': locational_endpoint, 'name': loc}) + counter += 1 + write_regions = [self.regions[0]] + write_locations = [] + for loc in write_regions: + locational_endpoint = _location_cache.LocationCache.GetLocationalEndpoint(TestPreferredLocations.host, loc) + write_locations.append({'databaseAccountEndpoint': locational_endpoint, 'name': loc}) + multi_write = False + + db_acc = DatabaseAccount() + db_acc.DatabasesLink = "/dbs/" + db_acc.MediaLink = "/media/" + db_acc._ReadableLocations = read_locations + db_acc._WritableLocations = write_locations + db_acc._EnableMultipleWritableLocations = multi_write + db_acc.ConsistencyPolicy = {"defaultConsistencyLevel": "Session"} + return db_acc diff --git a/sdk/cosmos/azure-cosmos/tests/test_effective_preferred_locations_async.py b/sdk/cosmos/azure-cosmos/tests/test_effective_preferred_locations_async.py new file mode 100644 index 000000000000..f5c234511738 --- /dev/null +++ b/sdk/cosmos/azure-cosmos/tests/test_effective_preferred_locations_async.py @@ -0,0 +1,196 @@ +# The MIT License (MIT) +# Copyright (c) Microsoft Corporation. All rights reserved. + +import asyncio +from typing import List + +import pytest +import pytest_asyncio +from azure.core.exceptions import ServiceRequestError + +import test_config +from azure.cosmos import DatabaseAccount, _location_cache +from azure.cosmos._location_cache import RegionalRoutingContext + +from azure.cosmos.aio import _global_endpoint_manager_async, _cosmos_client_connection_async, CosmosClient +from _fault_injection_transport_async import FaultInjectionTransportAsync +from azure.cosmos.exceptions import CosmosHttpResponseError +from test_circuit_breaker_emulator import COLLECTION +from test_effective_preferred_locations import REGION_1, REGION_2, REGION_3, ACCOUNT_REGIONS, construct_item, error + + +@pytest_asyncio.fixture() +async def setup(): + if (TestPreferredLocationsAsync.master_key == '[YOUR_KEY_HERE]' or + TestPreferredLocationsAsync.host == '[YOUR_ENDPOINT_HERE]'): + raise Exception( + "You must specify your Azure Cosmos account values for " + "'masterKey' and 'host' at the top of this class to run the " + "tests.") + + client = CosmosClient(TestPreferredLocationsAsync.host, TestPreferredLocationsAsync.master_key, consistency_level="Session") + created_database = client.get_database_client(TestPreferredLocationsAsync.TEST_DATABASE_ID) + created_collection = created_database.get_container_client(TestPreferredLocationsAsync.TEST_CONTAINER_SINGLE_PARTITION_ID) + yield { + COLLECTION: created_collection + } + await client.close() + +def preferred_locations(): + host = test_config.TestConfig.host + locational_endpoint = _location_cache.LocationCache.GetLocationalEndpoint(host, REGION_2) + return [ + ([], host), + ([REGION_1, REGION_2], host), + ([REGION_1], host), + ([REGION_2, REGION_3], host), + ([REGION_1, REGION_2, REGION_3], host), + ([], locational_endpoint), + ([REGION_2], locational_endpoint), + ([REGION_3, REGION_1], locational_endpoint), + ([REGION_1, REGION_3], locational_endpoint), + ([REGION_1, REGION_2, REGION_3], locational_endpoint) + ] + +@pytest.mark.asyncio +@pytest.mark.usefixtures("setup") +class TestPreferredLocationsAsync: + host = test_config.TestConfig.host + master_key = test_config.TestConfig.masterKey + TEST_DATABASE_ID = test_config.TestConfig.TEST_DATABASE_ID + TEST_CONTAINER_SINGLE_PARTITION_ID = test_config.TestConfig.TEST_SINGLE_PARTITION_CONTAINER_ID + partition_key = test_config.TestConfig.TEST_CONTAINER_PARTITION_KEY + + async def setup_method_with_custom_transport(self, custom_transport, error_lambda, default_endpoint=host, **kwargs): + uri_down = _location_cache.LocationCache.GetLocationalEndpoint(self.host, REGION_1) + predicate = lambda r: (FaultInjectionTransportAsync.predicate_is_document_operation(r) and + (FaultInjectionTransportAsync.predicate_targets_region(r, uri_down) or + FaultInjectionTransportAsync.predicate_targets_region(r, self.host))) + custom_transport.add_fault(predicate, + error_lambda) + client = CosmosClient(default_endpoint, + self.master_key, + multiple_write_locations=True, + transport=custom_transport, **kwargs) + db = client.get_database_client(self.TEST_DATABASE_ID) + container = db.get_container_client(self.TEST_CONTAINER_SINGLE_PARTITION_ID) + return {"client": client, "db": db, "col": container} + + @pytest.mark.cosmosEmulator + @pytest.mark.parametrize("preferred_location, default_endpoint", preferred_locations()) + async def test_effective_preferred_regions_async(self, setup, preferred_location, default_endpoint): + + self.original_getDatabaseAccountStub = _global_endpoint_manager_async._GlobalEndpointManager._GetDatabaseAccountStub + self.original_getDatabaseAccountCheck = _cosmos_client_connection_async.CosmosClientConnection._GetDatabaseAccountCheck + _global_endpoint_manager_async._GlobalEndpointManager._GetDatabaseAccountStub = self.MockGetDatabaseAccount(ACCOUNT_REGIONS) + _cosmos_client_connection_async.CosmosClientConnection._GetDatabaseAccountCheck = self.MockGetDatabaseAccount(ACCOUNT_REGIONS) + try: + client = CosmosClient(default_endpoint, self.master_key, preferred_locations=preferred_location) + # this will setup the location cache + await client.client_connection._global_endpoint_manager.force_refresh_on_startup(None) + finally: + _global_endpoint_manager_async._GlobalEndpointManager._GetDatabaseAccountStub = self.original_getDatabaseAccountStub + _cosmos_client_connection_async.CosmosClientConnection._GetDatabaseAccountCheck = self.original_getDatabaseAccountCheck + expected_dual_endpoints = [] + + # if preferred location set should use that + if preferred_location: + expected_locations = preferred_location + # if client created with regional endpoint preferred locations, only use hub region + elif default_endpoint != self.host: + expected_locations = ACCOUNT_REGIONS[:1] + # if client created with global endpoint and no preferred locations, use all regions + else: + expected_locations = ACCOUNT_REGIONS + + for location in expected_locations: + locational_endpoint = _location_cache.LocationCache.GetLocationalEndpoint(self.host, location) + if default_endpoint == self.host or preferred_location: + expected_dual_endpoints.append(RegionalRoutingContext(locational_endpoint, locational_endpoint)) + else: + expected_dual_endpoints.append(RegionalRoutingContext(locational_endpoint, default_endpoint)) + + read_dual_endpoints = client.client_connection._global_endpoint_manager.location_cache.read_regional_routing_contexts + assert read_dual_endpoints == expected_dual_endpoints + + @pytest.mark.cosmosMultiRegion + @pytest.mark.parametrize("error", error()) + async def test_read_no_preferred_locations_with_errors_async(self, setup, error): + container = setup[COLLECTION] + item_to_read = construct_item() + await container.create_item(item_to_read) + + # setup fault injection so that first account region fails + custom_transport = FaultInjectionTransportAsync() + expected = _location_cache.LocationCache.GetLocationalEndpoint(self.host, REGION_2) + error_lambda = lambda r: asyncio.create_task(FaultInjectionTransportAsync.error_after_delay( + 0, + error + )) + try: + fault_setup = await self.setup_method_with_custom_transport(custom_transport=custom_transport, error_lambda=error_lambda) + fault_container = fault_setup["col"] + response = await fault_container.read_item(item=item_to_read["id"], partition_key=item_to_read[self.partition_key]) + request = response.get_response_headers()["_request"] + # Validate the response comes from another region meaning that the account locations were used + assert request.url.startswith(expected) + + # should fail if using excluded locations because no where to failover to + with pytest.raises(CosmosHttpResponseError): + await fault_container.read_item(item=item_to_read["id"], partition_key=item_to_read[self.partition_key], excluded_locations=[REGION_2]) + + finally: + await fault_setup["client"].close() + + @pytest.mark.cosmosMultiRegion + async def test_write_no_preferred_locations_with_errors_async(self, setup): + # setup fault injection so that first account region fails + custom_transport = FaultInjectionTransportAsync() + expected = _location_cache.LocationCache.GetLocationalEndpoint(self.host, REGION_2) + error_lambda = lambda r: asyncio.create_task(FaultInjectionTransportAsync.error_region_down()) + + try: + fault_setup = await self.setup_method_with_custom_transport(custom_transport=custom_transport, error_lambda=error_lambda) + fault_container = fault_setup["col"] + response = await fault_container.create_item(body=construct_item()) + request = response.get_response_headers()["_request"] + # Validate the response comes from another region meaning that the account locations were used + assert request.url.startswith(expected) + + # should fail if using excluded locations because no where to failover to + with pytest.raises(ServiceRequestError): + await fault_container.create_item(body=construct_item(), excluded_locations=[REGION_2]) + + finally: + await fault_setup["client"].close() + + class MockGetDatabaseAccount(object): + def __init__( + self, + regions: List[str], + ): + self.regions = regions + + async def __call__(self, endpoint): + read_regions = self.regions + read_locations = [] + counter = 0 + for loc in read_regions: + locational_endpoint = _location_cache.LocationCache.GetLocationalEndpoint(TestPreferredLocationsAsync.host, loc) + read_locations.append({'databaseAccountEndpoint': locational_endpoint, 'name': loc}) + counter += 1 + write_regions = [self.regions[0]] + write_locations = [] + for loc in write_regions: + locational_endpoint = _location_cache.LocationCache.GetLocationalEndpoint(TestPreferredLocationsAsync.host, loc) + write_locations.append({'databaseAccountEndpoint': locational_endpoint, 'name': loc}) + multi_write = False + + db_acc = DatabaseAccount() + db_acc.DatabasesLink = "/dbs/" + db_acc.MediaLink = "/media/" + db_acc._ReadableLocations = read_locations + db_acc._WritableLocations = write_locations + db_acc._EnableMultipleWritableLocations = multi_write + db_acc.ConsistencyPolicy = {"defaultConsistencyLevel": "Session"} + return db_acc diff --git a/sdk/cosmos/azure-cosmos/tests/test_location_cache.py b/sdk/cosmos/azure-cosmos/tests/test_location_cache.py index 887be44f2273..1a0f132f8e0f 100644 --- a/sdk/cosmos/azure-cosmos/tests/test_location_cache.py +++ b/sdk/cosmos/azure-cosmos/tests/test_location_cache.py @@ -95,7 +95,7 @@ def test_get_locations(self): # check read endpoints without preferred locations read_regions = lc.get_read_regional_routing_contexts() - assert len(read_regions) == 1 + assert len(read_regions) == 3 assert read_regions[0].get_primary() == location1_endpoint # check read endpoints with preferred locations diff --git a/sdk/cosmos/azure-cosmos/tests/test_regional_routing_context.py b/sdk/cosmos/azure-cosmos/tests/test_regional_routing_context.py index 173b4c20f35b..f44d9c75eeae 100644 --- a/sdk/cosmos/azure-cosmos/tests/test_regional_routing_context.py +++ b/sdk/cosmos/azure-cosmos/tests/test_regional_routing_context.py @@ -6,18 +6,13 @@ import pytest import test_config -from azure.cosmos import (CosmosClient, DatabaseAccount, _global_endpoint_manager) -from azure.cosmos._location_cache import RegionalRoutingContext +from azure.cosmos import CosmosClient @pytest.mark.cosmosEmulator class TestRegionalRoutingContext(unittest.TestCase): host = test_config.TestConfig.host masterKey = test_config.TestConfig.masterKey - REGION1 = "West US" - REGION2 = "East US" - REGION3 = "West US 2" - REGIONAL_ROUTING_CONTEXT = RegionalRoutingContext(host, "something_different") TEST_DATABASE_ID = test_config.TestConfig.TEST_DATABASE_ID TEST_CONTAINER_ID = test_config.TestConfig.TEST_SINGLE_PARTITION_CONTAINER_ID @@ -34,42 +29,14 @@ def setUpClass(cls): cls.created_container = cls.created_database.get_container_client(cls.TEST_CONTAINER_ID) def test_no_swaps_on_successful_request(self): - original_get_database_account_stub = _global_endpoint_manager._GlobalEndpointManager._GetDatabaseAccountStub - _global_endpoint_manager._GlobalEndpointManager._GetDatabaseAccountStub = self.MockGetDatabaseAccountStub - mocked_client = CosmosClient(self.host, self.masterKey) - db = mocked_client.get_database_client(self.TEST_DATABASE_ID) - container = db.get_container_client(self.TEST_CONTAINER_ID) - # Mock the GetDatabaseAccountStub to return the regional endpoints - original_read_endpoint = (mocked_client.client_connection._global_endpoint_manager + original_read_endpoint = (self.client.client_connection._global_endpoint_manager .location_cache.get_read_regional_routing_context()) - try: - container.create_item(body={"id": str(uuid.uuid4())}) - finally: - # Check for if there was a swap - self.assertEqual(original_read_endpoint, - mocked_client.client_connection._global_endpoint_manager - .location_cache.get_read_regional_routing_context()) - self.assertEqual(self.REGIONAL_ROUTING_CONTEXT.get_primary(), - mocked_client.client_connection._global_endpoint_manager - .location_cache.get_write_regional_routing_context()) - _global_endpoint_manager._GlobalEndpointManager._GetDatabaseAccountStub = original_get_database_account_stub - - def MockGetDatabaseAccountStub(self, endpoint): - read_locations = [] - read_locations.append({'databaseAccountEndpoint': endpoint, 'name': "West US"}) - read_locations.append({'databaseAccountEndpoint': "some different endpoint", 'name': "East US"}) - write_regions = ["West US"] - write_locations = [] - for loc in write_regions: - write_locations.append({'databaseAccountEndpoint': endpoint, 'name': loc}) - multi_write = False - - db_acc = DatabaseAccount() - db_acc.DatabasesLink = "/dbs/" - db_acc.MediaLink = "/media/" - db_acc._ReadableLocations = read_locations - db_acc._WritableLocations = write_locations - db_acc._EnableMultipleWritableLocations = multi_write - db_acc.ConsistencyPolicy = {"defaultConsistencyLevel": "Session"} - return db_acc + self.created_container.create_item(body={"id": str(uuid.uuid4())}) + # Check for if there was a swap + self.assertEqual(original_read_endpoint, + self.client.client_connection._global_endpoint_manager + .location_cache.get_read_regional_routing_context()) + self.assertEqual(original_read_endpoint, + self.client.client_connection._global_endpoint_manager + .location_cache.get_write_regional_routing_context()) diff --git a/sdk/cosmos/azure-cosmos/tests/test_regional_routing_context_async.py b/sdk/cosmos/azure-cosmos/tests/test_regional_routing_context_async.py index 134f059dde99..5886d7ce58c9 100644 --- a/sdk/cosmos/azure-cosmos/tests/test_regional_routing_context_async.py +++ b/sdk/cosmos/azure-cosmos/tests/test_regional_routing_context_async.py @@ -6,19 +6,13 @@ import pytest import test_config -from azure.cosmos import DatabaseAccount -from azure.cosmos._location_cache import RegionalRoutingContext -from azure.cosmos.aio import CosmosClient, _global_endpoint_manager_async +from azure.cosmos.aio import CosmosClient @pytest.mark.cosmosEmulator class TestRegionalRoutingContextAsync(unittest.IsolatedAsyncioTestCase): host = test_config.TestConfig.host masterKey = test_config.TestConfig.masterKey - REGION1 = "West US" - REGION2 = "East US" - REGION3 = "West US 2" - REGIONAL_ROUTING_CONTEXT = RegionalRoutingContext(host, "something_different") TEST_DATABASE_ID = test_config.TestConfig.TEST_DATABASE_ID TEST_CONTAINER_ID = test_config.TestConfig.TEST_SINGLE_PARTITION_CONTAINER_ID @@ -39,45 +33,16 @@ async def asyncSetUp(self): async def asyncTearDown(self): await self.client.close() - async def test_no_swaps_on_successful_request(self): - original_get_database_account_stub = _global_endpoint_manager_async._GlobalEndpointManager._GetDatabaseAccountStub - _global_endpoint_manager_async._GlobalEndpointManager._GetDatabaseAccountStub = self.MockGetDatabaseAccountStub - mocked_client = CosmosClient(self.host, self.masterKey) - db = mocked_client.get_database_client(self.TEST_DATABASE_ID) - container = db.get_container_client(self.TEST_CONTAINER_ID) - # Mock the GetDatabaseAccountStub to return the regional endpoints - - original_read_endpoint = (mocked_client.client_connection._global_endpoint_manager + async def test_no_swaps_on_successful_request_async(self): + # Make sure that getDatabaseAccount call has finished + await self.client.client_connection._global_endpoint_manager.force_refresh_on_startup(None) + original_read_endpoint = (self.client.client_connection._global_endpoint_manager .location_cache.get_read_regional_routing_context()) - try: - await container.create_item(body={"id": str(uuid.uuid4())}) - finally: - # Check for if there was a swap - self.assertEqual(original_read_endpoint, - mocked_client.client_connection._global_endpoint_manager - .location_cache.get_read_regional_routing_context()) - # return it - self.assertEqual(self.REGIONAL_ROUTING_CONTEXT.get_primary(), - mocked_client.client_connection._global_endpoint_manager - .location_cache.get_write_regional_routing_context()) - _global_endpoint_manager_async._GlobalEndpointManager._GetDatabaseAccountStub = original_get_database_account_stub - await mocked_client.close() - - async def MockGetDatabaseAccountStub(self, endpoint): - read_locations = [] - read_locations.append({'databaseAccountEndpoint': endpoint, 'name': "West US"}) - read_locations.append({'databaseAccountEndpoint': "some different endpoint", 'name': "East US"}) - write_regions = ["West US"] - write_locations = [] - for loc in write_regions: - write_locations.append({'databaseAccountEndpoint': endpoint, 'name': loc}) - multi_write = False - - db_acc = DatabaseAccount() - db_acc.DatabasesLink = "/dbs/" - db_acc.MediaLink = "/media/" - db_acc._ReadableLocations = read_locations - db_acc._WritableLocations = write_locations - db_acc._EnableMultipleWritableLocations = multi_write - db_acc.ConsistencyPolicy = {"defaultConsistencyLevel": "Session"} - return db_acc + await self.created_container.create_item(body={"id": str(uuid.uuid4())}) + # Check for if there was a swap + self.assertEqual(original_read_endpoint, + self.client.client_connection._global_endpoint_manager + .location_cache.get_read_regional_routing_context()) + self.assertEqual(original_read_endpoint, + self.client.client_connection._global_endpoint_manager + .location_cache.get_write_regional_routing_context())