diff --git a/sdk/cosmos/azure-cosmos/CHANGELOG.md b/sdk/cosmos/azure-cosmos/CHANGELOG.md index ba4f51186488..29b17c0b1dbf 100644 --- a/sdk/cosmos/azure-cosmos/CHANGELOG.md +++ b/sdk/cosmos/azure-cosmos/CHANGELOG.md @@ -20,8 +20,11 @@ #### Bugs Fixed * Fixed how resource tokens are parsed for metadata calls in the lifecycle of a document operation. See [PR 40302](https://github.com/Azure/azure-sdk-for-python/pull/40302). * Fixed issue where Query Change Feed did not return items if the container uses legacy Hash V1 Partition Keys. This also fixes issues with not being able to change feed query for Specific Partition Key Values for HPK. See [PR 41270](https://github.com/Azure/azure-sdk-for-python/pull/41270/) +* Fixed session container compound session token logic. The SDK will now only send the relevant partition-local session tokens for each read request, as opposed to the entire compound session token for the container. See [PR 40366](https://github.com/Azure/azure-sdk-for-python/pull/40366). +* Write requests for single-write region accounts will no longer send session tokens when using session consistency. See [PR 40366](https://github.com/Azure/azure-sdk-for-python/pull/40366). #### Other Changes +* Cross-partition queries will now always send a query plan before attempting to execute. See [PR 40366](https://github.com/Azure/azure-sdk-for-python/pull/40366). * Added Client Generated Activity IDs to all Requests. Cosmos Diagnostics Logs will more clearly show the Activity ID for each request and response. [PR 41013](https://github.com/Azure/azure-sdk-for-python/pull/41013) ### 4.12.0b1 (2025-05-19) diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/_base.py b/sdk/cosmos/azure-cosmos/azure/cosmos/_base.py index 6e1bdaa99d51..b6c9ff648584 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/_base.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/_base.py @@ -44,7 +44,9 @@ if TYPE_CHECKING: from ._cosmos_client_connection import CosmosClientConnection from .aio._cosmos_client_connection_async import CosmosClientConnection as AsyncClientConnection + from ._request_object import RequestObject +# pylint: disable=protected-access _COMMON_OPTIONS = { 'initial_headers': 'initialHeaders', @@ -172,37 +174,9 @@ def GetHeaders( # pylint: disable=too-many-statements,too-many-branches if options.get("indexingDirective"): headers[http_constants.HttpHeaders.IndexingDirective] = options["indexingDirective"] - consistency_level = None - - # get default client consistency level - default_client_consistency_level = headers.get(http_constants.HttpHeaders.ConsistencyLevel) - - # set consistency level. check if set via options, this will override the default + # set request consistency level - if session consistency, the client should be setting this on its own if options.get("consistencyLevel"): - consistency_level = options["consistencyLevel"] - # TODO: move this line outside of if-else cause to remove the code duplication - headers[http_constants.HttpHeaders.ConsistencyLevel] = consistency_level - elif default_client_consistency_level is not None: - consistency_level = default_client_consistency_level - headers[http_constants.HttpHeaders.ConsistencyLevel] = consistency_level - - # figure out if consistency level for this request is session - is_session_consistency = consistency_level == documents.ConsistencyLevel.Session - - # set session token if required - if is_session_consistency is True and not IsMasterResource(resource_type): - # if there is a token set via option, then use it to override default - if options.get("sessionToken"): - headers[http_constants.HttpHeaders.SessionToken] = options["sessionToken"] - else: - # check if the client's default consistency is session (and request consistency level is same), - # then update from session container - if default_client_consistency_level == documents.ConsistencyLevel.Session and \ - cosmos_client_connection.session: - # populate session token from the client's session container - headers[http_constants.HttpHeaders.SessionToken] = cosmos_client_connection.session.get_session_token( - path - ) + headers[http_constants.HttpHeaders.ConsistencyLevel] = options["consistencyLevel"] if options.get("enableScanInQuery"): headers[http_constants.HttpHeaders.EnableScanInQuery] = options["enableScanInQuery"] @@ -345,6 +319,76 @@ def GetHeaders( # pylint: disable=too-many-statements,too-many-branches return headers +def _is_session_token_request( + cosmos_client_connection: Union["CosmosClientConnection", "AsyncClientConnection"], + headers: dict, + request_object: "RequestObject") -> bool: + consistency_level = headers.get(http_constants.HttpHeaders.ConsistencyLevel) + # Figure out if consistency level for this request is session + is_session_consistency = consistency_level == documents.ConsistencyLevel.Session + + # Verify that it is not a metadata request, and that it is either a read request, batch request, or an account + # configured to use multiple write regions + return (is_session_consistency is True and cosmos_client_connection.session is not None + and not IsMasterResource(request_object.resource_type) + and (documents._OperationType.IsReadOnlyOperation(request_object.operation_type) + or request_object.operation_type == "Batch" + or cosmos_client_connection._global_endpoint_manager.can_use_multiple_write_locations(request_object))) + + +def set_session_token_header( + cosmos_client_connection: Union["CosmosClientConnection", "AsyncClientConnection"], + headers: dict, + path: str, + request_object: "RequestObject", + options: Mapping[str, Any], + partition_key_range_id: Optional[str] = None) -> None: + # set session token if required + if _is_session_token_request(cosmos_client_connection, headers, request_object): + # if there is a token set via option, then use it to override default + if options.get("sessionToken"): + headers[http_constants.HttpHeaders.SessionToken] = options["sessionToken"] + else: + # check if the client's default consistency is session (and request consistency level is same), + # then update from session container + if headers[http_constants.HttpHeaders.ConsistencyLevel] == documents.ConsistencyLevel.Session and \ + cosmos_client_connection.session: + # populate session token from the client's session container + session_token = ( + cosmos_client_connection.session.get_session_token(path, + options.get('partitionKey'), + cosmos_client_connection._container_properties_cache, + cosmos_client_connection._routing_map_provider, + partition_key_range_id)) + if session_token != "": + headers[http_constants.HttpHeaders.SessionToken] = session_token + +async def set_session_token_header_async( + cosmos_client_connection: Union["CosmosClientConnection", "AsyncClientConnection"], + headers: dict, + path: str, + request_object: "RequestObject", + options: Mapping[str, Any], + partition_key_range_id: Optional[str] = None) -> None: + # set session token if required + if _is_session_token_request(cosmos_client_connection, headers, request_object): + # if there is a token set via option, then use it to override default + if options.get("sessionToken"): + headers[http_constants.HttpHeaders.SessionToken] = options["sessionToken"] + else: + # check if the client's default consistency is session (and request consistency level is same), + # then update from session container + if headers[http_constants.HttpHeaders.ConsistencyLevel] == documents.ConsistencyLevel.Session and \ + cosmos_client_connection.session: + # populate session token from the client's session container + session_token = \ + await cosmos_client_connection.session.get_session_token_async(path, + options.get('partitionKey'), + cosmos_client_connection._container_properties_cache, + cosmos_client_connection._routing_map_provider, + partition_key_range_id) + if session_token != "": + headers[http_constants.HttpHeaders.SessionToken] = session_token def GetResourceIdOrFullNameFromLink(resource_link: str) -> str: """Gets resource id or full name from resource link. diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/_cosmos_client_connection.py b/sdk/cosmos/azure-cosmos/azure/cosmos/_cosmos_client_connection.py index 3ee11eeccdd8..14c054883926 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/_cosmos_client_connection.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/_cosmos_client_connection.py @@ -1119,7 +1119,9 @@ def fetch_fn(options: Mapping[str, Any]) -> Tuple[List[Dict[str, Any]], CaseInse fetch_function=fetch_fn, collection_link=database_or_container_link, page_iterator_class=query_iterable.QueryIterable, - response_hook=response_hook + response_hook=response_hook, + raw_response_hook=kwargs.get('raw_response_hook'), + resource_type=http_constants.ResourceType.Document ) def QueryItemsChangeFeed( @@ -2072,6 +2074,7 @@ def PatchItem( documents._OperationType.Patch, headers) request_params.set_excluded_location_from_options(options) + base.set_session_token_header(self, headers, path, request_params, options) request_data = {} if options.get("filterPredicate"): request_data["condition"] = options.get("filterPredicate") @@ -2164,6 +2167,7 @@ def _Batch( documents._OperationType.Batch, headers) request_params.set_excluded_location_from_options(options) + base.set_session_token_header(self, headers, path, request_params, options) return cast( Tuple[List[Dict[str, Any]], CaseInsensitiveDict], self.__Post(path, request_params, batch_operations, headers, **kwargs) @@ -2234,6 +2238,7 @@ def DeleteAllItemsByPartitionKey( body=None, **kwargs ) + self._UpdateSessionIfRequired(headers, None, last_response_headers) self.last_response_headers = last_response_headers if response_hook: response_hook(last_response_headers, None) @@ -2658,7 +2663,7 @@ def Create( self, body: Dict[str, Any], path: str, - typ: str, + resource_type: str, id: Optional[str], initial_headers: Optional[Mapping[str, Any]], options: Optional[Mapping[str, Any]] = None, @@ -2668,7 +2673,7 @@ def Create( :param dict body: :param str path: - :param str typ: + :param str resource_type: :param str id: :param dict initial_headers: :param dict options: @@ -2685,12 +2690,12 @@ def Create( options = {} initial_headers = initial_headers or self.default_headers - headers = base.GetHeaders(self, initial_headers, "post", path, id, typ, documents._OperationType.Create, - options) + headers = base.GetHeaders(self, initial_headers, "post", path, id, resource_type, + documents._OperationType.Create, options) # Create will use WriteEndpoint since it uses POST operation - - request_params = RequestObject(typ, documents._OperationType.Create, headers) + request_params = RequestObject(resource_type, documents._OperationType.Create, headers) request_params.set_excluded_location_from_options(options) + base.set_session_token_header(self, headers, path, request_params, options) result, last_response_headers = self.__Post(path, request_params, body, headers, **kwargs) self.last_response_headers = last_response_headers @@ -2704,7 +2709,7 @@ def Upsert( self, body: Dict[str, Any], path: str, - typ: str, + resource_type: str, id: Optional[str], initial_headers: Optional[Mapping[str, Any]], options: Optional[Mapping[str, Any]] = None, @@ -2714,7 +2719,7 @@ def Upsert( :param dict body: :param str path: - :param str typ: + :param str resource_type: :param str id: :param dict initial_headers: :param dict options: @@ -2731,13 +2736,13 @@ def Upsert( options = {} initial_headers = initial_headers or self.default_headers - headers = base.GetHeaders(self, initial_headers, "post", path, id, typ, documents._OperationType.Upsert, - options) + headers = base.GetHeaders(self, initial_headers, "post", path, id, resource_type, + documents._OperationType.Upsert, options) headers[http_constants.HttpHeaders.IsUpsert] = True - # Upsert will use WriteEndpoint since it uses POST operation - request_params = RequestObject(typ, documents._OperationType.Upsert, headers) + request_params = RequestObject(resource_type, documents._OperationType.Upsert, headers) request_params.set_excluded_location_from_options(options) + base.set_session_token_header(self, headers, path, request_params, options) result, last_response_headers = self.__Post(path, request_params, body, headers, **kwargs) self.last_response_headers = last_response_headers # update session for write request @@ -2750,7 +2755,7 @@ def Replace( self, resource: Dict[str, Any], path: str, - typ: str, + resource_type: str, id: Optional[str], initial_headers: Optional[Mapping[str, Any]], options: Optional[Mapping[str, Any]] = None, @@ -2760,7 +2765,7 @@ def Replace( :param dict resource: :param str path: - :param str typ: + :param str resource_type: :param str id: :param dict initial_headers: :param dict options: @@ -2777,16 +2782,17 @@ def Replace( options = {} initial_headers = initial_headers or self.default_headers - headers = base.GetHeaders(self, initial_headers, "put", path, id, typ, documents._OperationType.Replace, - options) + headers = base.GetHeaders(self, initial_headers, "put", path, id, resource_type, + documents._OperationType.Replace, options) # Replace will use WriteEndpoint since it uses PUT operation - request_params = RequestObject(typ, documents._OperationType.Replace, headers) + request_params = RequestObject(resource_type, documents._OperationType.Replace, headers) request_params.set_excluded_location_from_options(options) + base.set_session_token_header(self, headers, path, request_params, options) result, last_response_headers = self.__Put(path, request_params, resource, headers, **kwargs) self.last_response_headers = last_response_headers # update session for request mutates data on server side - self._UpdateSessionIfRequired(headers, result, self.last_response_headers) + self._UpdateSessionIfRequired(headers, result, last_response_headers) if response_hook: response_hook(last_response_headers, result) return CosmosDict(result, response_headers=last_response_headers) @@ -2794,7 +2800,7 @@ def Replace( def Read( self, path: str, - typ: str, + resource_type: str, id: Optional[str], initial_headers: Optional[Mapping[str, Any]], options: Optional[Mapping[str, Any]] = None, @@ -2803,7 +2809,7 @@ def Read( """Reads an Azure Cosmos resource and returns it. :param str path: - :param str typ: + :param str resource_type: :param str id: :param dict initial_headers: :param dict options: @@ -2820,11 +2826,16 @@ def Read( options = {} initial_headers = initial_headers or self.default_headers - headers = base.GetHeaders(self, initial_headers, "get", path, id, typ, documents._OperationType.Read, options) + headers = base.GetHeaders(self, initial_headers, "get", path, id, resource_type, + documents._OperationType.Read, options) # Read will use ReadEndpoint since it uses GET operation - request_params = RequestObject(typ, documents._OperationType.Read, headers) + request_params = RequestObject(resource_type, documents._OperationType.Read, headers) request_params.set_excluded_location_from_options(options) + base.set_session_token_header(self, headers, path, request_params, options) result, last_response_headers = self.__Get(path, request_params, headers, **kwargs) + # update session for request mutates data on server side + self._UpdateSessionIfRequired(headers, result, last_response_headers) + self.last_response_headers = last_response_headers if response_hook: response_hook(last_response_headers, result) @@ -2833,7 +2844,7 @@ def Read( def DeleteResource( self, path: str, - typ: str, + resource_type: str, id: Optional[str], initial_headers: Optional[Mapping[str, Any]], options: Optional[Mapping[str, Any]] = None, @@ -2842,7 +2853,7 @@ def DeleteResource( """Deletes an Azure Cosmos resource and returns it. :param str path: - :param str typ: + :param str resource_type: :param str id: :param dict initial_headers: :param dict options: @@ -2859,16 +2870,17 @@ def DeleteResource( options = {} initial_headers = initial_headers or self.default_headers - headers = base.GetHeaders(self, initial_headers, "delete", path, id, typ, documents._OperationType.Delete, - options) + headers = base.GetHeaders(self, initial_headers, "delete", path, id, resource_type, + documents._OperationType.Delete, options) # Delete will use WriteEndpoint since it uses DELETE operation - request_params = RequestObject(typ, documents._OperationType.Delete, headers) + request_params = RequestObject(resource_type, documents._OperationType.Delete, headers) request_params.set_excluded_location_from_options(options) + base.set_session_token_header(self, headers, path, request_params, options) result, last_response_headers = self.__Delete(path, request_params, headers, **kwargs) self.last_response_headers = last_response_headers # update session for request mutates data on server side - self._UpdateSessionIfRequired(headers, result, self.last_response_headers) + self._UpdateSessionIfRequired(headers, result, last_response_headers) if response_hook: response_hook(last_response_headers, None) @@ -3095,7 +3107,7 @@ def __GetBodiesFromQueryResult(result: Dict[str, Any]) -> List[Dict[str, Any]]: initial_headers = self.default_headers.copy() # Copy to make sure that default_headers won't be changed. if query is None: - op_typ = documents._OperationType.QueryPlan if is_query_plan else documents._OperationType.ReadFeed + op_type = documents._OperationType.QueryPlan if is_query_plan else documents._OperationType.ReadFeed # Query operations will use ReadEndpoint even though it uses GET(for feed requests) headers = base.GetHeaders( self, @@ -3104,17 +3116,17 @@ def __GetBodiesFromQueryResult(result: Dict[str, Any]) -> List[Dict[str, Any]]: path, resource_id, resource_type, - op_typ, + op_type, options, partition_key_range_id ) - request_params = RequestObject( resource_type, - op_typ, + op_type, headers ) request_params.set_excluded_location_from_options(options) + base.set_session_token_header(self, headers, path, request_params, options, partition_key_range_id) change_feed_state: Optional[ChangeFeedState] = options.get("changeFeedState") if change_feed_state is not None: @@ -3131,10 +3143,6 @@ def __GetBodiesFromQueryResult(result: Dict[str, Any]) -> List[Dict[str, Any]]: query = self.__CheckAndUnifyQueryFormat(query) - initial_headers[http_constants.HttpHeaders.IsQuery] = "true" - if not is_query_plan: - initial_headers[http_constants.HttpHeaders.IsQuery] = "true" # TODO: check why we have this weird logic - if (self._query_compatibility_mode in (CosmosClientConnection._QueryCompatibilityMode.Default, CosmosClientConnection._QueryCompatibilityMode.Query)): initial_headers[http_constants.HttpHeaders.ContentType] = runtime_constants.MediaTypes.QueryJson @@ -3155,9 +3163,11 @@ def __GetBodiesFromQueryResult(result: Dict[str, Any]) -> List[Dict[str, Any]]: options, partition_key_range_id ) - request_params = RequestObject(resource_type, documents._OperationType.SqlQuery, req_headers) request_params.set_excluded_location_from_options(options) + if not is_query_plan: + req_headers[http_constants.HttpHeaders.IsQuery] = "true" + base.set_session_token_header(self, req_headers, path, request_params, options, partition_key_range_id) # check if query has prefix partition key isPrefixPartitionQuery = kwargs.pop("isPrefixPartitionQuery", None) @@ -3210,6 +3220,7 @@ def __GetBodiesFromQueryResult(result: Dict[str, Any]) -> List[Dict[str, Any]]: path, request_params, query, req_headers, **kwargs ) self.last_response_headers = last_response_headers + self._UpdateSessionIfRequired(req_headers, partial_result, last_response_headers) if results: # add up all the query results from all over lapping ranges results["Documents"].extend(partial_result["Documents"]) @@ -3222,12 +3233,12 @@ def __GetBodiesFromQueryResult(result: Dict[str, Any]) -> List[Dict[str, Any]]: return __GetBodiesFromQueryResult(results), last_response_headers result, last_response_headers = self.__Post(path, request_params, query, req_headers, **kwargs) + self.last_response_headers = last_response_headers + self._UpdateSessionIfRequired(req_headers, result, last_response_headers) if last_response_headers.get(http_constants.HttpHeaders.IndexUtilization) is not None: INDEX_METRICS_HEADER = http_constants.HttpHeaders.IndexUtilization index_metrics_raw = last_response_headers[INDEX_METRICS_HEADER] last_response_headers[INDEX_METRICS_HEADER] = _utils.get_index_metrics_info(index_metrics_raw) - self.last_response_headers = last_response_headers - if response_hook: response_hook(last_response_headers, result) @@ -3264,7 +3275,6 @@ def _GetQueryPlanThroughGateway(self, query: str, resource_link: str, excluded_l } if excluded_locations is not None: options["excludedLocations"] = excluded_locations - resource_link = base.TrimBeginningAndEndingSlashes(resource_link) path = base.GetPathFromLink(resource_link, http_constants.ResourceType.Document) resource_id = base.GetResourceIdOrFullNameFromLink(resource_link) @@ -3419,9 +3429,10 @@ def _UpdateSessionIfRequired( if documents.ConsistencyLevel.Session == request_headers[http_constants.HttpHeaders.ConsistencyLevel]: is_session_consistency = True - if is_session_consistency and self.session: + if (is_session_consistency and self.session and + not base.IsMasterResource(request_headers[http_constants.HttpHeaders.ThinClientProxyResourceType])): # update session - self.session.update_session(response_result, response_headers) + self.session.update_session(self, response_result, response_headers) def _get_partition_key_definition( self, diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/_execution_context/aio/document_producer.py b/sdk/cosmos/azure-cosmos/azure/cosmos/_execution_context/aio/document_producer.py index 382e0ee0f2b4..7584ef142cbd 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/_execution_context/aio/document_producer.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/_execution_context/aio/document_producer.py @@ -41,7 +41,7 @@ class _DocumentProducer(object): """ def __init__(self, partition_key_target_range, client, collection_link, query, document_producer_comp, options, - response_hook): + response_hook, raw_response_hook): """ Constructor """ @@ -62,7 +62,7 @@ def __init__(self, partition_key_target_range, client, collection_link, query, d async def fetch_fn(options): return await self._client.QueryFeed(path, collection_id, query, options, partition_key_target_range["id"], - response_hook=response_hook) + response_hook=response_hook, raw_response_hook=raw_response_hook) self._ex_context = _DefaultQueryExecutionContext(client, self._options, fetch_fn) diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/_execution_context/aio/execution_dispatcher.py b/sdk/cosmos/azure-cosmos/azure/cosmos/_execution_context/aio/execution_dispatcher.py index dc4046bff2f4..56fc2ddd89e9 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/_execution_context/aio/execution_dispatcher.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/_execution_context/aio/execution_dispatcher.py @@ -47,7 +47,8 @@ class _ProxyQueryExecutionContext(_QueryExecutionContextBase): # pylint: disabl to _MultiExecutionContextAggregator """ - def __init__(self, client, resource_link, query, options, fetch_function, response_hook): + def __init__(self, client, resource_link, query, options, fetch_function, + response_hook, raw_response_hook, resource_type): """ Constructor """ @@ -57,7 +58,17 @@ def __init__(self, client, resource_link, query, options, fetch_function, respon self._resource_link = resource_link self._query = query self._fetch_function = fetch_function + self._resource_type = resource_type self._response_hook = response_hook + self._raw_response_hook = raw_response_hook + self._fetched_query_plan = False + + async def _create_execution_context_with_query_plan(self): + self._fetched_query_plan = True + query_to_use = self._query if self._query is not None else "Select * from root r" + query_execution_info = _PartitionedQueryExecutionInfo(await self._client._GetQueryPlanThroughGateway + (query_to_use, self._resource_link, self._options.get('excludedLocations'))) + self._execution_context = await self._create_pipelined_execution_context(query_execution_info) async def __anext__(self): """Returns the next query result. @@ -70,12 +81,8 @@ async def __anext__(self): try: return await self._execution_context.__anext__() except CosmosHttpResponseError as e: - if _is_partitioned_execution_info(e): - query_to_use = self._query if self._query is not None else "Select * from root r" - query_plan_dict = await self._client._GetQueryPlanThroughGateway( - query_to_use, self._resource_link, self._options.get('excludedLocations')) - query_execution_info = _PartitionedQueryExecutionInfo(query_plan_dict) - self._execution_context = await self._create_pipelined_execution_context(query_execution_info) + if _is_partitioned_execution_info(e) or _is_hybrid_search_query(self._query, e): + await self._create_execution_context_with_query_plan() else: raise e @@ -94,11 +101,7 @@ async def fetch_next_block(self): return await self._execution_context.fetch_next_block() except CosmosHttpResponseError as e: if _is_partitioned_execution_info(e) or _is_hybrid_search_query(self._query, e): - query_to_use = self._query if self._query is not None else "Select * from root r" - query_plan_dict = await self._client._GetQueryPlanThroughGateway( - query_to_use, self._resource_link, self._options.get('excludedLocations')) - query_execution_info = _PartitionedQueryExecutionInfo(query_plan_dict) - self._execution_context = await self._create_pipelined_execution_context(query_execution_info) + await self._create_execution_context_with_query_plan() else: raise e @@ -112,6 +115,8 @@ async def _create_pipelined_execution_context(self, query_execution_info): and self._options["enableCrossPartitionQuery"]): raise CosmosHttpResponseError(StatusCodes.BAD_REQUEST, "Cross partition query only supports 'VALUE ' for aggregates") + if self._fetched_query_plan: + self._options.pop("enableCrossPartitionQuery", None) # throw exception here for vector search query without limit filter or limit > max_limit if query_execution_info.get_non_streaming_order_by(): @@ -131,7 +136,8 @@ async def _create_pipelined_execution_context(self, query_execution_info): self._query, self._options, query_execution_info, - self._response_hook) + self._response_hook, + self._raw_response_hook) await execution_context_aggregator._configure_partition_ranges() elif query_execution_info.has_hybrid_search_query_info(): hybrid_search_query_info = query_execution_info._query_execution_info['hybridSearchQueryInfo'] @@ -142,15 +148,13 @@ async def _create_pipelined_execution_context(self, query_execution_info): self._options, query_execution_info, hybrid_search_query_info, - self._response_hook) + self._response_hook, + self._raw_response_hook) await execution_context_aggregator._run_hybrid_search() else: - execution_context_aggregator = multi_execution_aggregator._MultiExecutionContextAggregator(self._client, - self._resource_link, - self._query, - self._options, - query_execution_info, - self._response_hook) + execution_context_aggregator = multi_execution_aggregator._MultiExecutionContextAggregator( + self._client, self._resource_link, self._query, self._options, query_execution_info, + self._response_hook, self._raw_response_hook) await execution_context_aggregator._configure_partition_ranges() return _PipelineExecutionContext(self._client, self._options, execution_context_aggregator, query_execution_info) diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/_execution_context/aio/hybrid_search_aggregator.py b/sdk/cosmos/azure-cosmos/azure/cosmos/_execution_context/aio/hybrid_search_aggregator.py index 2725a1d1a85a..a403bdbc4062 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/_execution_context/aio/hybrid_search_aggregator.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/_execution_context/aio/hybrid_search_aggregator.py @@ -45,7 +45,7 @@ class _HybridSearchContextAggregator(_QueryExecutionContextBase): """ def __init__(self, client, resource_link, options, partitioned_query_execution_info, - hybrid_search_query_info, response_hook): + hybrid_search_query_info, response_hook, raw_response_hook): super(_HybridSearchContextAggregator, self).__init__(client, options) # use the routing provider in the client @@ -58,6 +58,7 @@ def __init__(self, client, resource_link, options, partitioned_query_execution_i self._aggregated_global_statistics = None self._document_producer_comparator = None self._response_hook = response_hook + self._raw_response_hook = raw_response_hook async def _run_hybrid_search(self): # pylint: disable=too-many-branches, too-many-statements # Check if we need to run global statistics queries, and if so do for every partition in the container @@ -76,7 +77,8 @@ async def _run_hybrid_search(self): # pylint: disable=too-many-branches, too-ma global_statistics_query, self._document_producer_comparator, self._options, - self._response_hook + self._response_hook, + self._raw_response_hook ) ) @@ -119,7 +121,8 @@ async def _run_hybrid_search(self): # pylint: disable=too-many-branches, too-ma rewritten_query['rewrittenQuery'], self._document_producer_comparator, self._options, - self._response_hook + self._response_hook, + self._raw_response_hook ) ) # verify all document producers have items/ no splits @@ -236,7 +239,8 @@ async def _repair_document_producer(self, query, target_all_ranges=False): query, self._document_producer_comparator, self._options, - self._response_hook + self._response_hook, + self._raw_response_hook ) ) diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/_execution_context/aio/multi_execution_aggregator.py b/sdk/cosmos/azure-cosmos/azure/cosmos/_execution_context/aio/multi_execution_aggregator.py index 9f6e2b51df00..7d763b6f55cf 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/_execution_context/aio/multi_execution_aggregator.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/_execution_context/aio/multi_execution_aggregator.py @@ -62,7 +62,8 @@ def peek(self): def size(self): return len(self._heap) - def __init__(self, client, resource_link, query, options, partitioned_query_ex_info, response_hook): + def __init__(self, client, resource_link, query, options, partitioned_query_ex_info, + response_hook, raw_response_hook): super(_MultiExecutionContextAggregator, self).__init__(client, options) # use the routing provider in the client @@ -73,6 +74,7 @@ def __init__(self, client, resource_link, query, options, partitioned_query_ex_i self._partitioned_query_ex_info = partitioned_query_ex_info self._sort_orders = partitioned_query_ex_info.get_order_by() self._response_hook = response_hook + self._raw_response_hook = raw_response_hook if self._sort_orders: self._document_producer_comparator = document_producer._OrderByDocumentProducerComparator(self._sort_orders) @@ -155,7 +157,8 @@ def _createTargetPartitionQueryExecutionContext(self, partition_key_target_range query, self._document_producer_comparator, self._options, - self._response_hook + self._response_hook, + self._raw_response_hook ) async def _get_target_partition_key_range(self): diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/_execution_context/aio/non_streaming_order_by_aggregator.py b/sdk/cosmos/azure-cosmos/azure/cosmos/_execution_context/aio/non_streaming_order_by_aggregator.py index 1a6ed820d80c..20e0471ed68c 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/_execution_context/aio/non_streaming_order_by_aggregator.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/_execution_context/aio/non_streaming_order_by_aggregator.py @@ -22,7 +22,8 @@ class _NonStreamingOrderByContextAggregator(_QueryExecutionContextBase): by the user. """ - def __init__(self, client, resource_link, query, options, partitioned_query_ex_info, response_hook): + def __init__(self, client, resource_link, query, options, partitioned_query_ex_info, + response_hook, raw_response_hook): super(_NonStreamingOrderByContextAggregator, self).__init__(client, options) # use the routing provider in the client @@ -31,11 +32,12 @@ def __init__(self, client, resource_link, query, options, partitioned_query_ex_i self._resource_link = resource_link self._query = query self._partitioned_query_ex_info = partitioned_query_ex_info - self._sort_orders = partitioned_query_ex_info.get_order_by() self._orderByPQ = _MultiExecutionContextAggregator.PriorityQueue() self._doc_producers = [] - self._document_producer_comparator = document_producer._NonStreamingOrderByComparator(self._sort_orders) + self._document_producer_comparator = ( + document_producer._NonStreamingOrderByComparator(partitioned_query_ex_info.get_order_by())) self._response_hook = response_hook + self._raw_response_hook = raw_response_hook async def __anext__(self): @@ -100,7 +102,8 @@ def _createTargetPartitionQueryExecutionContext(self, partition_key_target_range query, self._document_producer_comparator, self._options, - self._response_hook + self._response_hook, + self._raw_response_hook ) async def _get_target_partition_key_range(self): @@ -138,11 +141,12 @@ async def _configure_partition_ranges(self): pq_size = self._partitioned_query_ex_info.get_top() or\ self._partitioned_query_ex_info.get_limit() + self._partitioned_query_ex_info.get_offset() + sort_orders = self._partitioned_query_ex_info.get_order_by() for doc_producer in self._doc_producers: while True: try: result = await doc_producer.peek() - item_result = document_producer._NonStreamingItemResultProducer(result, self._sort_orders) + item_result = document_producer._NonStreamingItemResultProducer(result, sort_orders) await self._orderByPQ.push_async(item_result, self._document_producer_comparator) await doc_producer.__anext__() except StopAsyncIteration: diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/_execution_context/document_producer.py b/sdk/cosmos/azure-cosmos/azure/cosmos/_execution_context/document_producer.py index dc01334f1905..f77504d3e9a1 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/_execution_context/document_producer.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/_execution_context/document_producer.py @@ -40,7 +40,7 @@ class _DocumentProducer(object): """ def __init__(self, partition_key_target_range, client, collection_link, query, document_producer_comp, options, - response_hook): + response_hook, raw_response_hook): """ Constructor """ @@ -61,7 +61,7 @@ def __init__(self, partition_key_target_range, client, collection_link, query, d def fetch_fn(options): return self._client.QueryFeed(path, collection_id, query, options, partition_key_target_range["id"], - response_hook=response_hook) + response_hook=response_hook, raw_response_hook=raw_response_hook) self._ex_context = _DefaultQueryExecutionContext(client, self._options, fetch_fn) diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/_execution_context/execution_dispatcher.py b/sdk/cosmos/azure-cosmos/azure/cosmos/_execution_context/execution_dispatcher.py index 7cedfd23c7df..cbe6d67a0da9 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/_execution_context/execution_dispatcher.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/_execution_context/execution_dispatcher.py @@ -77,7 +77,8 @@ class _ProxyQueryExecutionContext(_QueryExecutionContextBase): # pylint: disabl to _MultiExecutionContextAggregator """ - def __init__(self, client, resource_link, query, options, fetch_function, response_hook): + def __init__(self, client, resource_link, query, options, fetch_function, response_hook, + raw_response_hook, resource_type): """ Constructor """ @@ -87,7 +88,17 @@ def __init__(self, client, resource_link, query, options, fetch_function, respon self._resource_link = resource_link self._query = query self._fetch_function = fetch_function + self._resource_type = resource_type self._response_hook = response_hook + self._raw_response_hook = raw_response_hook + self._fetched_query_plan = False + + def _create_execution_context_with_query_plan(self): + self._fetched_query_plan = True + query_to_use = self._query if self._query is not None else "Select * from root r" + query_execution_info = _PartitionedQueryExecutionInfo(self._client._GetQueryPlanThroughGateway + (query_to_use, self._resource_link, self._options.get('excludedLocations'))) + self._execution_context = self._create_pipelined_execution_context(query_execution_info) def __next__(self): """Returns the next query result. @@ -100,12 +111,8 @@ def __next__(self): try: return next(self._execution_context) except CosmosHttpResponseError as e: - if _is_partitioned_execution_info(e): - query_to_use = self._query if self._query is not None else "Select * from root r" - query_plan_dict = self._client._GetQueryPlanThroughGateway( - query_to_use, self._resource_link, self._options.get('excludedLocations')) - query_execution_info = _PartitionedQueryExecutionInfo(query_plan_dict) - self._execution_context = self._create_pipelined_execution_context(query_execution_info) + if _is_partitioned_execution_info(e) or _is_hybrid_search_query(self._query, e): + self._create_execution_context_with_query_plan() else: raise e @@ -120,18 +127,11 @@ def fetch_next_block(self): :return: List of results. :rtype: list """ - # TODO: NEED to change this - make every query retrieve a query plan - # also, we can't have this logic being returned to so often - there should be no need for this - # need to split up query plan logic and actual query iterating logic try: return self._execution_context.fetch_next_block() except CosmosHttpResponseError as e: if _is_partitioned_execution_info(e) or _is_hybrid_search_query(self._query, e): - query_to_use = self._query if self._query is not None else "Select * from root r" - query_plan_dict = self._client._GetQueryPlanThroughGateway( - query_to_use, self._resource_link, self._options.get('excludedLocations')) - query_execution_info = _PartitionedQueryExecutionInfo(query_plan_dict) - self._execution_context = self._create_pipelined_execution_context(query_execution_info) + self._create_execution_context_with_query_plan() else: raise e @@ -145,6 +145,8 @@ def _create_pipelined_execution_context(self, query_execution_info): raise CosmosHttpResponseError( StatusCodes.BAD_REQUEST, "Cross partition query only supports 'VALUE ' for aggregates") + # if self._fetched_query_plan: + # self._options.pop("enableCrossPartitionQuery", None) # throw exception here for vector search query without limit filter or limit > max_limit if query_execution_info.get_non_streaming_order_by(): @@ -164,7 +166,8 @@ def _create_pipelined_execution_context(self, query_execution_info): self._query, self._options, query_execution_info, - self._response_hook) + self._response_hook, + self._raw_response_hook) elif query_execution_info.has_hybrid_search_query_info(): hybrid_search_query_info = query_execution_info._query_execution_info['hybridSearchQueryInfo'] _verify_valid_hybrid_search_query(hybrid_search_query_info) @@ -174,7 +177,8 @@ def _create_pipelined_execution_context(self, query_execution_info): self._options, query_execution_info, hybrid_search_query_info, - self._response_hook) + self._response_hook, + self._raw_response_hook) execution_context_aggregator._run_hybrid_search() else: execution_context_aggregator = \ @@ -183,7 +187,9 @@ def _create_pipelined_execution_context(self, query_execution_info): self._query, self._options, query_execution_info, - self._response_hook) + self._response_hook, + self._raw_response_hook) + execution_context_aggregator._configure_partition_ranges() return _PipelineExecutionContext(self._client, self._options, execution_context_aggregator, query_execution_info) diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/_execution_context/hybrid_search_aggregator.py b/sdk/cosmos/azure-cosmos/azure/cosmos/_execution_context/hybrid_search_aggregator.py index a59f3ac28ee9..f9da4225a0a5 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/_execution_context/hybrid_search_aggregator.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/_execution_context/hybrid_search_aggregator.py @@ -151,7 +151,7 @@ class _HybridSearchContextAggregator(_QueryExecutionContextBase): """ def __init__(self, client, resource_link, options, - partitioned_query_execution_info, hybrid_search_query_info, response_hook): + partitioned_query_execution_info, hybrid_search_query_info, response_hook, raw_response_hook): super(_HybridSearchContextAggregator, self).__init__(client, options) # use the routing provider in the client @@ -164,6 +164,7 @@ def __init__(self, client, resource_link, options, self._aggregated_global_statistics = None self._document_producer_comparator = None self._response_hook = response_hook + self._raw_response_hook = raw_response_hook def _run_hybrid_search(self): # pylint: disable=too-many-branches, too-many-statements # Check if we need to run global statistics queries, and if so do for every partition in the container @@ -182,7 +183,8 @@ def _run_hybrid_search(self): # pylint: disable=too-many-branches, too-many-sta global_statistics_query, self._document_producer_comparator, self._options, - self._response_hook + self._response_hook, + self._raw_response_hook ) ) @@ -224,7 +226,8 @@ def _run_hybrid_search(self): # pylint: disable=too-many-branches, too-many-sta rewritten_query['rewrittenQuery'], self._document_producer_comparator, self._options, - self._response_hook + self._response_hook, + self._raw_response_hook ) ) # verify all document producers have items/ no splits @@ -363,7 +366,8 @@ def _repair_document_producer(self, query, target_all_ranges=False): query, self._document_producer_comparator, self._options, - self._response_hook + self._response_hook, + self._raw_response_hook ) ) diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/_execution_context/multi_execution_aggregator.py b/sdk/cosmos/azure-cosmos/azure/cosmos/_execution_context/multi_execution_aggregator.py index b7330b74c095..8c24ab41721a 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/_execution_context/multi_execution_aggregator.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/_execution_context/multi_execution_aggregator.py @@ -63,7 +63,8 @@ def peek(self): def size(self): return len(self._heap) - def __init__(self, client, resource_link, query, options, partitioned_query_ex_info, response_hook): + def __init__(self, client, resource_link, query, options, partitioned_query_ex_info, + response_hook, raw_response_hook): super(_MultiExecutionContextAggregator, self).__init__(client, options) # use the routing provider in the client @@ -74,42 +75,15 @@ def __init__(self, client, resource_link, query, options, partitioned_query_ex_i self._partitioned_query_ex_info = partitioned_query_ex_info self._sort_orders = partitioned_query_ex_info.get_order_by() self._response_hook = response_hook + self._raw_response_hook = raw_response_hook if self._sort_orders: self._document_producer_comparator = document_producer._OrderByDocumentProducerComparator(self._sort_orders) else: self._document_producer_comparator = document_producer._PartitionKeyRangeDocumentProducerComparator() - # will be a list of (partition_min, partition_max) tuples - targetPartitionRanges = self._get_target_partition_key_range() - - targetPartitionQueryExecutionContextList = [] - for partitionTargetRange in targetPartitionRanges: - # create and add the child execution context for the target range - targetPartitionQueryExecutionContextList.append( - self._createTargetPartitionQueryExecutionContext(partitionTargetRange) - ) - self._orderByPQ = _MultiExecutionContextAggregator.PriorityQueue() - for targetQueryExContext in targetPartitionQueryExecutionContextList: - try: - # TODO: we can also use more_itertools.peekable to be more python friendly - targetQueryExContext.peek() - # if there are matching results in the target ex range add it to the priority queue - - self._orderByPQ.push(targetQueryExContext) - - except exceptions.CosmosHttpResponseError as e: - if exceptions._partition_range_is_gone(e): - # repairing document producer context on partition split - self._repair_document_producer() - else: - raise - - except StopIteration: - continue - def __next__(self): """Returns the next result @@ -137,6 +111,35 @@ def fetch_next_block(self): raise NotImplementedError("You should use pipeline's fetch_next_block.") + def _configure_partition_ranges(self): + # will be a list of (partition_min, partition_max) tuples + targetPartitionRanges = self._get_target_partition_key_range() + + targetPartitionQueryExecutionContextList = [] + for partitionTargetRange in targetPartitionRanges: + # create and add the child execution context for the target range + targetPartitionQueryExecutionContextList.append( + self._createTargetPartitionQueryExecutionContext(partitionTargetRange) + ) + + for targetQueryExContext in targetPartitionQueryExecutionContextList: + try: + # TODO: we can also use more_itertools.peekable to be more python friendly + targetQueryExContext.peek() + # if there are matching results in the target ex range add it to the priority queue + + self._orderByPQ.push(targetQueryExContext) + + except exceptions.CosmosHttpResponseError as e: + if exceptions._partition_range_is_gone(e): + # repairing document producer context on partition split + self._repair_document_producer() + else: + raise + + except StopIteration: + continue + def _repair_document_producer(self): """Repairs the document producer context by using the re-initialized routing map provider in the client, which loads in a refreshed partition key range cache to re-create the partition key ranges. @@ -187,7 +190,8 @@ def _createTargetPartitionQueryExecutionContext(self, partition_key_target_range query, self._document_producer_comparator, self._options, - self._response_hook + self._response_hook, + self._raw_response_hook ) def _get_target_partition_key_range(self): diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/_execution_context/non_streaming_order_by_aggregator.py b/sdk/cosmos/azure-cosmos/azure/cosmos/_execution_context/non_streaming_order_by_aggregator.py index 0bfc514e00d8..c07864d7d767 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/_execution_context/non_streaming_order_by_aggregator.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/_execution_context/non_streaming_order_by_aggregator.py @@ -22,7 +22,8 @@ class _NonStreamingOrderByContextAggregator(_QueryExecutionContextBase): by the user. """ - def __init__(self, client, resource_link, query, options, partitioned_query_ex_info, response_hook): + def __init__(self, client, resource_link, query, options, partitioned_query_ex_info, + response_hook, raw_response_hook): super(_NonStreamingOrderByContextAggregator, self).__init__(client, options) # use the routing provider in the client @@ -31,14 +32,15 @@ def __init__(self, client, resource_link, query, options, partitioned_query_ex_i self._resource_link = resource_link self._query = query self._partitioned_query_ex_info = partitioned_query_ex_info - self._sort_orders = partitioned_query_ex_info.get_order_by() self._orderByPQ = _MultiExecutionContextAggregator.PriorityQueue() self._response_hook = response_hook + self._raw_response_hook = raw_response_hook # will be a list of (partition_min, partition_max) tuples targetPartitionRanges = self._get_target_partition_key_range() - self._document_producer_comparator = document_producer._OrderByDocumentProducerComparator(self._sort_orders) + sort_orders = partitioned_query_ex_info.get_order_by() + self._document_producer_comparator = document_producer._OrderByDocumentProducerComparator(sort_orders) targetPartitionQueryExecutionContextList = [] for partitionTargetRange in targetPartitionRanges: @@ -68,7 +70,7 @@ def __init__(self, client, resource_link, query, options, partitioned_query_ex_i while True: try: result = doc_producer.peek() - item_result = document_producer._NonStreamingItemResultProducer(result, self._sort_orders) + item_result = document_producer._NonStreamingItemResultProducer(result, sort_orders) self._orderByPQ.push(item_result) next(doc_producer) except StopIteration: @@ -143,7 +145,8 @@ def _createTargetPartitionQueryExecutionContext(self, partition_key_target_range query, self._document_producer_comparator, self._options, - self._response_hook + self._response_hook, + self._raw_response_hook ) def _get_target_partition_key_range(self): diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/_query_iterable.py b/sdk/cosmos/azure-cosmos/azure/cosmos/_query_iterable.py index 6663628dad5f..be06b24478a8 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/_query_iterable.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/_query_iterable.py @@ -43,7 +43,9 @@ def __init__( database_link=None, partition_key=None, continuation_token=None, + resource_type=None, response_hook=None, + raw_response_hook=None, ): """Instantiates a QueryIterable for non-client side partitioning queries. @@ -54,7 +56,7 @@ def __init__( :param (str or dict) query: :param dict options: The request options for the request. :param method fetch_function: - :param method resource_type: The type of the resource being queried + :param str resource_type: The type of the resource being queried :param str resource_link: If this is a Document query/feed collection_link is required. Example of `fetch_function`: @@ -74,7 +76,8 @@ def __init__( self._database_link = database_link self._partition_key = partition_key self._ex_context = execution_dispatcher._ProxyQueryExecutionContext( - self._client, self._collection_link, self._query, self._options, self._fetch_function, response_hook) + self._client, self._collection_link, self._query, self._options, self._fetch_function, + response_hook, raw_response_hook, resource_type) super(QueryIterable, self).__init__(self._fetch_next, self._unpack, continuation_token=continuation_token) def _unpack(self, block): diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/_session.py b/sdk/cosmos/azure-cosmos/azure/cosmos/_session.py index 84dd5914e208..d07e6f6ba2e0 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/_session.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/_session.py @@ -25,12 +25,17 @@ import sys import traceback import threading +from typing import Any, Dict, Optional from . import _base from . import http_constants +from ._routing.routing_map_provider import SmartRoutingMapProvider +from ._routing.aio.routing_map_provider import SmartRoutingMapProvider as SmartRoutingMapProviderAsync from ._vector_session_token import VectorSessionToken from .exceptions import CosmosHttpResponseError +from .partition_key import PartitionKey +# pylint: disable=protected-access,too-many-nested-blocks class SessionContainer(object): def __init__(self): @@ -38,17 +43,29 @@ def __init__(self): self.rid_to_session_token = {} self.session_lock = threading.RLock() - def get_session_token(self, resource_path): - """Get Session Token for collection_link. + def get_session_token( + self, + resource_path: str, + pk_value: Any, + container_properties_cache: Dict[str, Dict[str, Any]], + routing_map_provider: SmartRoutingMapProvider, + partition_key_range_id: Optional[int]) -> str: + """Get Session Token for the given collection and partition key information. :param str resource_path: Self link / path to the resource - :return: Session Token dictionary for the collection_id - :rtype: dict + :param ~azure.cosmos.SmartRoutingMapProvider routing_map_provider: routing map containing relevant session + information, such as partition key ranges for a given collection + :param Any pk_value: The partition key value being used for the operation + :param container_properties_cache: Container properties cache used to fetch partition key definitions + :type container_properties_cache: Dict[str, Dict[str, Any]] + :param int partition_key_range_id: The partition key range ID used for the operation + :return: Session Token dictionary for the collection_id, will be empty string if not found or if the operation + does not require a session token (single master write operations). + :rtype: str """ with self.session_lock: is_name_based = _base.IsNameBased(resource_path) - collection_rid = "" session_token = "" try: @@ -59,23 +76,117 @@ def get_session_token(self, resource_path): else: collection_rid = _base.GetItemContainerLink(resource_path) - if collection_rid in self.rid_to_session_token: + if collection_rid in self.rid_to_session_token and collection_name in container_properties_cache: token_dict = self.rid_to_session_token[collection_rid] - session_token_list = [] - for key in token_dict.keys(): - session_token_list.append("{0}:{1}".format(key, token_dict[key].convert_to_string())) - session_token = ",".join(session_token_list) + if partition_key_range_id is not None: + # if we find a cached session token for the relevant pk range id, use that session token + if token_dict.get(partition_key_range_id): + vector_session_token = token_dict.get(partition_key_range_id) + session_token = "{0}:{1}".format(partition_key_range_id, vector_session_token.session_token) + # if we don't find it, we do a session token merge for the parent pk ranges + # this should only happen immediately after a partition split + else: + container_routing_map = \ + routing_map_provider._collection_routing_map_by_item[collection_name] + current_range = container_routing_map._rangeById.get(partition_key_range_id) + if current_range is not None: + vector_session_token = self._resolve_partition_local_session_token(current_range, + token_dict) + if vector_session_token is not None: + session_token = "{0}:{1}".format(partition_key_range_id, vector_session_token) + elif pk_value is not None: + collection_pk_definition = container_properties_cache[collection_name]["partitionKey"] + partition_key = PartitionKey(path=collection_pk_definition['paths'], + kind=collection_pk_definition['kind'], + version=collection_pk_definition['version']) + epk_range = partition_key._get_epk_range_for_partition_key(pk_value=pk_value) + pk_range = routing_map_provider.get_overlapping_ranges(collection_name, [epk_range]) + if len(pk_range) > 0: + partition_key_range_id = pk_range[0]['id'] + vector_session_token = self._resolve_partition_local_session_token(pk_range, token_dict) + if vector_session_token is not None: + session_token = "{0}:{1}".format(partition_key_range_id, vector_session_token) return session_token + return "" + except Exception: # pylint: disable=broad-except + return "" + + async def get_session_token_async( + self, + resource_path: str, + pk_value: Any, + container_properties_cache: Dict[str, Dict[str, Any]], + routing_map_provider: SmartRoutingMapProviderAsync, + partition_key_range_id: Optional[str]) -> str: + """Get Session Token for the given collection and partition key information. + + :param str resource_path: Self link / path to the resource + :param ~azure.cosmos.SmartRoutingMapProviderAsync routing_map_provider: routing map containing relevant session + information, such as partition key ranges for a given collection + :param Any pk_value: The partition key value being used for the operation + :param container_properties_cache: Container properties cache used to fetch partition key definitions + :type container_properties_cache: Dict[str, Dict[str, Any]] + :param Any routing_map_provider: The routing map provider containing the partition key range cache logic + :param str partition_key_range_id: The partition key range ID used for the operation + :return: Session Token dictionary for the collection_id, will be empty string if not found or if the operation + does not require a session token (single master write operations). + :rtype: str + """ + + with self.session_lock: + is_name_based = _base.IsNameBased(resource_path) + session_token = "" - # return empty token if not found + try: + if is_name_based: + # get the collection name + collection_name = _base.GetItemContainerLink(resource_path) + collection_rid = self.collection_name_to_rid[collection_name] + else: + collection_rid = _base.GetItemContainerLink(resource_path) + + if collection_rid in self.rid_to_session_token and collection_name in container_properties_cache: + token_dict = self.rid_to_session_token[collection_rid] + if partition_key_range_id is not None: + # if we find a cached session token for the relevant pk range id, use that session token + if token_dict.get(partition_key_range_id): + vector_session_token = token_dict.get(partition_key_range_id) + session_token = "{0}:{1}".format(partition_key_range_id, + vector_session_token.session_token) + # if we don't find it, we do a session token merge for the parent pk ranges + # this should only happen immediately after a partition split + else: + container_routing_map = \ + routing_map_provider._collection_routing_map_by_item[collection_name] + current_range = container_routing_map._rangeById.get(partition_key_range_id) + if current_range is not None: + vector_session_token = self._resolve_partition_local_session_token(current_range, + token_dict) + if vector_session_token is not None: + session_token = "{0}:{1}".format(partition_key_range_id, vector_session_token) + elif pk_value is not None: + collection_pk_definition = container_properties_cache[collection_name]["partitionKey"] + partition_key = PartitionKey(path=collection_pk_definition['paths'], + kind=collection_pk_definition['kind'], + version=collection_pk_definition['version']) + epk_range = partition_key._get_epk_range_for_partition_key(pk_value=pk_value) + pk_range = await routing_map_provider.get_overlapping_ranges(collection_name, [epk_range]) + if len(pk_range) > 0: + partition_key_range_id = pk_range[0]['id'] + vector_session_token = self._resolve_partition_local_session_token(pk_range, token_dict) + if vector_session_token is not None: + session_token = "{0}:{1}".format(partition_key_range_id, vector_session_token) + return session_token return "" except Exception: # pylint: disable=broad-except return "" - def set_session_token(self, response_result, response_headers): + def set_session_token(self, client_connection, response_result, response_headers): """Session token must only be updated from response of requests that successfully mutate resource on the server side (write, replace, delete etc). + :param client_connection: Client connection used to refresh the partition key range cache if needed + :type client_connection: Union[azure.cosmos.CosmosClientConnection, azure.cosmos.aio.CosmosClientConnection] :param dict response_result: :param dict response_headers: :return: None @@ -86,11 +197,12 @@ def set_session_token(self, response_result, response_headers): # x-ms-alt-content-path which is the string representation of the resource with self.session_lock: - collection_rid = "" - collection_name = "" - try: - self_link = response_result["_self"] + self_link = response_result.get("_self") + # query results don't directly have a self_link - need to fetch it directly from one of the items + if self_link is None: + if 'Documents' in response_result and len(response_result['Documents']) > 0: + self_link = response_result['Documents'][0].get('_self') # extract alternate content path from the response_headers # (only document level resource updates will have this), @@ -102,13 +214,25 @@ def set_session_token(self, response_result, response_headers): response_result_id = None if alt_content_path_key in response_headers: alt_content_path = response_headers[http_constants.HttpHeaders.AlternateContentPath] - response_result_id = response_result[response_result_id_key] + if response_result_id_key in response_result: + response_result_id = response_result[response_result_id_key] else: return - collection_rid, collection_name = _base.GetItemContainerInfo( - self_link, alt_content_path, response_result_id - ) - + if self_link is not None: + collection_rid, collection_name = _base.GetItemContainerInfo(self_link, alt_content_path, + response_result_id) + else: + # if for whatever reason we don't have a _self link at this point, we use the container name + collection_name = alt_content_path + collection_rid = self.collection_name_to_rid.get(collection_name) + # if the response came in with a new partition key range id after a split, refresh the pk range cache + partition_key_range_id = response_headers.get(http_constants.HttpHeaders.PartitionKeyRangeID) + collection_ranges = None + if client_connection: + collection_ranges = \ + client_connection._routing_map_provider._collection_routing_map_by_item.get(collection_name) + if collection_ranges and not collection_ranges._rangeById.get(partition_key_range_id): + client_connection.refresh_routing_map_provider() except ValueError: return except Exception: # pylint: disable=broad-except @@ -143,10 +267,9 @@ def set_session_token(self, response_result, response_headers): self.rid_to_session_token[collection_rid][id_] = parsed_tokens[id_] else: self.rid_to_session_token[collection_rid][id_] = parsed_tokens[id_].merge(old_session_token) - self.collection_name_to_rid[collection_name] = collection_rid else: self.rid_to_session_token[collection_rid] = parsed_tokens - self.collection_name_to_rid[collection_name] = collection_rid + self.collection_name_to_rid[collection_name] = collection_rid def clear_session_token(self, response_headers): with self.session_lock: @@ -192,9 +315,39 @@ def parse_session_token(response_headers): id_to_sessionlsn[id_] = sessionToken return id_to_sessionlsn + def _format_session_token(self, pk_range, token_dict): + session_token_list = [] + parents = pk_range[0].get('parents').copy() + parents.append(pk_range[0]['id']) + for parent in parents: + vector_session_token = token_dict.get(parent) + if vector_session_token is not None: + session_token = "{0}:{1}".format(parent, vector_session_token.session_token) + session_token_list.append(session_token) + session_token = ",".join(session_token_list) + return session_token + + def _resolve_partition_local_session_token(self, pk_range, token_dict): + parent_session_token = None + parents = pk_range[0].get('parents').copy() + parents.append(pk_range[0]['id']) + for parent in parents: + session_token = token_dict.get(parent) + if session_token is not None: + vector_session_token = session_token.session_token + if parent_session_token is None: + parent_session_token = vector_session_token + # if initial token is already set, and the next parent's token is cached, merge vector session tokens + else: + vector_token_1 = VectorSessionToken.create(parent_session_token) + vector_token_2 = VectorSessionToken.create(vector_session_token) + vector_token = vector_token_1.merge(vector_token_2) + parent_session_token = vector_token.session_token + return parent_session_token + class Session(object): - """State of a Azure Cosmos session. + """State of an Azure Cosmos session. This session object can be shared across clients within the same process. @@ -209,8 +362,15 @@ def __init__(self, url_connection): def clear_session_token(self, response_headers): self.session_container.clear_session_token(response_headers) - def update_session(self, response_result, response_headers): - self.session_container.set_session_token(response_result, response_headers) + def update_session(self, client_connection, response_result, response_headers): + self.session_container.set_session_token(client_connection, response_result, response_headers) + + def get_session_token(self, resource_path, pk_value, container_properties_cache, routing_map_provider, + partition_key_range_id): + return self.session_container.get_session_token(resource_path, pk_value, container_properties_cache, + routing_map_provider, partition_key_range_id) - def get_session_token(self, resource_path): - return self.session_container.get_session_token(resource_path) + async def get_session_token_async(self, resource_path, pk_value, container_properties_cache, routing_map_provider, + partition_key_range_id): + return await self.session_container.get_session_token_async(resource_path, pk_value, container_properties_cache, + routing_map_provider, partition_key_range_id) diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_cosmos_client_connection_async.py b/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_cosmos_client_connection_async.py index cbcd3ccafba7..4d9136af4a29 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_cosmos_client_connection_async.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_cosmos_client_connection_async.py @@ -756,7 +756,7 @@ async def Create( self, body: Dict[str, Any], path: str, - typ: str, + resource_type: str, id: Optional[str], initial_headers: Optional[Mapping[str, Any]], options: Optional[Mapping[str, Any]] = None, @@ -766,7 +766,7 @@ async def Create( :param dict body: :param str path: - :param str typ: + :param str resource_type: :param str id: :param dict initial_headers: :param dict options: @@ -782,12 +782,12 @@ async def Create( options = {} initial_headers = initial_headers or self.default_headers - headers = base.GetHeaders(self, initial_headers, "post", path, id, typ, + headers = base.GetHeaders(self, initial_headers, "post", path, id, resource_type, documents._OperationType.Create, options) # Create will use WriteEndpoint since it uses POST operation - - request_params = _request_object.RequestObject(typ, documents._OperationType.Create, headers) + request_params = _request_object.RequestObject(resource_type, documents._OperationType.Create, headers) request_params.set_excluded_location_from_options(options) + await base.set_session_token_header_async(self, headers, path, request_params, options) result, last_response_headers = await self.__Post(path, request_params, body, headers, **kwargs) self.last_response_headers = last_response_headers @@ -895,7 +895,7 @@ async def Upsert( self, body: Dict[str, Any], path: str, - typ: str, + resource_type: str, id: Optional[str], initial_headers: Optional[Mapping[str, Any]], options: Optional[Mapping[str, Any]] = None, @@ -905,7 +905,7 @@ async def Upsert( :param dict body: :param str path: - :param str typ: + :param str resource_type: :param str id: :param dict initial_headers: :param dict options: @@ -921,22 +921,21 @@ async def Upsert( options = {} initial_headers = initial_headers or self.default_headers - headers = base.GetHeaders(self, initial_headers, "post", path, id, typ, documents._OperationType.Upsert, - options) - + headers = base.GetHeaders(self, initial_headers, "post", path, id, resource_type, + documents._OperationType.Upsert, options) headers[http_constants.HttpHeaders.IsUpsert] = True # Upsert will use WriteEndpoint since it uses POST operation - request_params = _request_object.RequestObject(typ, documents._OperationType.Upsert, headers) + request_params = _request_object.RequestObject(resource_type, documents._OperationType.Upsert, headers) request_params.set_excluded_location_from_options(options) + await base.set_session_token_header_async(self, headers, path, request_params, options) result, last_response_headers = await self.__Post(path, request_params, body, headers, **kwargs) self.last_response_headers = last_response_headers # update session for write request - self._UpdateSessionIfRequired(headers, result, self.last_response_headers) + self._UpdateSessionIfRequired(headers, result, last_response_headers) if response_hook: response_hook(last_response_headers, result) - return CosmosDict(result, - response_headers=last_response_headers) + return CosmosDict(result, response_headers=last_response_headers) async def __Post( self, @@ -1200,7 +1199,7 @@ async def ReadConflict( async def Read( self, path: str, - typ: str, + resource_type: str, id: Optional[str], initial_headers: Optional[Mapping[str, Any]], options: Optional[Mapping[str, Any]] = None, @@ -1209,7 +1208,7 @@ async def Read( """Reads an Azure Cosmos resource and returns it. :param str path: - :param str typ: + :param str resource_type: :param str id: :param dict initial_headers: :param dict options: @@ -1226,17 +1225,19 @@ async def Read( options = {} initial_headers = initial_headers or self.default_headers - headers = base.GetHeaders(self, initial_headers, "get", path, id, typ, documents._OperationType.Read, - options) + headers = base.GetHeaders(self, initial_headers, "get", path, id, resource_type, + documents._OperationType.Read, options) # Read will use ReadEndpoint since it uses GET operation - request_params = _request_object.RequestObject(typ, documents._OperationType.Read, headers) + request_params = _request_object.RequestObject(resource_type, documents._OperationType.Read, headers) request_params.set_excluded_location_from_options(options) + await base.set_session_token_header_async(self, headers, path, request_params, options) result, last_response_headers = await self.__Get(path, request_params, headers, **kwargs) + # update session for request mutates data on server side + self._UpdateSessionIfRequired(headers, result, last_response_headers) self.last_response_headers = last_response_headers if response_hook: response_hook(last_response_headers, result) - return CosmosDict(result, - response_headers=last_response_headers) + return CosmosDict(result, response_headers=last_response_headers) async def __Get( self, @@ -1483,17 +1484,18 @@ async def PatchItem( response_hook = kwargs.pop("response_hook", None) path = base.GetPathFromLink(document_link) document_id = base.GetResourceIdOrFullNameFromLink(document_link) - typ = http_constants.ResourceType.Document + resource_type = http_constants.ResourceType.Document if options is None: options = {} initial_headers = self.default_headers - headers = base.GetHeaders(self, initial_headers, "patch", path, document_id, typ, + headers = base.GetHeaders(self, initial_headers, "patch", path, document_id, resource_type, documents._OperationType.Patch, options) # Patch will use WriteEndpoint since it uses PUT operation - request_params = _request_object.RequestObject(typ, documents._OperationType.Patch, headers) + request_params = _request_object.RequestObject(resource_type, documents._OperationType.Patch, headers) request_params.set_excluded_location_from_options(options) + await base.set_session_token_header_async(self, headers, path, request_params, options) request_data = {} if options.get("filterPredicate"): request_data["condition"] = options.get("filterPredicate") @@ -1505,8 +1507,7 @@ async def PatchItem( self._UpdateSessionIfRequired(headers, result, last_response_headers) if response_hook: response_hook(last_response_headers, result) - return CosmosDict(result, - response_headers=last_response_headers) + return CosmosDict(result, response_headers=last_response_headers) async def ReplaceOffer( self, @@ -1569,7 +1570,7 @@ async def Replace( self, resource: Dict[str, Any], path: str, - typ: str, + resource_type: str, id: Optional[str], initial_headers: Optional[Mapping[str, Any]], options: Optional[Mapping[str, Any]] = None, @@ -1579,7 +1580,7 @@ async def Replace( :param dict resource: :param str path: - :param str typ: + :param str resource_type: :param str id: :param dict initial_headers: :param dict options: @@ -1595,11 +1596,12 @@ async def Replace( options = {} initial_headers = initial_headers or self.default_headers - headers = base.GetHeaders(self, initial_headers, "put", path, id, typ, documents._OperationType.Replace, - options) + headers = base.GetHeaders(self, initial_headers, "put", path, id, resource_type, + documents._OperationType.Replace, options) # Replace will use WriteEndpoint since it uses PUT operation - request_params = _request_object.RequestObject(typ, documents._OperationType.Replace, headers) + request_params = _request_object.RequestObject(resource_type, documents._OperationType.Replace, headers) request_params.set_excluded_location_from_options(options) + await base.set_session_token_header_async(self, headers, path, request_params, options) result, last_response_headers = await self.__Put(path, request_params, resource, headers, **kwargs) self.last_response_headers = last_response_headers @@ -1607,8 +1609,7 @@ async def Replace( self._UpdateSessionIfRequired(headers, result, self.last_response_headers) if response_hook: response_hook(last_response_headers, result) - return CosmosDict(result, - response_headers=last_response_headers) + return CosmosDict(result, response_headers=last_response_headers) async def __Put( self, @@ -1895,7 +1896,7 @@ async def DeleteConflict( async def DeleteResource( self, path: str, - typ: str, + resource_type: str, id: Optional[str], initial_headers: Optional[Mapping[str, Any]], options: Optional[Mapping[str, Any]] = None, @@ -1904,7 +1905,7 @@ async def DeleteResource( """Deletes an Azure Cosmos resource and returns it. :param str path: - :param str typ: + :param str resource_type: :param str id: :param dict initial_headers: :param dict options: @@ -1920,11 +1921,12 @@ async def DeleteResource( options = {} initial_headers = initial_headers or self.default_headers - headers = base.GetHeaders(self, initial_headers, "delete", path, id, typ, documents._OperationType.Delete, - options) + headers = base.GetHeaders(self, initial_headers, "delete", path, id, resource_type, + documents._OperationType.Delete, options) # Delete will use WriteEndpoint since it uses DELETE operation - request_params = _request_object.RequestObject(typ, documents._OperationType.Delete, headers) + request_params = _request_object.RequestObject(resource_type, documents._OperationType.Delete, headers) request_params.set_excluded_location_from_options(options) + await base.set_session_token_header_async(self, headers, path, request_params, options) result, last_response_headers = await self.__Delete(path, request_params, headers, **kwargs) self.last_response_headers = last_response_headers @@ -2041,6 +2043,7 @@ async def _Batch( request_params = _request_object.RequestObject(http_constants.ResourceType.Document, documents._OperationType.Batch, headers) request_params.set_excluded_location_from_options(options) + await base.set_session_token_header_async(self, headers, path, request_params, options) result = await self.__Post(path, request_params, batch_operations, headers, **kwargs) return cast(Tuple[List[Dict[str, Any]], CaseInsensitiveDict], result) @@ -2103,7 +2106,8 @@ async def fetch_fn(options: Mapping[str, Any]) -> Tuple[List[Dict[str, Any]], Ca ) return AsyncItemPaged( - self, query, options, fetch_function=fetch_fn, page_iterator_class=query_iterable.QueryIterable + self, query, options, fetch_function=fetch_fn, page_iterator_class=query_iterable.QueryIterable, + resource_type=http_constants.ResourceType.PartitionKeyRange ) def ReadDatabases( @@ -2308,7 +2312,9 @@ async def fetch_fn(options: Mapping[str, Any]) -> Tuple[List[Dict[str, Any]], Ca fetch_function=fetch_fn, collection_link=database_or_container_link, page_iterator_class=query_iterable.QueryIterable, - response_hook=response_hook + response_hook=response_hook, + raw_response_hook=kwargs.get('raw_response_hook'), + resource_type=http_constants.ResourceType.Document ) def QueryItemsChangeFeed( @@ -2440,7 +2446,8 @@ async def fetch_fn(options: Mapping[str, Any]) -> Tuple[List[Dict[str, Any]], Ca query, options, fetch_function=fetch_fn, - page_iterator_class=query_iterable.QueryIterable + page_iterator_class=query_iterable.QueryIterable, + resource_type=http_constants.ResourceType.Offer ) def ReadUsers( @@ -2502,7 +2509,8 @@ async def fetch_fn(options: Mapping[str, Any]) -> Tuple[List[Dict[str, Any]], Ca ) return AsyncItemPaged( - self, query, options, fetch_function=fetch_fn, page_iterator_class=query_iterable.QueryIterable + self, query, options, fetch_function=fetch_fn, page_iterator_class=query_iterable.QueryIterable, + resource_type=http_constants.ResourceType.User ) def ReadPermissions( @@ -2551,19 +2559,21 @@ def QueryPermissions( if options is None: options = {} - path = base.GetPathFromLink(user_link, "permissions") + path = base.GetPathFromLink(user_link, http_constants.ResourceType.Permission) user_id = base.GetResourceIdOrFullNameFromLink(user_link) async def fetch_fn(options: Mapping[str, Any]) -> Tuple[List[Dict[str, Any]], CaseInsensitiveDict]: return ( await self.__QueryFeed( - path, "permissions", user_id, lambda r: r["Permissions"], lambda _, b: b, query, options, **kwargs + path, http_constants.ResourceType.Permission, user_id, lambda r: r["Permissions"], lambda _, b: b, + query, options, **kwargs ), self.last_response_headers, ) return AsyncItemPaged( - self, query, options, fetch_function=fetch_fn, page_iterator_class=query_iterable.QueryIterable + self, query, options, fetch_function=fetch_fn, page_iterator_class=query_iterable.QueryIterable, + resource_type=http_constants.ResourceType.Permission ) def ReadStoredProcedures( @@ -2625,7 +2635,8 @@ async def fetch_fn(options: Mapping[str, Any]) -> Tuple[List[Dict[str, Any]], Ca ) return AsyncItemPaged( - self, query, options, fetch_function=fetch_fn, page_iterator_class=query_iterable.QueryIterable + self, query, options, fetch_function=fetch_fn, page_iterator_class=query_iterable.QueryIterable, + resource_type=http_constants.ResourceType.StoredProcedure ) def ReadTriggers( @@ -2687,7 +2698,8 @@ async def fetch_fn(options: Mapping[str, Any]) -> Tuple[List[Dict[str, Any]], Ca ) return AsyncItemPaged( - self, query, options, fetch_function=fetch_fn, page_iterator_class=query_iterable.QueryIterable + self, query, options, fetch_function=fetch_fn, page_iterator_class=query_iterable.QueryIterable, + resource_type=http_constants.ResourceType.Trigger ) def ReadUserDefinedFunctions( @@ -2750,7 +2762,8 @@ async def fetch_fn(options: Mapping[str, Any]) -> Tuple[List[Dict[str, Any]], Ca ) return AsyncItemPaged( - self, query, options, fetch_function=fetch_fn, page_iterator_class=query_iterable.QueryIterable + self, query, options, fetch_function=fetch_fn, page_iterator_class=query_iterable.QueryIterable, + resource_type=http_constants.ResourceType.UserDefinedFunction ) def ReadConflicts( @@ -2811,7 +2824,8 @@ async def fetch_fn(options: Mapping[str, Any]) -> Tuple[List[Dict[str, Any]], Ca ) return AsyncItemPaged( - self, query, options, fetch_function=fetch_fn, page_iterator_class=query_iterable.QueryIterable + self, query, options, fetch_function=fetch_fn, page_iterator_class=query_iterable.QueryIterable, + resource_type=http_constants.ResourceType.Conflict ) async def QueryFeed( @@ -2851,7 +2865,7 @@ async def QueryFeed( async def __QueryFeed( # pylint: disable=too-many-branches,too-many-statements,too-many-locals self, path: str, - typ: str, + resource_type: str, id_: Optional[str], result_fn: Callable[[Dict[str, Any]], List[Dict[str, Any]]], create_fn: Optional[Callable[['CosmosClientConnection', Dict[str, Any]], Dict[str, Any]]], @@ -2865,7 +2879,7 @@ async def __QueryFeed( # pylint: disable=too-many-branches,too-many-statements, """Query for more than one Azure Cosmos resources. :param str path: - :param str typ: + :param str resource_type: :param str id_: :param function result_fn: :param function create_fn: @@ -2904,17 +2918,20 @@ def __GetBodiesFromQueryResult(result: Dict[str, Any]) -> List[Dict[str, Any]]: # Copy to make sure that default_headers won't be changed. if query is None: - op_typ = documents._OperationType.QueryPlan if is_query_plan else documents._OperationType.ReadFeed + op_type = documents._OperationType.QueryPlan if is_query_plan else documents._OperationType.ReadFeed # Query operations will use ReadEndpoint even though it uses GET(for feed requests) - headers = base.GetHeaders(self, initial_headers, "get", path, id_, typ, op_typ, + headers = base.GetHeaders(self, initial_headers, "get", path, id_, resource_type, op_type, options, partition_key_range_id) request_params = _request_object.RequestObject( - typ, - op_typ, + resource_type, + op_type, headers ) request_params.set_excluded_location_from_options(options) - + headers = base.GetHeaders(self, initial_headers, "get", path, id_, resource_type, + request_params.operation_type, options, partition_key_range_id) + await base.set_session_token_header_async(self, headers, path, request_params, options, + partition_key_range_id) change_feed_state: Optional[ChangeFeedState] = options.get("changeFeedState") if change_feed_state is not None: feed_options = {} @@ -2923,20 +2940,20 @@ def __GetBodiesFromQueryResult(result: Dict[str, Any]) -> List[Dict[str, Any]]: await change_feed_state.populate_request_headers_async(self._routing_map_provider, headers, feed_options) - result, self.last_response_headers = await self.__Get(path, request_params, headers, **kwargs) + result, last_response_headers = await self.__Get(path, request_params, headers, **kwargs) + self.last_response_headers = last_response_headers + self._UpdateSessionIfRequired(headers, result, last_response_headers) if response_hook: response_hook(self.last_response_headers, result) return __GetBodiesFromQueryResult(result) query = self.__CheckAndUnifyQueryFormat(query) - initial_headers[http_constants.HttpHeaders.IsQuery] = "true" if not is_query_plan: initial_headers[http_constants.HttpHeaders.IsQuery] = "true" - if ( - self._query_compatibility_mode in (CosmosClientConnection._QueryCompatibilityMode.Default, - CosmosClientConnection._QueryCompatibilityMode.Query)): + if (self._query_compatibility_mode in (CosmosClientConnection._QueryCompatibilityMode.Default, + CosmosClientConnection._QueryCompatibilityMode.Query)): initial_headers[http_constants.HttpHeaders.ContentType] = runtime_constants.MediaTypes.QueryJson elif self._query_compatibility_mode == CosmosClientConnection._QueryCompatibilityMode.SqlQuery: initial_headers[http_constants.HttpHeaders.ContentType] = runtime_constants.MediaTypes.SQL @@ -2944,11 +2961,13 @@ def __GetBodiesFromQueryResult(result: Dict[str, Any]) -> List[Dict[str, Any]]: raise SystemError("Unexpected query compatibility mode.") # Query operations will use ReadEndpoint even though it uses POST(for regular query operations) - req_headers = base.GetHeaders(self, initial_headers, "post", path, id_, typ, - documents._OperationType.SqlQuery, - options, partition_key_range_id) - request_params = _request_object.RequestObject(typ, documents._OperationType.SqlQuery, req_headers) + req_headers = base.GetHeaders(self, initial_headers, "post", path, id_, resource_type, + documents._OperationType.SqlQuery, options, partition_key_range_id) + request_params = _request_object.RequestObject(resource_type, documents._OperationType.SqlQuery, req_headers) request_params.set_excluded_location_from_options(options) + if not is_query_plan: + await base.set_session_token_header_async(self, req_headers, path, request_params, options, + partition_key_range_id) # check if query has prefix partition key partition_key_value = options.get("partitionKey", None) @@ -2998,13 +3017,15 @@ def __GetBodiesFromQueryResult(result: Dict[str, Any]) -> List[Dict[str, Any]]: req_headers[http_constants.HttpHeaders.StartEpkString] = EPK_sub_range.min req_headers[http_constants.HttpHeaders.EndEpkString] = EPK_sub_range.max req_headers[http_constants.HttpHeaders.ReadFeedKeyType] = "EffectivePartitionKeyRange" - partial_result, self.last_response_headers = await self.__Post( + partial_result, last_response_headers = await self.__Post( path, request_params, query, req_headers, **kwargs ) + self.last_response_headers = last_response_headers + self._UpdateSessionIfRequired(req_headers, partial_result, last_response_headers) if results: # add up all the query results from all over lapping ranges results["Documents"].extend(partial_result["Documents"]) @@ -3016,7 +3037,11 @@ def __GetBodiesFromQueryResult(result: Dict[str, Any]) -> List[Dict[str, Any]]: if results: return __GetBodiesFromQueryResult(results) - result, self.last_response_headers = await self.__Post(path, request_params, query, req_headers, **kwargs) + result, last_response_headers = await self.__Post(path, request_params, query, req_headers, **kwargs) + self.last_response_headers = last_response_headers + # update session for request mutates data on server side + self._UpdateSessionIfRequired(req_headers, result, last_response_headers) + # TODO: this part might become an issue since HTTP/2 can return read-only headers if self.last_response_headers.get(http_constants.HttpHeaders.IndexUtilization) is not None: INDEX_METRICS_HEADER = http_constants.HttpHeaders.IndexUtilization index_metrics_raw = self.last_response_headers[INDEX_METRICS_HEADER] @@ -3084,9 +3109,10 @@ def _UpdateSessionIfRequired( if documents.ConsistencyLevel.Session == request_headers[http_constants.HttpHeaders.ConsistencyLevel]: is_session_consistency = True - if is_session_consistency and self.session: + if (is_session_consistency and self.session and + not base.IsMasterResource(request_headers[http_constants.HttpHeaders.ThinClientProxyResourceType])): # update session - self.session.update_session(response_result, response_headers) + self.session.update_session(self, response_result, response_headers) PartitionResolverErrorMessage = ( "Couldn't find any partition resolvers for the database link provided. " @@ -3272,7 +3298,6 @@ async def _GetQueryPlanThroughGateway(self, query: str, resource_link: str, } if excluded_locations is not None: options["excludedLocations"] = excluded_locations - resource_link = base.TrimBeginningAndEndingSlashes(resource_link) path = base.GetPathFromLink(resource_link, http_constants.ResourceType.Document) resource_id = base.GetResourceIdOrFullNameFromLink(resource_link) @@ -3324,6 +3349,7 @@ async def DeleteAllItemsByPartitionKey( request_params.set_excluded_location_from_options(options) _, last_response_headers = await self.__Post(path=path, request_params=request_params, req_headers=headers, body=None, **kwargs) + self._UpdateSessionIfRequired(headers, None, last_response_headers) self.last_response_headers = last_response_headers if response_hook: response_hook(last_response_headers, None) diff --git a/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_query_iterable_async.py b/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_query_iterable_async.py index 4a67671606dd..d0304ccfde60 100644 --- a/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_query_iterable_async.py +++ b/sdk/cosmos/azure-cosmos/azure/cosmos/aio/_query_iterable_async.py @@ -44,7 +44,9 @@ def __init__( database_link=None, partition_key=None, continuation_token=None, + resource_type=None, response_hook=None, + raw_response_hook=None, ): """Instantiates a QueryIterable for non-client side partitioning queries. @@ -55,7 +57,7 @@ def __init__( :param (str or dict) query: :param dict options: The request options for the request. :param method fetch_function: - :param method resource_type: The type of the resource being queried + :param str resource_type: The type of the resource being queried :param str resource_link: If this is a Document query/feed collection_link is required. Example of `fetch_function`: @@ -75,7 +77,8 @@ def __init__( self._database_link = database_link self._partition_key = partition_key self._ex_context = execution_dispatcher._ProxyQueryExecutionContext( - self._client, self._collection_link, self._query, self._options, self._fetch_function, response_hook) + self._client, self._collection_link, self._query, self._options, self._fetch_function, + response_hook, raw_response_hook, resource_type) super(QueryIterable, self).__init__(self._fetch_next, self._unpack, continuation_token=continuation_token) async def _unpack(self, block): diff --git a/sdk/cosmos/azure-cosmos/tests/test_config.py b/sdk/cosmos/azure-cosmos/tests/test_config.py index f8c2f7832bdb..51d1df4cd6b0 100644 --- a/sdk/cosmos/azure-cosmos/tests/test_config.py +++ b/sdk/cosmos/azure-cosmos/tests/test_config.py @@ -3,6 +3,7 @@ import collections import os +import random import time import unittest import uuid @@ -10,7 +11,7 @@ from azure.cosmos._retry_utility import _has_database_account_header, _has_read_retryable_headers, _configure_timeout from azure.cosmos.cosmos_client import CosmosClient from azure.cosmos.exceptions import CosmosHttpResponseError -from azure.cosmos.http_constants import StatusCodes +from azure.cosmos.http_constants import StatusCodes, HttpHeaders from azure.cosmos.partition_key import PartitionKey from azure.cosmos import (ContainerProxy, DatabaseProxy, documents, exceptions, http_constants, _retry_utility) @@ -58,12 +59,12 @@ class TestConfig(object): THROUGHPUT_FOR_2_PARTITIONS = 12000 THROUGHPUT_FOR_1_PARTITION = 400 - TEST_DATABASE_ID = os.getenv('COSMOS_TEST_DATABASE_ID', "Python SDK Test Database " + str(uuid.uuid4())) + TEST_DATABASE_ID = os.getenv('COSMOS_TEST_DATABASE_ID', "PythonSDKTestDatabase-" + str(uuid.uuid4())) - TEST_SINGLE_PARTITION_CONTAINER_ID = "Single Partition Test Container " + str(uuid.uuid4()) - TEST_MULTI_PARTITION_CONTAINER_ID = "Multi Partition Test Container " + str(uuid.uuid4()) - TEST_SINGLE_PARTITION_PREFIX_PK_CONTAINER_ID = "Single Partition With Prefix PK Test Container " + str(uuid.uuid4()) - TEST_MULTI_PARTITION_PREFIX_PK_CONTAINER_ID = "Multi Partition With Prefix PK Test Container " + str(uuid.uuid4()) + TEST_SINGLE_PARTITION_CONTAINER_ID = "SinglePartitionTestContainer-" + str(uuid.uuid4()) + TEST_MULTI_PARTITION_CONTAINER_ID = "MultiPartitionTestContainer-" + str(uuid.uuid4()) + TEST_SINGLE_PARTITION_PREFIX_PK_CONTAINER_ID = "SinglePartitionWithPrefixPKTestContainer-" + str(uuid.uuid4()) + TEST_MULTI_PARTITION_PREFIX_PK_CONTAINER_ID = "MultiPartitionWithPrefixPKTestContainer-" + str(uuid.uuid4()) TEST_CONTAINER_PARTITION_KEY = "pk" TEST_CONTAINER_PREFIX_PARTITION_KEY = ["pk1", "pk2"] @@ -291,6 +292,32 @@ def get_full_text_policy(path): ] } +def get_test_item(): + test_item = { + 'id': 'Item_' + str(uuid.uuid4()), + 'test_object': True, + 'lastName': 'Smith', + 'attr1': random.randint(0, 10) + } + return test_item + +def pre_split_hook(response): + request_headers = response.http_request.headers + session_token = request_headers.get('x-ms-session-token') + assert len(session_token) <= 20 + assert session_token.startswith('0') + assert session_token.count(':') == 1 + assert session_token.count(',') == 0 + +def post_split_hook(response): + request_headers = response.http_request.headers + session_token = request_headers.get('x-ms-session-token') + assert len(session_token) > 30 + assert len(session_token) < 60 # should only be 0-1 or 0-2, not 0-1-2 + assert session_token.startswith('0') is False + assert session_token.count(':') == 2 + assert session_token.count(',') == 1 + class ResponseHookCaller: def __init__(self): self.count = 0 @@ -322,6 +349,14 @@ def __init__(self, headers=None, status_code=200, message="test-message"): def body(self): return None +def no_token_response_hook(raw_response): + request_headers = raw_response.http_request.headers + assert request_headers.get(HttpHeaders.SessionToken) is None + +def token_response_hook(raw_response): + request_headers = raw_response.http_request.headers + assert request_headers.get(HttpHeaders.SessionToken) is not None + class MockConnectionRetryPolicy(RetryPolicy): def __init__(self, resource_type, error=None, **kwargs): diff --git a/sdk/cosmos/azure-cosmos/tests/test_latest_session_token.py b/sdk/cosmos/azure-cosmos/tests/test_latest_session_token.py index 7fd18a53326b..7ea83b17749b 100644 --- a/sdk/cosmos/azure-cosmos/tests/test_latest_session_token.py +++ b/sdk/cosmos/azure-cosmos/tests/test_latest_session_token.py @@ -93,16 +93,10 @@ def test_latest_session_token_from_pk(self): phys_feed_ranges_and_session_tokens) phys_session_token = container.get_latest_session_token(phys_feed_ranges_and_session_tokens, phys_target_feed_range) - assert is_compound_session_token(phys_session_token) - session_tokens = phys_session_token.split(",") - assert len(session_tokens) == 2 - pk_range_id1, session_token1 = parse_session_token(session_tokens[0]) - pk_range_id2, session_token2 = parse_session_token(session_tokens[1]) - pk_range_ids = [pk_range_id1, pk_range_id2] - - assert 620 <= (session_token1.global_lsn + session_token2.global_lsn) - assert '1' in pk_range_ids - assert '2' in pk_range_ids + pk_range_id, session_token = parse_session_token(phys_session_token) + + assert session_token.global_lsn >= 360 + assert '2' in pk_range_id self.database.delete_container(container.id) def test_latest_session_token_hpk(self): diff --git a/sdk/cosmos/azure-cosmos/tests/test_latest_session_token_async.py b/sdk/cosmos/azure-cosmos/tests/test_latest_session_token_async.py index 139078683bd3..5e1fbffa5921 100644 --- a/sdk/cosmos/azure-cosmos/tests/test_latest_session_token_async.py +++ b/sdk/cosmos/azure-cosmos/tests/test_latest_session_token_async.py @@ -95,16 +95,10 @@ async def test_latest_session_token_from_pk_async(self): phys_feed_ranges_and_session_tokens) phys_session_token = await container.get_latest_session_token(phys_feed_ranges_and_session_tokens, phys_target_feed_range) - assert is_compound_session_token(phys_session_token) - session_tokens = phys_session_token.split(",") - assert len(session_tokens) == 2 - pk_range_id1, session_token1 = parse_session_token(session_tokens[0]) - pk_range_id2, session_token2 = parse_session_token(session_tokens[1]) - pk_range_ids = [pk_range_id1, pk_range_id2] - - assert 620 <= (session_token1.global_lsn + session_token2.global_lsn) - assert '1' in pk_range_ids - assert '2' in pk_range_ids + pk_range_id, session_token = parse_session_token(phys_session_token) + + assert session_token.global_lsn >= 360 + assert '2' in pk_range_id await self.database.delete_container(container.id) async def test_latest_session_token_hpk(self): diff --git a/sdk/cosmos/azure-cosmos/tests/test_partition_split_query.py b/sdk/cosmos/azure-cosmos/tests/test_partition_split_query.py index 4bb7d160a7fe..c22cd57b17a4 100644 --- a/sdk/cosmos/azure-cosmos/tests/test_partition_split_query.py +++ b/sdk/cosmos/azure-cosmos/tests/test_partition_split_query.py @@ -4,7 +4,6 @@ import random import time import unittest -import uuid import os import pytest @@ -12,17 +11,7 @@ import azure.cosmos.cosmos_client as cosmos_client import test_config from azure.cosmos import DatabaseProxy, PartitionKey, ContainerProxy -from azure.cosmos.exceptions import CosmosClientTimeoutError, CosmosHttpResponseError - - -def get_test_item(): - test_item = { - 'id': 'Item_' + str(uuid.uuid4()), - 'test_object': True, - 'lastName': 'Smith', - 'attr1': random.randint(0, 10) - } - return test_item +from azure.cosmos.exceptions import CosmosHttpResponseError def run_queries(container, iterations): @@ -52,6 +41,7 @@ class TestPartitionSplitQuery(unittest.TestCase): throughput = 400 TEST_DATABASE_ID = configs.TEST_DATABASE_ID TEST_CONTAINER_ID = "Single-partition-container-without-throughput" + MAX_TIME = 60 * 7 # 7 minutes for the test to complete, should be enough for partition split to complete @classmethod def setUpClass(cls): @@ -59,7 +49,8 @@ def setUpClass(cls): cls.database = cls.client.get_database_client(cls.TEST_DATABASE_ID) cls.container = cls.database.create_container( id=cls.TEST_CONTAINER_ID, - partition_key=PartitionKey(path="/id")) + partition_key=PartitionKey(path="/id"), + offer_throughput=cls.throughput) if cls.host == "https://localhost:8081/": os.environ["AZURE_COSMOS_DISABLE_NON_STREAMING_ORDER_BY"] = "True" @@ -72,12 +63,12 @@ def tearDownClass(cls) -> None: def test_partition_split_query(self): for i in range(100): - body = get_test_item() + body = test_config.get_test_item() self.container.create_item(body=body) start_time = time.time() print("created items, changing offer to 11k and starting queries") - self.database.replace_throughput(11000) + self.container.replace_throughput(11000) offer_time = time.time() print("changed offer to 11k") print("--------------------------------") @@ -85,13 +76,13 @@ def test_partition_split_query(self): run_queries(self.container, 100) # initial check for queries before partition split print("initial check succeeded, now reading offer until replacing is done") - offer = self.database.get_throughput() + offer = self.container.get_throughput() while True: - if time.time() - start_time > 60 * 25: # timeout test at 25 minutes - raise unittest.SkipTest("Partition split didn't complete in time") + if time.time() - start_time > self.MAX_TIME: # timeout test at 25 minutes + self.skipTest("Partition split didn't complete in time") if offer.properties['content'].get('isOfferReplacePending', False): - time.sleep(10) - offer = self.database.get_throughput() + time.sleep(30) # wait for the offer to be replaced, check every 30 seconds + offer = self.container.get_throughput() else: print("offer replaced successfully, took around {} seconds".format(time.time() - offer_time)) run_queries(self.container, 100) # check queries work post partition split diff --git a/sdk/cosmos/azure-cosmos/tests/test_partition_split_query_async.py b/sdk/cosmos/azure-cosmos/tests/test_partition_split_query_async.py new file mode 100644 index 000000000000..56b912ca506d --- /dev/null +++ b/sdk/cosmos/azure-cosmos/tests/test_partition_split_query_async.py @@ -0,0 +1,93 @@ +# The MIT License (MIT) +# Copyright (c) Microsoft Corporation. All rights reserved. + +import time +import unittest +import random + +import pytest + +import test_config +from azure.cosmos import PartitionKey +from azure.cosmos.aio import CosmosClient, DatabaseProxy, ContainerProxy + +async def run_queries(container, iterations): + ret_list = [] + for i in range(iterations): + curr = str(random.randint(0, 10)) + query = 'SELECT * FROM c WHERE c.attr1=' + curr + ' order by c.attr1' + qlist = [item async for item in container.query_items(query=query, enable_cross_partition_query=True)] + ret_list.append((curr, qlist)) + for ret in ret_list: + curr = ret[0] + if len(ret[1]) != 0: + for results in ret[1]: + attr_number = results['attr1'] + assert str(attr_number) == curr # verify that all results match their randomly generated attributes + print("validation succeeded for all query results") + + +@pytest.mark.cosmosQuery +class TestPartitionSplitQueryAsync(unittest.IsolatedAsyncioTestCase): + database: DatabaseProxy = None + container: ContainerProxy = None + client: CosmosClient = None + configs = test_config.TestConfig + host = configs.host + masterKey = configs.masterKey + throughput = 400 + TEST_DATABASE_ID = configs.TEST_DATABASE_ID + TEST_CONTAINER_ID = "Single-partition-container-without-throughput-async" + MAX_TIME = 60 * 7 # 7 minutes for the test to complete, should be enough for partition split to complete + + @classmethod + def setUpClass(cls): + if (cls.masterKey == '[YOUR_KEY_HERE]' or + cls.host == '[YOUR_ENDPOINT_HERE]'): + raise Exception( + "You must specify your Azure Cosmos account values for " + "'masterKey' and 'host' at the top of this class to run the " + "tests.") + + async def asyncSetUp(self): + self.client = CosmosClient(self.host, self.masterKey) + await self.client.__aenter__() + self.created_database = self.client.get_database_client(self.TEST_DATABASE_ID) + self.container = await self.created_database.create_container( + id=self.TEST_CONTAINER_ID, + partition_key=PartitionKey(path="/id"), + offer_throughput=self.throughput) + + async def asyncTearDown(self): + await self.client.close() + + async def test_partition_split_query_async(self): + for i in range(100): + body = test_config.get_test_item() + await self.container.create_item(body=body) + + start_time = time.time() + print("created items, changing offer to 11k and starting queries") + await self.container.replace_throughput(11000) + offer_time = time.time() + print("changed offer to 11k") + print("--------------------------------") + print("now starting queries") + + await run_queries(self.container, 100) # initial check for queries before partition split + print("initial check succeeded, now reading offer until replacing is done") + offer = await self.container.get_throughput() + while True: + if time.time() - start_time > self.MAX_TIME: # timeout test at 25 minutes + self.skipTest("Partition split didn't complete in time.") + if offer.properties['content'].get('isOfferReplacePending', False): + time.sleep(30) # wait for the offer to be replaced, check every 30 seconds + offer = await self.container.get_throughput() + else: + print("offer replaced successfully, took around {} seconds".format(time.time() - offer_time)) + await run_queries(self.container, 100) # check queries work post partition split + self.assertTrue(offer.offer_throughput > self.throughput) + return + +if __name__ == '__main__': + unittest.main() diff --git a/sdk/cosmos/azure-cosmos/tests/test_session.py b/sdk/cosmos/azure-cosmos/tests/test_session.py index ab9307db3443..15cac3f2e2b5 100644 --- a/sdk/cosmos/azure-cosmos/tests/test_session.py +++ b/sdk/cosmos/azure-cosmos/tests/test_session.py @@ -7,13 +7,15 @@ import pytest -import azure.cosmos._synchronized_request as synchronized_request import azure.cosmos.cosmos_client as cosmos_client import azure.cosmos.exceptions as exceptions import test_config -from azure.cosmos import DatabaseProxy +from _fault_injection_transport import FaultInjectionTransport +from azure.core.rest import HttpRequest +from azure.cosmos import DatabaseProxy, PartitionKey from azure.cosmos import _retry_utility from azure.cosmos.http_constants import StatusCodes, SubStatusCodes, HttpHeaders +from typing import Callable @pytest.mark.cosmosEmulator @@ -44,25 +46,114 @@ def setUpClass(cls): cls.created_db = cls.client.get_database_client(cls.TEST_DATABASE_ID) cls.created_collection = cls.created_db.get_container_client(cls.TEST_COLLECTION_ID) - def _MockRequest(self, global_endpoint_manager, request_params, connection_policy, pipeline_client, request): - if HttpHeaders.SessionToken in request.headers: - self.last_session_token_sent = request.headers[HttpHeaders.SessionToken] - else: - self.last_session_token_sent = None - return self._OriginalRequest(global_endpoint_manager, request_params, connection_policy, pipeline_client, - request) - - def test_session_token_not_sent_for_master_resource_ops(self): - self._OriginalRequest = synchronized_request._Request - synchronized_request._Request = self._MockRequest - created_document = self.created_collection.create_item(body={'id': '1' + str(uuid.uuid4()), 'pk': 'mypk'}) - self.created_collection.read_item(item=created_document['id'], partition_key='mypk') - self.assertNotEqual(self.last_session_token_sent, None) - self.created_db.get_container_client(container=self.created_collection).read() - self.assertEqual(self.last_session_token_sent, None) - self.created_collection.read_item(item=created_document['id'], partition_key='mypk') - self.assertNotEqual(self.last_session_token_sent, None) - synchronized_request._Request = self._OriginalRequest + def test_session_token_sm_for_ops(self): + + # Session token should not be sent for control plane operations + test_container = self.created_db.create_container(str(uuid.uuid4()), PartitionKey(path="/id"), raw_response_hook=test_config.no_token_response_hook) + self.created_db.get_container_client(container=self.created_collection).read(raw_response_hook=test_config.no_token_response_hook) + self.created_db.delete_container(test_container, raw_response_hook=test_config.no_token_response_hook) + + # Session token should be sent for document read/batch requests only - verify it is not sent for write requests + up_item = self.created_collection.upsert_item(body={'id': '1' + str(uuid.uuid4()), 'pk': 'mypk'}, + raw_response_hook=test_config.no_token_response_hook) + replaced_item = self.created_collection.replace_item(item=up_item['id'], body={'id': up_item['id'], 'song': 'song', 'pk': 'mypk'}, + raw_response_hook=test_config.no_token_response_hook) + created_document = self.created_collection.create_item(body={'id': '1' + str(uuid.uuid4()), 'pk': 'mypk'}, + raw_response_hook=test_config.no_token_response_hook) + response_session_token = created_document.get_response_headers().get(HttpHeaders.SessionToken) + read_item = self.created_collection.read_item(item=created_document['id'], partition_key='mypk', + raw_response_hook=test_config.token_response_hook) + read_item2 = self.created_collection.read_item(item=created_document['id'], partition_key='mypk', + raw_response_hook=test_config.token_response_hook) + + # Since the session hasn't been updated (no write requests have happened) verify session is still the same + assert (read_item.get_response_headers().get(HttpHeaders.SessionToken) == + read_item2.get_response_headers().get(HttpHeaders.SessionToken) == + response_session_token) + # Verify session tokens are sent for batch requests too + batch_operations = [ + ("create", ({"id": str(uuid.uuid4()), "pk": 'mypk'},)), + ("replace", (read_item2['id'], {"id": str(uuid.uuid4()), "pk": 'mypk'})), + ("read", (replaced_item['id'],)), + ("upsert", ({"id": str(uuid.uuid4()), "pk": 'mypk'},)), + ] + batch_result = self.created_collection.execute_item_batch(batch_operations, 'mypk', raw_response_hook=test_config.token_response_hook) + batch_response_token = batch_result.get_response_headers().get(HttpHeaders.SessionToken) + assert batch_response_token != response_session_token + + # Verify no session tokens are sent for delete requests either - but verify session token is updated + self.created_collection.delete_item(replaced_item['id'], replaced_item['pk'], raw_response_hook=test_config.no_token_response_hook) + assert self.created_db.client_connection.last_response_headers.get(HttpHeaders.SessionToken) is not None + assert self.created_db.client_connection.last_response_headers.get(HttpHeaders.SessionToken) != batch_response_token + + def test_session_token_mwr_for_ops(self): + # For multiple write regions, all document requests should send out session tokens + # We will use fault injection to simulate the regions the emulator needs + first_region_uri: str = test_config.TestConfig.local_host.replace("localhost", "127.0.0.1") + custom_transport = FaultInjectionTransport() + + # Inject topology transformation that would make Emulator look like a multiple write region account + # account with two read regions + is_get_account_predicate: Callable[[HttpRequest], bool] = lambda r: FaultInjectionTransport.predicate_is_database_account_call(r) + emulator_as_multi_write_region_account_transformation = \ + lambda r, inner: FaultInjectionTransport.transform_topology_mwr( + first_region_name="First Region", + second_region_name="Second Region", + inner=inner) + custom_transport.add_response_transformation( + is_get_account_predicate, + emulator_as_multi_write_region_account_transformation) + client = cosmos_client.CosmosClient(self.host, self.masterKey, consistency_level="Session", + transport=custom_transport, multiple_write_locations=True) + db = client.get_database_client(self.TEST_DATABASE_ID) + container = db.get_container_client(self.TEST_COLLECTION_ID) + + # Session token should not be sent for control plane operations + test_container = db.create_container(str(uuid.uuid4()), PartitionKey(path="/id"), + raw_response_hook=test_config.no_token_response_hook) + db.get_container_client(container=self.created_collection).read( + raw_response_hook=test_config.no_token_response_hook) + db.delete_container(test_container, raw_response_hook=test_config.no_token_response_hook) + + # Session token should be sent for all document requests since we have mwr configuration + # First write request won't have since tokens need to be populated on the client first + container.upsert_item(body={'id': '0' + str(uuid.uuid4()), 'pk': 'mypk'}, + raw_response_hook=test_config.no_token_response_hook) + up_item = container.upsert_item(body={'id': '1' + str(uuid.uuid4()), 'pk': 'mypk'}, + raw_response_hook=test_config.token_response_hook) + replaced_item = container.replace_item(item=up_item['id'], body={'id': up_item['id'], 'song': 'song', + 'pk': 'mypk'}, + raw_response_hook=test_config.token_response_hook) + created_document = container.create_item(body={'id': '1' + str(uuid.uuid4()), 'pk': 'mypk'}, + raw_response_hook=test_config.token_response_hook) + response_session_token = created_document.get_response_headers().get(HttpHeaders.SessionToken) + read_item = container.read_item(item=created_document['id'], partition_key='mypk', + raw_response_hook=test_config.token_response_hook) + read_item2 = container.read_item(item=created_document['id'], partition_key='mypk', + raw_response_hook=test_config.token_response_hook) + + # Since the session hasn't been updated (no write requests have happened) verify session is still the same + assert (read_item.get_response_headers().get(HttpHeaders.SessionToken) == + read_item2.get_response_headers().get(HttpHeaders.SessionToken) == + response_session_token) + # Verify session tokens are sent for batch requests too + batch_operations = [ + ("create", ({"id": str(uuid.uuid4()), "pk": 'mypk'},)), + ("replace", (read_item2['id'], {"id": str(uuid.uuid4()), "pk": 'mypk'})), + ("read", (replaced_item['id'],)), + ("upsert", ({"id": str(uuid.uuid4()), "pk": 'mypk'},)), + ] + batch_result = container.execute_item_batch(batch_operations, 'mypk', + raw_response_hook=test_config.token_response_hook) + batch_response_token = batch_result.get_response_headers().get(HttpHeaders.SessionToken) + assert batch_response_token != response_session_token + + # Should get sent now that we have mwr configuration + container.delete_item(replaced_item['id'], replaced_item['pk'], + raw_response_hook=test_config.token_response_hook) + assert db.client_connection.last_response_headers.get(HttpHeaders.SessionToken) is not None + assert db.client_connection.last_response_headers.get(HttpHeaders.SessionToken) != batch_response_token + def _MockExecuteFunctionSessionReadFailureOnce(self, function, *args, **kwargs): response = test_config.FakeResponse({HttpHeaders.SubStatus: SubStatusCodes.READ_SESSION_NOTAVAILABLE}) @@ -80,7 +171,11 @@ def test_clear_session_token(self): self.created_collection.read_item(item=created_document['id'], partition_key='mypk') except exceptions.CosmosHttpResponseError as e: self.assertEqual(self.client.client_connection.session.get_session_token( - 'dbs/' + self.created_db.id + '/colls/' + self.created_collection.id), "") + 'dbs/' + self.created_db.id + '/colls/' + self.created_collection.id, + None, + None, + None, + None), "") self.assertEqual(e.status_code, StatusCodes.NOT_FOUND) self.assertEqual(e.sub_status, SubStatusCodes.READ_SESSION_NOTAVAILABLE) _retry_utility.ExecuteFunction = self.OriginalExecuteFunction diff --git a/sdk/cosmos/azure-cosmos/tests/test_session_async.py b/sdk/cosmos/azure-cosmos/tests/test_session_async.py new file mode 100644 index 000000000000..b30102435dd0 --- /dev/null +++ b/sdk/cosmos/azure-cosmos/tests/test_session_async.py @@ -0,0 +1,202 @@ +# -*- coding: utf-8 -*- +# The MIT License (MIT) +# Copyright (c) Microsoft Corporation. All rights reserved. + +import unittest +import uuid + +import pytest + +from _fault_injection_transport_async import FaultInjectionTransportAsync +import azure.cosmos.exceptions as exceptions +import test_config +from azure.cosmos.aio import CosmosClient, _retry_utility_async +from azure.cosmos import DatabaseProxy, PartitionKey +from azure.cosmos.http_constants import StatusCodes, SubStatusCodes, HttpHeaders +from azure.core.pipeline.transport._aiohttp import AioHttpTransportResponse +from azure.core.rest import HttpRequest, AsyncHttpResponse +from typing import Awaitable, Callable + + + +@pytest.mark.cosmosEmulator +class TestSessionAsync(unittest.IsolatedAsyncioTestCase): + """Test to ensure escaping of non-ascii characters from partition key""" + + created_db: DatabaseProxy = None + client: CosmosClient = None + host = test_config.TestConfig.host + masterKey = test_config.TestConfig.masterKey + connectionPolicy = test_config.TestConfig.connectionPolicy + configs = test_config.TestConfig + TEST_DATABASE_ID = configs.TEST_DATABASE_ID + TEST_COLLECTION_ID = configs.TEST_MULTI_PARTITION_CONTAINER_ID + + @classmethod + def setUpClass(cls): + if cls.masterKey == '[YOUR_KEY_HERE]' or cls.host == '[YOUR_ENDPOINT_HERE]': + raise Exception("You must specify your Azure Cosmos account values for " + "'masterKey' and 'host' at the top of this class to run the " + "tests.") + + async def asyncSetUp(self): + self.client = CosmosClient(self.host, self.masterKey) + await self.client.__aenter__() + self.created_db = self.client.get_database_client(self.TEST_DATABASE_ID) + self.created_container = self.created_db.get_container_client(self.TEST_COLLECTION_ID) + + async def asyncTearDown(self): + await self.client.close() + + async def test_session_token_swr_for_ops_async(self): + # Session token should not be sent for control plane operations + test_container = await self.created_db.create_container(str(uuid.uuid4()), PartitionKey(path="/id"), raw_response_hook=test_config.no_token_response_hook) + await self.created_db.get_container_client(container=self.created_container).read(raw_response_hook=test_config.no_token_response_hook) + await self.created_db.delete_container(test_container, raw_response_hook=test_config.no_token_response_hook) + + # Session token should be sent for document read/batch requests only - verify it is not sent for write requests + up_item = await self.created_container.upsert_item(body={'id': '1' + str(uuid.uuid4()), 'pk': 'mypk'}, + raw_response_hook=test_config.no_token_response_hook) + replaced_item = await self.created_container.replace_item(item=up_item['id'], body={'id': up_item['id'], 'song': 'song', 'pk': 'mypk'}, + raw_response_hook=test_config.no_token_response_hook) + created_document = await self.created_container.create_item(body={'id': '1' + str(uuid.uuid4()), 'pk': 'mypk'}, + raw_response_hook=test_config.no_token_response_hook) + response_session_token = created_document.get_response_headers().get(HttpHeaders.SessionToken) + read_item = await self.created_container.read_item(item=created_document['id'], partition_key='mypk', + raw_response_hook=test_config.token_response_hook) + read_item2 = await self.created_container.read_item(item=created_document['id'], partition_key='mypk', + raw_response_hook=test_config.token_response_hook) + + # Since the session hasn't been updated (no write requests have happened) verify session is still the same + assert (read_item.get_response_headers().get(HttpHeaders.SessionToken) == + read_item2.get_response_headers().get(HttpHeaders.SessionToken) == + response_session_token) + # Verify session tokens are sent for batch requests too + batch_operations = [ + ("create", ({"id": str(uuid.uuid4()), "pk": 'mypk'},)), + ("replace", (read_item2['id'], {"id": str(uuid.uuid4()), "pk": 'mypk'})), + ("read", (replaced_item['id'],)), + ("upsert", ({"id": str(uuid.uuid4()), "pk": 'mypk'},)), + ] + batch_result = await self.created_container.execute_item_batch(batch_operations, 'mypk', raw_response_hook=test_config.token_response_hook) + batch_response_token = batch_result.get_response_headers().get(HttpHeaders.SessionToken) + assert batch_response_token != response_session_token + + # Verify no session tokens are sent for delete requests either - but verify session token is updated + await self.created_container.delete_item(replaced_item['id'], replaced_item['pk'], raw_response_hook=test_config.no_token_response_hook) + assert self.created_db.client_connection.last_response_headers.get(HttpHeaders.SessionToken) is not None + assert self.created_db.client_connection.last_response_headers.get(HttpHeaders.SessionToken) != batch_response_token + + async def test_session_token_mwr_for_ops_async(self): + # For multiple write regions, all document requests should send out session tokens + # We will use fault injection to simulate the regions the emulator needs + first_region_uri: str = test_config.TestConfig.local_host.replace("localhost", "127.0.0.1") + custom_transport = FaultInjectionTransportAsync() + + # Inject topology transformation that would make Emulator look like a multiple write region account + # account with two read regions + is_get_account_predicate: Callable[[HttpRequest], bool] = lambda r: FaultInjectionTransportAsync.predicate_is_database_account_call(r) + emulator_as_multi_write_region_account_transformation: Callable[[HttpRequest, Callable[[HttpRequest], Awaitable[AsyncHttpResponse]]], AioHttpTransportResponse] = \ + lambda r, inner: FaultInjectionTransportAsync.transform_topology_mwr( + first_region_name="First Region", + second_region_name="Second Region", + inner=inner) + custom_transport.add_response_transformation( + is_get_account_predicate, + emulator_as_multi_write_region_account_transformation) + client = CosmosClient(self.host, self.masterKey, consistency_level="Session", + transport=custom_transport, multiple_write_locations=True) + db = client.get_database_client(self.TEST_DATABASE_ID) + container = db.get_container_client(self.TEST_COLLECTION_ID) + await client.__aenter__() + + # Session token should not be sent for control plane operations + test_container = await db.create_container(str(uuid.uuid4()), PartitionKey(path="/id"), + raw_response_hook=test_config.no_token_response_hook) + await db.get_container_client(container=self.created_container).read( + raw_response_hook=test_config.no_token_response_hook) + await db.delete_container(test_container, raw_response_hook=test_config.no_token_response_hook) + + # Session token should be sent for all document requests since we have mwr configuration + # First write request won't have since tokens need to be populated on the client first + await container.upsert_item(body={'id': '0' + str(uuid.uuid4()), 'pk': 'mypk'}, + raw_response_hook=test_config.no_token_response_hook) + up_item = await container.upsert_item(body={'id': '1' + str(uuid.uuid4()), 'pk': 'mypk'}, + raw_response_hook=test_config.token_response_hook) + replaced_item = await container.replace_item(item=up_item['id'], + body={'id': up_item['id'], 'song': 'song', + 'pk': 'mypk'}, + raw_response_hook=test_config.token_response_hook) + created_document = await container.create_item(body={'id': '1' + str(uuid.uuid4()), 'pk': 'mypk'}, + raw_response_hook=test_config.token_response_hook) + response_session_token = created_document.get_response_headers().get(HttpHeaders.SessionToken) + read_item = await container.read_item(item=created_document['id'], partition_key='mypk', + raw_response_hook=test_config.token_response_hook) + read_item2 = await container.read_item(item=created_document['id'], partition_key='mypk', + raw_response_hook=test_config.token_response_hook) + + # Since the session hasn't been updated (no write requests have happened) verify session is still the same + assert (read_item.get_response_headers().get(HttpHeaders.SessionToken) == + read_item2.get_response_headers().get(HttpHeaders.SessionToken) == + response_session_token) + # Verify session tokens are sent for batch requests too + batch_operations = [ + ("create", ({"id": str(uuid.uuid4()), "pk": 'mypk'},)), + ("replace", (read_item2['id'], {"id": str(uuid.uuid4()), "pk": 'mypk'})), + ("read", (replaced_item['id'],)), + ("upsert", ({"id": str(uuid.uuid4()), "pk": 'mypk'},)), + ] + batch_result = await container.execute_item_batch(batch_operations, 'mypk', + raw_response_hook=test_config.token_response_hook) + batch_response_token = batch_result.get_response_headers().get(HttpHeaders.SessionToken) + assert batch_response_token != response_session_token + + # Should get sent now that we have mwr configuration + await container.delete_item(replaced_item['id'], replaced_item['pk'], + raw_response_hook=test_config.token_response_hook) + assert db.client_connection.last_response_headers.get(HttpHeaders.SessionToken) is not None + assert db.client_connection.last_response_headers.get(HttpHeaders.SessionToken) != batch_response_token + await client.close() + + + def _MockExecuteFunctionSessionReadFailureOnce(self, function, *args, **kwargs): + response = test_config.FakeResponse({HttpHeaders.SubStatus: SubStatusCodes.READ_SESSION_NOTAVAILABLE}) + raise exceptions.CosmosHttpResponseError( + status_code=StatusCodes.NOT_FOUND, + message="Read Session not available", + response=response) + + async def test_clear_session_token_async(self): + created_document = await self.created_container.create_item(body={'id': '1' + str(uuid.uuid4()), 'pk': 'mypk'}) + + self.OriginalExecuteFunction = _retry_utility_async.ExecuteFunctionAsync + _retry_utility_async.ExecuteFunctionAsync = self._MockExecuteFunctionSessionReadFailureOnce + try: + await self.created_container.read_item(item=created_document['id'], partition_key='mypk') + except exceptions.CosmosHttpResponseError as e: + self.assertEqual(self.client.client_connection.session.get_session_token( + 'dbs/' + self.created_db.id + '/colls/' + self.created_container.id, + None, + None, + None, + None), "") + self.assertEqual(e.status_code, StatusCodes.NOT_FOUND) + self.assertEqual(e.sub_status, SubStatusCodes.READ_SESSION_NOTAVAILABLE) + _retry_utility_async.ExecuteFunctionAsync = self.OriginalExecuteFunction + + async def _MockExecuteFunctionInvalidSessionTokenAsync(self, function, *args, **kwargs): + response = {'_self': 'dbs/90U1AA==/colls/90U1AJ4o6iA=/docs/90U1AJ4o6iABCT0AAAAABA==/', 'id': '1'} + headers = {HttpHeaders.SessionToken: '0:2', + HttpHeaders.AlternateContentPath: 'dbs/testDatabase/colls/testCollection'} + return (response, headers) + + async def test_internal_server_error_raised_for_invalid_session_token_received_from_server_async(self): + self.OriginalExecuteFunction = _retry_utility_async.ExecuteFunctionAsync + _retry_utility_async.ExecuteFunctionAsync = self._MockExecuteFunctionInvalidSessionTokenAsync + try: + await self.created_container.create_item(body={'id': '1' + str(uuid.uuid4()), 'pk': 'mypk'}) + self.fail('Test did not fail as expected') + except exceptions.CosmosHttpResponseError as e: + self.assertEqual(e.http_error_message, "Could not parse the received session token: 2") + self.assertEqual(e.status_code, StatusCodes.INTERNAL_SERVER_ERROR) + _retry_utility_async.ExecuteFunctionAsync = self.OriginalExecuteFunction diff --git a/sdk/cosmos/azure-cosmos/tests/test_session_container.py b/sdk/cosmos/azure-cosmos/tests/test_session_container.py deleted file mode 100644 index 2ee352571204..000000000000 --- a/sdk/cosmos/azure-cosmos/tests/test_session_container.py +++ /dev/null @@ -1,67 +0,0 @@ -# The MIT License (MIT) -# Copyright (c) Microsoft Corporation. All rights reserved. - -import unittest - -import pytest - -import azure.cosmos.cosmos_client as cosmos_client -import test_config - - -# from types import * - -@pytest.mark.cosmosEmulator -class TestSessionContainer(unittest.TestCase): - # this test doesn't need real credentials, or connection to server - host = test_config.TestConfig.host - master_key = test_config.TestConfig.masterKey - connectionPolicy = test_config.TestConfig.connectionPolicy - - def setUp(self): - self.client = cosmos_client.CosmosClient(self.host, self.master_key, consistency_level="Session", - connection_policy=self.connectionPolicy) - self.session = self.client.client_connection.Session - - def tearDown(self): - pass - - def test_create_collection(self): - # validate session token population after create collection request - session_token = self.session.get_session_token('') - assert session_token == '' - - create_collection_response_result = {u'_self': u'dbs/DdAkAA==/colls/DdAkAPS2rAA=/', u'_rid': u'DdAkAPS2rAA=', - u'id': u'sample collection'} - create_collection_response_header = {'x-ms-session-token': '0:0#409#24=-1#12=-1', - 'x-ms-alt-content-path': 'dbs/sample%20database'} - self.session.update_session(create_collection_response_result, create_collection_response_header) - - token = self.session.get_session_token(u'/dbs/sample%20database/colls/sample%20collection') - assert token == '0:0#409#24=-1#12=-1' - - token = self.session.get_session_token(u'dbs/DdAkAA==/colls/DdAkAPS2rAA=/') - assert token == '0:0#409#24=-1#12=-1' - return True - - def test_document_requests(self): - # validate session token for rid based requests - create_document_response_result = {u'_self': u'dbs/DdAkAA==/colls/DdAkAPS2rAA=/docs/DdAkAPS2rAACAAAAAAAAAA==/', - u'_rid': u'DdAkAPS2rAACAAAAAAAAAA==', - u'id': u'eb391181-5c49-415a-ab27-848ce21d5d11'} - create_document_response_header = {'x-ms-session-token': '0:0#406#24=-1#12=-1', - 'x-ms-alt-content-path': 'dbs/sample%20database/colls/sample%20collection', - 'x-ms-content-path': 'DdAkAPS2rAA='} - - self.session.update_session(create_document_response_result, create_document_response_header) - - token = self.session.get_session_token(u'dbs/DdAkAA==/colls/DdAkAPS2rAA=/docs/DdAkAPS2rAACAAAAAAAAAA==/') - assert token == '0:0#406#24=-1#12=-1' - - token = self.session.get_session_token( - u'dbs/sample%20database/colls/sample%20collection/docs/eb391181-5c49-415a-ab27-848ce21d5d11') - assert token == '0:0#406#24=-1#12=-1' - - -if __name__ == '__main__': - unittest.main()