From c7d747c696c450d3b2fae4f0a793324580b41e23 Mon Sep 17 00:00:00 2001 From: Kushagra Thapar Date: Thu, 3 Apr 2025 23:30:21 -0700 Subject: [PATCH 1/3] Added Cosmos Fabric native integration tests --- sdk/cosmos/azure-cosmos/tests/conftest.py | 6 +- .../tests/fabric_token_credential.py | 20 + sdk/cosmos/azure-cosmos/tests/test_config.py | 3 +- .../tests/test_fabric_change_feed.py | 253 +++ .../azure-cosmos/tests/test_fabric_crud.py | 1699 +++++++++++++++++ .../tests/test_fabric_crud_container.py | 967 ++++++++++ .../azure-cosmos/tests/test_fabric_query.py | 607 ++++++ .../tests/testing_fabric_intergation.py | 96 + 8 files changed, 3648 insertions(+), 3 deletions(-) create mode 100644 sdk/cosmos/azure-cosmos/tests/fabric_token_credential.py create mode 100644 sdk/cosmos/azure-cosmos/tests/test_fabric_change_feed.py create mode 100644 sdk/cosmos/azure-cosmos/tests/test_fabric_crud.py create mode 100644 sdk/cosmos/azure-cosmos/tests/test_fabric_crud_container.py create mode 100644 sdk/cosmos/azure-cosmos/tests/test_fabric_query.py create mode 100644 sdk/cosmos/azure-cosmos/tests/testing_fabric_intergation.py diff --git a/sdk/cosmos/azure-cosmos/tests/conftest.py b/sdk/cosmos/azure-cosmos/tests/conftest.py index 0842d837931f..5b5f78605554 100644 --- a/sdk/cosmos/azure-cosmos/tests/conftest.py +++ b/sdk/cosmos/azure-cosmos/tests/conftest.py @@ -1,10 +1,12 @@ # The MIT License (MIT) # Copyright (c) Microsoft Corporation. All rights reserved. +import fabric_token_credential import test_config from azure.cosmos import CosmosClient as CosmosSyncClient -cosmos_sync_client = CosmosSyncClient(test_config.TestConfig.host, test_config.TestConfig.masterKey) +credential = fabric_token_credential.FabricTokenCredential() +cosmos_sync_client = CosmosSyncClient(test_config.TestConfig.fabric_host, credential=credential) def pytest_configure(config): @@ -32,7 +34,7 @@ def pytest_sessionfinish(session, exitstatus): returning the exit status to the system. """ config = test_config.TestConfig - config.try_delete_database(cosmos_sync_client) + # config.try_delete_database(cosmos_sync_client) def pytest_unconfigure(config): diff --git a/sdk/cosmos/azure-cosmos/tests/fabric_token_credential.py b/sdk/cosmos/azure-cosmos/tests/fabric_token_credential.py new file mode 100644 index 000000000000..d4f50fad49ab --- /dev/null +++ b/sdk/cosmos/azure-cosmos/tests/fabric_token_credential.py @@ -0,0 +1,20 @@ +# -*- coding: utf-8 -*- +# The MIT License (MIT) +# Copyright (c) Microsoft Corporation. All rights reserved. +from typing import Optional, Any + +from azure.core.credentials import TokenCredential, AccessToken +from azure.identity import InteractiveBrowserCredential + + +class FabricTokenCredential(TokenCredential): + + def __init__(self): + self.token_credential = InteractiveBrowserCredential() + self.token_credential.authority = '' + + def get_token(self, *scopes: str, claims: Optional[str] = None, tenant_id: Optional[str] = None, + enable_cae: bool = False, **kwargs: Any) -> AccessToken: + scopes = ["https://cosmos.azure.com/.default"] + return self.token_credential.get_token(*scopes, claims=claims, tenant_id=tenant_id, enable_cae=enable_cae, + **kwargs) diff --git a/sdk/cosmos/azure-cosmos/tests/test_config.py b/sdk/cosmos/azure-cosmos/tests/test_config.py index 7334d7f7a88a..0682bb334196 100644 --- a/sdk/cosmos/azure-cosmos/tests/test_config.py +++ b/sdk/cosmos/azure-cosmos/tests/test_config.py @@ -32,6 +32,7 @@ class TestConfig(object): masterKey = os.getenv('ACCOUNT_KEY', 'C2y6yDjf5/R+ob0N8A7Cgv30VRDJIWEHLM+4QDU5DE2nQ9nDuVTqobD4b8mGGyPMbIZnqyMsEcaGQy67XIw/Jw==') host = os.getenv('ACCOUNT_HOST', local_host) + fabric_host = "" connection_str = os.getenv('ACCOUNT_CONNECTION_STR', 'AccountEndpoint={};AccountKey={};'.format(host, masterKey)) connectionPolicy = documents.ConnectionPolicy() @@ -54,7 +55,7 @@ class TestConfig(object): THROUGHPUT_FOR_5_PARTITIONS = 30000 THROUGHPUT_FOR_1_PARTITION = 400 - TEST_DATABASE_ID = os.getenv('COSMOS_TEST_DATABASE_ID', "Python SDK Test Database " + str(uuid.uuid4())) + TEST_DATABASE_ID = os.getenv('COSMOS_TEST_DATABASE_ID', 'dkunda-fabric-cdb') TEST_SINGLE_PARTITION_CONTAINER_ID = "Single Partition Test Container " + str(uuid.uuid4()) TEST_MULTI_PARTITION_CONTAINER_ID = "Multi Partition Test Container " + str(uuid.uuid4()) diff --git a/sdk/cosmos/azure-cosmos/tests/test_fabric_change_feed.py b/sdk/cosmos/azure-cosmos/tests/test_fabric_change_feed.py new file mode 100644 index 000000000000..6cb07f7b98c0 --- /dev/null +++ b/sdk/cosmos/azure-cosmos/tests/test_fabric_change_feed.py @@ -0,0 +1,253 @@ +# The MIT License (MIT) +# Copyright (c) Microsoft Corporation. All rights reserved. + +import unittest +import uuid +from datetime import datetime, timedelta, timezone +from time import sleep + +import pytest +from _pytest.outcomes import fail + +import azure.cosmos.cosmos_client as cosmos_client +import azure.cosmos.exceptions as exceptions +import test_config +from azure.cosmos.partition_key import PartitionKey +from tests import fabric_token_credential + + +@pytest.fixture(scope="class") +def setup(): + config = test_config.TestConfig() + credential = fabric_token_credential.FabricTokenCredential() + test_client = cosmos_client.CosmosClient(config.fabric_host, credential=credential), + return { + "created_db": test_client[0].get_database_client(config.TEST_DATABASE_ID), + "is_emulator": config.is_emulator + } + +def round_time(): + utc_now = datetime.now(timezone.utc) + return utc_now - timedelta(microseconds=utc_now.microsecond) + +@pytest.mark.cosmosQuery +@pytest.mark.unittest +@pytest.mark.usefixtures("setup") +class TestChangeFeed: + """Test to ensure escaping of non-ascii characters from partition key""" + + def test_get_feed_ranges(self, setup): + created_collection = setup["created_db"].create_container("get_feed_ranges_" + str(uuid.uuid4()), + PartitionKey(path="/pk")) + result = list(created_collection.read_feed_ranges()) + assert len(result) == 1 + + @pytest.mark.parametrize("change_feed_filter_param", ["partitionKey", "partitionKeyRangeId", "feedRange"]) + # @pytest.mark.parametrize("change_feed_filter_param", ["partitionKeyRangeId"]) + def test_query_change_feed_with_different_filter(self, change_feed_filter_param, setup): + created_collection = setup["created_db"].create_container(f"change_feed_test_{change_feed_filter_param}_{str(uuid.uuid4())}", + PartitionKey(path="/pk")) + # Read change feed without passing any options + query_iterable = created_collection.query_items_change_feed() + iter_list = list(query_iterable) + assert len(iter_list) == 0 + + if change_feed_filter_param == "partitionKey": + filter_param = {"partition_key": "pk"} + elif change_feed_filter_param == "partitionKeyRangeId": + filter_param = {"partition_key_range_id": "0"} + elif change_feed_filter_param == "feedRange": + feed_ranges = list(created_collection.read_feed_ranges()) + assert len(feed_ranges) == 1 + filter_param = {"feed_range": feed_ranges[0]} + else: + filter_param = None + + # Read change feed from current should return an empty list + query_iterable = created_collection.query_items_change_feed(**filter_param) + iter_list = list(query_iterable) + assert len(iter_list) == 0 + assert 'etag' in created_collection.client_connection.last_response_headers + assert created_collection.client_connection.last_response_headers['etag'] !='' + + # Read change feed from beginning should return an empty list + query_iterable = created_collection.query_items_change_feed( + is_start_from_beginning=True, + **filter_param + ) + iter_list = list(query_iterable) + assert len(iter_list) == 0 + assert 'etag' in created_collection.client_connection.last_response_headers + continuation1 = created_collection.client_connection.last_response_headers['etag'] + assert continuation1 != '' + + # Create a document. Read change feed should return be able to read that document + document_definition = {'pk': 'pk', 'id': 'doc1'} + created_collection.create_item(body=document_definition) + query_iterable = created_collection.query_items_change_feed( + is_start_from_beginning=True, + **filter_param + ) + iter_list = list(query_iterable) + assert len(iter_list) == 1 + assert iter_list[0]['id'] == 'doc1' + assert 'etag' in created_collection.client_connection.last_response_headers + continuation2 = created_collection.client_connection.last_response_headers['etag'] + assert continuation2 != '' + assert continuation2 != continuation1 + + # Create two new documents. Verify that change feed contains the 2 new documents + # with page size 1 and page size 100 + document_definition = {'pk': 'pk', 'id': 'doc2'} + created_collection.create_item(body=document_definition) + document_definition = {'pk': 'pk3', 'id': 'doc3'} + created_collection.create_item(body=document_definition) + + for pageSize in [1, 100]: + # verify iterator + query_iterable = created_collection.query_items_change_feed( + continuation=continuation2, + max_item_count=pageSize, + **filter_param + ) + it = query_iterable.__iter__() + expected_ids = 'doc2.doc3.' + if "partition_key" in filter_param: + expected_ids = 'doc2.' + actual_ids = '' + for item in it: + actual_ids += item['id'] + '.' + assert actual_ids == expected_ids + + # verify by_page + # the options is not copied, therefore it need to be restored + query_iterable = created_collection.query_items_change_feed( + continuation=continuation2, + max_item_count=pageSize, + **filter_param + ) + count = 0 + expected_count = 2 + if "partition_key" in filter_param: + expected_count = 1 + all_fetched_res = [] + for page in query_iterable.by_page(): + fetched_res = list(page) + assert len(fetched_res) == min(pageSize, expected_count - count) + count += len(fetched_res) + all_fetched_res.extend(fetched_res) + + actual_ids = '' + for item in all_fetched_res: + actual_ids += item['id'] + '.' + assert actual_ids == expected_ids + + # verify reading change feed from the beginning + query_iterable = created_collection.query_items_change_feed( + is_start_from_beginning=True, + **filter_param + ) + expected_ids = 'doc1.doc2.doc3.' + if "partition_key" in filter_param: + expected_ids = 'doc1.doc2.' + it = query_iterable.__iter__() + actual_ids = '' + for item in it: + actual_ids += item['id'] + '.' + assert actual_ids == expected_ids + assert 'etag' in created_collection.client_connection.last_response_headers + continuation3 = created_collection.client_connection.last_response_headers['etag'] + + # verify reading empty change feed + query_iterable = created_collection.query_items_change_feed( + continuation=continuation3, + is_start_from_beginning=True, + **filter_param + ) + iter_list = list(query_iterable) + assert len(iter_list) == 0 + setup["created_db"].delete_container(created_collection.id) + + def test_query_change_feed_with_start_time(self, setup): + created_collection = setup["created_db"].create_container_if_not_exists("query_change_feed_start_time_test", + PartitionKey(path="/pk")) + batchSize = 50 + + def create_random_items(container, batch_size): + for _ in range(batch_size): + # Generate a Random partition key + partition_key = 'pk' + str(uuid.uuid4()) + + # Generate a random item + item = { + 'id': 'item' + str(uuid.uuid4()), + 'partitionKey': partition_key, + 'content': 'This is some random content', + } + + try: + # Create the item in the container + container.upsert_item(item) + except exceptions.CosmosHttpResponseError as e: + fail(e) + + # Create first batch of random items + create_random_items(created_collection, batchSize) + + # wait for 1 second and record the time, then wait another second + sleep(1) + start_time = round_time() + not_utc_time = datetime.now() + sleep(1) + + # now create another batch of items + create_random_items(created_collection, batchSize) + + # now query change feed based on start time + change_feed_iter = list(created_collection.query_items_change_feed(start_time=start_time)) + totalCount = len(change_feed_iter) + + # now check if the number of items that were changed match the batch size + assert totalCount == batchSize + + # negative test: pass in a valid time in the future + future_time = start_time + timedelta(hours=1) + change_feed_iter = list(created_collection.query_items_change_feed(start_time=future_time)) + totalCount = len(change_feed_iter) + # A future time should return 0 + assert totalCount == 0 + + # test a date that is not utc, will be converted to utc by sdk + change_feed_iter = list(created_collection.query_items_change_feed(start_time=not_utc_time)) + totalCount = len(change_feed_iter) + # Should equal batch size + assert totalCount == batchSize + + setup["created_db"].delete_container(created_collection.id) + + def test_query_change_feed_with_multi_partition(self, setup): + created_collection = setup["created_db"].create_container("change_feed_test_" + str(uuid.uuid4()), + PartitionKey(path="/pk"), + offer_throughput=11000) + + # create one doc and make sure change feed query can return the document + new_documents = [ + {'pk': 'pk', 'id': 'doc1'}, + {'pk': 'pk2', 'id': 'doc2'}, + {'pk': 'pk3', 'id': 'doc3'}, + {'pk': 'pk4', 'id': 'doc4'}] + expected_ids = ['doc1', 'doc2', 'doc3', 'doc4'] + + for document in new_documents: + created_collection.create_item(body=document) + + query_iterable = created_collection.query_items_change_feed(start_time="Beginning") + it = query_iterable.__iter__() + actual_ids = [] + for item in it: + actual_ids.append(item['id']) + + assert actual_ids == expected_ids + +if __name__ == "__main__": + unittest.main() diff --git a/sdk/cosmos/azure-cosmos/tests/test_fabric_crud.py b/sdk/cosmos/azure-cosmos/tests/test_fabric_crud.py new file mode 100644 index 000000000000..2276ed09e26f --- /dev/null +++ b/sdk/cosmos/azure-cosmos/tests/test_fabric_crud.py @@ -0,0 +1,1699 @@ +# -*- coding: utf-8 -*- +# The MIT License (MIT) +# Copyright (c) Microsoft Corporation. All rights reserved. + +"""End-to-end test. +""" + +import time +import unittest +import urllib.parse as urllib +import uuid + +import pytest +import requests +from azure.core import MatchConditions +from azure.core.exceptions import AzureError, ServiceResponseError +from azure.core.pipeline.transport import RequestsTransport, RequestsTransportResponse +from urllib3.util.retry import Retry + +import azure.cosmos.cosmos_client as cosmos_client +import azure.cosmos.documents as documents +import azure.cosmos.exceptions as exceptions +import fabric_token_credential +import test_config +from azure.cosmos import _retry_utility +from azure.cosmos.http_constants import HttpHeaders, StatusCodes +from azure.cosmos.partition_key import PartitionKey + + +class TimeoutTransport(RequestsTransport): + + def __init__(self, response): + self._response = response + super(TimeoutTransport, self).__init__() + + def send(self, *args, **kwargs): + if kwargs.pop("passthrough", False): + return super(TimeoutTransport, self).send(*args, **kwargs) + + time.sleep(5) + if isinstance(self._response, Exception): + raise self._response + output = requests.Response() + output.status_code = self._response + response = RequestsTransportResponse(None, output) + return response + + +@pytest.mark.cosmosLong +class TestCRUDOperations(unittest.TestCase): + """Python CRUD Tests. + """ + + configs = test_config.TestConfig + host = configs.host + fabric_host = configs.fabric_host + fabric_credential = fabric_token_credential.FabricTokenCredential() + connectionPolicy = configs.connectionPolicy + last_headers = [] + client: cosmos_client.CosmosClient = None + + def __AssertHTTPFailureWithStatus(self, status_code, func, *args, **kwargs): + """Assert HTTP failure with status. + + :Parameters: + - `status_code`: int + - `func`: function + """ + try: + func(*args, **kwargs) + self.assertFalse(True, 'function should fail.') + except exceptions.CosmosHttpResponseError as inst: + self.assertEqual(inst.status_code, status_code) + + @classmethod + def setUpClass(cls): + cls.client = cosmos_client.CosmosClient(cls.fabric_host, credential=cls.fabric_credential) + cls.databaseForTest = cls.client.get_database_client(cls.configs.TEST_DATABASE_ID) + + @unittest.skip + def test_partitioned_collection_document_crud_and_query(self): + created_db = self.databaseForTest + + created_collection = created_db.create_container("crud-query-container", partition_key=PartitionKey("/pk")) + + document_definition = {'id': 'document', + 'key': 'value', + 'pk': 'pk'} + + created_document = created_collection.create_item( + body=document_definition + ) + + self.assertEqual(created_document.get('id'), document_definition.get('id')) + self.assertEqual(created_document.get('key'), document_definition.get('key')) + + # read document + read_document = created_collection.read_item( + item=created_document.get('id'), + partition_key=created_document.get('pk') + ) + + self.assertEqual(read_document.get('id'), created_document.get('id')) + self.assertEqual(read_document.get('key'), created_document.get('key')) + + # Read document feed doesn't require partitionKey as it's always a cross partition query + documentlist = list(created_collection.read_all_items()) + self.assertEqual(1, len(documentlist)) + + # replace document + document_definition['key'] = 'new value' + + replaced_document = created_collection.replace_item( + item=read_document, + body=document_definition + ) + + self.assertEqual(replaced_document.get('key'), document_definition.get('key')) + + # upsert document(create scenario) + document_definition['id'] = 'document2' + document_definition['key'] = 'value2' + + upserted_document = created_collection.upsert_item(body=document_definition) + + self.assertEqual(upserted_document.get('id'), document_definition.get('id')) + self.assertEqual(upserted_document.get('key'), document_definition.get('key')) + + documentlist = list(created_collection.read_all_items()) + self.assertEqual(2, len(documentlist)) + + # delete document + created_collection.delete_item(item=upserted_document, partition_key=upserted_document.get('pk')) + + # query document on the partition key specified in the predicate will pass even without setting enableCrossPartitionQuery or passing in the partitionKey value + documentlist = list(created_collection.query_items( + { + 'query': 'SELECT * FROM root r WHERE r.id=\'' + replaced_document.get('id') + '\'' # nosec + }, enable_cross_partition_query=True)) + self.assertEqual(1, len(documentlist)) + + # query document on any property other than partitionKey will fail without setting enableCrossPartitionQuery or passing in the partitionKey value + try: + list(created_collection.query_items( + { + 'query': 'SELECT * FROM root r WHERE r.key=\'' + replaced_document.get('key') + '\'' # nosec + })) + except Exception: + pass + + # cross partition query + documentlist = list(created_collection.query_items( + query='SELECT * FROM root r WHERE r.key=\'' + replaced_document.get('key') + '\'', # nosec + enable_cross_partition_query=True + )) + + self.assertEqual(1, len(documentlist)) + + # query document by providing the partitionKey value + documentlist = list(created_collection.query_items( + query='SELECT * FROM root r WHERE r.key=\'' + replaced_document.get('key') + '\'', # nosec + partition_key=replaced_document.get('pk') + )) + + self.assertEqual(1, len(documentlist)) + created_db.delete_container(created_collection.id) + + def test_partitioned_collection_execute_stored_procedure(self): + created_collection = self.databaseForTest.get_container_client(self.configs.TEST_MULTI_PARTITION_CONTAINER_ID) + document_id = str(uuid.uuid4()) + + sproc = { + 'id': 'storedProcedure' + str(uuid.uuid4()), + 'body': ( + 'function () {' + + ' var client = getContext().getCollection();' + + ' client.createDocument(client.getSelfLink(), { id: "' + document_id + '", pk : 2}, ' + + ' {}, function(err, docCreated, options) { ' + + ' if(err) throw new Error(\'Error while creating document: \' + err.message);' + + ' else {' + + ' getContext().getResponse().setBody(1);' + + ' }' + + ' });}') + } + + created_sproc = created_collection.scripts.create_stored_procedure(sproc) + + # Partition Key value same as what is specified in the stored procedure body + result = created_collection.scripts.execute_stored_procedure(sproc=created_sproc['id'], partition_key=2) + self.assertEqual(result, 1) + + # Partition Key value different than what is specified in the stored procedure body will cause a bad request(400) error + self.__AssertHTTPFailureWithStatus( + StatusCodes.BAD_REQUEST, + created_collection.scripts.execute_stored_procedure, + created_sproc['id'], + 3) + + def test_script_logging_execute_stored_procedure(self): + created_collection = self.databaseForTest.get_container_client(self.configs.TEST_MULTI_PARTITION_CONTAINER_ID) + stored_proc_id = 'storedProcedure-1-' + str(uuid.uuid4()) + + sproc = { + 'id': stored_proc_id, + 'body': ( + 'function () {' + + ' var mytext = \'x\';' + + ' var myval = 1;' + + ' try {' + + ' console.log(\'The value of %s is %s.\', mytext, myval);' + + ' getContext().getResponse().setBody(\'Success!\');' + + ' }' + + ' catch (err) {' + + ' getContext().getResponse().setBody(\'inline err: [\' + err.number + \'] \' + err);' + + ' }' + '}') + } + + created_sproc = created_collection.scripts.create_stored_procedure(sproc) + + result = created_collection.scripts.execute_stored_procedure( + sproc=created_sproc['id'], + partition_key=1 + ) + + self.assertEqual(result, 'Success!') + self.assertFalse( + HttpHeaders.ScriptLogResults in created_collection.scripts.client_connection.last_response_headers) + + result = created_collection.scripts.execute_stored_procedure( + sproc=created_sproc['id'], + enable_script_logging=True, + partition_key=1 + ) + + self.assertEqual(result, 'Success!') + self.assertEqual(urllib.quote('The value of x is 1.'), + created_collection.scripts.client_connection.last_response_headers.get( + HttpHeaders.ScriptLogResults)) + + result = created_collection.scripts.execute_stored_procedure( + sproc=created_sproc['id'], + enable_script_logging=False, + partition_key=1 + ) + + self.assertEqual(result, 'Success!') + self.assertFalse( + HttpHeaders.ScriptLogResults in created_collection.scripts.client_connection.last_response_headers) + + def test_stored_procedure_functionality(self): + # create database + db = self.databaseForTest + # create collection + collection = self.databaseForTest.get_container_client(self.configs.TEST_MULTI_PARTITION_CONTAINER_ID) + + stored_proc_id = 'storedProcedure-1-' + str(uuid.uuid4()) + + sproc1 = { + 'id': stored_proc_id, + 'body': ( + 'function () {' + + ' for (var i = 0; i < 1000; i++) {' + + ' var item = getContext().getResponse().getBody();' + + ' if (i > 0 && item != i - 1) throw \'body mismatch\';' + + ' getContext().getResponse().setBody(i);' + + ' }' + + '}') + } + + retrieved_sproc = collection.scripts.create_stored_procedure(sproc1) + result = collection.scripts.execute_stored_procedure( + sproc=retrieved_sproc['id'], + partition_key=1 + ) + self.assertEqual(result, 999) + stored_proc_id_2 = 'storedProcedure-2-' + str(uuid.uuid4()) + sproc2 = { + 'id': stored_proc_id_2, + 'body': ( + 'function () {' + + ' for (var i = 0; i < 10; i++) {' + + ' getContext().getResponse().appendValue(\'Body\', i);' + + ' }' + + '}') + } + retrieved_sproc2 = collection.scripts.create_stored_procedure(sproc2) + result = collection.scripts.execute_stored_procedure( + sproc=retrieved_sproc2['id'], + partition_key=1 + ) + self.assertEqual(int(result), 123456789) + stored_proc_id_3 = 'storedProcedure-3-' + str(uuid.uuid4()) + sproc3 = { + 'id': stored_proc_id_3, + 'body': ( + 'function (input) {' + + ' getContext().getResponse().setBody(' + + ' \'a\' + input.temp);' + + '}') + } + retrieved_sproc3 = collection.scripts.create_stored_procedure(sproc3) + result = collection.scripts.execute_stored_procedure( + sproc=retrieved_sproc3['id'], + params={'temp': 'so'}, + partition_key=1 + ) + self.assertEqual(result, 'aso') + + def test_partitioned_collection_permissions(self): + created_db = self.databaseForTest + + collection_id = 'test_partitioned_collection_permissions all collection' + str(uuid.uuid4()) + + all_collection = created_db.create_container( + id=collection_id, + partition_key=PartitionKey(path='/key', kind=documents.PartitionKind.Hash) + ) + + collection_id = 'test_partitioned_collection_permissions read collection' + str(uuid.uuid4()) + + read_collection = created_db.create_container( + id=collection_id, + partition_key=PartitionKey(path='/key', kind=documents.PartitionKind.Hash) + ) + + user = created_db.create_user(body={'id': 'user' + str(uuid.uuid4())}) + + permission_definition = { + 'id': 'all permission', + 'permissionMode': documents.PermissionMode.All, + 'resource': all_collection.container_link, + 'resourcePartitionKey': [1] + } + + all_permission = user.create_permission(body=permission_definition) + + permission_definition = { + 'id': 'read permission', + 'permissionMode': documents.PermissionMode.Read, + 'resource': read_collection.container_link, + 'resourcePartitionKey': [1] + } + + read_permission = user.create_permission(body=permission_definition) + + resource_tokens = {} + # storing the resource tokens based on Resource IDs + resource_tokens["dbs/" + created_db.id + "/colls/" + all_collection.id] = (all_permission.properties['_token']) + resource_tokens["dbs/" + created_db.id + "/colls/" + read_collection.id] = ( + read_permission.properties['_token']) + + restricted_client = cosmos_client.CosmosClient( + TestCRUDOperations.host, resource_tokens, "Session", connection_policy=TestCRUDOperations.connectionPolicy) + + document_definition = {'id': 'document1', + 'key': 1 + } + + all_collection.client_connection = restricted_client.client_connection + read_collection.client_connection = restricted_client.client_connection + + # Create document in all_collection should succeed since the partitionKey is 1 which is what specified as resourcePartitionKey in permission object and it has all permissions + created_document = all_collection.create_item(body=document_definition) + + # Create document in read_collection should fail since it has only read permissions for this collection + self.__AssertHTTPFailureWithStatus( + StatusCodes.FORBIDDEN, + read_collection.create_item, + document_definition) + + document_definition['key'] = 2 + # Create document should fail since the partitionKey is 2 which is different that what is specified as resourcePartitionKey in permission object + self.__AssertHTTPFailureWithStatus( + StatusCodes.FORBIDDEN, + all_collection.create_item, + document_definition) + + document_definition['key'] = 1 + # Delete document should succeed since the partitionKey is 1 which is what specified as resourcePartitionKey in permission object + created_document = all_collection.delete_item(item=created_document['id'], + partition_key=document_definition['key']) + + # Delete document in read_collection should fail since it has only read permissions for this collection + self.__AssertHTTPFailureWithStatus( + StatusCodes.FORBIDDEN, + read_collection.delete_item, + document_definition['id'], + document_definition['id'] + ) + + created_db.delete_container(all_collection) + created_db.delete_container(read_collection) + + def test_partitioned_collection_partition_key_value_types(self): + created_db = self.databaseForTest + + created_collection = created_db.get_container_client(self.configs.TEST_MULTI_PARTITION_CONTAINER_ID) + + document_definition = {'id': 'document1' + str(uuid.uuid4()), + 'pk': None, + 'spam': 'eggs'} + + # create document with partitionKey set as None here + created_collection.create_item(body=document_definition) + + document_definition = {'id': 'document1' + str(uuid.uuid4()), + 'spam': 'eggs'} + + # create document with partitionKey set as Undefined here + created_collection.create_item(body=document_definition) + + document_definition = {'id': 'document1' + str(uuid.uuid4()), + 'pk': True, + 'spam': 'eggs'} + + # create document with bool partitionKey + created_collection.create_item(body=document_definition) + + document_definition = {'id': 'document1' + str(uuid.uuid4()), + 'pk': 'value', + 'spam': 'eggs'} + + # create document with string partitionKey + created_collection.create_item(body=document_definition) + + document_definition = {'id': 'document1' + str(uuid.uuid4()), + 'pk': 100, + 'spam': 'eggs'} + + # create document with int partitionKey + created_collection.create_item(body=document_definition) + + document_definition = {'id': 'document1' + str(uuid.uuid4()), + 'pk': 10.50, + 'spam': 'eggs'} + + # create document with float partitionKey + created_collection.create_item(body=document_definition) + + document_definition = {'name': 'sample document', + 'spam': 'eggs', + 'pk': 'value'} + + # Should throw an error because automatic id generation is disabled always. + self.__AssertHTTPFailureWithStatus( + StatusCodes.BAD_REQUEST, + created_collection.create_item, + document_definition + ) + + def test_document_crud(self): + # create database + created_db = self.databaseForTest + # create collection + created_collection = self.databaseForTest.get_container_client(self.configs.TEST_MULTI_PARTITION_CONTAINER_ID) + # read documents + documents = list(created_collection.read_all_items()) + # create a document + before_create_documents_count = len(documents) + + # create a document with auto ID generation + document_definition = {'name': 'sample document', + 'spam': 'eggs', + 'key': 'value', + 'pk': 'pk'} + + no_response = created_collection.create_item(body=document_definition, enable_automatic_id_generation=True, no_response=True) + self.assertDictEqual(no_response, {}) + + created_document = created_collection.create_item(body=document_definition, enable_automatic_id_generation=True) + self.assertEqual(created_document.get('name'), + document_definition['name']) + + document_definition = {'name': 'sample document', + 'spam': 'eggs', + 'key': 'value', + 'pk': 'pk', + 'id': str(uuid.uuid4())} + + created_document = created_collection.create_item(body=document_definition) + self.assertEqual(created_document.get('name'), + document_definition['name']) + self.assertEqual(created_document.get('id'), + document_definition['id']) + + # duplicated documents are not allowed when 'id' is provided. + duplicated_definition_with_id = document_definition.copy() + self.__AssertHTTPFailureWithStatus(StatusCodes.CONFLICT, + created_collection.create_item, + duplicated_definition_with_id) + # read documents after creation + documents = list(created_collection.read_all_items()) + self.assertEqual( + len(documents), + before_create_documents_count + 3, + 'create should increase the number of documents') + # query documents + documents = list(created_collection.query_items( + { + 'query': 'SELECT * FROM root r WHERE r.name=@name', + 'parameters': [ + {'name': '@name', 'value': document_definition['name']} + ] + }, enable_cross_partition_query=True + )) + self.assertTrue(documents) + documents = list(created_collection.query_items( + { + 'query': 'SELECT * FROM root r WHERE r.name=@name', + 'parameters': [ + {'name': '@name', 'value': document_definition['name']} + ], + }, enable_cross_partition_query=True, + enable_scan_in_query=True + )) + self.assertTrue(documents) + # replace document. + created_document['name'] = 'replaced document' + created_document['spam'] = 'not eggs' + old_etag = created_document['_etag'] + replaced_document = created_collection.replace_item( + item=created_document['id'], + body=created_document + ) + self.assertEqual(replaced_document['name'], + 'replaced document', + 'document id property should change') + self.assertEqual(replaced_document['spam'], + 'not eggs', + 'property should have changed') + self.assertEqual(created_document['id'], + replaced_document['id'], + 'document id should stay the same') + + # replace document based on condition + replaced_document['name'] = 'replaced document based on condition' + replaced_document['spam'] = 'new spam field' + + # should fail for stale etag + self.__AssertHTTPFailureWithStatus( + StatusCodes.PRECONDITION_FAILED, + created_collection.replace_item, + replaced_document['id'], + replaced_document, + if_match=old_etag, + ) + + # should fail if only etag specified + with self.assertRaises(ValueError): + created_collection.replace_item( + etag=replaced_document['_etag'], + item=replaced_document['id'], + body=replaced_document + ) + + # should fail if only match condition specified + with self.assertRaises(ValueError): + created_collection.replace_item( + match_condition=MatchConditions.IfNotModified, + item=replaced_document['id'], + body=replaced_document + ) + with self.assertRaises(ValueError): + created_collection.replace_item( + match_condition=MatchConditions.IfModified, + item=replaced_document['id'], + body=replaced_document + ) + + # should fail if invalid match condition specified + with self.assertRaises(TypeError): + created_collection.replace_item( + match_condition=replaced_document['_etag'], + item=replaced_document['id'], + body=replaced_document + ) + + # should pass for most recent etag + replaced_document_conditional = created_collection.replace_item( + match_condition=MatchConditions.IfNotModified, + etag=replaced_document['_etag'], + item=replaced_document['id'], + body=replaced_document + ) + self.assertEqual(replaced_document_conditional['name'], + 'replaced document based on condition', + 'document id property should change') + self.assertEqual(replaced_document_conditional['spam'], + 'new spam field', + 'property should have changed') + self.assertEqual(replaced_document_conditional['id'], + replaced_document['id'], + 'document id should stay the same') + # read document + one_document_from_read = created_collection.read_item( + item=replaced_document['id'], + partition_key=replaced_document['pk'] + ) + self.assertEqual(replaced_document['id'], + one_document_from_read['id']) + # delete document + created_collection.delete_item( + item=replaced_document, + partition_key=replaced_document['pk'] + ) + # read documents after deletion + self.__AssertHTTPFailureWithStatus(StatusCodes.NOT_FOUND, + created_collection.read_item, + replaced_document['id'], + replaced_document['id']) + + def test_document_upsert(self): + # create database + created_db = self.databaseForTest + + # create collection + created_collection = self.databaseForTest.get_container_client(self.configs.TEST_MULTI_PARTITION_CONTAINER_ID) + + # read documents and check count + documents = list(created_collection.read_all_items()) + before_create_documents_count = len(documents) + + # create document definition + document_definition = {'id': 'doc', + 'name': 'sample document', + 'spam': 'eggs', + 'pk': 'pk', + 'key': 'value'} + + # create document using Upsert API + created_document = created_collection.upsert_item(body=document_definition) + + # verify id property + self.assertEqual(created_document['id'], + document_definition['id']) + + # test error for non-string id + with self.assertRaises(TypeError): + document_definition['id'] = 7 + created_collection.upsert_item(body=document_definition) + + # read documents after creation and verify updated count + documents = list(created_collection.read_all_items()) + self.assertEqual( + len(documents), + before_create_documents_count + 1, + 'create should increase the number of documents') + + # update document + created_document['name'] = 'replaced document' + created_document['spam'] = 'not eggs' + + # should replace document since it already exists + upserted_document = created_collection.upsert_item(body=created_document) + + # verify the changed properties + self.assertEqual(upserted_document['name'], + created_document['name'], + 'document name property should change') + self.assertEqual(upserted_document['spam'], + created_document['spam'], + 'property should have changed') + + # verify id property + self.assertEqual(upserted_document['id'], + created_document['id'], + 'document id should stay the same') + + # read documents after upsert and verify count doesn't increases again + documents = list(created_collection.read_all_items()) + self.assertEqual( + len(documents), + before_create_documents_count + 1, + 'number of documents should remain same') + + created_document['id'] = 'new id' + + # Upsert should create new document since the id is different + new_document = created_collection.upsert_item(body=created_document) + + # Test modified access conditions + created_document['spam'] = 'more eggs' + created_collection.upsert_item(body=created_document) + with self.assertRaises(exceptions.CosmosHttpResponseError): + created_collection.upsert_item( + body=created_document, + match_condition=MatchConditions.IfNotModified, + etag=new_document['_etag']) + + # verify id property + self.assertEqual(created_document['id'], + new_document['id'], + 'document id should be same') + + # read documents after upsert and verify count increases + documents = list(created_collection.read_all_items()) + self.assertEqual( + len(documents), + before_create_documents_count + 2, + 'upsert should increase the number of documents') + + # delete documents + created_collection.delete_item(item=upserted_document, partition_key=upserted_document['pk']) + created_collection.delete_item(item=new_document, partition_key=new_document['pk']) + + # read documents after delete and verify count is same as original + documents = list(created_collection.read_all_items()) + self.assertEqual( + len(documents), + before_create_documents_count, + 'number of documents should remain same') + + def test_geospatial_index(self): + db = self.databaseForTest + # partial policy specified + collection = db.create_container( + id='collection with spatial index ' + str(uuid.uuid4()), + indexing_policy={ + 'includedPaths': [ + { + 'path': '/"Location"/?', + 'indexes': [ + { + 'kind': 'Spatial', + 'dataType': 'Point' + } + ] + }, + { + 'path': '/' + } + ] + }, + partition_key=PartitionKey(path='/id', kind='Hash') + ) + collection.create_item( + body={ + 'id': 'loc1', + 'Location': { + 'type': 'Point', + 'coordinates': [20.0, 20.0] + } + } + ) + collection.create_item( + body={ + 'id': 'loc2', + 'Location': { + 'type': 'Point', + 'coordinates': [100.0, 100.0] + } + } + ) + results = list(collection.query_items( + query="SELECT * FROM root WHERE (ST_DISTANCE(root.Location, {type: 'Point', coordinates: [20.1, 20]}) < 20000)", + enable_cross_partition_query=True + )) + self.assertEqual(1, len(results)) + self.assertEqual('loc1', results[0]['id']) + + db.delete_container(container=collection) + + # CRUD test for User resource + # crud on user doesn't work in fabric + @unittest.skip + def test_user_crud(self): + # Should do User CRUD operations successfully. + # create database + db = self.databaseForTest + # list users + users = list(db.list_users()) + before_create_count = len(users) + # create user + user_id = 'new user' + str(uuid.uuid4()) + user = db.create_user(body={'id': user_id}) + self.assertEqual(user.id, user_id, 'user id error') + # list users after creation + users = list(db.list_users()) + self.assertEqual(len(users), before_create_count + 1) + # query users + results = list(db.query_users( + query='SELECT * FROM root r WHERE r.id=@id', + parameters=[ + {'name': '@id', 'value': user_id} + ] + )) + self.assertTrue(results) + + # replace user + replaced_user_id = 'replaced user' + str(uuid.uuid4()) + user_properties = user.read() + user_properties['id'] = replaced_user_id + replaced_user = db.replace_user(user_id, user_properties) + self.assertEqual(replaced_user.id, + replaced_user_id, + 'user id should change') + self.assertEqual(user_properties['id'], + replaced_user.id, + 'user id should stay the same') + # read user + user = db.get_user_client(replaced_user.id) + self.assertEqual(replaced_user.id, user.id) + # delete user + db.delete_user(user.id) + # read user after deletion + deleted_user = db.get_user_client(user.id) + self.__AssertHTTPFailureWithStatus(StatusCodes.NOT_FOUND, + deleted_user.read) + + # crud on user doesn't work in fabric + @unittest.skip + def test_user_upsert(self): + # create database + db = self.databaseForTest + + # read users and check count + users = list(db.list_users()) + before_create_count = len(users) + + # create user using Upsert API + user_id = 'user' + str(uuid.uuid4()) + user = db.upsert_user(body={'id': user_id}) + + # verify id property + self.assertEqual(user.id, user_id, 'user id error') + + # read users after creation and verify updated count + users = list(db.list_users()) + self.assertEqual(len(users), before_create_count + 1) + + # Should replace the user since it already exists, there is no public property to change here + user_properties = user.read() + upserted_user = db.upsert_user(user_properties) + + # verify id property + self.assertEqual(upserted_user.id, + user.id, + 'user id should remain same') + + # read users after upsert and verify count doesn't increases again + users = list(db.list_users()) + self.assertEqual(len(users), before_create_count + 1) + + user_properties = user.read() + user_properties['id'] = 'new user' + str(uuid.uuid4()) + user.id = user_properties['id'] + + # Upsert should create new user since id is different + new_user = db.upsert_user(user_properties) + + # verify id property + self.assertEqual(new_user.id, user.id, 'user id error') + + # read users after upsert and verify count increases + users = list(db.list_users()) + self.assertEqual(len(users), before_create_count + 2) + + # delete users + db.delete_user(upserted_user.id) + db.delete_user(new_user.id) + + # read users after delete and verify count remains the same + users = list(db.list_users()) + self.assertEqual(len(users), before_create_count) + + # crud on user doesn't work in fabric + @unittest.skip + def test_permission_crud(self): + # Should do Permission CRUD operations successfully + # create database + db = self.databaseForTest + # create user + user = db.create_user(body={'id': 'new user' + str(uuid.uuid4())}) + # list permissions + permissions = list(user.list_permissions()) + before_create_count = len(permissions) + permission = { + 'id': 'new permission', + 'permissionMode': documents.PermissionMode.Read, + 'resource': 'dbs/AQAAAA==/colls/AQAAAJ0fgTc=' # A random one. + } + # create permission + permission = user.create_permission(permission) + self.assertEqual(permission.id, + 'new permission', + 'permission id error') + # list permissions after creation + permissions = list(user.list_permissions()) + self.assertEqual(len(permissions), before_create_count + 1) + # query permissions + results = list(user.query_permissions( + query='SELECT * FROM root r WHERE r.id=@id', + parameters=[ + {'name': '@id', 'value': permission.id} + ] + )) + self.assertTrue(results) + + # replace permission + change_permission = permission.properties.copy() + permission.properties['id'] = 'replaced permission' + permission.id = permission.properties['id'] + replaced_permission = user.replace_permission(change_permission['id'], permission.properties) + self.assertEqual(replaced_permission.id, + 'replaced permission', + 'permission id should change') + self.assertEqual(permission.id, + replaced_permission.id, + 'permission id should stay the same') + # read permission + permission = user.get_permission(replaced_permission.id) + self.assertEqual(replaced_permission.id, permission.id) + # delete permission + user.delete_permission(replaced_permission.id) + # read permission after deletion + self.__AssertHTTPFailureWithStatus(StatusCodes.NOT_FOUND, + user.get_permission, + permission.id) + + # crud on user doesn't work in fabric + @unittest.skip + def test_permission_upsert(self): + # create database + db = self.databaseForTest + + # create user + user = db.create_user(body={'id': 'new user' + str(uuid.uuid4())}) + + # read permissions and check count + permissions = list(user.list_permissions()) + before_create_count = len(permissions) + + permission_definition = { + 'id': 'permission', + 'permissionMode': documents.PermissionMode.Read, + 'resource': 'dbs/AQAAAA==/colls/AQAAAJ0fgTc=' # A random one. + } + + # create permission using Upsert API + created_permission = user.upsert_permission(permission_definition) + + # verify id property + self.assertEqual(created_permission.id, + permission_definition['id'], + 'permission id error') + + # read permissions after creation and verify updated count + permissions = list(user.list_permissions()) + self.assertEqual(len(permissions), before_create_count + 1) + + # update permission mode + permission_definition['permissionMode'] = documents.PermissionMode.All + + # should repace the permission since it already exists + upserted_permission = user.upsert_permission(permission_definition) + # verify id property + self.assertEqual(upserted_permission.id, + created_permission.id, + 'permission id should remain same') + + # verify changed property + self.assertEqual(upserted_permission.permission_mode, + permission_definition['permissionMode'], + 'permissionMode should change') + + # read permissions and verify count doesn't increases again + permissions = list(user.list_permissions()) + self.assertEqual(len(permissions), before_create_count + 1) + + # update permission id + created_permission.properties['id'] = 'new permission' + created_permission.id = created_permission.properties['id'] + # resource needs to be changed along with the id in order to create a new permission + created_permission.properties['resource'] = 'dbs/N9EdAA==/colls/N9EdAIugXgA=' + created_permission.resource_link = created_permission.properties['resource'] + + # should create new permission since id has changed + new_permission = user.upsert_permission(created_permission.properties) + + # verify id and resource property + self.assertEqual(new_permission.id, + created_permission.id, + 'permission id should be same') + + self.assertEqual(new_permission.resource_link, + created_permission.resource_link, + 'permission resource should be same') + + # read permissions and verify count increases + permissions = list(user.list_permissions()) + self.assertEqual(len(permissions), before_create_count + 2) + + # delete permissions + user.delete_permission(upserted_permission.id) + user.delete_permission(new_permission.id) + + # read permissions and verify count remains the same + permissions = list(user.list_permissions()) + self.assertEqual(len(permissions), before_create_count) + + # Doesn't apply to Managed Identity test + @unittest.skip + def test_authorization(self): + def __SetupEntities(client): + """ + Sets up entities for this test. + + :Parameters: + - `client`: cosmos_client_connection.CosmosClientConnection + + :Returns: + dict + + """ + # create database + db = self.databaseForTest + # create collection + collection = db.create_container( + id='test_authorization' + str(uuid.uuid4()), + partition_key=PartitionKey(path='/id', kind='Hash') + ) + # create document1 + document = collection.create_item( + body={'id': 'doc1', + 'spam': 'eggs', + 'key': 'value'}, + ) + + # create user + user = db.create_user(body={'id': 'user' + str(uuid.uuid4())}) + + # create permission for collection + permission = { + 'id': 'permission On Coll', + 'permissionMode': documents.PermissionMode.Read, + 'resource': "dbs/" + db.id + "/colls/" + collection.id + } + permission_on_coll = user.create_permission(body=permission) + self.assertIsNotNone(permission_on_coll.properties['_token'], + 'permission token is invalid') + + # create permission for document + permission = { + 'id': 'permission On Doc', + 'permissionMode': documents.PermissionMode.All, + 'resource': "dbs/" + db.id + "/colls/" + collection.id + "/docs/" + document["id"] + } + permission_on_doc = user.create_permission(body=permission) + self.assertIsNotNone(permission_on_doc.properties['_token'], + 'permission token is invalid') + + entities = { + 'db': db, + 'coll': collection, + 'doc': document, + 'user': user, + 'permissionOnColl': permission_on_coll, + 'permissionOnDoc': permission_on_doc, + } + return entities + + # Client without any authorization will fail. + try: + cosmos_client.CosmosClient(TestCRUDOperations.host, {}, "Session", + connection_policy=TestCRUDOperations.connectionPolicy) + raise Exception("Test did not fail as expected.") + except exceptions.CosmosHttpResponseError as error: + self.assertEqual(error.status_code, StatusCodes.UNAUTHORIZED) + + # Client with master key. + client = cosmos_client.CosmosClient(TestCRUDOperations.host, + TestCRUDOperations.masterKey, + "Session", + connection_policy=TestCRUDOperations.connectionPolicy) + # setup entities + entities = __SetupEntities(client) + resource_tokens = {"dbs/" + entities['db'].id + "/colls/" + entities['coll'].id: + entities['permissionOnColl'].properties['_token']} + col_client = cosmos_client.CosmosClient( + TestCRUDOperations.host, resource_tokens, "Session", connection_policy=TestCRUDOperations.connectionPolicy) + db = entities['db'] + + old_client_connection = db.client_connection + db.client_connection = col_client.client_connection + # 1. Success-- Use Col Permission to Read + success_coll = db.get_container_client(container=entities['coll']) + # 2. Failure-- Use Col Permission to delete + self.__AssertHTTPFailureWithStatus(StatusCodes.FORBIDDEN, + db.delete_container, + success_coll) + # 3. Success-- Use Col Permission to Read All Docs + success_documents = list(success_coll.read_all_items()) + self.assertTrue(success_documents != None, + 'error reading documents') + self.assertEqual(len(success_documents), + 1, + 'Expected 1 Document to be successfully read') + # 4. Success-- Use Col Permission to Read Doc + + docId = entities['doc']['id'] + success_doc = success_coll.read_item( + item=docId, + partition_key=docId + ) + self.assertTrue(success_doc != None, 'error reading document') + self.assertEqual( + success_doc['id'], + entities['doc']['id'], + 'Expected to read children using parent permissions') + + # 5. Failure-- Use Col Permission to Delete Doc + self.__AssertHTTPFailureWithStatus(StatusCodes.FORBIDDEN, + success_coll.delete_item, + docId, docId) + + resource_tokens = {"dbs/" + entities['db'].id + "/colls/" + entities['coll'].id + "/docs/" + docId: + entities['permissionOnDoc'].properties['_token']} + + doc_client = cosmos_client.CosmosClient( + TestCRUDOperations.host, resource_tokens, "Session", connection_policy=TestCRUDOperations.connectionPolicy) + + # 6. Success-- Use Doc permission to read doc + read_doc = doc_client.get_database_client(db.id).get_container_client(success_coll.id).read_item(docId, docId) + self.assertEqual(read_doc["id"], docId) + + # 6. Success-- Use Doc permission to delete doc + doc_client.get_database_client(db.id).get_container_client(success_coll.id).delete_item(docId, docId) + self.assertEqual(read_doc["id"], docId) + + db.client_connection = old_client_connection + db.delete_container(entities['coll']) + + def test_client_request_timeout(self): + # Test is flaky on Emulator + if not ('localhost' in self.host or '127.0.0.1' in self.host): + connection_policy = documents.ConnectionPolicy() + # making timeout 0 ms to make sure it will throw + connection_policy.RequestTimeout = 0.000000000001 + + # client does a getDatabaseAccount on initialization, which will not time out because + # there is a forced timeout for those calls + client = cosmos_client.CosmosClient(TestCRUDOperations.host, TestCRUDOperations.masterKey, "Session", + connection_policy=connection_policy) + with self.assertRaises(Exception): + databaseForTest = client.get_database_client(self.configs.TEST_DATABASE_ID) + container = databaseForTest.get_container_client(self.configs.TEST_SINGLE_PARTITION_CONTAINER_ID) + container.create_item(body={'id': str(uuid.uuid4()), 'name': 'sample'}) + + async def test_read_timeout_async(self): + connection_policy = documents.ConnectionPolicy() + # making timeout 0 ms to make sure it will throw + connection_policy.DBAReadTimeout = 0.000000000001 + with self.assertRaises(ServiceResponseError): + # this will make a get database account call + with cosmos_client.CosmosClient(self.host, self.masterKey, connection_policy=connection_policy): + print('initialization') + + @unittest.skip + def test_client_request_timeout_when_connection_retry_configuration_specified(self): + connection_policy = documents.ConnectionPolicy() + # making timeout 0 ms to make sure it will throw + connection_policy.RequestTimeout = 0.000000000001 + connection_policy.ConnectionRetryConfiguration = Retry( + total=3, + read=3, + connect=3, + backoff_factor=0.3, + status_forcelist=(500, 502, 504) + ) + # client does a getDatabaseAccount on initialization, which will not time out because + # there is a forced timeout for those calls + with cosmos_client.CosmosClient(self.fabric_host, self.fabric_credential, connection_policy=connection_policy) as client: + with self.assertRaises(AzureError): + databaseForTest = client.get_database_client(self.configs.TEST_DATABASE_ID) + container = databaseForTest.get_container_client(self.configs.TEST_SINGLE_PARTITION_CONTAINER_ID) + container.create_item(body={'id': str(uuid.uuid4()), 'name': 'sample'}) + + # TODO: Skipping this test to debug later + @unittest.skip + def test_client_connection_retry_configuration(self): + total_time_for_two_retries = self.initialize_client_with_connection_core_retry_config(2) + total_time_for_three_retries = self.initialize_client_with_connection_core_retry_config(3) + self.assertGreater(total_time_for_three_retries, total_time_for_two_retries) + + def initialize_client_with_connection_core_retry_config(self, retries): + start_time = time.time() + try: + cosmos_client.CosmosClient( + "https://localhost:9999", + TestCRUDOperations.masterKey, + "Session", + retry_total=retries, + retry_read=retries, + retry_connect=retries, + retry_status=retries) + self.fail() + except AzureError as e: + end_time = time.time() + return end_time - start_time + + # TODO: Skipping this test to debug later + @unittest.skip + def test_absolute_client_timeout(self): + with self.assertRaises(exceptions.CosmosClientTimeoutError): + cosmos_client.CosmosClient( + "https://localhost:9999", + TestCRUDOperations.masterKey, + "Session", + retry_total=3, + timeout=1) + + error_response = ServiceResponseError("Read timeout") + timeout_transport = TimeoutTransport(error_response) + client = cosmos_client.CosmosClient( + self.host, self.masterKey, "Session", transport=timeout_transport, passthrough=True) + + with self.assertRaises(exceptions.CosmosClientTimeoutError): + client.create_database_if_not_exists("test", timeout=2) + + status_response = 500 # Users connection level retry + timeout_transport = TimeoutTransport(status_response) + client = cosmos_client.CosmosClient( + self.host, self.masterKey, "Session", transport=timeout_transport, passthrough=True) + with self.assertRaises(exceptions.CosmosClientTimeoutError): + client.create_database("test", timeout=2) + + databases = client.list_databases(timeout=2) + with self.assertRaises(exceptions.CosmosClientTimeoutError): + list(databases) + + status_response = 429 # Uses Cosmos custom retry + timeout_transport = TimeoutTransport(status_response) + client = cosmos_client.CosmosClient( + self.host, self.masterKey, "Session", transport=timeout_transport, passthrough=True) + with self.assertRaises(exceptions.CosmosClientTimeoutError): + client.create_database_if_not_exists("test", timeout=2) + + databases = client.list_databases(timeout=2) + with self.assertRaises(exceptions.CosmosClientTimeoutError): + list(databases) + + def test_query_iterable_functionality(self): + collection = self.databaseForTest.create_container("query-iterable-container", + partition_key=PartitionKey("/pk")) + + doc1 = collection.create_item(body={'id': 'doc1', 'prop1': 'value1', 'pk': 'pk'}) + doc2 = collection.create_item(body={'id': 'doc2', 'prop1': 'value2', 'pk': 'pk'}) + doc3 = collection.create_item(body={'id': 'doc3', 'prop1': 'value3', 'pk': 'pk'}) + resources = { + 'coll': collection, + 'doc1': doc1, + 'doc2': doc2, + 'doc3': doc3 + } + + results = resources['coll'].read_all_items(max_item_count=2) + docs = list(iter(results)) + self.assertEqual(3, + len(docs), + 'QueryIterable should return all documents' + + ' using continuation') + self.assertEqual(resources['doc1']['id'], docs[0]['id']) + self.assertEqual(resources['doc2']['id'], docs[1]['id']) + self.assertEqual(resources['doc3']['id'], docs[2]['id']) + + # Validate QueryIterable iterator with 'for'. + results = resources['coll'].read_all_items(max_item_count=2) + counter = 0 + # test QueryIterable with 'for'. + for doc in iter(results): + counter += 1 + if counter == 1: + self.assertEqual(resources['doc1']['id'], + doc['id'], + 'first document should be doc1') + elif counter == 2: + self.assertEqual(resources['doc2']['id'], + doc['id'], + 'second document should be doc2') + elif counter == 3: + self.assertEqual(resources['doc3']['id'], + doc['id'], + 'third document should be doc3') + self.assertEqual(counter, 3) + + # Get query results page by page. + results = resources['coll'].read_all_items(max_item_count=2) + + page_iter = results.by_page() + first_block = list(next(page_iter)) + self.assertEqual(2, len(first_block), 'First block should have 2 entries.') + self.assertEqual(resources['doc1']['id'], first_block[0]['id']) + self.assertEqual(resources['doc2']['id'], first_block[1]['id']) + self.assertEqual(1, len(list(next(page_iter))), 'Second block should have 1 entry.') + with self.assertRaises(StopIteration): + next(page_iter) + + self.databaseForTest.delete_container(collection.id) + + # Crud on user is not supported on fabric integration + @unittest.skip + def test_get_resource_with_dictionary_and_object(self): + created_db = self.databaseForTest + + # read database with id + read_db = self.client.get_database_client(created_db.id) + self.assertEqual(read_db.id, created_db.id) + + # read database with instance + read_db = self.client.get_database_client(created_db) + self.assertEqual(read_db.id, created_db.id) + + # read database with properties + read_db = self.client.get_database_client(created_db.read()) + self.assertEqual(read_db.id, created_db.id) + + created_container = self.databaseForTest.get_container_client(self.configs.TEST_MULTI_PARTITION_CONTAINER_ID) + + # read container with id + read_container = created_db.get_container_client(created_container.id) + self.assertEqual(read_container.id, created_container.id) + + # read container with instance + read_container = created_db.get_container_client(created_container) + self.assertEqual(read_container.id, created_container.id) + + # read container with properties + created_properties = created_container.read() + read_container = created_db.get_container_client(created_properties) + self.assertEqual(read_container.id, created_container.id) + + created_item = created_container.create_item({'id': '1' + str(uuid.uuid4()), 'pk': 'pk'}) + + # read item with id + read_item = created_container.read_item(item=created_item['id'], partition_key=created_item['pk']) + self.assertEqual(read_item['id'], created_item['id']) + + # read item with properties + read_item = created_container.read_item(item=created_item, partition_key=created_item['pk']) + self.assertEqual(read_item['id'], created_item['id']) + + created_sproc = created_container.scripts.create_stored_procedure({ + 'id': 'storedProcedure' + str(uuid.uuid4()), + 'body': 'function () { }' + }) + + # read sproc with id + read_sproc = created_container.scripts.get_stored_procedure(created_sproc['id']) + self.assertEqual(read_sproc['id'], created_sproc['id']) + + # read sproc with properties + read_sproc = created_container.scripts.get_stored_procedure(created_sproc) + self.assertEqual(read_sproc['id'], created_sproc['id']) + + created_trigger = created_container.scripts.create_trigger({ + 'id': 'sample trigger' + str(uuid.uuid4()), + 'serverScript': 'function() {var x = 10;}', + 'triggerType': documents.TriggerType.Pre, + 'triggerOperation': documents.TriggerOperation.All + }) + + # read trigger with id + read_trigger = created_container.scripts.get_trigger(created_trigger['id']) + self.assertEqual(read_trigger['id'], created_trigger['id']) + + # read trigger with properties + read_trigger = created_container.scripts.get_trigger(created_trigger) + self.assertEqual(read_trigger['id'], created_trigger['id']) + + created_udf = created_container.scripts.create_user_defined_function({ + 'id': 'sample udf' + str(uuid.uuid4()), + 'body': 'function() {var x = 10;}' + }) + + # read udf with id + read_udf = created_container.scripts.get_user_defined_function(created_udf['id']) + self.assertEqual(created_udf['id'], read_udf['id']) + + # read udf with properties + read_udf = created_container.scripts.get_user_defined_function(created_udf) + self.assertEqual(created_udf['id'], read_udf['id']) + + created_user = created_db.create_user({ + 'id': 'user' + str(uuid.uuid4()) + }) + + # read user with id + read_user = created_db.get_user_client(created_user.id) + self.assertEqual(read_user.id, created_user.id) + + # read user with instance + read_user = created_db.get_user_client(created_user) + self.assertEqual(read_user.id, created_user.id) + + # read user with properties + created_user_properties = created_user.read() + read_user = created_db.get_user_client(created_user_properties) + self.assertEqual(read_user.id, created_user.id) + + created_permission = created_user.create_permission({ + 'id': 'all permission' + str(uuid.uuid4()), + 'permissionMode': documents.PermissionMode.All, + 'resource': created_container.container_link, + 'resourcePartitionKey': [1] + }) + + # read permission with id + read_permission = created_user.get_permission(created_permission.id) + self.assertEqual(read_permission.id, created_permission.id) + + # read permission with instance + read_permission = created_user.get_permission(created_permission) + self.assertEqual(read_permission.id, created_permission.id) + + # read permission with properties + read_permission = created_user.get_permission(created_permission.properties) + self.assertEqual(read_permission.id, created_permission.id) + + # Skipping for fabric integration + @unittest.skip + def test_delete_all_items_by_partition_key(self): + # enable the test only for the emulator + if "localhost" not in self.host and "127.0.0.1" not in self.host: + return + # create database + created_db = self.databaseForTest + + # create container + created_collection = created_db.create_container( + id='test_delete_all_items_by_partition_key ' + str(uuid.uuid4()), + partition_key=PartitionKey(path='/pk', kind='Hash') + ) + # Create two partition keys + partition_key1 = "{}-{}".format("Partition Key 1", str(uuid.uuid4())) + partition_key2 = "{}-{}".format("Partition Key 2", str(uuid.uuid4())) + + # add items for partition key 1 + for i in range(1, 3): + created_collection.upsert_item( + dict(id="item{}".format(i), pk=partition_key1) + ) + + # add items for partition key 2 + + pk2_item = created_collection.upsert_item(dict(id="item{}".format(3), pk=partition_key2)) + + # delete all items for partition key 1 + created_collection.delete_all_items_by_partition_key(partition_key1) + + # check that only items from partition key 1 have been deleted + items = list(created_collection.read_all_items()) + + # items should only have 1 item, and it should equal pk2_item + self.assertDictEqual(pk2_item, items[0]) + + # attempting to delete a non-existent partition key or passing none should not delete + # anything and leave things unchanged + created_collection.delete_all_items_by_partition_key(None) + + # check that no changes were made by checking if the only item is still there + items = list(created_collection.read_all_items()) + + # items should only have 1 item, and it should equal pk2_item + self.assertDictEqual(pk2_item, items[0]) + + created_db.delete_container(created_collection) + + def test_patch_operations(self): + created_container = self.databaseForTest.get_container_client(self.configs.TEST_MULTI_PARTITION_CONTAINER_ID) + + pkValue = "patch_item_pk" + str(uuid.uuid4()) + # Create item to patch + item = { + "id": "patch_item", + "pk": pkValue, + "prop": "prop1", + "address": { + "city": "Redmond" + }, + "company": "Microsoft", + "number": 3} + created_container.create_item(item) + # Define and run patch operations + operations = [ + {"op": "add", "path": "/color", "value": "yellow"}, + {"op": "remove", "path": "/prop"}, + {"op": "replace", "path": "/company", "value": "CosmosDB"}, + {"op": "set", "path": "/address/new_city", "value": "Atlanta"}, + {"op": "incr", "path": "/number", "value": 7}, + {"op": "move", "from": "/color", "path": "/favorite_color"} + ] + patched_item = created_container.patch_item(item="patch_item", partition_key=pkValue, + patch_operations=operations) + # Verify results from patch operations + self.assertTrue(patched_item.get("color") is None) + self.assertTrue(patched_item.get("prop") is None) + self.assertEqual(patched_item.get("company"), "CosmosDB") + self.assertEqual(patched_item.get("address").get("new_city"), "Atlanta") + self.assertEqual(patched_item.get("number"), 10) + self.assertEqual(patched_item.get("favorite_color"), "yellow") + + # Negative test - attempt to replace non-existent field + operations = [{"op": "replace", "path": "/wrong_field", "value": "wrong_value"}] + try: + created_container.patch_item(item="patch_item", partition_key=pkValue, patch_operations=operations) + except exceptions.CosmosHttpResponseError as e: + self.assertEqual(e.status_code, StatusCodes.BAD_REQUEST) + + # Negative test - attempt to remove non-existent field + operations = [{"op": "remove", "path": "/wrong_field"}] + try: + created_container.patch_item(item="patch_item", partition_key=pkValue, patch_operations=operations) + except exceptions.CosmosHttpResponseError as e: + self.assertEqual(e.status_code, StatusCodes.BAD_REQUEST) + + # Negative test - attempt to increment non-number field + operations = [{"op": "incr", "path": "/company", "value": 3}] + try: + created_container.patch_item(item="patch_item", partition_key=pkValue, patch_operations=operations) + except exceptions.CosmosHttpResponseError as e: + self.assertEqual(e.status_code, StatusCodes.BAD_REQUEST) + + # Negative test - attempt to move from non-existent field + operations = [{"op": "move", "from": "/wrong_field", "path": "/other_field"}] + try: + created_container.patch_item(item="patch_item", partition_key=pkValue, patch_operations=operations) + except exceptions.CosmosHttpResponseError as e: + self.assertEqual(e.status_code, StatusCodes.BAD_REQUEST) + + def test_conditional_patching(self): + created_container = self.databaseForTest.get_container_client(self.configs.TEST_MULTI_PARTITION_CONTAINER_ID) + # Create item to patch + pkValue = "patch_item_pk" + str(uuid.uuid4()) + item = { + "id": "conditional_patch_item", + "pk": pkValue, + "prop": "prop1", + "address": { + "city": "Redmond" + }, + "company": "Microsoft", + "number": 3} + created_container.create_item(item) + + # Define patch operations + operations = [ + {"op": "add", "path": "/color", "value": "yellow"}, + {"op": "remove", "path": "/prop"}, + {"op": "replace", "path": "/company", "value": "CosmosDB"}, + {"op": "set", "path": "/address/new_city", "value": "Atlanta"}, + {"op": "incr", "path": "/number", "value": 7}, + {"op": "move", "from": "/color", "path": "/favorite_color"} + ] + + # Run patch operations with wrong filter + num_false = item.get("number") + 1 + filter_predicate = "from root where root.number = " + str(num_false) + try: + created_container.patch_item(item="conditional_patch_item", partition_key=pkValue, + patch_operations=operations, filter_predicate=filter_predicate) + except exceptions.CosmosHttpResponseError as e: + self.assertEqual(e.status_code, StatusCodes.PRECONDITION_FAILED) + + # Run patch operations with correct filter + filter_predicate = "from root where root.number = " + str(item.get("number")) + patched_item = created_container.patch_item(item="conditional_patch_item", partition_key=pkValue, + patch_operations=operations, filter_predicate=filter_predicate) + # Verify results from patch operations + self.assertTrue(patched_item.get("color") is None) + self.assertTrue(patched_item.get("prop") is None) + self.assertEqual(patched_item.get("company"), "CosmosDB") + self.assertEqual(patched_item.get("address").get("new_city"), "Atlanta") + self.assertEqual(patched_item.get("number"), 10) + self.assertEqual(patched_item.get("favorite_color"), "yellow") + + # Temporarily commenting analytical storage tests until emulator support comes. + # def test_create_container_with_analytical_store_off(self): + # # don't run test, for the time being, if running against the emulator + # if 'localhost' in self.host or '127.0.0.1' in self.host: + # return + + # created_db = self.databaseForTest + # collection_id = 'test_create_container_with_analytical_store_off_' + str(uuid.uuid4()) + # collection_indexing_policy = {'indexingMode': 'consistent'} + # created_recorder = RecordDiagnostics() + # created_collection = created_db.create_container(id=collection_id, + # indexing_policy=collection_indexing_policy, + # partition_key=PartitionKey(path="/pk", kind="Hash"), + # response_hook=created_recorder) + # properties = created_collection.read() + # ttl_key = "analyticalStorageTtl" + # self.assertTrue(ttl_key not in properties or properties[ttl_key] == None) + + # def test_create_container_with_analytical_store_on(self): + # # don't run test, for the time being, if running against the emulator + # if 'localhost' in self.host or '127.0.0.1' in self.host: + # return + + # created_db = self.databaseForTest + # collection_id = 'test_create_container_with_analytical_store_on_' + str(uuid.uuid4()) + # collection_indexing_policy = {'indexingMode': 'consistent'} + # created_recorder = RecordDiagnostics() + # created_collection = created_db.create_container(id=collection_id, + # analytical_storage_ttl=-1, + # indexing_policy=collection_indexing_policy, + # partition_key=PartitionKey(path="/pk", kind="Hash"), + # response_hook=created_recorder) + # properties = created_collection.read() + # ttl_key = "analyticalStorageTtl" + # self.assertTrue(ttl_key in properties and properties[ttl_key] == -1) + + # def test_create_container_if_not_exists_with_analytical_store_on(self): + # # don't run test, for the time being, if running against the emulator + # if 'localhost' in self.host or '127.0.0.1' in self.host: + # return + + # # first, try when we know the container doesn't exist. + # created_db = self.databaseForTest + # collection_id = 'test_create_container_if_not_exists_with_analytical_store_on_' + str(uuid.uuid4()) + # collection_indexing_policy = {'indexingMode': 'consistent'} + # created_recorder = RecordDiagnostics() + # created_collection = created_db.create_container_if_not_exists(id=collection_id, + # analytical_storage_ttl=-1, + # indexing_policy=collection_indexing_policy, + # partition_key=PartitionKey(path="/pk", kind="Hash"), + # response_hook=created_recorder) + # properties = created_collection.read() + # ttl_key = "analyticalStorageTtl" + # self.assertTrue(ttl_key in properties and properties[ttl_key] == -1) + + # # next, try when we know the container DOES exist. This way both code paths are tested. + # created_collection = created_db.create_container_if_not_exists(id=collection_id, + # analytical_storage_ttl=-1, + # indexing_policy=collection_indexing_policy, + # partition_key=PartitionKey(path="/pk", kind="Hash"), + # response_hook=created_recorder) + # properties = created_collection.read() + # ttl_key = "analyticalStorageTtl" + # self.assertTrue(ttl_key in properties and properties[ttl_key] == -1) + + def test_priority_level(self): + # These test verify if headers for priority level are sent + # Feature must be enabled at the account level + # If feature is not enabled the test will still pass as we just verify the headers were sent + created_container = self.databaseForTest.get_container_client(self.configs.TEST_MULTI_PARTITION_CONTAINER_ID) + + item1 = {"id": "item1", "pk": "pk1"} + item2 = {"id": "item2", "pk": "pk2"} + self.OriginalExecuteFunction = _retry_utility.ExecuteFunction + priority_headers = [] + + # mock execute function to check if priority level set in headers + + def priority_mock_execute_function(function, *args, **kwargs): + if args: + priority_headers.append(args[4].headers[HttpHeaders.PriorityLevel] + if HttpHeaders.PriorityLevel in args[4].headers else '') + return self.OriginalExecuteFunction(function, *args, **kwargs) + + _retry_utility.ExecuteFunction = priority_mock_execute_function + # upsert item with high priority + created_container.upsert_item(body=item1, priority="High") + # check if the priority level was passed + self.assertEqual(priority_headers[-1], "High") + # upsert item with low priority + created_container.upsert_item(body=item2, priority="Low") + # check that headers passed low priority + self.assertEqual(priority_headers[-1], "Low") + # Repeat for read operations + item1_read = created_container.read_item("item1", "pk1", priority="High") + self.assertEqual(priority_headers[-1], "High") + item2_read = created_container.read_item("item2", "pk2", priority="Low") + self.assertEqual(priority_headers[-1], "Low") + # repeat for query + query = list(created_container.query_items("Select * from c", partition_key="pk1", priority="High")) + + self.assertEqual(priority_headers[-1], "High") + + # Negative Test: Verify that if we send a value other than High or Low that it will not set the header value + # and result in bad request + try: + item2_read = created_container.read_item("item2", "pk2", priority="Medium") + except exceptions.CosmosHttpResponseError as e: + self.assertEqual(e.status_code, StatusCodes.BAD_REQUEST) + _retry_utility.ExecuteFunction = self.OriginalExecuteFunction + + def _MockExecuteFunction(self, function, *args, **kwargs): + if HttpHeaders.PartitionKey in args[4].headers: + self.last_headers.append(args[4].headers[HttpHeaders.PartitionKey]) + return self.OriginalExecuteFunction(function, *args, **kwargs) + + +if __name__ == '__main__': + try: + unittest.main() + except SystemExit as inst: + if inst.args[0] is True: # raised by sys.exit(True) when tests failed + raise diff --git a/sdk/cosmos/azure-cosmos/tests/test_fabric_crud_container.py b/sdk/cosmos/azure-cosmos/tests/test_fabric_crud_container.py new file mode 100644 index 000000000000..b7860d341043 --- /dev/null +++ b/sdk/cosmos/azure-cosmos/tests/test_fabric_crud_container.py @@ -0,0 +1,967 @@ +# -*- coding: utf-8 -*- +# The MIT License (MIT) +# Copyright (c) Microsoft Corporation. All rights reserved. + +"""End-to-end test. +""" + +import json +import os.path +import time +import unittest +import urllib.parse as urllib +import uuid + +import pytest +import requests +from azure.core import MatchConditions +from azure.core.exceptions import AzureError, ServiceResponseError +from azure.core.pipeline.transport import RequestsTransport, RequestsTransportResponse +from urllib3.util.retry import Retry + +import azure.cosmos._base as base +import azure.cosmos.cosmos_client as cosmos_client +import azure.cosmos.documents as documents +import azure.cosmos.exceptions as exceptions +import test_config +from azure.cosmos import _retry_utility +from azure.cosmos.http_constants import HttpHeaders, StatusCodes +from azure.cosmos.partition_key import PartitionKey +from tests import fabric_token_credential + + +class TimeoutTransport(RequestsTransport): + + def __init__(self, response): + self._response = response + super(TimeoutTransport, self).__init__() + + def send(self, *args, **kwargs): + if kwargs.pop("passthrough", False): + return super(TimeoutTransport, self).send(*args, **kwargs) + + time.sleep(5) + if isinstance(self._response, Exception): + raise self._response + output = requests.Response() + output.status_code = self._response + response = RequestsTransportResponse(None, output) + return response + + +@pytest.mark.cosmosLong +class TestCRUDContainerOperations(unittest.TestCase): + """Python CRUD Tests. + """ + + configs = test_config.TestConfig + fabric_host = configs.fabric_host + fabric_credential = fabric_token_credential.FabricTokenCredential() + connectionPolicy = configs.connectionPolicy + last_headers = [] + client: cosmos_client.CosmosClient = None + + def __AssertHTTPFailureWithStatus(self, status_code, func, *args, **kwargs): + """Assert HTTP failure with status. + + :Parameters: + - `status_code`: int + - `func`: function + """ + try: + func(*args, **kwargs) + self.assertFalse(True, 'function should fail.') + except exceptions.CosmosHttpResponseError as inst: + self.assertEqual(inst.status_code, status_code) + + @classmethod + def setUpClass(cls): + cls.client = cosmos_client.CosmosClient(cls.fabric_host, credential=cls.fabric_credential) + cls.databaseForTest = cls.client.get_database_client(cls.configs.TEST_DATABASE_ID) + + def test_collection_crud(self): + created_db = self.databaseForTest + collections = list(created_db.list_containers()) + # create a collection + before_create_collections_count = len(collections) + collection_id = 'test_collection_crud ' + str(uuid.uuid4()) + collection_indexing_policy = {'indexingMode': 'consistent'} + created_collection = created_db.create_container(id=collection_id, + indexing_policy=collection_indexing_policy, + partition_key=PartitionKey(path="/pk", kind="Hash")) + self.assertEqual(collection_id, created_collection.id) + + created_properties = created_collection.read() + self.assertEqual('consistent', created_properties['indexingPolicy']['indexingMode']) + self.assertDictEqual(PartitionKey(path='/pk', kind='Hash'), created_properties['partitionKey']) + + # read collections after creation + collections = list(created_db.list_containers()) + self.assertEqual(len(collections), + before_create_collections_count + 1, + 'create should increase the number of collections') + # query collections + collections = list(created_db.query_containers( + { + 'query': 'SELECT * FROM root r WHERE r.id=@id', + 'parameters': [ + {'name': '@id', 'value': collection_id} + ] + })) + + self.assertTrue(collections) + # delete collection + created_db.delete_container(created_collection.id) + # read collection after deletion + created_container = created_db.get_container_client(created_collection.id) + self.__AssertHTTPFailureWithStatus(StatusCodes.NOT_FOUND, + created_container.read) + + # get_throughput is not working + @unittest.skip + def test_partitioned_collection(self): + created_db = self.databaseForTest + + collection_definition = {'id': 'test_partitioned_collection ' + str(uuid.uuid4()), + 'partitionKey': + { + 'paths': ['/id'], + 'kind': documents.PartitionKind.Hash + } + } + + offer_throughput = 10100 + created_collection = created_db.create_container(id=collection_definition['id'], + partition_key=collection_definition['partitionKey'], + offer_throughput=offer_throughput) + + self.assertEqual(collection_definition.get('id'), created_collection.id) + + created_collection_properties = created_collection.read( + populate_partition_key_range_statistics=True, + populate_quota_info=True) + self.assertEqual(collection_definition.get('partitionKey').get('paths')[0], + created_collection_properties['partitionKey']['paths'][0]) + self.assertEqual(collection_definition.get('partitionKey').get('kind'), + created_collection_properties['partitionKey']['kind']) + self.assertIsNotNone(created_collection_properties.get("statistics")) + self.assertIsNotNone(created_db.client_connection.last_response_headers.get("x-ms-resource-usage")) + + expected_offer = created_collection.get_throughput() + + self.assertIsNotNone(expected_offer) + + self.assertEqual(expected_offer.offer_throughput, offer_throughput) + + created_db.delete_container(created_collection.id) + + def test_partitioned_collection_partition_key_extraction(self): + created_db = self.databaseForTest + + collection_id = 'test_partitioned_collection_partition_key_extraction ' + str(uuid.uuid4()) + created_collection = created_db.create_container( + id=collection_id, + partition_key=PartitionKey(path='/address/state', kind=documents.PartitionKind.Hash) + ) + + document_definition = {'id': 'document1', + 'address': {'street': '1 Microsoft Way', + 'city': 'Redmond', + 'state': 'WA', + 'zip code': 98052 + } + } + + self.OriginalExecuteFunction = _retry_utility.ExecuteFunction + _retry_utility.ExecuteFunction = self._MockExecuteFunction + # create document without partition key being specified + created_document = created_collection.create_item(body=document_definition) + _retry_utility.ExecuteFunction = self.OriginalExecuteFunction + self.assertEqual(self.last_headers[0], '["WA"]') + del self.last_headers[:] + + self.assertEqual(created_document.get('id'), document_definition.get('id')) + self.assertEqual(created_document.get('address').get('state'), document_definition.get('address').get('state')) + + collection_id = 'test_partitioned_collection_partition_key_extraction1 ' + str(uuid.uuid4()) + created_collection1 = created_db.create_container( + id=collection_id, + partition_key=PartitionKey(path='/address', kind=documents.PartitionKind.Hash) + ) + + self.OriginalExecuteFunction = _retry_utility.ExecuteFunction + _retry_utility.ExecuteFunction = self._MockExecuteFunction + # Create document with partitionkey not present as a leaf level property but a dict + created_document = created_collection1.create_item(document_definition) + _retry_utility.ExecuteFunction = self.OriginalExecuteFunction + self.assertEqual(self.last_headers[0], [{}]) + del self.last_headers[:] + + # self.assertEqual(options['partitionKey'], documents.Undefined) + + collection_id = 'test_partitioned_collection_partition_key_extraction2 ' + str(uuid.uuid4()) + created_collection2 = created_db.create_container( + id=collection_id, + partition_key=PartitionKey(path='/address/state/city', kind=documents.PartitionKind.Hash) + ) + + self.OriginalExecuteFunction = _retry_utility.ExecuteFunction + _retry_utility.ExecuteFunction = self._MockExecuteFunction + # Create document with partitionkey not present in the document + created_document = created_collection2.create_item(document_definition) + _retry_utility.ExecuteFunction = self.OriginalExecuteFunction + self.assertEqual(self.last_headers[0], [{}]) + del self.last_headers[:] + + # self.assertEqual(options['partitionKey'], documents.Undefined) + + created_db.delete_container(created_collection.id) + created_db.delete_container(created_collection1.id) + created_db.delete_container(created_collection2.id) + + def test_partitioned_collection_partition_key_extraction_special_chars(self): + created_db = self.databaseForTest + + collection_id = 'test_partitioned_collection_partition_key_extraction_special_chars1 ' + str(uuid.uuid4()) + + created_collection1 = created_db.create_container( + id=collection_id, + partition_key=PartitionKey(path='/\"level\' 1*()\"/\"le/vel2\"', kind=documents.PartitionKind.Hash) + ) + document_definition = {'id': 'document1', + "level' 1*()": {"le/vel2": 'val1'} + } + + self.OriginalExecuteFunction = _retry_utility.ExecuteFunction + _retry_utility.ExecuteFunction = self._MockExecuteFunction + created_document = created_collection1.create_item(body=document_definition) + _retry_utility.ExecuteFunction = self.OriginalExecuteFunction + self.assertEqual(self.last_headers[0], '["val1"]') + del self.last_headers[:] + + collection_definition2 = { + 'id': 'test_partitioned_collection_partition_key_extraction_special_chars2 ' + str(uuid.uuid4()), + 'partitionKey': + { + 'paths': ['/\'level\" 1*()\'/\'le/vel2\''], + 'kind': documents.PartitionKind.Hash + } + } + + collection_id = 'test_partitioned_collection_partition_key_extraction_special_chars2 ' + str(uuid.uuid4()) + + created_collection2 = created_db.create_container( + id=collection_id, + partition_key=PartitionKey(path='/\'level\" 1*()\'/\'le/vel2\'', kind=documents.PartitionKind.Hash) + ) + + document_definition = {'id': 'document2', + 'level\" 1*()': {'le/vel2': 'val2'} + } + + self.OriginalExecuteFunction = _retry_utility.ExecuteFunction + _retry_utility.ExecuteFunction = self._MockExecuteFunction + # create document without partition key being specified + created_document = created_collection2.create_item(body=document_definition) + _retry_utility.ExecuteFunction = self.OriginalExecuteFunction + self.assertEqual(self.last_headers[0], '["val2"]') + del self.last_headers[:] + + created_db.delete_container(created_collection1.id) + created_db.delete_container(created_collection2.id) + + def test_partitioned_collection_path_parser(self): + test_dir = os.path.dirname(os.path.abspath(__file__)) + with open(os.path.join(test_dir, "BaselineTest.PathParser.json")) as json_file: + entries = json.loads(json_file.read()) + for entry in entries: + parts = base.ParsePaths([entry['path']]) + self.assertEqual(parts, entry['parts']) + + paths = ["/\"Ke \\ \\\" \\\' \\? \\a \\\b \\\f \\\n \\\r \\\t \\v y1\"/*"] + parts = ["Ke \\ \\\" \\\' \\? \\a \\\b \\\f \\\n \\\r \\\t \\v y1", "*"] + self.assertEqual(parts, base.ParsePaths(paths)) + + paths = ["/'Ke \\ \\\" \\\' \\? \\a \\\b \\\f \\\n \\\r \\\t \\v y1'/*"] + parts = ["Ke \\ \\\" \\\' \\? \\a \\\b \\\f \\\n \\\r \\\t \\v y1", "*"] + self.assertEqual(parts, base.ParsePaths(paths)) + + def test_partitioned_collection_conflict_crud_and_query(self): + created_db = self.databaseForTest + + created_collection = self.databaseForTest.get_container_client(self.configs.TEST_MULTI_PARTITION_CONTAINER_ID) + + conflict_definition = {'id': 'new conflict', + 'resourceId': 'doc1', + 'operationType': 'create', + 'resourceType': 'document' + } + + # read conflict here will return resource not found(404) since there is no conflict here + self.__AssertHTTPFailureWithStatus( + StatusCodes.NOT_FOUND, + created_collection.get_conflict, + conflict_definition['id'], + conflict_definition['id'] + ) + + # Read conflict feed doesn't require partitionKey to be specified as it's a cross partition thing + conflict_list = list(created_collection.list_conflicts()) + self.assertEqual(0, len(conflict_list)) + + # delete conflict here will return resource not found(404) since there is no conflict here + self.__AssertHTTPFailureWithStatus( + StatusCodes.NOT_FOUND, + created_collection.delete_conflict, + conflict_definition['id'], + conflict_definition['id'] + ) + + # query conflicts on any property other than partitionKey will fail without setting enableCrossPartitionQuery or passing in the partitionKey value + try: + list(created_collection.query_conflicts( + query='SELECT * FROM root r WHERE r.resourceType=\'' + conflict_definition.get( # nosec + 'resourceType') + '\'' + )) + except Exception: + pass + + conflict_list = list(created_collection.query_conflicts( + query='SELECT * FROM root r WHERE r.resourceType=\'' + conflict_definition.get('resourceType') + '\'', + # nosec + enable_cross_partition_query=True + )) + + self.assertEqual(0, len(conflict_list)) + + # query conflicts by providing the partitionKey value + options = {'partitionKey': conflict_definition.get('id')} + conflict_list = list(created_collection.query_conflicts( + query='SELECT * FROM root r WHERE r.resourceType=\'' + conflict_definition.get('resourceType') + '\'', + # nosec + partition_key=conflict_definition['id'] + )) + + self.assertEqual(0, len(conflict_list)) + + def test_trigger_crud(self): + # create database + db = self.databaseForTest + # create collection + collection = self.databaseForTest.get_container_client(self.configs.TEST_MULTI_PARTITION_CONTAINER_ID) + # read triggers + triggers = list(collection.scripts.list_triggers()) + # create a trigger + before_create_triggers_count = len(triggers) + trigger_id = 'sample trigger-' + str(uuid.uuid4()) + trigger_definition = { + 'id': trigger_id, + 'serverScript': 'function() {var x = 10;}', + 'triggerType': documents.TriggerType.Pre, + 'triggerOperation': documents.TriggerOperation.All + } + trigger = collection.scripts.create_trigger(body=trigger_definition) + for property in trigger_definition: + if property != "serverScript": + self.assertEqual( + trigger[property], + trigger_definition[property], + 'property {property} should match'.format(property=property)) + else: + self.assertEqual(trigger['body'], + 'function() {var x = 10;}') + + # read triggers after creation + triggers = list(collection.scripts.list_triggers()) + self.assertEqual(len(triggers), + before_create_triggers_count + 1, + 'create should increase the number of triggers') + # query triggers + triggers = list(collection.scripts.query_triggers( + query='SELECT * FROM root r WHERE r.id=@id', + parameters=[ + {'name': '@id', 'value': trigger_definition['id']} + ] + )) + self.assertTrue(triggers) + + # replace trigger + change_trigger = trigger.copy() + trigger['body'] = 'function() {var x = 20;}' + replaced_trigger = collection.scripts.replace_trigger(change_trigger['id'], trigger) + for property in trigger_definition: + if property != "serverScript": + self.assertEqual( + replaced_trigger[property], + trigger[property], + 'property {property} should match'.format(property=property)) + else: + self.assertEqual(replaced_trigger['body'], + 'function() {var x = 20;}') + + # read trigger + trigger = collection.scripts.get_trigger(replaced_trigger['id']) + self.assertEqual(replaced_trigger['id'], trigger['id']) + # delete trigger + collection.scripts.delete_trigger(replaced_trigger['id']) + # read triggers after deletion + self.__AssertHTTPFailureWithStatus(StatusCodes.NOT_FOUND, + collection.scripts.delete_trigger, + replaced_trigger['id']) + + def test_udf_crud(self): + # create database + db = self.databaseForTest + # create collection + collection = self.databaseForTest.get_container_client(self.configs.TEST_MULTI_PARTITION_CONTAINER_ID) + # read udfs + udfs = list(collection.scripts.list_user_defined_functions()) + # create a udf + before_create_udfs_count = len(udfs) + udf_definition = { + 'id': 'sample udf', + 'body': 'function() {var x = 10;}' + } + udf = collection.scripts.create_user_defined_function(body=udf_definition) + for property in udf_definition: + self.assertEqual( + udf[property], + udf_definition[property], + 'property {property} should match'.format(property=property)) + + # read udfs after creation + udfs = list(collection.scripts.list_user_defined_functions()) + self.assertEqual(len(udfs), + before_create_udfs_count + 1, + 'create should increase the number of udfs') + # query udfs + results = list(collection.scripts.query_user_defined_functions( + query='SELECT * FROM root r WHERE r.id=@id', + parameters=[ + {'name': '@id', 'value': udf_definition['id']} + ] + )) + self.assertTrue(results) + # replace udf + change_udf = udf.copy() + udf['body'] = 'function() {var x = 20;}' + replaced_udf = collection.scripts.replace_user_defined_function(udf=udf['id'], body=udf) + for property in udf_definition: + self.assertEqual( + replaced_udf[property], + udf[property], + 'property {property} should match'.format(property=property)) + # read udf + udf = collection.scripts.get_user_defined_function(replaced_udf['id']) + self.assertEqual(replaced_udf['id'], udf['id']) + # delete udf + collection.scripts.delete_user_defined_function(replaced_udf['id']) + # read udfs after deletion + self.__AssertHTTPFailureWithStatus(StatusCodes.NOT_FOUND, + collection.scripts.get_user_defined_function, + replaced_udf['id']) + + def test_sproc_crud(self): + # create database + db = self.databaseForTest + # create collection + collection = self.databaseForTest.get_container_client(self.configs.TEST_MULTI_PARTITION_CONTAINER_ID) + # read sprocs + sprocs = list(collection.scripts.list_stored_procedures()) + # create a sproc + before_create_sprocs_count = len(sprocs) + sproc_id = 'sample sproc-' + str(uuid.uuid4()) + sproc_definition = { + 'id': sproc_id, + 'serverScript': 'function() {var x = 10;}' + } + sproc = collection.scripts.create_stored_procedure(sproc_definition) + for property in sproc_definition: + if property != "serverScript": + self.assertEqual( + sproc[property], + sproc_definition[property], + 'property {property} should match'.format(property=property)) + else: + self.assertEqual(sproc['body'], 'function() {var x = 10;}') + + # read sprocs after creation + sprocs = list(collection.scripts.list_stored_procedures()) + self.assertEqual(len(sprocs), + before_create_sprocs_count + 1, + 'create should increase the number of sprocs') + # query sprocs + sprocs = list(collection.scripts.query_stored_procedures( + query='SELECT * FROM root r WHERE r.id=@id', + parameters=[ + {'name': '@id', 'value': sproc_definition['id']} + ] + )) + self.assertIsNotNone(sprocs) + # replace sproc + change_sproc = sproc.copy() + sproc['body'] = 'function() {var x = 20;}' + replaced_sproc = collection.scripts.replace_stored_procedure(sproc=change_sproc['id'], body=sproc) + for property in sproc_definition: + if property != 'serverScript': + self.assertEqual( + replaced_sproc[property], + sproc[property], + 'property {property} should match'.format(property=property)) + else: + self.assertEqual(replaced_sproc['body'], + "function() {var x = 20;}") + # read sproc + sproc = collection.scripts.get_stored_procedure(replaced_sproc['id']) + self.assertEqual(replaced_sproc['id'], sproc['id']) + # delete sproc + collection.scripts.delete_stored_procedure(replaced_sproc['id']) + # read sprocs after deletion + self.__AssertHTTPFailureWithStatus(StatusCodes.NOT_FOUND, + collection.scripts.get_stored_procedure, + replaced_sproc['id']) + + def test_collection_indexing_policy(self): + # create database + db = self.databaseForTest + # create collection + collection = db.get_container_client(self.configs.TEST_MULTI_PARTITION_CONTAINER_ID) + + collection_properties = collection.read() + self.assertEqual(collection_properties['indexingPolicy']['indexingMode'], + documents.IndexingMode.Consistent, + 'default indexing mode should be consistent') + + collection_with_indexing_policy = db.create_container( + id='CollectionWithIndexingPolicy ' + str(uuid.uuid4()), + indexing_policy={ + 'automatic': True, + 'indexingMode': documents.IndexingMode.Consistent, + 'includedPaths': [ + { + 'path': '/', + 'indexes': [ + { + 'kind': documents.IndexKind.Hash, + 'dataType': documents.DataType.Number, + 'precision': 2 + } + ] + } + ], + 'excludedPaths': [ + { + 'path': '/"systemMetadata"/*' + } + ] + }, + partition_key=PartitionKey(path='/id', kind='Hash') + ) + + collection_with_indexing_policy_properties = collection_with_indexing_policy.read() + self.assertEqual(1, + len(collection_with_indexing_policy_properties['indexingPolicy']['includedPaths']), + 'Unexpected includedPaths length') + self.assertEqual(2, + len(collection_with_indexing_policy_properties['indexingPolicy']['excludedPaths']), + 'Unexpected excluded path count') + db.delete_container(collection_with_indexing_policy.id) + + def test_create_default_indexing_policy(self): + # create database + db = self.databaseForTest + + # no indexing policy specified + collection = db.get_container_client(self.configs.TEST_MULTI_PARTITION_CONTAINER_ID) + + collection_properties = collection.read() + self._check_default_indexing_policy_paths(collection_properties['indexingPolicy']) + + # partial policy specified + collection = db.create_container( + id='test_create_default_indexing_policy TestCreateDefaultPolicy01' + str(uuid.uuid4()), + indexing_policy={ + 'indexingMode': documents.IndexingMode.Consistent, 'automatic': True + }, + partition_key=PartitionKey(path='/id', kind='Hash') + ) + collection_properties = collection.read() + self._check_default_indexing_policy_paths(collection_properties['indexingPolicy']) + db.delete_container(container=collection) + + # default policy + collection = db.create_container( + id='test_create_default_indexing_policy TestCreateDefaultPolicy03' + str(uuid.uuid4()), + indexing_policy={}, + partition_key=PartitionKey(path='/id', kind='Hash') + ) + collection_properties = collection.read() + self._check_default_indexing_policy_paths(collection_properties['indexingPolicy']) + db.delete_container(container=collection) + + # missing indexes + collection = db.create_container( + id='test_create_default_indexing_policy TestCreateDefaultPolicy04' + str(uuid.uuid4()), + indexing_policy={ + 'includedPaths': [ + { + 'path': '/*' + } + ] + }, + partition_key=PartitionKey(path='/id', kind='Hash') + ) + collection_properties = collection.read() + self._check_default_indexing_policy_paths(collection_properties['indexingPolicy']) + db.delete_container(container=collection) + + # missing precision + collection = db.create_container( + id='test_create_default_indexing_policy TestCreateDefaultPolicy05' + str(uuid.uuid4()), + indexing_policy={ + 'includedPaths': [ + { + 'path': '/*', + 'indexes': [ + { + 'kind': documents.IndexKind.Hash, + 'dataType': documents.DataType.String + }, + { + 'kind': documents.IndexKind.Range, + 'dataType': documents.DataType.Number + } + ] + } + ] + }, + partition_key=PartitionKey(path='/id', kind='Hash') + ) + collection_properties = collection.read() + self._check_default_indexing_policy_paths(collection_properties['indexingPolicy']) + db.delete_container(container=collection) + + def test_create_indexing_policy_with_composite_and_spatial_indexes(self): + # create database + db = self.databaseForTest + + indexing_policy = { + "spatialIndexes": [ + { + "path": "/path0/*", + "types": [ + "Point", + "LineString", + "Polygon", + "MultiPolygon" + ] + }, + { + "path": "/path1/*", + "types": [ + "Point", + "LineString", + "Polygon", + "MultiPolygon" + ] + } + ], + "compositeIndexes": [ + [ + { + "path": "/path1", + "order": "ascending" + }, + { + "path": "/path2", + "order": "descending" + }, + { + "path": "/path3", + "order": "ascending" + } + ], + [ + { + "path": "/path4", + "order": "ascending" + }, + { + "path": "/path5", + "order": "descending" + }, + { + "path": "/path6", + "order": "ascending" + } + ] + ] + } + + # TODO: custom_logger = logging.getLogger("CustomLogger") was in old code, check on later + created_container = db.create_container( + id='composite_index_spatial_index' + str(uuid.uuid4()), + indexing_policy=indexing_policy, + partition_key=PartitionKey(path='/id', kind='Hash'), + headers={"Foo": "bar"}, + user_agent="blah", + user_agent_overwrite=True, + logging_enable=True, + ) + # TODO: logger was passed into read previously + created_properties = created_container.read() + read_indexing_policy = created_properties['indexingPolicy'] + + # All types are returned for spatial Indexes + self.assertListEqual(indexing_policy['spatialIndexes'], read_indexing_policy['spatialIndexes']) + + self.assertListEqual(indexing_policy['compositeIndexes'], read_indexing_policy['compositeIndexes']) + db.delete_container(container=created_container) + + def _check_default_indexing_policy_paths(self, indexing_policy): + def __get_first(array): + if array: + return array[0] + else: + return None + + # '/_etag' is present in excluded paths by default + self.assertEqual(1, len(indexing_policy['excludedPaths'])) + # included paths should be 1: '/'. + self.assertEqual(1, len(indexing_policy['includedPaths'])) + + root_included_path = __get_first([included_path for included_path in indexing_policy['includedPaths'] + if included_path['path'] == '/*']) + self.assertFalse(root_included_path.get('indexes')) + + def test_trigger_functionality(self): + triggers_in_collection1 = [ + { + 'id': 't1', + 'body': ( + 'function() {' + + ' var item = getContext().getRequest().getBody();' + + ' item.id = item.id.toUpperCase() + \'t1\';' + + ' getContext().getRequest().setBody(item);' + + '}'), + 'triggerType': documents.TriggerType.Pre, + 'triggerOperation': documents.TriggerOperation.All + }, + { + 'id': 'response1', + 'body': ( + 'function() {' + + ' var prebody = getContext().getRequest().getBody();' + + ' if (prebody.id != \'TESTING POST TRIGGERt1\')' + ' throw \'id mismatch\';' + + ' var postbody = getContext().getResponse().getBody();' + + ' if (postbody.id != \'TESTING POST TRIGGERt1\')' + ' throw \'id mismatch\';' + '}'), + 'triggerType': documents.TriggerType.Post, + 'triggerOperation': documents.TriggerOperation.All + }, + { + 'id': 'response2', + # can't be used because setValue is currently disabled + 'body': ( + 'function() {' + + ' var predoc = getContext().getRequest().getBody();' + + ' var postdoc = getContext().getResponse().getBody();' + + ' getContext().getResponse().setValue(' + + ' \'predocname\', predoc.id + \'response2\');' + + ' getContext().getResponse().setValue(' + + ' \'postdocname\', postdoc.id + \'response2\');' + + '}'), + 'triggerType': documents.TriggerType.Post, + 'triggerOperation': documents.TriggerOperation.All, + }] + triggers_in_collection2 = [ + { + 'id': "t2", + 'body': "function() { }", # trigger already stringified + 'triggerType': documents.TriggerType.Pre, + 'triggerOperation': documents.TriggerOperation.All + }, + { + 'id': "t3", + 'body': ( + 'function() {' + + ' var item = getContext().getRequest().getBody();' + + ' item.id = item.id.toLowerCase() + \'t3\';' + + ' getContext().getRequest().setBody(item);' + + '}'), + 'triggerType': documents.TriggerType.Pre, + 'triggerOperation': documents.TriggerOperation.All + }] + triggers_in_collection3 = [ + { + 'id': 'triggerOpType', + 'body': 'function() { }', + 'triggerType': documents.TriggerType.Post, + 'triggerOperation': documents.TriggerOperation.Delete, + }] + + def __CreateTriggers(collection, triggers): + """Creates triggers. + + :Parameters: + - `client`: cosmos_client_connection.CosmosClientConnection + - `collection`: dict + + """ + for trigger_i in triggers: + trigger = collection.scripts.create_trigger(body=trigger_i) + for property in trigger_i: + self.assertEqual( + trigger[property], + trigger_i[property], + 'property {property} should match'.format(property=property)) + + # create database + db = self.databaseForTest + # create collections + pkd = PartitionKey(path='/id', kind='Hash') + collection1 = db.create_container(id='test_trigger_functionality 1 ' + str(uuid.uuid4()), + partition_key=PartitionKey(path='/key', kind='Hash')) + collection2 = db.create_container(id='test_trigger_functionality 2 ' + str(uuid.uuid4()), + partition_key=PartitionKey(path='/key', kind='Hash')) + collection3 = db.create_container(id='test_trigger_functionality 3 ' + str(uuid.uuid4()), + partition_key=PartitionKey(path='/key', kind='Hash')) + # create triggers + __CreateTriggers(collection1, triggers_in_collection1) + __CreateTriggers(collection2, triggers_in_collection2) + __CreateTriggers(collection3, triggers_in_collection3) + # create document + triggers_1 = list(collection1.scripts.list_triggers()) + self.assertEqual(len(triggers_1), 3) + document_1_1 = collection1.create_item( + body={'id': 'doc1', + 'key': 'value'}, + pre_trigger_include='t1' + ) + self.assertEqual(document_1_1['id'], + 'DOC1t1', + 'id should be capitalized') + + document_1_2 = collection1.create_item( + body={'id': 'testing post trigger', 'key': 'value'}, + pre_trigger_include='t1', + post_trigger_include='response1', + ) + self.assertEqual(document_1_2['id'], 'TESTING POST TRIGGERt1') + + document_1_3 = collection1.create_item( + body={'id': 'responseheaders', 'key': 'value'}, + pre_trigger_include='t1' + ) + self.assertEqual(document_1_3['id'], "RESPONSEHEADERSt1") + + triggers_2 = list(collection2.scripts.list_triggers()) + self.assertEqual(len(triggers_2), 2) + document_2_1 = collection2.create_item( + body={'id': 'doc2', + 'key': 'value2'}, + pre_trigger_include='t2' + ) + self.assertEqual(document_2_1['id'], + 'doc2', + 'id shouldn\'t change') + document_2_2 = collection2.create_item( + body={'id': 'Doc3', + 'prop': 'empty', + 'key': 'value2'}, + pre_trigger_include='t3') + self.assertEqual(document_2_2['id'], 'doc3t3') + + triggers_3 = list(collection3.scripts.list_triggers()) + self.assertEqual(len(triggers_3), 1) + with self.assertRaises(Exception): + collection3.create_item( + body={'id': 'Docoptype', 'key': 'value2'}, + post_trigger_include='triggerOpType' + ) + + db.delete_container(collection1) + db.delete_container(collection2) + db.delete_container(collection3) + + def __ValidateOfferResponseBody(self, offer, expected_coll_link, expected_offer_type): + # type: (Offer, str, Any) -> None + self.assertIsNotNone(offer.properties['id'], 'Id cannot be null.') + self.assertIsNotNone(offer.properties.get('_rid'), 'Resource Id (Rid) cannot be null.') + self.assertIsNotNone(offer.properties.get('_self'), 'Self Link cannot be null.') + self.assertIsNotNone(offer.properties.get('resource'), 'Resource Link cannot be null.') + self.assertTrue(offer.properties['_self'].find(offer.properties['id']) != -1, + 'Offer id not contained in offer self link.') + self.assertEqual(expected_coll_link.strip('/'), offer.properties['resource'].strip('/')) + if (expected_offer_type): + self.assertEqual(expected_offer_type, offer.properties.get('offerType')) + + # get_throughput is not working + @unittest.skip + def test_offer_read_and_query(self): + # Create database. + db = self.databaseForTest + collection = db.get_container_client(self.configs.TEST_MULTI_PARTITION_CONTAINER_ID) + # Read the offer. + expected_offer = collection.get_throughput() + collection_properties = collection.read() + self.__ValidateOfferResponseBody(expected_offer, collection_properties.get('_self'), None) + + # get_throughput is not working + @unittest.skip + def test_offer_replace(self): + # Create database. + db = self.databaseForTest + # Create collection. + collection = db.get_container_client(self.configs.TEST_MULTI_PARTITION_CONTAINER_ID) + # Read Offer + expected_offer = collection.get_throughput() + collection_properties = collection.read() + self.__ValidateOfferResponseBody(expected_offer, collection_properties.get('_self'), None) + # Replace the offer. + replaced_offer = collection.replace_throughput(expected_offer.offer_throughput + 100) + collection_properties = collection.read() + self.__ValidateOfferResponseBody(replaced_offer, collection_properties.get('_self'), None) + # Check if the replaced offer is what we expect. + self.assertEqual(expected_offer.properties.get('content').get('offerThroughput') + 100, + replaced_offer.properties.get('content').get('offerThroughput')) + self.assertEqual(expected_offer.offer_throughput + 100, + replaced_offer.offer_throughput) + + def test_index_progress_headers(self): + created_db = self.databaseForTest + created_container = created_db.get_container_client(self.configs.TEST_MULTI_PARTITION_CONTAINER_ID) + created_container.read(populate_quota_info=True) + self.assertFalse(HttpHeaders.LazyIndexingProgress in created_db.client_connection.last_response_headers) + self.assertTrue(HttpHeaders.IndexTransformationProgress in created_db.client_connection.last_response_headers) + + none_coll = created_db.create_container( + id='test_index_progress_headers none_coll ' + str(uuid.uuid4()), + indexing_policy={ + 'indexingMode': documents.IndexingMode.NoIndex, + 'automatic': False + }, + partition_key=PartitionKey(path="/id", kind='Hash') + ) + created_container = created_db.get_container_client(container=none_coll) + created_container.read(populate_quota_info=True) + self.assertFalse(HttpHeaders.LazyIndexingProgress in created_db.client_connection.last_response_headers) + self.assertTrue(HttpHeaders.IndexTransformationProgress in created_db.client_connection.last_response_headers) + + created_db.delete_container(none_coll) + + + + def _MockExecuteFunction(self, function, *args, **kwargs): + if HttpHeaders.PartitionKey in args[4].headers: + self.last_headers.append(args[4].headers[HttpHeaders.PartitionKey]) + return self.OriginalExecuteFunction(function, *args, **kwargs) + +if __name__ == '__main__': + try: + unittest.main() + except SystemExit as inst: + if inst.args[0] is True: # raised by sys.exit(True) when tests failed + raise \ No newline at end of file diff --git a/sdk/cosmos/azure-cosmos/tests/test_fabric_query.py b/sdk/cosmos/azure-cosmos/tests/test_fabric_query.py new file mode 100644 index 000000000000..cf6d0d9ef084 --- /dev/null +++ b/sdk/cosmos/azure-cosmos/tests/test_fabric_query.py @@ -0,0 +1,607 @@ +# The MIT License (MIT) +# Copyright (c) Microsoft Corporation. All rights reserved. + +import os +import unittest +import uuid + +import pytest + +import azure.cosmos._retry_utility as retry_utility +import azure.cosmos.cosmos_client as cosmos_client +import azure.cosmos.exceptions as exceptions +import test_config +from azure.cosmos import http_constants, DatabaseProxy, _endpoint_discovery_retry_policy +from azure.cosmos._execution_context.base_execution_context import _QueryExecutionContextBase +from azure.cosmos._execution_context.query_execution_info import _PartitionedQueryExecutionInfo +from azure.cosmos.documents import _DistinctType +from azure.cosmos.partition_key import PartitionKey +from tests import fabric_token_credential + + +class TestQuery(unittest.TestCase): + """Test to ensure escaping of non-ascii characters from partition key""" + + created_db: DatabaseProxy = None + client: cosmos_client.CosmosClient = None + config = test_config.TestConfig + fabric_host = test_config.TestConfig.fabric_host + fabric_credential = fabric_token_credential.FabricTokenCredential() + connectionPolicy = config.connectionPolicy + TEST_DATABASE_ID = config.TEST_DATABASE_ID + is_emulator = config.is_emulator + credential = config.credential + + @classmethod + def setUpClass(cls): + cls.client = cosmos_client.CosmosClient(cls.fabric_host, credential=cls.fabric_credential) + cls.created_db = cls.client.get_database_client(cls.config.TEST_DATABASE_ID) + + def test_first_and_last_slashes_trimmed_for_query_string(self): + created_collection = self.created_db.create_container( + "test_trimmed_slashes", PartitionKey(path="/pk")) + doc_id = 'myId' + str(uuid.uuid4()) + document_definition = {'pk': 'pk', 'id': doc_id} + created_collection.create_item(body=document_definition) + + query = 'SELECT * from c' + query_iterable = created_collection.query_items( + query=query, + partition_key='pk' + ) + iter_list = list(query_iterable) + self.assertEqual(iter_list[0]['id'], doc_id) + self.created_db.delete_container(created_collection.id) + + def test_populate_query_metrics(self): + created_collection = self.created_db.create_container("query_metrics_test", + PartitionKey(path="/pk")) + doc_id = 'MyId' + str(uuid.uuid4()) + document_definition = {'pk': 'pk', 'id': doc_id} + created_collection.create_item(body=document_definition) + + query = 'SELECT * from c' + query_iterable = created_collection.query_items( + query=query, + partition_key='pk', + populate_query_metrics=True + ) + + iter_list = list(query_iterable) + self.assertEqual(iter_list[0]['id'], doc_id) + + METRICS_HEADER_NAME = 'x-ms-documentdb-query-metrics' + self.assertTrue(METRICS_HEADER_NAME in created_collection.client_connection.last_response_headers) + metrics_header = created_collection.client_connection.last_response_headers[METRICS_HEADER_NAME] + # Validate header is well-formed: "key1=value1;key2=value2;etc" + metrics = metrics_header.split(';') + self.assertTrue(len(metrics) > 1) + self.assertTrue(all(['=' in x for x in metrics])) + self.created_db.delete_container(created_collection.id) + + def test_populate_index_metrics(self): + created_collection = self.created_db.create_container("query_index_test", + PartitionKey(path="/pk")) + + doc_id = 'MyId' + str(uuid.uuid4()) + document_definition = {'pk': 'pk', 'id': doc_id} + created_collection.create_item(body=document_definition) + + query = 'SELECT * from c' + query_iterable = created_collection.query_items( + query=query, + partition_key='pk', + populate_index_metrics=True + ) + + iter_list = list(query_iterable) + self.assertEqual(iter_list[0]['id'], doc_id) + + INDEX_HEADER_NAME = http_constants.HttpHeaders.IndexUtilization + self.assertTrue(INDEX_HEADER_NAME in created_collection.client_connection.last_response_headers) + index_metrics = created_collection.client_connection.last_response_headers[INDEX_HEADER_NAME] + self.assertIsNotNone(index_metrics) + expected_index_metrics = {'UtilizedSingleIndexes': [{'FilterExpression': '', 'IndexSpec': '/pk/?', + 'FilterPreciseSet': True, 'IndexPreciseSet': True, + 'IndexImpactScore': 'High'}], + 'PotentialSingleIndexes': [], 'UtilizedCompositeIndexes': [], + 'PotentialCompositeIndexes': []} + self.assertDictEqual(expected_index_metrics, index_metrics) + self.created_db.delete_container(created_collection.id) + + # TODO: Need to validate the query request count logic + @pytest.mark.skip + def test_max_item_count_honored_in_order_by_query(self): + created_collection = self.created_db.create_container("test-max-item-count" + str(uuid.uuid4()), + PartitionKey(path="/pk")) + docs = [] + for i in range(10): + document_definition = {'pk': 'pk', 'id': 'myId' + str(uuid.uuid4())} + docs.append(created_collection.create_item(body=document_definition)) + + query = 'SELECT * from c ORDER BY c._ts' + query_iterable = created_collection.query_items( + query=query, + max_item_count=1, + enable_cross_partition_query=True + ) + self.validate_query_requests_count(query_iterable, 25) + + query_iterable = created_collection.query_items( + query=query, + max_item_count=100, + enable_cross_partition_query=True + ) + + self.validate_query_requests_count(query_iterable, 5) + self.created_db.delete_container(created_collection.id) + + def validate_query_requests_count(self, query_iterable, expected_count): + self.count = 0 + self.OriginalExecuteFunction = retry_utility.ExecuteFunction + retry_utility.ExecuteFunction = self._MockExecuteFunction + for block in query_iterable.by_page(): + assert len(list(block)) != 0 + retry_utility.ExecuteFunction = self.OriginalExecuteFunction + self.assertEqual(self.count, expected_count) + self.count = 0 + + def _MockExecuteFunction(self, function, *args, **kwargs): + self.count += 1 + return self.OriginalExecuteFunction(function, *args, **kwargs) + + def test_get_query_plan_through_gateway(self): + created_collection = self.created_db.get_container_client(self.config.TEST_MULTI_PARTITION_CONTAINER_ID) + self._validate_query_plan(query="Select top 10 value count(c.id) from c", + container_link=created_collection.container_link, + top=10, + order_by=[], + aggregate=['Count'], + select_value=True, + offset=None, + limit=None, + distinct=_DistinctType.NoneType) + + self._validate_query_plan(query="Select * from c order by c._ts offset 5 limit 10", + container_link=created_collection.container_link, + top=None, + order_by=['Ascending'], + aggregate=[], + select_value=False, + offset=5, + limit=10, + distinct=_DistinctType.NoneType) + + self._validate_query_plan(query="Select distinct value c.id from c order by c.id", + container_link=created_collection.container_link, + top=None, + order_by=['Ascending'], + aggregate=[], + select_value=True, + offset=None, + limit=None, + distinct=_DistinctType.Ordered) + + def _validate_query_plan(self, query, container_link, top, order_by, aggregate, select_value, offset, limit, + distinct): + query_plan_dict = self.client.client_connection._GetQueryPlanThroughGateway(query, container_link) + query_execution_info = _PartitionedQueryExecutionInfo(query_plan_dict) + self.assertTrue(query_execution_info.has_rewritten_query()) + self.assertEqual(query_execution_info.has_distinct_type(), distinct != "None") + self.assertEqual(query_execution_info.get_distinct_type(), distinct) + self.assertEqual(query_execution_info.has_top(), top is not None) + self.assertEqual(query_execution_info.get_top(), top) + self.assertEqual(query_execution_info.has_order_by(), len(order_by) > 0) + self.assertListEqual(query_execution_info.get_order_by(), order_by) + self.assertEqual(query_execution_info.has_aggregates(), len(aggregate) > 0) + self.assertListEqual(query_execution_info.get_aggregates(), aggregate) + self.assertEqual(query_execution_info.has_select_value(), select_value) + self.assertEqual(query_execution_info.has_offset(), offset is not None) + self.assertEqual(query_execution_info.get_offset(), offset) + self.assertEqual(query_execution_info.has_limit(), limit is not None) + self.assertEqual(query_execution_info.get_limit(), limit) + + def test_unsupported_queries(self): + created_collection = self.created_db.get_container_client(self.config.TEST_MULTI_PARTITION_CONTAINER_ID) + queries = ['SELECT COUNT(1) FROM c', 'SELECT COUNT(1) + 5 FROM c', 'SELECT COUNT(1) + SUM(c) FROM c'] + for query in queries: + query_iterable = created_collection.query_items(query=query, enable_cross_partition_query=True) + try: + list(query_iterable) + self.fail() + except exceptions.CosmosHttpResponseError as e: + self.assertEqual(e.status_code, 400) + + def test_query_with_non_overlapping_pk_ranges(self): + created_collection = self.created_db.get_container_client(self.config.TEST_MULTI_PARTITION_CONTAINER_ID) + query_iterable = created_collection.query_items("select * from c where c.pk='1' or c.pk='2'", + enable_cross_partition_query=True) + self.assertListEqual(list(query_iterable), []) + + def test_offset_limit(self): + created_collection = self.created_db.create_container("offset_limit_test_" + str(uuid.uuid4()), + PartitionKey(path="/pk")) + values = [] + for i in range(10): + document_definition = {'pk': i, 'id': 'myId' + str(uuid.uuid4()), 'value': i // 3} + values.append(created_collection.create_item(body=document_definition)['pk']) + + self._validate_distinct_offset_limit(created_collection=created_collection, + query='SELECT DISTINCT c["value"] from c ORDER BY c.pk OFFSET 0 LIMIT 2', + results=[0, 1]) + + self._validate_distinct_offset_limit(created_collection=created_collection, + query='SELECT DISTINCT c["value"] from c ORDER BY c.pk OFFSET 2 LIMIT 2', + results=[2, 3]) + + self._validate_distinct_offset_limit(created_collection=created_collection, + query='SELECT DISTINCT c["value"] from c ORDER BY c.pk OFFSET 4 LIMIT 3', + results=[]) + + self._validate_offset_limit(created_collection=created_collection, + query='SELECT * from c ORDER BY c.pk OFFSET 0 LIMIT 5', + results=values[:5]) + + self._validate_offset_limit(created_collection=created_collection, + query='SELECT * from c ORDER BY c.pk OFFSET 5 LIMIT 10', + results=values[5:]) + + self._validate_offset_limit(created_collection=created_collection, + query='SELECT * from c ORDER BY c.pk OFFSET 10 LIMIT 5', + results=[]) + + self._validate_offset_limit(created_collection=created_collection, + query='SELECT * from c ORDER BY c.pk OFFSET 100 LIMIT 1', + results=[]) + self.created_db.delete_container(created_collection.id) + + def _validate_offset_limit(self, created_collection, query, results): + query_iterable = created_collection.query_items( + query=query, + enable_cross_partition_query=True + ) + self.assertListEqual(list(map(lambda doc: doc['pk'], list(query_iterable))), results) + + def _validate_distinct_offset_limit(self, created_collection, query, results): + query_iterable = created_collection.query_items( + query=query, + enable_cross_partition_query=True + ) + self.assertListEqual(list(map(lambda doc: doc["value"], list(query_iterable))), results) + + def test_distinct(self): + distinct_field = 'distinct_field' + pk_field = "pk" + different_field = "different_field" + + created_collection = self.created_db.create_container( + id='collection with composite index ' + str(uuid.uuid4()), + partition_key=PartitionKey(path="/pk", kind="Hash"), + indexing_policy={ + "compositeIndexes": [ + [{"path": "/" + pk_field, "order": "ascending"}, + {"path": "/" + distinct_field, "order": "ascending"}], + [{"path": "/" + distinct_field, "order": "ascending"}, + {"path": "/" + pk_field, "order": "ascending"}] + ] + } + ) + documents = [] + for i in range(5): + j = i + while j > i - 5: + document_definition = {pk_field: i, 'id': str(uuid.uuid4()), distinct_field: j} + documents.append(created_collection.create_item(body=document_definition)) + document_definition = {pk_field: i, 'id': str(uuid.uuid4()), distinct_field: j} + documents.append(created_collection.create_item(body=document_definition)) + document_definition = {pk_field: i, 'id': str(uuid.uuid4())} + documents.append(created_collection.create_item(body=document_definition)) + j -= 1 + + padded_docs = self.config._pad_with_none(documents, distinct_field) + + self._validate_distinct(created_collection=created_collection, # returns {} and is right number + query='SELECT distinct c.%s from c' % distinct_field, # nosec + results=self.config._get_distinct_docs(padded_docs, distinct_field, None, False), + is_select=True, + fields=[distinct_field]) + + self._validate_distinct(created_collection=created_collection, + query='SELECT distinct c.%s, c.%s from c' % (distinct_field, pk_field), # nosec + results=self.config._get_distinct_docs(padded_docs, distinct_field, pk_field, False), + is_select=True, + fields=[distinct_field, pk_field]) + + self._validate_distinct(created_collection=created_collection, + query='SELECT distinct value c.%s from c' % distinct_field, # nosec + results=self.config._get_distinct_docs(padded_docs, distinct_field, None, True), + is_select=True, + fields=[distinct_field]) + + self._validate_distinct(created_collection=created_collection, + query='SELECT distinct c.%s from c' % different_field, # nosec + results=['None'], + is_select=True, + fields=[different_field]) + + self.created_db.delete_container(created_collection.id) + + def _validate_distinct(self, created_collection, query, results, is_select, fields): + query_iterable = created_collection.query_items( + query=query, + enable_cross_partition_query=True + ) + query_results = list(query_iterable) + + self.assertEqual(len(results), len(query_results)) + query_results_strings = [] + result_strings = [] + for i in range(len(results)): + query_results_strings.append(self.config._get_query_result_string(query_results[i], fields)) + result_strings.append(str(results[i])) + if is_select: + query_results_strings = sorted(query_results_strings) + result_strings = sorted(result_strings) + self.assertListEqual(result_strings, query_results_strings) + + def test_distinct_on_different_types_and_field_orders(self): + created_collection = self.created_db.create_container( + id="test-distinct-container-" + str(uuid.uuid4()), + partition_key=PartitionKey("/pk"), + offer_throughput=self.config.THROUGHPUT_FOR_5_PARTITIONS) + self.payloads = [ + {'f1': 1, 'f2': 'value', 'f3': 100000000000000000, 'f4': [1, 2, '3'], 'f5': {'f6': {'f7': 2}}}, + {'f2': '\'value', 'f4': [1.0, 2, '3'], 'f5': {'f6': {'f7': 2.0}}, 'f1': 1.0, 'f3': 100000000000000000.00}, + {'f3': 100000000000000000.0, 'f5': {'f6': {'f7': 2}}, 'f2': '\'value', 'f1': 1, 'f4': [1, 2.0, '3']} + ] + self.OriginalExecuteFunction = _QueryExecutionContextBase.__next__ + _QueryExecutionContextBase.__next__ = self._MockNextFunction + + self._validate_distinct_on_different_types_and_field_orders( + collection=created_collection, + query="Select distinct value c.f1 from c", + expected_results=[1], + get_mock_result=lambda x, i: (None, x[i]["f1"]) + ) + + self._validate_distinct_on_different_types_and_field_orders( + collection=created_collection, + query="Select distinct value c.f2 from c", + expected_results=['value', '\'value'], + get_mock_result=lambda x, i: (None, x[i]["f2"]) + ) + + self._validate_distinct_on_different_types_and_field_orders( + collection=created_collection, + query="Select distinct value c.f2 from c order by c.f2", + expected_results=['\'value', 'value'], + get_mock_result=lambda x, i: (x[i]["f2"], x[i]["f2"]) + ) + + self._validate_distinct_on_different_types_and_field_orders( + collection=created_collection, + query="Select distinct value c.f3 from c", + expected_results=[100000000000000000], + get_mock_result=lambda x, i: (None, x[i]["f3"]) + ) + + self._validate_distinct_on_different_types_and_field_orders( + collection=created_collection, + query="Select distinct value c.f4 from c", + expected_results=[[1, 2, '3']], + get_mock_result=lambda x, i: (None, x[i]["f4"]) + ) + + self._validate_distinct_on_different_types_and_field_orders( + collection=created_collection, + query="Select distinct value c.f5.f6 from c", + expected_results=[{'f7': 2}], + get_mock_result=lambda x, i: (None, x[i]["f5"]["f6"]) + ) + + self._validate_distinct_on_different_types_and_field_orders( + collection=created_collection, + query="Select distinct c.f1, c.f2, c.f3 from c", + expected_results=[self.payloads[0], self.payloads[1]], + get_mock_result=lambda x, i: (None, x[i]) + ) + + self._validate_distinct_on_different_types_and_field_orders( + collection=created_collection, + query="Select distinct c.f1, c.f2, c.f3 from c order by c.f1", + expected_results=[self.payloads[0], self.payloads[1]], + get_mock_result=lambda x, i: (i, x[i]) + ) + + _QueryExecutionContextBase.__next__ = self.OriginalExecuteFunction + _QueryExecutionContextBase.next = self.OriginalExecuteFunction + + self.created_db.delete_container(created_collection.id) + + def test_paging_with_continuation_token(self): + created_collection = self.created_db.get_container_client(self.config.TEST_MULTI_PARTITION_CONTAINER_ID) + + document_definition = {'pk': 'pk', 'id': '1'} + created_collection.create_item(body=document_definition) + document_definition = {'pk': 'pk', 'id': '2'} + created_collection.create_item(body=document_definition) + + query = 'SELECT * from c' + query_iterable = created_collection.query_items( + query=query, + partition_key='pk', + max_item_count=1 + ) + pager = query_iterable.by_page() + pager.next() + token = pager.continuation_token + second_page = list(pager.next())[0] + + pager = query_iterable.by_page(token) + second_page_fetched_with_continuation_token = list(pager.next())[0] + + self.assertEqual(second_page['id'], second_page_fetched_with_continuation_token['id']) + + def test_cross_partition_query_with_continuation_token(self): + created_collection = self.created_db.get_container_client(self.config.TEST_MULTI_PARTITION_CONTAINER_ID) + document_definition = {'pk': 'pk1', 'id': str(uuid.uuid4())} + created_collection.create_item(body=document_definition) + document_definition = {'pk': 'pk2', 'id': str(uuid.uuid4())} + created_collection.create_item(body=document_definition) + + query = 'SELECT * from c' + query_iterable = created_collection.query_items( + query=query, + enable_cross_partition_query=True, + max_item_count=1, + ) + pager = query_iterable.by_page() + pager.next() + token = pager.continuation_token + second_page = list(pager.next())[0] + + pager = query_iterable.by_page(token) + second_page_fetched_with_continuation_token = list(pager.next())[0] + + self.assertEqual(second_page['id'], second_page_fetched_with_continuation_token['id']) + + def _validate_distinct_on_different_types_and_field_orders(self, collection, query, expected_results, + get_mock_result): + self.count = 0 + self.get_mock_result = get_mock_result + query_iterable = collection.query_items(query, enable_cross_partition_query=True) + results = list(query_iterable) + for i in range(len(expected_results)): + if isinstance(results[i], dict): + self.assertDictEqual(results[i], expected_results[i]) + elif isinstance(results[i], list): + self.assertListEqual(results[i], expected_results[i]) + else: + self.assertEqual(results[i], expected_results[i]) + self.count = 0 + + def test_value_max_query(self): + container = self.created_db.get_container_client(self.config.TEST_MULTI_PARTITION_CONTAINER_ID) + query = "Select value max(c.version) FROM c where c.isComplete = true and c.lookupVersion = @lookupVersion" + query_results = container.query_items(query, parameters=[ + {"name": "@lookupVersion", "value": "console_csat"} # cspell:disable-line + ], enable_cross_partition_query=True) + + self.assertListEqual(list(query_results), [None]) + + def test_continuation_token_size_limit_query(self): + container = self.created_db.get_container_client(self.config.TEST_MULTI_PARTITION_CONTAINER_ID) + for i in range(1, 1000): + container.create_item(body=dict(pk='123', id=str(uuid.uuid4()), some_value=str(i % 3))) + query = "Select * from c where c.some_value='2'" + response_query = container.query_items(query, partition_key='123', max_item_count=100, + continuation_token_limit=1) + pager = response_query.by_page() + pager.next() + token = pager.continuation_token + # Continuation token size should be below 1kb + self.assertLessEqual(len(token.encode('utf-8')), 1024) + pager.next() + token = pager.continuation_token + + # verify a second time + self.assertLessEqual(len(token.encode('utf-8')), 1024) + + def test_query_request_params_none_retry_policy(self): + created_collection = self.created_db.create_container( + "query_request_params_none_retry_policy_" + str(uuid.uuid4()), PartitionKey(path="/pk")) + items = [ + {'id': str(uuid.uuid4()), 'pk': 'test', 'val': 5}, + {'id': str(uuid.uuid4()), 'pk': 'test', 'val': 5}, + {'id': str(uuid.uuid4()), 'pk': 'test', 'val': 5}] + + for item in items: + created_collection.create_item(body=item) + + self.OriginalExecuteFunction = retry_utility.ExecuteFunction + # Test session retry will properly push the exception when retries run out + retry_utility.ExecuteFunction = self._MockExecuteFunctionSessionRetry + try: + query = "SELECT * FROM c" + items = created_collection.query_items( + query=query, + enable_cross_partition_query=True + ) + fetch_results = list(items) + except exceptions.CosmosHttpResponseError as e: + self.assertEqual(e.status_code, 404) + self.assertEqual(e.sub_status, 1002) + + # Test endpoint discovery retry + retry_utility.ExecuteFunction = self._MockExecuteFunctionEndPointRetry + _endpoint_discovery_retry_policy.EndpointDiscoveryRetryPolicy.Max_retry_attempt_count = 3 + _endpoint_discovery_retry_policy.EndpointDiscoveryRetryPolicy.Retry_after_in_milliseconds = 10 + try: + query = "SELECT * FROM c" + items = created_collection.query_items( + query=query, + enable_cross_partition_query=True + ) + fetch_results = list(items) + except exceptions.CosmosHttpResponseError as e: + self.assertEqual(e.status_code, http_constants.StatusCodes.FORBIDDEN) + self.assertEqual(e.sub_status, http_constants.SubStatusCodes.WRITE_FORBIDDEN) + _endpoint_discovery_retry_policy.EndpointDiscoveryRetryPolicy.Max_retry_attempt_count = 120 + _endpoint_discovery_retry_policy.EndpointDiscoveryRetryPolicy.Retry_after_in_milliseconds = 1000 + + # Finally lets test timeout failover retry + retry_utility.ExecuteFunction = self._MockExecuteFunctionTimeoutFailoverRetry + try: + query = "SELECT * FROM c" + items = created_collection.query_items( + query=query, + enable_cross_partition_query=True + ) + fetch_results = list(items) + except exceptions.CosmosHttpResponseError as e: + self.assertEqual(e.status_code, http_constants.StatusCodes.REQUEST_TIMEOUT) + retry_utility.ExecuteFunction = self.OriginalExecuteFunction + retry_utility.ExecuteFunction = self.OriginalExecuteFunction + self.created_db.delete_container(created_collection.id) + + + def _MockExecuteFunctionSessionRetry(self, function, *args, **kwargs): + if args: + if args[1].operation_type == 'SqlQuery': + ex_to_raise = exceptions.CosmosHttpResponseError(status_code=http_constants.StatusCodes.NOT_FOUND, + message="Read Session is Not Available") + ex_to_raise.sub_status = http_constants.SubStatusCodes.READ_SESSION_NOTAVAILABLE + raise ex_to_raise + return self.OriginalExecuteFunction(function, *args, **kwargs) + + def _MockExecuteFunctionEndPointRetry(self, function, *args, **kwargs): + if args: + if args[1].operation_type == 'SqlQuery': + ex_to_raise = exceptions.CosmosHttpResponseError(status_code=http_constants.StatusCodes.FORBIDDEN, + message="End Point Discovery") + ex_to_raise.sub_status = http_constants.SubStatusCodes.WRITE_FORBIDDEN + raise ex_to_raise + return self.OriginalExecuteFunction(function, *args, **kwargs) + + def _MockExecuteFunctionTimeoutFailoverRetry(self, function, *args, **kwargs): + if args: + if args[1].operation_type == 'SqlQuery': + ex_to_raise = exceptions.CosmosHttpResponseError(status_code=http_constants.StatusCodes.REQUEST_TIMEOUT, + message="Timeout Failover") + raise ex_to_raise + return self.OriginalExecuteFunction(function, *args, **kwargs) + + def _MockNextFunction(self): + if self.count < len(self.payloads): + item, result = self.get_mock_result(self.payloads, self.count) + self.count += 1 + if item is not None: + return {'orderByItems': [{'item': item}], '_rid': 'fake_rid', 'payload': result} + else: + return result + else: + raise StopIteration + + +if __name__ == "__main__": + unittest.main() diff --git a/sdk/cosmos/azure-cosmos/tests/testing_fabric_intergation.py b/sdk/cosmos/azure-cosmos/tests/testing_fabric_intergation.py new file mode 100644 index 000000000000..49673f50f647 --- /dev/null +++ b/sdk/cosmos/azure-cosmos/tests/testing_fabric_intergation.py @@ -0,0 +1,96 @@ +from azure.cosmos import PartitionKey, cosmos_client +from tests import fabric_token_credential, test_config + + +def run_sample(): + print("Running sample") + fabric_credential = fabric_token_credential.FabricTokenCredential() + fabric_host = test_config.TestConfig.fabric_host + client = cosmos_client.CosmosClient(fabric_host, credential=fabric_credential) + databaseForTest = client.get_database_client("dkunda-fabric-cdb") + created_collection = databaseForTest.create_container("testing-container", partition_key=PartitionKey("/pk")) + document_definition = {'id': 'document', + 'key': 'value', + 'pk': 'pk'} + + created_document = created_collection.create_item( + body=document_definition + ) + + assert created_document.get('id'), document_definition.get('id') + assert created_document.get('key'), document_definition.get('key') + + # read document + read_document = created_collection.read_item( + item=created_document.get('id'), + partition_key=created_document.get('pk') + ) + + assert read_document.get('id'), created_document.get('id') + assert read_document.get('key'), created_document.get('key') + + # Read document feed doesn't require partitionKey as it's always a cross partition query + documentlist = list(created_collection.read_all_items()) + assert 1, len(documentlist) + + # replace document + document_definition['key'] = 'new value' + + replaced_document = created_collection.replace_item( + item=read_document, + body=document_definition + ) + + assert replaced_document.get('key'), document_definition.get('key') + + # upsert document(create scenario) + document_definition['id'] = 'document2' + document_definition['key'] = 'value2' + + upserted_document = created_collection.upsert_item(body=document_definition) + + assert upserted_document.get('id'), document_definition.get('id') + assert upserted_document.get('key'), document_definition.get('key') + + documentlist = list(created_collection.read_all_items()) + assert 2, len(documentlist) + + # delete document + created_collection.delete_item(item=upserted_document, partition_key=upserted_document.get('pk')) + + # query document on the partition key specified in the predicate will pass even without setting enableCrossPartitionQuery or passing in the partitionKey value + documentlist = list(created_collection.query_items( + { + 'query': 'SELECT * FROM root r WHERE r.id=\'' + replaced_document.get('id') + '\'' # nosec + }, enable_cross_partition_query=True)) + assert 1, len(documentlist) + + # query document on any property other than partitionKey will fail without setting enableCrossPartitionQuery or passing in the partitionKey value + try: + list(created_collection.query_items( + { + 'query': 'SELECT * FROM root r WHERE r.key=\'' + replaced_document.get('key') + '\'' # nosec + })) + except Exception: + pass + + # cross partition query + documentlist = list(created_collection.query_items( + query='SELECT * FROM root r WHERE r.key=\'' + replaced_document.get('key') + '\'', # nosec + enable_cross_partition_query=True + )) + + assert 1, len(documentlist) + + # query document by providing the partitionKey value + documentlist = list(created_collection.query_items( + query='SELECT * FROM root r WHERE r.key=\'' + replaced_document.get('key') + '\'', # nosec + partition_key=replaced_document.get('pk') + )) + + assert 1, len(documentlist) + databaseForTest.delete_container(created_collection.id) + + +if __name__ == '__main__': + run_sample() From 1d506271d630d5d4de4ad0bce4575736d2c8affa Mon Sep 17 00:00:00 2001 From: Kushagra Thapar Date: Thu, 3 Apr 2025 23:39:33 -0700 Subject: [PATCH 2/3] Added fabric tag --- sdk/cosmos/azure-cosmos/tests/test_fabric_change_feed.py | 2 +- sdk/cosmos/azure-cosmos/tests/test_fabric_crud.py | 2 +- sdk/cosmos/azure-cosmos/tests/test_fabric_crud_container.py | 2 +- sdk/cosmos/azure-cosmos/tests/test_fabric_query.py | 1 + 4 files changed, 4 insertions(+), 3 deletions(-) diff --git a/sdk/cosmos/azure-cosmos/tests/test_fabric_change_feed.py b/sdk/cosmos/azure-cosmos/tests/test_fabric_change_feed.py index 6cb07f7b98c0..2c1c4b102871 100644 --- a/sdk/cosmos/azure-cosmos/tests/test_fabric_change_feed.py +++ b/sdk/cosmos/azure-cosmos/tests/test_fabric_change_feed.py @@ -30,7 +30,7 @@ def round_time(): utc_now = datetime.now(timezone.utc) return utc_now - timedelta(microseconds=utc_now.microsecond) -@pytest.mark.cosmosQuery +@pytest.mark.cosmosFabric @pytest.mark.unittest @pytest.mark.usefixtures("setup") class TestChangeFeed: diff --git a/sdk/cosmos/azure-cosmos/tests/test_fabric_crud.py b/sdk/cosmos/azure-cosmos/tests/test_fabric_crud.py index 2276ed09e26f..dc4c8dfcd2b9 100644 --- a/sdk/cosmos/azure-cosmos/tests/test_fabric_crud.py +++ b/sdk/cosmos/azure-cosmos/tests/test_fabric_crud.py @@ -46,7 +46,7 @@ def send(self, *args, **kwargs): return response -@pytest.mark.cosmosLong +@pytest.mark.cosmosFabric class TestCRUDOperations(unittest.TestCase): """Python CRUD Tests. """ diff --git a/sdk/cosmos/azure-cosmos/tests/test_fabric_crud_container.py b/sdk/cosmos/azure-cosmos/tests/test_fabric_crud_container.py index b7860d341043..936962e08caa 100644 --- a/sdk/cosmos/azure-cosmos/tests/test_fabric_crud_container.py +++ b/sdk/cosmos/azure-cosmos/tests/test_fabric_crud_container.py @@ -49,7 +49,7 @@ def send(self, *args, **kwargs): return response -@pytest.mark.cosmosLong +@pytest.mark.cosmosFabric class TestCRUDContainerOperations(unittest.TestCase): """Python CRUD Tests. """ diff --git a/sdk/cosmos/azure-cosmos/tests/test_fabric_query.py b/sdk/cosmos/azure-cosmos/tests/test_fabric_query.py index cf6d0d9ef084..24350ad71aa2 100644 --- a/sdk/cosmos/azure-cosmos/tests/test_fabric_query.py +++ b/sdk/cosmos/azure-cosmos/tests/test_fabric_query.py @@ -19,6 +19,7 @@ from tests import fabric_token_credential +@pytest.mark.cosmosFabric class TestQuery(unittest.TestCase): """Test to ensure escaping of non-ascii characters from partition key""" From af1b5edd0bd2c38f343142375eea0265071ea144 Mon Sep 17 00:00:00 2001 From: Kushagra Thapar Date: Mon, 7 Apr 2025 21:25:32 -0700 Subject: [PATCH 3/3] added database crud test for fabric native integration --- .../tests/test_fabric_crud_container.py | 4 - .../tests/test_fabric_crud_database.py | 235 ++++++++++++++++++ 2 files changed, 235 insertions(+), 4 deletions(-) create mode 100644 sdk/cosmos/azure-cosmos/tests/test_fabric_crud_database.py diff --git a/sdk/cosmos/azure-cosmos/tests/test_fabric_crud_container.py b/sdk/cosmos/azure-cosmos/tests/test_fabric_crud_container.py index 936962e08caa..29e788d7ff49 100644 --- a/sdk/cosmos/azure-cosmos/tests/test_fabric_crud_container.py +++ b/sdk/cosmos/azure-cosmos/tests/test_fabric_crud_container.py @@ -9,15 +9,11 @@ import os.path import time import unittest -import urllib.parse as urllib import uuid import pytest import requests -from azure.core import MatchConditions -from azure.core.exceptions import AzureError, ServiceResponseError from azure.core.pipeline.transport import RequestsTransport, RequestsTransportResponse -from urllib3.util.retry import Retry import azure.cosmos._base as base import azure.cosmos.cosmos_client as cosmos_client diff --git a/sdk/cosmos/azure-cosmos/tests/test_fabric_crud_database.py b/sdk/cosmos/azure-cosmos/tests/test_fabric_crud_database.py new file mode 100644 index 000000000000..ed7ad45a009c --- /dev/null +++ b/sdk/cosmos/azure-cosmos/tests/test_fabric_crud_database.py @@ -0,0 +1,235 @@ +# -*- coding: utf-8 -*- +# The MIT License (MIT) +# Copyright (c) Microsoft Corporation. All rights reserved. + +"""End-to-end test. +""" + +import time +import unittest +import uuid + +import pytest +import requests +from azure.core.pipeline.transport import RequestsTransport, RequestsTransportResponse + +import azure.cosmos.cosmos_client as cosmos_client +import azure.cosmos.exceptions as exceptions +import test_config +from azure.cosmos.http_constants import HttpHeaders, StatusCodes +from tests import fabric_token_credential + + +class TimeoutTransport(RequestsTransport): + + def __init__(self, response): + self._response = response + super(TimeoutTransport, self).__init__() + + def send(self, *args, **kwargs): + if kwargs.pop("passthrough", False): + return super(TimeoutTransport, self).send(*args, **kwargs) + + time.sleep(5) + if isinstance(self._response, Exception): + raise self._response + output = requests.Response() + output.status_code = self._response + response = RequestsTransportResponse(None, output) + return response + + +@pytest.mark.cosmosLong +class TestCRUDDatabaseOperations(unittest.TestCase): + """Python CRUD Tests. + """ + + configs = test_config.TestConfig + host = configs.host + masterKey = configs.masterKey + fabric_host = configs.fabric_host + fabric_credential = fabric_token_credential.FabricTokenCredential() + connectionPolicy = configs.connectionPolicy + last_headers = [] + client: cosmos_client.CosmosClient = None + + def __AssertHTTPFailureWithStatus(self, status_code, func, *args, **kwargs): + """Assert HTTP failure with status. + + :Parameters: + - `status_code`: int + - `func`: function + """ + try: + func(*args, **kwargs) + self.assertFalse(True, 'function should fail.') + except exceptions.CosmosHttpResponseError as inst: + self.assertEqual(inst.status_code, status_code) + + @classmethod + def setUpClass(cls): + cls.client = cosmos_client.CosmosClient(cls.fabric_host, credential=cls.fabric_credential) + cls.databaseForTest = cls.client.get_database_client(cls.configs.TEST_DATABASE_ID) + + def test_database_query(self): + # Query databases + databases = list(self.client.query_databases({ + 'query': 'SELECT * FROM root r WHERE r.id=@id', + 'parameters': [ + {'name': '@id', 'value': self.configs.TEST_DATABASE_ID}, + ] + })) + self.assertTrue(databases, 'number of results for the query should be > 0') + + def test_database_read(self): + read_database = self.databaseForTest.read() + self.assertEqual(read_database["id"], self.configs.TEST_DATABASE_ID) + + def test_database_crud(self): + database_id = str(uuid.uuid4()) + created_db = self.client.create_database(database_id) + self.assertEqual(created_db.id, database_id) + # Read databases after creation. + databases = list(self.client.query_databases({ + 'query': 'SELECT * FROM root r WHERE r.id=@id', + 'parameters': [ + {'name': '@id', 'value': database_id} + ] + })) + self.assertTrue(databases, 'number of results for the query should be > 0') + + # read database. + self.client.get_database_client(created_db.id).read() + + # delete database. + self.client.delete_database(created_db.id) + # read database after deletion + read_db = self.client.get_database_client(created_db.id) + self.__AssertHTTPFailureWithStatus(StatusCodes.NOT_FOUND, + read_db.read) + + database_proxy = self.client.create_database_if_not_exists(id=database_id, offer_throughput=5000) + self.assertEqual(database_id, database_proxy.id) + self.assertEqual(5000, database_proxy.read_offer().offer_throughput) + + database_proxy = self.client.create_database_if_not_exists(id=database_id, offer_throughput=6000) + self.assertEqual(database_id, database_proxy.id) + self.assertEqual(5000, database_proxy.read_offer().offer_throughput) + + self.client.delete_database(database_id) + + def test_database_level_offer_throughput(self): + # Create a database with throughput + offer_throughput = 1000 + database_id = str(uuid.uuid4()) + created_db = self.client.create_database( + id=database_id, + offer_throughput=offer_throughput + ) + self.assertEqual(created_db.id, database_id) + + # Verify offer throughput for database + offer = created_db.read_offer() + self.assertEqual(offer.offer_throughput, offer_throughput) + + # Update database offer throughput + new_offer_throughput = 2000 + offer = created_db.replace_throughput(new_offer_throughput) + self.assertEqual(offer.offer_throughput, new_offer_throughput) + self.client.delete_database(created_db.id) + + def test_sql_query_crud(self): + # create two databases. + db1 = self.client.create_database('database 1' + str(uuid.uuid4())) + db2 = self.client.create_database('database 2' + str(uuid.uuid4())) + + # query with parameters. + databases = list(self.client.query_databases({ + 'query': 'SELECT * FROM root r WHERE r.id=@id', + 'parameters': [ + {'name': '@id', 'value': db1.id} + ] + })) + self.assertEqual(1, len(databases), 'Unexpected number of query results.') + + # query without parameters. + databases = list(self.client.query_databases({ + 'query': 'SELECT * FROM root r WHERE r.id="database non-existing"' + })) + self.assertEqual(0, len(databases), 'Unexpected number of query results.') + + # query with a string. + databases = list(self.client.query_databases('SELECT * FROM root r WHERE r.id="' + db2.id + '"')) # nosec + self.assertEqual(1, len(databases), 'Unexpected number of query results.') + self.client.delete_database(db1.id) + self.client.delete_database(db2.id) + + def test_database_account_functionality(self): + # Validate database account functionality. + database_account = self.client.get_database_account() + self.assertEqual(database_account.DatabasesLink, '/dbs/') + self.assertEqual(database_account.MediaLink, '/media/') + if (HttpHeaders.MaxMediaStorageUsageInMB in + self.client.client_connection.last_response_headers): + self.assertEqual( + database_account.MaxMediaStorageUsageInMB, + self.client.client_connection.last_response_headers[ + HttpHeaders.MaxMediaStorageUsageInMB]) + if (HttpHeaders.CurrentMediaStorageUsageInMB in + self.client.client_connection.last_response_headers): + self.assertEqual( + database_account.CurrentMediaStorageUsageInMB, + self.client.client_connection.last_response_headers[ + HttpHeaders.CurrentMediaStorageUsageInMB]) + self.assertIsNotNone(database_account.ConsistencyPolicy['defaultConsistencyLevel']) + + def test_id_validation(self): + # Id shouldn't end with space. + try: + self.client.create_database(id='id_with_space ') + self.assertFalse(True) + except ValueError as e: + self.assertEqual('Id ends with a space or newline.', e.args[0]) + # Id shouldn't contain '/'. + + try: + self.client.create_database(id='id_with_illegal/_char') + self.assertFalse(True) + except ValueError as e: + self.assertEqual('Id contains illegal chars.', e.args[0]) + # Id shouldn't contain '\\'. + + try: + self.client.create_database(id='id_with_illegal\\_char') + self.assertFalse(True) + except ValueError as e: + self.assertEqual('Id contains illegal chars.', e.args[0]) + # Id shouldn't contain '?'. + + try: + self.client.create_database(id='id_with_illegal?_char') + self.assertFalse(True) + except ValueError as e: + self.assertEqual('Id contains illegal chars.', e.args[0]) + # Id shouldn't contain '#'. + + try: + self.client.create_database(id='id_with_illegal#_char') + self.assertFalse(True) + except ValueError as e: + self.assertEqual('Id contains illegal chars.', e.args[0]) + + # Id can begin with space + db = self.client.create_database(id=' id_begin_space' + str(uuid.uuid4())) + self.assertTrue(True) + + self.client.delete_database(db.id) + + + +if __name__ == '__main__': + try: + unittest.main() + except SystemExit as inst: + if inst.args[0] is True: # raised by sys.exit(True) when tests failed + raise \ No newline at end of file