Skip to content

Emulator tests #40453

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 24 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
f03f51f
Add Excluded Locations Feature
allenkim0129 Mar 31, 2025
5bb9f1f
Added multi-region tests
kushagraThapar Apr 3, 2025
996217a
Fix _AddParitionKey to pass options to sub methods
allenkim0129 Apr 3, 2025
41fc917
Added initial live tests
allenkim0129 Apr 3, 2025
07b8f39
Updated live-platform-matrix for multi-region tests
allenkim0129 Apr 3, 2025
8495c51
Add cosmosQuery mark to TestQuery
allenkim0129 Apr 4, 2025
b29980c
Correct spelling
allenkim0129 Apr 4, 2025
5e79172
Fixed live platform matrix syntax
allenkim0129 Apr 4, 2025
fd40cd7
Changed Multi-regions
allenkim0129 Apr 4, 2025
29305f4
Added client level ExcludedLocation for async
allenkim0129 Apr 7, 2025
c77b4e7
Update Live test settings
allenkim0129 Apr 7, 2025
d82fa74
Added Async tests
allenkim0129 Apr 7, 2025
5610889
Add more live tests for all other Python versions
allenkim0129 Apr 7, 2025
f4cb8b3
Fix Async test failure
allenkim0129 Apr 8, 2025
9b0236d
Merge branch 'main' into user/allekim/feature/addExcludedLocations
allenkim0129 Apr 8, 2025
4f08168
Fix live test failures
allenkim0129 Apr 8, 2025
4e2fd6b
Fix live test failures
allenkim0129 Apr 8, 2025
e0dab29
Fix live test failures
allenkim0129 Apr 8, 2025
798c12f
Add test_delete_all_items_by_partition_key
allenkim0129 Apr 8, 2025
2c5b8fc
Remove test_delete_all_items_by_partition_key
allenkim0129 Apr 8, 2025
2b9b58f
Added missing doc for excluded_locations in async client
allenkim0129 Apr 10, 2025
eead750
Remove duplicate functions
allenkim0129 Apr 10, 2025
67d312e
test emulator
allenkim0129 Apr 10, 2025
ea7f189
Fix import
allenkim0129 Apr 10, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions sdk/cosmos/azure-cosmos/azure/cosmos/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@
'priority': 'priorityLevel',
'no_response': 'responsePayloadOnWriteDisabled',
'max_item_count': 'maxItemCount',
'excluded_locations': 'excludedLocations',
}

