Skip to content

Effective Preferred Locations #39714

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Open
wants to merge 11 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 7 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions sdk/cosmos/azure-cosmos/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@
#### Features Added

#### Breaking Changes
* Adds cross region retries when no preferred locations are set. This is only a breaking change for customers using bounded staleness consistency. See [PR 39714](https://github.com/Azure/azure-sdk-for-python/pull/39714)

#### Bugs Fixed

Expand Down
16 changes: 10 additions & 6 deletions sdk/cosmos/azure-cosmos/azure/cosmos/_location_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -134,6 +134,7 @@ def __init__(
refresh_time_interval_in_ms,
):
self.preferred_locations = preferred_locations
self.effective_preferred_locations = []
self.default_regional_endpoint = RegionalEndpoint(default_endpoint, default_endpoint)
self.enable_endpoint_discovery = enable_endpoint_discovery
self.use_multiple_write_locations = use_multiple_write_locations
Expand Down Expand Up @@ -245,7 +246,7 @@ def resolve_service_endpoint(self, request):
return regional_endpoint.get_current()

def should_refresh_endpoints(self): # pylint: disable=too-many-return-statements
most_preferred_location = self.preferred_locations[0] if self.preferred_locations else None
most_preferred_location = self.effective_preferred_locations[0] if self.effective_preferred_locations else None

# we should schedule refresh in background if we are unable to target the user's most preferredLocation.
if self.enable_endpoint_discovery:
Expand Down Expand Up @@ -357,9 +358,6 @@ def mark_endpoint_unavailable(self, unavailable_endpoint: str, unavailable_opera
if refresh_cache:
self.update_location_cache()

def get_preferred_locations(self):
return self.preferred_locations

def update_location_cache(self, write_locations=None, read_locations=None, enable_multiple_writable_locations=None):
if enable_multiple_writable_locations:
self.enable_multiple_writable_locations = enable_multiple_writable_locations
Expand Down Expand Up @@ -387,6 +385,12 @@ def update_location_cache(self, write_locations=None, read_locations=None, enabl
self.use_multiple_write_locations,
)

# if preferred locations is empty, we should use the read locations from gateway
if len(self.preferred_locations) == 0:
self.effective_preferred_locations = self.available_read_locations
else:
self.effective_preferred_locations = self.preferred_locations

self.write_regional_endpoints = self.get_preferred_available_regional_endpoints(
self.available_write_regional_endpoints_by_location,
self.available_write_locations,
Expand All @@ -413,12 +417,12 @@ def get_preferred_available_regional_endpoints( # pylint: disable=name-too-long
or expected_available_operation == EndpointOperationType.ReadType
):
unavailable_endpoints = []
if self.preferred_locations:
if self.effective_preferred_locations:
# When client can not use multiple write locations, preferred locations
# list should only be used determining read endpoints order. If client
# can use multiple write locations, preferred locations list should be
# used for determining both read and write endpoints order.
for location in self.preferred_locations:
for location in self.effective_preferred_locations:
regional_endpoint = endpoints_by_location[location] if location in endpoints_by_location \
else None
if regional_endpoint:
Expand Down
102 changes: 102 additions & 0 deletions sdk/cosmos/azure-cosmos/test/test_effective_preferred_locations.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,102 @@
from typing import List

import pytest
import test_config
from azure.cosmos import DatabaseAccount, _location_cache, CosmosClient, _global_endpoint_manager, \
_cosmos_client_connection

from azure.cosmos._location_cache import RegionalEndpoint

COLLECTION = "created_collection"
REGION_1 = "East US"
REGION_2 = "West US"
REGION_3 = "West US 2"
ACCOUNT_REGIONS = [REGION_1, REGION_2, REGION_3]

@pytest.fixture()
def setup():
if (TestPreferredLocations.masterKey == '[YOUR_KEY_HERE]' or
TestPreferredLocations.host == '[YOUR_ENDPOINT_HERE]'):
raise Exception(
"You must specify your Azure Cosmos account values for "
"'masterKey' and 'host' at the top of this class to run the "
"tests.")

client = CosmosClient(TestPreferredLocations.host, TestPreferredLocations.masterKey, consistency_level="Session")
created_database = client.get_database_client(TestPreferredLocations.TEST_DATABASE_ID)
created_collection = created_database.get_container_client(TestPreferredLocations.TEST_CONTAINER_SINGLE_PARTITION_ID)
yield {
COLLECTION: created_collection
}

def preferred_locations():
return [([]), ([REGION_1, REGION_2]), ([REGION_1]), ([REGION_2, REGION_3]), ([REGION_1, REGION_2, REGION_3])]

@pytest.mark.cosmosEmulator
@pytest.mark.unittest
@pytest.mark.usefixtures("setup")
class TestPreferredLocations:
host = test_config.TestConfig.host
masterKey = test_config.TestConfig.masterKey
connectionPolicy = test_config.TestConfig.connectionPolicy
TEST_DATABASE_ID = test_config.TestConfig.TEST_DATABASE_ID
TEST_CONTAINER_SINGLE_PARTITION_ID = test_config.TestConfig.TEST_SINGLE_PARTITION_CONTAINER_ID

@pytest.mark.parametrize("preferred_location", preferred_locations())
def test_effective_preferred_regions(self, setup, preferred_location):

self.original_getDatabaseAccountStub = _global_endpoint_manager._GlobalEndpointManager._GetDatabaseAccountStub
self.original_getDatabaseAccountCheck = _cosmos_client_connection.CosmosClientConnection._GetDatabaseAccountCheck
_global_endpoint_manager._GlobalEndpointManager._GetDatabaseAccountStub = self.MockGetDatabaseAccount(ACCOUNT_REGIONS)
_cosmos_client_connection.CosmosClientConnection._GetDatabaseAccountCheck = self.MockGetDatabaseAccount(ACCOUNT_REGIONS)
try:
client = CosmosClient(self.host, self.masterKey, preferred_locations=preferred_location)
# this will setup the location cache
client.client_connection._global_endpoint_manager.force_refresh(None)
finally:
_global_endpoint_manager._GlobalEndpointManager._GetDatabaseAccountStub = self.original_getDatabaseAccountStub
_cosmos_client_connection.CosmosClientConnection._GetDatabaseAccountCheck = self.original_getDatabaseAccountCheck
expected_dual_endpoints = []

if preferred_location:
expected_locations = preferred_location
else:
expected_locations = ACCOUNT_REGIONS

for location in expected_locations:
locational_endpoint = _location_cache.LocationCache.GetLocationalEndpoint(self.host, location)
expected_dual_endpoints.append(RegionalEndpoint(locational_endpoint, locational_endpoint))

read_dual_endpoints = client.client_connection._global_endpoint_manager.location_cache.read_regional_endpoints
assert read_dual_endpoints == expected_dual_endpoints

class MockGetDatabaseAccount(object):
def __init__(
self,
regions: List[str],
):
self.regions = regions

def __call__(self, endpoint):
read_regions = self.regions
read_locations = []
counter = 0
for loc in read_regions:
locational_endpoint = _location_cache.LocationCache.GetLocationalEndpoint(endpoint, loc)
read_locations.append({'databaseAccountEndpoint': locational_endpoint, 'name': loc})
counter += 1
write_regions = [self.regions[0]]
write_locations = []
for loc in write_regions:
locational_endpoint = _location_cache.LocationCache.GetLocationalEndpoint(endpoint, loc)
write_locations.append({'databaseAccountEndpoint': locational_endpoint, 'name': loc})
multi_write = False

db_acc = DatabaseAccount()
db_acc.DatabasesLink = "/dbs/"
db_acc.MediaLink = "/media/"
db_acc._ReadableLocations = read_locations
db_acc._WritableLocations = write_locations
db_acc._EnableMultipleWritableLocations = multi_write
db_acc.ConsistencyPolicy = {"defaultConsistencyLevel": "Session"}
return db_acc
Original file line number Diff line number Diff line change
@@ -0,0 +1,105 @@
from typing import List

import pytest
import pytest_asyncio

import test_config
from azure.cosmos import DatabaseAccount, _location_cache

from azure.cosmos._location_cache import RegionalEndpoint
from azure.cosmos.aio import _global_endpoint_manager_async, _cosmos_client_connection_async, CosmosClient

COLLECTION = "created_collection"
REGION_1 = "East US"
REGION_2 = "West US"
REGION_3 = "West US 2"
ACCOUNT_REGIONS = [REGION_1, REGION_2, REGION_3]

@pytest_asyncio.fixture()
async def setup():
if (TestPreferredLocationsAsync.masterKey == '[YOUR_KEY_HERE]' or
TestPreferredLocationsAsync.host == '[YOUR_ENDPOINT_HERE]'):
raise Exception(
"You must specify your Azure Cosmos account values for "
"'masterKey' and 'host' at the top of this class to run the "
"tests.")

client = CosmosClient(TestPreferredLocationsAsync.host, TestPreferredLocationsAsync.masterKey, consistency_level="Session")
created_database = client.get_database_client(TestPreferredLocationsAsync.TEST_DATABASE_ID)
created_collection = created_database.get_container_client(TestPreferredLocationsAsync.TEST_CONTAINER_SINGLE_PARTITION_ID)
yield {
COLLECTION: created_collection
}
await client.close()

def preferred_locations():
return [([]), ([REGION_1, REGION_2]), ([REGION_1]), ([REGION_2, REGION_3]), ([REGION_1, REGION_2, REGION_3])]

@pytest.mark.cosmosEmulator
@pytest.mark.asyncio
@pytest.mark.usefixtures("setup")
class TestPreferredLocationsAsync:
host = test_config.TestConfig.host
masterKey = test_config.TestConfig.masterKey
connectionPolicy = test_config.TestConfig.connectionPolicy
TEST_DATABASE_ID = test_config.TestConfig.TEST_DATABASE_ID
TEST_CONTAINER_SINGLE_PARTITION_ID = test_config.TestConfig.TEST_SINGLE_PARTITION_CONTAINER_ID

@pytest.mark.parametrize("preferred_location", preferred_locations())
async def test_effective_preferred_regions_async(self, setup, preferred_location):

self.original_getDatabaseAccountStub = _global_endpoint_manager_async._GlobalEndpointManager._GetDatabaseAccountStub
self.original_getDatabaseAccountCheck = _cosmos_client_connection_async.CosmosClientConnection._GetDatabaseAccountCheck
_global_endpoint_manager_async._GlobalEndpointManager._GetDatabaseAccountStub = self.MockGetDatabaseAccount(ACCOUNT_REGIONS)
_cosmos_client_connection_async.CosmosClientConnection._GetDatabaseAccountCheck = self.MockGetDatabaseAccount(ACCOUNT_REGIONS)
try:
client = CosmosClient(self.host, self.masterKey, preferred_locations=preferred_location)
# this will setup the location cache
await client.client_connection._global_endpoint_manager.force_refresh(None)
finally:
_global_endpoint_manager_async._GlobalEndpointManager._GetDatabaseAccountStub = self.original_getDatabaseAccountStub
_cosmos_client_connection_async.CosmosClientConnection._GetDatabaseAccountCheck = self.original_getDatabaseAccountCheck
expected_dual_endpoints = []

if preferred_location:
expected_locations = preferred_location
else:
expected_locations = ACCOUNT_REGIONS

for location in expected_locations:
locational_endpoint = _location_cache.LocationCache.GetLocationalEndpoint(self.host, location)
expected_dual_endpoints.append(RegionalEndpoint(locational_endpoint, locational_endpoint))

read_dual_endpoints = client.client_connection._global_endpoint_manager.location_cache.read_regional_endpoints
assert read_dual_endpoints == expected_dual_endpoints

class MockGetDatabaseAccount(object):
def __init__(
self,
regions: List[str],
):
self.regions = regions

async def __call__(self, endpoint):
read_regions = self.regions
read_locations = []
counter = 0
for loc in read_regions:
locational_endpoint = _location_cache.LocationCache.GetLocationalEndpoint(endpoint, loc)
read_locations.append({'databaseAccountEndpoint': locational_endpoint, 'name': loc})
counter += 1
write_regions = [self.regions[0]]
write_locations = []
for loc in write_regions:
locational_endpoint = _location_cache.LocationCache.GetLocationalEndpoint(endpoint, loc)
write_locations.append({'databaseAccountEndpoint': locational_endpoint, 'name': loc})
multi_write = False

db_acc = DatabaseAccount()
db_acc.DatabasesLink = "/dbs/"
db_acc.MediaLink = "/media/"
db_acc._ReadableLocations = read_locations
db_acc._WritableLocations = write_locations
db_acc._EnableMultipleWritableLocations = multi_write
db_acc.ConsistencyPolicy = {"defaultConsistencyLevel": "Session"}
return db_acc
2 changes: 1 addition & 1 deletion sdk/cosmos/azure-cosmos/test/test_location_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,7 +102,7 @@ def test_get_locations(self):

# check read endpoints without preferred locations
read_regions = lc.get_read_regional_endpoints()
assert len(read_regions) == 1
assert len(read_regions) == 3
assert read_regions[0].get_current() == location1_endpoint

# check read endpoints with preferred locations
Expand Down
51 changes: 9 additions & 42 deletions sdk/cosmos/azure-cosmos/test/test_regional_endpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,10 +14,6 @@
class TestRegionalEndpoints(unittest.TestCase):
host = test_config.TestConfig.host
masterKey = test_config.TestConfig.masterKey
REGION1 = "West US"
REGION2 = "East US"
REGION3 = "West US 2"
REGIONAL_ENDPOINT = RegionalEndpoint(host, "something_different")
TEST_DATABASE_ID = test_config.TestConfig.TEST_DATABASE_ID
TEST_CONTAINER_ID = test_config.TestConfig.TEST_SINGLE_PARTITION_CONTAINER_ID

Expand All @@ -34,43 +30,14 @@ def setUpClass(cls):
cls.created_container = cls.created_database.get_container_client(cls.TEST_CONTAINER_ID)

def test_no_swaps_on_successful_request(self):
original_get_database_account_stub = _global_endpoint_manager._GlobalEndpointManager._GetDatabaseAccountStub
_global_endpoint_manager._GlobalEndpointManager._GetDatabaseAccountStub = self.MockGetDatabaseAccountStub
mocked_client = CosmosClient(self.host, self.masterKey)
db = mocked_client.get_database_client(self.TEST_DATABASE_ID)
container = db.get_container_client(self.TEST_CONTAINER_ID)
# Mock the GetDatabaseAccountStub to return the regional endpoints

original_read_endpoint = (mocked_client.client_connection._global_endpoint_manager
original_read_endpoint = (self.client.client_connection._global_endpoint_manager
.location_cache.get_read_regional_endpoint())
try:
container.create_item(body={"id": str(uuid.uuid4())})
finally:
# Check for if there was a swap
self.assertEqual(original_read_endpoint,
mocked_client.client_connection._global_endpoint_manager
.location_cache.get_read_regional_endpoint())
# return it
self.assertEqual(self.REGIONAL_ENDPOINT.get_current(),
mocked_client.client_connection._global_endpoint_manager
.location_cache.get_write_regional_endpoint())
_global_endpoint_manager._GlobalEndpointManager._GetDatabaseAccountStub = original_get_database_account_stub

def MockGetDatabaseAccountStub(self, endpoint):
read_locations = []
read_locations.append({'databaseAccountEndpoint': endpoint, 'name': "West US"})
read_locations.append({'databaseAccountEndpoint': "some different endpoint", 'name': "West US"})
write_regions = ["West US"]
write_locations = []
for loc in write_regions:
write_locations.append({'databaseAccountEndpoint': endpoint, 'name': loc})
multi_write = False

db_acc = DatabaseAccount()
db_acc.DatabasesLink = "/dbs/"
db_acc.MediaLink = "/media/"
db_acc._ReadableLocations = read_locations
db_acc._WritableLocations = write_locations
db_acc._EnableMultipleWritableLocations = multi_write
db_acc.ConsistencyPolicy = {"defaultConsistencyLevel": "Session"}
return db_acc
self.created_container.create_item(body={"id": str(uuid.uuid4())})
# Check for if there was a swap
self.assertEqual(original_read_endpoint,
self.client.client_connection._global_endpoint_manager
.location_cache.get_read_regional_endpoint())
self.assertEqual(original_read_endpoint,
self.client.client_connection._global_endpoint_manager
.location_cache.get_write_regional_endpoint())
Loading
Loading