diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/_base.py b/sdk/cosmos/azure-cosmos/azure/cosmos/_base.py index 654b23c5d71f..ab2b658009d1 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/_base.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/_base.py @@ -167,37 +167,9 @@ def GetHeaders( # pylint: disable=too-many-statements,too-many-branches if options.get("indexingDirective"): headers[http_constants.HttpHeaders.IndexingDirective] = options["indexingDirective"] - consistency_level = None - - # get default client consistency level - default_client_consistency_level = headers.get(http_constants.HttpHeaders.ConsistencyLevel) - - # set consistency level. check if set via options, this will override the default + # set request consistency level - if session consistency, the client should be setting this on its own if options.get("consistencyLevel"): - consistency_level = options["consistencyLevel"] - # TODO: move this line outside of if-else cause to remove the code duplication - headers[http_constants.HttpHeaders.ConsistencyLevel] = consistency_level - elif default_client_consistency_level is not None: - consistency_level = default_client_consistency_level - headers[http_constants.HttpHeaders.ConsistencyLevel] = consistency_level - - # figure out if consistency level for this request is session - is_session_consistency = consistency_level == documents.ConsistencyLevel.Session - - # set session token if required - if is_session_consistency is True and not IsMasterResource(resource_type): - # if there is a token set via option, then use it to override default - if options.get("sessionToken"): - headers[http_constants.HttpHeaders.SessionToken] = options["sessionToken"] - else: - # check if the client's default consistency is session (and request consistency level is same), - # then update from session container - if default_client_consistency_level == documents.ConsistencyLevel.Session and \ - cosmos_client_connection.session: - # populate session token from the client's session container - headers[http_constants.HttpHeaders.SessionToken] = cosmos_client_connection.session.get_session_token( - path - ) + headers[http_constants.HttpHeaders.ConsistencyLevel] = options["consistencyLevel"] if options.get("enableScanInQuery"): headers[http_constants.HttpHeaders.EnableScanInQuery] = options["enableScanInQuery"] @@ -337,6 +309,75 @@ def GetHeaders( # pylint: disable=too-many-statements,too-many-branches return headers +def _is_session_token_request( + cosmos_client_connection: Union["CosmosClientConnection", "AsyncClientConnection"], + headers: dict, + resource_type: str, + operation_type: str) -> None: + consistency_level = headers.get(http_constants.HttpHeaders.ConsistencyLevel) + # Figure out if consistency level for this request is session + is_session_consistency = consistency_level == documents.ConsistencyLevel.Session + + # Verify that it is not a metadata request, and that it is either a read request, batch request, or an account + # configured to use multiple write regions + return (is_session_consistency is True and not IsMasterResource(resource_type) + and (documents._OperationType.IsReadOnlyOperation(operation_type) or operation_type == "Batch" + or cosmos_client_connection._global_endpoint_manager.get_use_multiple_write_locations())) + + +def set_session_token_header( + cosmos_client_connection: Union["CosmosClientConnection", "AsyncClientConnection"], + headers: dict, + path: str, + resource_type: str, + operation_type: str, + options: Mapping[str, Any], + partition_key_range_id: Optional[str] = None) -> None: + # set session token if required + if _is_session_token_request(cosmos_client_connection, headers, resource_type, operation_type): + # if there is a token set via option, then use it to override default + if options.get("sessionToken"): + headers[http_constants.HttpHeaders.SessionToken] = options["sessionToken"] + else: + # check if the client's default consistency is session (and request consistency level is same), + # then update from session container + if headers[http_constants.HttpHeaders.ConsistencyLevel] == documents.ConsistencyLevel.Session and \ + cosmos_client_connection.session: + # populate session token from the client's session container + session_token = cosmos_client_connection.session.get_session_token(path, + options.get('partitionKey'), + cosmos_client_connection._container_properties_cache, + cosmos_client_connection._routing_map_provider, + partition_key_range_id) + if session_token != "": + headers[http_constants.HttpHeaders.SessionToken] = session_token + +async def set_session_token_header_async( + cosmos_client_connection: Union["CosmosClientConnection", "AsyncClientConnection"], + headers: dict, + path: str, + resource_type: str, + operation_type: str, + options: Mapping[str, Any], + partition_key_range_id: Optional[str] = None) -> None: + # set session token if required + if _is_session_token_request(cosmos_client_connection, headers, resource_type, operation_type): + # if there is a token set via option, then use it to override default + if options.get("sessionToken"): + headers[http_constants.HttpHeaders.SessionToken] = options["sessionToken"] + else: + # check if the client's default consistency is session (and request consistency level is same), + # then update from session container + if headers[http_constants.HttpHeaders.ConsistencyLevel] == documents.ConsistencyLevel.Session and \ + cosmos_client_connection.session: + # populate session token from the client's session container + session_token = await cosmos_client_connection.session.get_session_token_async(path, + options.get('partitionKey'), + cosmos_client_connection._container_properties_cache, + cosmos_client_connection._routing_map_provider, + partition_key_range_id) + if session_token != "": + headers[http_constants.HttpHeaders.SessionToken] = session_token def GetResourceIdOrFullNameFromLink(resource_link: str) -> Optional[str]: """Gets resource id or full name from resource link. diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/_cosmos_client_connection.py b/sdk/cosmos/azure-cosmos/azure/cosmos/_cosmos_client_connection.py index 3934c23bcf99..38ae3f4a0bf2 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/_cosmos_client_connection.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/_cosmos_client_connection.py @@ -2042,6 +2042,7 @@ def PatchItem( headers = base.GetHeaders(self, self.default_headers, "patch", path, document_id, resource_type, documents._OperationType.Patch, options) + base.set_session_token_header(self, headers, path, resource_type, documents._OperationType.Patch, options) # Patch will use WriteEndpoint since it uses PUT operation request_params = RequestObject(resource_type, documents._OperationType.Patch) request_data = {} @@ -2131,6 +2132,7 @@ def _Batch( base._populate_batch_headers(initial_headers) headers = base.GetHeaders(self, initial_headers, "post", path, collection_id, "docs", documents._OperationType.Batch, options) + base.set_session_token_header(self, headers, path, "docs", documents._OperationType.Batch, options) request_params = RequestObject("docs", documents._OperationType.Batch) return cast( Tuple[List[Dict[str, Any]], CaseInsensitiveDict], @@ -2191,6 +2193,8 @@ def DeleteAllItemsByPartitionKey( collection_id = base.GetResourceIdOrFullNameFromLink(collection_link) headers = base.GetHeaders(self, self.default_headers, "post", path, collection_id, "partitionkey", documents._OperationType.Delete, options) + base.set_session_token_header(self, headers, path, "partitionkey", documents._OperationType.Delete, + options) request_params = RequestObject("partitionkey", documents._OperationType.Delete) _, last_response_headers = self.__Post( path=path, @@ -2615,7 +2619,7 @@ def Create( self, body: Dict[str, Any], path: str, - typ: str, + resource_type: str, id: Optional[str], initial_headers: Optional[Mapping[str, Any]], options: Optional[Mapping[str, Any]] = None, @@ -2625,7 +2629,7 @@ def Create( :param dict body: :param str path: - :param str typ: + :param str resource_type: :param str id: :param dict initial_headers: :param dict options: @@ -2642,11 +2646,12 @@ def Create( options = {} initial_headers = initial_headers or self.default_headers - headers = base.GetHeaders(self, initial_headers, "post", path, id, typ, documents._OperationType.Create, - options) + headers = base.GetHeaders(self, initial_headers, "post", path, id, resource_type, + documents._OperationType.Create, options) + base.set_session_token_header(self, headers, path, resource_type, documents._OperationType.Create, options) # Create will use WriteEndpoint since it uses POST operation - request_params = RequestObject(typ, documents._OperationType.Create) + request_params = RequestObject(resource_type, documents._OperationType.Create) result, last_response_headers = self.__Post(path, request_params, body, headers, **kwargs) self.last_response_headers = last_response_headers @@ -2660,7 +2665,7 @@ def Upsert( self, body: Dict[str, Any], path: str, - typ: str, + resource_type: str, id: Optional[str], initial_headers: Optional[Mapping[str, Any]], options: Optional[Mapping[str, Any]] = None, @@ -2670,7 +2675,7 @@ def Upsert( :param dict body: :param str path: - :param str typ: + :param str resource_type: :param str id: :param dict initial_headers: :param dict options: @@ -2687,12 +2692,13 @@ def Upsert( options = {} initial_headers = initial_headers or self.default_headers - headers = base.GetHeaders(self, initial_headers, "post", path, id, typ, documents._OperationType.Upsert, - options) + headers = base.GetHeaders(self, initial_headers, "post", path, id, resource_type, + documents._OperationType.Upsert, options) headers[http_constants.HttpHeaders.IsUpsert] = True + base.set_session_token_header(self, headers, path, resource_type, documents._OperationType.Upsert, options) # Upsert will use WriteEndpoint since it uses POST operation - request_params = RequestObject(typ, documents._OperationType.Upsert) + request_params = RequestObject(resource_type, documents._OperationType.Upsert) result, last_response_headers = self.__Post(path, request_params, body, headers, **kwargs) self.last_response_headers = last_response_headers # update session for write request @@ -2705,7 +2711,7 @@ def Replace( self, resource: Dict[str, Any], path: str, - typ: str, + resource_type: str, id: Optional[str], initial_headers: Optional[Mapping[str, Any]], options: Optional[Mapping[str, Any]] = None, @@ -2715,7 +2721,7 @@ def Replace( :param dict resource: :param str path: - :param str typ: + :param str resource_type: :param str id: :param dict initial_headers: :param dict options: @@ -2732,10 +2738,11 @@ def Replace( options = {} initial_headers = initial_headers or self.default_headers - headers = base.GetHeaders(self, initial_headers, "put", path, id, typ, documents._OperationType.Replace, - options) + headers = base.GetHeaders(self, initial_headers, "put", path, id, resource_type, + documents._OperationType.Replace, options) + base.set_session_token_header(self, headers, path, resource_type, documents._OperationType.Replace, options) # Replace will use WriteEndpoint since it uses PUT operation - request_params = RequestObject(typ, documents._OperationType.Replace) + request_params = RequestObject(resource_type, documents._OperationType.Replace) result, last_response_headers = self.__Put(path, request_params, resource, headers, **kwargs) self.last_response_headers = last_response_headers @@ -2748,7 +2755,7 @@ def Replace( def Read( self, path: str, - typ: str, + resource_type: str, id: Optional[str], initial_headers: Optional[Mapping[str, Any]], options: Optional[Mapping[str, Any]] = None, @@ -2757,7 +2764,7 @@ def Read( """Reads an Azure Cosmos resource and returns it. :param str path: - :param str typ: + :param str resource_type: :param str id: :param dict initial_headers: :param dict options: @@ -2774,9 +2781,11 @@ def Read( options = {} initial_headers = initial_headers or self.default_headers - headers = base.GetHeaders(self, initial_headers, "get", path, id, typ, documents._OperationType.Read, options) + headers = base.GetHeaders(self, initial_headers, "get", path, id, resource_type, + documents._OperationType.Read, options) + base.set_session_token_header(self, headers, path, resource_type, documents._OperationType.Read, options) # Read will use ReadEndpoint since it uses GET operation - request_params = RequestObject(typ, documents._OperationType.Read) + request_params = RequestObject(resource_type, documents._OperationType.Read) result, last_response_headers = self.__Get(path, request_params, headers, **kwargs) self.last_response_headers = last_response_headers if response_hook: @@ -2786,7 +2795,7 @@ def Read( def DeleteResource( self, path: str, - typ: str, + resource_type: str, id: Optional[str], initial_headers: Optional[Mapping[str, Any]], options: Optional[Mapping[str, Any]] = None, @@ -2812,10 +2821,11 @@ def DeleteResource( options = {} initial_headers = initial_headers or self.default_headers - headers = base.GetHeaders(self, initial_headers, "delete", path, id, typ, documents._OperationType.Delete, - options) + headers = base.GetHeaders(self, initial_headers, "delete", path, id, resource_type, + documents._OperationType.Delete, options) + base.set_session_token_header(self, headers, path, resource_type, documents._OperationType.Delete, options) # Delete will use WriteEndpoint since it uses DELETE operation - request_params = RequestObject(typ, documents._OperationType.Delete) + request_params = RequestObject(resource_type, documents._OperationType.Delete) result, last_response_headers = self.__Delete(path, request_params, headers, **kwargs) self.last_response_headers = last_response_headers @@ -3063,6 +3073,8 @@ def __GetBodiesFromQueryResult(result: Dict[str, Any]) -> List[Dict[str, Any]]: options, partition_key_range_id ) + base.set_session_token_header(self, headers, path, resource_type, request_params.operation_type, options, + partition_key_range_id) change_feed_state: Optional[ChangeFeedState] = options.get("changeFeedState") if change_feed_state is not None: @@ -3101,6 +3113,8 @@ def __GetBodiesFromQueryResult(result: Dict[str, Any]) -> List[Dict[str, Any]]: options, partition_key_range_id ) + base.set_session_token_header(self, req_headers, path, resource_type, documents._OperationType.SqlQuery, + options) # check if query has prefix partition key isPrefixPartitionQuery = kwargs.pop("isPrefixPartitionQuery", None) @@ -3355,7 +3369,7 @@ def _UpdateSessionIfRequired( if is_session_consistency and self.session: # update session - self.session.update_session(response_result, response_headers) + self.session.update_session(self, response_result, response_headers) def _get_partition_key_definition(self, collection_link: str) -> Optional[Dict[str, Any]]: partition_key_definition: Optional[Dict[str, Any]] diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/_session.py b/sdk/cosmos/azure-cosmos/azure/cosmos/_session.py index 84dd5914e208..811d778fca90 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/_session.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/_session.py @@ -30,6 +30,8 @@ from . import http_constants from ._vector_session_token import VectorSessionToken from .exceptions import CosmosHttpResponseError +from .partition_key import PartitionKey +from typing import Any, Dict, Optional class SessionContainer(object): @@ -38,12 +40,24 @@ def __init__(self): self.rid_to_session_token = {} self.session_lock = threading.RLock() - def get_session_token(self, resource_path): - """Get Session Token for collection_link. + def get_session_token( + self, + resource_path: str, + pk_value: str, + container_properties_cache: Dict[str, Dict[str, Any]], + routing_map_provider: Any, + partition_key_range_id: Optional[int]) -> str: + """Get Session Token for collection_link and operation_type. :param str resource_path: Self link / path to the resource - :return: Session Token dictionary for the collection_id - :rtype: dict + :param str operation_type: Operation type (e.g. 'Create', 'Read', 'Upsert', 'Replace') + :param str pk_value: The partition key value being used for the operation + :param container_properties_cache: Container properties cache used to fetch partition key definitions + :type container_properties_cache: Dict[str, Dict[str, Any]] + :param int partition_key_range_id: The partition key range ID used for the operation + :return: Session Token dictionary for the collection_id, will be empty string if not found or if the operation + does not require a session token (single master write operations). + :rtype: str """ with self.session_lock: @@ -59,23 +73,91 @@ def get_session_token(self, resource_path): else: collection_rid = _base.GetItemContainerLink(resource_path) - if collection_rid in self.rid_to_session_token: + if collection_rid in self.rid_to_session_token and collection_name in container_properties_cache: token_dict = self.rid_to_session_token[collection_rid] - session_token_list = [] - for key in token_dict.keys(): - session_token_list.append("{0}:{1}".format(key, token_dict[key].convert_to_string())) - session_token = ",".join(session_token_list) + if partition_key_range_id is not None: + session_token = token_dict.get(partition_key_range_id) + else: + collection_pk_definition = container_properties_cache[collection_name].get("partitionKey") + partition_key = PartitionKey(path=collection_pk_definition['paths'], + kind=collection_pk_definition['kind'], + version=collection_pk_definition['version']) + epk_range = partition_key._get_epk_range_for_partition_key(pk_value=pk_value) + pk_range = routing_map_provider.get_overlapping_ranges(collection_name, [epk_range]) + vector_session_token = token_dict.get(pk_range[0]['id']) + session_token = "{0}:{1}".format(pk_range[0]['id'], vector_session_token.session_token) return session_token + return "" + except Exception: # pylint: disable=broad-except + return "" + + async def get_session_token_async( + self, + resource_path: str, + pk_value: str, + container_properties_cache: Dict[str, Dict[str, Any]], + routing_map_provider: Any, + partition_key_range_id: Optional[str]) -> str: + """Get Session Token for collection_link and operation_type. + + :param str resource_path: Self link / path to the resource + :param str operation_type: Operation type (e.g. 'Create', 'Read', 'Upsert', 'Replace') + :param str pk_value: The partition key value being used for the operation + :param container_properties_cache: Container properties cache used to fetch partition key definitions + :type container_properties_cache: Dict[str, Dict[str, Any]] + :param Any routing_map_provider: The routing map provider containing the partition key range cache logic + :param str partition_key_range_id: The partition key range ID used for the operation + :return: Session Token dictionary for the collection_id, will be empty string if not found or if the operation + does not require a session token (single master write operations). + :rtype: str + """ + + with self.session_lock: + is_name_based = _base.IsNameBased(resource_path) + collection_rid = "" + session_token = "" - # return empty token if not found + try: + if is_name_based: + # get the collection name + collection_name = _base.GetItemContainerLink(resource_path) + collection_rid = self.collection_name_to_rid[collection_name] + else: + collection_rid = _base.GetItemContainerLink(resource_path) + + if collection_rid in self.rid_to_session_token and collection_name in container_properties_cache: + token_dict = self.rid_to_session_token[collection_rid] + if partition_key_range_id is not None: + vector_session_token = token_dict.get(partition_key_range_id) + session_token = "{0}:{1}".format(partition_key_range_id, vector_session_token.session_token) + else: + collection_pk_definition = container_properties_cache[collection_name].get("partitionKey") + partition_key = PartitionKey(path=collection_pk_definition['paths'], + kind=collection_pk_definition['kind'], + version=collection_pk_definition['version']) + epk_range = partition_key._get_epk_range_for_partition_key(pk_value=pk_value) + pk_range = await routing_map_provider.get_overlapping_ranges(collection_name, [epk_range]) + session_token_list = [] + parents = pk_range[0].get('parents').copy() + parents.append(pk_range[0]['id']) + for parent in parents: + vector_session_token = token_dict.get(parent) + session_token = "{0}:{1}".format(parent, vector_session_token.session_token) + session_token_list.append(session_token) + # if vector_session_token is not None: + # session_token = "{0}:{1}".format(parent, vector_session_token.session_token) + # session_token_list.append(session_token) + session_token = ",".join(session_token_list) + return session_token return "" except Exception: # pylint: disable=broad-except return "" - def set_session_token(self, response_result, response_headers): + def set_session_token(self, client_connection, response_result, response_headers): """Session token must only be updated from response of requests that successfully mutate resource on the server side (write, replace, delete etc). + :param client_connection: Client connection used to refresh the partition key range cache if needed :param dict response_result: :param dict response_headers: :return: None @@ -86,8 +168,6 @@ def set_session_token(self, response_result, response_headers): # x-ms-alt-content-path which is the string representation of the resource with self.session_lock: - collection_rid = "" - collection_name = "" try: self_link = response_result["_self"] @@ -105,10 +185,15 @@ def set_session_token(self, response_result, response_headers): response_result_id = response_result[response_result_id_key] else: return - collection_rid, collection_name = _base.GetItemContainerInfo( - self_link, alt_content_path, response_result_id - ) - + collection_rid, collection_name = _base.GetItemContainerInfo(self_link, alt_content_path, + response_result_id) + # if the response came in with a new partition key range id after a split, refresh the pk range cache + partition_key_range_id = response_headers.get(http_constants.HttpHeaders.PartitionKeyRangeID) + collection_ranges = None + if client_connection: + collection_ranges = client_connection._routing_map_provider._collection_routing_map_by_item.get(collection_name) + if collection_ranges and not collection_ranges._rangeById.get(partition_key_range_id): + client_connection.refresh_routing_map_provider() except ValueError: return except Exception: # pylint: disable=broad-except @@ -194,7 +279,7 @@ def parse_session_token(response_headers): class Session(object): - """State of a Azure Cosmos session. + """State of an Azure Cosmos session. This session object can be shared across clients within the same process. @@ -209,8 +294,13 @@ def __init__(self, url_connection): def clear_session_token(self, response_headers): self.session_container.clear_session_token(response_headers) - def update_session(self, response_result, response_headers): - self.session_container.set_session_token(response_result, response_headers) + def update_session(self, client_connection, response_result, response_headers): + self.session_container.set_session_token(client_connection, response_result, response_headers) + + def get_session_token(self, resource_path, pk_value, container_properties_cache, routing_map_provider, partition_key_range_id): + return self.session_container.get_session_token(resource_path, pk_value, container_properties_cache, + routing_map_provider, partition_key_range_id) - def get_session_token(self, resource_path): - return self.session_container.get_session_token(resource_path) + async def get_session_token_async(self, resource_path, pk_value, container_properties_cache, routing_map_provider, partition_key_range_id): + return await self.session_container.get_session_token_async(resource_path, pk_value, container_properties_cache, + routing_map_provider, partition_key_range_id) diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_cosmos_client_connection_async.py b/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_cosmos_client_connection_async.py index 49219533a7e6..0fb466697109 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_cosmos_client_connection_async.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_cosmos_client_connection_async.py @@ -737,7 +737,7 @@ async def Create( self, body: Dict[str, Any], path: str, - typ: str, + resource_type: str, id: Optional[str], initial_headers: Optional[Mapping[str, Any]], options: Optional[Mapping[str, Any]] = None, @@ -747,7 +747,7 @@ async def Create( :param dict body: :param str path: - :param str typ: + :param str resource_type: :param str id: :param dict initial_headers: :param dict options: @@ -763,11 +763,13 @@ async def Create( options = {} initial_headers = initial_headers or self.default_headers - headers = base.GetHeaders(self, initial_headers, "post", path, id, typ, + headers = base.GetHeaders(self, initial_headers, "post", path, id, resource_type, documents._OperationType.Create, options) + await base.set_session_token_header_async(self, headers, path, resource_type, + documents._OperationType.Create, options) # Create will use WriteEndpoint since it uses POST operation - request_params = _request_object.RequestObject(typ, documents._OperationType.Create) + request_params = _request_object.RequestObject(resource_type, documents._OperationType.Create) result, last_response_headers = await self.__Post(path, request_params, body, headers, **kwargs) self.last_response_headers = last_response_headers @@ -874,7 +876,7 @@ async def Upsert( self, body: Dict[str, Any], path: str, - typ: str, + resource_type: str, id: Optional[str], initial_headers: Optional[Mapping[str, Any]], options: Optional[Mapping[str, Any]] = None, @@ -884,7 +886,7 @@ async def Upsert( :param dict body: :param str path: - :param str typ: + :param str resource_type: :param str id: :param dict initial_headers: :param dict options: @@ -900,21 +902,22 @@ async def Upsert( options = {} initial_headers = initial_headers or self.default_headers - headers = base.GetHeaders(self, initial_headers, "post", path, id, typ, documents._OperationType.Upsert, - options) + headers = base.GetHeaders(self, initial_headers, "post", path, id, resource_type, + documents._OperationType.Upsert, options) + await base.set_session_token_header_async(self, headers, path, resource_type, + documents._OperationType.Upsert, options) headers[http_constants.HttpHeaders.IsUpsert] = True # Upsert will use WriteEndpoint since it uses POST operation - request_params = _request_object.RequestObject(typ, documents._OperationType.Upsert) + request_params = _request_object.RequestObject(resource_type, documents._OperationType.Upsert) result, last_response_headers = await self.__Post(path, request_params, body, headers, **kwargs) self.last_response_headers = last_response_headers # update session for write request self._UpdateSessionIfRequired(headers, result, self.last_response_headers) if response_hook: response_hook(last_response_headers, result) - return CosmosDict(result, - response_headers=last_response_headers) + return CosmosDict(result, response_headers=last_response_headers) async def __Post( self, @@ -1178,7 +1181,7 @@ async def ReadConflict( async def Read( self, path: str, - typ: str, + resource_type: str, id: Optional[str], initial_headers: Optional[Mapping[str, Any]], options: Optional[Mapping[str, Any]] = None, @@ -1204,10 +1207,12 @@ async def Read( options = {} initial_headers = initial_headers or self.default_headers - headers = base.GetHeaders(self, initial_headers, "get", path, id, typ, documents._OperationType.Read, - options) + headers = base.GetHeaders(self, initial_headers, "get", path, id, resource_type, + documents._OperationType.Read, options) + await base.set_session_token_header_async(self, headers, path, resource_type, + documents._OperationType.Read, options) # Read will use ReadEndpoint since it uses GET operation - request_params = _request_object.RequestObject(typ, documents._OperationType.Read) + request_params = _request_object.RequestObject(resource_type, documents._OperationType.Read) result, last_response_headers = await self.__Get(path, request_params, headers, **kwargs) self.last_response_headers = last_response_headers if response_hook: @@ -1456,16 +1461,18 @@ async def PatchItem( response_hook = kwargs.pop("response_hook", None) path = base.GetPathFromLink(document_link) document_id = base.GetResourceIdOrFullNameFromLink(document_link) - typ = "docs" + resource_type = "docs" if options is None: options = {} initial_headers = self.default_headers - headers = base.GetHeaders(self, initial_headers, "patch", path, document_id, typ, + headers = base.GetHeaders(self, initial_headers, "patch", path, document_id, resource_type, documents._OperationType.Patch, options) + await base.set_session_token_header_async(self, headers, path, resource_type, + documents._OperationType.Patch, options) # Patch will use WriteEndpoint since it uses PUT operation - request_params = _request_object.RequestObject(typ, documents._OperationType.Patch) + request_params = _request_object.RequestObject(resource_type, documents._OperationType.Patch) request_data = {} if options.get("filterPredicate"): request_data["condition"] = options.get("filterPredicate") @@ -1540,7 +1547,7 @@ async def Replace( self, resource: Dict[str, Any], path: str, - typ: str, + resource_type: str, id: Optional[str], initial_headers: Optional[Mapping[str, Any]], options: Optional[Mapping[str, Any]] = None, @@ -1566,10 +1573,12 @@ async def Replace( options = {} initial_headers = initial_headers or self.default_headers - headers = base.GetHeaders(self, initial_headers, "put", path, id, typ, documents._OperationType.Replace, - options) + headers = base.GetHeaders(self, initial_headers, "put", path, id, resource_type, + documents._OperationType.Replace, options) + await base.set_session_token_header_async(self, headers, path, resource_type, + documents._OperationType.Replace, options) # Replace will use WriteEndpoint since it uses PUT operation - request_params = _request_object.RequestObject(typ, documents._OperationType.Replace) + request_params = _request_object.RequestObject(resource_type, documents._OperationType.Replace) result, last_response_headers = await self.__Put(path, request_params, resource, headers, **kwargs) self.last_response_headers = last_response_headers @@ -1864,7 +1873,7 @@ async def DeleteConflict( async def DeleteResource( self, path: str, - typ: str, + resource_type: str, id: Optional[str], initial_headers: Optional[Mapping[str, Any]], options: Optional[Mapping[str, Any]] = None, @@ -1889,10 +1898,12 @@ async def DeleteResource( options = {} initial_headers = initial_headers or self.default_headers - headers = base.GetHeaders(self, initial_headers, "delete", path, id, typ, documents._OperationType.Delete, - options) + headers = base.GetHeaders(self, initial_headers, "delete", path, id, resource_type, + documents._OperationType.Delete, options) + await base.set_session_token_header_async(self, headers, path, resource_type, + documents._OperationType.Delete, options) # Delete will use WriteEndpoint since it uses DELETE operation - request_params = _request_object.RequestObject(typ, documents._OperationType.Delete) + request_params = _request_object.RequestObject(resource_type, documents._OperationType.Delete) result, last_response_headers = await self.__Delete(path, request_params, headers, **kwargs) self.last_response_headers = last_response_headers @@ -2005,6 +2016,8 @@ async def _Batch( base._populate_batch_headers(initial_headers) headers = base.GetHeaders(self, initial_headers, "post", path, collection_id, "docs", documents._OperationType.Batch, options) + await base.set_session_token_header_async(self, headers, path, "docs", + documents._OperationType.Read, options) request_params = _request_object.RequestObject("docs", documents._OperationType.Batch) result = await self.__Post(path, request_params, batch_operations, headers, **kwargs) return cast(Tuple[List[Dict[str, Any]], CaseInsensitiveDict], result) @@ -2808,7 +2821,7 @@ async def QueryFeed( async def __QueryFeed( # pylint: disable=too-many-branches,too-many-statements,too-many-locals self, path: str, - typ: str, + resource_type: str, id_: Optional[str], result_fn: Callable[[Dict[str, Any]], List[Dict[str, Any]]], create_fn: Optional[Callable[['CosmosClientConnection', Dict[str, Any]], Dict[str, Any]]], @@ -2822,7 +2835,7 @@ async def __QueryFeed( # pylint: disable=too-many-branches,too-many-statements, """Query for more than one Azure Cosmos resources. :param str path: - :param str typ: + :param str resource_type: :param str id_: :param function result_fn: :param function create_fn: @@ -2858,12 +2871,13 @@ def __GetBodiesFromQueryResult(result: Dict[str, Any]) -> List[Dict[str, Any]]: if query is None: # Query operations will use ReadEndpoint even though it uses GET(for feed requests) request_params = _request_object.RequestObject( - typ, + resource_type, documents._OperationType.QueryPlan if is_query_plan else documents._OperationType.ReadFeed ) - headers = base.GetHeaders(self, initial_headers, "get", path, id_, typ, request_params.operation_type, - options, partition_key_range_id) - + headers = base.GetHeaders(self, initial_headers, "get", path, id_, resource_type, + request_params.operation_type, options, partition_key_range_id) + await base.set_session_token_header_async(self, headers, path, resource_type, + request_params.operation_type, options, partition_key_range_id) change_feed_state: Optional[ChangeFeedState] = options.get("changeFeedState") if change_feed_state is not None: await change_feed_state.populate_request_headers_async(self._routing_map_provider, headers) @@ -2889,9 +2903,11 @@ def __GetBodiesFromQueryResult(result: Dict[str, Any]) -> List[Dict[str, Any]]: raise SystemError("Unexpected query compatibility mode.") # Query operations will use ReadEndpoint even though it uses POST(for regular query operations) - request_params = _request_object.RequestObject(typ, documents._OperationType.SqlQuery) - req_headers = base.GetHeaders(self, initial_headers, "post", path, id_, typ, request_params.operation_type, - options, partition_key_range_id) + request_params = _request_object.RequestObject(resource_type, documents._OperationType.SqlQuery) + req_headers = base.GetHeaders(self, initial_headers, "post", path, id_, resource_type, + request_params.operation_type, options, partition_key_range_id) + await base.set_session_token_header_async(self, req_headers, path, resource_type, + request_params.operation_type, options, partition_key_range_id) # check if query has prefix partition key cont_prop = kwargs.pop("containerProperties", None) @@ -3030,7 +3046,7 @@ def _UpdateSessionIfRequired( if is_session_consistency and self.session: # update session - self.session.update_session(response_result, response_headers) + self.session.update_session(self, response_result, response_headers) PartitionResolverErrorMessage = ( "Couldn't find any partition resolvers for the database link provided. " @@ -3258,6 +3274,8 @@ async def DeleteAllItemsByPartitionKey( initial_headers = dict(self.default_headers) headers = base.GetHeaders(self, initial_headers, "post", path, collection_id, "partitionkey", documents._OperationType.Delete, options) + await base.set_session_token_header_async(self, headers, path, "partitionkey", + documents._OperationType.Delete, options) request_params = _request_object.RequestObject("partitionkey", documents._OperationType.Delete) _, last_response_headers = await self.__Post(path=path, request_params=request_params, req_headers=headers, body=None, **kwargs) diff --git a/sdk/cosmos/azure-cosmos/tests/test_session.py b/sdk/cosmos/azure-cosmos/tests/test_session.py index ab9307db3443..fde1e57be711 100644 --- a/sdk/cosmos/azure-cosmos/tests/test_session.py +++ b/sdk/cosmos/azure-cosmos/tests/test_session.py @@ -80,7 +80,7 @@ def test_clear_session_token(self): self.created_collection.read_item(item=created_document['id'], partition_key='mypk') except exceptions.CosmosHttpResponseError as e: self.assertEqual(self.client.client_connection.session.get_session_token( - 'dbs/' + self.created_db.id + '/colls/' + self.created_collection.id), "") + 'dbs/' + self.created_db.id + '/colls/' + self.created_collection.id, "Read"), "") self.assertEqual(e.status_code, StatusCodes.NOT_FOUND) self.assertEqual(e.sub_status, SubStatusCodes.READ_SESSION_NOTAVAILABLE) _retry_utility.ExecuteFunction = self.OriginalExecuteFunction diff --git a/sdk/cosmos/azure-cosmos/tests/test_session_container.py b/sdk/cosmos/azure-cosmos/tests/test_session_container.py index 2ee352571204..12d513b66455 100644 --- a/sdk/cosmos/azure-cosmos/tests/test_session_container.py +++ b/sdk/cosmos/azure-cosmos/tests/test_session_container.py @@ -35,7 +35,7 @@ def test_create_collection(self): u'id': u'sample collection'} create_collection_response_header = {'x-ms-session-token': '0:0#409#24=-1#12=-1', 'x-ms-alt-content-path': 'dbs/sample%20database'} - self.session.update_session(create_collection_response_result, create_collection_response_header) + self.session.update_session(None, create_collection_response_result, create_collection_response_header) token = self.session.get_session_token(u'/dbs/sample%20database/colls/sample%20collection') assert token == '0:0#409#24=-1#12=-1' @@ -53,7 +53,7 @@ def test_document_requests(self): 'x-ms-alt-content-path': 'dbs/sample%20database/colls/sample%20collection', 'x-ms-content-path': 'DdAkAPS2rAA='} - self.session.update_session(create_document_response_result, create_document_response_header) + self.session.update_session(None, create_document_response_result, create_document_response_header) token = self.session.get_session_token(u'dbs/DdAkAA==/colls/DdAkAPS2rAA=/docs/DdAkAPS2rAACAAAAAAAAAA==/') assert token == '0:0#406#24=-1#12=-1'