# Cosmos resource ID validation regex breakdown:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2044,6 +2044,7 @@ def PatchItem(
documents._OperationType.Patch, options)
# Patch will use WriteEndpoint since it uses PUT operation
request_params = RequestObject(resource_type, documents._OperationType.Patch)
request_params.set_excluded_location_from_options(options)
request_data = {}
if options.get("filterPredicate"):
request_data["condition"] = options.get("filterPredicate")
Expand Down Expand Up @@ -2132,6 +2133,7 @@ def _Batch(
headers = base.GetHeaders(self, initial_headers, "post", path, collection_id, "docs",
documents._OperationType.Batch, options)
request_params = RequestObject("docs", documents._OperationType.Batch)
request_params.set_excluded_location_from_options(options)
return cast(
Tuple[List[Dict[str, Any]], CaseInsensitiveDict],
self.__Post(path, request_params, batch_operations, headers, **kwargs)
Expand Down Expand Up @@ -2192,6 +2194,7 @@ def DeleteAllItemsByPartitionKey(
headers = base.GetHeaders(self, self.default_headers, "post", path, collection_id,
"partitionkey", documents._OperationType.Delete, options)
request_params = RequestObject("partitionkey", documents._OperationType.Delete)
request_params.set_excluded_location_from_options(options)
_, last_response_headers = self.__Post(
path=path,
request_params=request_params,
Expand Down Expand Up @@ -2647,6 +2650,7 @@ def Create(
# Create will use WriteEndpoint since it uses POST operation

request_params = RequestObject(typ, documents._OperationType.Create)
request_params.set_excluded_location_from_options(options)
result, last_response_headers = self.__Post(path, request_params, body, headers, **kwargs)
self.last_response_headers = last_response_headers

Expand Down Expand Up @@ -2693,6 +2697,7 @@ def Upsert(

# Upsert will use WriteEndpoint since it uses POST operation
request_params = RequestObject(typ, documents._OperationType.Upsert)
request_params.set_excluded_location_from_options(options)
result, last_response_headers = self.__Post(path, request_params, body, headers, **kwargs)
self.last_response_headers = last_response_headers
# update session for write request
Expand Down Expand Up @@ -2736,6 +2741,7 @@ def Replace(
options)
# Replace will use WriteEndpoint since it uses PUT operation
request_params = RequestObject(typ, documents._OperationType.Replace)
request_params.set_excluded_location_from_options(options)
result, last_response_headers = self.__Put(path, request_params, resource, headers, **kwargs)
self.last_response_headers = last_response_headers

Expand Down Expand Up @@ -2777,6 +2783,7 @@ def Read(
headers = base.GetHeaders(self, initial_headers, "get", path, id, typ, documents._OperationType.Read, options)
# Read will use ReadEndpoint since it uses GET operation
request_params = RequestObject(typ, documents._OperationType.Read)
request_params.set_excluded_location_from_options(options)
result, last_response_headers = self.__Get(path, request_params, headers, **kwargs)
self.last_response_headers = last_response_headers
if response_hook:
Expand Down Expand Up @@ -2816,6 +2823,7 @@ def DeleteResource(
options)
# Delete will use WriteEndpoint since it uses DELETE operation
request_params = RequestObject(typ, documents._OperationType.Delete)
request_params.set_excluded_location_from_options(options)
result, last_response_headers = self.__Delete(path, request_params, headers, **kwargs)
self.last_response_headers = last_response_headers

Expand Down Expand Up @@ -3052,6 +3060,7 @@ def __GetBodiesFromQueryResult(result: Dict[str, Any]) -> List[Dict[str, Any]]:
resource_type,
documents._OperationType.QueryPlan if is_query_plan else documents._OperationType.ReadFeed
)
request_params.set_excluded_location_from_options(options)
headers = base.GetHeaders(
self,
initial_headers,
Expand Down Expand Up @@ -3090,6 +3099,7 @@ def __GetBodiesFromQueryResult(result: Dict[str, Any]) -> List[Dict[str, Any]]:

# Query operations will use ReadEndpoint even though it uses POST(for regular query operations)
request_params = RequestObject(resource_type, documents._OperationType.SqlQuery)
request_params.set_excluded_location_from_options(options)
req_headers = base.GetHeaders(
self,
initial_headers,
Expand Down Expand Up @@ -3256,7 +3266,7 @@ def _AddPartitionKey(
options: Mapping[str, Any]
) -> Dict[str, Any]:
collection_link = base.TrimBeginningAndEndingSlashes(collection_link)
partitionKeyDefinition = self._get_partition_key_definition(collection_link)
partitionKeyDefinition = self._get_partition_key_definition(collection_link, options)
new_options = dict(options)
# If the collection doesn't have a partition key definition, skip it as it's a legacy collection
if partitionKeyDefinition:
Expand Down Expand Up @@ -3358,15 +3368,19 @@ def _UpdateSessionIfRequired(
# update session
self.session.update_session(response_result, response_headers)

def _get_partition_key_definition(self, collection_link: str) -> Optional[Dict[str, Any]]:
def _get_partition_key_definition(
self,
collection_link: str,
options: Mapping[str, Any]
) -> Optional[Dict[str, Any]]:
partition_key_definition: Optional[Dict[str, Any]]
# If the document collection link is present in the cache, then use the cached partitionkey definition
if collection_link in self.__container_properties_cache:
cached_container: Dict[str, Any] = self.__container_properties_cache.get(collection_link, {})
partition_key_definition = cached_container.get("partitionKey")
# Else read the collection from backend and add it to the cache
else:
container = self.ReadContainer(collection_link)
container = self.ReadContainer(collection_link, options)
partition_key_definition = container.get("partitionKey")
self.__container_properties_cache[collection_link] = _set_properties_cache(container)
return partition_key_definition
Original file line number Diff line number Diff line change
Expand Up @@ -50,10 +50,8 @@ def __init__(self, client):
self.DefaultEndpoint = client.url_connection
self.refresh_time_interval_in_ms = self.get_refresh_time_interval_in_ms_stub()
self.location_cache = LocationCache(
self.PreferredLocations,
self.DefaultEndpoint,
self.EnableEndpointDiscovery,
client.connection_policy.UseMultipleWriteLocations
client.connection_policy
)
self.refresh_needed = False
self.refresh_lock = threading.RLock()
Expand Down
101 changes: 80 additions & 21 deletions sdk/cosmos/azure-cosmos/azure/cosmos/_location_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,12 +25,13 @@
import collections
import logging
import time
from typing import Set
from typing import Set, Mapping, List
from urllib.parse import urlparse

from . import documents
from . import http_constants
from .documents import _OperationType
from ._request_object import RequestObject

# pylint: disable=protected-access

Expand Down Expand Up @@ -113,7 +114,10 @@ def get_endpoints_by_location(new_locations,
except Exception as e:
raise e

return endpoints_by_location, parsed_locations
# Also store a hash map of endpoints for each location
locations_by_endpoints = {value.get_primary(): key for key, value in endpoints_by_location.items()}

return endpoints_by_location, locations_by_endpoints, parsed_locations

def add_endpoint_if_preferred(endpoint: str, preferred_endpoints: Set[str], endpoints: Set[str]) -> bool:
if endpoint in preferred_endpoints:
Expand Down Expand Up @@ -150,31 +154,44 @@ def _get_health_check_endpoints(

return endpoints

def _get_applicable_regional_endpoints(endpoints: List[RegionalRoutingContext],
location_name_by_endpoint: Mapping[str, str],
fall_back_endpoint: RegionalRoutingContext,
exclude_location_list: List[str]) -> List[RegionalRoutingContext]:
# filter endpoints by excluded locations
applicable_endpoints = []
for endpoint in endpoints:
if location_name_by_endpoint.get(endpoint.get_primary()) not in exclude_location_list:
applicable_endpoints.append(endpoint)

# if endpoint is empty add fallback endpoint
if not applicable_endpoints:
applicable_endpoints.append(fall_back_endpoint)

return applicable_endpoints

class LocationCache(object): # pylint: disable=too-many-public-methods,too-many-instance-attributes
def current_time_millis(self):
return int(round(time.time() * 1000))

def __init__(
self,
preferred_locations,
default_endpoint,
enable_endpoint_discovery,
use_multiple_write_locations,
connection_policy,
):
self.preferred_locations = preferred_locations
self.default_regional_routing_context = RegionalRoutingContext(default_endpoint, default_endpoint)
self.enable_endpoint_discovery = enable_endpoint_discovery
self.use_multiple_write_locations = use_multiple_write_locations
self.enable_multiple_writable_locations = False
self.write_regional_routing_contexts = [self.default_regional_routing_context]
self.read_regional_routing_contexts = [self.default_regional_routing_context]
self.location_unavailability_info_by_endpoint = {}
self.last_cache_update_time_stamp = 0
self.account_read_regional_routing_contexts_by_location = {} # pylint: disable=name-too-long
self.account_write_regional_routing_contexts_by_location = {} # pylint: disable=name-too-long
self.account_locations_by_read_regional_routing_context = {} # pylint: disable=name-too-long
self.account_locations_by_write_regional_routing_context = {} # pylint: disable=name-too-long
self.account_write_locations = []
self.account_read_locations = []
self.connection_policy = connection_policy

def get_write_regional_routing_contexts(self):
return self.write_regional_routing_contexts
Expand Down Expand Up @@ -207,6 +224,44 @@ def get_ordered_write_locations(self):
def get_ordered_read_locations(self):
return self.account_read_locations

def _get_configured_excluded_locations(self, request: RequestObject):
# If excluded locations were configured on request, use request level excluded locations.
excluded_locations = request.excluded_locations
if excluded_locations is None:
# If excluded locations were only configured on client(connection_policy), use client level
excluded_locations = self.connection_policy.ExcludedLocations
return excluded_locations

def _get_applicable_read_regional_endpoints(self, request: RequestObject):
# Get configured excluded locations
excluded_locations = self._get_configured_excluded_locations(request)

# If excluded locations were configured, return filtered regional endpoints by excluded locations.
if excluded_locations:
return _get_applicable_regional_endpoints(
self.get_read_regional_routing_contexts(),
self.account_locations_by_read_regional_routing_context,
self.get_write_regional_routing_contexts()[0],
excluded_locations)

# Else, return all regional endpoints
return self.get_read_regional_routing_contexts()

def _get_applicable_write_regional_endpoints(self, request: RequestObject):
# Get configured excluded locations
excluded_locations = self._get_configured_excluded_locations(request)

# If excluded locations were configured, return filtered regional endpoints by excluded locations.
if excluded_locations:
return _get_applicable_regional_endpoints(
self.get_write_regional_routing_contexts(),
self.account_locations_by_write_regional_routing_context,
self.default_regional_routing_context,
excluded_locations)

# Else, return all regional endpoints
return self.get_write_regional_routing_contexts()

def resolve_service_endpoint(self, request):
if request.location_endpoint_to_route:
return request.location_endpoint_to_route
Expand All @@ -227,7 +282,7 @@ def resolve_service_endpoint(self, request):
# For non-document resource types in case of client can use multiple write locations
# or when client cannot use multiple write locations, flip-flop between the
# first and the second writable region in DatabaseAccount (for manual failover)
if self.enable_endpoint_discovery and self.account_write_locations:
if self.connection_policy.EnableEndpointDiscovery and self.account_write_locations:
location_index = min(location_index % 2, len(self.account_write_locations) - 1)
write_location = self.account_write_locations[location_index]
if (self.account_write_regional_routing_contexts_by_location
Expand All @@ -247,9 +302,9 @@ def resolve_service_endpoint(self, request):
return self.default_regional_routing_context.get_primary()

regional_routing_contexts = (
self.get_write_regional_routing_contexts()
self._get_applicable_write_regional_endpoints(request)
if documents._OperationType.IsWriteOperation(request.operation_type)
else self.get_read_regional_routing_contexts()
else self._get_applicable_read_regional_endpoints(request)
)
regional_routing_context = regional_routing_contexts[location_index % len(regional_routing_contexts)]
if (
Expand All @@ -263,12 +318,14 @@ def resolve_service_endpoint(self, request):
return regional_routing_context.get_primary()

def should_refresh_endpoints(self): # pylint: disable=too-many-return-statements
most_preferred_location = self.preferred_locations[0] if self.preferred_locations else None
most_preferred_location = self.connection_policy.PreferredLocations[0] \
if self.connection_policy.PreferredLocations else None

# we should schedule refresh in background if we are unable to target the user's most preferredLocation.
if self.enable_endpoint_discovery:
if self.connection_policy.EnableEndpointDiscovery:

should_refresh = self.use_multiple_write_locations and not self.enable_multiple_writable_locations
should_refresh = (self.connection_policy.UseMultipleWriteLocations
and not self.enable_multiple_writable_locations)

if (most_preferred_location and most_preferred_location in
self.account_read_regional_routing_contexts_by_location):
Expand Down Expand Up @@ -358,25 +415,27 @@ def update_location_cache(self, write_locations=None, read_locations=None, enabl
if enable_multiple_writable_locations:
self.enable_multiple_writable_locations = enable_multiple_writable_locations

if self.enable_endpoint_discovery:
if self.connection_policy.EnableEndpointDiscovery:
if read_locations:
(self.account_read_regional_routing_contexts_by_location,
self.account_locations_by_read_regional_routing_context,
self.account_read_locations) = get_endpoints_by_location(
read_locations,
self.account_read_regional_routing_contexts_by_location,
self.default_regional_routing_context,
False,
self.use_multiple_write_locations
self.connection_policy.UseMultipleWriteLocations
)

if write_locations:
(self.account_write_regional_routing_contexts_by_location,
self.account_locations_by_write_regional_routing_context,
self.account_write_locations) = get_endpoints_by_location(
write_locations,
self.account_write_regional_routing_contexts_by_location,
self.default_regional_routing_context,
True,
self.use_multiple_write_locations
self.connection_policy.UseMultipleWriteLocations
)

self.write_regional_routing_contexts = self.get_preferred_regional_routing_contexts(
Expand All @@ -399,18 +458,18 @@ def get_preferred_regional_routing_contexts(
regional_endpoints = []
# if enableEndpointDiscovery is false, we always use the defaultEndpoint that
# user passed in during documentClient init
if self.enable_endpoint_discovery and endpoints_by_location: # pylint: disable=too-many-nested-blocks
if self.connection_policy.EnableEndpointDiscovery and endpoints_by_location: # pylint: disable=too-many-nested-blocks
if (
self.can_use_multiple_write_locations()
or expected_available_operation == EndpointOperationType.ReadType
):
unavailable_endpoints = []
if self.preferred_locations:
if self.connection_policy.PreferredLocations:
# When client can not use multiple write locations, preferred locations
# list should only be used determining read endpoints order. If client
# can use multiple write locations, preferred locations list should be
# used for determining both read and write endpoints order.
for location in self.preferred_locations:
for location in self.connection_policy.PreferredLocations:
regional_endpoint = endpoints_by_location[location] if location in endpoints_by_location \
else None
if regional_endpoint:
Expand All @@ -436,7 +495,7 @@ def get_preferred_regional_routing_contexts(
return regional_endpoints

def can_use_multiple_write_locations(self):
return self.use_multiple_write_locations and self.enable_multiple_writable_locations
return self.connection_policy.UseMultipleWriteLocations and self.enable_multiple_writable_locations

def can_use_multiple_write_locations_for_request(self, request): # pylint: disable=name-too-long
return self.can_use_multiple_write_locations() and (
Expand Down
Loading
Loading