diff --git a/sdk/cosmos/azure-cosmos/CHANGELOG.md b/sdk/cosmos/azure-cosmos/CHANGELOG.md index 44f14bb3676d..2d58c009440a 100644 --- a/sdk/cosmos/azure-cosmos/CHANGELOG.md +++ b/sdk/cosmos/azure-cosmos/CHANGELOG.md @@ -1,4 +1,14 @@ ## Release History +### 4.10.1b1 (Unreleased) + +#### Features Added + +#### Breaking Changes +* Adds cross region retries when no preferred locations are set. This is only a breaking change for customers using bounded staleness consistency. See [PR 39714](https://github.com/Azure/azure-sdk-for-python/pull/39714) + +#### Bugs Fixed + +#### Other Changes ### 4.10.0b1 (2025-02-13) diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/_location_cache.py b/sdk/cosmos/azure-cosmos/azure/cosmos/_location_cache.py index d7d68a4563ef..43d65acdc931 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/_location_cache.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/_location_cache.py @@ -23,8 +23,10 @@ DatabaseAccount with multiple writable and readable locations. """ import collections +import difflib import logging import time +from typing import List, OrderedDict from urllib.parse import urlparse from . import documents @@ -113,7 +115,6 @@ def get_endpoints_by_location(new_locations, default_regional_endpoint.get_current(), new_location["name"]) regional_object.set_previous(constructed_region_uri) - # pass in object with region uri , last known good, curr etc endpoints_by_location.update({new_location["name"]: regional_object}) except Exception as e: raise e @@ -134,6 +135,7 @@ def __init__( refresh_time_interval_in_ms, ): self.preferred_locations = preferred_locations + self.effective_preferred_locations = [] self.default_regional_endpoint = RegionalEndpoint(default_endpoint, default_endpoint) self.enable_endpoint_discovery = enable_endpoint_discovery self.use_multiple_write_locations = use_multiple_write_locations @@ -245,7 +247,7 @@ def resolve_service_endpoint(self, request): return regional_endpoint.get_current() def should_refresh_endpoints(self): # pylint: disable=too-many-return-statements - most_preferred_location = self.preferred_locations[0] if self.preferred_locations else None + most_preferred_location = self.effective_preferred_locations[0] if self.effective_preferred_locations else None # we should schedule refresh in background if we are unable to target the user's most preferredLocation. if self.enable_endpoint_discovery: @@ -357,9 +359,6 @@ def mark_endpoint_unavailable(self, unavailable_endpoint: str, unavailable_opera if refresh_cache: self.update_location_cache() - def get_preferred_locations(self): - return self.preferred_locations - def update_location_cache(self, write_locations=None, read_locations=None, enable_multiple_writable_locations=None): if enable_multiple_writable_locations: self.enable_multiple_writable_locations = enable_multiple_writable_locations @@ -387,6 +386,19 @@ def update_location_cache(self, write_locations=None, read_locations=None, enabl self.use_multiple_write_locations, ) + # if preferred locations is empty and the default endpoint is a global endpoint, + # we should use the read locations from gateway + if self.preferred_locations: + self.effective_preferred_locations = self.preferred_locations + elif not LocationCache.is_global_endpoint( + self.default_regional_endpoint.get_current(), + self.available_read_regional_endpoints_by_location, + self.available_read_locations + ): + self.effective_preferred_locations = [] + else: + self.effective_preferred_locations = self.available_read_locations + self.write_regional_endpoints = self.get_preferred_available_regional_endpoints( self.available_write_regional_endpoints_by_location, self.available_write_locations, @@ -413,12 +425,12 @@ def get_preferred_available_regional_endpoints( # pylint: disable=name-too-long or expected_available_operation == EndpointOperationType.ReadType ): unavailable_endpoints = [] - if self.preferred_locations: + if self.effective_preferred_locations: # When client can not use multiple write locations, preferred locations # list should only be used determining read endpoints order. If client # can use multiple write locations, preferred locations list should be # used for determining both read and write endpoints order. - for location in self.preferred_locations: + for location in self.effective_preferred_locations: regional_endpoint = endpoints_by_location[location] if location in endpoints_by_location \ else None if regional_endpoint: @@ -481,3 +493,46 @@ def GetLocationalEndpoint(default_endpoint, location_name): return locational_endpoint return None + + @staticmethod + def is_global_endpoint( + endpoint: str, + account_read_endpoints_by_location: OrderedDict[str, RegionalEndpoint], + read_locations: List[str] + ): + + # if there is only one read location, the sdk cannot figure out the global endpoint + if len(read_locations) < 2: + return True + + first_endpoint_url = urlparse(account_read_endpoints_by_location[read_locations[0]].get_current()) + second_endpoint_url = urlparse(account_read_endpoints_by_location[read_locations[1]].get_current()) + + # hostname attribute in endpoint_url will return something like 'contoso.documents.azure.com' + if first_endpoint_url.hostname is not None and second_endpoint_url.hostname is not None: + first_hostname_parts = str(first_endpoint_url.hostname).lower().split(".") + # first account name will return something like 'contoso-eastus' + first_account_name = first_hostname_parts[0] + second_hostname_parts = str(second_endpoint_url.hostname).lower().split(".") + # second account name will return something like 'contoso-westus' + second_account_name = second_hostname_parts[0] + # + common_account_name = LocationCache.get_common_part(first_account_name, second_account_name) + + # if both were regional endpoints the common account name will have a - at the end + if common_account_name[len(common_account_name) - 1] == "-": + global_account_name = common_account_name[:-1] + else: + global_account_name = common_account_name + default_endpoint_account_name = str(urlparse(endpoint).hostname).lower().split('.', maxsplit=1)[0] + return global_account_name == default_endpoint_account_name + return False + + # finds the longest matching consecutive substring in two strings + @staticmethod + def get_common_part(str1: str, str2: str): + sequence_matcher = difflib.SequenceMatcher(None, str1, str2) + match = sequence_matcher.find_longest_match(0, len(str1), 0, len(str2)) + if match.size != 0: + return str1[match.a: match.a + match.size] + return "" diff --git a/sdk/cosmos/azure-cosmos/test/test_effective_preferred_locations.py b/sdk/cosmos/azure-cosmos/test/test_effective_preferred_locations.py new file mode 100644 index 000000000000..c3772d7d2a12 --- /dev/null +++ b/sdk/cosmos/azure-cosmos/test/test_effective_preferred_locations.py @@ -0,0 +1,168 @@ +import random +from collections import OrderedDict +from typing import List + +import pytest +import test_config +from azure.cosmos import DatabaseAccount, _location_cache, CosmosClient, _global_endpoint_manager, \ + _cosmos_client_connection + +from azure.cosmos._location_cache import RegionalEndpoint, LocationCache + +COLLECTION = "created_collection" +REGION_1 = "East US" +REGION_2 = "West US" +REGION_3 = "West US 2" +ACCOUNT_REGIONS = [REGION_1, REGION_2, REGION_3] + +@pytest.fixture() +def setup(): + if (TestPreferredLocations.masterKey == '[YOUR_KEY_HERE]' or + TestPreferredLocations.host == '[YOUR_ENDPOINT_HERE]'): + raise Exception( + "You must specify your Azure Cosmos account values for " + "'masterKey' and 'host' at the top of this class to run the " + "tests.") + + client = CosmosClient(TestPreferredLocations.host, TestPreferredLocations.masterKey, consistency_level="Session") + created_database = client.get_database_client(TestPreferredLocations.TEST_DATABASE_ID) + created_collection = created_database.get_container_client(TestPreferredLocations.TEST_CONTAINER_SINGLE_PARTITION_ID) + yield { + COLLECTION: created_collection + } + +def preferred_locations(): + host = test_config.TestConfig.host + locational_endpoint = _location_cache.LocationCache.GetLocationalEndpoint(host, REGION_2) + return [ + ([], host), + ([REGION_1, REGION_2], host), + ([REGION_1], host), + ([REGION_2, REGION_3], host), + ([REGION_1, REGION_2, REGION_3], host), + ([], locational_endpoint), + ([REGION_2], locational_endpoint), + ([REGION_3, REGION_1], locational_endpoint), + ([REGION_1, REGION_3], locational_endpoint), + ([REGION_1, REGION_2, REGION_3], locational_endpoint) + ] + +def create_account_endpoints_by_location(global_endpoint: str): + locational_endpoints = [] + for region in ACCOUNT_REGIONS: + locational_endpoints.append(_location_cache.LocationCache.GetLocationalEndpoint(global_endpoint, region)) + _location_cache.LocationCache.GetLocationalEndpoint(global_endpoint, REGION_1) + account_endpoints_by_location = OrderedDict() + for i, region in enumerate(ACCOUNT_REGIONS): + # should create some with the global endpoint sometimes + if random.random() < 0.5: + account_endpoints_by_location[region] = RegionalEndpoint(locational_endpoints[i], global_endpoint) + else: + account_endpoints_by_location[region] = RegionalEndpoint(global_endpoint, locational_endpoints[i]) + + + return locational_endpoints, account_endpoints_by_location + +def is_global_endpoint_inputs(): + global_endpoint = test_config.TestConfig.host + locational_endpoints, account_endpoints_by_location = create_account_endpoints_by_location(global_endpoint) + + # testing if customers account name includes a region ex. contoso-eastus + global_endpoint_2 = _location_cache.LocationCache.GetLocationalEndpoint(global_endpoint, REGION_1) + locational_endpoints_2, account_endpoints_by_location_2 = create_account_endpoints_by_location(global_endpoint_2) + + # endpoint, locations, account_endpoints_by_location, result + return [ + (global_endpoint, ACCOUNT_REGIONS, account_endpoints_by_location, True), + (locational_endpoints[0], ACCOUNT_REGIONS, account_endpoints_by_location, False), + (locational_endpoints[1], ACCOUNT_REGIONS, account_endpoints_by_location, False), + (locational_endpoints[2], ACCOUNT_REGIONS, account_endpoints_by_location, False), + (global_endpoint_2, ACCOUNT_REGIONS, account_endpoints_by_location_2, True), + (locational_endpoints_2[0], ACCOUNT_REGIONS, account_endpoints_by_location_2, False), + (locational_endpoints_2[1], ACCOUNT_REGIONS, account_endpoints_by_location_2, False), + (locational_endpoints_2[2], ACCOUNT_REGIONS, account_endpoints_by_location_2, False) + ] + + +@pytest.mark.cosmosEmulator +@pytest.mark.unittest +@pytest.mark.usefixtures("setup") +class TestPreferredLocations: + host = test_config.TestConfig.host + masterKey = test_config.TestConfig.masterKey + connectionPolicy = test_config.TestConfig.connectionPolicy + TEST_DATABASE_ID = test_config.TestConfig.TEST_DATABASE_ID + TEST_CONTAINER_SINGLE_PARTITION_ID = test_config.TestConfig.TEST_SINGLE_PARTITION_CONTAINER_ID + + @pytest.mark.parametrize("preferred_location, default_endpoint", preferred_locations()) + def test_effective_preferred_regions(self, setup, preferred_location, default_endpoint): + + self.original_getDatabaseAccountStub = _global_endpoint_manager._GlobalEndpointManager._GetDatabaseAccountStub + self.original_getDatabaseAccountCheck = _cosmos_client_connection.CosmosClientConnection._GetDatabaseAccountCheck + _global_endpoint_manager._GlobalEndpointManager._GetDatabaseAccountStub = self.MockGetDatabaseAccount(ACCOUNT_REGIONS) + _cosmos_client_connection.CosmosClientConnection._GetDatabaseAccountCheck = self.MockGetDatabaseAccount(ACCOUNT_REGIONS) + try: + client = CosmosClient(default_endpoint, self.masterKey, preferred_locations=preferred_location) + # this will setup the location cache + client.client_connection._global_endpoint_manager.force_refresh(None) + finally: + _global_endpoint_manager._GlobalEndpointManager._GetDatabaseAccountStub = self.original_getDatabaseAccountStub + _cosmos_client_connection.CosmosClientConnection._GetDatabaseAccountCheck = self.original_getDatabaseAccountCheck + expected_dual_endpoints = [] + + # if preferred location set should use that + if preferred_location: + expected_locations = preferred_location + # if client created with regional endpoint preferred locations, only use hub region + elif default_endpoint != self.host: + expected_locations = ACCOUNT_REGIONS[:1] + # if client created with global endpoint and no preferred locations, use all regions + else: + expected_locations = ACCOUNT_REGIONS + + for location in expected_locations: + locational_endpoint = _location_cache.LocationCache.GetLocationalEndpoint(self.host, location) + if default_endpoint == self.host or preferred_location: + expected_dual_endpoints.append(RegionalEndpoint(locational_endpoint, locational_endpoint)) + else: + expected_dual_endpoints.append(RegionalEndpoint(locational_endpoint, default_endpoint)) + + read_dual_endpoints = client.client_connection._global_endpoint_manager.location_cache.read_regional_endpoints + assert read_dual_endpoints == expected_dual_endpoints + + + @pytest.mark.parametrize("default_endpoint, read_locations, account_endpoints_by_location, result", is_global_endpoint_inputs()) + def test_is_global_endpoint(self, default_endpoint, read_locations, account_endpoints_by_location, result): + assert result == LocationCache.is_global_endpoint(default_endpoint, account_endpoints_by_location, read_locations) + + + class MockGetDatabaseAccount(object): + def __init__( + self, + regions: List[str], + ): + self.regions = regions + + def __call__(self, endpoint): + read_regions = self.regions + read_locations = [] + counter = 0 + for loc in read_regions: + locational_endpoint = _location_cache.LocationCache.GetLocationalEndpoint(TestPreferredLocations.host, loc) + read_locations.append({'databaseAccountEndpoint': locational_endpoint, 'name': loc}) + counter += 1 + write_regions = [self.regions[0]] + write_locations = [] + for loc in write_regions: + locational_endpoint = _location_cache.LocationCache.GetLocationalEndpoint(TestPreferredLocations.host, loc) + write_locations.append({'databaseAccountEndpoint': locational_endpoint, 'name': loc}) + multi_write = False + + db_acc = DatabaseAccount() + db_acc.DatabasesLink = "/dbs/" + db_acc.MediaLink = "/media/" + db_acc._ReadableLocations = read_locations + db_acc._WritableLocations = write_locations + db_acc._EnableMultipleWritableLocations = multi_write + db_acc.ConsistencyPolicy = {"defaultConsistencyLevel": "Session"} + return db_acc diff --git a/sdk/cosmos/azure-cosmos/test/test_effective_preferred_locations_async.py b/sdk/cosmos/azure-cosmos/test/test_effective_preferred_locations_async.py new file mode 100644 index 000000000000..700faedc77df --- /dev/null +++ b/sdk/cosmos/azure-cosmos/test/test_effective_preferred_locations_async.py @@ -0,0 +1,126 @@ +from typing import List + +import pytest +import pytest_asyncio + +import test_config +from azure.cosmos import DatabaseAccount, _location_cache + +from azure.cosmos._location_cache import RegionalEndpoint +from azure.cosmos.aio import _global_endpoint_manager_async, _cosmos_client_connection_async, CosmosClient + +COLLECTION = "created_collection" +REGION_1 = "East US" +REGION_2 = "West US" +REGION_3 = "West US 2" +ACCOUNT_REGIONS = [REGION_1, REGION_2, REGION_3] + +@pytest_asyncio.fixture() +async def setup(): + if (TestPreferredLocationsAsync.masterKey == '[YOUR_KEY_HERE]' or + TestPreferredLocationsAsync.host == '[YOUR_ENDPOINT_HERE]'): + raise Exception( + "You must specify your Azure Cosmos account values for " + "'masterKey' and 'host' at the top of this class to run the " + "tests.") + + client = CosmosClient(TestPreferredLocationsAsync.host, TestPreferredLocationsAsync.masterKey, consistency_level="Session") + created_database = client.get_database_client(TestPreferredLocationsAsync.TEST_DATABASE_ID) + created_collection = created_database.get_container_client(TestPreferredLocationsAsync.TEST_CONTAINER_SINGLE_PARTITION_ID) + yield { + COLLECTION: created_collection + } + await client.close() + +def preferred_locations(): + host = test_config.TestConfig.host + locational_endpoint = _location_cache.LocationCache.GetLocationalEndpoint(host, REGION_2) + return [ + ([], host), + ([REGION_1, REGION_2], host), + ([REGION_1], host), + ([REGION_2, REGION_3], host), + ([REGION_1, REGION_2, REGION_3], host), + ([], locational_endpoint), + ([REGION_2], locational_endpoint), + ([REGION_3, REGION_1], locational_endpoint), + ([REGION_1, REGION_3], locational_endpoint), + ([REGION_1, REGION_2, REGION_3], locational_endpoint) + ] + +@pytest.mark.cosmosEmulator +@pytest.mark.asyncio +@pytest.mark.usefixtures("setup") +class TestPreferredLocationsAsync: + host = test_config.TestConfig.host + masterKey = test_config.TestConfig.masterKey + connectionPolicy = test_config.TestConfig.connectionPolicy + TEST_DATABASE_ID = test_config.TestConfig.TEST_DATABASE_ID + TEST_CONTAINER_SINGLE_PARTITION_ID = test_config.TestConfig.TEST_SINGLE_PARTITION_CONTAINER_ID + + @pytest.mark.parametrize("preferred_location, default_endpoint", preferred_locations()) + async def test_effective_preferred_regions_async(self, setup, preferred_location, default_endpoint): + + self.original_getDatabaseAccountStub = _global_endpoint_manager_async._GlobalEndpointManager._GetDatabaseAccountStub + self.original_getDatabaseAccountCheck = _cosmos_client_connection_async.CosmosClientConnection._GetDatabaseAccountCheck + _global_endpoint_manager_async._GlobalEndpointManager._GetDatabaseAccountStub = self.MockGetDatabaseAccount(ACCOUNT_REGIONS) + _cosmos_client_connection_async.CosmosClientConnection._GetDatabaseAccountCheck = self.MockGetDatabaseAccount(ACCOUNT_REGIONS) + try: + client = CosmosClient(default_endpoint, self.masterKey, preferred_locations=preferred_location) + # this will setup the location cache + await client.client_connection._global_endpoint_manager.force_refresh(None) + finally: + _global_endpoint_manager_async._GlobalEndpointManager._GetDatabaseAccountStub = self.original_getDatabaseAccountStub + _cosmos_client_connection_async.CosmosClientConnection._GetDatabaseAccountCheck = self.original_getDatabaseAccountCheck + expected_dual_endpoints = [] + + # if preferred location set should use that + if preferred_location: + expected_locations = preferred_location + # if client created with regional endpoint preferred locations, only use hub region + elif default_endpoint != self.host: + expected_locations = ACCOUNT_REGIONS[:1] + # if client created with global endpoint and no preferred locations, use all regions + else: + expected_locations = ACCOUNT_REGIONS + + for location in expected_locations: + locational_endpoint = _location_cache.LocationCache.GetLocationalEndpoint(self.host, location) + if default_endpoint == self.host or preferred_location: + expected_dual_endpoints.append(RegionalEndpoint(locational_endpoint, locational_endpoint)) + else: + expected_dual_endpoints.append(RegionalEndpoint(locational_endpoint, default_endpoint)) + + read_dual_endpoints = client.client_connection._global_endpoint_manager.location_cache.read_regional_endpoints + assert read_dual_endpoints == expected_dual_endpoints + + class MockGetDatabaseAccount(object): + def __init__( + self, + regions: List[str], + ): + self.regions = regions + + async def __call__(self, endpoint): + read_regions = self.regions + read_locations = [] + counter = 0 + for loc in read_regions: + locational_endpoint = _location_cache.LocationCache.GetLocationalEndpoint(TestPreferredLocationsAsync.host, loc) + read_locations.append({'databaseAccountEndpoint': locational_endpoint, 'name': loc}) + counter += 1 + write_regions = [self.regions[0]] + write_locations = [] + for loc in write_regions: + locational_endpoint = _location_cache.LocationCache.GetLocationalEndpoint(TestPreferredLocationsAsync.host, loc) + write_locations.append({'databaseAccountEndpoint': locational_endpoint, 'name': loc}) + multi_write = False + + db_acc = DatabaseAccount() + db_acc.DatabasesLink = "/dbs/" + db_acc.MediaLink = "/media/" + db_acc._ReadableLocations = read_locations + db_acc._WritableLocations = write_locations + db_acc._EnableMultipleWritableLocations = multi_write + db_acc.ConsistencyPolicy = {"defaultConsistencyLevel": "Session"} + return db_acc diff --git a/sdk/cosmos/azure-cosmos/test/test_location_cache.py b/sdk/cosmos/azure-cosmos/test/test_location_cache.py index 1fb4927b5862..4e4ef349f690 100644 --- a/sdk/cosmos/azure-cosmos/test/test_location_cache.py +++ b/sdk/cosmos/azure-cosmos/test/test_location_cache.py @@ -102,7 +102,7 @@ def test_get_locations(self): # check read endpoints without preferred locations read_regions = lc.get_read_regional_endpoints() - assert len(read_regions) == 1 + assert len(read_regions) == 3 assert read_regions[0].get_current() == location1_endpoint # check read endpoints with preferred locations diff --git a/sdk/cosmos/azure-cosmos/test/test_regional_endpoint.py b/sdk/cosmos/azure-cosmos/test/test_regional_endpoint.py index c8915be30972..7a7681adcd65 100644 --- a/sdk/cosmos/azure-cosmos/test/test_regional_endpoint.py +++ b/sdk/cosmos/azure-cosmos/test/test_regional_endpoint.py @@ -14,10 +14,6 @@ class TestRegionalEndpoints(unittest.TestCase): host = test_config.TestConfig.host masterKey = test_config.TestConfig.masterKey - REGION1 = "West US" - REGION2 = "East US" - REGION3 = "West US 2" - REGIONAL_ENDPOINT = RegionalEndpoint(host, "something_different") TEST_DATABASE_ID = test_config.TestConfig.TEST_DATABASE_ID TEST_CONTAINER_ID = test_config.TestConfig.TEST_SINGLE_PARTITION_CONTAINER_ID @@ -34,43 +30,14 @@ def setUpClass(cls): cls.created_container = cls.created_database.get_container_client(cls.TEST_CONTAINER_ID) def test_no_swaps_on_successful_request(self): - original_get_database_account_stub = _global_endpoint_manager._GlobalEndpointManager._GetDatabaseAccountStub - _global_endpoint_manager._GlobalEndpointManager._GetDatabaseAccountStub = self.MockGetDatabaseAccountStub - mocked_client = CosmosClient(self.host, self.masterKey) - db = mocked_client.get_database_client(self.TEST_DATABASE_ID) - container = db.get_container_client(self.TEST_CONTAINER_ID) - # Mock the GetDatabaseAccountStub to return the regional endpoints - original_read_endpoint = (mocked_client.client_connection._global_endpoint_manager + original_read_endpoint = (self.client.client_connection._global_endpoint_manager .location_cache.get_read_regional_endpoint()) - try: - container.create_item(body={"id": str(uuid.uuid4())}) - finally: - # Check for if there was a swap - self.assertEqual(original_read_endpoint, - mocked_client.client_connection._global_endpoint_manager - .location_cache.get_read_regional_endpoint()) - # return it - self.assertEqual(self.REGIONAL_ENDPOINT.get_current(), - mocked_client.client_connection._global_endpoint_manager - .location_cache.get_write_regional_endpoint()) - _global_endpoint_manager._GlobalEndpointManager._GetDatabaseAccountStub = original_get_database_account_stub - - def MockGetDatabaseAccountStub(self, endpoint): - read_locations = [] - read_locations.append({'databaseAccountEndpoint': endpoint, 'name': "West US"}) - read_locations.append({'databaseAccountEndpoint': "some different endpoint", 'name': "West US"}) - write_regions = ["West US"] - write_locations = [] - for loc in write_regions: - write_locations.append({'databaseAccountEndpoint': endpoint, 'name': loc}) - multi_write = False - - db_acc = DatabaseAccount() - db_acc.DatabasesLink = "/dbs/" - db_acc.MediaLink = "/media/" - db_acc._ReadableLocations = read_locations - db_acc._WritableLocations = write_locations - db_acc._EnableMultipleWritableLocations = multi_write - db_acc.ConsistencyPolicy = {"defaultConsistencyLevel": "Session"} - return db_acc + self.created_container.create_item(body={"id": str(uuid.uuid4())}) + # Check for if there was a swap + self.assertEqual(original_read_endpoint, + self.client.client_connection._global_endpoint_manager + .location_cache.get_read_regional_endpoint()) + self.assertEqual(original_read_endpoint, + self.client.client_connection._global_endpoint_manager + .location_cache.get_write_regional_endpoint()) diff --git a/sdk/cosmos/azure-cosmos/test/test_regional_endpoint_async.py b/sdk/cosmos/azure-cosmos/test/test_regional_endpoint_async.py index e9dab2a9ac08..a46bc0b888f1 100644 --- a/sdk/cosmos/azure-cosmos/test/test_regional_endpoint_async.py +++ b/sdk/cosmos/azure-cosmos/test/test_regional_endpoint_async.py @@ -6,19 +6,14 @@ import pytest import test_config -from azure.cosmos import DatabaseAccount from azure.cosmos._location_cache import RegionalEndpoint -from azure.cosmos.aio import CosmosClient, _global_endpoint_manager_async +from azure.cosmos.aio import CosmosClient @pytest.mark.cosmosEmulator class TestRegionalEndpoints(unittest.IsolatedAsyncioTestCase): host = test_config.TestConfig.host masterKey = test_config.TestConfig.masterKey - REGION1 = "West US" - REGION2 = "East US" - REGION3 = "West US 2" - REGIONAL_ENDPOINT = RegionalEndpoint(host, "something_different") TEST_DATABASE_ID = test_config.TestConfig.TEST_DATABASE_ID TEST_CONTAINER_ID = test_config.TestConfig.TEST_SINGLE_PARTITION_CONTAINER_ID @@ -36,44 +31,19 @@ async def asyncSetUp(self): self.created_database = self.client.get_database_client(self.TEST_DATABASE_ID) self.created_container = self.created_database.get_container_client(self.TEST_CONTAINER_ID) - async def test_no_swaps_on_successful_request(self): - original_get_database_account_stub = _global_endpoint_manager_async._GlobalEndpointManager._GetDatabaseAccountStub - _global_endpoint_manager_async._GlobalEndpointManager._GetDatabaseAccountStub = self.MockGetDatabaseAccountStub - mocked_client = CosmosClient(self.host, self.masterKey) - db = mocked_client.get_database_client(self.TEST_DATABASE_ID) - container = db.get_container_client(self.TEST_CONTAINER_ID) - # Mock the GetDatabaseAccountStub to return the regional endpoints + async def asyncTearDown(self): + await self.client.close() - original_read_endpoint = (mocked_client.client_connection._global_endpoint_manager + async def test_no_swaps_on_successful_request(self): + # Make sure that getDatabaseAccount call has finished + await self.client.client_connection._global_endpoint_manager.force_refresh(None) + original_read_endpoint = (self.client.client_connection._global_endpoint_manager .location_cache.get_read_regional_endpoint()) - try: - await container.create_item(body={"id": str(uuid.uuid4())}) - finally: - # Check for if there was a swap - self.assertEqual(original_read_endpoint, - mocked_client.client_connection._global_endpoint_manager - .location_cache.get_read_regional_endpoint()) - # return it - self.assertEqual(self.REGIONAL_ENDPOINT.get_current(), - mocked_client.client_connection._global_endpoint_manager - .location_cache.get_write_regional_endpoint()) - _global_endpoint_manager_async._GlobalEndpointManager._GetDatabaseAccountStub = original_get_database_account_stub - - async def MockGetDatabaseAccountStub(self, endpoint): - read_locations = [] - read_locations.append({'databaseAccountEndpoint': endpoint, 'name': "West US"}) - read_locations.append({'databaseAccountEndpoint': "some different endpoint", 'name': "West US"}) - write_regions = ["West US"] - write_locations = [] - for loc in write_regions: - write_locations.append({'databaseAccountEndpoint': endpoint, 'name': loc}) - multi_write = False - - db_acc = DatabaseAccount() - db_acc.DatabasesLink = "/dbs/" - db_acc.MediaLink = "/media/" - db_acc._ReadableLocations = read_locations - db_acc._WritableLocations = write_locations - db_acc._EnableMultipleWritableLocations = multi_write - db_acc.ConsistencyPolicy = {"defaultConsistencyLevel": "Session"} - return db_acc + await self.created_container.create_item(body={"id": str(uuid.uuid4())}) + # Check for if there was a swap + self.assertEqual(original_read_endpoint, + self.client.client_connection._global_endpoint_manager + .location_cache.get_read_regional_endpoint()) + self.assertEqual(original_read_endpoint, + self.client.client_connection._global_endpoint_manager + .location_cache.get_write_regional_endpoint())