diff --git a/sdk/storage/azure-storage-blob/azure/storage/blob/_shared/policies_async.py b/sdk/storage/azure-storage-blob/azure/storage/blob/_shared/policies_async.py index c44e19ca06ea..53f1582ac0a9 100644 --- a/sdk/storage/azure-storage-blob/azure/storage/blob/_shared/policies_async.py +++ b/sdk/storage/azure-storage-blob/azure/storage/blob/_shared/policies_async.py @@ -44,11 +44,12 @@ async def retry_hook(settings, **kwargs): async def is_checksum_retry(response): # retry if invalid content md5 - if response.context.get('validate_content', False) and response.http_response.headers.get('content-md5'): + if hasattr(response.http_response, "load_body"): try: - await response.http_response.load_body() # Load the body in memory and close the socket + await response.http_response.load_body() except (StreamClosedError, StreamConsumedError): pass + if response.context.get('validate_content', False) and response.http_response.headers.get('content-md5'): computed_md5 = response.http_request.headers.get('content-md5', None) or \ encode_base64(StorageContentValidation.get_content_md5(response.http_response.body())) if response.http_response.headers['content-md5'] != computed_md5: diff --git a/sdk/storage/azure-storage-blob/azure/storage/blob/aio/_download_async.py b/sdk/storage/azure-storage-blob/azure/storage/blob/aio/_download_async.py index 8a929647e78f..8710b5fac227 100644 --- a/sdk/storage/azure-storage-blob/azure/storage/blob/aio/_download_async.py +++ b/sdk/storage/azure-storage-blob/azure/storage/blob/aio/_download_async.py @@ -46,8 +46,11 @@ async def process_content(data: Any, start_offset: int, end_offset: int, encryption: Dict[str, Any]) -> bytes: if data is None: raise ValueError("Response cannot be None.") - await data.response.load_body() - content = cast(bytes, data.response.body()) + if hasattr(data.response, "load_body"): + await data.response.load_body() + content = cast(bytes, data.response.body()) + else: + content = b"".join([d async for d in data]) if encryption.get('key') is not None or encryption.get('resolver') is not None: try: return decrypt_blob( diff --git a/sdk/storage/azure-storage-blob/tests/test_common_blob.py b/sdk/storage/azure-storage-blob/tests/test_common_blob.py index 3afdb1103bd3..b7e8936b3656 100644 --- a/sdk/storage/azure-storage-blob/tests/test_common_blob.py +++ b/sdk/storage/azure-storage-blob/tests/test_common_blob.py @@ -53,7 +53,6 @@ from devtools_testutils import FakeTokenCredential, recorded_by_proxy from devtools_testutils.storage import StorageRecordedTestCase from settings.testcase import BlobPreparer -from test_helpers import MockStorageTransport # ------------------------------------------------------------------------------ TEST_CONTAINER_PREFIX = 'container' @@ -3532,57 +3531,4 @@ def test_upload_blob_partial_stream_chunked(self, **kwargs): result = blob.download_blob().readall() assert result == data[:length] - @BlobPreparer() - def test_mock_transport_no_content_validation(self, **kwargs): - storage_account_name = kwargs.pop("storage_account_name") - storage_account_key = kwargs.pop("storage_account_key") - - transport = MockStorageTransport() - blob_client = BlobClient( - self.account_url(storage_account_name, "blob"), - container_name='test_cont', - blob_name='test_blob', - credential=storage_account_key, - transport=transport, - retry_total=0 - ) - - content = blob_client.download_blob() - assert content is not None - - props = blob_client.get_blob_properties() - assert props is not None - - data = b"Hello World!" - resp = blob_client.upload_blob(data, overwrite=True) - assert resp is not None - - blob_data = blob_client.download_blob().read() - assert blob_data == b"Hello World!" # data is fixed by mock transport - - resp = blob_client.delete_blob() - assert resp is None - - @BlobPreparer() - def test_mock_transport_with_content_validation(self, **kwargs): - storage_account_name = kwargs.pop("storage_account_name") - storage_account_key = kwargs.pop("storage_account_key") - - transport = MockStorageTransport() - blob_client = BlobClient( - self.account_url(storage_account_name, "blob"), - container_name='test_cont', - blob_name='test_blob', - credential=storage_account_key, - transport=transport, - retry_total=0 - ) - - data = b"Hello World!" - resp = blob_client.upload_blob(data, overwrite=True, validate_content=True) - assert resp is not None - - blob_data = blob_client.download_blob(validate_content=True).read() - assert blob_data == b"Hello World!" # data is fixed by mock transport - # ------------------------------------------------------------------------------ \ No newline at end of file diff --git a/sdk/storage/azure-storage-blob/tests/test_common_blob_async.py b/sdk/storage/azure-storage-blob/tests/test_common_blob_async.py index aef6d9680603..2ebaad59279d 100644 --- a/sdk/storage/azure-storage-blob/tests/test_common_blob_async.py +++ b/sdk/storage/azure-storage-blob/tests/test_common_blob_async.py @@ -53,7 +53,7 @@ from devtools_testutils.aio import recorded_by_proxy_async from devtools_testutils.storage.aio import AsyncStorageRecordedTestCase from settings.testcase import BlobPreparer -from test_helpers_async import AsyncStream, MockStorageTransport +from test_helpers_async import AsyncStream # ------------------------------------------------------------------------------ TEST_CONTAINER_PREFIX = 'container' @@ -3456,58 +3456,4 @@ async def test_upload_blob_partial_stream_chunked(self, **kwargs): result = await (await blob.download_blob()).readall() assert result == data[:length] - @BlobPreparer() - async def test_mock_transport_no_content_validation(self, **kwargs): - storage_account_name = kwargs.pop("storage_account_name") - storage_account_key = kwargs.pop("storage_account_key") - - transport = MockStorageTransport() - blob_client = BlobClient( - self.account_url(storage_account_name, "blob"), - container_name='test_cont', - blob_name='test_blob', - credential=storage_account_key, - transport=transport, - retry_total=0 - ) - - content = await blob_client.download_blob() - assert content is not None - - props = await blob_client.get_blob_properties() - assert props is not None - - data = b"Hello Async World!" - stream = AsyncStream(data) - resp = await blob_client.upload_blob(stream, overwrite=True) - assert resp is not None - - blob_data = await (await blob_client.download_blob()).read() - assert blob_data == b"Hello Async World!" # data is fixed by mock transport - - resp = await blob_client.delete_blob() - assert resp is None - - @BlobPreparer() - async def test_mock_transport_with_content_validation(self, **kwargs): - storage_account_name = kwargs.pop("storage_account_name") - storage_account_key = kwargs.pop("storage_account_key") - - transport = MockStorageTransport() - blob_client = BlobClient( - self.account_url(storage_account_name, "blob"), - container_name='test_cont', - blob_name='test_blob', - credential=storage_account_key, - transport=transport, - retry_total=0 - ) - - data = b"Hello Async World!" - stream = AsyncStream(data) - resp = await blob_client.upload_blob(stream, overwrite=True, validate_content=True) - assert resp is not None - - blob_data = await (await blob_client.download_blob(validate_content=True)).read() - assert blob_data == b"Hello Async World!" # data is fixed by mock transport # ------------------------------------------------------------------------------ diff --git a/sdk/storage/azure-storage-blob/tests/test_helpers.py b/sdk/storage/azure-storage-blob/tests/test_helpers.py index 6f309bd756cc..dc4932e31673 100644 --- a/sdk/storage/azure-storage-blob/tests/test_helpers.py +++ b/sdk/storage/azure-storage-blob/tests/test_helpers.py @@ -8,7 +8,11 @@ from typing import Any, Dict, Optional from typing_extensions import Self -from azure.core.pipeline.transport import HttpTransport, RequestsTransportResponse +from azure.core.pipeline.transport import ( + HttpTransport, + RequestsTransport, + RequestsTransportResponse +) from azure.core.rest import HttpRequest from requests import Response from urllib3 import HTTPResponse @@ -49,7 +53,7 @@ def tell(self): return self.wrapped_stream.tell() -class MockHttpClientResponse(Response): +class MockClientResponse(Response): def __init__( self, url: str, body_bytes: bytes, @@ -57,7 +61,7 @@ def __init__( status: int = 200, reason: str = "OK" ) -> None: - super(MockHttpClientResponse).__init__() + super(MockClientResponse).__init__() self._url = url self._body = body_bytes self._content = body_bytes @@ -70,7 +74,7 @@ def __init__( self.raw = HTTPResponse() -class MockStorageTransport(HttpTransport): +class MockLegacyTransport(HttpTransport): """ This transport returns legacy http response objects from azure core and is intended only to test our backwards compatibility support. @@ -89,7 +93,7 @@ def send(self, request: HttpRequest, **kwargs: Any) -> RequestsTransportResponse rest_response = RequestsTransportResponse( request=request, - requests_response=MockHttpClientResponse( + requests_response=MockClientResponse( request.url, b"Hello World!", headers, @@ -99,7 +103,7 @@ def send(self, request: HttpRequest, **kwargs: Any) -> RequestsTransportResponse # get_blob_properties rest_response = RequestsTransportResponse( request=request, - requests_response=MockHttpClientResponse( + requests_response=MockClientResponse( request.url, b"", { @@ -112,7 +116,7 @@ def send(self, request: HttpRequest, **kwargs: Any) -> RequestsTransportResponse # upload_blob rest_response = RequestsTransportResponse( request=request, - requests_response=MockHttpClientResponse( + requests_response=MockClientResponse( request.url, b"", { @@ -126,7 +130,7 @@ def send(self, request: HttpRequest, **kwargs: Any) -> RequestsTransportResponse # delete_blob rest_response = RequestsTransportResponse( request=request, - requests_response=MockHttpClientResponse( + requests_response=MockClientResponse( request.url, b"", { @@ -137,7 +141,90 @@ def send(self, request: HttpRequest, **kwargs: Any) -> RequestsTransportResponse ) ) else: - raise ValueError("The request is not accepted as part of MockStorageTransport.") + raise ValueError("The request is not accepted as part of MockLegacyTransport.") + return rest_response + + def __enter__(self) -> Self: + return self + + def __exit__(self, *args: Any) -> None: + pass + + def open(self) -> None: + pass + + def close(self) -> None: + pass + + +class MockCoreTransport(RequestsTransport): + """ + This transport returns http response objects from azure core pipelines and is + intended only to test our backwards compatibility support. + """ + def send(self, request: HttpRequest, **kwargs: Any) -> RequestsTransportResponse: + if request.method == 'GET': + # download_blob + headers = { + "Content-Type": "application/octet-stream", + "Content-Range": "bytes 0-17/18", + "Content-Length": "18", + } + + if "x-ms-range-get-content-md5" in request.headers: + headers["Content-MD5"] = "7Qdih1MuhjZehB6Sv8UNjA==" # cspell:disable-line + + rest_response = RequestsTransportResponse( + request=request, + requests_response=MockClientResponse( + request.url, + b"Hello World!", + headers, + ) + ) + elif request.method == 'HEAD': + # get_blob_properties + rest_response = RequestsTransportResponse( + request=request, + requests_response=MockClientResponse( + request.url, + b"", + { + "Content-Type": "application/octet-stream", + "Content-Length": "1024", + }, + ) + ) + elif request.method == 'PUT': + # upload_blob + rest_response = RequestsTransportResponse( + request=request, + requests_response=MockClientResponse( + request.url, + b"", + { + "Content-Length": "0", + }, + 201, + "Created" + ) + ) + elif request.method == 'DELETE': + # delete_blob + rest_response = RequestsTransportResponse( + request=request, + requests_response=MockClientResponse( + request.url, + b"", + { + "Content-Length": "0", + }, + 202, + "Accepted" + ) + ) + else: + raise ValueError("The request is not accepted as part of MockCoreTransport.") return rest_response def __enter__(self) -> Self: diff --git a/sdk/storage/azure-storage-blob/tests/test_helpers_async.py b/sdk/storage/azure-storage-blob/tests/test_helpers_async.py index 0cbf222f8f94..d59caf32449e 100644 --- a/sdk/storage/azure-storage-blob/tests/test_helpers_async.py +++ b/sdk/storage/azure-storage-blob/tests/test_helpers_async.py @@ -6,9 +6,16 @@ from io import IOBase, UnsupportedOperation from typing import Any, Dict, Optional -from azure.core.pipeline.transport import AioHttpTransportResponse, AsyncHttpTransport +from azure.core.pipeline.transport import ( + AioHttpTransportResponse, + AsyncHttpTransport, + AsyncioRequestsTransport, + AsyncioRequestsTransportResponse +) from azure.core.rest import HttpRequest from aiohttp import ClientResponse +from urllib3 import HTTPResponse +from requests import Response class ProgressTracker: @@ -66,7 +73,7 @@ async def read(self, size: int = -1) -> bytes: return data -class MockAioHttpClientResponse(ClientResponse): +class MockAsyncClientResponse(ClientResponse): def __init__( self, url: str, body_bytes: bytes, @@ -74,7 +81,7 @@ def __init__( status: int = 200, reason: str = "OK" ) -> None: - super(MockAioHttpClientResponse).__init__() + super(MockAsyncClientResponse).__init__() self._url = url self._body = body_bytes self._headers = headers @@ -84,7 +91,7 @@ def __init__( self.reason = reason -class MockStorageTransport(AsyncHttpTransport): +class MockLegacyTransport(AsyncHttpTransport): """ This transport returns legacy http response objects from azure core and is intended only to test our backwards compatibility support. @@ -103,7 +110,7 @@ async def send(self, request: HttpRequest, **kwargs: Any) -> AioHttpTransportRes rest_response = AioHttpTransportResponse( request=request, - aiohttp_response=MockAioHttpClientResponse( + aiohttp_response=MockAsyncClientResponse( request.url, b"Hello Async World!", headers, @@ -114,7 +121,7 @@ async def send(self, request: HttpRequest, **kwargs: Any) -> AioHttpTransportRes # get_blob_properties rest_response = AioHttpTransportResponse( request=request, - aiohttp_response=MockAioHttpClientResponse( + aiohttp_response=MockAsyncClientResponse( request.url, b"", { @@ -128,7 +135,7 @@ async def send(self, request: HttpRequest, **kwargs: Any) -> AioHttpTransportRes # upload_blob rest_response = AioHttpTransportResponse( request=request, - aiohttp_response=MockAioHttpClientResponse( + aiohttp_response=MockAsyncClientResponse( request.url, b"", { @@ -143,7 +150,7 @@ async def send(self, request: HttpRequest, **kwargs: Any) -> AioHttpTransportRes # delete_blob rest_response = AioHttpTransportResponse( request=request, - aiohttp_response=MockAioHttpClientResponse( + aiohttp_response=MockAsyncClientResponse( request.url, b"", { @@ -155,7 +162,7 @@ async def send(self, request: HttpRequest, **kwargs: Any) -> AioHttpTransportRes decompress=False ) else: - raise ValueError("The request is not accepted as part of MockStorageTransport.") + raise ValueError("The request is not accepted as part of MockLegacyTransport.") await rest_response.load_body() return rest_response @@ -171,3 +178,105 @@ async def open(self): async def close(self): pass + + +class MockAsyncResponse(Response): + def __init__( + self, url: str, + body_bytes: bytes, + headers: Dict[str, Any], + status: int = 200, + reason: str = "OK" + ) -> None: + super(MockAsyncResponse).__init__() + self._content = body_bytes + self._content_consumed = False + self.url = url + self.headers = headers + self.status_code = status + self.reason = reason + self.raw = HTTPResponse(body_bytes) + + +class MockCoreTransport(AsyncioRequestsTransport): + """ + This transport returns legacy http response objects from azure core and is + intended only to test our backwards compatibility support. + """ + async def send(self, request: HttpRequest, **kwargs: Any) -> AsyncioRequestsTransportResponse: + if request.method == 'GET': + # download_blob + headers = { + "Content-Type": "application/octet-stream", + "Content-Range": "bytes 0-17/18", + "Content-Length": "18", + } + + if "x-ms-range-get-content-md5" in request.headers: + headers["Content-MD5"] = "I3pVbaOCUTom+G9F9uKFoA==" + + rest_response = AsyncioRequestsTransportResponse( + request=request, + requests_response=MockAsyncResponse( + request.url, + b"Hello Async World!", + headers, + ) + ) + elif request.method == 'HEAD': + # get_blob_properties + rest_response = AsyncioRequestsTransportResponse( + request=request, + requests_response=MockAsyncResponse( + request.url, + b"", + { + "Content-Type": "application/octet-stream", + "Content-Length": "1024", + }, + ) + ) + elif request.method == 'PUT': + # upload_blob + rest_response = AsyncioRequestsTransportResponse( + request=request, + requests_response=MockAsyncResponse( + request.url, + b"", + { + "Content-Length": "0", + }, + 201, + "Created" + ) + ) + elif request.method == 'DELETE': + # delete_blob + rest_response = AsyncioRequestsTransportResponse( + request=request, + requests_response=MockAsyncResponse( + request.url, + b"", + { + "Content-Length": "0", + }, + 202, + "Accepted" + ) + ) + else: + raise ValueError("The request is not accepted as part of MockCoreTransport.") + + return rest_response + + async def __aenter__(self): + return self + + async def __aexit__(self, *args): + pass + + async def open(self): + pass + + async def close(self): + pass diff --git a/sdk/storage/azure-storage-blob/tests/test_transports.py b/sdk/storage/azure-storage-blob/tests/test_transports.py new file mode 100644 index 000000000000..9b48a67b5369 --- /dev/null +++ b/sdk/storage/azure-storage-blob/tests/test_transports.py @@ -0,0 +1,135 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for +# license information. +# -------------------------------------------------------------------------- + +from azure.storage.blob import BlobClient, BlobServiceClient +from azure.core.exceptions import ResourceExistsError + +from devtools_testutils.storage import StorageRecordedTestCase +from settings.testcase import BlobPreparer +from test_helpers import MockCoreTransport, MockLegacyTransport + + +class TestStorageTransports(StorageRecordedTestCase): + def _setup(self, storage_account_name, key): + self.bsc = BlobServiceClient(self.account_url(storage_account_name, "blob"), credential=key) + self.container_name = self.get_resource_name('utcontainer') + self.source_container_name = self.get_resource_name('utcontainersource') + if self.is_live: + try: + self.bsc.create_container(self.container_name, timeout=5) + except ResourceExistsError: + pass + try: + self.bsc.create_container(self.source_container_name, timeout=5) + except ResourceExistsError: + pass + self.byte_data = self.get_random_bytes(1024) + + @BlobPreparer() + def test_legacy_transport(self, **kwargs): + storage_account_name = kwargs.pop("storage_account_name") + storage_account_key = kwargs.pop("storage_account_key") + + transport = MockLegacyTransport() + blob_client = BlobClient( + self.account_url(storage_account_name, "blob"), + container_name='container', + blob_name='test_blob', + credential=storage_account_key, + transport=transport, + retry_total=0 + ) + + content = blob_client.download_blob() + assert content is not None + + props = blob_client.get_blob_properties() + assert props is not None + + data = b"Hello World!" + resp = blob_client.upload_blob(data, overwrite=True) + assert resp is not None + + blob_data = blob_client.download_blob().read() + assert blob_data == b"Hello World!" # data is fixed by mock transport + + resp = blob_client.delete_blob() + assert resp is None + + @BlobPreparer() + def test_legacy_transport_content_validation(self, **kwargs): + storage_account_name = kwargs.pop("storage_account_name") + storage_account_key = kwargs.pop("storage_account_key") + + transport = MockLegacyTransport() + blob_client = BlobClient( + self.account_url(storage_account_name, "blob"), + container_name='container', + blob_name='test_blob', + credential=storage_account_key, + transport=transport, + retry_total=0 + ) + + data = b"Hello World!" + resp = blob_client.upload_blob(data, overwrite=True, validate_content=True) + assert resp is not None + + blob_data = blob_client.download_blob(validate_content=True).read() + assert blob_data == b"Hello World!" # data is fixed by mock transport + + @BlobPreparer() + def test_core_transport(self, **kwargs): + storage_account_name = kwargs.pop("storage_account_name") + storage_account_key = kwargs.pop("storage_account_key") + + transport = MockCoreTransport() + blob_client = BlobClient( + self.account_url(storage_account_name, "blob"), + container_name='container', + blob_name='test_blob', + credential=storage_account_key, + transport=transport, + retry_total=0 + ) + + content = blob_client.download_blob() + assert content is not None + + props = blob_client.get_blob_properties() + assert props is not None + + data = b"Hello World!" + resp = blob_client.upload_blob(data, overwrite=True) + assert resp is not None + + blob_data = blob_client.download_blob().read() + assert blob_data == b"Hello World!" # data is fixed by mock transport + + resp = blob_client.delete_blob() + assert resp is None + + @BlobPreparer() + def test_core_transport_content_validation(self, **kwargs): + storage_account_name = kwargs.pop("storage_account_name") + storage_account_key = kwargs.pop("storage_account_key") + + transport = MockCoreTransport() + blob_client = BlobClient( + self.account_url(storage_account_name, "blob"), + container_name='container', + blob_name='test_blob', + credential=storage_account_key, + transport=transport, + retry_total=0 + ) + + data = b"Hello World!" + resp = blob_client.upload_blob(data, overwrite=True, validate_content=True) + assert resp is not None + + blob_data = blob_client.download_blob(validate_content=True).read() + assert blob_data == b"Hello World!" # data is fixed by mock transport diff --git a/sdk/storage/azure-storage-blob/tests/test_transports_async.py b/sdk/storage/azure-storage-blob/tests/test_transports_async.py new file mode 100644 index 000000000000..60c5663dddb7 --- /dev/null +++ b/sdk/storage/azure-storage-blob/tests/test_transports_async.py @@ -0,0 +1,135 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for +# license information. +# -------------------------------------------------------------------------- + +import pytest + +from azure.storage.blob.aio import BlobClient, BlobServiceClient +from azure.core.exceptions import ResourceExistsError +from azure.core.pipeline.transport import AsyncioRequestsTransport + +from devtools_testutils.storage.aio import AsyncStorageRecordedTestCase +from settings.testcase import BlobPreparer +from test_helpers_async import AsyncStream, MockLegacyTransport + + +class TestStorageTransportsAsync(AsyncStorageRecordedTestCase): + async def _setup(self, storage_account_name, key): + self.bsc = BlobServiceClient(self.account_url(storage_account_name, "blob"), credential=key) + self.container_name = self.get_resource_name('utcontainer') + self.source_container_name = self.get_resource_name('utcontainersource') + self.byte_data = self.get_random_bytes(1024) + if self.is_live: + try: + await self.bsc.create_container(self.container_name) + except ResourceExistsError: + pass + try: + await self.bsc.create_container(self.source_container_name) + except ResourceExistsError: + pass + + @BlobPreparer() + async def test_legacy_transport(self, **kwargs): + storage_account_name = kwargs.pop("storage_account_name") + storage_account_key = kwargs.pop("storage_account_key") + + transport = MockLegacyTransport() + blob_client = BlobClient( + self.account_url(storage_account_name, "blob"), + container_name='container', + blob_name='test_blob', + credential=storage_account_key, + transport=transport, + retry_total=0 + ) + + data = b"Hello Async World!" + stream = AsyncStream(data) + resp = await blob_client.upload_blob(stream, overwrite=True) + assert resp is not None + + props = await blob_client.get_blob_properties() + assert props is not None + + blob_data = await (await blob_client.download_blob()).read() + assert blob_data == b"Hello Async World!" # data is fixed by mock transport + + resp = await blob_client.delete_blob() + assert resp is None + + @BlobPreparer() + async def test_legacy_transport_content_validation(self, **kwargs): + storage_account_name = kwargs.pop("storage_account_name") + storage_account_key = kwargs.pop("storage_account_key") + + transport = MockLegacyTransport() + blob_client = BlobClient( + self.account_url(storage_account_name, "blob"), + container_name='container', + blob_name='test_blob', + credential=storage_account_key, + transport=transport, + retry_total=0 + ) + + data = b"Hello Async World!" + stream = AsyncStream(data) + resp = await blob_client.upload_blob(stream, overwrite=True, validate_content=True) + assert resp is not None + + blob_data = await (await blob_client.download_blob(validate_content=True)).read() + assert blob_data == b"Hello Async World!" # data is fixed by mock transport + + @pytest.mark.live_test_only + @BlobPreparer() + async def test_core_transport(self, **kwargs): + storage_account_name = kwargs.pop("storage_account_name") + storage_account_key = kwargs.pop("storage_account_key") + + transport = AsyncioRequestsTransport() + blob_client = BlobClient( + self.account_url(storage_account_name, "blob"), + container_name='container', + blob_name='test_blob', + credential=storage_account_key, + transport=transport, + retry_total=0 + ) + + data = b"Hello Async World!" + stream = AsyncStream(data) + resp = await blob_client.upload_blob(stream, overwrite=True) + assert resp is not None + + props = await blob_client.get_blob_properties() + assert props is not None + + blob_data = await (await blob_client.download_blob()).read() + assert blob_data == b"Hello Async World!" + + @pytest.mark.live_test_only + @BlobPreparer() + async def test_core_transport_content_validation(self, **kwargs): + storage_account_name = kwargs.pop("storage_account_name") + storage_account_key = kwargs.pop("storage_account_key") + + transport = AsyncioRequestsTransport() + blob_client = BlobClient( + self.account_url(storage_account_name, "blob"), + container_name='container', + blob_name='test_blob', + credential=storage_account_key, + transport=transport, + retry_total=0 + ) + + data = b"Hello Async World!" + stream = AsyncStream(data) + resp = await blob_client.upload_blob(stream, overwrite=True, validate_content=True) + assert resp is not None + + blob_data = await (await blob_client.download_blob(validate_content=True)).read() + assert blob_data == b"Hello Async World!" diff --git a/sdk/storage/azure-storage-file-datalake/azure/storage/filedatalake/_shared/policies_async.py b/sdk/storage/azure-storage-file-datalake/azure/storage/filedatalake/_shared/policies_async.py index 807a51dd297c..7aa5f97ace40 100644 --- a/sdk/storage/azure-storage-file-datalake/azure/storage/filedatalake/_shared/policies_async.py +++ b/sdk/storage/azure-storage-file-datalake/azure/storage/filedatalake/_shared/policies_async.py @@ -44,11 +44,12 @@ async def retry_hook(settings, **kwargs): async def is_checksum_retry(response): # retry if invalid content md5 - if response.context.get('validate_content', False) and response.http_response.headers.get('content-md5'): + if hasattr(response.http_response, "load_body"): try: - await response.http_response.load_body() # Load the body in memory and close the socket + await response.http_response.load_body() except (StreamClosedError, StreamConsumedError): pass + if response.context.get('validate_content', False) and response.http_response.headers.get('content-md5'): computed_md5 = response.http_request.headers.get('content-md5', None) or \ encode_base64(StorageContentValidation.get_content_md5(response.http_response.body())) if response.http_response.headers['content-md5'] != computed_md5: diff --git a/sdk/storage/azure-storage-file-share/azure/storage/fileshare/_shared/policies_async.py b/sdk/storage/azure-storage-file-share/azure/storage/fileshare/_shared/policies_async.py index 807a51dd297c..7aa5f97ace40 100644 --- a/sdk/storage/azure-storage-file-share/azure/storage/fileshare/_shared/policies_async.py +++ b/sdk/storage/azure-storage-file-share/azure/storage/fileshare/_shared/policies_async.py @@ -44,11 +44,12 @@ async def retry_hook(settings, **kwargs): async def is_checksum_retry(response): # retry if invalid content md5 - if response.context.get('validate_content', False) and response.http_response.headers.get('content-md5'): + if hasattr(response.http_response, "load_body"): try: - await response.http_response.load_body() # Load the body in memory and close the socket + await response.http_response.load_body() except (StreamClosedError, StreamConsumedError): pass + if response.context.get('validate_content', False) and response.http_response.headers.get('content-md5'): computed_md5 = response.http_request.headers.get('content-md5', None) or \ encode_base64(StorageContentValidation.get_content_md5(response.http_response.body())) if response.http_response.headers['content-md5'] != computed_md5: diff --git a/sdk/storage/azure-storage-file-share/azure/storage/fileshare/aio/_download_async.py b/sdk/storage/azure-storage-file-share/azure/storage/fileshare/aio/_download_async.py index 44b45084a670..62b589b2ed9c 100644 --- a/sdk/storage/azure-storage-file-share/azure/storage/fileshare/aio/_download_async.py +++ b/sdk/storage/azure-storage-file-share/azure/storage/fileshare/aio/_download_async.py @@ -33,8 +33,10 @@ async def process_content(data: Any) -> bytes: raise ValueError("Response cannot be None.") try: - await data.response.load_body() - return cast(bytes, data.response.body()) + if hasattr(data.response, "load_body"): + await data.response.load_body() + return cast(bytes, data.response.body()) + return b"".join([d async for d in data]) except Exception as error: raise HttpResponseError(message="Download stream interrupted.", response=data.response, error=error) from error diff --git a/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/policies_async.py b/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/policies_async.py index 807a51dd297c..7aa5f97ace40 100644 --- a/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/policies_async.py +++ b/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/policies_async.py @@ -44,11 +44,12 @@ async def retry_hook(settings, **kwargs): async def is_checksum_retry(response): # retry if invalid content md5 - if response.context.get('validate_content', False) and response.http_response.headers.get('content-md5'): + if hasattr(response.http_response, "load_body"): try: - await response.http_response.load_body() # Load the body in memory and close the socket + await response.http_response.load_body() except (StreamClosedError, StreamConsumedError): pass + if response.context.get('validate_content', False) and response.http_response.headers.get('content-md5'): computed_md5 = response.http_request.headers.get('content-md5', None) or \ encode_base64(StorageContentValidation.get_content_md5(response.http_response.body())) if response.http_response.headers['content-md5'] != computed_md5: