Skip to content

Effective Preferred Locations #39714

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 11 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions sdk/cosmos/azure-cosmos/CHANGELOG.md
Original file line number Diff line number Diff line change
@@ -1,4 +1,14 @@
## Release History
### 4.10.1b1 (Unreleased)

#### Features Added

#### Breaking Changes
* Adds cross region retries when no preferred locations are set. This is only a breaking change for customers using bounded staleness consistency. See [PR 39714](https://github.com/Azure/azure-sdk-for-python/pull/39714)

#### Bugs Fixed

#### Other Changes

### 4.10.0b1 (2025-02-13)

Expand Down
69 changes: 62 additions & 7 deletions sdk/cosmos/azure-cosmos/azure/cosmos/_location_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,8 +23,10 @@
DatabaseAccount with multiple writable and readable locations.
"""
import collections
import difflib
import logging
import time
from typing import List, OrderedDict
from urllib.parse import urlparse

from . import documents
Expand Down Expand Up @@ -113,7 +115,6 @@ def get_endpoints_by_location(new_locations,
default_regional_endpoint.get_current(),
new_location["name"])
regional_object.set_previous(constructed_region_uri)
# pass in object with region uri , last known good, curr etc
endpoints_by_location.update({new_location["name"]: regional_object})
except Exception as e:
raise e
Expand All @@ -134,6 +135,7 @@ def __init__(
refresh_time_interval_in_ms,
):
self.preferred_locations = preferred_locations
self.effective_preferred_locations = []
self.default_regional_endpoint = RegionalEndpoint(default_endpoint, default_endpoint)
self.enable_endpoint_discovery = enable_endpoint_discovery
self.use_multiple_write_locations = use_multiple_write_locations
Expand Down Expand Up @@ -245,7 +247,7 @@ def resolve_service_endpoint(self, request):
return regional_endpoint.get_current()

def should_refresh_endpoints(self): # pylint: disable=too-many-return-statements
most_preferred_location = self.preferred_locations[0] if self.preferred_locations else None
most_preferred_location = self.effective_preferred_locations[0] if self.effective_preferred_locations else None

# we should schedule refresh in background if we are unable to target the user's most preferredLocation.
if self.enable_endpoint_discovery:
Expand Down Expand Up @@ -357,9 +359,6 @@ def mark_endpoint_unavailable(self, unavailable_endpoint: str, unavailable_opera
if refresh_cache:
self.update_location_cache()

def get_preferred_locations(self):
return self.preferred_locations

def update_location_cache(self, write_locations=None, read_locations=None, enable_multiple_writable_locations=None):
if enable_multiple_writable_locations:
self.enable_multiple_writable_locations = enable_multiple_writable_locations
Expand Down Expand Up @@ -387,6 +386,19 @@ def update_location_cache(self, write_locations=None, read_locations=None, enabl
self.use_multiple_write_locations,
)

# if preferred locations is empty and the default endpoint is a global endpoint,
# we should use the read locations from gateway
if self.preferred_locations:
self.effective_preferred_locations = self.preferred_locations
elif not LocationCache.is_global_endpoint(
self.default_regional_endpoint.get_current(),
self.available_read_regional_endpoints_by_location,
self.available_read_locations
):
self.effective_preferred_locations = []
else:
self.effective_preferred_locations = self.available_read_locations

self.write_regional_endpoints = self.get_preferred_available_regional_endpoints(
self.available_write_regional_endpoints_by_location,
self.available_write_locations,
Expand All @@ -413,12 +425,12 @@ def get_preferred_available_regional_endpoints( # pylint: disable=name-too-long
or expected_available_operation == EndpointOperationType.ReadType
):
unavailable_endpoints = []
if self.preferred_locations:
if self.effective_preferred_locations:
# When client can not use multiple write locations, preferred locations
# list should only be used determining read endpoints order. If client
# can use multiple write locations, preferred locations list should be
# used for determining both read and write endpoints order.
for location in self.preferred_locations:
for location in self.effective_preferred_locations:
regional_endpoint = endpoints_by_location[location] if location in endpoints_by_location \
else None
if regional_endpoint:
Expand Down Expand Up @@ -481,3 +493,46 @@ def GetLocationalEndpoint(default_endpoint, location_name):
return locational_endpoint

return None

@staticmethod
def is_global_endpoint(
endpoint: str,
account_read_endpoints_by_location: OrderedDict[str, RegionalEndpoint],
read_locations: List[str]
):

# if there is only one read location, the sdk cannot figure out the global endpoint
if len(read_locations) < 2:
return True

first_endpoint_url = urlparse(account_read_endpoints_by_location[read_locations[0]].get_current())
second_endpoint_url = urlparse(account_read_endpoints_by_location[read_locations[1]].get_current())

# hostname attribute in endpoint_url will return something like 'contoso.documents.azure.com'
if first_endpoint_url.hostname is not None and second_endpoint_url.hostname is not None:
first_hostname_parts = str(first_endpoint_url.hostname).lower().split(".")
# first account name will return something like 'contoso-eastus'
first_account_name = first_hostname_parts[0]
second_hostname_parts = str(second_endpoint_url.hostname).lower().split(".")
# second account name will return something like 'contoso-westus'
second_account_name = second_hostname_parts[0]
#
common_account_name = LocationCache.get_common_part(first_account_name, second_account_name)

# if both were regional endpoints the common account name will have a - at the end
if common_account_name[len(common_account_name) - 1] == "-":
global_account_name = common_account_name[:-1]
else:
global_account_name = common_account_name
default_endpoint_account_name = str(urlparse(endpoint).hostname).lower().split('.', maxsplit=1)[0]
return global_account_name == default_endpoint_account_name
return False

# finds the longest matching consecutive substring in two strings
@staticmethod
def get_common_part(str1: str, str2: str):
sequence_matcher = difflib.SequenceMatcher(None, str1, str2)
match = sequence_matcher.find_longest_match(0, len(str1), 0, len(str2))
if match.size != 0:
return str1[match.a: match.a + match.size]
return ""
168 changes: 168 additions & 0 deletions sdk/cosmos/azure-cosmos/test/test_effective_preferred_locations.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,168 @@
import random
from collections import OrderedDict
from typing import List

import pytest
import test_config
from azure.cosmos import DatabaseAccount, _location_cache, CosmosClient, _global_endpoint_manager, \
_cosmos_client_connection

from azure.cosmos._location_cache import RegionalEndpoint, LocationCache

COLLECTION = "created_collection"
REGION_1 = "East US"
REGION_2 = "West US"
REGION_3 = "West US 2"
ACCOUNT_REGIONS = [REGION_1, REGION_2, REGION_3]

@pytest.fixture()
def setup():
if (TestPreferredLocations.masterKey == '[YOUR_KEY_HERE]' or
TestPreferredLocations.host == '[YOUR_ENDPOINT_HERE]'):
raise Exception(
"You must specify your Azure Cosmos account values for "
"'masterKey' and 'host' at the top of this class to run the "
"tests.")

client = CosmosClient(TestPreferredLocations.host, TestPreferredLocations.masterKey, consistency_level="Session")
created_database = client.get_database_client(TestPreferredLocations.TEST_DATABASE_ID)
created_collection = created_database.get_container_client(TestPreferredLocations.TEST_CONTAINER_SINGLE_PARTITION_ID)
yield {
COLLECTION: created_collection
}

def preferred_locations():
host = test_config.TestConfig.host
locational_endpoint = _location_cache.LocationCache.GetLocationalEndpoint(host, REGION_2)
return [
([], host),
([REGION_1, REGION_2], host),
([REGION_1], host),
([REGION_2, REGION_3], host),
([REGION_1, REGION_2, REGION_3], host),
([], locational_endpoint),
([REGION_2], locational_endpoint),
([REGION_3, REGION_1], locational_endpoint),
([REGION_1, REGION_3], locational_endpoint),
([REGION_1, REGION_2, REGION_3], locational_endpoint)
]

def create_account_endpoints_by_location(global_endpoint: str):
locational_endpoints = []
for region in ACCOUNT_REGIONS:
locational_endpoints.append(_location_cache.LocationCache.GetLocationalEndpoint(global_endpoint, region))
_location_cache.LocationCache.GetLocationalEndpoint(global_endpoint, REGION_1)
account_endpoints_by_location = OrderedDict()
for i, region in enumerate(ACCOUNT_REGIONS):
# should create some with the global endpoint sometimes
if random.random() < 0.5:
account_endpoints_by_location[region] = RegionalEndpoint(locational_endpoints[i], global_endpoint)
else:
account_endpoints_by_location[region] = RegionalEndpoint(global_endpoint, locational_endpoints[i])


return locational_endpoints, account_endpoints_by_location

def is_global_endpoint_inputs():
global_endpoint = test_config.TestConfig.host
locational_endpoints, account_endpoints_by_location = create_account_endpoints_by_location(global_endpoint)

# testing if customers account name includes a region ex. contoso-eastus
global_endpoint_2 = _location_cache.LocationCache.GetLocationalEndpoint(global_endpoint, REGION_1)
locational_endpoints_2, account_endpoints_by_location_2 = create_account_endpoints_by_location(global_endpoint_2)

# endpoint, locations, account_endpoints_by_location, result
return [
(global_endpoint, ACCOUNT_REGIONS, account_endpoints_by_location, True),
(locational_endpoints[0], ACCOUNT_REGIONS, account_endpoints_by_location, False),
(locational_endpoints[1], ACCOUNT_REGIONS, account_endpoints_by_location, False),
(locational_endpoints[2], ACCOUNT_REGIONS, account_endpoints_by_location, False),
(global_endpoint_2, ACCOUNT_REGIONS, account_endpoints_by_location_2, True),
(locational_endpoints_2[0], ACCOUNT_REGIONS, account_endpoints_by_location_2, False),
(locational_endpoints_2[1], ACCOUNT_REGIONS, account_endpoints_by_location_2, False),
(locational_endpoints_2[2], ACCOUNT_REGIONS, account_endpoints_by_location_2, False)
]


@pytest.mark.cosmosEmulator
@pytest.mark.unittest
@pytest.mark.usefixtures("setup")
class TestPreferredLocations:
host = test_config.TestConfig.host
masterKey = test_config.TestConfig.masterKey
connectionPolicy = test_config.TestConfig.connectionPolicy
TEST_DATABASE_ID = test_config.TestConfig.TEST_DATABASE_ID
TEST_CONTAINER_SINGLE_PARTITION_ID = test_config.TestConfig.TEST_SINGLE_PARTITION_CONTAINER_ID

@pytest.mark.parametrize("preferred_location, default_endpoint", preferred_locations())
def test_effective_preferred_regions(self, setup, preferred_location, default_endpoint):

self.original_getDatabaseAccountStub = _global_endpoint_manager._GlobalEndpointManager._GetDatabaseAccountStub
self.original_getDatabaseAccountCheck = _cosmos_client_connection.CosmosClientConnection._GetDatabaseAccountCheck
_global_endpoint_manager._GlobalEndpointManager._GetDatabaseAccountStub = self.MockGetDatabaseAccount(ACCOUNT_REGIONS)
_cosmos_client_connection.CosmosClientConnection._GetDatabaseAccountCheck = self.MockGetDatabaseAccount(ACCOUNT_REGIONS)
try:
client = CosmosClient(default_endpoint, self.masterKey, preferred_locations=preferred_location)
# this will setup the location cache
client.client_connection._global_endpoint_manager.force_refresh(None)
finally:
_global_endpoint_manager._GlobalEndpointManager._GetDatabaseAccountStub = self.original_getDatabaseAccountStub
_cosmos_client_connection.CosmosClientConnection._GetDatabaseAccountCheck = self.original_getDatabaseAccountCheck
expected_dual_endpoints = []

# if preferred location set should use that
if preferred_location:
expected_locations = preferred_location
# if client created with regional endpoint preferred locations, only use hub region
elif default_endpoint != self.host:
expected_locations = ACCOUNT_REGIONS[:1]
# if client created with global endpoint and no preferred locations, use all regions
else:
expected_locations = ACCOUNT_REGIONS

for location in expected_locations:
locational_endpoint = _location_cache.LocationCache.GetLocationalEndpoint(self.host, location)
if default_endpoint == self.host or preferred_location:
expected_dual_endpoints.append(RegionalEndpoint(locational_endpoint, locational_endpoint))
else:
expected_dual_endpoints.append(RegionalEndpoint(locational_endpoint, default_endpoint))

read_dual_endpoints = client.client_connection._global_endpoint_manager.location_cache.read_regional_endpoints
assert read_dual_endpoints == expected_dual_endpoints


@pytest.mark.parametrize("default_endpoint, read_locations, account_endpoints_by_location, result", is_global_endpoint_inputs())
def test_is_global_endpoint(self, default_endpoint, read_locations, account_endpoints_by_location, result):
assert result == LocationCache.is_global_endpoint(default_endpoint, account_endpoints_by_location, read_locations)


class MockGetDatabaseAccount(object):
def __init__(
self,
regions: List[str],
):
self.regions = regions

def __call__(self, endpoint):
read_regions = self.regions
read_locations = []
counter = 0
for loc in read_regions:
locational_endpoint = _location_cache.LocationCache.GetLocationalEndpoint(TestPreferredLocations.host, loc)
read_locations.append({'databaseAccountEndpoint': locational_endpoint, 'name': loc})
counter += 1
write_regions = [self.regions[0]]
write_locations = []
for loc in write_regions:
locational_endpoint = _location_cache.LocationCache.GetLocationalEndpoint(TestPreferredLocations.host, loc)
write_locations.append({'databaseAccountEndpoint': locational_endpoint, 'name': loc})
multi_write = False

db_acc = DatabaseAccount()
db_acc.DatabasesLink = "/dbs/"
db_acc.MediaLink = "/media/"
db_acc._ReadableLocations = read_locations
db_acc._WritableLocations = write_locations
db_acc._EnableMultipleWritableLocations = multi_write
db_acc.ConsistencyPolicy = {"defaultConsistencyLevel": "Session"}
return db_acc
Loading
Loading