Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
25 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions sdk/cosmos/azure-cosmos/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,10 @@
#### Features Added

#### Breaking Changes
* Adds cross region retries when no preferred locations are set. This is only a breaking change for customers using bounded staleness consistency. See [PR 39714](https://github.com/Azure/azure-sdk-for-python/pull/39714)

#### Bugs Fixed
* Fixed bug where replacing manual throughput using `ThroughputProperties` would not work. See [PR 41564](https://github.com/Azure/azure-sdk-for-python/pull/41564)

#### Other Changes

Expand Down
89 changes: 53 additions & 36 deletions sdk/cosmos/azure-cosmos/azure/cosmos/_location_cache.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,12 +24,13 @@
"""
import collections
import logging
from typing import Set, Mapping, List
from typing import Set, Mapping, OrderedDict, Dict
from typing import List
from urllib.parse import urlparse

from . import documents, _base as base
from .http_constants import ResourceType
from .documents import _OperationType
from .documents import _OperationType, ConnectionPolicy
from ._request_object import RequestObject

# pylint: disable=protected-access
Expand All @@ -43,8 +44,8 @@ class EndpointOperationType(object):

class RegionalRoutingContext(object):
def __init__(self, primary_endpoint: str, alternate_endpoint: str):
self.primary_endpoint = primary_endpoint
self.alternate_endpoint = alternate_endpoint
self.primary_endpoint: str = primary_endpoint
self.alternate_endpoint: str = alternate_endpoint

def set_primary(self, endpoint: str):
self.primary_endpoint = endpoint
Expand All @@ -65,13 +66,13 @@ def __eq__(self, other):
def __str__(self):
return "Primary: " + self.primary_endpoint + ", Alternate: " + self.alternate_endpoint

def get_endpoints_by_location(new_locations,
old_endpoints_by_location,
default_regional_endpoint,
writes,
use_multiple_write_locations):
def get_endpoints_by_location(new_locations: List[Dict[str, str]],
old_regional_routing_contexts_by_location: Dict[str, RegionalRoutingContext],
default_regional_endpoint: RegionalRoutingContext,
writes: bool,
use_multiple_write_locations: bool):
# construct from previous object
endpoints_by_location = collections.OrderedDict()
regional_routing_context_by_location: OrderedDict[str, RegionalRoutingContext] = collections.OrderedDict()
parsed_locations = []


Expand All @@ -86,8 +87,8 @@ def get_endpoints_by_location(new_locations,
parsed_locations.append(new_location["name"])
if not writes or use_multiple_write_locations:
regional_object = RegionalRoutingContext(region_uri, region_uri)
elif new_location["name"] in old_endpoints_by_location:
regional_object = old_endpoints_by_location[new_location["name"]]
elif new_location["name"] in old_regional_routing_contexts_by_location:
regional_object = old_regional_routing_contexts_by_location[new_location["name"]]
current = regional_object.get_primary()
# swap the previous with current and current with new region_uri received from the gateway
if current != region_uri:
Expand All @@ -108,15 +109,14 @@ def get_endpoints_by_location(new_locations,
default_regional_endpoint.get_primary(),
new_location["name"])
regional_object.set_alternate(constructed_region_uri)
# pass in object with region uri , last known good, curr etc
endpoints_by_location.update({new_location["name"]: regional_object})
regional_routing_context_by_location.update({new_location["name"]: regional_object})
except Exception as e:
raise e

# Also store a hash map of endpoints for each location
locations_by_endpoints = {value.get_primary(): key for key, value in endpoints_by_location.items()}
locations_by_endpoints = {value.get_primary(): key for key, value in regional_routing_context_by_location.items()}

return endpoints_by_location, locations_by_endpoints, parsed_locations
return regional_routing_context_by_location, locations_by_endpoints, parsed_locations

def _get_health_check_endpoints(regional_routing_contexts) -> Set[str]:
# should use the endpoints in the order returned from gateway and only the ones specified in preferred locations
Expand Down Expand Up @@ -154,22 +154,24 @@ class LocationCache(object): # pylint: disable=too-many-public-methods,too-many

def __init__(
self,
default_endpoint,
connection_policy,
default_endpoint: str,
connection_policy: ConnectionPolicy,
):
self.default_regional_routing_context = RegionalRoutingContext(default_endpoint, default_endpoint)
self.enable_multiple_writable_locations = False
self.write_regional_routing_contexts = [self.default_regional_routing_context]
self.read_regional_routing_contexts = [self.default_regional_routing_context]
self.location_unavailability_info_by_endpoint = {}
self.last_cache_update_time_stamp = 0
self.account_read_regional_routing_contexts_by_location = {} # pylint: disable=name-too-long
self.account_write_regional_routing_contexts_by_location = {} # pylint: disable=name-too-long
self.account_locations_by_read_endpoints = {} # pylint: disable=name-too-long
self.account_locations_by_write_endpoints = {} # pylint: disable=name-too-long
self.account_write_locations = []
self.account_read_locations = []
self.connection_policy = connection_policy
self.default_regional_routing_context: RegionalRoutingContext = RegionalRoutingContext(default_endpoint,
default_endpoint)
self.effective_preferred_locations: List[str] = []
self.enable_multiple_writable_locations: bool = False
self.write_regional_routing_contexts: List[RegionalRoutingContext] = [self.default_regional_routing_context]
self.read_regional_routing_contexts: List[RegionalRoutingContext] = [self.default_regional_routing_context]
self.location_unavailability_info_by_endpoint: Dict[str, Dict[str, Set[EndpointOperationType]]] = {}
self.last_cache_update_time_stamp: int = 0
self.account_read_regional_routing_contexts_by_location: Dict[str, RegionalRoutingContext] = {} # pylint: disable=name-too-long
self.account_write_regional_routing_contexts_by_location: Dict[str, RegionalRoutingContext] = {} # pylint: disable=name-too-long
self.account_locations_by_read_endpoints: Dict[str, str] = {} # pylint: disable=name-too-long
self.account_locations_by_write_endpoints: Dict[str, str] = {} # pylint: disable=name-too-long
self.account_write_locations: List[str] = []
self.account_read_locations: List[str] = []
self.connection_policy: ConnectionPolicy = connection_policy

def get_write_regional_routing_contexts(self):
return self.write_regional_routing_contexts
Expand Down Expand Up @@ -310,8 +312,7 @@ def resolve_service_endpoint(self, request):
return regional_routing_context.get_primary()

def should_refresh_endpoints(self): # pylint: disable=too-many-return-statements
most_preferred_location = self.connection_policy.PreferredLocations[0] \
if self.connection_policy.PreferredLocations else None
most_preferred_location = self.effective_preferred_locations[0] if self.effective_preferred_locations else None

# we should schedule refresh in background if we are unable to target the user's most preferredLocation.
if self.connection_policy.EnableEndpointDiscovery:
Expand Down Expand Up @@ -379,7 +380,7 @@ def is_endpoint_unavailable_internal(self, endpoint: str, expected_available_ope
return True

def mark_endpoint_unavailable(
self, unavailable_endpoint: str, unavailable_operation_type: str, refresh_cache: bool):
self, unavailable_endpoint: str, unavailable_operation_type: EndpointOperationType, refresh_cache: bool):
logger.warning("Marking %s unavailable for %s ",
unavailable_endpoint,
unavailable_operation_type)
Expand Down Expand Up @@ -431,6 +432,15 @@ def update_location_cache(self, write_locations=None, read_locations=None, enabl
self.connection_policy.UseMultipleWriteLocations
)

# if preferred locations is empty and the default endpoint is a global endpoint,
# we should use the read locations from gateway as effective preferred locations
if self.connection_policy.PreferredLocations:
self.effective_preferred_locations = self.connection_policy.PreferredLocations
elif self.is_default_endpoint_regional():
self.effective_preferred_locations = []
elif not self.effective_preferred_locations:
self.effective_preferred_locations = self.account_read_locations

self.write_regional_routing_contexts = self.get_preferred_regional_routing_contexts(
self.account_write_regional_routing_contexts_by_location,
self.account_write_locations,
Expand All @@ -456,12 +466,12 @@ def get_preferred_regional_routing_contexts(
or expected_available_operation == EndpointOperationType.ReadType
):
unavailable_endpoints = []
if self.connection_policy.PreferredLocations:
if self.effective_preferred_locations:
# When client can not use multiple write locations, preferred locations
# list should only be used determining read endpoints order. If client
# can use multiple write locations, preferred locations list should be
# used for determining both read and write endpoints order.
for location in self.connection_policy.PreferredLocations:
for location in self.effective_preferred_locations:
regional_endpoint = endpoints_by_location[location] if location in endpoints_by_location \
else None
if regional_endpoint:
Expand All @@ -486,6 +496,13 @@ def get_preferred_regional_routing_contexts(

return regional_endpoints

# if the endpoint is returned from the gateway in the account topology, it is a regional endpoint
def is_default_endpoint_regional(self) -> bool:
return any(
context.get_primary() == self.default_regional_routing_context.get_primary()
for context in self.account_read_regional_routing_contexts_by_location.values()
)

def can_use_multiple_write_locations(self):
return self.connection_policy.UseMultipleWriteLocations and self.enable_multiple_writable_locations

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -23,17 +23,8 @@
in the Azure Cosmos database service.
"""

import logging
from azure.cosmos.documents import _OperationType

logger = logging.getLogger(__name__)
logger.setLevel(logging.INFO)
log_formatter = logging.Formatter("%(levelname)s:%(message)s")
log_handler = logging.StreamHandler()
log_handler.setFormatter(log_formatter)
logger.addHandler(log_handler)


class _SessionRetryPolicy(object):
"""The session retry policy used to handle read/write session unavailability.
"""
Expand Down
Loading
Loading