diff --git a/sdk/cosmos/azure-cosmos/CHANGELOG.md b/sdk/cosmos/azure-cosmos/CHANGELOG.md index 685754287cb9..9637a28e24b8 100644 --- a/sdk/cosmos/azure-cosmos/CHANGELOG.md +++ b/sdk/cosmos/azure-cosmos/CHANGELOG.md @@ -3,6 +3,7 @@ ### 4.10.0b5 (Unreleased) #### Features Added +* Per partition circuit breaker support. It can be enabled through the environment variable `AZURE_COSMOS_ENABLE_CIRCUIT_BREAKER`. See [PR 40302](https://github.com/Azure/azure-sdk-for-python/pull/40302). #### Breaking Changes diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/_base.py b/sdk/cosmos/azure-cosmos/azure/cosmos/_base.py index addde69e515e..7b5ac8f13dbf 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/_base.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/_base.py @@ -63,6 +63,7 @@ 'priority': 'priorityLevel', 'no_response': 'responsePayloadOnWriteDisabled', 'max_item_count': 'maxItemCount', + 'excluded_locations': 'excludedLocations', } # Cosmos resource ID validation regex breakdown: diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/_constants.py b/sdk/cosmos/azure-cosmos/azure/cosmos/_constants.py index cb82d5bd7de2..d0f5a3d185ad 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/_constants.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/_constants.py @@ -51,6 +51,18 @@ class _Constants: HS_MAX_ITEMS_CONFIG_DEFAULT: int = 1000 MAX_ITEM_BUFFER_VS_CONFIG: str = "AZURE_COSMOS_MAX_ITEM_BUFFER_VECTOR_SEARCH" MAX_ITEM_BUFFER_VS_CONFIG_DEFAULT: int = 50000 + CIRCUIT_BREAKER_ENABLED_CONFIG: str = "AZURE_COSMOS_ENABLE_CIRCUIT_BREAKER" + CIRCUIT_BREAKER_ENABLED_CONFIG_DEFAULT: str = "False" + # Only applicable when circuit breaker is enabled ------------------------- + CONSECUTIVE_ERROR_COUNT_TOLERATED_FOR_READ: str = "AZURE_COSMOS_CONSECUTIVE_ERROR_COUNT_TOLERATED_FOR_READ" + CONSECUTIVE_ERROR_COUNT_TOLERATED_FOR_READ_DEFAULT: int = 10 + CONSECUTIVE_ERROR_COUNT_TOLERATED_FOR_WRITE: str = "AZURE_COSMOS_CONSECUTIVE_ERROR_COUNT_TOLERATED_FOR_WRITE" + CONSECUTIVE_ERROR_COUNT_TOLERATED_FOR_WRITE_DEFAULT: int = 5 + FAILURE_PERCENTAGE_TOLERATED = "AZURE_COSMOS_FAILURE_PERCENTAGE_TOLERATED" + FAILURE_PERCENTAGE_TOLERATED_DEFAULT: int = 90 + STALE_PARTITION_UNAVAILABILITY_CHECK = "AZURE_COSMOS_STALE_PARTITION_UNAVAILABILITY_CHECK_IN_SECONDS" + STALE_PARTITION_UNAVAILABILITY_CHECK_DEFAULT: int = 120 + # ------------------------------------------------------------------------- # Error code translations ERROR_TRANSLATIONS: Dict[int, str] = { diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/_cosmos_client_connection.py b/sdk/cosmos/azure-cosmos/azure/cosmos/_cosmos_client_connection.py index e43cd4c9c287..1d134b096d7b 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/_cosmos_client_connection.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/_cosmos_client_connection.py @@ -48,7 +48,7 @@ HttpResponse # pylint: disable=no-legacy-azure-core-http-response-import from . import _base as base -from . import _global_endpoint_manager as global_endpoint_manager +from ._global_partition_endpoint_manager_circuit_breaker import _GlobalPartitionEndpointManagerForCircuitBreaker from . import _query_iterable as query_iterable from . import _runtime_constants as runtime_constants from . import _session @@ -164,7 +164,7 @@ def __init__( # pylint: disable=too-many-statements self.last_response_headers: CaseInsensitiveDict = CaseInsensitiveDict() self.UseMultipleWriteLocations = False - self._global_endpoint_manager = global_endpoint_manager._GlobalEndpointManager(self) + self._global_endpoint_manager = _GlobalPartitionEndpointManagerForCircuitBreaker(self) retry_policy = None if isinstance(self.connection_policy.ConnectionRetryConfiguration, HTTPPolicy): @@ -2043,7 +2043,8 @@ def PatchItem( headers = base.GetHeaders(self, self.default_headers, "patch", path, document_id, resource_type, documents._OperationType.Patch, options) # Patch will use WriteEndpoint since it uses PUT operation - request_params = RequestObject(resource_type, documents._OperationType.Patch) + request_params = RequestObject(resource_type, documents._OperationType.Patch, headers) + request_params.set_excluded_location_from_options(options) request_data = {} if options.get("filterPredicate"): request_data["condition"] = options.get("filterPredicate") @@ -2131,7 +2132,8 @@ def _Batch( base._populate_batch_headers(initial_headers) headers = base.GetHeaders(self, initial_headers, "post", path, collection_id, "docs", documents._OperationType.Batch, options) - request_params = RequestObject("docs", documents._OperationType.Batch) + request_params = RequestObject("docs", documents._OperationType.Batch, headers) + request_params.set_excluded_location_from_options(options) return cast( Tuple[List[Dict[str, Any]], CaseInsensitiveDict], self.__Post(path, request_params, batch_operations, headers, **kwargs) @@ -2190,8 +2192,9 @@ def DeleteAllItemsByPartitionKey( path = '{}{}/{}'.format(path, "operations", "partitionkeydelete") collection_id = base.GetResourceIdOrFullNameFromLink(collection_link) headers = base.GetHeaders(self, self.default_headers, "post", path, collection_id, - "partitionkey", documents._OperationType.Delete, options) - request_params = RequestObject("partitionkey", documents._OperationType.Delete) + http_constants.ResourceType.PartitionKey, documents._OperationType.Delete, options) + request_params = RequestObject(http_constants.ResourceType.PartitionKey, documents._OperationType.Delete, headers) + request_params.set_excluded_location_from_options(options) _, last_response_headers = self.__Post( path=path, request_params=request_params, @@ -2362,7 +2365,8 @@ def ExecuteStoredProcedure( documents._OperationType.ExecuteJavaScript, options) # ExecuteStoredProcedure will use WriteEndpoint since it uses POST operation - request_params = RequestObject("sprocs", documents._OperationType.ExecuteJavaScript) + request_params = RequestObject("sprocs", documents._OperationType.ExecuteJavaScript, headers) + request_params.set_excluded_location_from_options(options) result, self.last_response_headers = self.__Post(path, request_params, params, headers, **kwargs) return result @@ -2558,7 +2562,7 @@ def GetDatabaseAccount( headers = base.GetHeaders(self, self.default_headers, "get", "", "", "", documents._OperationType.Read,{}, client_id=self.client_id) - request_params = RequestObject("databaseaccount", documents._OperationType.Read, url_connection) + request_params = RequestObject("databaseaccount", documents._OperationType.Read, headers, url_connection) result, last_response_headers = self.__Get("", request_params, headers, **kwargs) self.last_response_headers = last_response_headers database_account = DatabaseAccount() @@ -2607,7 +2611,7 @@ def _GetDatabaseAccountCheck( headers = base.GetHeaders(self, self.default_headers, "get", "", "", "", documents._OperationType.Read,{}, client_id=self.client_id) - request_params = RequestObject("databaseaccount", documents._OperationType.Read, url_connection) + request_params = RequestObject("databaseaccount", documents._OperationType.Read, headers, url_connection) self.__Get("", request_params, headers, **kwargs) @@ -2646,7 +2650,8 @@ def Create( options) # Create will use WriteEndpoint since it uses POST operation - request_params = RequestObject(typ, documents._OperationType.Create) + request_params = RequestObject(typ, documents._OperationType.Create, headers) + request_params.set_excluded_location_from_options(options) result, last_response_headers = self.__Post(path, request_params, body, headers, **kwargs) self.last_response_headers = last_response_headers @@ -2692,7 +2697,8 @@ def Upsert( headers[http_constants.HttpHeaders.IsUpsert] = True # Upsert will use WriteEndpoint since it uses POST operation - request_params = RequestObject(typ, documents._OperationType.Upsert) + request_params = RequestObject(typ, documents._OperationType.Upsert, headers) + request_params.set_excluded_location_from_options(options) result, last_response_headers = self.__Post(path, request_params, body, headers, **kwargs) self.last_response_headers = last_response_headers # update session for write request @@ -2735,7 +2741,8 @@ def Replace( headers = base.GetHeaders(self, initial_headers, "put", path, id, typ, documents._OperationType.Replace, options) # Replace will use WriteEndpoint since it uses PUT operation - request_params = RequestObject(typ, documents._OperationType.Replace) + request_params = RequestObject(typ, documents._OperationType.Replace, headers) + request_params.set_excluded_location_from_options(options) result, last_response_headers = self.__Put(path, request_params, resource, headers, **kwargs) self.last_response_headers = last_response_headers @@ -2776,7 +2783,8 @@ def Read( initial_headers = initial_headers or self.default_headers headers = base.GetHeaders(self, initial_headers, "get", path, id, typ, documents._OperationType.Read, options) # Read will use ReadEndpoint since it uses GET operation - request_params = RequestObject(typ, documents._OperationType.Read) + request_params = RequestObject(typ, documents._OperationType.Read, headers) + request_params.set_excluded_location_from_options(options) result, last_response_headers = self.__Get(path, request_params, headers, **kwargs) self.last_response_headers = last_response_headers if response_hook: @@ -2815,7 +2823,8 @@ def DeleteResource( headers = base.GetHeaders(self, initial_headers, "delete", path, id, typ, documents._OperationType.Delete, options) # Delete will use WriteEndpoint since it uses DELETE operation - request_params = RequestObject(typ, documents._OperationType.Delete) + request_params = RequestObject(typ, documents._OperationType.Delete, headers) + request_params.set_excluded_location_from_options(options) result, last_response_headers = self.__Delete(path, request_params, headers, **kwargs) self.last_response_headers = last_response_headers @@ -3047,11 +3056,8 @@ def __GetBodiesFromQueryResult(result: Dict[str, Any]) -> List[Dict[str, Any]]: initial_headers = self.default_headers.copy() # Copy to make sure that default_headers won't be changed. if query is None: + op_typ = documents._OperationType.QueryPlan if is_query_plan else documents._OperationType.ReadFeed # Query operations will use ReadEndpoint even though it uses GET(for feed requests) - request_params = RequestObject( - resource_type, - documents._OperationType.QueryPlan if is_query_plan else documents._OperationType.ReadFeed - ) headers = base.GetHeaders( self, initial_headers, @@ -3059,11 +3065,18 @@ def __GetBodiesFromQueryResult(result: Dict[str, Any]) -> List[Dict[str, Any]]: path, resource_id, resource_type, - request_params.operation_type, + op_typ, options, partition_key_range_id ) + request_params = RequestObject( + resource_type, + op_typ, + headers + ) + request_params.set_excluded_location_from_options(options) + change_feed_state: Optional[ChangeFeedState] = options.get("changeFeedState") if change_feed_state is not None: change_feed_state.populate_request_headers(self._routing_map_provider, headers) @@ -3089,7 +3102,6 @@ def __GetBodiesFromQueryResult(result: Dict[str, Any]) -> List[Dict[str, Any]]: raise SystemError("Unexpected query compatibility mode.") # Query operations will use ReadEndpoint even though it uses POST(for regular query operations) - request_params = RequestObject(resource_type, documents._OperationType.SqlQuery) req_headers = base.GetHeaders( self, initial_headers, @@ -3102,6 +3114,9 @@ def __GetBodiesFromQueryResult(result: Dict[str, Any]) -> List[Dict[str, Any]]: partition_key_range_id ) + request_params = RequestObject(resource_type, documents._OperationType.SqlQuery, req_headers) + request_params.set_excluded_location_from_options(options) + # check if query has prefix partition key isPrefixPartitionQuery = kwargs.pop("isPrefixPartitionQuery", None) if isPrefixPartitionQuery: @@ -3256,7 +3271,7 @@ def _AddPartitionKey( options: Mapping[str, Any] ) -> Dict[str, Any]: collection_link = base.TrimBeginningAndEndingSlashes(collection_link) - partitionKeyDefinition = self._get_partition_key_definition(collection_link) + partitionKeyDefinition = self._get_partition_key_definition(collection_link, options) new_options = dict(options) # If the collection doesn't have a partition key definition, skip it as it's a legacy collection if partitionKeyDefinition: @@ -3358,7 +3373,11 @@ def _UpdateSessionIfRequired( # update session self.session.update_session(response_result, response_headers) - def _get_partition_key_definition(self, collection_link: str) -> Optional[Dict[str, Any]]: + def _get_partition_key_definition( + self, + collection_link: str, + options: Mapping[str, Any] + ) -> Optional[Dict[str, Any]]: partition_key_definition: Optional[Dict[str, Any]] # If the document collection link is present in the cache, then use the cached partitionkey definition if collection_link in self.__container_properties_cache: @@ -3366,7 +3385,7 @@ def _get_partition_key_definition(self, collection_link: str) -> Optional[Dict[s partition_key_definition = cached_container.get("partitionKey") # Else read the collection from backend and add it to the cache else: - container = self.ReadContainer(collection_link) + container = self.ReadContainer(collection_link, options) partition_key_definition = container.get("partitionKey") self.__container_properties_cache[collection_link] = _set_properties_cache(container) return partition_key_definition diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/_global_endpoint_manager.py b/sdk/cosmos/azure-cosmos/azure/cosmos/_global_endpoint_manager.py index e167871dd4a5..62dd60a30da0 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/_global_endpoint_manager.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/_global_endpoint_manager.py @@ -30,8 +30,10 @@ from . import _constants as constants from . import exceptions +from ._request_object import RequestObject +from ._routing.routing_range import PartitionKeyRangeWrapper from .documents import DatabaseAccount -from ._location_cache import LocationCache +from ._location_cache import LocationCache, current_time_millis # pylint: disable=protected-access @@ -50,19 +52,14 @@ def __init__(self, client): self.DefaultEndpoint = client.url_connection self.refresh_time_interval_in_ms = self.get_refresh_time_interval_in_ms_stub() self.location_cache = LocationCache( - self.PreferredLocations, self.DefaultEndpoint, - self.EnableEndpointDiscovery, - client.connection_policy.UseMultipleWriteLocations + client.connection_policy ) self.refresh_needed = False self.refresh_lock = threading.RLock() self.last_refresh_time = 0 self._database_account_cache = None - def get_use_multiple_write_locations(self): - return self.location_cache.can_use_multiple_write_locations() - def get_refresh_time_interval_in_ms_stub(self): return constants._Constants.DefaultEndpointsRefreshTime @@ -72,7 +69,11 @@ def get_write_endpoint(self): def get_read_endpoint(self): return self.location_cache.get_read_regional_routing_context() - def resolve_service_endpoint(self, request): + def resolve_service_endpoint( + self, + request: RequestObject, + pk_range_wrapper: PartitionKeyRangeWrapper # pylint: disable=unused-argument + ) -> str: return self.location_cache.resolve_service_endpoint(request) def mark_endpoint_unavailable_for_read(self, endpoint, refresh_cache): @@ -98,7 +99,7 @@ def update_location_cache(self): self.location_cache.update_location_cache() def refresh_endpoint_list(self, database_account, **kwargs): - if self.location_cache.current_time_millis() - self.last_refresh_time > self.refresh_time_interval_in_ms: + if current_time_millis() - self.last_refresh_time > self.refresh_time_interval_in_ms: self.refresh_needed = True if self.refresh_needed: with self.refresh_lock: @@ -114,11 +115,11 @@ def _refresh_endpoint_list_private(self, database_account=None, **kwargs): if database_account: self.location_cache.perform_on_database_account_read(database_account) self.refresh_needed = False - self.last_refresh_time = self.location_cache.current_time_millis() + self.last_refresh_time = current_time_millis() else: if self.location_cache.should_refresh_endpoints() or self.refresh_needed: self.refresh_needed = False - self.last_refresh_time = self.location_cache.current_time_millis() + self.last_refresh_time = current_time_millis() # this will perform getDatabaseAccount calls to check endpoint health self._endpoints_health_check(**kwargs) diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/_global_partition_endpoint_manager_circuit_breaker.py b/sdk/cosmos/azure-cosmos/azure/cosmos/_global_partition_endpoint_manager_circuit_breaker.py new file mode 100644 index 000000000000..288205cb2411 --- /dev/null +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/_global_partition_endpoint_manager_circuit_breaker.py @@ -0,0 +1,99 @@ +# The MIT License (MIT) +# Copyright (c) 2021 Microsoft Corporation + +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: + +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. + +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +"""Internal class for global endpoint manager for circuit breaker. +""" +from typing import TYPE_CHECKING + +from azure.cosmos._global_partition_endpoint_manager_circuit_breaker_core import \ + _GlobalPartitionEndpointManagerForCircuitBreakerCore + +from azure.cosmos._global_endpoint_manager import _GlobalEndpointManager +from azure.cosmos._request_object import RequestObject +from azure.cosmos._routing.routing_range import PartitionKeyRangeWrapper +from azure.cosmos.http_constants import HttpHeaders + +if TYPE_CHECKING: + from azure.cosmos._cosmos_client_connection import CosmosClientConnection + + + +class _GlobalPartitionEndpointManagerForCircuitBreaker(_GlobalEndpointManager): + """ + This internal class implements the logic for partition endpoint management for + geo-replicated database accounts. + """ + + + def __init__(self, client: "CosmosClientConnection"): + super(_GlobalPartitionEndpointManagerForCircuitBreaker, self).__init__(client) + self.global_partition_endpoint_manager_core = ( + _GlobalPartitionEndpointManagerForCircuitBreakerCore(client, self.location_cache)) + + def is_circuit_breaker_applicable(self, request: RequestObject) -> bool: + return self.global_partition_endpoint_manager_core.is_circuit_breaker_applicable(request) + + + def create_pk_range_wrapper(self, request: RequestObject) -> PartitionKeyRangeWrapper: + container_rid = request.headers[HttpHeaders.IntendedCollectionRID] + partition_key = request.headers[HttpHeaders.PartitionKey] + # get the partition key range for the given partition key + target_container_link = None + for container_link, properties in self.Client._container_properties_cache.items(): # pylint: disable=protected-access + if properties["_rid"] == container_rid: + target_container_link = container_link + if not target_container_link: + raise RuntimeError("Illegal state: the container cache is not properly initialized.") + # TODO: @tvaron3 check different clients and create them in different ways + pk_range = (self.Client._routing_map_provider # pylint: disable=protected-access + .get_overlapping_ranges(target_container_link, partition_key)) + return PartitionKeyRangeWrapper(pk_range, container_rid) + + def record_failure( + self, + request: RequestObject + ) -> None: + if self.is_circuit_breaker_applicable(request): + pk_range_wrapper = self.create_pk_range_wrapper(request) + self.global_partition_endpoint_manager_core.record_failure(request, pk_range_wrapper) + + def resolve_service_endpoint(self, request: RequestObject, pk_range_wrapper: PartitionKeyRangeWrapper) -> str: + # TODO: @tvaron3 check here if it is healthy tentative and move it back to Unhealthy + if self.is_circuit_breaker_applicable(request): + request = self.global_partition_endpoint_manager_core.add_excluded_locations_to_request(request, + pk_range_wrapper) + return super(_GlobalPartitionEndpointManagerForCircuitBreaker, self).resolve_service_endpoint(request, + pk_range_wrapper) + + def mark_partition_unavailable( + self, + request: RequestObject, + pk_range_wrapper: PartitionKeyRangeWrapper + ) -> None: + self.global_partition_endpoint_manager_core.mark_partition_unavailable(request, pk_range_wrapper) + + def record_success( + self, + request: RequestObject + ) -> None: + if self.global_partition_endpoint_manager_core.is_circuit_breaker_applicable(request): + pk_range_wrapper = self.create_pk_range_wrapper(request) + self.global_partition_endpoint_manager_core.record_success(request, pk_range_wrapper) diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/_global_partition_endpoint_manager_circuit_breaker_core.py b/sdk/cosmos/azure-cosmos/azure/cosmos/_global_partition_endpoint_manager_circuit_breaker_core.py new file mode 100644 index 000000000000..577b8410f435 --- /dev/null +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/_global_partition_endpoint_manager_circuit_breaker_core.py @@ -0,0 +1,111 @@ +# The MIT License (MIT) +# Copyright (c) 2021 Microsoft Corporation + +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: + +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. + +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +"""Internal class for global endpoint manager for circuit breaker. +""" +import logging +import os + +from azure.cosmos import documents + +from azure.cosmos._partition_health_tracker import _PartitionHealthTracker +from azure.cosmos._routing.routing_range import PartitionKeyRangeWrapper +from azure.cosmos._location_cache import EndpointOperationType, LocationCache +from azure.cosmos._request_object import RequestObject +from azure.cosmos.http_constants import ResourceType +from azure.cosmos._constants import _Constants as Constants + + +logger = logging.getLogger("azure.cosmos._GlobalEndpointManagerForCircuitBreakerCore") + +class _GlobalPartitionEndpointManagerForCircuitBreakerCore(object): + """ + This internal class implements the logic for partition endpoint management for + geo-replicated database accounts. + """ + + + def __init__(self, client, location_cache: LocationCache): + self.partition_health_tracker = _PartitionHealthTracker() + self.location_cache = location_cache + self.client = client + + def is_circuit_breaker_applicable(self, request: RequestObject) -> bool: + if not request: + return False + + circuit_breaker_enabled = os.environ.get(Constants.CIRCUIT_BREAKER_ENABLED_CONFIG, + Constants.CIRCUIT_BREAKER_ENABLED_CONFIG_DEFAULT) == "True" + if not circuit_breaker_enabled: + return False + + if (not self.location_cache.can_use_multiple_write_locations_for_request(request) + and documents._OperationType.IsWriteOperation(request.operation_type)): # pylint: disable=protected-access + return False + + if request.resource_type != ResourceType.Document: + return False + + if request.operation_type == documents._OperationType.QueryPlan: # pylint: disable=protected-access + return False + + return True + + def record_failure( + self, + request: RequestObject, + pk_range_wrapper: PartitionKeyRangeWrapper + ) -> None: + #convert operation_type to EndpointOperationType + endpoint_operation_type = (EndpointOperationType.WriteType if ( + documents._OperationType.IsWriteOperation(request.operation_type)) # pylint: disable=protected-access + else EndpointOperationType.ReadType) + location = self.location_cache.get_location_from_endpoint(str(request.location_endpoint_to_route)) + self.partition_health_tracker.add_failure(pk_range_wrapper, endpoint_operation_type, str(location)) + # TODO: @tvaron3 lower request timeout to 5.5 seconds for recovering + # TODO: @tvaron3 exponential backoff for recovering + + def add_excluded_locations_to_request( + self, + request: RequestObject, + pk_range_wrapper: PartitionKeyRangeWrapper + ) -> RequestObject: + request.set_excluded_locations_from_circuit_breaker( + self.partition_health_tracker.get_excluded_locations(pk_range_wrapper) + ) + return request + + def mark_partition_unavailable(self, request: RequestObject, pk_range_wrapper: PartitionKeyRangeWrapper) -> None: + location = self.location_cache.get_location_from_endpoint(str(request.location_endpoint_to_route)) + self.partition_health_tracker.mark_partition_unavailable(pk_range_wrapper, location) + + def record_success( + self, + request: RequestObject, + pk_range_wrapper: PartitionKeyRangeWrapper + ) -> None: + #convert operation_type to either Read or Write + endpoint_operation_type = EndpointOperationType.WriteType if ( + documents._OperationType.IsWriteOperation(request.operation_type)) else EndpointOperationType.ReadType # pylint: disable=protected-access + location = self.location_cache.get_location_from_endpoint(str(request.location_endpoint_to_route)) + self.partition_health_tracker.add_success(pk_range_wrapper, endpoint_operation_type, location) + +# TODO: @tvaron3 there should be no in region retries when trying on healthy tentative ----------------------- diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/_location_cache.py b/sdk/cosmos/azure-cosmos/azure/cosmos/_location_cache.py index 96651d5c8b7f..21a65c349dab 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/_location_cache.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/_location_cache.py @@ -25,12 +25,13 @@ import collections import logging import time -from typing import Set +from typing import Set, Mapping, List from urllib.parse import urlparse from . import documents from . import http_constants from .documents import _OperationType +from ._request_object import RequestObject # pylint: disable=protected-access @@ -113,7 +114,10 @@ def get_endpoints_by_location(new_locations, except Exception as e: raise e - return endpoints_by_location, parsed_locations + # Also store a hash map of endpoints for each location + locations_by_endpoints = {value.get_primary(): key for key, value in endpoints_by_location.items()} + + return endpoints_by_location, locations_by_endpoints, parsed_locations def add_endpoint_if_preferred(endpoint: str, preferred_endpoints: Set[str], endpoints: Set[str]) -> bool: if endpoint in preferred_endpoints: @@ -150,22 +154,33 @@ def _get_health_check_endpoints( return endpoints +def _get_applicable_regional_endpoints(endpoints: List[RegionalRoutingContext], + location_name_by_endpoint: Mapping[str, str], + fall_back_endpoint: RegionalRoutingContext, + exclude_location_list: List[str]) -> List[RegionalRoutingContext]: + # filter endpoints by excluded locations + applicable_endpoints = [] + for endpoint in endpoints: + if location_name_by_endpoint.get(endpoint.get_primary()) not in exclude_location_list: + applicable_endpoints.append(endpoint) + + # if endpoint is empty add fallback endpoint + if not applicable_endpoints: + applicable_endpoints.append(fall_back_endpoint) + + return applicable_endpoints + +def current_time_millis(): + return int(round(time.time() * 1000)) class LocationCache(object): # pylint: disable=too-many-public-methods,too-many-instance-attributes - def current_time_millis(self): - return int(round(time.time() * 1000)) def __init__( self, - preferred_locations, default_endpoint, - enable_endpoint_discovery, - use_multiple_write_locations, + connection_policy, ): - self.preferred_locations = preferred_locations self.default_regional_routing_context = RegionalRoutingContext(default_endpoint, default_endpoint) - self.enable_endpoint_discovery = enable_endpoint_discovery - self.use_multiple_write_locations = use_multiple_write_locations self.enable_multiple_writable_locations = False self.write_regional_routing_contexts = [self.default_regional_routing_context] self.read_regional_routing_contexts = [self.default_regional_routing_context] @@ -173,8 +188,11 @@ def __init__( self.last_cache_update_time_stamp = 0 self.account_read_regional_routing_contexts_by_location = {} # pylint: disable=name-too-long self.account_write_regional_routing_contexts_by_location = {} # pylint: disable=name-too-long + self.account_locations_by_read_regional_routing_context = {} # pylint: disable=name-too-long + self.account_locations_by_write_regional_routing_context = {} # pylint: disable=name-too-long self.account_write_locations = [] self.account_read_locations = [] + self.connection_policy = connection_policy def get_write_regional_routing_contexts(self): return self.write_regional_routing_contexts @@ -182,6 +200,9 @@ def get_write_regional_routing_contexts(self): def get_read_regional_routing_contexts(self): return self.read_regional_routing_contexts + def get_location_from_endpoint(self, endpoint: str) -> str: + return self.account_locations_by_read_regional_routing_context[endpoint] + def get_write_regional_routing_context(self): return self.get_write_regional_routing_contexts()[0].get_primary() @@ -207,6 +228,45 @@ def get_ordered_write_locations(self): def get_ordered_read_locations(self): return self.account_read_locations + def _get_configured_excluded_locations(self, request: RequestObject): + # If excluded locations were configured on request, use request level excluded locations. + excluded_locations = request.excluded_locations + if excluded_locations is None: + # If excluded locations were only configured on client(connection_policy), use client level + excluded_locations = self.connection_policy.ExcludedLocations + excluded_locations.extend(request.excluded_locations_circuit_breaker) + return excluded_locations + + def _get_applicable_read_regional_endpoints(self, request: RequestObject): + # Get configured excluded locations + excluded_locations = self._get_configured_excluded_locations(request) + + # If excluded locations were configured, return filtered regional endpoints by excluded locations. + if excluded_locations: + return _get_applicable_regional_endpoints( + self.get_read_regional_routing_contexts(), + self.account_locations_by_read_regional_routing_context, + self.get_write_regional_routing_contexts()[0], + excluded_locations) + + # Else, return all regional endpoints + return self.get_read_regional_routing_contexts() + + def _get_applicable_write_regional_endpoints(self, request: RequestObject): + # Get configured excluded locations + excluded_locations = self._get_configured_excluded_locations(request) + + # If excluded locations were configured, return filtered regional endpoints by excluded locations. + if excluded_locations: + return _get_applicable_regional_endpoints( + self.get_write_regional_routing_contexts(), + self.account_locations_by_write_regional_routing_context, + self.default_regional_routing_context, + excluded_locations) + + # Else, return all regional endpoints + return self.get_write_regional_routing_contexts() + def resolve_service_endpoint(self, request): if request.location_endpoint_to_route: return request.location_endpoint_to_route @@ -227,7 +287,7 @@ def resolve_service_endpoint(self, request): # For non-document resource types in case of client can use multiple write locations # or when client cannot use multiple write locations, flip-flop between the # first and the second writable region in DatabaseAccount (for manual failover) - if self.enable_endpoint_discovery and self.account_write_locations: + if self.connection_policy.EnableEndpointDiscovery and self.account_write_locations: location_index = min(location_index % 2, len(self.account_write_locations) - 1) write_location = self.account_write_locations[location_index] if (self.account_write_regional_routing_contexts_by_location @@ -247,9 +307,9 @@ def resolve_service_endpoint(self, request): return self.default_regional_routing_context.get_primary() regional_routing_contexts = ( - self.get_write_regional_routing_contexts() + self._get_applicable_write_regional_endpoints(request) if documents._OperationType.IsWriteOperation(request.operation_type) - else self.get_read_regional_routing_contexts() + else self._get_applicable_read_regional_endpoints(request) ) regional_routing_context = regional_routing_contexts[location_index % len(regional_routing_contexts)] if ( @@ -263,12 +323,14 @@ def resolve_service_endpoint(self, request): return regional_routing_context.get_primary() def should_refresh_endpoints(self): # pylint: disable=too-many-return-statements - most_preferred_location = self.preferred_locations[0] if self.preferred_locations else None + most_preferred_location = self.connection_policy.PreferredLocations[0] \ + if self.connection_policy.PreferredLocations else None # we should schedule refresh in background if we are unable to target the user's most preferredLocation. - if self.enable_endpoint_discovery: + if self.connection_policy.EnableEndpointDiscovery: - should_refresh = self.use_multiple_write_locations and not self.enable_multiple_writable_locations + should_refresh = (self.connection_policy.UseMultipleWriteLocations + and not self.enable_multiple_writable_locations) if (most_preferred_location and most_preferred_location in self.account_read_regional_routing_contexts_by_location): @@ -358,25 +420,27 @@ def update_location_cache(self, write_locations=None, read_locations=None, enabl if enable_multiple_writable_locations: self.enable_multiple_writable_locations = enable_multiple_writable_locations - if self.enable_endpoint_discovery: + if self.connection_policy.EnableEndpointDiscovery: if read_locations: (self.account_read_regional_routing_contexts_by_location, + self.account_locations_by_read_regional_routing_context, self.account_read_locations) = get_endpoints_by_location( read_locations, self.account_read_regional_routing_contexts_by_location, self.default_regional_routing_context, False, - self.use_multiple_write_locations + self.connection_policy.UseMultipleWriteLocations ) if write_locations: (self.account_write_regional_routing_contexts_by_location, + self.account_locations_by_write_regional_routing_context, self.account_write_locations) = get_endpoints_by_location( write_locations, self.account_write_regional_routing_contexts_by_location, self.default_regional_routing_context, True, - self.use_multiple_write_locations + self.connection_policy.UseMultipleWriteLocations ) self.write_regional_routing_contexts = self.get_preferred_regional_routing_contexts( @@ -391,7 +455,6 @@ def update_location_cache(self, write_locations=None, read_locations=None, enabl EndpointOperationType.ReadType, self.write_regional_routing_contexts[0] ) - self.last_cache_update_timestamp = self.current_time_millis() # pylint: disable=attribute-defined-outside-init def get_preferred_regional_routing_contexts( self, endpoints_by_location, orderedLocations, expected_available_operation, fallback_endpoint @@ -399,18 +462,18 @@ def get_preferred_regional_routing_contexts( regional_endpoints = [] # if enableEndpointDiscovery is false, we always use the defaultEndpoint that # user passed in during documentClient init - if self.enable_endpoint_discovery and endpoints_by_location: # pylint: disable=too-many-nested-blocks + if self.connection_policy.EnableEndpointDiscovery and endpoints_by_location: # pylint: disable=too-many-nested-blocks if ( self.can_use_multiple_write_locations() or expected_available_operation == EndpointOperationType.ReadType ): unavailable_endpoints = [] - if self.preferred_locations: + if self.connection_policy.PreferredLocations: # When client can not use multiple write locations, preferred locations # list should only be used determining read endpoints order. If client # can use multiple write locations, preferred locations list should be # used for determining both read and write endpoints order. - for location in self.preferred_locations: + for location in self.connection_policy.PreferredLocations: regional_endpoint = endpoints_by_location[location] if location in endpoints_by_location \ else None if regional_endpoint: @@ -436,11 +499,12 @@ def get_preferred_regional_routing_contexts( return regional_endpoints def can_use_multiple_write_locations(self): - return self.use_multiple_write_locations and self.enable_multiple_writable_locations + return self.connection_policy.UseMultipleWriteLocations and self.enable_multiple_writable_locations def can_use_multiple_write_locations_for_request(self, request): # pylint: disable=name-too-long return self.can_use_multiple_write_locations() and ( request.resource_type == http_constants.ResourceType.Document + or request.resource_type == http_constants.ResourceType.PartitionKey or ( request.resource_type == http_constants.ResourceType.StoredProcedure and request.operation_type == documents._OperationType.ExecuteJavaScript diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/_partition_health_tracker.py b/sdk/cosmos/azure-cosmos/azure/cosmos/_partition_health_tracker.py new file mode 100644 index 000000000000..034d1cf1ac10 --- /dev/null +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/_partition_health_tracker.py @@ -0,0 +1,286 @@ +# The MIT License (MIT) +# Copyright (c) 2021 Microsoft Corporation + +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: + +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. + +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +"""Internal class for partition health tracker for circuit breaker. +""" +import logging +import os +from typing import Dict, Set, Any, List +from azure.cosmos._routing.routing_range import PartitionKeyRangeWrapper +from azure.cosmos._location_cache import current_time_millis, EndpointOperationType +from ._constants import _Constants as Constants + + +MINIMUM_REQUESTS_FOR_FAILURE_RATE = 100 +REFRESH_INTERVAL = 60 * 1000 # milliseconds +INITIAL_UNAVAILABLE_TIME = 60 * 1000 # milliseconds +# partition is unhealthy if sdk tried to recover and failed +UNHEALTHY = "unhealthy" +# partition is unhealthy tentative when it initially marked unavailable +UNHEALTHY_TENTATIVE = "unhealthy_tentative" +# partition is healthy tentative when sdk is trying to recover +HEALTHY_TENTATIVE = "healthy_tentative" +# unavailability info keys +LAST_UNAVAILABILITY_CHECK_TIME_STAMP = "lastUnavailabilityCheckTimeStamp" +HEALTH_STATUS = "healthStatus" + + +def _has_exceeded_failure_rate_threshold( + successes: int, + failures: int, + failure_rate_threshold: int, +) -> bool: + if successes + failures < MINIMUM_REQUESTS_FOR_FAILURE_RATE: + return False + return (failures / successes * 100) >= failure_rate_threshold + +class _PartitionHealthInfo(object): + """ + This internal class keeps the health and statistics for a partition. + """ + + def __init__(self) -> None: + self.write_failure_count: int = 0 + self.read_failure_count: int = 0 + self.write_success_count: int = 0 + self.read_success_count: int = 0 + self.read_consecutive_failure_count: int = 0 + self.write_consecutive_failure_count: int = 0 + self.unavailability_info: Dict[str, Any] = {} + + + def reset_health_stats(self) -> None: + self.write_failure_count = 0 + self.read_failure_count = 0 + self.write_success_count = 0 + self.read_success_count = 0 + self.read_consecutive_failure_count = 0 + self.write_consecutive_failure_count = 0 + + def __str__(self) -> str: + return (f"{self.__class__.__name__}: {self.unavailability_info}\n" + f"write failure count: {self.write_failure_count}\n" + f"read failure count: {self.read_failure_count}\n" + f"write success count: {self.write_success_count}\n" + f"read success count: {self.read_success_count}\n" + f"write consecutive failure count: {self.write_consecutive_failure_count}\n" + f"read consecutive failure count: {self.read_consecutive_failure_count}\n") + +logger = logging.getLogger("azure.cosmos._PartitionHealthTracker") + +class _PartitionHealthTracker(object): + """ + This internal class implements the logic for tracking health thresholds for a partition. + """ + + + def __init__(self) -> None: + # partition -> regions -> health info + self.pk_range_wrapper_to_health_info: Dict[PartitionKeyRangeWrapper, Dict[str, _PartitionHealthInfo]] = {} + self.last_refresh = current_time_millis() + + # TODO: @tvaron3 look for useful places to add logs + + def mark_partition_unavailable(self, pk_range_wrapper: PartitionKeyRangeWrapper, location: str) -> None: + # mark the partition key range as unavailable + self._transition_health_status_on_failure(pk_range_wrapper, location) + + def _transition_health_status_on_failure( + self, + pk_range_wrapper: PartitionKeyRangeWrapper, + location: str + ) -> None: + logger.warning("%s has been marked as unavailable.", pk_range_wrapper) + current_time = current_time_millis() + if pk_range_wrapper not in self.pk_range_wrapper_to_health_info: + # healthy -> unhealthy tentative + partition_health_info = _PartitionHealthInfo() + partition_health_info.unavailability_info = { + LAST_UNAVAILABILITY_CHECK_TIME_STAMP: current_time, + HEALTH_STATUS: UNHEALTHY_TENTATIVE + } + self.pk_range_wrapper_to_health_info[pk_range_wrapper] = { + location: partition_health_info + } + else: + region_to_partition_health = self.pk_range_wrapper_to_health_info[pk_range_wrapper] + if location in region_to_partition_health and region_to_partition_health[location].unavailability_info: + # healthy tentative -> unhealthy + # if the operation type is not empty, we are in the healthy tentative state + region_to_partition_health[location].unavailability_info[HEALTH_STATUS] = UNHEALTHY + # reset the last unavailability check time stamp + region_to_partition_health[location].unavailability_info[LAST_UNAVAILABILITY_CHECK_TIME_STAMP] \ + = current_time + else: + # healthy -> unhealthy tentative + # if the operation type is empty, we are in the unhealthy tentative state + partition_health_info = _PartitionHealthInfo() + partition_health_info.unavailability_info = { + LAST_UNAVAILABILITY_CHECK_TIME_STAMP: current_time, + HEALTH_STATUS: UNHEALTHY_TENTATIVE + } + self.pk_range_wrapper_to_health_info[pk_range_wrapper][location] = partition_health_info + + def _transition_health_status_on_success( + self, + pk_range_wrapper: PartitionKeyRangeWrapper, + location: str + ) -> None: + if pk_range_wrapper in self.pk_range_wrapper_to_health_info: + # healthy tentative -> healthy + self.pk_range_wrapper_to_health_info[pk_range_wrapper].pop(location, None) + + def _check_stale_partition_info(self, pk_range_wrapper: PartitionKeyRangeWrapper) -> None: + current_time = current_time_millis() + + stale_partition_unavailability_check = int(os.environ.get(Constants.STALE_PARTITION_UNAVAILABILITY_CHECK, + Constants.STALE_PARTITION_UNAVAILABILITY_CHECK_DEFAULT)) * 1000 + if pk_range_wrapper in self.pk_range_wrapper_to_health_info: + for _, partition_health_info in self.pk_range_wrapper_to_health_info[pk_range_wrapper].items(): + if partition_health_info.unavailability_info: + elapsed_time = (current_time - + partition_health_info.unavailability_info[LAST_UNAVAILABILITY_CHECK_TIME_STAMP]) + current_health_status = partition_health_info.unavailability_info[HEALTH_STATUS] + # check if the partition key range is still unavailable + if ((current_health_status == UNHEALTHY and elapsed_time > stale_partition_unavailability_check) + or (current_health_status == UNHEALTHY_TENTATIVE + and elapsed_time > INITIAL_UNAVAILABLE_TIME)): + # unhealthy or unhealthy tentative -> healthy tentative + partition_health_info.unavailability_info[HEALTH_STATUS] = HEALTHY_TENTATIVE + + if current_time - self.last_refresh > REFRESH_INTERVAL: + # all partition stats reset every minute + self._reset_partition_health_tracker_stats() + + + def get_excluded_locations(self, pk_range_wrapper: PartitionKeyRangeWrapper) -> List[str]: + self._check_stale_partition_info(pk_range_wrapper) + excluded_locations = [] + if pk_range_wrapper in self.pk_range_wrapper_to_health_info: + for location, partition_health_info in self.pk_range_wrapper_to_health_info[pk_range_wrapper].items(): + if partition_health_info.unavailability_info: + health_status = partition_health_info.unavailability_info[HEALTH_STATUS] + if health_status in (UNHEALTHY_TENTATIVE, UNHEALTHY): + excluded_locations.append(location) + return excluded_locations + + + def add_failure( + self, + pk_range_wrapper: PartitionKeyRangeWrapper, + operation_type: str, + location: str + ) -> None: + # Retrieve the failure rate threshold from the environment. + failure_rate_threshold = int(os.environ.get(Constants.FAILURE_PERCENTAGE_TOLERATED, + Constants.FAILURE_PERCENTAGE_TOLERATED_DEFAULT)) + + # Ensure that the health info dictionary is properly initialized. + if pk_range_wrapper not in self.pk_range_wrapper_to_health_info: + self.pk_range_wrapper_to_health_info[pk_range_wrapper] = {} + if location not in self.pk_range_wrapper_to_health_info[pk_range_wrapper]: + self.pk_range_wrapper_to_health_info[pk_range_wrapper][location] = _PartitionHealthInfo() + + health_info = self.pk_range_wrapper_to_health_info[pk_range_wrapper][location] + + # Determine attribute names and environment variables based on the operation type. + if operation_type == EndpointOperationType.WriteType: + success_attr = 'write_success_count' + failure_attr = 'write_failure_count' + consecutive_attr = 'write_consecutive_failure_count' + env_key = Constants.CONSECUTIVE_ERROR_COUNT_TOLERATED_FOR_WRITE + default_consecutive_threshold = Constants.CONSECUTIVE_ERROR_COUNT_TOLERATED_FOR_WRITE_DEFAULT + else: + success_attr = 'read_success_count' + failure_attr = 'read_failure_count' + consecutive_attr = 'read_consecutive_failure_count' + env_key = Constants.CONSECUTIVE_ERROR_COUNT_TOLERATED_FOR_READ + default_consecutive_threshold = Constants.CONSECUTIVE_ERROR_COUNT_TOLERATED_FOR_READ_DEFAULT + + # Increment failure and consecutive failure counts. + setattr(health_info, failure_attr, getattr(health_info, failure_attr) + 1) + setattr(health_info, consecutive_attr, getattr(health_info, consecutive_attr) + 1) + + # Retrieve the consecutive failure threshold from the environment. + consecutive_failure_threshold = int(os.environ.get(env_key, default_consecutive_threshold)) + + # Call the threshold checker with the current stats. + self._check_thresholds( + pk_range_wrapper, + getattr(health_info, success_attr), + getattr(health_info, failure_attr), + getattr(health_info, consecutive_attr), + location, + failure_rate_threshold, + consecutive_failure_threshold + ) + print(self.pk_range_wrapper_to_health_info[pk_range_wrapper][location]) + print(pk_range_wrapper) + print(location) + + def _check_thresholds( + self, + pk_range_wrapper: PartitionKeyRangeWrapper, + successes: int, + failures: int, + consecutive_failures: int, + location: str, + failure_rate_threshold: int, + consecutive_failure_threshold: int, + ) -> None: + + # check the failure rate was not exceeded + if _has_exceeded_failure_rate_threshold( + successes, + failures, + failure_rate_threshold + ): + self._transition_health_status_on_failure(pk_range_wrapper, location) + + # add to consecutive failures and check that threshold was not exceeded + if consecutive_failures >= consecutive_failure_threshold: + self._transition_health_status_on_failure(pk_range_wrapper, location) + + def add_success(self, pk_range_wrapper: PartitionKeyRangeWrapper, operation_type: str, location: str) -> None: + # Ensure that the health info dictionary is initialized. + if pk_range_wrapper not in self.pk_range_wrapper_to_health_info: + self.pk_range_wrapper_to_health_info[pk_range_wrapper] = {} + if location not in self.pk_range_wrapper_to_health_info[pk_range_wrapper]: + self.pk_range_wrapper_to_health_info[pk_range_wrapper][location] = _PartitionHealthInfo() + + health_info = self.pk_range_wrapper_to_health_info[pk_range_wrapper][location] + + if operation_type == EndpointOperationType.WriteType: + health_info.write_success_count += 1 + health_info.write_consecutive_failure_count = 0 + else: + health_info.read_success_count += 1 + health_info.read_consecutive_failure_count = 0 + self._transition_health_status_on_success(pk_range_wrapper, operation_type) + print(self.pk_range_wrapper_to_health_info[pk_range_wrapper][location]) + print(pk_range_wrapper) + print(location) + + + def _reset_partition_health_tracker_stats(self) -> None: + for locations in self.pk_range_wrapper_to_health_info.values(): + for health_info in locations.values(): + health_info.reset_health_stats() diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/_request_object.py b/sdk/cosmos/azure-cosmos/azure/cosmos/_request_object.py index a220c6af42c2..24dc9e0dd9c2 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/_request_object.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/_request_object.py @@ -21,18 +21,28 @@ """Represents a request object. """ -from typing import Optional +from typing import Optional, Mapping, Any, Dict, Set, List +from . import http_constants -class RequestObject(object): - def __init__(self, resource_type: str, operation_type: str, endpoint_override: Optional[str] = None) -> None: +class RequestObject(object): # pylint: disable=too-many-instance-attributes + def __init__( + self, + resource_type: str, + operation_type: str, + headers: Dict[str, Any], + endpoint_override: Optional[str] = None, + ) -> None: self.resource_type = resource_type self.operation_type = operation_type self.endpoint_override = endpoint_override self.should_clear_session_token_on_session_read_failure: bool = False # pylint: disable=name-too-long + self.headers = headers self.use_preferred_locations: Optional[bool] = None self.location_index_to_route: Optional[int] = None self.location_endpoint_to_route: Optional[str] = None self.last_routed_location_endpoint_within_region: Optional[str] = None + self.excluded_locations: Optional[List[str]] = None + self.excluded_locations_circuit_breaker: List[str] = [] def route_to_location_with_preferred_location_flag( # pylint: disable=name-too-long self, @@ -52,3 +62,32 @@ def clear_route_to_location(self) -> None: self.location_index_to_route = None self.use_preferred_locations = None self.location_endpoint_to_route = None + + def _can_set_excluded_location(self, options: Mapping[str, Any]) -> bool: + # If resource types for requests are not one of the followings, excluded locations cannot be set + acceptable_resource_types = [ + http_constants.ResourceType.Document, + http_constants.ResourceType.PartitionKey, + http_constants.ResourceType.Collection, + ] + if self.resource_type.lower() not in acceptable_resource_types: + return False + + # If 'excludedLocations' wasn't in the options, excluded locations cannot be set + if (options is None + or 'excludedLocations' not in options): + return False + + # The 'excludedLocations' cannot be None + if options['excludedLocations'] is None: + raise ValueError("Excluded locations cannot be None. " + "If you want to remove all excluded locations, try passing an empty list.") + + return True + + def set_excluded_location_from_options(self, options: Mapping[str, Any]) -> None: + if self._can_set_excluded_location(options): + self.excluded_locations = options['excludedLocations'] + + def set_excluded_locations_from_circuit_breaker(self, excluded_locations: List[str]) -> None: # pylint: disable=name-too-long + self.excluded_locations_circuit_breaker = excluded_locations diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/_retry_utility.py b/sdk/cosmos/azure-cosmos/azure/cosmos/_retry_utility.py index 927ed7a41baa..7d27885f10db 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/_retry_utility.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/_retry_utility.py @@ -45,7 +45,7 @@ # pylint: disable=protected-access, disable=too-many-lines, disable=too-many-statements, disable=too-many-branches -def Execute(client, global_endpoint_manager, function, *args, **kwargs): +def Execute(client, global_endpoint_manager, function, *args, **kwargs): # pylint: disable=too-many-locals """Executes the function with passed parameters applying all retry policies :param object client: @@ -58,6 +58,9 @@ def Execute(client, global_endpoint_manager, function, *args, **kwargs): :returns: the result of running the passed in function as a (result, headers) tuple :rtype: tuple of (dict, dict) """ + pk_range_wrapper = None + if args and global_endpoint_manager.is_circuit_breaker_applicable(args[0]): + pk_range_wrapper = global_endpoint_manager.create_pk_range_wrapper(args[0]) # instantiate all retry policies here to be applied for each request execution endpointDiscovery_retry_policy = _endpoint_discovery_retry_policy.EndpointDiscoveryRetryPolicy( client.connection_policy, global_endpoint_manager, *args @@ -73,19 +76,19 @@ def Execute(client, global_endpoint_manager, function, *args, **kwargs): defaultRetry_policy = _default_retry_policy.DefaultRetryPolicy(*args) sessionRetry_policy = _session_retry_policy._SessionRetryPolicy( - client.connection_policy.EnableEndpointDiscovery, global_endpoint_manager, *args + client.connection_policy.EnableEndpointDiscovery, global_endpoint_manager, pk_range_wrapper, *args ) partition_key_range_gone_retry_policy = _gone_retry_policy.PartitionKeyRangeGoneRetryPolicy(client, *args) timeout_failover_retry_policy = _timeout_failover_retry_policy._TimeoutFailoverRetryPolicy( - client.connection_policy, global_endpoint_manager, *args + client.connection_policy, global_endpoint_manager, pk_range_wrapper, *args ) service_response_retry_policy = _service_response_retry_policy.ServiceResponseRetryPolicy( - client.connection_policy, global_endpoint_manager, *args, + client.connection_policy, global_endpoint_manager, pk_range_wrapper, *args, ) service_request_retry_policy = _service_request_retry_policy.ServiceRequestRetryPolicy( - client.connection_policy, global_endpoint_manager, *args, + client.connection_policy, global_endpoint_manager, pk_range_wrapper, *args, ) # HttpRequest we would need to modify for Container Recreate Retry Policy request = None @@ -104,6 +107,7 @@ def Execute(client, global_endpoint_manager, function, *args, **kwargs): try: if args: result = ExecuteFunction(function, global_endpoint_manager, *args, **kwargs) + global_endpoint_manager.record_success(args[0]) else: result = ExecuteFunction(function, *args, **kwargs) if not client.last_response_headers: @@ -172,9 +176,9 @@ def Execute(client, global_endpoint_manager, function, *args, **kwargs): retry_policy.container_rid = cached_container["_rid"] request.headers[retry_policy._intended_headers] = retry_policy.container_rid - elif e.status_code == StatusCodes.REQUEST_TIMEOUT: - retry_policy = timeout_failover_retry_policy - elif e.status_code >= StatusCodes.INTERNAL_SERVER_ERROR: + elif e.status_code == StatusCodes.REQUEST_TIMEOUT or e.status_code >= StatusCodes.INTERNAL_SERVER_ERROR: + # record the failure for circuit breaker tracking + global_endpoint_manager.record_failure(args[0]) retry_policy = timeout_failover_retry_policy else: retry_policy = defaultRetry_policy @@ -200,6 +204,7 @@ def Execute(client, global_endpoint_manager, function, *args, **kwargs): if client_timeout: kwargs['timeout'] = client_timeout - (time.time() - start_time) if kwargs['timeout'] <= 0: + global_endpoint_manager.record_failure(args[0]) raise exceptions.CosmosClientTimeoutError() except ServiceRequestError as e: @@ -291,7 +296,8 @@ def send(self, request): """ absolute_timeout = request.context.options.pop('timeout', None) per_request_timeout = request.context.options.pop('connection_timeout', 0) - + request_params = request.context.options.pop('request_params', None) + global_endpoint_manager = request.context.options.pop('global_endpoint_manager', None) retry_error = None retry_active = True response = None @@ -317,6 +323,7 @@ def send(self, request): # since request wasn't sent, raise exception immediately to be dealt with in client retry policies # This logic is based on the _retry.py file from azure-core if not _has_database_account_header(request.http_request.headers): + global_endpoint_manager.record_failure(request_params) if retry_settings['connect'] > 0: retry_active = self.increment(retry_settings, response=request, error=err) if retry_active: @@ -329,7 +336,7 @@ def send(self, request): if (not _has_read_retryable_headers(request.http_request.headers) or _has_database_account_header(request.http_request.headers)): raise err - + global_endpoint_manager.record_failure(request_params) # This logic is based on the _retry.py file from azure-core if retry_settings['read'] > 0: retry_active = self.increment(retry_settings, response=request, error=err) @@ -342,6 +349,7 @@ def send(self, request): if _has_database_account_header(request.http_request.headers): raise err if self._is_method_retryable(retry_settings, request.http_request): + global_endpoint_manager.record_failure(request_params) retry_active = self.increment(retry_settings, response=request, error=err) if retry_active: self.sleep(retry_settings, request.context.transport) diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/_routing/aio/routing_map_provider.py b/sdk/cosmos/azure-cosmos/azure/cosmos/_routing/aio/routing_map_provider.py index e70ae355c495..f59513d05d24 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/_routing/aio/routing_map_provider.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/_routing/aio/routing_map_provider.py @@ -22,11 +22,13 @@ """Internal class for partition key range cache implementation in the Azure Cosmos database service. """ +from typing import Dict, Any from ... import _base from ..collection_routing_map import CollectionRoutingMap from .. import routing_range + # pylint: disable=protected-access @@ -58,13 +60,21 @@ async def get_overlapping_ranges(self, collection_link, partition_key_ranges, ** :return: List of overlapping partition key ranges. :rtype: list """ - cl = self._documentClient - collection_id = _base.GetResourceIdOrFullNameFromLink(collection_link) + await self.initialize_collection_routing_map_if_needed(collection_link, str(collection_id), **kwargs) + + return self._collection_routing_map_by_item[collection_id].get_overlapping_ranges(partition_key_ranges) + async def initialize_collection_routing_map_if_needed( + self, + collection_link: str, + collection_id: str, + **kwargs: Dict[str, Any] + ): + client = self._documentClient collection_routing_map = self._collection_routing_map_by_item.get(collection_id) if collection_routing_map is None: - collection_pk_ranges = [pk async for pk in cl._ReadPartitionKeyRanges(collection_link, **kwargs)] + collection_pk_ranges = [pk async for pk in client._ReadPartitionKeyRanges(collection_link, **kwargs)] # for large collections, a split may complete between the read partition key ranges query page responses, # causing the partitionKeyRanges to have both the children ranges and their parents. Therefore, we need # to discard the parent ranges to have a valid routing map. @@ -72,8 +82,18 @@ async def get_overlapping_ranges(self, collection_link, partition_key_ranges, ** collection_routing_map = CollectionRoutingMap.CompleteRoutingMap( [(r, True) for r in collection_pk_ranges], collection_id ) - self._collection_routing_map_by_item[collection_id] = collection_routing_map - return collection_routing_map.get_overlapping_ranges(partition_key_ranges) + self._collection_routing_map_by_item[collection_id] = collection_routing_map + + async def get_range_by_partition_key_range_id( + self, + collection_link: str, + partition_key_range_id: int, + **kwargs: Dict[str, Any] + ) -> Dict[str, Any]: + collection_id = _base.GetResourceIdOrFullNameFromLink(collection_link) + await self.initialize_collection_routing_map_if_needed(collection_link, str(collection_id), **kwargs) + + return self._collection_routing_map_by_item[collection_id].get_range_by_partition_key_range_id(partition_key_range_id) @staticmethod def _discard_parent_ranges(partitionKeyRanges): @@ -196,3 +216,11 @@ async def get_overlapping_ranges(self, collection_link, partition_key_ranges, ** pass return target_partition_key_ranges + + async def get_range_by_partition_key_range_id( + self, + collection_link: str, + partition_key_range_id: int, + **kwargs: Dict[str, Any] + ) -> Dict[str, Any]: + return await super().get_range_by_partition_key_range_id(collection_link, partition_key_range_id, **kwargs) diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/_routing/routing_range.py b/sdk/cosmos/azure-cosmos/azure/cosmos/_routing/routing_range.py index 452bc32e5b34..e31682725828 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/_routing/routing_range.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/_routing/routing_range.py @@ -226,3 +226,27 @@ def is_subset(self, parent_range: 'Range') -> bool: normalized_child_range = self.to_normalized_range() return (normalized_parent_range.min <= normalized_child_range.min and normalized_parent_range.max >= normalized_child_range.max) + +class PartitionKeyRangeWrapper(object): + """Internal class for a representation of a unique partition for an account + """ + + def __init__(self, partition_key_range: Range, collection_rid: str) -> None: + self.partition_key_range = partition_key_range + self.collection_rid = collection_rid + + + def __str__(self) -> str: + return ( + f"PartitionKeyRangeWrapper(" + f"partition_key_range={self.partition_key_range}, " + f"collection_rid={self.collection_rid}, " + ) + + def __eq__(self, other): + if not isinstance(other, PartitionKeyRangeWrapper): + return False + return self.partition_key_range == other.partition_key_range and self.collection_rid == other.collection_rid + + def __hash__(self): + return hash((self.partition_key_range, self.collection_rid)) diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/_service_request_retry_policy.py b/sdk/cosmos/azure-cosmos/azure/cosmos/_service_request_retry_policy.py index 9b34d048e3a6..b49185512fa4 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/_service_request_retry_policy.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/_service_request_retry_policy.py @@ -13,9 +13,10 @@ class ServiceRequestRetryPolicy(object): - def __init__(self, connection_policy, global_endpoint_manager, *args): + def __init__(self, connection_policy, global_endpoint_manager, pk_range_wrapper, *args): self.args = args self.global_endpoint_manager = global_endpoint_manager + self.pk_range_wrapper = pk_range_wrapper self.total_retries = len(self.global_endpoint_manager.location_cache.read_regional_routing_contexts) self.total_in_region_retries = 1 self.in_region_retry_count = 0 @@ -44,9 +45,12 @@ def ShouldRetry(self): if self.request.resource_type == ResourceType.DatabaseAccount: return False - refresh_cache = self.request.last_routed_location_endpoint_within_region is not None - # This logic is for the last retry and mark the region unavailable - self.mark_endpoint_unavailable(self.request.location_endpoint_to_route, refresh_cache) + if self.global_endpoint_manager.is_circuit_breaker_applicable(self.request): + self.global_endpoint_manager.mark_partition_unavailable(self.request, self.pk_range_wrapper) + else: + refresh_cache = self.request.last_routed_location_endpoint_within_region is not None + # This logic is for the last retry and mark the region unavailable + self.mark_endpoint_unavailable(self.request.location_endpoint_to_route, refresh_cache) # Check if it is safe to do another retry if self.in_region_retry_count >= self.total_in_region_retries: @@ -58,14 +62,14 @@ def ShouldRetry(self): self.request.last_routed_location_endpoint_within_region = self.request.location_endpoint_to_route if (_OperationType.IsReadOnlyOperation(self.request.operation_type) - or self.global_endpoint_manager.get_use_multiple_write_locations()): + or self.global_endpoint_manager.can_use_multiple_write_locations(self.request)): self.update_location_cache() # We just directly got to the next location in case of read requests # We don't retry again on the same region for regional endpoint self.failover_retry_count += 1 if self.failover_retry_count >= self.total_retries: return False - # # Check if it is safe to failover to another region + # Check if it is safe to failover to another region location_endpoint = self.resolve_next_region_service_endpoint() else: location_endpoint = self.resolve_current_region_service_endpoint() @@ -80,7 +84,7 @@ def ShouldRetry(self): # and we reset the in region retry count self.in_region_retry_count = 0 self.failover_retry_count += 1 - # # Check if it is safe to failover to another region + # Check if it is safe to failover to another region if self.failover_retry_count >= self.total_retries: return False location_endpoint = self.resolve_next_region_service_endpoint() @@ -96,7 +100,7 @@ def resolve_current_region_service_endpoint(self): # resolve the next service endpoint in the same region # since we maintain 2 endpoints per region for write operations self.request.route_to_location_with_preferred_location_flag(0, True) - return self.global_endpoint_manager.resolve_service_endpoint(self.request) + return self.global_endpoint_manager.resolve_service_endpoint(self.request, self.pk_range_wrapper) # This function prepares the request to go to the next region def resolve_next_region_service_endpoint(self): @@ -110,7 +114,7 @@ def resolve_next_region_service_endpoint(self): self.request.route_to_location_with_preferred_location_flag(0, True) # Resolve the endpoint for the request and pin the resolution to the resolved endpoint # This enables marking the endpoint unavailability on endpoint failover/unreachability - return self.global_endpoint_manager.resolve_service_endpoint(self.request) + return self.global_endpoint_manager.resolve_service_endpoint(self.request, self.pk_range_wrapper) def mark_endpoint_unavailable(self, unavailable_endpoint, refresh_cache: bool): if _OperationType.IsReadOnlyOperation(self.request.operation_type): diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/_service_response_retry_policy.py b/sdk/cosmos/azure-cosmos/azure/cosmos/_service_response_retry_policy.py index 83a856f39d33..330ffb5929a5 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/_service_response_retry_policy.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/_service_response_retry_policy.py @@ -12,15 +12,17 @@ class ServiceResponseRetryPolicy(object): - def __init__(self, connection_policy, global_endpoint_manager, *args): + def __init__(self, connection_policy, global_endpoint_manager, pk_range_wrapper, *args): self.args = args self.global_endpoint_manager = global_endpoint_manager + self.pk_range_wrapper = pk_range_wrapper self.total_retries = len(self.global_endpoint_manager.location_cache.read_regional_routing_contexts) self.failover_retry_count = 0 self.connection_policy = connection_policy self.request = args[0] if args else None if self.request: - self.location_endpoint = self.global_endpoint_manager.resolve_service_endpoint(self.request) + self.location_endpoint = self.global_endpoint_manager.resolve_service_endpoint(self.request, + pk_range_wrapper) self.logger = logging.getLogger('azure.cosmos.ServiceResponseRetryPolicy') def ShouldRetry(self): @@ -57,4 +59,4 @@ def resolve_next_region_service_endpoint(self): self.request.route_to_location_with_preferred_location_flag(self.failover_retry_count, True) # Resolve the endpoint for the request and pin the resolution to the resolved endpoint # This enables marking the endpoint unavailability on endpoint failover/unreachability - return self.global_endpoint_manager.resolve_service_endpoint(self.request) + return self.global_endpoint_manager.resolve_service_endpoint(self.request, self.pk_range_wrapper) diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/_session_retry_policy.py b/sdk/cosmos/azure-cosmos/azure/cosmos/_session_retry_policy.py index 1614f337de5b..e52fbe996e11 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/_session_retry_policy.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/_session_retry_policy.py @@ -41,10 +41,11 @@ class _SessionRetryPolicy(object): Max_retry_attempt_count = 1 Retry_after_in_milliseconds = 0 - def __init__(self, endpoint_discovery_enable, global_endpoint_manager, *args): + def __init__(self, endpoint_discovery_enable, global_endpoint_manager, pk_range_wrapper, *args): self.global_endpoint_manager = global_endpoint_manager self._max_retry_attempt_count = _SessionRetryPolicy.Max_retry_attempt_count self.session_token_retry_count = 0 + self.pk_range_wrapper = pk_range_wrapper self.retry_after_in_milliseconds = _SessionRetryPolicy.Retry_after_in_milliseconds self.endpoint_discovery_enable = endpoint_discovery_enable self.request = args[0] if args else None @@ -57,7 +58,8 @@ def __init__(self, endpoint_discovery_enable, global_endpoint_manager, *args): # Resolve the endpoint for the request and pin the resolution to the resolved endpoint # This enables marking the endpoint unavailability on endpoint failover/unreachability - self.location_endpoint = self.global_endpoint_manager.resolve_service_endpoint(self.request) + self.location_endpoint = self.global_endpoint_manager.resolve_service_endpoint(self.request, + self.pk_range_wrapper) self.request.route_to_location(self.location_endpoint) def ShouldRetry(self, _exception): @@ -98,7 +100,8 @@ def ShouldRetry(self, _exception): # Resolve the endpoint for the request and pin the resolution to the resolved endpoint # This enables marking the endpoint unavailability on endpoint failover/unreachability - self.location_endpoint = self.global_endpoint_manager.resolve_service_endpoint(self.request) + self.location_endpoint = self.global_endpoint_manager.resolve_service_endpoint(self.request, + self.pk_range_wrapper) self.request.route_to_location(self.location_endpoint) return True @@ -113,6 +116,7 @@ def ShouldRetry(self, _exception): # Resolve the endpoint for the request and pin the resolution to the resolved endpoint # This enables marking the endpoint unavailability on endpoint failover/unreachability - self.location_endpoint = self.global_endpoint_manager.resolve_service_endpoint(self.request) + self.location_endpoint = self.global_endpoint_manager.resolve_service_endpoint(self.request, + self.pk_range_wrapper) self.request.route_to_location(self.location_endpoint) return True diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/_synchronized_request.py b/sdk/cosmos/azure-cosmos/azure/cosmos/_synchronized_request.py index 68e37caf1d9d..43516430bdb6 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/_synchronized_request.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/_synchronized_request.py @@ -65,7 +65,7 @@ def _request_body_from_data(data): return None -def _Request(global_endpoint_manager, request_params, connection_policy, pipeline_client, request, **kwargs): +def _Request(global_endpoint_manager, request_params, connection_policy, pipeline_client, request, **kwargs): # pylint: disable=too-many-statements """Makes one http request using the requests module. :param _GlobalEndpointManager global_endpoint_manager: @@ -104,7 +104,11 @@ def _Request(global_endpoint_manager, request_params, connection_policy, pipelin if request_params.endpoint_override: base_url = request_params.endpoint_override else: - base_url = global_endpoint_manager.resolve_service_endpoint(request_params) + pk_range_wrapper = None + if global_endpoint_manager.is_circuit_breaker_applicable(request_params): + # Circuit breaker is applicable, so we need to use the endpoint from the request + pk_range_wrapper = global_endpoint_manager.create_pk_range_wrapper(request_params) + base_url = global_endpoint_manager.resolve_service_endpoint(request_params, pk_range_wrapper) if not request.url.startswith(base_url): request.url = _replace_url_prefix(request.url, base_url) @@ -132,6 +136,8 @@ def _Request(global_endpoint_manager, request_params, connection_policy, pipelin read_timeout=read_timeout, connection_verify=kwargs.pop("connection_verify", ca_certs), connection_cert=kwargs.pop("connection_cert", cert_files), + request_params=request_params, + global_endpoint_manager=global_endpoint_manager, **kwargs ) else: @@ -142,6 +148,8 @@ def _Request(global_endpoint_manager, request_params, connection_policy, pipelin read_timeout=read_timeout, # If SSL is disabled, verify = false connection_verify=kwargs.pop("connection_verify", is_ssl_enabled), + request_params=request_params, + global_endpoint_manager=global_endpoint_manager, **kwargs ) diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/_timeout_failover_retry_policy.py b/sdk/cosmos/azure-cosmos/azure/cosmos/_timeout_failover_retry_policy.py index b170fb4fd9d2..69bc973c3346 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/_timeout_failover_retry_policy.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/_timeout_failover_retry_policy.py @@ -9,11 +9,10 @@ class _TimeoutFailoverRetryPolicy(object): - def __init__(self, connection_policy, global_endpoint_manager, *args): + def __init__(self, connection_policy, global_endpoint_manager, pk_range_wrapper, *args): self.retry_after_in_milliseconds = 500 - self.args = args - self.global_endpoint_manager = global_endpoint_manager + self.pk_range_wrapper = pk_range_wrapper # If an account only has 1 region, then we still want to retry once on the same region self._max_retry_attempt_count = (len(self.global_endpoint_manager.location_cache.read_regional_routing_contexts) + 1) @@ -56,4 +55,4 @@ def resolve_next_region_service_endpoint(self): self.request.route_to_location_with_preferred_location_flag(self.retry_count, True) # Resolve the endpoint for the request and pin the resolution to the resolved endpoint # This enables marking the endpoint unavailability on endpoint failover/unreachability - return self.global_endpoint_manager.resolve_service_endpoint(self.request) + return self.global_endpoint_manager.resolve_service_endpoint(self.request, self.pk_range_wrapper) diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_asynchronous_request.py b/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_asynchronous_request.py index 81430d8df42c..4fda37ea0a87 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_asynchronous_request.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_asynchronous_request.py @@ -34,7 +34,7 @@ from .._synchronized_request import _request_body_from_data, _replace_url_prefix -async def _Request(global_endpoint_manager, request_params, connection_policy, pipeline_client, request, **kwargs): +async def _Request(global_endpoint_manager, request_params, connection_policy, pipeline_client, request, **kwargs): # pylint: disable=too-many-statements """Makes one http request using the requests module. :param _GlobalEndpointManager global_endpoint_manager: @@ -73,7 +73,11 @@ async def _Request(global_endpoint_manager, request_params, connection_policy, p if request_params.endpoint_override: base_url = request_params.endpoint_override else: - base_url = global_endpoint_manager.resolve_service_endpoint(request_params) + pk_range_wrapper = None + if global_endpoint_manager.is_circuit_breaker_applicable(request_params): + # Circuit breaker is applicable, so we need to use the endpoint from the request + pk_range_wrapper = await global_endpoint_manager.create_pk_range_wrapper(request_params) + base_url = global_endpoint_manager.resolve_service_endpoint(request_params, pk_range_wrapper) if not request.url.startswith(base_url): request.url = _replace_url_prefix(request.url, base_url) @@ -101,6 +105,8 @@ async def _Request(global_endpoint_manager, request_params, connection_policy, p read_timeout=read_timeout, connection_verify=kwargs.pop("connection_verify", ca_certs), connection_cert=kwargs.pop("connection_cert", cert_files), + request_params=request_params, + global_endpoint_manager=global_endpoint_manager, **kwargs ) else: @@ -111,6 +117,8 @@ async def _Request(global_endpoint_manager, request_params, connection_policy, p read_timeout=read_timeout, # If SSL is disabled, verify = false connection_verify=kwargs.pop("connection_verify", is_ssl_enabled), + request_params=request_params, + global_endpoint_manager=global_endpoint_manager, **kwargs ) diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_container.py b/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_container.py index 79a5b766eb3e..360ed38af64b 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_container.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_container.py @@ -164,6 +164,8 @@ async def read( range statistics in response headers. :keyword bool populate_quota_info: Enable returning collection storage quota information in response headers. :keyword dict[str, str] initial_headers: Initial headers to be sent as part of the request. + :keyword list[str] excluded_locations: Excluded locations to be skipped from preferred locations. The locations + in this list are specified as the names of the azure Cosmos locations like, 'West US', 'East US' and so on. :keyword response_hook: A callable invoked with the response metadata. :paramtype response_hook: Callable[[Mapping[str, str], Dict[str, Any]], None] :keyword Literal["High", "Low"] priority: Priority based execution allows users to set a priority for each @@ -224,7 +226,13 @@ async def create_item( :keyword bool enable_automatic_id_generation: Enable automatic id generation if no id present. :keyword str session_token: Token for use with Session consistency. :keyword dict[str, str] initial_headers: Initial headers to be sent as part of the request. + :keyword list[str] excluded_locations: Excluded locations to be skipped from preferred locations. The locations + in this list are specified as the names of the azure Cosmos locations like, 'West US', 'East US' and so on. :keyword response_hook: A callable invoked with the response metadata. + :keyword list[str] excluded_locations: Excluded locations to be skipped from preferred locations. The locations + in this list are specified as the names of the azure Cosmos locations like, 'West US', 'East US' and so on. + :keyword response_hook: A callable invoked with the response metadata. + :paramtype response_hook: Callable[[Dict[str, str], Dict[str, Any]], None] :paramtype response_hook: Callable[[Mapping[str, str], Dict[str, Any]], None] :keyword Literal["High", "Low"] priority: Priority based execution allows users to set a priority for each request. Once the user has reached their provisioned throughput, low priority requests are throttled @@ -303,6 +311,8 @@ async def read_item( :keyword Literal["High", "Low"] priority: Priority based execution allows users to set a priority for each request. Once the user has reached their provisioned throughput, low priority requests are throttled before high priority requests start getting throttled. Feature must first be enabled at the account level. + :keyword list[str] excluded_locations: Excluded locations to be skipped from preferred locations. The locations + in this list are specified as the names of the azure Cosmos locations like, 'West US', 'East US' and so on. :raises ~azure.cosmos.exceptions.CosmosHttpResponseError: The given item couldn't be retrieved. :returns: A CosmosDict representing the retrieved item. :rtype: ~azure.cosmos.CosmosDict[str, Any] @@ -361,6 +371,8 @@ def read_all_items( :keyword Literal["High", "Low"] priority: Priority based execution allows users to set a priority for each request. Once the user has reached their provisioned throughput, low priority requests are throttled before high priority requests start getting throttled. Feature must first be enabled at the account level. + :keyword list[str] excluded_locations: Excluded locations to be skipped from preferred locations. The locations + in this list are specified as the names of the azure Cosmos locations like, 'West US', 'East US' and so on. :returns: An AsyncItemPaged of items (dicts). :rtype: AsyncItemPaged[Dict[str, Any]] """ @@ -441,6 +453,8 @@ def query_items( :keyword Literal["High", "Low"] priority: Priority based execution allows users to set a priority for each request. Once the user has reached their provisioned throughput, low priority requests are throttled before high priority requests start getting throttled. Feature must first be enabled at the account level. + :keyword list[str] excluded_locations: Excluded locations to be skipped from preferred locations. The locations + in this list are specified as the names of the azure Cosmos locations like, 'West US', 'East US' and so on. :returns: An AsyncItemPaged of items (dicts). :rtype: AsyncItemPaged[Dict[str, Any]] @@ -537,6 +551,8 @@ def query_items_change_feed( ALL_VERSIONS_AND_DELETES: Query all versions and deleted items from either `start_time='Now'` or 'continuation' token. :paramtype mode: Literal["LatestVersion", "AllVersionsAndDeletes"] + :keyword list[str] excluded_locations: Excluded locations to be skipped from preferred locations. The locations + in this list are specified as the names of the azure Cosmos locations like, 'West US', 'East US' and so on. :keyword response_hook: A callable invoked with the response metadata. :paramtype response_hook: Callable[[Mapping[str, str], Dict[str, Any]], None] :returns: An AsyncItemPaged of items (dicts). @@ -575,6 +591,8 @@ def query_items_change_feed( ALL_VERSIONS_AND_DELETES: Query all versions and deleted items from either `start_time='Now'` or 'continuation' token. :paramtype mode: Literal["LatestVersion", "AllVersionsAndDeletes"] + :keyword list[str] excluded_locations: Excluded locations to be skipped from preferred locations. The locations + in this list are specified as the names of the azure Cosmos locations like, 'West US', 'East US' and so on. :keyword response_hook: A callable invoked with the response metadata. :paramtype response_hook: Callable[[Mapping[str, str], Dict[str, Any]], None] :returns: An AsyncItemPaged of items (dicts). @@ -601,6 +619,8 @@ def query_items_change_feed( request. Once the user has reached their provisioned throughput, low priority requests are throttled before high priority requests start getting throttled. Feature must first be enabled at the account level. :paramtype priority: Literal["High", "Low"] + :keyword list[str] excluded_locations: Excluded locations to be skipped from preferred locations. The locations + in this list are specified as the names of the azure Cosmos locations like, 'West US', 'East US' and so on. :keyword response_hook: A callable invoked with the response metadata. :paramtype response_hook: Callable[[Mapping[str, str], Dict[str, Any]], None] :returns: An AsyncItemPaged of items (dicts). @@ -639,6 +659,8 @@ def query_items_change_feed( ALL_VERSIONS_AND_DELETES: Query all versions and deleted items from either `start_time='Now'` or 'continuation' token. :paramtype mode: Literal["LatestVersion", "AllVersionsAndDeletes"] + :keyword list[str] excluded_locations: Excluded locations to be skipped from preferred locations. The locations + in this list are specified as the names of the azure Cosmos locations like, 'West US', 'East US' and so on. :keyword response_hook: A callable invoked with the response metadata. :paramtype response_hook: Callable[[Mapping[str, str], Dict[str, Any]], None] :returns: An AsyncItemPaged of items (dicts). @@ -675,6 +697,8 @@ def query_items_change_feed( # pylint: disable=unused-argument ALL_VERSIONS_AND_DELETES: Query all versions and deleted items from either `start_time='Now'` or 'continuation' token. :paramtype mode: Literal["LatestVersion", "AllVersionsAndDeletes"] + :keyword list[str] excluded_locations: Excluded locations to be skipped from preferred locations. The locations + in this list are specified as the names of the azure Cosmos locations like, 'West US', 'East US' and so on. :keyword response_hook: A callable invoked with the response metadata. :paramtype response_hook: Callable[[Mapping[str, str], Dict[str, Any]], None] :returns: An AsyncItemPaged of items (dicts). @@ -754,6 +778,8 @@ async def upsert_item( :keyword bool no_response: Indicates whether service should be instructed to skip sending response payloads. When not specified explicitly here, the default value will be determined from client-level options. + :keyword list[str] excluded_locations: Excluded locations to be skipped from preferred locations. The locations + in this list are specified as the names of the azure Cosmos locations like, 'West US', 'East US' and so on. :raises ~azure.cosmos.exceptions.CosmosHttpResponseError: The given item could not be upserted. :returns: A CosmosDict representing the upserted item. The dict will be empty if `no_response` is specified. @@ -827,6 +853,8 @@ async def replace_item( :keyword bool no_response: Indicates whether service should be instructed to skip sending response payloads. When not specified explicitly here, the default value will be determined from client-level options. + :keyword list[str] excluded_locations: Excluded locations to be skipped from preferred locations. The locations + in this list are specified as the names of the azure Cosmos locations like, 'West US', 'East US' and so on. :raises ~azure.cosmos.exceptions.CosmosHttpResponseError: The replace operation failed or the item with given id does not exist. :returns: A CosmosDict representing the item after replace went through. The dict will be empty if `no_response` @@ -903,6 +931,8 @@ async def patch_item( :keyword bool no_response: Indicates whether service should be instructed to skip sending response payloads. When not specified explicitly here, the default value will be determined from client-level options. + :keyword list[str] excluded_locations: Excluded locations to be skipped from preferred locations. The locations + in this list are specified as the names of the azure Cosmos locations like, 'West US', 'East US' and so on. :raises ~azure.cosmos.exceptions.CosmosHttpResponseError: The patch operations failed or the item with given id does not exist. :returns: A CosmosDict representing the item after the patch operations went through. The dict will be empty if @@ -970,6 +1000,8 @@ async def delete_item( :keyword Literal["High", "Low"] priority: Priority based execution allows users to set a priority for each request. Once the user has reached their provisioned throughput, low priority requests are throttled before high priority requests start getting throttled. Feature must first be enabled at the account level. + :keyword list[str] excluded_locations: Excluded locations to be skipped from preferred locations. The locations + in this list are specified as the names of the azure Cosmos locations like, 'West US', 'East US' and so on. :keyword response_hook: A callable invoked with the response metadata. :paramtype response_hook: Callable[[Mapping[str, str], None], None] :raises ~azure.cosmos.exceptions.CosmosHttpResponseError: The item wasn't deleted successfully. @@ -1220,6 +1252,8 @@ async def delete_all_items_by_partition_key( :keyword str pre_trigger_include: trigger id to be used as pre operation trigger. :keyword str post_trigger_include: trigger id to be used as post operation trigger. :keyword str session_token: Token for use with Session consistency. + :keyword list[str] excluded_locations: Excluded locations to be skipped from preferred locations. The locations + in this list are specified as the names of the azure Cosmos locations like, 'West US', 'East US' and so on. :keyword Callable response_hook: A callable invoked with the response metadata. :rtype: None """ @@ -1275,6 +1309,8 @@ async def execute_item_batch( :keyword Literal["High", "Low"] priority: Priority based execution allows users to set a priority for each request. Once the user has reached their provisioned throughput, low priority requests are throttled before high priority requests start getting throttled. Feature must first be enabled at the account level. + :keyword list[str] excluded_locations: Excluded locations to be skipped from preferred locations. The locations + in this list are specified as the names of the azure Cosmos locations like, 'West US', 'East US' and so on. :keyword Callable response_hook: A callable invoked with the response metadata. :returns: A CosmosList representing the items after the batch operations went through. :raises ~azure.cosmos.exceptions.CosmosHttpResponseError: The batch failed to execute. diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_cosmos_client.py b/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_cosmos_client.py index 683f16288cd3..e5e526670629 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_cosmos_client.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_cosmos_client.py @@ -84,6 +84,7 @@ def _build_connection_policy(kwargs: Dict[str, Any]) -> ConnectionPolicy: policy.ProxyConfiguration = kwargs.pop('proxy_config', policy.ProxyConfiguration) policy.EnableEndpointDiscovery = kwargs.pop('enable_endpoint_discovery', policy.EnableEndpointDiscovery) policy.PreferredLocations = kwargs.pop('preferred_locations', policy.PreferredLocations) + policy.ExcludedLocations = kwargs.pop('excluded_locations', policy.ExcludedLocations) policy.UseMultipleWriteLocations = kwargs.pop('multiple_write_locations', policy.UseMultipleWriteLocations) # SSL config @@ -161,6 +162,7 @@ class CosmosClient: # pylint: disable=client-accepts-api-version-keyword :keyword bool enable_endpoint_discovery: Enable endpoint discovery for geo-replicated database accounts. (Default: True) :keyword list[str] preferred_locations: The preferred locations for geo-replicated database accounts. + :keyword list[str] excluded_locations: The excluded locations to be skipped from preferred locations. The locations :keyword bool enable_diagnostics_logging: Enable the CosmosHttpLogging policy. Must be used along with a logger to work. :keyword ~logging.Logger logger: Logger to be used for collecting request diagnostics. Can be passed in at client diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_cosmos_client_connection_async.py b/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_cosmos_client_connection_async.py index cd2e3ca9c9f0..31ae9dc334cd 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_cosmos_client_connection_async.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_cosmos_client_connection_async.py @@ -48,6 +48,8 @@ DistributedTracingPolicy, ProxyPolicy) from azure.core.utils import CaseInsensitiveDict +from azure.cosmos.aio._global_partition_endpoint_manager_circuit_breaker_async import ( + _GlobalPartitionEndpointManagerForCircuitBreakerAsync) from .. import _base as base from .._base import _set_properties_cache @@ -63,7 +65,6 @@ from .. import _runtime_constants as runtime_constants from .. import _request_object from . import _asynchronous_request as asynchronous_request -from . import _global_endpoint_manager_async as global_endpoint_manager_async from .._routing.aio.routing_map_provider import SmartRoutingMapProvider from ._retry_utility_async import _ConnectionRetryPolicy from .. import _session @@ -169,7 +170,7 @@ def __init__( # pylint: disable=too-many-statements # Keeps the latest response headers from the server. self.last_response_headers: CaseInsensitiveDict = CaseInsensitiveDict() self.UseMultipleWriteLocations = False - self._global_endpoint_manager = global_endpoint_manager_async._GlobalEndpointManager(self) + self._global_endpoint_manager = _GlobalPartitionEndpointManagerForCircuitBreakerAsync(self) retry_policy = None if isinstance(self.connection_policy.ConnectionRetryConfiguration, AsyncHTTPPolicy): @@ -415,7 +416,8 @@ async def GetDatabaseAccount( documents._OperationType.Read, {}, client_id=self.client_id) # path # id # type - request_params = _request_object.RequestObject("databaseaccount", documents._OperationType.Read, url_connection) + request_params = _request_object.RequestObject("databaseaccount", documents._OperationType.Read, + headers, url_connection) result, self.last_response_headers = await self.__Get("", request_params, headers, **kwargs) database_account = documents.DatabaseAccount() @@ -465,7 +467,9 @@ async def _GetDatabaseAccountCheck( documents._OperationType.Read, {}, client_id=self.client_id) # path # id # type - request_params = _request_object.RequestObject("databaseaccount", documents._OperationType.Read, url_connection) + request_params = _request_object.RequestObject("databaseaccount", documents._OperationType.Read, + headers, + url_connection) await self.__Get("", request_params, headers, **kwargs) async def CreateDatabase( @@ -729,7 +733,9 @@ async def ExecuteStoredProcedure( documents._OperationType.ExecuteJavaScript, options) # ExecuteStoredProcedure will use WriteEndpoint since it uses POST operation - request_params = _request_object.RequestObject("sprocs", documents._OperationType.ExecuteJavaScript) + request_params = _request_object.RequestObject("sprocs", + documents._OperationType.ExecuteJavaScript, headers) + request_params.set_excluded_location_from_options(options) result, self.last_response_headers = await self.__Post(path, request_params, params, headers, **kwargs) return result @@ -767,7 +773,8 @@ async def Create( documents._OperationType.Create, options) # Create will use WriteEndpoint since it uses POST operation - request_params = _request_object.RequestObject(typ, documents._OperationType.Create) + request_params = _request_object.RequestObject(typ, documents._OperationType.Create, headers) + request_params.set_excluded_location_from_options(options) result, last_response_headers = await self.__Post(path, request_params, body, headers, **kwargs) self.last_response_headers = last_response_headers @@ -906,7 +913,8 @@ async def Upsert( headers[http_constants.HttpHeaders.IsUpsert] = True # Upsert will use WriteEndpoint since it uses POST operation - request_params = _request_object.RequestObject(typ, documents._OperationType.Upsert) + request_params = _request_object.RequestObject(typ, documents._OperationType.Upsert, headers) + request_params.set_excluded_location_from_options(options) result, last_response_headers = await self.__Post(path, request_params, body, headers, **kwargs) self.last_response_headers = last_response_headers # update session for write request @@ -1207,7 +1215,8 @@ async def Read( headers = base.GetHeaders(self, initial_headers, "get", path, id, typ, documents._OperationType.Read, options) # Read will use ReadEndpoint since it uses GET operation - request_params = _request_object.RequestObject(typ, documents._OperationType.Read) + request_params = _request_object.RequestObject(typ, documents._OperationType.Read, headers) + request_params.set_excluded_location_from_options(options) result, last_response_headers = await self.__Get(path, request_params, headers, **kwargs) self.last_response_headers = last_response_headers if response_hook: @@ -1465,7 +1474,8 @@ async def PatchItem( headers = base.GetHeaders(self, initial_headers, "patch", path, document_id, typ, documents._OperationType.Patch, options) # Patch will use WriteEndpoint since it uses PUT operation - request_params = _request_object.RequestObject(typ, documents._OperationType.Patch) + request_params = _request_object.RequestObject(typ, documents._OperationType.Patch, headers) + request_params.set_excluded_location_from_options(options) request_data = {} if options.get("filterPredicate"): request_data["condition"] = options.get("filterPredicate") @@ -1569,7 +1579,8 @@ async def Replace( headers = base.GetHeaders(self, initial_headers, "put", path, id, typ, documents._OperationType.Replace, options) # Replace will use WriteEndpoint since it uses PUT operation - request_params = _request_object.RequestObject(typ, documents._OperationType.Replace) + request_params = _request_object.RequestObject(typ, documents._OperationType.Replace, headers) + request_params.set_excluded_location_from_options(options) result, last_response_headers = await self.__Put(path, request_params, resource, headers, **kwargs) self.last_response_headers = last_response_headers @@ -1892,7 +1903,8 @@ async def DeleteResource( headers = base.GetHeaders(self, initial_headers, "delete", path, id, typ, documents._OperationType.Delete, options) # Delete will use WriteEndpoint since it uses DELETE operation - request_params = _request_object.RequestObject(typ, documents._OperationType.Delete) + request_params = _request_object.RequestObject(typ, documents._OperationType.Delete, headers) + request_params.set_excluded_location_from_options(options) result, last_response_headers = await self.__Delete(path, request_params, headers, **kwargs) self.last_response_headers = last_response_headers @@ -2005,7 +2017,8 @@ async def _Batch( base._populate_batch_headers(initial_headers) headers = base.GetHeaders(self, initial_headers, "post", path, collection_id, "docs", documents._OperationType.Batch, options) - request_params = _request_object.RequestObject("docs", documents._OperationType.Batch) + request_params = _request_object.RequestObject("docs", documents._OperationType.Batch, headers) + request_params.set_excluded_location_from_options(options) result = await self.__Post(path, request_params, batch_operations, headers, **kwargs) return cast(Tuple[List[Dict[str, Any]], CaseInsensitiveDict], result) @@ -2856,13 +2869,16 @@ def __GetBodiesFromQueryResult(result: Dict[str, Any]) -> List[Dict[str, Any]]: initial_headers = self.default_headers.copy() # Copy to make sure that default_headers won't be changed. if query is None: + op_typ = documents._OperationType.QueryPlan if is_query_plan else documents._OperationType.ReadFeed # Query operations will use ReadEndpoint even though it uses GET(for feed requests) + headers = base.GetHeaders(self, initial_headers, "get", path, id_, typ, op_typ, + options, partition_key_range_id) request_params = _request_object.RequestObject( typ, - documents._OperationType.QueryPlan if is_query_plan else documents._OperationType.ReadFeed + op_typ, + headers ) - headers = base.GetHeaders(self, initial_headers, "get", path, id_, typ, request_params.operation_type, - options, partition_key_range_id) + request_params.set_excluded_location_from_options(options) change_feed_state: Optional[ChangeFeedState] = options.get("changeFeedState") if change_feed_state is not None: @@ -2889,9 +2905,11 @@ def __GetBodiesFromQueryResult(result: Dict[str, Any]) -> List[Dict[str, Any]]: raise SystemError("Unexpected query compatibility mode.") # Query operations will use ReadEndpoint even though it uses POST(for regular query operations) - request_params = _request_object.RequestObject(typ, documents._OperationType.SqlQuery) - req_headers = base.GetHeaders(self, initial_headers, "post", path, id_, typ, request_params.operation_type, + req_headers = base.GetHeaders(self, initial_headers, "post", path, id_, typ, + documents._OperationType.SqlQuery, options, partition_key_range_id) + request_params = _request_object.RequestObject(typ, documents._OperationType.SqlQuery, req_headers) + request_params.set_excluded_location_from_options(options) # check if query has prefix partition key cont_prop = kwargs.pop("containerProperties", None) @@ -3259,7 +3277,9 @@ async def DeleteAllItemsByPartitionKey( initial_headers = dict(self.default_headers) headers = base.GetHeaders(self, initial_headers, "post", path, collection_id, "partitionkey", documents._OperationType.Delete, options) - request_params = _request_object.RequestObject("partitionkey", documents._OperationType.Delete) + request_params = _request_object.RequestObject("partitionkey", documents._OperationType.Delete, + headers) + request_params.set_excluded_location_from_options(options) _, last_response_headers = await self.__Post(path=path, request_params=request_params, req_headers=headers, body=None, **kwargs) self.last_response_headers = last_response_headers diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_global_endpoint_manager_async.py b/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_global_endpoint_manager_async.py index 4d00a7ef5629..0fe666f1983c 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_global_endpoint_manager_async.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_global_endpoint_manager_async.py @@ -25,7 +25,6 @@ import asyncio # pylint: disable=do-not-import-asyncio import logging -from asyncio import CancelledError from typing import Tuple from azure.core.exceptions import AzureError @@ -33,8 +32,9 @@ from .. import _constants as constants from .. import exceptions -from .._location_cache import LocationCache - +from .._location_cache import LocationCache, current_time_millis +from .._request_object import RequestObject +from .._routing.routing_range import PartitionKeyRangeWrapper # pylint: disable=protected-access @@ -48,15 +48,12 @@ class _GlobalEndpointManager(object): # pylint: disable=too-many-instance-attrib def __init__(self, client): self.client = client - self.EnableEndpointDiscovery = client.connection_policy.EnableEndpointDiscovery self.PreferredLocations = client.connection_policy.PreferredLocations self.DefaultEndpoint = client.url_connection self.refresh_time_interval_in_ms = self.get_refresh_time_interval_in_ms_stub() self.location_cache = LocationCache( - self.PreferredLocations, self.DefaultEndpoint, - self.EnableEndpointDiscovery, - client.connection_policy.UseMultipleWriteLocations + client.connection_policy ) self.startup = True self.refresh_task = None @@ -65,9 +62,6 @@ def __init__(self, client): self.last_refresh_time = 0 self._database_account_cache = None - def get_use_multiple_write_locations(self): - return self.location_cache.can_use_multiple_write_locations() - def get_refresh_time_interval_in_ms_stub(self): return constants._Constants.DefaultEndpointsRefreshTime @@ -77,7 +71,11 @@ def get_write_endpoint(self): def get_read_endpoint(self): return self.location_cache.get_read_regional_routing_context() - def resolve_service_endpoint(self, request): + def resolve_service_endpoint( + self, + request: RequestObject, + pk_range_wrapper: PartitionKeyRangeWrapper # pylint: disable=unused-argument + ) -> str: return self.location_cache.resolve_service_endpoint(request) def mark_endpoint_unavailable_for_read(self, endpoint, refresh_cache): @@ -105,9 +103,9 @@ async def refresh_endpoint_list(self, database_account, **kwargs): try: await self.refresh_task self.refresh_task = None - except (Exception, CancelledError) as exception: #pylint: disable=broad-exception-caught + except (Exception, asyncio.CancelledError) as exception: #pylint: disable=broad-exception-caught logger.exception("Health check task failed: %s", exception) - if self.location_cache.current_time_millis() - self.last_refresh_time > self.refresh_time_interval_in_ms: + if current_time_millis() - self.last_refresh_time > self.refresh_time_interval_in_ms: self.refresh_needed = True if self.refresh_needed: async with self.refresh_lock: @@ -123,11 +121,11 @@ async def _refresh_endpoint_list_private(self, database_account=None, **kwargs): if database_account and not self.startup: self.location_cache.perform_on_database_account_read(database_account) self.refresh_needed = False - self.last_refresh_time = self.location_cache.current_time_millis() + self.last_refresh_time = current_time_millis() else: if self.location_cache.should_refresh_endpoints() or self.refresh_needed: self.refresh_needed = False - self.last_refresh_time = self.location_cache.current_time_millis() + self.last_refresh_time = current_time_millis() if not self.startup: # this will perform getDatabaseAccount calls to check endpoint health # in background @@ -217,5 +215,5 @@ async def close(self): self.refresh_task.cancel() try: await self.refresh_task - except (Exception, CancelledError) : #pylint: disable=broad-exception-caught + except (Exception, asyncio.CancelledError) : #pylint: disable=broad-exception-caught pass diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_global_partition_endpoint_manager_circuit_breaker_async.py b/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_global_partition_endpoint_manager_circuit_breaker_async.py new file mode 100644 index 000000000000..5231ed5c06c4 --- /dev/null +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_global_partition_endpoint_manager_circuit_breaker_async.py @@ -0,0 +1,115 @@ +# The MIT License (MIT) +# Copyright (c) 2021 Microsoft Corporation + +# Permission is hereby granted, free of charge, to any person obtaining a copy +# of this software and associated documentation files (the "Software"), to deal +# in the Software without restriction, including without limitation the rights +# to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +# copies of the Software, and to permit persons to whom the Software is +# furnished to do so, subject to the following conditions: + +# The above copyright notice and this permission notice shall be included in all +# copies or substantial portions of the Software. + +# THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +# IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +# FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +# AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +# LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +# OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +# SOFTWARE. + +"""Internal class for global endpoint manager for circuit breaker. +""" +from typing import TYPE_CHECKING + +from azure.cosmos import PartitionKey +from azure.cosmos._global_partition_endpoint_manager_circuit_breaker_core import \ + _GlobalPartitionEndpointManagerForCircuitBreakerCore +from azure.cosmos._routing.routing_range import PartitionKeyRangeWrapper, Range + +from azure.cosmos.aio._global_endpoint_manager_async import _GlobalEndpointManager +from azure.cosmos._request_object import RequestObject +from azure.cosmos.http_constants import HttpHeaders + +if TYPE_CHECKING: + from azure.cosmos.aio._cosmos_client_connection_async import CosmosClientConnection + + +# pylint: disable=protected-access +class _GlobalPartitionEndpointManagerForCircuitBreakerAsync(_GlobalEndpointManager): + """ + This internal class implements the logic for partition endpoint management for + geo-replicated database accounts. + """ + + + def __init__(self, client: "CosmosClientConnection"): + super(_GlobalPartitionEndpointManagerForCircuitBreakerAsync, self).__init__(client) + self.global_partition_endpoint_manager_core = ( + _GlobalPartitionEndpointManagerForCircuitBreakerCore(client, self.location_cache)) + + async def create_pk_range_wrapper(self, request: RequestObject) -> PartitionKeyRangeWrapper: + container_rid = request.headers[HttpHeaders.IntendedCollectionRID] + print(request.headers) + target_container_link = None + partition_key = None + # TODO: @tvaron3 see if the container link is absolutely necessary or change container cache + for container_link, properties in self.client._container_properties_cache.items(): + if properties["_rid"] == container_rid: + target_container_link = container_link + partition_key_definition = properties["partitionKey"] + partition_key = PartitionKey(path=partition_key_definition["paths"], + kind=partition_key_definition["kind"]) + + if not target_container_link or not partition_key: + raise RuntimeError("Illegal state: the container cache is not properly initialized.") + + if request.headers.get(HttpHeaders.PartitionKey): + partition_key_value = request.headers[HttpHeaders.PartitionKey] + # get the partition key range for the given partition key + # TODO: @tvaron3 check different clients and create them in different ways + epk_range = [partition_key._get_epk_range_for_partition_key(partition_key_value)] + partition_ranges = await (self.client._routing_map_provider + .get_overlapping_ranges(target_container_link, epk_range)) + partition_range = Range.PartitionKeyRangeToRange(partition_ranges[0]) + elif request.headers.get(HttpHeaders.PartitionKeyRangeID): + pk_range_id = request.headers[HttpHeaders.PartitionKeyRangeID] + range = await (self.client._routing_map_provider + .get_range_by_partition_key_range_id(target_container_link, pk_range_id)) + partition_range = Range.PartitionKeyRangeToRange(range) + + return PartitionKeyRangeWrapper(partition_range, container_rid) + + def is_circuit_breaker_applicable(self, request: RequestObject) -> bool: + return self.global_partition_endpoint_manager_core.is_circuit_breaker_applicable(request) + + async def record_failure( + self, + request: RequestObject + ) -> None: + if self.is_circuit_breaker_applicable(request): + pk_range_wrapper = await self.create_pk_range_wrapper(request) + self.global_partition_endpoint_manager_core.record_failure(request, pk_range_wrapper) + + def resolve_service_endpoint(self, request: RequestObject, pk_range_wrapper: PartitionKeyRangeWrapper): + if self.global_partition_endpoint_manager_core.is_circuit_breaker_applicable(request) and pk_range_wrapper: + request = self.global_partition_endpoint_manager_core.add_excluded_locations_to_request(request, + pk_range_wrapper) + return (super(_GlobalPartitionEndpointManagerForCircuitBreakerAsync, self) + .resolve_service_endpoint(request, pk_range_wrapper)) + + def mark_partition_unavailable( + self, + request: RequestObject, + pk_range_wrapper: PartitionKeyRangeWrapper + ) -> None: + self.global_partition_endpoint_manager_core.mark_partition_unavailable(request, pk_range_wrapper) + + async def record_success( + self, + request: RequestObject + ) -> None: + if self.global_partition_endpoint_manager_core.is_circuit_breaker_applicable(request): + pk_range_wrapper = await self.create_pk_range_wrapper(request) + self.global_partition_endpoint_manager_core.record_success(request, pk_range_wrapper) diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_retry_utility_async.py b/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_retry_utility_async.py index ef5bad070014..5d4d680b50e4 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_retry_utility_async.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_retry_utility_async.py @@ -58,6 +58,9 @@ async def ExecuteAsync(client, global_endpoint_manager, function, *args, **kwarg :returns: the result of running the passed in function as a (result, headers) tuple :rtype: tuple of (dict, dict) """ + pk_range_wrapper = None + if args and global_endpoint_manager.is_circuit_breaker_applicable(args[0]): + pk_range_wrapper = await global_endpoint_manager.create_pk_range_wrapper(args[0]) # instantiate all retry policies here to be applied for each request execution endpointDiscovery_retry_policy = _endpoint_discovery_retry_policy.EndpointDiscoveryRetryPolicy( client.connection_policy, global_endpoint_manager, *args @@ -73,17 +76,17 @@ async def ExecuteAsync(client, global_endpoint_manager, function, *args, **kwarg defaultRetry_policy = _default_retry_policy.DefaultRetryPolicy(*args) sessionRetry_policy = _session_retry_policy._SessionRetryPolicy( - client.connection_policy.EnableEndpointDiscovery, global_endpoint_manager, *args + client.connection_policy.EnableEndpointDiscovery, global_endpoint_manager, pk_range_wrapper, *args ) partition_key_range_gone_retry_policy = _gone_retry_policy.PartitionKeyRangeGoneRetryPolicy(client, *args) timeout_failover_retry_policy = _timeout_failover_retry_policy._TimeoutFailoverRetryPolicy( - client.connection_policy, global_endpoint_manager, *args + client.connection_policy, global_endpoint_manager, pk_range_wrapper, *args ) service_response_retry_policy = _service_response_retry_policy.ServiceResponseRetryPolicy( - client.connection_policy, global_endpoint_manager, *args, + client.connection_policy, global_endpoint_manager, pk_range_wrapper, *args, ) service_request_retry_policy = _service_request_retry_policy.ServiceRequestRetryPolicy( - client.connection_policy, global_endpoint_manager, *args, + client.connection_policy, global_endpoint_manager, pk_range_wrapper, *args, ) # HttpRequest we would need to modify for Container Recreate Retry Policy request = None @@ -102,6 +105,7 @@ async def ExecuteAsync(client, global_endpoint_manager, function, *args, **kwarg try: if args: result = await ExecuteFunctionAsync(function, global_endpoint_manager, *args, **kwargs) + await global_endpoint_manager.record_success(args[0]) else: result = await ExecuteFunctionAsync(function, *args, **kwargs) if not client.last_response_headers: @@ -170,9 +174,9 @@ async def ExecuteAsync(client, global_endpoint_manager, function, *args, **kwarg retry_policy.container_rid = cached_container["_rid"] request.headers[retry_policy._intended_headers] = retry_policy.container_rid - elif e.status_code == StatusCodes.REQUEST_TIMEOUT: - retry_policy = timeout_failover_retry_policy - elif e.status_code >= StatusCodes.INTERNAL_SERVER_ERROR: + elif e.status_code == StatusCodes.REQUEST_TIMEOUT or e.status_code >= StatusCodes.INTERNAL_SERVER_ERROR: + # record the failure for circuit breaker tracking + await global_endpoint_manager.record_failure(args[0]) retry_policy = timeout_failover_retry_policy else: retry_policy = defaultRetry_policy @@ -198,6 +202,7 @@ async def ExecuteAsync(client, global_endpoint_manager, function, *args, **kwarg if client_timeout: kwargs['timeout'] = client_timeout - (time.time() - start_time) if kwargs['timeout'] <= 0: + await global_endpoint_manager.record_failure(args[0]) raise exceptions.CosmosClientTimeoutError() except ServiceRequestError as e: @@ -255,6 +260,8 @@ async def send(self, request): """ absolute_timeout = request.context.options.pop('timeout', None) per_request_timeout = request.context.options.pop('connection_timeout', 0) + request_params = request.context.options.pop('request_params', None) + global_endpoint_manager = request.context.options.pop('global_endpoint_manager', None) retry_error = None retry_active = True response = None @@ -279,6 +286,7 @@ async def send(self, request): # the request ran into a socket timeout or failed to establish a new connection # since request wasn't sent, raise exception immediately to be dealt with in client retry policies if not _has_database_account_header(request.http_request.headers): + await global_endpoint_manager.record_failure(request_params) if retry_settings['connect'] > 0: retry_active = self.increment(retry_settings, response=request, error=err) if retry_active: @@ -296,6 +304,7 @@ async def send(self, request): ClientConnectionError) if (isinstance(err.inner_exception, ClientConnectionError) or _has_read_retryable_headers(request.http_request.headers)): + await global_endpoint_manager.record_failure(request_params) # This logic is based on the _retry.py file from azure-core if retry_settings['read'] > 0: retry_active = self.increment(retry_settings, response=request, error=err) @@ -310,6 +319,7 @@ async def send(self, request): if _has_database_account_header(request.http_request.headers): raise err if self._is_method_retryable(retry_settings, request.http_request): + await global_endpoint_manager.record_failure(request_params) retry_active = self.increment(retry_settings, response=request, error=err) if retry_active: await self.sleep(retry_settings, request.context.transport) diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/container.py b/sdk/cosmos/azure-cosmos/azure/cosmos/container.py index ef676cf3e1d2..c659ae746b4a 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/container.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/container.py @@ -167,6 +167,8 @@ def read( # pylint:disable=docstring-missing-param request. Once the user has reached their provisioned throughput, low priority requests are throttled before high priority requests start getting throttled. Feature must first be enabled at the account level. :keyword dict[str, str] initial_headers: Initial headers to be sent as part of the request. + :keyword list[str] excluded_locations: Excluded locations to be skipped from preferred locations. The locations + in this list are specified as the names of the azure Cosmos locations like, 'West US', 'East US' and so on. :keyword response_hook: A callable invoked with the response metadata. :paramtype response_hook: Callable[[Mapping[str, str], Dict[str, Any]], None] :raises ~azure.cosmos.exceptions.CosmosHttpResponseError: Raised if the container couldn't be retrieved. @@ -233,6 +235,8 @@ def read_item( # pylint:disable=docstring-missing-param :keyword Literal["High", "Low"] priority: Priority based execution allows users to set a priority for each request. Once the user has reached their provisioned throughput, low priority requests are throttled before high priority requests start getting throttled. Feature must first be enabled at the account level. + :keyword list[str] excluded_locations: Excluded locations to be skipped from preferred locations. The locations + in this list are specified as the names of the azure Cosmos locations like, 'West US', 'East US' and so on. :returns: A CosmosDict representing the item to be retrieved. :raises ~azure.cosmos.exceptions.CosmosHttpResponseError: The given item couldn't be retrieved. :rtype: ~azure.cosmos.CosmosDict[str, Any] @@ -298,6 +302,8 @@ def read_all_items( # pylint:disable=docstring-missing-param :keyword Literal["High", "Low"] priority: Priority based execution allows users to set a priority for each request. Once the user has reached their provisioned throughput, low priority requests are throttled before high priority requests start getting throttled. Feature must first be enabled at the account level. + :keyword list[str] excluded_locations: Excluded locations to be skipped from preferred locations. The locations + in this list are specified as the names of the azure Cosmos locations like, 'West US', 'East US' and so on. :returns: An Iterable of items (dicts). :rtype: Iterable[Dict[str, Any]] """ @@ -364,6 +370,8 @@ def query_items_change_feed( ALL_VERSIONS_AND_DELETES: Query all versions and deleted items from either `start_time='Now'` or 'continuation' token. :paramtype mode: Literal["LatestVersion", "AllVersionsAndDeletes"] + :keyword list[str] excluded_locations: Excluded locations to be skipped from preferred locations. The locations + in this list are specified as the names of the azure Cosmos locations like, 'West US', 'East US' and so on. :keyword response_hook: A callable invoked with the response metadata. :paramtype response_hook: Callable[[Mapping[str, str], Dict[str, Any]], None] :returns: An Iterable of items (dicts). @@ -403,6 +411,8 @@ def query_items_change_feed( ALL_VERSIONS_AND_DELETES: Query all versions and deleted items from either `start_time='Now'` or 'continuation' token. :paramtype mode: Literal["LatestVersion", "AllVersionsAndDeletes"] + :keyword list[str] excluded_locations: Excluded locations to be skipped from preferred locations. The locations + in this list are specified as the names of the azure Cosmos locations like, 'West US', 'East US' and so on. :keyword response_hook: A callable invoked with the response metadata. :paramtype response_hook: Callable[[Mapping[str, str], Dict[str, Any]], None] :returns: An Iterable of items (dicts). @@ -429,6 +439,8 @@ def query_items_change_feed( request. Once the user has reached their provisioned throughput, low priority requests are throttled before high priority requests start getting throttled. Feature must first be enabled at the account level. :paramtype priority: Literal["High", "Low"] + :keyword list[str] excluded_locations: Excluded locations to be skipped from preferred locations. The locations + in this list are specified as the names of the azure Cosmos locations like, 'West US', 'East US' and so on. :keyword response_hook: A callable invoked with the response metadata. :paramtype response_hook: Callable[[Mapping[str, str], Dict[str, Any]], None] :returns: An Iterable of items (dicts). @@ -466,6 +478,8 @@ def query_items_change_feed( ALL_VERSIONS_AND_DELETES: Query all versions and deleted items from either `start_time='Now'` or 'continuation' token. :paramtype mode: Literal["LatestVersion", "AllVersionsAndDeletes"] + :keyword list[str] excluded_locations: Excluded locations to be skipped from preferred locations. The locations + in this list are specified as the names of the azure Cosmos locations like, 'West US', 'East US' and so on. :keyword response_hook: A callable invoked with the response metadata. :paramtype response_hook: Callable[[Mapping[str, str], Dict[str, Any]], None] :returns: An Iterable of items (dicts). @@ -501,6 +515,8 @@ def query_items_change_feed( ALL_VERSIONS_AND_DELETES: Query all versions and deleted items from either `start_time='Now'` or 'continuation' token. :paramtype mode: Literal["LatestVersion", "AllVersionsAndDeletes"] + :keyword list[str] excluded_locations: Excluded locations to be skipped from preferred locations. The locations + in this list are specified as the names of the azure Cosmos locations like, 'West US', 'East US' and so on. :keyword response_hook: A callable invoked with the response metadata. :paramtype response_hook: Callable[[Mapping[str, str], Dict[str, Any]], None] :param Any args: args @@ -601,6 +617,8 @@ def query_items( # pylint:disable=docstring-missing-param :keyword bool populate_index_metrics: Used to obtain the index metrics to understand how the query engine used existing indexes and how it could use potential new indexes. Please note that this options will incur overhead, so it should be enabled only when debugging slow queries. + :keyword list[str] excluded_locations: Excluded locations to be skipped from preferred locations. The locations + in this list are specified as the names of the azure Cosmos locations like, 'West US', 'East US' and so on. :returns: An Iterable of items (dicts). :rtype: ItemPaged[Dict[str, Any]] @@ -716,6 +734,8 @@ def replace_item( # pylint:disable=docstring-missing-param :keyword bool no_response: Indicates whether service should be instructed to skip sending response payloads. When not specified explicitly here, the default value will be determined from kwargs or when also not specified there from client-level kwargs. + :keyword list[str] excluded_locations: Excluded locations to be skipped from preferred locations. The locations + in this list are specified as the names of the azure Cosmos locations like, 'West US', 'East US' and so on. :raises ~azure.cosmos.exceptions.CosmosHttpResponseError: The replace operation failed or the item with given id does not exist. :returns: A CosmosDict representing the item after replace went through. The dict will be empty if `no_response` @@ -795,6 +815,8 @@ def upsert_item( # pylint:disable=docstring-missing-param :keyword bool no_response: Indicates whether service should be instructed to skip sending response payloads. When not specified explicitly here, the default value will be determined from kwargs or when also not specified there from client-level kwargs. + :keyword list[str] excluded_locations: Excluded locations to be skipped from preferred locations. The locations + in this list are specified as the names of the azure Cosmos locations like, 'West US', 'East US' and so on. :raises ~azure.cosmos.exceptions.CosmosHttpResponseError: The given item could not be upserted. :returns: A CosmosDict representing the upserted item. The dict will be empty if `no_response` is specified. :rtype: ~azure.cosmos.CosmosDict[str, Any] @@ -876,6 +898,8 @@ def create_item( # pylint:disable=docstring-missing-param :keyword bool no_response: Indicates whether service should be instructed to skip sending response payloads. When not specified explicitly here, the default value will be determined from kwargs or when also not specified there from client-level kwargs. + :keyword list[str] excluded_locations: Excluded locations to be skipped from preferred locations. The locations + in this list are specified as the names of the azure Cosmos locations like, 'West US', 'East US' and so on. :raises ~azure.cosmos.exceptions.CosmosHttpResponseError: Item with the given ID already exists. :returns: A CosmosDict representing the new item. The dict will be empty if `no_response` is specified. :rtype: ~azure.cosmos.CosmosDict[str, Any] @@ -967,6 +991,8 @@ def patch_item( :keyword bool no_response: Indicates whether service should be instructed to skip sending response payloads. When not specified explicitly here, the default value will be determined from kwargs or when also not specified there from client-level kwargs. + :keyword list[str] excluded_locations: Excluded locations to be skipped from preferred locations. The locations + in this list are specified as the names of the azure Cosmos locations like, 'West US', 'East US' and so on. :raises ~azure.cosmos.exceptions.CosmosHttpResponseError: The patch operations failed or the item with given id does not exist. :returns: A CosmosDict representing the item after the patch operations went through. The dict will be empty @@ -1027,6 +1053,8 @@ def execute_item_batch( :keyword Literal["High", "Low"] priority: Priority based execution allows users to set a priority for each request. Once the user has reached their provisioned throughput, low priority requests are throttled before high priority requests start getting throttled. Feature must first be enabled at the account level. + :keyword list[str] excluded_locations: Excluded locations to be skipped from preferred locations. The locations + in this list are specified as the names of the azure Cosmos locations like, 'West US', 'East US' and so on. :keyword response_hook: A callable invoked with the response metadata. :paramtype response_hook: [Callable[[Mapping[str, str], List[Dict[str, Any]]], None] :returns: A CosmosList representing the items after the batch operations went through. @@ -1060,6 +1088,8 @@ def execute_item_batch( request_options = build_options(kwargs) request_options["partitionKey"] = self._set_partition_key(partition_key) request_options["disableAutomaticIdGeneration"] = True + container_properties = self._get_properties() + request_options["containerRID"] = container_properties["_rid"] return self.client_connection.Batch( collection_link=self.container_link, batch_operations=batch_operations, options=request_options, **kwargs) @@ -1099,6 +1129,8 @@ def delete_item( # pylint:disable=docstring-missing-param :keyword Literal["High", "Low"] priority: Priority based execution allows users to set a priority for each request. Once the user has reached their provisioned throughput, low priority requests are throttled before high priority requests start getting throttled. Feature must first be enabled at the account level. + :keyword list[str] excluded_locations: Excluded locations to be skipped from preferred locations. The locations + in this list are specified as the names of the azure Cosmos locations like, 'West US', 'East US' and so on. :keyword response_hook: A callable invoked with the response metadata. :paramtype response_hook: Callable[[Mapping[str, str], None], None] :raises ~azure.cosmos.exceptions.CosmosHttpResponseError: The item wasn't deleted successfully. @@ -1374,6 +1406,8 @@ def delete_all_items_by_partition_key( :keyword str pre_trigger_include: trigger id to be used as pre operation trigger. :keyword str post_trigger_include: trigger id to be used as post operation trigger. :keyword str session_token: Token for use with Session consistency. + :keyword list[str] excluded_locations: Excluded locations to be skipped from preferred locations. The locations + in this list are specified as the names of the azure Cosmos locations like, 'West US', 'East US' and so on. :keyword response_hook: A callable invoked with the response metadata. :paramtype response_hook: Callable[[Mapping[str, str], None], None] = None, :rtype: None diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/cosmos_client.py b/sdk/cosmos/azure-cosmos/azure/cosmos/cosmos_client.py index 10543f97c47b..b7a6ea94bd2b 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/cosmos_client.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/cosmos_client.py @@ -93,6 +93,8 @@ def _build_connection_policy(kwargs: Dict[str, Any]) -> ConnectionPolicy: policy.ProxyConfiguration = kwargs.pop('proxy_config', policy.ProxyConfiguration) policy.EnableEndpointDiscovery = kwargs.pop('enable_endpoint_discovery', policy.EnableEndpointDiscovery) policy.PreferredLocations = kwargs.pop('preferred_locations', policy.PreferredLocations) + # TODO: Consider storing callback method instead, such as 'Supplier' in JAVA SDK + policy.ExcludedLocations = kwargs.pop('excluded_locations', policy.ExcludedLocations) policy.UseMultipleWriteLocations = kwargs.pop('multiple_write_locations', policy.UseMultipleWriteLocations) # SSL config @@ -181,6 +183,8 @@ class CosmosClient: # pylint: disable=client-accepts-api-version-keyword :keyword bool enable_endpoint_discovery: Enable endpoint discovery for geo-replicated database accounts. (Default: True) :keyword list[str] preferred_locations: The preferred locations for geo-replicated database accounts. + :keyword list[str] excluded_locations: The excluded locations to be skipped from preferred locations. The locations + in this list are specified as the names of the azure Cosmos locations like, 'West US', 'East US' and so on. :keyword bool enable_diagnostics_logging: Enable the CosmosHttpLogging policy. Must be used along with a logger to work. :keyword ~logging.Logger logger: Logger to be used for collecting request diagnostics. Can be passed in at client diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/documents.py b/sdk/cosmos/azure-cosmos/azure/cosmos/documents.py index 40fbed24451f..7ccc99da9dfe 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/documents.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/documents.py @@ -308,6 +308,13 @@ class ConnectionPolicy: # pylint: disable=too-many-instance-attributes locations in this list are specified as the names of the azure Cosmos locations like, 'West US', 'East US', 'Central India' and so on. :vartype PreferredLocations: List[str] + :ivar ExcludedLocations: + Gets or sets the excluded locations for geo-replicated database + accounts. When ExcludedLocations is non-empty, the client will skip this + set of locations from the final location evaluation. The locations in + this list are specified as the names of the azure Cosmos locations like, + 'West US', 'East US', 'Central India' and so on. + :vartype ExcludedLocations: List[str] :ivar RetryOptions: Gets or sets the retry options to be applied to all requests when retrying. @@ -347,6 +354,7 @@ def __init__(self) -> None: self.ProxyConfiguration: Optional[ProxyConfiguration] = None self.EnableEndpointDiscovery: bool = True self.PreferredLocations: List[str] = [] + self.ExcludedLocations: List[str] = [] self.RetryOptions: RetryOptions = RetryOptions() self.DisableSSLVerification: bool = False self.UseMultipleWriteLocations: bool = False diff --git a/sdk/cosmos/azure-cosmos/samples/excluded_locations.py b/sdk/cosmos/azure-cosmos/samples/excluded_locations.py new file mode 100644 index 000000000000..a8c699a7cccf --- /dev/null +++ b/sdk/cosmos/azure-cosmos/samples/excluded_locations.py @@ -0,0 +1,115 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See LICENSE.txt in the project root for +# license information. +# ------------------------------------------------------------------------- +from azure.cosmos import CosmosClient +from azure.cosmos.partition_key import PartitionKey +import config + +# ---------------------------------------------------------------------------------------------------------- +# Prerequisites - +# +# 1. An Azure Cosmos account - +# https://learn.microsoft.com/azure/cosmos-db/create-sql-api-python#create-a-database-account +# +# 2. Microsoft Azure Cosmos +# pip install azure-cosmos>=4.3.0b4 +# +# 3. Configure Azure Cosmos account to add 3+ regions, such as 'West US 3', 'West US', 'East US 2'. +# If you added other regions, update L1~L3 with the regions in your account. +# ---------------------------------------------------------------------------------------------------------- +# Sample - demonstrates how to use excluded locations in client level and request level +# ---------------------------------------------------------------------------------------------------------- +# Note: +# This sample creates a Container to your database account. +# Each time a Container is created the account will be billed for 1 hour of usage based on +# the provisioned throughput (RU/s) of that account. +# ---------------------------------------------------------------------------------------------------------- + +HOST = config.settings["host"] +MASTER_KEY = config.settings["master_key"] + +TENANT_ID = config.settings["tenant_id"] +CLIENT_ID = config.settings["client_id"] +CLIENT_SECRET = config.settings["client_secret"] + +DATABASE_ID = config.settings["database_id"] +CONTAINER_ID = config.settings["container_id"] +PARTITION_KEY = PartitionKey(path="/pk") + +L1, L2, L3 = 'West US 3', 'West US', 'East US 2' + +def get_test_item(num): + test_item = { + 'id': 'Item_' + str(num), + 'pk': 'PartitionKey_' + str(num), + 'test_object': True, + 'lastName': 'Smith' + } + return test_item + +def clean_up_db(client): + try: + client.delete_database(DATABASE_ID) + except Exception as e: + pass + +def excluded_locations_client_level_sample(): + preferred_locations = [L1, L2, L3] + excluded_locations = [L1, L2] + client = CosmosClient( + HOST, + MASTER_KEY, + preferred_locations=preferred_locations, + excluded_locations=excluded_locations + ) + clean_up_db(client) + + db = client.create_database(DATABASE_ID) + container = db.create_container(id=CONTAINER_ID, partition_key=PARTITION_KEY) + + # For write operations with single master account, write endpoint will be the default endpoint, + # since preferred_locations or excluded_locations are ignored and used + created_item = container.create_item(get_test_item(0)) + + # For read operations, read endpoints will be 'preferred_locations' - 'excluded_locations'. + # In our sample, ['West US 3', 'West US', 'East US 2'] - ['West US 3', 'West US'] => ['East US 2'], + # therefore 'East US 2' will be the read endpoint, and items will be read from 'East US 2' location + item = container.read_item(item=created_item['id'], partition_key=created_item['pk']) + + clean_up_db(client) + +def excluded_locations_request_level_sample(): + preferred_locations = [L1, L2, L3] + excluded_locations_on_client = [L1, L2] + excluded_locations_on_request = [L1] + client = CosmosClient( + HOST, + MASTER_KEY, + preferred_locations=preferred_locations, + excluded_locations=excluded_locations_on_client + ) + clean_up_db(client) + + db = client.create_database(DATABASE_ID) + container = db.create_container(id=CONTAINER_ID, partition_key=PARTITION_KEY) + + # For write operations with single master account, write endpoint will be the default endpoint, + # since preferred_locations or excluded_locations are ignored and used + created_item = container.create_item(get_test_item(0), excluded_locations=excluded_locations_on_request) + + # For read operations, read endpoints will be 'preferred_locations' - 'excluded_locations'. + # However, in our sample, since the excluded_locations` were passed with the read request, the `excluded_location` + # will be replaced with the locations from request, ['West US 3']. The `excluded_locations` on request always takes + # the highest priority! + # With the excluded_locations on request, the read endpoints will be ['West US', 'East US 2'] + # ['West US 3', 'West US', 'East US 2'] - ['West US 3'] => ['West US', 'East US 2'] + # Therefore, items will be read from 'West US' or 'East US 2' location + item = container.read_item(item=created_item['id'], partition_key=created_item['pk'], excluded_locations=excluded_locations_on_request) + + clean_up_db(client) + +if __name__ == "__main__": + # excluded_locations_client_level_sample() + excluded_locations_request_level_sample() diff --git a/sdk/cosmos/azure-cosmos/tests/_fault_injection_transport.py b/sdk/cosmos/azure-cosmos/tests/_fault_injection_transport.py index 628456d95158..9a99229ec995 100644 --- a/sdk/cosmos/azure-cosmos/tests/_fault_injection_transport.py +++ b/sdk/cosmos/azure-cosmos/tests/_fault_injection_transport.py @@ -203,7 +203,10 @@ def transform_topology_swr_mrr( def transform_topology_mwr( first_region_name: str, second_region_name: str, - inner: Callable[[], RequestsTransportResponse]) -> RequestsTransportResponse: + inner: Callable[[], RequestsTransportResponse], + first_region_url: str = None, + second_region_url: str = test_config.TestConfig.local_host + ) -> RequestsTransportResponse: response = inner() if not FaultInjectionTransport.predicate_is_database_account_call(response.request): @@ -215,12 +218,17 @@ def transform_topology_mwr( result = json.loads(data) readable_locations = result["readableLocations"] writable_locations = result["writableLocations"] - readable_locations[0]["name"] = first_region_name - writable_locations[0]["name"] = first_region_name + + if first_region_url is None: + first_region_url = readable_locations[0]["databaseAccountEndpoint"] + readable_locations[0] = \ + {"name": first_region_name, "databaseAccountEndpoint": first_region_url} + writable_locations[0] = \ + {"name": first_region_name, "databaseAccountEndpoint": first_region_url} readable_locations.append( - {"name": second_region_name, "databaseAccountEndpoint": test_config.TestConfig.local_host}) + {"name": second_region_name, "databaseAccountEndpoint": second_region_url}) writable_locations.append( - {"name": second_region_name, "databaseAccountEndpoint": test_config.TestConfig.local_host}) + {"name": second_region_name, "databaseAccountEndpoint": second_region_url}) result["enableMultipleWriteLocations"] = True FaultInjectionTransport.logger.info("Transformed Account Topology: {}".format(result)) request: HttpRequest = response.request diff --git a/sdk/cosmos/azure-cosmos/tests/_fault_injection_transport_async.py b/sdk/cosmos/azure-cosmos/tests/_fault_injection_transport_async.py index 13dda0dc7e20..6bdeb4ed49c9 100644 --- a/sdk/cosmos/azure-cosmos/tests/_fault_injection_transport_async.py +++ b/sdk/cosmos/azure-cosmos/tests/_fault_injection_transport_async.py @@ -201,7 +201,10 @@ async def transform_topology_swr_mrr( async def transform_topology_mwr( first_region_name: str, second_region_name: str, - inner: Callable[[], Awaitable[AioHttpTransportResponse]]) -> AioHttpTransportResponse: + inner: Callable[[], Awaitable[AioHttpTransportResponse]], + first_region_url: str = None, + second_region_url: str = test_config.TestConfig.local_host + ) -> AioHttpTransportResponse: response = await inner() if not FaultInjectionTransportAsync.predicate_is_database_account_call(response.request): @@ -213,12 +216,17 @@ async def transform_topology_mwr( result = json.loads(data) readable_locations = result["readableLocations"] writable_locations = result["writableLocations"] - readable_locations[0]["name"] = first_region_name - writable_locations[0]["name"] = first_region_name + + if first_region_url is None: + first_region_url = readable_locations[0]["databaseAccountEndpoint"] + readable_locations[0] = \ + {"name": first_region_name, "databaseAccountEndpoint": first_region_url} + writable_locations[0] = \ + {"name": first_region_name, "databaseAccountEndpoint": first_region_url} readable_locations.append( - {"name": second_region_name, "databaseAccountEndpoint": test_config.TestConfig.local_host}) + {"name": second_region_name, "databaseAccountEndpoint": second_region_url}) writable_locations.append( - {"name": second_region_name, "databaseAccountEndpoint": test_config.TestConfig.local_host}) + {"name": second_region_name, "databaseAccountEndpoint": second_region_url}) result["enableMultipleWriteLocations"] = True FaultInjectionTransportAsync.logger.info("Transformed Account Topology: {}".format(result)) request: HttpRequest = response.request diff --git a/sdk/cosmos/azure-cosmos/tests/test_excluded_locations.py b/sdk/cosmos/azure-cosmos/tests/test_excluded_locations.py new file mode 100644 index 000000000000..4b517796dd3a --- /dev/null +++ b/sdk/cosmos/azure-cosmos/tests/test_excluded_locations.py @@ -0,0 +1,445 @@ +# The MIT License (MIT) +# Copyright (c) Microsoft Corporation. All rights reserved. + +import logging +import unittest +import uuid +import test_config +import pytest + +import azure.cosmos.cosmos_client as cosmos_client +from azure.cosmos.partition_key import PartitionKey +from azure.cosmos.exceptions import CosmosResourceNotFoundError + + +class MockHandler(logging.Handler): + def __init__(self): + super(MockHandler, self).__init__() + self.messages = [] + + def reset(self): + self.messages = [] + + def emit(self, record): + self.messages.append(record.msg) + +MOCK_HANDLER = MockHandler() +CONFIG = test_config.TestConfig() +HOST = CONFIG.host +KEY = CONFIG.masterKey +DATABASE_ID = CONFIG.TEST_DATABASE_ID +CONTAINER_ID = CONFIG.TEST_SINGLE_PARTITION_CONTAINER_ID +PARTITION_KEY = CONFIG.TEST_CONTAINER_PARTITION_KEY +ITEM_ID = 'doc1' +ITEM_PK_VALUE = 'pk' +TEST_ITEM = {'id': ITEM_ID, PARTITION_KEY: ITEM_PK_VALUE} + +L0 = "Default" +L1 = "West US 3" +L2 = "West US" +L3 = "East US 2" + +# L0 = "Default" +# L1 = "East US 2" +# L2 = "East US" +# L3 = "West US 2" + +CLIENT_ONLY_TEST_DATA = [ + # preferred_locations, client_excluded_locations, excluded_locations_request + # 0. No excluded location + [[L1, L2], [], None], + # 1. Single excluded location + [[L1, L2], [L1], None], + # 2. Exclude all locations + [[L1, L2], [L1, L2], None], + # 3. Exclude a location not in preferred locations + [[L1, L2], [L3], None], +] + +CLIENT_AND_REQUEST_TEST_DATA = [ + # preferred_locations, client_excluded_locations, excluded_locations_request + # 0. No client excluded locations + a request excluded location + [[L1, L2], [], [L1]], + # 1. The same client and request excluded location + [[L1, L2], [L1], [L1]], + # 2. Less request excluded locations + [[L1, L2], [L1, L2], [L1]], + # 3. More request excluded locations + [[L1, L2], [L1], [L1, L2]], + # 4. All locations were excluded + [[L1, L2], [L1, L2], [L1, L2]], + # 5. No common excluded locations + [[L1, L2], [L1], [L2]], + # 6. Request excluded location not in preferred locations + [[L1, L2], [L1, L2], [L3]], + # 7. Empty excluded locations, remove all client excluded locations + [[L1, L2], [L1, L2], []], +] + +ALL_INPUT_TEST_DATA = CLIENT_ONLY_TEST_DATA + CLIENT_AND_REQUEST_TEST_DATA +# ALL_INPUT_TEST_DATA = CLIENT_ONLY_TEST_DATA +# ALL_INPUT_TEST_DATA = CLIENT_AND_REQUEST_TEST_DATA + +def read_item_test_data(): + client_only_output_data = [ + [L1], # 0 + [L2], # 1 + [L1], # 2 + [L1], # 3 + ] + client_and_request_output_data = [ + [L2], # 0 + [L2], # 1 + [L2], # 2 + [L1], # 3 + [L1], # 4 + [L1], # 5 + [L1], # 6 + [L1], # 7 + ] + all_output_test_data = client_only_output_data + client_and_request_output_data + + all_test_data = [input_data + [output_data] for input_data, output_data in zip(ALL_INPUT_TEST_DATA, all_output_test_data)] + return all_test_data + +def query_items_change_feed_test_data(): + client_only_output_data = [ + [L1, L1, L1, L1], #0 + [L2, L2, L2, L2], #1 + [L1, L1, L1, L1], #2 + [L1, L1, L1, L1] #3 + ] + client_and_request_output_data = [ + [L1, L1, L2, L2], #0 + [L2, L2, L2, L2], #1 + [L1, L1, L2, L2], #2 + [L2, L2, L1, L1], #3 + [L1, L1, L1, L1], #4 + [L2, L2, L1, L1], #5 + [L1, L1, L1, L1], #6 + [L1, L1, L1, L1], #7 + ] + all_output_test_data = client_only_output_data + client_and_request_output_data + + all_test_data = [input_data + [output_data] for input_data, output_data in zip(ALL_INPUT_TEST_DATA, all_output_test_data)] + return all_test_data + +def replace_item_test_data(): + client_only_output_data = [ + [L1, L1], #0 + [L2, L2], #1 + [L1, L0], #2 + [L1, L1] #3 + ] + client_and_request_output_data = [ + [L2, L2], #0 + [L2, L2], #1 + [L2, L2], #2 + [L1, L0], #3 + [L1, L0], #4 + [L1, L1], #5 + [L1, L1], #6 + [L1, L1], #7 + ] + all_output_test_data = client_only_output_data + client_and_request_output_data + + all_test_data = [input_data + [output_data] for input_data, output_data in zip(ALL_INPUT_TEST_DATA, all_output_test_data)] + return all_test_data + +def patch_item_test_data(): + client_only_output_data = [ + [L1], #0 + [L2], #1 + [L0], #3 + [L1] #4 + ] + client_and_request_output_data = [ + [L2], #0 + [L2], #1 + [L2], #2 + [L0], #3 + [L0], #4 + [L1], #5 + [L1], #6 + [L1], #7 + ] + all_output_test_data = client_only_output_data + client_and_request_output_data + + all_test_data = [input_data + [output_data] for input_data, output_data in zip(ALL_INPUT_TEST_DATA, all_output_test_data)] + return all_test_data + +def _create_item_with_excluded_locations(container, body, excluded_locations): + if excluded_locations is None: + container.create_item(body=body) + else: + container.create_item(body=body, excluded_locations=excluded_locations) + +@pytest.fixture(scope="class", autouse=True) +def setup_and_teardown(): + print("Setup: This runs before any tests") + logger = logging.getLogger("azure") + logger.addHandler(MOCK_HANDLER) + logger.setLevel(logging.DEBUG) + + container = cosmos_client.CosmosClient(HOST, KEY).get_database_client(DATABASE_ID).get_container_client(CONTAINER_ID) + container.upsert_item(body=TEST_ITEM) + + yield + # Code to run after tests + print("Teardown: This runs after all tests") + +def _init_container(preferred_locations, client_excluded_locations, multiple_write_locations = True): + client = cosmos_client.CosmosClient(HOST, KEY, + preferred_locations=preferred_locations, + excluded_locations=client_excluded_locations, + multiple_write_locations=multiple_write_locations) + db = client.get_database_client(DATABASE_ID) + container = db.get_container_client(CONTAINER_ID) + MOCK_HANDLER.reset() + + return client, db, container + +def _verify_endpoint(messages, client, expected_locations): + # get mapping for locations + location_mapping = (client.client_connection._global_endpoint_manager. + location_cache.account_locations_by_write_regional_routing_context) + default_endpoint = (client.client_connection._global_endpoint_manager. + location_cache.default_regional_routing_context.get_primary()) + + # get Request URL + req_urls = [url.replace("Request URL: '", "") for url in messages if 'Request URL:' in url] + + # get location + actual_locations = [] + for req_url in req_urls: + if req_url.startswith(default_endpoint): + actual_locations.append(L0) + else: + for endpoint in location_mapping: + if req_url.startswith(endpoint): + location = location_mapping[endpoint] + actual_locations.append(location) + break + + assert actual_locations == expected_locations + +@pytest.mark.cosmosMultiRegion +class TestExcludedLocations: + @pytest.mark.parametrize('test_data', read_item_test_data()) + def test_read_item(self, test_data): + # Init test variables + preferred_locations, client_excluded_locations, request_excluded_locations, expected_locations = test_data + + # Client setup + client, db, container = _init_container(preferred_locations, client_excluded_locations) + + # API call: read_item + if request_excluded_locations is None: + container.read_item(ITEM_ID, ITEM_PK_VALUE) + else: + container.read_item(ITEM_ID, ITEM_PK_VALUE, excluded_locations=request_excluded_locations) + + # Verify endpoint locations + _verify_endpoint(MOCK_HANDLER.messages, client, expected_locations) + + @pytest.mark.parametrize('test_data', read_item_test_data()) + def test_read_all_items(self, test_data): + # Init test variables + preferred_locations, client_excluded_locations, request_excluded_locations, expected_locations = test_data + + # Client setup + client, db, container = _init_container(preferred_locations, client_excluded_locations) + + # API call: read_all_items + if request_excluded_locations is None: + list(container.read_all_items()) + else: + list(container.read_all_items(excluded_locations=request_excluded_locations)) + + # Verify endpoint locations + _verify_endpoint(MOCK_HANDLER.messages, client, expected_locations) + + @pytest.mark.parametrize('test_data', read_item_test_data()) + def test_query_items(self, test_data): + # Init test variables + preferred_locations, client_excluded_locations, request_excluded_locations, expected_locations = test_data + + # Client setup and create an item + client, db, container = _init_container(preferred_locations, client_excluded_locations) + + # API call: query_items + if request_excluded_locations is None: + list(container.query_items(None)) + else: + list(container.query_items(None, excluded_locations=request_excluded_locations)) + + # Verify endpoint locations + _verify_endpoint(MOCK_HANDLER.messages, client, expected_locations) + + @pytest.mark.parametrize('test_data', query_items_change_feed_test_data()) + def test_query_items_change_feed(self, test_data): + # Init test variables + preferred_locations, client_excluded_locations, request_excluded_locations, expected_locations = test_data + + + # Client setup and create an item + client, db, container = _init_container(preferred_locations, client_excluded_locations) + + # API call: query_items_change_feed + if request_excluded_locations is None: + items = list(container.query_items_change_feed(start_time="Beginning")) + else: + items = list(container.query_items_change_feed(start_time="Beginning", excluded_locations=request_excluded_locations)) + + # Verify endpoint locations + _verify_endpoint(MOCK_HANDLER.messages, client, expected_locations) + + + @pytest.mark.parametrize('test_data', replace_item_test_data()) + def test_replace_item(self, test_data): + # Init test variables + preferred_locations, client_excluded_locations, request_excluded_locations, expected_locations = test_data + + for multiple_write_locations in [True, False]: + # Client setup and create an item + client, db, container = _init_container(preferred_locations, client_excluded_locations, multiple_write_locations) + + # API call: replace_item + if request_excluded_locations is None: + container.replace_item(ITEM_ID, body=TEST_ITEM) + else: + container.replace_item(ITEM_ID, body=TEST_ITEM, excluded_locations=request_excluded_locations) + + # Verify endpoint locations + if multiple_write_locations: + _verify_endpoint(MOCK_HANDLER.messages, client, expected_locations) + else: + _verify_endpoint(MOCK_HANDLER.messages, client, [expected_locations[0], L1]) + + @pytest.mark.parametrize('test_data', replace_item_test_data()) + def test_upsert_item(self, test_data): + # Init test variables + preferred_locations, client_excluded_locations, request_excluded_locations, expected_locations = test_data + + for multiple_write_locations in [True, False]: + # Client setup and create an item + client, db, container = _init_container(preferred_locations, client_excluded_locations, multiple_write_locations) + + # API call: upsert_item + body = {'pk': 'pk', 'id': f'doc2-{str(uuid.uuid4())}'} + if request_excluded_locations is None: + container.upsert_item(body=body) + else: + container.upsert_item(body=body, excluded_locations=request_excluded_locations) + + # get location from mock_handler + if multiple_write_locations: + _verify_endpoint(MOCK_HANDLER.messages, client, expected_locations) + else: + _verify_endpoint(MOCK_HANDLER.messages, client, [expected_locations[0], L1]) + + @pytest.mark.parametrize('test_data', replace_item_test_data()) + def test_create_item(self, test_data): + # Init test variables + preferred_locations, client_excluded_locations, request_excluded_locations, expected_locations = test_data + + for multiple_write_locations in [True, False]: + # Client setup and create an item + client, db, container = _init_container(preferred_locations, client_excluded_locations, multiple_write_locations) + + # API call: create_item + body = {'pk': 'pk', 'id': f'doc2-{str(uuid.uuid4())}'} + _create_item_with_excluded_locations(container, body, request_excluded_locations) + + # get location from mock_handler + if multiple_write_locations: + _verify_endpoint(MOCK_HANDLER.messages, client, expected_locations) + else: + _verify_endpoint(MOCK_HANDLER.messages, client, [expected_locations[0], L1]) + + @pytest.mark.parametrize('test_data', patch_item_test_data()) + def test_patch_item(self, test_data): + # Init test variables + preferred_locations, client_excluded_locations, request_excluded_locations, expected_locations = test_data + + for multiple_write_locations in [True, False]: + # Client setup and create an item + client, db, container = _init_container(preferred_locations, client_excluded_locations, + multiple_write_locations) + + # API call: patch_item + operations = [ + {"op": "add", "path": "/test_data", "value": f'Data-{str(uuid.uuid4())}'}, + ] + if request_excluded_locations is None: + container.patch_item(item=ITEM_ID, partition_key=ITEM_PK_VALUE, + patch_operations=operations) + else: + container.patch_item(item=ITEM_ID, partition_key=ITEM_PK_VALUE, + patch_operations=operations, + excluded_locations=request_excluded_locations) + + # get location from mock_handler + if multiple_write_locations: + _verify_endpoint(MOCK_HANDLER.messages, client, expected_locations) + else: + _verify_endpoint(MOCK_HANDLER.messages, client, [L1]) + + @pytest.mark.parametrize('test_data', patch_item_test_data()) + def test_execute_item_batch(self, test_data): + # Init test variables + preferred_locations, client_excluded_locations, request_excluded_locations, expected_locations = test_data + + for multiple_write_locations in [True, False]: + # Client setup and create an item + client, db, container = _init_container(preferred_locations, client_excluded_locations, + multiple_write_locations) + + # API call: execute_item_batch + batch_operations = [] + for i in range(3): + batch_operations.append(("create", ({"id": f'Doc-{str(uuid.uuid4())}', PARTITION_KEY: ITEM_PK_VALUE},))) + + if request_excluded_locations is None: + container.execute_item_batch(batch_operations=batch_operations, + partition_key=ITEM_PK_VALUE,) + else: + container.execute_item_batch(batch_operations=batch_operations, + partition_key=ITEM_PK_VALUE, + excluded_locations=request_excluded_locations) + + # get location from mock_handler + if multiple_write_locations: + _verify_endpoint(MOCK_HANDLER.messages, client, expected_locations) + else: + _verify_endpoint(MOCK_HANDLER.messages, client, [L1]) + + @pytest.mark.parametrize('test_data', patch_item_test_data()) + def test_delete_item(self, test_data): + # Init test variables + preferred_locations, client_excluded_locations, request_excluded_locations, expected_locations = test_data + + for multiple_write_locations in [True, False]: + # Client setup + client, db, container = _init_container(preferred_locations, client_excluded_locations, multiple_write_locations) + + # create before delete + item_id = f'doc2-{str(uuid.uuid4())}' + body = {PARTITION_KEY: ITEM_PK_VALUE, 'id': item_id} + _create_item_with_excluded_locations(container, body, request_excluded_locations) + MOCK_HANDLER.reset() + + # API call: delete_item + if request_excluded_locations is None: + container.delete_item(item_id, ITEM_PK_VALUE) + else: + container.delete_item(item_id, ITEM_PK_VALUE, excluded_locations=request_excluded_locations) + + # Verify endpoint locations + if multiple_write_locations: + _verify_endpoint(MOCK_HANDLER.messages, client, expected_locations) + else: + _verify_endpoint(MOCK_HANDLER.messages, client, [L1]) + +if __name__ == "__main__": + unittest.main() diff --git a/sdk/cosmos/azure-cosmos/tests/test_excluded_locations_async.py b/sdk/cosmos/azure-cosmos/tests/test_excluded_locations_async.py new file mode 100644 index 000000000000..11ababfdfafd --- /dev/null +++ b/sdk/cosmos/azure-cosmos/tests/test_excluded_locations_async.py @@ -0,0 +1,443 @@ +# The MIT License (MIT) +# Copyright (c) Microsoft Corporation. All rights reserved. + +import logging +import unittest +import uuid +import test_config +import pytest +import pytest_asyncio + +from azure.cosmos.aio import CosmosClient +from azure.cosmos.partition_key import PartitionKey +from test_excluded_locations import _verify_endpoint + +class MockHandler(logging.Handler): + def __init__(self): + super(MockHandler, self).__init__() + self.messages = [] + + def reset(self): + self.messages = [] + + def emit(self, record): + self.messages.append(record.msg) + +MOCK_HANDLER = MockHandler() +CONFIG = test_config.TestConfig() +HOST = CONFIG.host +KEY = CONFIG.masterKey +DATABASE_ID = CONFIG.TEST_DATABASE_ID +CONTAINER_ID = CONFIG.TEST_SINGLE_PARTITION_CONTAINER_ID +PARTITION_KEY = CONFIG.TEST_CONTAINER_PARTITION_KEY +ITEM_ID = 'doc1' +ITEM_PK_VALUE = 'pk' +TEST_ITEM = {'id': ITEM_ID, PARTITION_KEY: ITEM_PK_VALUE} + +L0 = "Default" +L1 = "West US 3" +L2 = "West US" +L3 = "East US 2" + +# L0 = "Default" +# L1 = "East US 2" +# L2 = "East US" +# L3 = "West US 2" + +CLIENT_ONLY_TEST_DATA = [ + # preferred_locations, client_excluded_locations, excluded_locations_request + # 0. No excluded location + [[L1, L2], [], None], + # 1. Single excluded location + [[L1, L2], [L1], None], + # 2. Exclude all locations + [[L1, L2], [L1, L2], None], + # 3. Exclude a location not in preferred locations + [[L1, L2], [L3], None], +] + +CLIENT_AND_REQUEST_TEST_DATA = [ + # preferred_locations, client_excluded_locations, excluded_locations_request + # 0. No client excluded locations + a request excluded location + [[L1, L2], [], [L1]], + # 1. The same client and request excluded location + [[L1, L2], [L1], [L1]], + # 2. Less request excluded locations + [[L1, L2], [L1, L2], [L1]], + # 3. More request excluded locations + [[L1, L2], [L1], [L1, L2]], + # 4. All locations were excluded + [[L1, L2], [L1, L2], [L1, L2]], + # 5. No common excluded locations + [[L1, L2], [L1], [L2, L3]], + # 6. Request excluded location not in preferred locations + [[L1, L2], [L1, L2], [L3]], + # 7. Empty excluded locations, remove all client excluded locations + [[L1, L2], [L1, L2], []], +] + +ALL_INPUT_TEST_DATA = CLIENT_ONLY_TEST_DATA + CLIENT_AND_REQUEST_TEST_DATA + +def read_item_test_data(): + client_only_output_data = [ + [L1], # 0 + [L2], # 1 + [L1], # 2 + [L1], # 3 + ] + client_and_request_output_data = [ + [L2], # 0 + [L2], # 1 + [L2], # 2 + [L1], # 3 + [L1], # 4 + [L1], # 5 + [L1], # 6 + [L1], # 7 + ] + all_output_test_data = client_only_output_data + client_and_request_output_data + + all_test_data = [input_data + [output_data] for input_data, output_data in zip(ALL_INPUT_TEST_DATA, all_output_test_data)] + return all_test_data + +def read_all_item_test_data(): + client_only_output_data = [ + [L1, L1], # 0 + [L2, L2], # 1 + [L1, L1], # 2 + [L1, L1], # 3 + ] + client_and_request_output_data = [ + [L2, L2], # 0 + [L2, L2], # 1 + [L2, L2], # 2 + [L1, L1], # 3 + [L1, L1], # 4 + [L1, L1], # 5 + [L1, L1], # 6 + [L1, L1], # 7 + ] + all_output_test_data = client_only_output_data + client_and_request_output_data + + all_test_data = [input_data + [output_data] for input_data, output_data in zip(ALL_INPUT_TEST_DATA, all_output_test_data)] + return all_test_data + +def query_items_change_feed_test_data(): + client_only_output_data = [ + [L1, L1, L1, L1], #0 + [L2, L2, L2, L2], #1 + [L1, L1, L1, L1], #2 + [L1, L1, L1, L1] #3 + ] + client_and_request_output_data = [ + [L1, L2, L2, L2], #0 + [L2, L2, L2, L2], #1 + [L1, L2, L2, L2], #2 + [L2, L1, L1, L1], #3 + [L1, L1, L1, L1], #4 + [L2, L1, L1, L1], #5 + [L1, L1, L1, L1], #6 + [L1, L1, L1, L1], #7 + ] + all_output_test_data = client_only_output_data + client_and_request_output_data + + all_test_data = [input_data + [output_data] for input_data, output_data in zip(ALL_INPUT_TEST_DATA, all_output_test_data)] + return all_test_data + +def replace_item_test_data(): + client_only_output_data = [ + [L1], #0 + [L2], #1 + [L0], #2 + [L1] #3 + ] + client_and_request_output_data = [ + [L2], #0 + [L2], #1 + [L2], #2 + [L0], #3 + [L0], #4 + [L1], #5 + [L1], #6 + [L1], #7 + ] + all_output_test_data = client_only_output_data + client_and_request_output_data + + all_test_data = [input_data + [output_data] for input_data, output_data in zip(ALL_INPUT_TEST_DATA, all_output_test_data)] + return all_test_data + +def patch_item_test_data(): + client_only_output_data = [ + [L1], #0 + [L2], #1 + [L0], #2 + [L1] #3 + ] + client_and_request_output_data = [ + [L2], #0 + [L2], #1 + [L2], #2 + [L0], #3 + [L0], #4 + [L1], #5 + [L1], #6 + [L1], #7 + ] + all_output_test_data = client_only_output_data + client_and_request_output_data + + all_test_data = [input_data + [output_data] for input_data, output_data in zip(ALL_INPUT_TEST_DATA, all_output_test_data)] + return all_test_data + +async def _create_item_with_excluded_locations(container, body, excluded_locations): + if excluded_locations is None: + await container.create_item(body=body) + else: + await container.create_item(body=body, excluded_locations=excluded_locations) + +@pytest_asyncio.fixture(scope="class", autouse=True) +async def setup_and_teardown(): + print("Setup: This runs before any tests") + logger = logging.getLogger("azure") + logger.addHandler(MOCK_HANDLER) + logger.setLevel(logging.DEBUG) + + test_client = CosmosClient(HOST, KEY) + container = test_client.get_database_client(DATABASE_ID).get_container_client(CONTAINER_ID) + await container.upsert_item(body=TEST_ITEM) + + yield + await test_client.close() + +async def _init_container(preferred_locations, client_excluded_locations, multiple_write_locations = True): + client = CosmosClient(HOST, KEY, + preferred_locations=preferred_locations, + excluded_locations=client_excluded_locations, + multiple_write_locations=multiple_write_locations) + db = await client.create_database_if_not_exists(DATABASE_ID) + container = await db.create_container_if_not_exists(CONTAINER_ID, PartitionKey(path='/' + PARTITION_KEY, kind='Hash')) + MOCK_HANDLER.reset() + + return client, db, container + +@pytest.mark.cosmosMultiRegion +@pytest.mark.asyncio +@pytest.mark.usefixtures("setup_and_teardown") +class TestExcludedLocations: + @pytest.mark.parametrize('test_data', read_item_test_data()) + async def test_read_item(self, test_data): + # Init test variables + preferred_locations, client_excluded_locations, request_excluded_locations, expected_locations = test_data + + # Client setup + client, db, container = await _init_container(preferred_locations, client_excluded_locations) + + # API call: read_item + if request_excluded_locations is None: + await container.read_item(ITEM_ID, ITEM_PK_VALUE) + else: + await container.read_item(ITEM_ID, ITEM_PK_VALUE, excluded_locations=request_excluded_locations) + + # Verify endpoint locations + _verify_endpoint(MOCK_HANDLER.messages, client, expected_locations) + + @pytest.mark.parametrize('test_data', read_all_item_test_data()) + async def test_read_all_items(self, test_data): + # Init test variables + preferred_locations, client_excluded_locations, request_excluded_locations, expected_locations = test_data + + # Client setup + client, db, container = await _init_container(preferred_locations, client_excluded_locations) + + # API call: read_all_items + if request_excluded_locations is None: + all_items = [item async for item in container.read_all_items()] + else: + all_items = [item async for item in container.read_all_items(excluded_locations=request_excluded_locations)] + + # Verify endpoint locations + _verify_endpoint(MOCK_HANDLER.messages, client, expected_locations) + + @pytest.mark.parametrize('test_data', read_all_item_test_data()) + async def test_query_items(self, test_data): + # Init test variables + preferred_locations, client_excluded_locations, request_excluded_locations, expected_locations = test_data + + # Client setup and create an item + client, db, container = await _init_container(preferred_locations, client_excluded_locations) + + # API call: query_items + if request_excluded_locations is None: + all_items = [item async for item in container.query_items(None)] + else: + all_items = [item async for item in container.query_items(None, excluded_locations=request_excluded_locations)] + + # Verify endpoint locations + _verify_endpoint(MOCK_HANDLER.messages, client, expected_locations) + + @pytest.mark.parametrize('test_data', query_items_change_feed_test_data()) + async def test_query_items_change_feed(self, test_data): + # Init test variables + preferred_locations, client_excluded_locations, request_excluded_locations, expected_locations = test_data + + + # Client setup and create an item + client, db, container = await _init_container(preferred_locations, client_excluded_locations) + + # API call: query_items_change_feed + if request_excluded_locations is None: + all_items = [item async for item in container.query_items_change_feed(start_time="Beginning")] + else: + all_items = [item async for item in container.query_items_change_feed(start_time="Beginning", excluded_locations=request_excluded_locations)] + + # Verify endpoint locations + _verify_endpoint(MOCK_HANDLER.messages, client, expected_locations) + + + @pytest.mark.parametrize('test_data', replace_item_test_data()) + async def test_replace_item(self, test_data): + # Init test variables + preferred_locations, client_excluded_locations, request_excluded_locations, expected_locations = test_data + + for multiple_write_locations in [True, False]: + # Client setup and create an item + client, db, container = await _init_container(preferred_locations, client_excluded_locations, multiple_write_locations) + + # API call: replace_item + if request_excluded_locations is None: + await container.replace_item(ITEM_ID, body=TEST_ITEM) + else: + await container.replace_item(ITEM_ID, body=TEST_ITEM, excluded_locations=request_excluded_locations) + + # Verify endpoint locations + if multiple_write_locations: + _verify_endpoint(MOCK_HANDLER.messages, client, expected_locations) + else: + _verify_endpoint(MOCK_HANDLER.messages, client, [L1]) + + @pytest.mark.parametrize('test_data', replace_item_test_data()) + async def test_upsert_item(self, test_data): + # Init test variables + preferred_locations, client_excluded_locations, request_excluded_locations, expected_locations = test_data + + for multiple_write_locations in [True, False]: + # Client setup and create an item + client, db, container = await _init_container(preferred_locations, client_excluded_locations, multiple_write_locations) + + # API call: upsert_item + body = {'pk': 'pk', 'id': f'doc2-{str(uuid.uuid4())}'} + if request_excluded_locations is None: + await container.upsert_item(body=body) + else: + await container.upsert_item(body=body, excluded_locations=request_excluded_locations) + + # get location from mock_handler + if multiple_write_locations: + _verify_endpoint(MOCK_HANDLER.messages, client, expected_locations) + else: + _verify_endpoint(MOCK_HANDLER.messages, client, [L1]) + + @pytest.mark.parametrize('test_data', replace_item_test_data()) + async def test_create_item(self, test_data): + # Init test variables + preferred_locations, client_excluded_locations, request_excluded_locations, expected_locations = test_data + + for multiple_write_locations in [True, False]: + # Client setup and create an item + client, db, container = await _init_container(preferred_locations, client_excluded_locations, multiple_write_locations) + + # API call: create_item + body = {'pk': 'pk', 'id': f'doc2-{str(uuid.uuid4())}'} + await _create_item_with_excluded_locations(container, body, request_excluded_locations) + + # get location from mock_handler + if multiple_write_locations: + _verify_endpoint(MOCK_HANDLER.messages, client, expected_locations) + else: + _verify_endpoint(MOCK_HANDLER.messages, client, [L1]) + + @pytest.mark.parametrize('test_data', patch_item_test_data()) + async def test_patch_item(self, test_data): + # Init test variables + preferred_locations, client_excluded_locations, request_excluded_locations, expected_locations = test_data + + for multiple_write_locations in [True, False]: + # Client setup and create an item + client, db, container = await _init_container(preferred_locations, client_excluded_locations, + multiple_write_locations) + + # API call: patch_item + operations = [ + {"op": "add", "path": "/test_data", "value": f'Data-{str(uuid.uuid4())}'}, + ] + if request_excluded_locations is None: + await container.patch_item(item=ITEM_ID, partition_key=ITEM_PK_VALUE, + patch_operations=operations) + else: + await container.patch_item(item=ITEM_ID, partition_key=ITEM_PK_VALUE, + patch_operations=operations, + excluded_locations=request_excluded_locations) + + # get location from mock_handler + if multiple_write_locations: + _verify_endpoint(MOCK_HANDLER.messages, client, expected_locations) + else: + _verify_endpoint(MOCK_HANDLER.messages, client, [L1]) + + @pytest.mark.parametrize('test_data', patch_item_test_data()) + async def test_execute_item_batch(self, test_data): + # Init test variables + preferred_locations, client_excluded_locations, request_excluded_locations, expected_locations = test_data + + for multiple_write_locations in [True, False]: + # Client setup and create an item + client, db, container = await _init_container(preferred_locations, client_excluded_locations, + multiple_write_locations) + + # API call: execute_item_batch + batch_operations = [] + for i in range(3): + batch_operations.append(("create", ({"id": f'Doc-{str(uuid.uuid4())}', PARTITION_KEY: ITEM_PK_VALUE},))) + + if request_excluded_locations is None: + await container.execute_item_batch(batch_operations=batch_operations, + partition_key=ITEM_PK_VALUE,) + else: + await container.execute_item_batch(batch_operations=batch_operations, + partition_key=ITEM_PK_VALUE, + excluded_locations=request_excluded_locations) + + # get location from mock_handler + if multiple_write_locations: + _verify_endpoint(MOCK_HANDLER.messages, client, expected_locations) + else: + _verify_endpoint(MOCK_HANDLER.messages, client, [L1]) + + @pytest.mark.parametrize('test_data', patch_item_test_data()) + async def test_delete_item(self, test_data): + # Init test variables + preferred_locations, client_excluded_locations, request_excluded_locations, expected_locations = test_data + + for multiple_write_locations in [True, False]: + # Client setup + client, db, container = await _init_container(preferred_locations, client_excluded_locations, multiple_write_locations) + + # create before delete + item_id = f'doc2-{str(uuid.uuid4())}' + body = {PARTITION_KEY: ITEM_PK_VALUE, 'id': item_id} + await _create_item_with_excluded_locations(container, body, request_excluded_locations) + MOCK_HANDLER.reset() + + # API call: delete_item + if request_excluded_locations is None: + await container.delete_item(item_id, ITEM_PK_VALUE) + else: + await container.delete_item(item_id, ITEM_PK_VALUE, excluded_locations=request_excluded_locations) + + # Verify endpoint locations + if multiple_write_locations: + _verify_endpoint(MOCK_HANDLER.messages, client, expected_locations) + else: + _verify_endpoint(MOCK_HANDLER.messages, client, [L1]) + +if __name__ == "__main__": + unittest.main() diff --git a/sdk/cosmos/azure-cosmos/tests/test_excluded_locations_emulator.py b/sdk/cosmos/azure-cosmos/tests/test_excluded_locations_emulator.py new file mode 100644 index 000000000000..375bcfc899d8 --- /dev/null +++ b/sdk/cosmos/azure-cosmos/tests/test_excluded_locations_emulator.py @@ -0,0 +1,122 @@ +# The MIT License (MIT) +# Copyright (c) Microsoft Corporation. All rights reserved. + +import unittest +import uuid +import test_config +import pytest +from typing import Callable, List, Mapping, Any + +from azure.core.rest import HttpRequest +from _fault_injection_transport import FaultInjectionTransport +from azure.cosmos.container import ContainerProxy +from test_excluded_locations import L1, L2, CLIENT_ONLY_TEST_DATA, CLIENT_AND_REQUEST_TEST_DATA +from test_fault_injection_transport import TestFaultInjectionTransport + +CONFIG = test_config.TestConfig() + +L1_URL = test_config.TestConfig.local_host +L2_URL = L1_URL.replace("localhost", "127.0.0.1") +URL_TO_LOCATIONS = { + L1_URL: L1, + L2_URL: L2 +} + +ALL_INPUT_TEST_DATA = CLIENT_ONLY_TEST_DATA + CLIENT_AND_REQUEST_TEST_DATA + +def delete_all_items_by_partition_key_test_data() -> List[str]: + client_only_output_data = [ + L1, #0 + L2, #1 + L1, #3 + L1 #4 + ] + client_and_request_output_data = [ + L2, #0 + L2, #1 + L2, #2 + L1, #3 + L1, #4 + L1, #5 + L1, #6 + L1, #7 + ] + all_output_test_data = client_only_output_data + client_and_request_output_data + + all_test_data = [input_data + [output_data] for input_data, output_data in zip(ALL_INPUT_TEST_DATA, all_output_test_data)] + return all_test_data + +def get_location( + initialized_objects: Mapping[str, Any], + url_to_locations: Mapping[str, str] = URL_TO_LOCATIONS) -> str: + # get Request URL + header = initialized_objects['client'].client_connection.last_response_headers + request_url = header["_request"].url + + # verify + location = "" + for url in url_to_locations: + if request_url.startswith(url): + location = url_to_locations[url] + break + return location + +@pytest.mark.unittest +@pytest.mark.cosmosEmulator +class TestExcludedLocationsEmulator: + @pytest.mark.parametrize('test_data', delete_all_items_by_partition_key_test_data()) + def test_delete_all_items_by_partition_key(self: "TestExcludedLocationsEmulator", test_data: List[List[str]]): + # Init test variables + preferred_locations, client_excluded_locations, request_excluded_locations, expected_location = test_data + + # Inject topology transformation that would make Emulator look like a multiple write region account + # with two read regions + custom_transport = FaultInjectionTransport() + is_get_account_predicate: Callable[[HttpRequest], bool] = lambda \ + r: FaultInjectionTransport.predicate_is_database_account_call(r) + emulator_as_multi_write_region_account_transformation = \ + lambda r, inner: FaultInjectionTransport.transform_topology_mwr( + first_region_name=L1, + second_region_name=L2, + inner=inner, + first_region_url=L1_URL, + second_region_url=L2_URL, + ) + custom_transport.add_response_transformation( + is_get_account_predicate, + emulator_as_multi_write_region_account_transformation) + + for multiple_write_locations in [True, False]: + # Create client + initialized_objects = TestFaultInjectionTransport.setup_method_with_custom_transport( + custom_transport, + default_endpoint=CONFIG.host, + key=CONFIG.masterKey, + database_id=CONFIG.TEST_DATABASE_ID, + container_id=CONFIG.TEST_SINGLE_PARTITION_CONTAINER_ID, + preferred_locations=preferred_locations, + excluded_locations=client_excluded_locations, + multiple_write_locations=multiple_write_locations, + ) + container: ContainerProxy = initialized_objects["col"] + + # create an item + id_value: str = str(uuid.uuid4()) + document_definition = {'id': id_value, 'pk': id_value} + container.create_item(body=document_definition) + + # API call: delete_all_items_by_partition_key + if request_excluded_locations is None: + container.delete_all_items_by_partition_key(id_value) + else: + container.delete_all_items_by_partition_key(id_value, excluded_locations=request_excluded_locations) + + # Verify endpoint locations + actual_location = get_location(initialized_objects) + if multiple_write_locations: + assert actual_location == expected_location + else: + assert actual_location == L1 + +if __name__ == "__main__": + unittest.main() diff --git a/sdk/cosmos/azure-cosmos/tests/test_excluded_locations_emulator_async.py b/sdk/cosmos/azure-cosmos/tests/test_excluded_locations_emulator_async.py new file mode 100644 index 000000000000..c24c2f13c9f7 --- /dev/null +++ b/sdk/cosmos/azure-cosmos/tests/test_excluded_locations_emulator_async.py @@ -0,0 +1,100 @@ +# The MIT License (MIT) +# Copyright (c) Microsoft Corporation. All rights reserved. + +import unittest +import uuid +import test_config +import pytest +from typing import Callable, List, Mapping, Any + +from azure.core.pipeline.transport import AioHttpTransport +from _fault_injection_transport_async import FaultInjectionTransportAsync +from azure.cosmos.aio._container import ContainerProxy +from test_excluded_locations import L1, L2, CLIENT_ONLY_TEST_DATA, CLIENT_AND_REQUEST_TEST_DATA +from test_excluded_locations_emulator import L1_URL, L2_URL, get_location +from test_fault_injection_transport_async import TestFaultInjectionTransportAsync + +CONFIG = test_config.TestConfig() + +ALL_INPUT_TEST_DATA = CLIENT_ONLY_TEST_DATA + CLIENT_AND_REQUEST_TEST_DATA + +def delete_all_items_by_partition_key_test_data() -> List[str]: + client_only_output_data = [ + L1, #0 + L2, #1 + L1, #3 + L1 #4 + ] + client_and_request_output_data = [ + L2, #0 + L2, #1 + L2, #2 + L1, #3 + L1, #4 + L1, #5 + L1, #6 + L1, #7 + ] + all_output_test_data = client_only_output_data + client_and_request_output_data + + all_test_data = [input_data + [output_data] for input_data, output_data in zip(ALL_INPUT_TEST_DATA, all_output_test_data)] + return all_test_data + +@pytest.mark.cosmosEmulator +@pytest.mark.asyncio +class TestExcludedLocationsEmulatorAsync: + @pytest.mark.parametrize('test_data', delete_all_items_by_partition_key_test_data()) + async def test_delete_all_items_by_partition_key(self: "TestExcludedLocationsEmulatorAsync", test_data: List[List[str]]): + # Init test variables + preferred_locations, client_excluded_locations, request_excluded_locations, expected_location = test_data + + # Inject topology transformation that would make Emulator look like a multiple write region account + # with two read regions + custom_transport = FaultInjectionTransportAsync() + is_get_account_predicate: Callable[[AioHttpTransport], bool] = lambda \ + r: FaultInjectionTransportAsync.predicate_is_database_account_call(r) + emulator_as_multi_write_region_account_transformation = \ + lambda r, inner: FaultInjectionTransportAsync.transform_topology_mwr( + first_region_name=L1, + first_region_url=L1_URL, + inner=inner, + second_region_name=L2, + second_region_url=L2_URL) + custom_transport.add_response_transformation( + is_get_account_predicate, + emulator_as_multi_write_region_account_transformation) + + for multiple_write_locations in [True, False]: + # Create client + initialized_objects = await TestFaultInjectionTransportAsync.setup_method_with_custom_transport( + custom_transport, + default_endpoint=CONFIG.host, + key=CONFIG.masterKey, + database_id=CONFIG.TEST_DATABASE_ID, + container_id=CONFIG.TEST_SINGLE_PARTITION_CONTAINER_ID, + preferred_locations=preferred_locations, + excluded_locations=client_excluded_locations, + multiple_write_locations=multiple_write_locations, + ) + container: ContainerProxy = initialized_objects["col"] + + # create an item + id_value: str = str(uuid.uuid4()) + document_definition = {'id': id_value, 'pk': id_value} + await container.create_item(body=document_definition) + + # API call: delete_all_items_by_partition_key + if request_excluded_locations is None: + await container.delete_all_items_by_partition_key(id_value) + else: + await container.delete_all_items_by_partition_key(id_value, excluded_locations=request_excluded_locations) + + # Verify endpoint locations + actual_location = get_location(initialized_objects) + if multiple_write_locations: + assert actual_location == expected_location + else: + assert actual_location == L1 + +if __name__ == "__main__": + unittest.main() diff --git a/sdk/cosmos/azure-cosmos/tests/test_fault_injection_transport.py b/sdk/cosmos/azure-cosmos/tests/test_fault_injection_transport.py index 304fa8d50f0d..4d7ea16ee58e 100644 --- a/sdk/cosmos/azure-cosmos/tests/test_fault_injection_transport.py +++ b/sdk/cosmos/azure-cosmos/tests/test_fault_injection_transport.py @@ -62,11 +62,17 @@ def teardown_class(cls): logger.warning("Exception trying to delete database {}. {}".format(created_database.id, containerDeleteError)) @staticmethod - def setup_method_with_custom_transport(custom_transport: RequestsTransport, default_endpoint=host, **kwargs): - client = CosmosClient(default_endpoint, master_key, consistency_level="Session", + def setup_method_with_custom_transport( + custom_transport: RequestsTransport, + default_endpoint: str = host, + key: str = master_key, + database_id: str = TEST_DATABASE_ID, + container_id: str = SINGLE_PARTITION_CONTAINER_NAME, + **kwargs): + client = CosmosClient(default_endpoint, key, consistency_level="Session", transport=custom_transport, logger=logger, enable_diagnostics_logging=True, **kwargs) - db: DatabaseProxy = client.get_database_client(TEST_DATABASE_ID) - container: ContainerProxy = db.get_container_client(SINGLE_PARTITION_CONTAINER_NAME) + db: DatabaseProxy = client.get_database_client(database_id) + container: ContainerProxy = db.get_container_client(container_id) return {"client": client, "db": db, "col": container} diff --git a/sdk/cosmos/azure-cosmos/tests/test_fault_injection_transport_async.py b/sdk/cosmos/azure-cosmos/tests/test_fault_injection_transport_async.py index 1df1de05936d..83535510c983 100644 --- a/sdk/cosmos/azure-cosmos/tests/test_fault_injection_transport_async.py +++ b/sdk/cosmos/azure-cosmos/tests/test_fault_injection_transport_async.py @@ -33,6 +33,7 @@ host = test_config.TestConfig.host master_key = test_config.TestConfig.masterKey TEST_DATABASE_ID = test_config.TestConfig.TEST_DATABASE_ID +SINGLE_PARTITION_CONTAINER_NAME = os.path.basename(__file__) + str(uuid.uuid4()) @pytest.mark.cosmosEmulator @pytest.mark.asyncio @@ -50,7 +51,7 @@ async def asyncSetUp(cls): "'masterKey' and 'host' at the top of this class to run the " "tests.") cls.database_id = TEST_DATABASE_ID - cls.single_partition_container_name = os.path.basename(__file__) + str(uuid.uuid4()) + cls.single_partition_container_name = SINGLE_PARTITION_CONTAINER_NAME cls.mgmt_client = CosmosClient(cls.host, cls.master_key, consistency_level="Session", logger=logger) created_database = cls.mgmt_client.get_database_client(cls.database_id) @@ -76,11 +77,18 @@ async def asyncTearDown(cls): except Exception as closeError: logger.warning("Exception trying to close client {}. {}".format(created_database.id, closeError)) - async def setup_method_with_custom_transport(self, custom_transport: AioHttpTransport, default_endpoint=host, **kwargs): - client = CosmosClient(default_endpoint, master_key, consistency_level="Session", + @staticmethod + async def setup_method_with_custom_transport( + custom_transport: AioHttpTransport, + default_endpoint: str = host, + key: str = master_key, + database_id: str = TEST_DATABASE_ID, + container_id: str = SINGLE_PARTITION_CONTAINER_NAME, + **kwargs): + client = CosmosClient(default_endpoint, key, consistency_level="Session", transport=custom_transport, logger=logger, enable_diagnostics_logging=True, **kwargs) - db: DatabaseProxy = client.get_database_client(TEST_DATABASE_ID) - container: ContainerProxy = db.get_container_client(self.single_partition_container_name) + db: DatabaseProxy = client.get_database_client(database_id) + container: ContainerProxy = db.get_container_client(container_id) return {"client": client, "db": db, "col": container} @staticmethod @@ -106,7 +114,7 @@ async def test_throws_injected_error_async(self: "TestFaultInjectionTransportAsy status_code=502, message="Some random reverse proxy error.")))) - initialized_objects = await self.setup_method_with_custom_transport(custom_transport) + initialized_objects = await TestFaultInjectionTransportAsync.setup_method_with_custom_transport(custom_transport) start: float = time.perf_counter() try: container: ContainerProxy = initialized_objects["col"] @@ -151,7 +159,7 @@ async def test_swr_mrr_succeeds_async(self: "TestFaultInjectionTransportAsync"): 'name': 'sample document', 'key': 'value'} - initialized_objects = await self.setup_method_with_custom_transport( + initialized_objects = await TestFaultInjectionTransportAsync.setup_method_with_custom_transport( custom_transport, preferred_locations=["Read Region", "Write Region"]) try: @@ -210,7 +218,7 @@ async def test_swr_mrr_region_down_read_succeeds_async(self: "TestFaultInjection 'name': 'sample document', 'key': 'value'} - initialized_objects = await self.setup_method_with_custom_transport( + initialized_objects = await TestFaultInjectionTransportAsync.setup_method_with_custom_transport( custom_transport, default_endpoint=expected_write_region_uri, preferred_locations=["Read Region", "Write Region"]) @@ -275,7 +283,7 @@ async def test_swr_mrr_region_down_envoy_read_succeeds_async(self: "TestFaultInj 'name': 'sample document', 'key': 'value'} - initialized_objects = await self.setup_method_with_custom_transport( + initialized_objects = await TestFaultInjectionTransportAsync.setup_method_with_custom_transport( custom_transport, default_endpoint=expected_write_region_uri, preferred_locations=["Read Region", "Write Region"]) @@ -320,7 +328,7 @@ async def test_mwr_succeeds_async(self: "TestFaultInjectionTransportAsync"): 'name': 'sample document', 'key': 'value'} - initialized_objects = await self.setup_method_with_custom_transport( + initialized_objects = await TestFaultInjectionTransportAsync.setup_method_with_custom_transport( custom_transport, preferred_locations=["First Region", "Second Region"]) try: @@ -373,7 +381,7 @@ async def test_mwr_region_down_succeeds_async(self: "TestFaultInjectionTransport 'name': 'sample document', 'key': 'value'} - initialized_objects = await self.setup_method_with_custom_transport( + initialized_objects = await TestFaultInjectionTransportAsync.setup_method_with_custom_transport( custom_transport, preferred_locations=["First Region", "Second Region"]) try: @@ -445,7 +453,7 @@ async def test_swr_mrr_all_regions_down_for_read_async(self: "TestFaultInjection 'name': 'sample document', 'key': 'value'} - initialized_objects = await self.setup_method_with_custom_transport( + initialized_objects = await TestFaultInjectionTransportAsync.setup_method_with_custom_transport( custom_transport, preferred_locations=["First Region", "Second Region"]) try: @@ -498,7 +506,7 @@ async def test_mwr_all_regions_down_async(self: "TestFaultInjectionTransportAsyn 'name': 'sample document', 'key': 'value'} - initialized_objects = await self.setup_method_with_custom_transport( + initialized_objects = await TestFaultInjectionTransportAsync.setup_method_with_custom_transport( custom_transport, preferred_locations=["First Region", "Second Region"]) try: diff --git a/sdk/cosmos/azure-cosmos/tests/test_health_check.py b/sdk/cosmos/azure-cosmos/tests/test_health_check.py index 0d313e6c911c..75db9deacc41 100644 --- a/sdk/cosmos/azure-cosmos/tests/test_health_check.py +++ b/sdk/cosmos/azure-cosmos/tests/test_health_check.py @@ -126,14 +126,14 @@ def test_health_check_timeouts_on_unavailable_endpoints(self, setup): locational_endpoint = _location_cache.LocationCache.GetLocationalEndpoint(TestHealthCheck.host, REGION_1) setup[COLLECTION].client_connection._global_endpoint_manager.location_cache.mark_endpoint_unavailable_for_read( locational_endpoint, True) - self.original_preferred_locations = setup[COLLECTION].client_connection._global_endpoint_manager.location_cache.preferred_locations - setup[COLLECTION].client_connection._global_endpoint_manager.location_cache.preferred_locations = REGIONS + self.original_preferred_locations = setup[COLLECTION].client_connection.connection_policy.PreferredLocations + setup[COLLECTION].client_connection.connection_policy.PreferredLocations = REGIONS try: setup[COLLECTION].create_item(body={'id': 'item' + str(uuid.uuid4()), 'pk': 'pk'}) finally: _global_endpoint_manager._GlobalEndpointManager._GetDatabaseAccountStub = self.original_getDatabaseAccountStub _cosmos_client_connection.CosmosClientConnection._GetDatabaseAccountCheck = self.original_getDatabaseAccountCheck - setup[COLLECTION].client_connection._global_endpoint_manager.location_cache.preferred_locations = self.original_preferred_locations + setup[COLLECTION].client_connection.connection_policy.PreferredLocations = self.original_preferred_locations class MockGetDatabaseAccountCheck(object): def __init__(self, client_connection=None, endpoint_unavailable=False): diff --git a/sdk/cosmos/azure-cosmos/tests/test_health_check_async.py b/sdk/cosmos/azure-cosmos/tests/test_health_check_async.py index ae2bf13fd8a7..a92eca0dd778 100644 --- a/sdk/cosmos/azure-cosmos/tests/test_health_check_async.py +++ b/sdk/cosmos/azure-cosmos/tests/test_health_check_async.py @@ -153,8 +153,8 @@ async def test_health_check_success(self, setup, preferred_location, use_write_g # checks the background health check works as expected when all endpoints healthy self.original_getDatabaseAccountStub = _global_endpoint_manager_async._GlobalEndpointManager._GetDatabaseAccountStub self.original_getDatabaseAccountCheck = _cosmos_client_connection_async.CosmosClientConnection._GetDatabaseAccountCheck - self.original_preferred_locations = setup[COLLECTION].client_connection._global_endpoint_manager.location_cache.preferred_locations - setup[COLLECTION].client_connection._global_endpoint_manager.location_cache.preferred_locations = preferred_location + self.original_preferred_locations = setup[COLLECTION].client_connection.connection_policy.PreferredLocations + setup[COLLECTION].client_connection.connection_policy.PreferredLocations = preferred_location mock_get_database_account_check = self.MockGetDatabaseAccountCheck() _global_endpoint_manager_async._GlobalEndpointManager._GetDatabaseAccountStub = ( self.MockGetDatabaseAccount(REGIONS, use_write_global_endpoint, use_read_global_endpoint)) @@ -168,7 +168,7 @@ async def test_health_check_success(self, setup, preferred_location, use_write_g finally: _global_endpoint_manager_async._GlobalEndpointManager._GetDatabaseAccountStub = self.original_getDatabaseAccountStub _cosmos_client_connection_async.CosmosClientConnection._GetDatabaseAccountCheck = self.original_getDatabaseAccountCheck - setup[COLLECTION].client_connection._global_endpoint_manager.location_cache.preferred_locations = self.original_preferred_locations + setup[COLLECTION].client_connection.connection_policy.PreferredLocations = self.original_preferred_locations expected_regional_routing_contexts = [] locational_endpoint = _location_cache.LocationCache.GetLocationalEndpoint(self.host, REGION_1) @@ -189,8 +189,8 @@ async def test_health_check_failure(self, setup, preferred_location, use_write_g self.original_getDatabaseAccountStub = _global_endpoint_manager_async._GlobalEndpointManager._GetDatabaseAccountStub _global_endpoint_manager_async._GlobalEndpointManager._GetDatabaseAccountStub = ( self.MockGetDatabaseAccount(REGIONS, use_write_global_endpoint, use_read_global_endpoint)) - self.original_preferred_locations = setup[COLLECTION].client_connection._global_endpoint_manager.location_cache.preferred_locations - setup[COLLECTION].client_connection._global_endpoint_manager.location_cache.preferred_locations = preferred_location + self.original_preferred_locations = setup[COLLECTION].client_connection.connection_policy.PreferredLocations + setup[COLLECTION].client_connection.connection_policy.PreferredLocations = preferred_location try: setup[COLLECTION].client_connection._global_endpoint_manager.startup = False @@ -201,7 +201,7 @@ async def test_health_check_failure(self, setup, preferred_location, use_write_g await asyncio.sleep(1) finally: _global_endpoint_manager_async._GlobalEndpointManager._GetDatabaseAccountStub = self.original_getDatabaseAccountStub - setup[COLLECTION].client_connection._global_endpoint_manager.location_cache.preferred_locations = self.original_preferred_locations + setup[COLLECTION].client_connection.connection_policy.PreferredLocations = self.original_preferred_locations if not use_write_global_endpoint: num_unavailable_endpoints = len(REGIONS) diff --git a/sdk/cosmos/azure-cosmos/tests/test_location_cache.py b/sdk/cosmos/azure-cosmos/tests/test_location_cache.py index a957094f1790..f65a1f1a3d21 100644 --- a/sdk/cosmos/azure-cosmos/tests/test_location_cache.py +++ b/sdk/cosmos/azure-cosmos/tests/test_location_cache.py @@ -3,8 +3,10 @@ import time import unittest +from typing import Mapping, Any import pytest +from azure.cosmos import documents from azure.cosmos.documents import DatabaseAccount, _OperationType from azure.cosmos.http_constants import ResourceType @@ -35,15 +37,15 @@ def create_database_account(enable_multiple_writable_locations): return db_acc -def refresh_location_cache(preferred_locations, use_multiple_write_locations): - lc = LocationCache(preferred_locations=preferred_locations, - default_endpoint=default_endpoint, - enable_endpoint_discovery=True, - use_multiple_write_locations=use_multiple_write_locations) +def refresh_location_cache(preferred_locations, use_multiple_write_locations, connection_policy=documents.ConnectionPolicy()): + connection_policy.PreferredLocations = preferred_locations + connection_policy.UseMultipleWriteLocations = use_multiple_write_locations + lc = LocationCache(default_endpoint=default_endpoint, + connection_policy=connection_policy) return lc @pytest.mark.cosmosEmulator -class TestLocationCache(unittest.TestCase): +class TestLocationCache: def test_mark_endpoint_unavailable(self): lc = refresh_location_cache([], False) @@ -136,6 +138,140 @@ def test_resolve_request_endpoint_preferred_regions(self): assert read_resolved == write_resolved assert read_resolved == default_endpoint + @pytest.mark.parametrize("test_type",["OnClient", "OnRequest", "OnBoth"]) + def test_get_applicable_regional_endpoints_excluded_regions(self, test_type): + # Init test data + if test_type == "OnClient": + excluded_locations_on_client_list = [ + [location1_name], + [location1_name, location2_name], + [location1_name, location2_name, location3_name], + [location4_name], + [], + ] + excluded_locations_on_requests_list = [None] * 5 + elif test_type == "OnRequest": + excluded_locations_on_client_list = [[]] * 5 + excluded_locations_on_requests_list = [ + [location1_name], + [location1_name, location2_name], + [location1_name, location2_name, location3_name], + [location4_name], + [], + ] + else: + excluded_locations_on_client_list = [ + [location1_name], + [location1_name, location2_name, location3_name], + [location1_name, location2_name], + [location2_name], + [location1_name, location2_name, location3_name], + ] + excluded_locations_on_requests_list = [ + [location1_name], + [location1_name, location2_name], + [location1_name, location2_name, location3_name], + [location4_name], + [], + ] + + expected_read_endpoints_list = [ + [location2_endpoint], + [location1_endpoint], + [location1_endpoint], + [location1_endpoint, location2_endpoint], + [location1_endpoint, location2_endpoint], + ] + expected_write_endpoints_list = [ + [location2_endpoint, location3_endpoint], + [location3_endpoint], + [default_endpoint], + [location1_endpoint, location2_endpoint, location3_endpoint], + [location1_endpoint, location2_endpoint, location3_endpoint], + ] + + # Loop over each test cases + for excluded_locations_on_client, excluded_locations_on_requests, expected_read_endpoints, expected_write_endpoints in zip(excluded_locations_on_client_list, excluded_locations_on_requests_list, expected_read_endpoints_list, expected_write_endpoints_list): + # Init excluded_locations in ConnectionPolicy + connection_policy = documents.ConnectionPolicy() + connection_policy.ExcludedLocations = excluded_locations_on_client + + # Init location_cache + location_cache = refresh_location_cache([location1_name, location2_name, location3_name], True, + connection_policy) + database_account = create_database_account(True) + location_cache.perform_on_database_account_read(database_account) + + # Init requests and set excluded regions on requests + write_doc_request = RequestObject(ResourceType.Document, _OperationType.Create) + write_doc_request.excluded_locations = excluded_locations_on_requests + read_doc_request = RequestObject(ResourceType.Document, _OperationType.Read) + read_doc_request.excluded_locations = excluded_locations_on_requests + + # Test if read endpoints were correctly filtered on client level + read_doc_endpoint = location_cache._get_applicable_read_regional_endpoints(read_doc_request) + read_doc_endpoint = [regional_endpoint.get_primary() for regional_endpoint in read_doc_endpoint] + assert read_doc_endpoint == expected_read_endpoints + + # Test if write endpoints were correctly filtered on client level + write_doc_endpoint = location_cache._get_applicable_write_regional_endpoints(write_doc_request) + write_doc_endpoint = [regional_endpoint.get_primary() for regional_endpoint in write_doc_endpoint] + assert write_doc_endpoint == expected_write_endpoints + + def test_set_excluded_locations_for_requests(self): + # Init excluded_locations in ConnectionPolicy + excluded_locations_on_client = [location1_name, location2_name] + connection_policy = documents.ConnectionPolicy() + connection_policy.ExcludedLocations = excluded_locations_on_client + + # Init location_cache + location_cache = refresh_location_cache([location1_name, location2_name, location3_name], True, + connection_policy) + database_account = create_database_account(True) + location_cache.perform_on_database_account_read(database_account) + + # Test setting excluded locations + excluded_locations = [location1_name] + options: Mapping[str, Any] = {"excludedLocations": excluded_locations} + + expected_excluded_locations = excluded_locations + read_doc_request = RequestObject(ResourceType.Document, _OperationType.Create) + read_doc_request.set_excluded_location_from_options(options) + actual_excluded_locations = read_doc_request.excluded_locations + assert actual_excluded_locations == expected_excluded_locations + + expected_read_endpoints = [location2_endpoint] + read_doc_endpoint = location_cache._get_applicable_read_regional_endpoints(read_doc_request) + read_doc_endpoint = [regional_endpoint.get_primary() for regional_endpoint in read_doc_endpoint] + assert read_doc_endpoint == expected_read_endpoints + + + # Test setting excluded locations with invalid resource types + expected_excluded_locations = None + for resource_type in [ResourceType.Offer, ResourceType.Conflict]: + options: Mapping[str, Any] = {"excludedLocations": [location1_name]} + read_doc_request = RequestObject(resource_type, _OperationType.Create) + read_doc_request.set_excluded_location_from_options(options) + actual_excluded_locations = read_doc_request.excluded_locations + assert actual_excluded_locations == expected_excluded_locations + + expected_read_endpoints = [location1_endpoint] + read_doc_endpoint = location_cache._get_applicable_read_regional_endpoints(read_doc_request) + read_doc_endpoint = [regional_endpoint.get_primary() for regional_endpoint in read_doc_endpoint] + assert read_doc_endpoint == expected_read_endpoints + + + + # Test setting excluded locations with None value + expected_error_message = ("Excluded locations cannot be None. " + "If you want to remove all excluded locations, try passing an empty list.") + with pytest.raises(ValueError) as e: + options: Mapping[str, Any] = {"excludedLocations": None} + doc_request = RequestObject(ResourceType.Document, _OperationType.Create) + doc_request.set_excluded_location_from_options(options) + assert str( + e.value) == expected_error_message + if __name__ == "__main__": unittest.main() diff --git a/sdk/cosmos/azure-cosmos/tests/test_ppcb_mm_async.py b/sdk/cosmos/azure-cosmos/tests/test_ppcb_mm_async.py new file mode 100644 index 000000000000..2d033890429d --- /dev/null +++ b/sdk/cosmos/azure-cosmos/tests/test_ppcb_mm_async.py @@ -0,0 +1,408 @@ +# The MIT License (MIT) +# Copyright (c) Microsoft Corporation. All rights reserved. +import asyncio +import os +import unittest +import uuid +from typing import Dict, Any + +import pytest +import pytest_asyncio +from azure.core.pipeline.transport._aiohttp import AioHttpTransport +from azure.core.exceptions import ServiceResponseError + +import test_config +from azure.cosmos import PartitionKey, _location_cache +from azure.cosmos._partition_health_tracker import HEALTH_STATUS, UNHEALTHY, UNHEALTHY_TENTATIVE +from azure.cosmos.aio import CosmosClient +from azure.cosmos.exceptions import CosmosHttpResponseError +from _fault_injection_transport_async import FaultInjectionTransportAsync + +REGION_1 = "West US 3" +REGION_2 = "Mexico Central" # "West US" + + +COLLECTION = "created_collection" +@pytest_asyncio.fixture() +async def setup(): + os.environ["AZURE_COSMOS_ENABLE_CIRCUIT_BREAKER"] = "True" + client = CosmosClient(TestPerPartitionCircuitBreakerMMAsync.host, TestPerPartitionCircuitBreakerMMAsync.master_key, consistency_level="Session") + created_database = client.get_database_client(TestPerPartitionCircuitBreakerMMAsync.TEST_DATABASE_ID) + await client.create_database_if_not_exists(TestPerPartitionCircuitBreakerMMAsync.TEST_DATABASE_ID) + created_collection = await created_database.create_container(TestPerPartitionCircuitBreakerMMAsync.TEST_CONTAINER_SINGLE_PARTITION_ID, + partition_key=PartitionKey("/pk"), + offer_throughput=10000) + yield { + COLLECTION: created_collection + } + + await created_database.delete_container(TestPerPartitionCircuitBreakerMMAsync.TEST_CONTAINER_SINGLE_PARTITION_ID) + await client.close() + os.environ["AZURE_COSMOS_ENABLE_CIRCUIT_BREAKER"] = "False" + +def write_operations_and_errors(): + write_operations = ["create", "upsert", "replace", "delete", "patch", "batch", "delete_all_items_by_partition_key"] + errors = [] + error_codes = [408, 500, 502, 503] + for error_code in error_codes: + errors.append(CosmosHttpResponseError( + status_code=error_code, + message="Some injected error.")) + errors.append(ServiceResponseError(message="Injected Service Response Error.")) + params = [] + for write_operation in write_operations: + for error in errors: + params.append((write_operation, error)) + + return params + +def read_operations_and_errors(): + read_operations = ["read", "query", "changefeed", "read_all_items"] + errors = [] + error_codes = [408, 500, 502, 503] + for error_code in error_codes: + errors.append(CosmosHttpResponseError( + status_code=error_code, + message="Some injected error.")) + errors.append(ServiceResponseError(message="Injected Service Response Error.")) + params = [] + for read_operation in read_operations: + for error in errors: + params.append((read_operation, error)) + + return params + +def validate_response_uri(response, expected_uri): + request = response.get_response_headers()["_request"] + assert request.url.startswith(expected_uri) + +async def perform_write_operation(operation, container, fault_injection_container, doc_id, pk, expected_uri): + doc = {'id': doc_id, + 'pk': pk, + 'name': 'sample document', + 'key': 'value'} + if operation == "create": + resp = await fault_injection_container.create_item(body=doc) + elif operation == "upsert": + resp = await fault_injection_container.upsert_item(body=doc) + elif operation == "replace": + new_doc = {'id': doc_id, + 'pk': pk, + 'name': 'sample document' + str(uuid), + 'key': 'value'} + resp = await fault_injection_container.replace_item(item=doc['id'], body=new_doc) + elif operation == "delete": + await container.create_item(body=doc) + resp = await fault_injection_container.delete_item(item=doc['id'], partition_key=doc['pk']) + elif operation == "patch": + operations = [{"op": "incr", "path": "/company", "value": 3}] + resp = await fault_injection_container.patch_item(item=doc['id'], partition_key=doc['pk'], patch_operations=operations) + elif operation == "batch": + batch_operations = [ + ("create", (doc, )), + ("upsert", (doc,)), + ("upsert", (doc,)), + ("upsert", (doc,)), + ] + resp = await fault_injection_container.execute_item_batch(batch_operations, partition_key=doc['pk']) + elif operation == "delete_all_items_by_partition_key": + await container.create_item(body=doc) + await container.create_item(body=doc) + await container.create_item(body=doc) + resp = await fault_injection_container.delete_all_items_by_partition_key(pk) + validate_response_uri(resp, expected_uri) + + +@pytest.mark.cosmosMultiRegion +@pytest.mark.asyncio +@pytest.mark.usefixtures("setup") +class TestPerPartitionCircuitBreakerMMAsync: + host = test_config.TestConfig.host + master_key = test_config.TestConfig.masterKey + connectionPolicy = test_config.TestConfig.connectionPolicy + TEST_DATABASE_ID = test_config.TestConfig.TEST_DATABASE_ID + TEST_CONTAINER_SINGLE_PARTITION_ID = os.path.basename(__file__) + str(uuid.uuid4()) + + async def setup_method_with_custom_transport(self, custom_transport: AioHttpTransport, default_endpoint=host, **kwargs): + client = CosmosClient(default_endpoint, self.master_key, consistency_level="Session", + preferred_locations=[REGION_1, REGION_2], + multiple_write_locations=True, + transport=custom_transport, **kwargs) + db = client.get_database_client(self.TEST_DATABASE_ID) + container = db.get_container_client(self.TEST_CONTAINER_SINGLE_PARTITION_ID) + return {"client": client, "db": db, "col": container} + + @staticmethod + async def cleanup_method(initialized_objects: Dict[str, Any]): + method_client: CosmosClient = initialized_objects["client"] + await method_client.close() + + @staticmethod + async def perform_read_operation(operation, container, doc_id, pk, expected_uri): + if operation == "read": + read_resp = await container.read_item(item=doc_id, partition_key=pk) + request = read_resp.get_response_headers()["_request"] + # Validate the response comes from "Read Region" (the most preferred read-only region) + assert request.url.startswith(expected_uri) + elif operation == "query": + query = "SELECT * FROM c WHERE c.id = @id AND c.pk = @pk" + parameters = [{"name": "@id", "value": doc_id}, {"name": "@pk", "value": pk}] + async for item in container.query_items(query=query, partition_key=pk, parameters=parameters): + assert item['id'] == doc_id + # need to do query with no pk and with feed range + elif operation == "changefeed": + async for _ in container.query_items_change_feed(): + pass + elif operation == "read_all_items": + async for item in container.read_all_items(partition_key=pk): + assert item['pk'] == pk + + + @pytest.mark.parametrize("write_operation, error", write_operations_and_errors()) + async def test_write_consecutive_failure_threshold_async(self, setup, write_operation, error): + expected_uri = _location_cache.LocationCache.GetLocationalEndpoint(self.host, REGION_2) + custom_transport = FaultInjectionTransportAsync() + id_value = 'failoverDoc-' + str(uuid.uuid4()) + document_definition = {'id': id_value, + 'pk': 'pk1', + 'name': 'sample document', + 'key': 'value'} + predicate = lambda r: (FaultInjectionTransportAsync.predicate_is_document_operation(r) and + FaultInjectionTransportAsync.predicate_targets_region(r, expected_uri)) + custom_transport.add_fault(predicate, lambda r: asyncio.create_task(FaultInjectionTransportAsync.error_after_delay( + 0, + error + ))) + + custom_setup = await self.setup_method_with_custom_transport(custom_transport, default_endpoint=self.host) + container = custom_setup['col'] + global_endpoint_manager = container.client_connection._global_endpoint_manager + + + await perform_write_operation(write_operation, + setup[COLLECTION], + container, + document_definition['id'], + document_definition['pk'], + expected_uri) + + TestPerPartitionCircuitBreakerMMAsync.validate_unhealthy_partitions(global_endpoint_manager, 0) + + # writes should fail but still be tracked + for i in range(6): + with pytest.raises((CosmosHttpResponseError, ServiceResponseError)): + await perform_write_operation(write_operation, + setup[COLLECTION], + container, + document_definition['id'], + document_definition['pk'], + expected_uri) + + TestPerPartitionCircuitBreakerMMAsync.validate_unhealthy_partitions(global_endpoint_manager, 1) + + @pytest.mark.parametrize("read_operation, error", read_operations_and_errors()) + async def test_read_consecutive_failure_threshold_async(self, setup, read_operation, error): + expected_uri = _location_cache.LocationCache.GetLocationalEndpoint(self.host, REGION_2) + custom_transport = FaultInjectionTransportAsync() + id_value = 'failoverDoc-' + str(uuid.uuid4()) + document_definition = {'id': id_value, + 'pk': 'pk1', + 'name': 'sample document', + 'key': 'value'} + predicate = lambda r: (FaultInjectionTransportAsync.predicate_is_document_operation(r) and + FaultInjectionTransportAsync.predicate_targets_region(r, expected_uri)) + custom_transport.add_fault(predicate, lambda r: asyncio.create_task(FaultInjectionTransportAsync.error_after_delay( + 0, + error + ))) + + custom_setup = await self.setup_method_with_custom_transport(custom_transport, default_endpoint=self.host) + container = custom_setup['col'] + global_endpoint_manager = container.client_connection._global_endpoint_manager + + + await TestPerPartitionCircuitBreakerMMAsync.perform_write_operation(write_operation, + setup[COLLECTION], + container, + document_definition['id'], + document_definition['pk'], + expected_uri) + + # writes should fail in sm mrr with circuit breaker and should not mark unavailable a partition + for i in range(6): + with pytest.raises((CosmosHttpResponseError, ServiceResponseError)): + await TestPerPartitionCircuitBreakerMMAsync.perform_write_operation(write_operation, + setup[COLLECTION], + container, + document_definition['id'], + document_definition['pk'], + expected_uri) + + TestPerPartitionCircuitBreakerMMAsync.validate_unhealthy_partitions(global_endpoint_manager, 1) + + # create item with client without fault injection + await setup[COLLECTION].create_item(body=document_definition) + + # reads should fail over and only the relevant partition should be marked as unavailable + await TestPerPartitionCircuitBreakerMMAsync.perform_read_operation(read_operation, + container, + document_definition['id'], + document_definition['pk'], + expected_uri) + # partition should not have been marked unavailable after one error + TestPerPartitionCircuitBreakerMMAsync.validate_unhealthy_partitions(global_endpoint_manager, 1) + + for i in range(10): + await TestPerPartitionCircuitBreakerMMAsync.perform_read_operation(read_operation, + container, + document_definition['id'], + document_definition['pk'], + expected_uri) + + # the partition should have been marked as unavailable after breaking read threshold + TestPerPartitionCircuitBreakerMMAsync.validate_unhealthy_partitions(global_endpoint_manager, 1) + # test recovering the partition again + + @pytest.mark.parametrize("write_operation, error", write_operations_and_errors()) + async def test_failure_rate_threshold_async(self, setup, write_operation, error): + expected_uri = _location_cache.LocationCache.GetLocationalEndpoint(self.host, REGION_2) + custom_transport = await self.create_custom_transport_sm_mrr() + id_value = 'failoverDoc-' + str(uuid.uuid4()) + # two documents targeted to same partition, one will always fail and the other will succeed + document_definition = {'id': id_value, + 'pk': 'pk1', + 'name': 'sample document', + 'key': 'value'} + doc_2 = {'id': str(uuid.uuid4()), + 'pk': 'pk1', + 'name': 'sample document', + 'key': 'value'} + predicate = lambda r: (FaultInjectionTransportAsync.predicate_req_for_document_with_id(r, id_value) and + FaultInjectionTransportAsync.predicate_targets_region(r, expected_write_region_uri)) + custom_transport.add_fault(predicate, + lambda r: asyncio.create_task(FaultInjectionTransportAsync.error_after_delay( + 0, + error + ))) + + custom_setup = await self.setup_method_with_custom_transport(custom_transport, default_endpoint=self.host) + container = custom_setup['col'] + global_endpoint_manager = container.client_connection._global_endpoint_manager + # lower minimum requests for testing + global_endpoint_manager.global_partition_endpoint_manager_core.partition_health_tracker.MINIMUM_REQUESTS_FOR_FAILURE_RATE = 10 + try: + # writes should fail in sm mrr with circuit breaker and should not mark unavailable a partition + for i in range(14): + if i == 9: + await container.upsert_item(body=doc_2) + with pytest.raises((CosmosHttpResponseError, ServiceResponseError)): + await TestPerPartitionCircuitBreakerMMAsync.perform_write_operation(write_operation, + setup[COLLECTION], + container, + document_definition['id'], + document_definition['pk'], + expected_uri) + + TestPerPartitionCircuitBreakerMMAsync.validate_unhealthy_partitions(global_endpoint_manager, 0) + + # create item with client without fault injection + await setup[COLLECTION].create_item(body=document_definition) + finally: + # restore minimum requests + global_endpoint_manager.global_partition_endpoint_manager_core.partition_health_tracker.MINIMUM_REQUESTS_FOR_FAILURE_RATE = 100 + + + + @pytest.mark.parametrize("write_operation, error", write_operations_and_errors()) + async def test_read_failure_rate_threshold_async(self, setup, write_operation, error): + expected_uri = _location_cache.LocationCache.GetLocationalEndpoint(self.host, REGION_2) + custom_transport = await self.create_custom_transport_sm_mrr() + id_value = 'failoverDoc-' + str(uuid.uuid4()) + # two documents targeted to same partition, one will always fail and the other will succeed + document_definition = {'id': id_value, + 'pk': 'pk1', + 'name': 'sample document', + 'key': 'value'} + doc_2 = {'id': str(uuid.uuid4()), + 'pk': 'pk1', + 'name': 'sample document', + 'key': 'value'} + predicate = lambda r: (FaultInjectionTransportAsync.predicate_req_for_document_with_id(r, id_value) and + FaultInjectionTransportAsync.predicate_targets_region(r, expected_write_region_uri)) + custom_transport.add_fault(predicate, + lambda r: asyncio.create_task(FaultInjectionTransportAsync.error_after_delay( + 0, + error + ))) + + custom_setup = await self.setup_method_with_custom_transport(custom_transport, default_endpoint=self.host) + container = custom_setup['col'] + global_endpoint_manager = container.client_connection._global_endpoint_manager + # lower minimum requests for testing + global_endpoint_manager.global_partition_endpoint_manager_core.partition_health_tracker.MINIMUM_REQUESTS_FOR_FAILURE_RATE = 10 + try: + # writes should fail in sm mrr with circuit breaker and should not mark unavailable a partition + for i in range(14): + if i == 9: + await container.upsert_item(body=doc_2) + with pytest.raises((CosmosHttpResponseError, ServiceResponseError)): + await TestPerPartitionCircuitBreakerMMAsync.perform_write_operation(write_operation, + setup[COLLECTION], + container, + document_definition['id'], + document_definition['pk'], + expected_uri) + + TestPerPartitionCircuitBreakerMMAsync.validate_unhealthy_partitions(global_endpoint_manager, 0) + + # create item with client without fault injection + await setup[COLLECTION].create_item(body=document_definition) + + # reads should fail over and only the relevant partition should be marked as unavailable + await container.read_item(item=document_definition['id'], partition_key=document_definition['pk']) + # partition should not have been marked unavailable after one error + TestPerPartitionCircuitBreakerMMAsync.validate_unhealthy_partitions(global_endpoint_manager, 0) + for i in range(20): + if i == 8: + read_resp = await container.read_item(item=doc_2['id'], + partition_key=doc_2['pk']) + request = read_resp.get_response_headers()["_request"] + # Validate the response comes from "Read Region" (the most preferred read-only region) + assert request.url.startswith(expected_uri) + else: + await TestPerPartitionCircuitBreakerMMAsync.perform_read_operation(read_operation, + container, + document_definition['id'], + document_definition['pk'], + expected_uri) + # the partition should have been marked as unavailable after breaking read threshold + TestPerPartitionCircuitBreakerMMAsync.validate_unhealthy_partitions(global_endpoint_manager, 1) + finally: + # restore minimum requests + global_endpoint_manager.global_partition_endpoint_manager_core.partition_health_tracker.MINIMUM_REQUESTS_FOR_FAILURE_RATE = 100 + + # look at the urls for verifying fall back and use another id for same partition + + @staticmethod + def validate_unhealthy_partitions(global_endpoint_manager, + expected_unhealthy_partitions): + health_info_map = global_endpoint_manager.global_partition_endpoint_manager_core.partition_health_tracker.pk_range_wrapper_to_health_info + unhealthy_partitions = 0 + for pk_range_wrapper, location_to_health_info in health_info_map.items(): + for location, health_info in location_to_health_info.items(): + health_status = health_info.unavailability_info.get(HEALTH_STATUS) + if health_status == UNHEALTHY_TENTATIVE or health_status == UNHEALTHY: + unhealthy_partitions += 1 + else: + assert health_info.read_consecutive_failure_count < 10 + assert health_info.write_failure_count == 0 + assert health_info.write_consecutive_failure_count == 0 + + assert unhealthy_partitions == expected_unhealthy_partitions + + # test_failure_rate_threshold - add service response error - across operation types - test recovering the partition again + # test service request marks only a partition unavailable not an entire region - across operation types + # test cosmos client timeout + +if __name__ == '__main__': + unittest.main() diff --git a/sdk/cosmos/azure-cosmos/tests/test_ppcb_sm_mrr_async.py b/sdk/cosmos/azure-cosmos/tests/test_ppcb_sm_mrr_async.py new file mode 100644 index 000000000000..ec6cf5ecff7d --- /dev/null +++ b/sdk/cosmos/azure-cosmos/tests/test_ppcb_sm_mrr_async.py @@ -0,0 +1,314 @@ +# The MIT License (MIT) +# Copyright (c) Microsoft Corporation. All rights reserved. +import asyncio +import os +import unittest +import uuid +from typing import Dict, Any + +import pytest +import pytest_asyncio +from azure.core.pipeline.transport._aiohttp import AioHttpTransport +from azure.core.exceptions import ServiceResponseError + +import test_config +from azure.cosmos import PartitionKey +from azure.cosmos._partition_health_tracker import HEALTH_STATUS, UNHEALTHY, UNHEALTHY_TENTATIVE +from azure.cosmos.aio import CosmosClient +from azure.cosmos.exceptions import CosmosHttpResponseError +from _fault_injection_transport_async import FaultInjectionTransportAsync + +COLLECTION = "created_collection" +@pytest_asyncio.fixture() +async def setup(): + os.environ["AZURE_COSMOS_ENABLE_CIRCUIT_BREAKER"] = "True" + client = CosmosClient(TestPerPartitionCircuitBreakerSmMrrAsync.host, TestPerPartitionCircuitBreakerSmMrrAsync.master_key, consistency_level="Session") + created_database = client.get_database_client(TestPerPartitionCircuitBreakerSmMrrAsync.TEST_DATABASE_ID) + await client.create_database_if_not_exists(TestPerPartitionCircuitBreakerSmMrrAsync.TEST_DATABASE_ID) + created_collection = await created_database.create_container(TestPerPartitionCircuitBreakerSmMrrAsync.TEST_CONTAINER_SINGLE_PARTITION_ID, + partition_key=PartitionKey("/pk"), + offer_throughput=10000) + yield { + COLLECTION: created_collection + } + + await created_database.delete_container(TestPerPartitionCircuitBreakerSmMrrAsync.TEST_CONTAINER_SINGLE_PARTITION_ID) + await client.close() + os.environ["AZURE_COSMOS_ENABLE_CIRCUIT_BREAKER"] = "False" + +def operations_and_errors(): + write_operations = ["create", "upsert", "replace", "delete", "patch", "batch"] + read_operations = ["read", "query", "changefeed", "read_all_items", "delete_all_items_by_partition_key"] + errors = [] + error_codes = [408, 500, 502, 503] + for error_code in error_codes: + errors.append(CosmosHttpResponseError( + status_code=error_code, + message="Some injected error.")) + errors.append(ServiceResponseError(message="Injected Service Response Error.")) + params = [] + for write_operation in write_operations: + for read_operation in read_operations: + for error in errors: + params.append((write_operation, read_operation, error)) + + return params + + +@pytest.mark.cosmosEmulator +@pytest.mark.asyncio +@pytest.mark.usefixtures("setup") +class TestPerPartitionCircuitBreakerSmMrrAsync: + host = test_config.TestConfig.host + master_key = test_config.TestConfig.masterKey + connectionPolicy = test_config.TestConfig.connectionPolicy + TEST_DATABASE_ID = test_config.TestConfig.TEST_DATABASE_ID + TEST_CONTAINER_SINGLE_PARTITION_ID = os.path.basename(__file__) + str(uuid.uuid4()) + + async def setup_method_with_custom_transport(self, custom_transport: AioHttpTransport, default_endpoint=host, **kwargs): + client = CosmosClient(default_endpoint, self.master_key, consistency_level="Session", + preferred_locations=["Write Region", "Read Region"], + transport=custom_transport, **kwargs) + db = client.get_database_client(self.TEST_DATABASE_ID) + container = db.get_container_client(self.TEST_CONTAINER_SINGLE_PARTITION_ID) + return {"client": client, "db": db, "col": container} + + @staticmethod + async def cleanup_method(initialized_objects: Dict[str, Any]): + method_client: CosmosClient = initialized_objects["client"] + await method_client.close() + + @staticmethod + async def perform_write_operation(operation, container, doc_id, pk): + doc = {'id': doc_id, + 'pk': pk, + 'name': 'sample document', + 'key': 'value'} + if operation == "create": + await container.create_item(body=doc) + elif operation == "upsert": + await container.upsert_item(body=doc) + elif operation == "replace": + new_doc = {'id': doc_id, + 'pk': pk, + 'name': 'sample document' + str(uuid), + 'key': 'value'} + await container.replace_item(item=doc['id'], body=new_doc) + elif operation == "delete": + await container.create_item(body=doc) + await container.delete_item(item=doc['id'], partition_key=doc['pk']) + elif operation == "patch": + operations = [{"op": "incr", "path": "/company", "value": 3}] + await container.patch_item(item=doc['id'], partition_key=doc['pk'], patch_operations=operations) + elif operation == "batch": + batch_operations = [ + ("create", (doc, )), + ("upsert", (doc,)), + ("upsert", (doc,)), + ("upsert", (doc,)), + ] + await container.execute_item_batch(batch_operations, partition_key=doc['pk']) + elif operation == "delete_all_items_by_partition_key": + await container.create_item(body=doc) + await container.create_item(body=doc) + await container.create_item(body=doc) + await container.delete_all_items_by_partition_key(pk) + + @staticmethod + async def perform_read_operation(operation, container, doc_id, pk, expected_read_region_uri): + if operation == "read": + read_resp = await container.read_item(item=doc_id, partition_key=pk) + request = read_resp.get_response_headers()["_request"] + # Validate the response comes from "Read Region" (the most preferred read-only region) + assert request.url.startswith(expected_read_region_uri) + elif operation == "query": + query = "SELECT * FROM c WHERE c.id = @id AND c.pk = @pk" + parameters = [{"name": "@id", "value": doc_id}, {"name": "@pk", "value": pk}] + async for item in container.query_items(query=query, partition_key=pk, parameters=parameters): + assert item['id'] == doc_id + # need to do query with no pk and with feed range + elif operation == "changefeed": + async for _ in container.query_items_change_feed(): + pass + elif operation == "read_all_items": + async for item in container.read_all_items(partition_key=pk): + assert item['pk'] == pk + + + + + + async def create_custom_transport_sm_mrr(self): + custom_transport = FaultInjectionTransportAsync() + # Inject rule to disallow writes in the read-only region + is_write_operation_in_read_region_predicate = lambda \ + r: FaultInjectionTransportAsync.predicate_is_write_operation(r, self.host) + + custom_transport.add_fault( + is_write_operation_in_read_region_predicate, + lambda r: asyncio.create_task(FaultInjectionTransportAsync.error_write_forbidden())) + + # Inject topology transformation that would make Emulator look like a single write region + # account with two read regions + is_get_account_predicate = lambda r: FaultInjectionTransportAsync.predicate_is_database_account_call(r) + emulator_as_multi_region_sm_account_transformation = \ + lambda r, inner: FaultInjectionTransportAsync.transform_topology_swr_mrr( + write_region_name="Write Region", + read_region_name="Read Region", + inner=inner) + custom_transport.add_response_transformation( + is_get_account_predicate, + emulator_as_multi_region_sm_account_transformation) + return custom_transport + + + # split this into write and read tests + + @pytest.mark.parametrize("write_operation, read_operation, error", operations_and_errors()) + async def test_consecutive_failure_threshold_async(self, setup, write_operation, read_operation, error): + expected_read_region_uri = self.host + expected_write_region_uri = expected_read_region_uri.replace("localhost", "127.0.0.1") + custom_transport = await self.create_custom_transport_sm_mrr() + id_value = 'failoverDoc-' + str(uuid.uuid4()) + document_definition = {'id': id_value, + 'pk': 'pk1', + 'name': 'sample document', + 'key': 'value'} + predicate = lambda r: (FaultInjectionTransportAsync.predicate_is_document_operation(r) and + FaultInjectionTransportAsync.predicate_targets_region(r, expected_write_region_uri)) + custom_transport.add_fault(predicate, lambda r: asyncio.create_task(FaultInjectionTransportAsync.error_after_delay( + 0, + error + ))) + + custom_setup = await self.setup_method_with_custom_transport(custom_transport, default_endpoint=self.host) + container = custom_setup['col'] + + # writes should fail in sm mrr with circuit breaker and should not mark unavailable a partition + for i in range(6): + with pytest.raises((CosmosHttpResponseError, ServiceResponseError)): + await TestPerPartitionCircuitBreakerSmMrrAsync.perform_write_operation(write_operation, + container, + document_definition['id'], + document_definition['pk']) + global_endpoint_manager = container.client_connection._global_endpoint_manager + + TestPerPartitionCircuitBreakerSmMrrAsync.validate_unhealthy_partitions(global_endpoint_manager, 0) + + # create item with client without fault injection + await setup[COLLECTION].create_item(body=document_definition) + + # reads should fail over and only the relevant partition should be marked as unavailable + await TestPerPartitionCircuitBreakerSmMrrAsync.perform_read_operation(read_operation, + container, + document_definition['id'], + document_definition['pk'], + expected_read_region_uri) + # partition should not have been marked unavailable after one error + TestPerPartitionCircuitBreakerSmMrrAsync.validate_unhealthy_partitions(global_endpoint_manager, 0) + + for i in range(10): + await TestPerPartitionCircuitBreakerSmMrrAsync.perform_read_operation(read_operation, + container, + document_definition['id'], + document_definition['pk'], + expected_read_region_uri) + + # the partition should have been marked as unavailable after breaking read threshold + TestPerPartitionCircuitBreakerSmMrrAsync.validate_unhealthy_partitions(global_endpoint_manager, 1) + # test recovering the partition again + + @pytest.mark.parametrize("write_operation, read_operation, error", operations_and_errors()) + async def test_failure_rate_threshold_async(self, setup, write_operation, read_operation, error): + expected_read_region_uri = self.host + expected_write_region_uri = expected_read_region_uri.replace("localhost", "127.0.0.1") + custom_transport = await self.create_custom_transport_sm_mrr() + id_value = 'failoverDoc-' + str(uuid.uuid4()) + # two documents targeted to same partition, one will always fail and the other will succeed + document_definition = {'id': id_value, + 'pk': 'pk1', + 'name': 'sample document', + 'key': 'value'} + doc_2 = {'id': str(uuid.uuid4()), + 'pk': 'pk1', + 'name': 'sample document', + 'key': 'value'} + predicate = lambda r: (FaultInjectionTransportAsync.predicate_req_for_document_with_id(r, id_value) and + FaultInjectionTransportAsync.predicate_targets_region(r, expected_write_region_uri)) + custom_transport.add_fault(predicate, + lambda r: asyncio.create_task(FaultInjectionTransportAsync.error_after_delay( + 0, + error + ))) + + custom_setup = await self.setup_method_with_custom_transport(custom_transport, default_endpoint=self.host) + container = custom_setup['col'] + global_endpoint_manager = container.client_connection._global_endpoint_manager + # lower minimum requests for testing + global_endpoint_manager.global_partition_endpoint_manager_core.partition_health_tracker.MINIMUM_REQUESTS_FOR_FAILURE_RATE = 10 + try: + # writes should fail in sm mrr with circuit breaker and should not mark unavailable a partition + for i in range(14): + if i == 9: + await container.upsert_item(body=doc_2) + with pytest.raises((CosmosHttpResponseError, ServiceResponseError)): + await TestPerPartitionCircuitBreakerSmMrrAsync.perform_write_operation(write_operation, + container, + document_definition['id'], + document_definition['pk']) + + TestPerPartitionCircuitBreakerSmMrrAsync.validate_unhealthy_partitions(global_endpoint_manager, 0) + + # create item with client without fault injection + await setup[COLLECTION].create_item(body=document_definition) + + # reads should fail over and only the relevant partition should be marked as unavailable + await container.read_item(item=document_definition['id'], partition_key=document_definition['pk']) + # partition should not have been marked unavailable after one error + TestPerPartitionCircuitBreakerSmMrrAsync.validate_unhealthy_partitions(global_endpoint_manager, 0) + for i in range(20): + if i == 8: + read_resp = await container.read_item(item=doc_2['id'], + partition_key=doc_2['pk']) + request = read_resp.get_response_headers()["_request"] + # Validate the response comes from "Read Region" (the most preferred read-only region) + assert request.url.startswith(expected_read_region_uri) + else: + await TestPerPartitionCircuitBreakerSmMrrAsync.perform_read_operation(read_operation, + container, + document_definition['id'], + document_definition['pk'], + expected_read_region_uri) + # the partition should have been marked as unavailable after breaking read threshold + TestPerPartitionCircuitBreakerSmMrrAsync.validate_unhealthy_partitions(global_endpoint_manager, 1) + finally: + # restore minimum requests + global_endpoint_manager.global_partition_endpoint_manager_core.partition_health_tracker.MINIMUM_REQUESTS_FOR_FAILURE_RATE = 100 + + # look at the urls for verifying fall back and use another id for same partition + + @staticmethod + def validate_unhealthy_partitions(global_endpoint_manager, + expected_unhealthy_partitions): + health_info_map = global_endpoint_manager.global_partition_endpoint_manager_core.partition_health_tracker.pk_range_wrapper_to_health_info + unhealthy_partitions = 0 + for pk_range_wrapper, location_to_health_info in health_info_map.items(): + for location, health_info in location_to_health_info.items(): + health_status = health_info.unavailability_info.get(HEALTH_STATUS) + if health_status == UNHEALTHY_TENTATIVE or health_status == UNHEALTHY: + unhealthy_partitions += 1 + assert health_info.write_failure_count == 0 + assert health_info.write_consecutive_failure_count == 0 + else: + assert health_info.read_consecutive_failure_count < 10 + assert health_info.write_failure_count == 0 + assert health_info.write_consecutive_failure_count == 0 + + assert unhealthy_partitions == expected_unhealthy_partitions + + # test_failure_rate_threshold - add service response error - across operation types - test recovering the partition again + # test service request marks only a partition unavailable not an entire region - across operation types + # test cosmos client timeout + +if __name__ == '__main__': + unittest.main() diff --git a/sdk/cosmos/azure-cosmos/tests/test_query.py b/sdk/cosmos/azure-cosmos/tests/test_query.py index 28262aa0f7e3..2a99263ed457 100644 --- a/sdk/cosmos/azure-cosmos/tests/test_query.py +++ b/sdk/cosmos/azure-cosmos/tests/test_query.py @@ -17,7 +17,7 @@ from azure.cosmos.documents import _DistinctType from azure.cosmos.partition_key import PartitionKey - +@pytest.mark.cosmosQuery class TestQuery(unittest.TestCase): """Test to ensure escaping of non-ascii characters from partition key""" diff --git a/sdk/cosmos/azure-cosmos/tests/test_retry_policy_async.py b/sdk/cosmos/azure-cosmos/tests/test_retry_policy_async.py index be1683d1504d..4faef31c9495 100644 --- a/sdk/cosmos/azure-cosmos/tests/test_retry_policy_async.py +++ b/sdk/cosmos/azure-cosmos/tests/test_retry_policy_async.py @@ -42,6 +42,7 @@ def __init__(self) -> None: self.ProxyConfiguration: Optional[ProxyConfiguration] = None self.EnableEndpointDiscovery: bool = True self.PreferredLocations: List[str] = [] + self.ExcludedLocations = None self.RetryOptions: RetryOptions = RetryOptions() self.DisableSSLVerification: bool = False self.UseMultipleWriteLocations: bool = False diff --git a/sdk/cosmos/live-platform-matrix.json b/sdk/cosmos/live-platform-matrix.json index 485a15ca92e8..6763c1c06562 100644 --- a/sdk/cosmos/live-platform-matrix.json +++ b/sdk/cosmos/live-platform-matrix.json @@ -88,6 +88,39 @@ "TestMarkArgument": "cosmosLong" } } + }, + { + "WindowsConfig": { + "Windows2022_38": { + "OSVmImage": "env:WINDOWSVMIMAGE", + "Pool": "env:WINDOWSPOOL", + "PythonVersion": "3.8", + "CoverageArg": "--disablecov", + "TestSamples": "false", + "TestMarkArgument": "cosmosMultiRegion" + }, + "Windows2022_310": { + "OSVmImage": "env:WINDOWSVMIMAGE", + "Pool": "env:WINDOWSPOOL", + "PythonVersion": "3.10", + "CoverageArg": "--disablecov", + "TestSamples": "false", + "TestMarkArgument": "cosmosMultiRegion" + }, + "Windows2022_312": { + "OSVmImage": "env:WINDOWSVMIMAGE", + "Pool": "env:WINDOWSPOOL", + "PythonVersion": "3.12", + "CoverageArg": "--disablecov", + "TestSamples": "false", + "TestMarkArgument": "cosmosMultiRegion" + } + }, + "ArmConfig": { + "MultiMaster_MultiRegion": { + "ArmTemplateParameters": "@{ enableMultipleWriteLocations = $true; defaultConsistencyLevel = 'Session'; enableMultipleRegions = $true }" + } + } } ] } diff --git a/sdk/cosmos/test-resources.bicep b/sdk/cosmos/test-resources.bicep index 17d88b0be92a..88abe955f8d8 100644 --- a/sdk/cosmos/test-resources.bicep +++ b/sdk/cosmos/test-resources.bicep @@ -30,13 +30,13 @@ var singleRegionConfiguration = [ ] var multiRegionConfiguration = [ { - locationName: 'East US 2' + locationName: 'West US 3' provisioningState: 'Succeeded' failoverPriority: 0 isZoneRedundant: false } { - locationName: 'East US' + locationName: 'West US' provisioningState: 'Succeeded' failoverPriority: 1 isZoneRedundant: false