Skip to content

[Cosmos] Session container fixes #40366

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Draft
wants to merge 1 commit into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
101 changes: 71 additions & 30 deletions sdk/cosmos/azure-cosmos/azure/cosmos/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -167,37 +167,9 @@ def GetHeaders( # pylint: disable=too-many-statements,too-many-branches
if options.get("indexingDirective"):
headers[http_constants.HttpHeaders.IndexingDirective] = options["indexingDirective"]

consistency_level = None

# get default client consistency level
default_client_consistency_level = headers.get(http_constants.HttpHeaders.ConsistencyLevel)

# set consistency level. check if set via options, this will override the default
# set request consistency level - if session consistency, the client should be setting this on its own
if options.get("consistencyLevel"):
consistency_level = options["consistencyLevel"]
# TODO: move this line outside of if-else cause to remove the code duplication
headers[http_constants.HttpHeaders.ConsistencyLevel] = consistency_level
elif default_client_consistency_level is not None:
consistency_level = default_client_consistency_level
headers[http_constants.HttpHeaders.ConsistencyLevel] = consistency_level

# figure out if consistency level for this request is session
is_session_consistency = consistency_level == documents.ConsistencyLevel.Session

# set session token if required
if is_session_consistency is True and not IsMasterResource(resource_type):
# if there is a token set via option, then use it to override default
if options.get("sessionToken"):
headers[http_constants.HttpHeaders.SessionToken] = options["sessionToken"]
else:
# check if the client's default consistency is session (and request consistency level is same),
# then update from session container
if default_client_consistency_level == documents.ConsistencyLevel.Session and \
cosmos_client_connection.session:
# populate session token from the client's session container
headers[http_constants.HttpHeaders.SessionToken] = cosmos_client_connection.session.get_session_token(
path
)
headers[http_constants.HttpHeaders.ConsistencyLevel] = options["consistencyLevel"]

if options.get("enableScanInQuery"):
headers[http_constants.HttpHeaders.EnableScanInQuery] = options["enableScanInQuery"]
Expand Down Expand Up @@ -337,6 +309,75 @@ def GetHeaders( # pylint: disable=too-many-statements,too-many-branches

return headers

def _is_session_token_request(
cosmos_client_connection: Union["CosmosClientConnection", "AsyncClientConnection"],
headers: dict,
resource_type: str,
operation_type: str) -> None:
consistency_level = headers.get(http_constants.HttpHeaders.ConsistencyLevel)
# Figure out if consistency level for this request is session
is_session_consistency = consistency_level == documents.ConsistencyLevel.Session

# Verify that it is not a metadata request, and that it is either a read request, batch request, or an account
# configured to use multiple write regions
return (is_session_consistency is True and not IsMasterResource(resource_type)
and (documents._OperationType.IsReadOnlyOperation(operation_type) or operation_type == "Batch"
or cosmos_client_connection._global_endpoint_manager.get_use_multiple_write_locations()))


def set_session_token_header(
cosmos_client_connection: Union["CosmosClientConnection", "AsyncClientConnection"],
headers: dict,
path: str,
resource_type: str,
operation_type: str,
options: Mapping[str, Any],
partition_key_range_id: Optional[str] = None) -> None:
# set session token if required
if _is_session_token_request(cosmos_client_connection, headers, resource_type, operation_type):
# if there is a token set via option, then use it to override default
if options.get("sessionToken"):
headers[http_constants.HttpHeaders.SessionToken] = options["sessionToken"]
else:
# check if the client's default consistency is session (and request consistency level is same),
# then update from session container
if headers[http_constants.HttpHeaders.ConsistencyLevel] == documents.ConsistencyLevel.Session and \
cosmos_client_connection.session:
# populate session token from the client's session container
session_token = cosmos_client_connection.session.get_session_token(path,
options.get('partitionKey'),
cosmos_client_connection._container_properties_cache,
cosmos_client_connection._routing_map_provider,
partition_key_range_id)
if session_token != "":
headers[http_constants.HttpHeaders.SessionToken] = session_token

async def set_session_token_header_async(
cosmos_client_connection: Union["CosmosClientConnection", "AsyncClientConnection"],
headers: dict,
path: str,
resource_type: str,
operation_type: str,
options: Mapping[str, Any],
partition_key_range_id: Optional[str] = None) -> None:
# set session token if required
if _is_session_token_request(cosmos_client_connection, headers, resource_type, operation_type):
# if there is a token set via option, then use it to override default
if options.get("sessionToken"):
headers[http_constants.HttpHeaders.SessionToken] = options["sessionToken"]
else:
# check if the client's default consistency is session (and request consistency level is same),
# then update from session container
if headers[http_constants.HttpHeaders.ConsistencyLevel] == documents.ConsistencyLevel.Session and \
cosmos_client_connection.session:
# populate session token from the client's session container
session_token = await cosmos_client_connection.session.get_session_token_async(path,
options.get('partitionKey'),
cosmos_client_connection._container_properties_cache,
cosmos_client_connection._routing_map_provider,
partition_key_range_id)
if session_token != "":
headers[http_constants.HttpHeaders.SessionToken] = session_token

def GetResourceIdOrFullNameFromLink(resource_link: str) -> Optional[str]:
"""Gets resource id or full name from resource link.
Expand Down
62 changes: 38 additions & 24 deletions sdk/cosmos/azure-cosmos/azure/cosmos/_cosmos_client_connection.py
Original file line number Diff line number Diff line change
Expand Up @@ -2042,6 +2042,7 @@ def PatchItem(

headers = base.GetHeaders(self, self.default_headers, "patch", path, document_id, resource_type,
documents._OperationType.Patch, options)
base.set_session_token_header(self, headers, path, resource_type, documents._OperationType.Patch, options)
# Patch will use WriteEndpoint since it uses PUT operation
request_params = RequestObject(resource_type, documents._OperationType.Patch)
request_data = {}
Expand Down Expand Up @@ -2131,6 +2132,7 @@ def _Batch(
base._populate_batch_headers(initial_headers)
headers = base.GetHeaders(self, initial_headers, "post", path, collection_id, "docs",
documents._OperationType.Batch, options)
base.set_session_token_header(self, headers, path, "docs", documents._OperationType.Batch, options)
request_params = RequestObject("docs", documents._OperationType.Batch)
return cast(
Tuple[List[Dict[str, Any]], CaseInsensitiveDict],
Expand Down Expand Up @@ -2191,6 +2193,8 @@ def DeleteAllItemsByPartitionKey(
collection_id = base.GetResourceIdOrFullNameFromLink(collection_link)
headers = base.GetHeaders(self, self.default_headers, "post", path, collection_id,
"partitionkey", documents._OperationType.Delete, options)
base.set_session_token_header(self, headers, path, "partitionkey", documents._OperationType.Delete,
options)
request_params = RequestObject("partitionkey", documents._OperationType.Delete)
_, last_response_headers = self.__Post(
path=path,
Expand Down Expand Up @@ -2615,7 +2619,7 @@ def Create(
self,
body: Dict[str, Any],
path: str,
typ: str,
resource_type: str,
id: Optional[str],
initial_headers: Optional[Mapping[str, Any]],
options: Optional[Mapping[str, Any]] = None,
Expand All @@ -2625,7 +2629,7 @@ def Create(

:param dict body:
:param str path:
:param str typ:
:param str resource_type:
:param str id:
:param dict initial_headers:
:param dict options:
Expand All @@ -2642,11 +2646,12 @@ def Create(
options = {}

initial_headers = initial_headers or self.default_headers
headers = base.GetHeaders(self, initial_headers, "post", path, id, typ, documents._OperationType.Create,
options)
headers = base.GetHeaders(self, initial_headers, "post", path, id, resource_type,
documents._OperationType.Create, options)
base.set_session_token_header(self, headers, path, resource_type, documents._OperationType.Create, options)
# Create will use WriteEndpoint since it uses POST operation

request_params = RequestObject(typ, documents._OperationType.Create)
request_params = RequestObject(resource_type, documents._OperationType.Create)
result, last_response_headers = self.__Post(path, request_params, body, headers, **kwargs)
self.last_response_headers = last_response_headers

Expand All @@ -2660,7 +2665,7 @@ def Upsert(
self,
body: Dict[str, Any],
path: str,
typ: str,
resource_type: str,
id: Optional[str],
initial_headers: Optional[Mapping[str, Any]],
options: Optional[Mapping[str, Any]] = None,
Expand All @@ -2670,7 +2675,7 @@ def Upsert(

:param dict body:
:param str path:
:param str typ:
:param str resource_type:
:param str id:
:param dict initial_headers:
:param dict options:
Expand All @@ -2687,12 +2692,13 @@ def Upsert(
options = {}

initial_headers = initial_headers or self.default_headers
headers = base.GetHeaders(self, initial_headers, "post", path, id, typ, documents._OperationType.Upsert,
options)
headers = base.GetHeaders(self, initial_headers, "post", path, id, resource_type,
documents._OperationType.Upsert, options)
headers[http_constants.HttpHeaders.IsUpsert] = True
base.set_session_token_header(self, headers, path, resource_type, documents._OperationType.Upsert, options)

# Upsert will use WriteEndpoint since it uses POST operation
request_params = RequestObject(typ, documents._OperationType.Upsert)
request_params = RequestObject(resource_type, documents._OperationType.Upsert)
result, last_response_headers = self.__Post(path, request_params, body, headers, **kwargs)
self.last_response_headers = last_response_headers
# update session for write request
Expand All @@ -2705,7 +2711,7 @@ def Replace(
self,
resource: Dict[str, Any],
path: str,
typ: str,
resource_type: str,
id: Optional[str],
initial_headers: Optional[Mapping[str, Any]],
options: Optional[Mapping[str, Any]] = None,
Expand All @@ -2715,7 +2721,7 @@ def Replace(

:param dict resource:
:param str path:
:param str typ:
:param str resource_type:
:param str id:
:param dict initial_headers:
:param dict options:
Expand All @@ -2732,10 +2738,11 @@ def Replace(
options = {}

initial_headers = initial_headers or self.default_headers
headers = base.GetHeaders(self, initial_headers, "put", path, id, typ, documents._OperationType.Replace,
options)
headers = base.GetHeaders(self, initial_headers, "put", path, id, resource_type,
documents._OperationType.Replace, options)
base.set_session_token_header(self, headers, path, resource_type, documents._OperationType.Replace, options)
# Replace will use WriteEndpoint since it uses PUT operation
request_params = RequestObject(typ, documents._OperationType.Replace)
request_params = RequestObject(resource_type, documents._OperationType.Replace)
result, last_response_headers = self.__Put(path, request_params, resource, headers, **kwargs)
self.last_response_headers = last_response_headers

Expand All @@ -2748,7 +2755,7 @@ def Replace(
def Read(
self,
path: str,
typ: str,
resource_type: str,
id: Optional[str],
initial_headers: Optional[Mapping[str, Any]],
options: Optional[Mapping[str, Any]] = None,
Expand All @@ -2757,7 +2764,7 @@ def Read(
"""Reads an Azure Cosmos resource and returns it.

:param str path:
:param str typ:
:param str resource_type:
:param str id:
:param dict initial_headers:
:param dict options:
Expand All @@ -2774,9 +2781,11 @@ def Read(
options = {}

initial_headers = initial_headers or self.default_headers
headers = base.GetHeaders(self, initial_headers, "get", path, id, typ, documents._OperationType.Read, options)
headers = base.GetHeaders(self, initial_headers, "get", path, id, resource_type,
documents._OperationType.Read, options)
base.set_session_token_header(self, headers, path, resource_type, documents._OperationType.Read, options)
# Read will use ReadEndpoint since it uses GET operation
request_params = RequestObject(typ, documents._OperationType.Read)
request_params = RequestObject(resource_type, documents._OperationType.Read)
result, last_response_headers = self.__Get(path, request_params, headers, **kwargs)
self.last_response_headers = last_response_headers
if response_hook:
Expand All @@ -2786,7 +2795,7 @@ def Read(
def DeleteResource(
self,
path: str,
typ: str,
resource_type: str,
id: Optional[str],
initial_headers: Optional[Mapping[str, Any]],
options: Optional[Mapping[str, Any]] = None,
Expand All @@ -2812,10 +2821,11 @@ def DeleteResource(
options = {}

initial_headers = initial_headers or self.default_headers
headers = base.GetHeaders(self, initial_headers, "delete", path, id, typ, documents._OperationType.Delete,
options)
headers = base.GetHeaders(self, initial_headers, "delete", path, id, resource_type,
documents._OperationType.Delete, options)
base.set_session_token_header(self, headers, path, resource_type, documents._OperationType.Delete, options)
# Delete will use WriteEndpoint since it uses DELETE operation
request_params = RequestObject(typ, documents._OperationType.Delete)
request_params = RequestObject(resource_type, documents._OperationType.Delete)
result, last_response_headers = self.__Delete(path, request_params, headers, **kwargs)
self.last_response_headers = last_response_headers

Expand Down Expand Up @@ -3063,6 +3073,8 @@ def __GetBodiesFromQueryResult(result: Dict[str, Any]) -> List[Dict[str, Any]]:
options,
partition_key_range_id
)
base.set_session_token_header(self, headers, path, resource_type, request_params.operation_type, options,
partition_key_range_id)

change_feed_state: Optional[ChangeFeedState] = options.get("changeFeedState")
if change_feed_state is not None:
Expand Down Expand Up @@ -3101,6 +3113,8 @@ def __GetBodiesFromQueryResult(result: Dict[str, Any]) -> List[Dict[str, Any]]:
options,
partition_key_range_id
)
base.set_session_token_header(self, req_headers, path, resource_type, documents._OperationType.SqlQuery,
options)

# check if query has prefix partition key
isPrefixPartitionQuery = kwargs.pop("isPrefixPartitionQuery", None)
Expand Down Expand Up @@ -3355,7 +3369,7 @@ def _UpdateSessionIfRequired(

if is_session_consistency and self.session:
# update session
self.session.update_session(response_result, response_headers)
self.session.update_session(self, response_result, response_headers)

def _get_partition_key_definition(self, collection_link: str) -> Optional[Dict[str, Any]]:
partition_key_definition: Optional[Dict[str, Any]]
Expand Down
Loading
Loading