From 12b77f61232fb0e9d8515822c49798953be856fb Mon Sep 17 00:00:00 2001 From: Peter Wu Date: Sat, 12 Apr 2025 21:40:51 -0400 Subject: [PATCH 1/8] Ported legacy transport tests to a separate file --- .../tests/test_common_blob.py | 54 ------------- .../tests/test_common_blob_async.py | 55 ------------- .../tests/test_transports.py | 76 ++++++++++++++++++ .../tests/test_transports_async.py | 78 +++++++++++++++++++ 4 files changed, 154 insertions(+), 109 deletions(-) create mode 100644 sdk/storage/azure-storage-blob/tests/test_transports.py create mode 100644 sdk/storage/azure-storage-blob/tests/test_transports_async.py diff --git a/sdk/storage/azure-storage-blob/tests/test_common_blob.py b/sdk/storage/azure-storage-blob/tests/test_common_blob.py index 3afdb1103bd3..b7e8936b3656 100644 --- a/sdk/storage/azure-storage-blob/tests/test_common_blob.py +++ b/sdk/storage/azure-storage-blob/tests/test_common_blob.py @@ -53,7 +53,6 @@ from devtools_testutils import FakeTokenCredential, recorded_by_proxy from devtools_testutils.storage import StorageRecordedTestCase from settings.testcase import BlobPreparer -from test_helpers import MockStorageTransport # ------------------------------------------------------------------------------ TEST_CONTAINER_PREFIX = 'container' @@ -3532,57 +3531,4 @@ def test_upload_blob_partial_stream_chunked(self, **kwargs): result = blob.download_blob().readall() assert result == data[:length] - @BlobPreparer() - def test_mock_transport_no_content_validation(self, **kwargs): - storage_account_name = kwargs.pop("storage_account_name") - storage_account_key = kwargs.pop("storage_account_key") - - transport = MockStorageTransport() - blob_client = BlobClient( - self.account_url(storage_account_name, "blob"), - container_name='test_cont', - blob_name='test_blob', - credential=storage_account_key, - transport=transport, - retry_total=0 - ) - - content = blob_client.download_blob() - assert content is not None - - props = blob_client.get_blob_properties() - assert props is not None - - data = b"Hello World!" - resp = blob_client.upload_blob(data, overwrite=True) - assert resp is not None - - blob_data = blob_client.download_blob().read() - assert blob_data == b"Hello World!" # data is fixed by mock transport - - resp = blob_client.delete_blob() - assert resp is None - - @BlobPreparer() - def test_mock_transport_with_content_validation(self, **kwargs): - storage_account_name = kwargs.pop("storage_account_name") - storage_account_key = kwargs.pop("storage_account_key") - - transport = MockStorageTransport() - blob_client = BlobClient( - self.account_url(storage_account_name, "blob"), - container_name='test_cont', - blob_name='test_blob', - credential=storage_account_key, - transport=transport, - retry_total=0 - ) - - data = b"Hello World!" - resp = blob_client.upload_blob(data, overwrite=True, validate_content=True) - assert resp is not None - - blob_data = blob_client.download_blob(validate_content=True).read() - assert blob_data == b"Hello World!" # data is fixed by mock transport - # ------------------------------------------------------------------------------ \ No newline at end of file diff --git a/sdk/storage/azure-storage-blob/tests/test_common_blob_async.py b/sdk/storage/azure-storage-blob/tests/test_common_blob_async.py index aef6d9680603..aea2f763f738 100644 --- a/sdk/storage/azure-storage-blob/tests/test_common_blob_async.py +++ b/sdk/storage/azure-storage-blob/tests/test_common_blob_async.py @@ -53,7 +53,6 @@ from devtools_testutils.aio import recorded_by_proxy_async from devtools_testutils.storage.aio import AsyncStorageRecordedTestCase from settings.testcase import BlobPreparer -from test_helpers_async import AsyncStream, MockStorageTransport # ------------------------------------------------------------------------------ TEST_CONTAINER_PREFIX = 'container' @@ -3456,58 +3455,4 @@ async def test_upload_blob_partial_stream_chunked(self, **kwargs): result = await (await blob.download_blob()).readall() assert result == data[:length] - @BlobPreparer() - async def test_mock_transport_no_content_validation(self, **kwargs): - storage_account_name = kwargs.pop("storage_account_name") - storage_account_key = kwargs.pop("storage_account_key") - - transport = MockStorageTransport() - blob_client = BlobClient( - self.account_url(storage_account_name, "blob"), - container_name='test_cont', - blob_name='test_blob', - credential=storage_account_key, - transport=transport, - retry_total=0 - ) - - content = await blob_client.download_blob() - assert content is not None - - props = await blob_client.get_blob_properties() - assert props is not None - - data = b"Hello Async World!" - stream = AsyncStream(data) - resp = await blob_client.upload_blob(stream, overwrite=True) - assert resp is not None - - blob_data = await (await blob_client.download_blob()).read() - assert blob_data == b"Hello Async World!" # data is fixed by mock transport - - resp = await blob_client.delete_blob() - assert resp is None - - @BlobPreparer() - async def test_mock_transport_with_content_validation(self, **kwargs): - storage_account_name = kwargs.pop("storage_account_name") - storage_account_key = kwargs.pop("storage_account_key") - - transport = MockStorageTransport() - blob_client = BlobClient( - self.account_url(storage_account_name, "blob"), - container_name='test_cont', - blob_name='test_blob', - credential=storage_account_key, - transport=transport, - retry_total=0 - ) - - data = b"Hello Async World!" - stream = AsyncStream(data) - resp = await blob_client.upload_blob(stream, overwrite=True, validate_content=True) - assert resp is not None - - blob_data = await (await blob_client.download_blob(validate_content=True)).read() - assert blob_data == b"Hello Async World!" # data is fixed by mock transport # ------------------------------------------------------------------------------ diff --git a/sdk/storage/azure-storage-blob/tests/test_transports.py b/sdk/storage/azure-storage-blob/tests/test_transports.py new file mode 100644 index 000000000000..27285c14f69b --- /dev/null +++ b/sdk/storage/azure-storage-blob/tests/test_transports.py @@ -0,0 +1,76 @@ +from azure.storage.blob import BlobClient, BlobServiceClient +from azure.core.exceptions import ResourceExistsError + +from devtools_testutils.storage import StorageRecordedTestCase +from settings.testcase import BlobPreparer +from test_helpers import MockStorageTransport + + +class TestStorageTransports(StorageRecordedTestCase): + def _setup(self, storage_account_name, key): + self.bsc = BlobServiceClient(self.account_url(storage_account_name, "blob"), credential=key) + self.container_name = self.get_resource_name('utcontainer') + self.source_container_name = self.get_resource_name('utcontainersource') + if self.is_live: + try: + self.bsc.create_container(self.container_name, timeout=5) + except ResourceExistsError: + pass + try: + self.bsc.create_container(self.source_container_name, timeout=5) + except ResourceExistsError: + pass + self.byte_data = self.get_random_bytes(1024) + + @BlobPreparer() + def test_legacy_transport(self, **kwargs): + storage_account_name = kwargs.pop("storage_account_name") + storage_account_key = kwargs.pop("storage_account_key") + + transport = MockStorageTransport() + blob_client = BlobClient( + self.account_url(storage_account_name, "blob"), + container_name='test_cont', + blob_name='test_blob', + credential=storage_account_key, + transport=transport, + retry_total=0 + ) + + content = blob_client.download_blob() + assert content is not None + + props = blob_client.get_blob_properties() + assert props is not None + + data = b"Hello World!" + resp = blob_client.upload_blob(data, overwrite=True) + assert resp is not None + + blob_data = blob_client.download_blob().read() + assert blob_data == b"Hello World!" # data is fixed by mock transport + + resp = blob_client.delete_blob() + assert resp is None + + @BlobPreparer() + def test_legacy_transport_content_validation(self, **kwargs): + storage_account_name = kwargs.pop("storage_account_name") + storage_account_key = kwargs.pop("storage_account_key") + + transport = MockStorageTransport() + blob_client = BlobClient( + self.account_url(storage_account_name, "blob"), + container_name='test_cont', + blob_name='test_blob', + credential=storage_account_key, + transport=transport, + retry_total=0 + ) + + data = b"Hello World!" + resp = blob_client.upload_blob(data, overwrite=True, validate_content=True) + assert resp is not None + + blob_data = blob_client.download_blob(validate_content=True).read() + assert blob_data == b"Hello World!" # data is fixed by mock transport diff --git a/sdk/storage/azure-storage-blob/tests/test_transports_async.py b/sdk/storage/azure-storage-blob/tests/test_transports_async.py new file mode 100644 index 000000000000..af4ed6f8f167 --- /dev/null +++ b/sdk/storage/azure-storage-blob/tests/test_transports_async.py @@ -0,0 +1,78 @@ +from azure.storage.blob.aio import BlobClient, BlobServiceClient +from azure.core.exceptions import ResourceExistsError + +from devtools_testutils.storage.aio import AsyncStorageRecordedTestCase +from settings.testcase import BlobPreparer +from test_helpers_async import AsyncStream, MockStorageTransport + + +class TestStorageTransportsAsync(AsyncStorageRecordedTestCase): + async def _setup(self, storage_account_name, key): + self.bsc = BlobServiceClient(self.account_url(storage_account_name, "blob"), credential=key) + self.container_name = self.get_resource_name('utcontainer') + self.source_container_name = self.get_resource_name('utcontainersource') + self.byte_data = self.get_random_bytes(1024) + if self.is_live: + try: + await self.bsc.create_container(self.container_name) + except ResourceExistsError: + pass + try: + await self.bsc.create_container(self.source_container_name) + except ResourceExistsError: + pass + + @BlobPreparer() + async def test_legacy_transport(self, **kwargs): + storage_account_name = kwargs.pop("storage_account_name") + storage_account_key = kwargs.pop("storage_account_key") + + transport = MockStorageTransport() + blob_client = BlobClient( + self.account_url(storage_account_name, "blob"), + container_name='test_cont', + blob_name='test_blob', + credential=storage_account_key, + transport=transport, + retry_total=0 + ) + + content = await blob_client.download_blob() + assert content is not None + + props = await blob_client.get_blob_properties() + assert props is not None + + data = b"Hello Async World!" + stream = AsyncStream(data) + resp = await blob_client.upload_blob(stream, overwrite=True) + assert resp is not None + + blob_data = await (await blob_client.download_blob()).read() + assert blob_data == b"Hello Async World!" # data is fixed by mock transport + + resp = await blob_client.delete_blob() + assert resp is None + + @BlobPreparer() + async def test_legacy_transport_content_validation(self, **kwargs): + storage_account_name = kwargs.pop("storage_account_name") + storage_account_key = kwargs.pop("storage_account_key") + + transport = MockStorageTransport() + blob_client = BlobClient( + self.account_url(storage_account_name, "blob"), + container_name='test_cont', + blob_name='test_blob', + credential=storage_account_key, + transport=transport, + retry_total=0 + ) + + data = b"Hello Async World!" + stream = AsyncStream(data) + resp = await blob_client.upload_blob(stream, overwrite=True, validate_content=True) + assert resp is not None + + blob_data = await (await blob_client.download_blob(validate_content=True)).read() + assert blob_data == b"Hello Async World!" # data is fixed by mock transport \ No newline at end of file From 0cd044e49f87f3bbc47a3c5798f862f760c01d30 Mon Sep 17 00:00:00 2001 From: Peter Wu Date: Sat, 12 Apr 2025 21:48:52 -0400 Subject: [PATCH 2/8] MockStorageTransport -> MockLegacyTransport --- sdk/storage/azure-storage-blob/tests/test_helpers.py | 4 ++-- sdk/storage/azure-storage-blob/tests/test_helpers_async.py | 4 ++-- sdk/storage/azure-storage-blob/tests/test_transports.py | 6 +++--- .../azure-storage-blob/tests/test_transports_async.py | 6 +++--- 4 files changed, 10 insertions(+), 10 deletions(-) diff --git a/sdk/storage/azure-storage-blob/tests/test_helpers.py b/sdk/storage/azure-storage-blob/tests/test_helpers.py index 6f309bd756cc..f4c7dd86b6ea 100644 --- a/sdk/storage/azure-storage-blob/tests/test_helpers.py +++ b/sdk/storage/azure-storage-blob/tests/test_helpers.py @@ -70,7 +70,7 @@ def __init__( self.raw = HTTPResponse() -class MockStorageTransport(HttpTransport): +class MockLegacyTransport(HttpTransport): """ This transport returns legacy http response objects from azure core and is intended only to test our backwards compatibility support. @@ -137,7 +137,7 @@ def send(self, request: HttpRequest, **kwargs: Any) -> RequestsTransportResponse ) ) else: - raise ValueError("The request is not accepted as part of MockStorageTransport.") + raise ValueError("The request is not accepted as part of MockLegacyTransport.") return rest_response def __enter__(self) -> Self: diff --git a/sdk/storage/azure-storage-blob/tests/test_helpers_async.py b/sdk/storage/azure-storage-blob/tests/test_helpers_async.py index 0cbf222f8f94..6a1a1db3540d 100644 --- a/sdk/storage/azure-storage-blob/tests/test_helpers_async.py +++ b/sdk/storage/azure-storage-blob/tests/test_helpers_async.py @@ -84,7 +84,7 @@ def __init__( self.reason = reason -class MockStorageTransport(AsyncHttpTransport): +class MockLegacyTransport(AsyncHttpTransport): """ This transport returns legacy http response objects from azure core and is intended only to test our backwards compatibility support. @@ -155,7 +155,7 @@ async def send(self, request: HttpRequest, **kwargs: Any) -> AioHttpTransportRes decompress=False ) else: - raise ValueError("The request is not accepted as part of MockStorageTransport.") + raise ValueError("The request is not accepted as part of MockLegacyTransport.") await rest_response.load_body() return rest_response diff --git a/sdk/storage/azure-storage-blob/tests/test_transports.py b/sdk/storage/azure-storage-blob/tests/test_transports.py index 27285c14f69b..0347fa268c6e 100644 --- a/sdk/storage/azure-storage-blob/tests/test_transports.py +++ b/sdk/storage/azure-storage-blob/tests/test_transports.py @@ -3,7 +3,7 @@ from devtools_testutils.storage import StorageRecordedTestCase from settings.testcase import BlobPreparer -from test_helpers import MockStorageTransport +from test_helpers import MockLegacyTransport class TestStorageTransports(StorageRecordedTestCase): @@ -27,7 +27,7 @@ def test_legacy_transport(self, **kwargs): storage_account_name = kwargs.pop("storage_account_name") storage_account_key = kwargs.pop("storage_account_key") - transport = MockStorageTransport() + transport = MockLegacyTransport() blob_client = BlobClient( self.account_url(storage_account_name, "blob"), container_name='test_cont', @@ -58,7 +58,7 @@ def test_legacy_transport_content_validation(self, **kwargs): storage_account_name = kwargs.pop("storage_account_name") storage_account_key = kwargs.pop("storage_account_key") - transport = MockStorageTransport() + transport = MockLegacyTransport() blob_client = BlobClient( self.account_url(storage_account_name, "blob"), container_name='test_cont', diff --git a/sdk/storage/azure-storage-blob/tests/test_transports_async.py b/sdk/storage/azure-storage-blob/tests/test_transports_async.py index af4ed6f8f167..98ad578dc38d 100644 --- a/sdk/storage/azure-storage-blob/tests/test_transports_async.py +++ b/sdk/storage/azure-storage-blob/tests/test_transports_async.py @@ -3,7 +3,7 @@ from devtools_testutils.storage.aio import AsyncStorageRecordedTestCase from settings.testcase import BlobPreparer -from test_helpers_async import AsyncStream, MockStorageTransport +from test_helpers_async import AsyncStream, MockLegacyTransport class TestStorageTransportsAsync(AsyncStorageRecordedTestCase): @@ -27,7 +27,7 @@ async def test_legacy_transport(self, **kwargs): storage_account_name = kwargs.pop("storage_account_name") storage_account_key = kwargs.pop("storage_account_key") - transport = MockStorageTransport() + transport = MockLegacyTransport() blob_client = BlobClient( self.account_url(storage_account_name, "blob"), container_name='test_cont', @@ -59,7 +59,7 @@ async def test_legacy_transport_content_validation(self, **kwargs): storage_account_name = kwargs.pop("storage_account_name") storage_account_key = kwargs.pop("storage_account_key") - transport = MockStorageTransport() + transport = MockLegacyTransport() blob_client = BlobClient( self.account_url(storage_account_name, "blob"), container_name='test_cont', From 67aa9f6b247af066a77c9fb0aee54f53f06c8ed8 Mon Sep 17 00:00:00 2001 From: Peter Wu Date: Sat, 12 Apr 2025 22:01:40 -0400 Subject: [PATCH 3/8] Added core transport tests --- .../azure-storage-blob/tests/test_helpers.py | 85 ++++++++++++++++++- .../tests/test_transports.py | 55 +++++++++++- 2 files changed, 138 insertions(+), 2 deletions(-) diff --git a/sdk/storage/azure-storage-blob/tests/test_helpers.py b/sdk/storage/azure-storage-blob/tests/test_helpers.py index f4c7dd86b6ea..147f1413df8b 100644 --- a/sdk/storage/azure-storage-blob/tests/test_helpers.py +++ b/sdk/storage/azure-storage-blob/tests/test_helpers.py @@ -8,7 +8,11 @@ from typing import Any, Dict, Optional from typing_extensions import Self -from azure.core.pipeline.transport import HttpTransport, RequestsTransportResponse +from azure.core.pipeline.transport import ( + HttpTransport, + RequestsTransport, + RequestsTransportResponse +) from azure.core.rest import HttpRequest from requests import Response from urllib3 import HTTPResponse @@ -151,3 +155,82 @@ def open(self) -> None: def close(self) -> None: pass + + +class MockCoreTransport(RequestsTransport): + def send(self, request: HttpRequest, **kwargs: Any) -> RequestsTransportResponse: + if request.method == 'GET': + # download_blob + headers = { + "Content-Type": "application/octet-stream", + "Content-Range": "bytes 0-17/18", + "Content-Length": "18", + } + + if "x-ms-range-get-content-md5" in request.headers: + headers["Content-MD5"] = "7Qdih1MuhjZehB6Sv8UNjA==" # cspell:disable-line + + rest_response = RequestsTransportResponse( + request=request, + requests_response=MockHttpClientResponse( + request.url, + b"Hello World!", + headers, + ) + ) + elif request.method == 'HEAD': + # get_blob_properties + rest_response = RequestsTransportResponse( + request=request, + requests_response=MockHttpClientResponse( + request.url, + b"", + { + "Content-Type": "application/octet-stream", + "Content-Length": "1024", + }, + ) + ) + elif request.method == 'PUT': + # upload_blob + rest_response = RequestsTransportResponse( + request=request, + requests_response=MockHttpClientResponse( + request.url, + b"", + { + "Content-Length": "0", + }, + 201, + "Created" + ) + ) + elif request.method == 'DELETE': + # delete_blob + rest_response = RequestsTransportResponse( + request=request, + requests_response=MockHttpClientResponse( + request.url, + b"", + { + "Content-Length": "0", + }, + 202, + "Accepted" + ) + ) + else: + raise ValueError("The request is not accepted as part of MockCoreTransport.") + return rest_response + + def __enter__(self) -> Self: + return self + + def __exit__(self, *args: Any) -> None: + pass + + def open(self) -> None: + pass + + def close(self) -> None: + pass diff --git a/sdk/storage/azure-storage-blob/tests/test_transports.py b/sdk/storage/azure-storage-blob/tests/test_transports.py index 0347fa268c6e..3440824e674a 100644 --- a/sdk/storage/azure-storage-blob/tests/test_transports.py +++ b/sdk/storage/azure-storage-blob/tests/test_transports.py @@ -3,7 +3,7 @@ from devtools_testutils.storage import StorageRecordedTestCase from settings.testcase import BlobPreparer -from test_helpers import MockLegacyTransport +from test_helpers import MockCoreTransport, MockLegacyTransport class TestStorageTransports(StorageRecordedTestCase): @@ -74,3 +74,56 @@ def test_legacy_transport_content_validation(self, **kwargs): blob_data = blob_client.download_blob(validate_content=True).read() assert blob_data == b"Hello World!" # data is fixed by mock transport + + @BlobPreparer() + def test_core_transport(self, **kwargs): + storage_account_name = kwargs.pop("storage_account_name") + storage_account_key = kwargs.pop("storage_account_key") + + transport = MockCoreTransport() + blob_client = BlobClient( + self.account_url(storage_account_name, "blob"), + container_name='test_cont', + blob_name='test_blob', + credential=storage_account_key, + transport=transport, + retry_total=0 + ) + + content = blob_client.download_blob() + assert content is not None + + props = blob_client.get_blob_properties() + assert props is not None + + data = b"Hello World!" + resp = blob_client.upload_blob(data, overwrite=True) + assert resp is not None + + blob_data = blob_client.download_blob().read() + assert blob_data == b"Hello World!" # data is fixed by mock transport + + resp = blob_client.delete_blob() + assert resp is None + + @BlobPreparer() + def test_core_transport_content_validation(self, **kwargs): + storage_account_name = kwargs.pop("storage_account_name") + storage_account_key = kwargs.pop("storage_account_key") + + transport = MockCoreTransport() + blob_client = BlobClient( + self.account_url(storage_account_name, "blob"), + container_name='test_cont', + blob_name='test_blob', + credential=storage_account_key, + transport=transport, + retry_total=0 + ) + + data = b"Hello World!" + resp = blob_client.upload_blob(data, overwrite=True, validate_content=True) + assert resp is not None + + blob_data = blob_client.download_blob(validate_content=True).read() + assert blob_data == b"Hello World!" # data is fixed by mock transport From acb52add11c317aa1aaa0a90cc68c6fbf6ac5fa2 Mon Sep 17 00:00:00 2001 From: Peter Wu Date: Sat, 12 Apr 2025 22:02:27 -0400 Subject: [PATCH 4/8] Copyright header --- sdk/storage/azure-storage-blob/tests/test_transports.py | 6 ++++++ .../azure-storage-blob/tests/test_transports_async.py | 6 ++++++ 2 files changed, 12 insertions(+) diff --git a/sdk/storage/azure-storage-blob/tests/test_transports.py b/sdk/storage/azure-storage-blob/tests/test_transports.py index 3440824e674a..172ea745095e 100644 --- a/sdk/storage/azure-storage-blob/tests/test_transports.py +++ b/sdk/storage/azure-storage-blob/tests/test_transports.py @@ -1,3 +1,9 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for +# license information. +# -------------------------------------------------------------------------- + from azure.storage.blob import BlobClient, BlobServiceClient from azure.core.exceptions import ResourceExistsError diff --git a/sdk/storage/azure-storage-blob/tests/test_transports_async.py b/sdk/storage/azure-storage-blob/tests/test_transports_async.py index 98ad578dc38d..00a249b1d30b 100644 --- a/sdk/storage/azure-storage-blob/tests/test_transports_async.py +++ b/sdk/storage/azure-storage-blob/tests/test_transports_async.py @@ -1,3 +1,9 @@ +# ------------------------------------------------------------------------- +# Copyright (c) Microsoft Corporation. All rights reserved. +# Licensed under the MIT License. See License.txt in the project root for +# license information. +# -------------------------------------------------------------------------- + from azure.storage.blob.aio import BlobClient, BlobServiceClient from azure.core.exceptions import ResourceExistsError From 553fa1029767dab43bc20ca52cf6b903d851d2fb Mon Sep 17 00:00:00 2001 From: Peter Wu Date: Sun, 13 Apr 2025 04:04:00 -0400 Subject: [PATCH 5/8] Fixed transport download/upload --- .../storage/blob/_shared/policies_async.py | 4 - .../azure/storage/blob/aio/_download_async.py | 7 +- .../tests/test_common_blob_async.py | 1 + .../azure-storage-blob/tests/test_helpers.py | 24 ++-- .../tests/test_helpers_async.py | 123 +++++++++++++++++- .../tests/test_transports.py | 8 +- .../tests/test_transports_async.py | 69 ++++++++-- .../filedatalake/_shared/policies_async.py | 4 - .../fileshare/_shared/policies_async.py | 4 - .../storage/queue/_shared/policies_async.py | 4 - 10 files changed, 200 insertions(+), 48 deletions(-) diff --git a/sdk/storage/azure-storage-blob/azure/storage/blob/_shared/policies_async.py b/sdk/storage/azure-storage-blob/azure/storage/blob/_shared/policies_async.py index c44e19ca06ea..5324c66a205b 100644 --- a/sdk/storage/azure-storage-blob/azure/storage/blob/_shared/policies_async.py +++ b/sdk/storage/azure-storage-blob/azure/storage/blob/_shared/policies_async.py @@ -45,10 +45,6 @@ async def retry_hook(settings, **kwargs): async def is_checksum_retry(response): # retry if invalid content md5 if response.context.get('validate_content', False) and response.http_response.headers.get('content-md5'): - try: - await response.http_response.load_body() # Load the body in memory and close the socket - except (StreamClosedError, StreamConsumedError): - pass computed_md5 = response.http_request.headers.get('content-md5', None) or \ encode_base64(StorageContentValidation.get_content_md5(response.http_response.body())) if response.http_response.headers['content-md5'] != computed_md5: diff --git a/sdk/storage/azure-storage-blob/azure/storage/blob/aio/_download_async.py b/sdk/storage/azure-storage-blob/azure/storage/blob/aio/_download_async.py index 8a929647e78f..d784c94d7d95 100644 --- a/sdk/storage/azure-storage-blob/azure/storage/blob/aio/_download_async.py +++ b/sdk/storage/azure-storage-blob/azure/storage/blob/aio/_download_async.py @@ -46,8 +46,11 @@ async def process_content(data: Any, start_offset: int, end_offset: int, encryption: Dict[str, Any]) -> bytes: if data is None: raise ValueError("Response cannot be None.") - await data.response.load_body() - content = cast(bytes, data.response.body()) + if hasattr(data.response, "read"): + content = b"".join([d async for d in data]) + else: + await data.response.load_body() + content = cast(bytes, data.response.body()) if encryption.get('key') is not None or encryption.get('resolver') is not None: try: return decrypt_blob( diff --git a/sdk/storage/azure-storage-blob/tests/test_common_blob_async.py b/sdk/storage/azure-storage-blob/tests/test_common_blob_async.py index aea2f763f738..2ebaad59279d 100644 --- a/sdk/storage/azure-storage-blob/tests/test_common_blob_async.py +++ b/sdk/storage/azure-storage-blob/tests/test_common_blob_async.py @@ -53,6 +53,7 @@ from devtools_testutils.aio import recorded_by_proxy_async from devtools_testutils.storage.aio import AsyncStorageRecordedTestCase from settings.testcase import BlobPreparer +from test_helpers_async import AsyncStream # ------------------------------------------------------------------------------ TEST_CONTAINER_PREFIX = 'container' diff --git a/sdk/storage/azure-storage-blob/tests/test_helpers.py b/sdk/storage/azure-storage-blob/tests/test_helpers.py index 147f1413df8b..dc4932e31673 100644 --- a/sdk/storage/azure-storage-blob/tests/test_helpers.py +++ b/sdk/storage/azure-storage-blob/tests/test_helpers.py @@ -53,7 +53,7 @@ def tell(self): return self.wrapped_stream.tell() -class MockHttpClientResponse(Response): +class MockClientResponse(Response): def __init__( self, url: str, body_bytes: bytes, @@ -61,7 +61,7 @@ def __init__( status: int = 200, reason: str = "OK" ) -> None: - super(MockHttpClientResponse).__init__() + super(MockClientResponse).__init__() self._url = url self._body = body_bytes self._content = body_bytes @@ -93,7 +93,7 @@ def send(self, request: HttpRequest, **kwargs: Any) -> RequestsTransportResponse rest_response = RequestsTransportResponse( request=request, - requests_response=MockHttpClientResponse( + requests_response=MockClientResponse( request.url, b"Hello World!", headers, @@ -103,7 +103,7 @@ def send(self, request: HttpRequest, **kwargs: Any) -> RequestsTransportResponse # get_blob_properties rest_response = RequestsTransportResponse( request=request, - requests_response=MockHttpClientResponse( + requests_response=MockClientResponse( request.url, b"", { @@ -116,7 +116,7 @@ def send(self, request: HttpRequest, **kwargs: Any) -> RequestsTransportResponse # upload_blob rest_response = RequestsTransportResponse( request=request, - requests_response=MockHttpClientResponse( + requests_response=MockClientResponse( request.url, b"", { @@ -130,7 +130,7 @@ def send(self, request: HttpRequest, **kwargs: Any) -> RequestsTransportResponse # delete_blob rest_response = RequestsTransportResponse( request=request, - requests_response=MockHttpClientResponse( + requests_response=MockClientResponse( request.url, b"", { @@ -158,6 +158,10 @@ def close(self) -> None: class MockCoreTransport(RequestsTransport): + """ + This transport returns http response objects from azure core pipelines and is + intended only to test our backwards compatibility support. + """ def send(self, request: HttpRequest, **kwargs: Any) -> RequestsTransportResponse: if request.method == 'GET': # download_blob @@ -172,7 +176,7 @@ def send(self, request: HttpRequest, **kwargs: Any) -> RequestsTransportResponse rest_response = RequestsTransportResponse( request=request, - requests_response=MockHttpClientResponse( + requests_response=MockClientResponse( request.url, b"Hello World!", headers, @@ -182,7 +186,7 @@ def send(self, request: HttpRequest, **kwargs: Any) -> RequestsTransportResponse # get_blob_properties rest_response = RequestsTransportResponse( request=request, - requests_response=MockHttpClientResponse( + requests_response=MockClientResponse( request.url, b"", { @@ -195,7 +199,7 @@ def send(self, request: HttpRequest, **kwargs: Any) -> RequestsTransportResponse # upload_blob rest_response = RequestsTransportResponse( request=request, - requests_response=MockHttpClientResponse( + requests_response=MockClientResponse( request.url, b"", { @@ -209,7 +213,7 @@ def send(self, request: HttpRequest, **kwargs: Any) -> RequestsTransportResponse # delete_blob rest_response = RequestsTransportResponse( request=request, - requests_response=MockHttpClientResponse( + requests_response=MockClientResponse( request.url, b"", { diff --git a/sdk/storage/azure-storage-blob/tests/test_helpers_async.py b/sdk/storage/azure-storage-blob/tests/test_helpers_async.py index 6a1a1db3540d..d59caf32449e 100644 --- a/sdk/storage/azure-storage-blob/tests/test_helpers_async.py +++ b/sdk/storage/azure-storage-blob/tests/test_helpers_async.py @@ -6,9 +6,16 @@ from io import IOBase, UnsupportedOperation from typing import Any, Dict, Optional -from azure.core.pipeline.transport import AioHttpTransportResponse, AsyncHttpTransport +from azure.core.pipeline.transport import ( + AioHttpTransportResponse, + AsyncHttpTransport, + AsyncioRequestsTransport, + AsyncioRequestsTransportResponse +) from azure.core.rest import HttpRequest from aiohttp import ClientResponse +from urllib3 import HTTPResponse +from requests import Response class ProgressTracker: @@ -66,7 +73,7 @@ async def read(self, size: int = -1) -> bytes: return data -class MockAioHttpClientResponse(ClientResponse): +class MockAsyncClientResponse(ClientResponse): def __init__( self, url: str, body_bytes: bytes, @@ -74,7 +81,7 @@ def __init__( status: int = 200, reason: str = "OK" ) -> None: - super(MockAioHttpClientResponse).__init__() + super(MockAsyncClientResponse).__init__() self._url = url self._body = body_bytes self._headers = headers @@ -103,7 +110,7 @@ async def send(self, request: HttpRequest, **kwargs: Any) -> AioHttpTransportRes rest_response = AioHttpTransportResponse( request=request, - aiohttp_response=MockAioHttpClientResponse( + aiohttp_response=MockAsyncClientResponse( request.url, b"Hello Async World!", headers, @@ -114,7 +121,7 @@ async def send(self, request: HttpRequest, **kwargs: Any) -> AioHttpTransportRes # get_blob_properties rest_response = AioHttpTransportResponse( request=request, - aiohttp_response=MockAioHttpClientResponse( + aiohttp_response=MockAsyncClientResponse( request.url, b"", { @@ -128,7 +135,7 @@ async def send(self, request: HttpRequest, **kwargs: Any) -> AioHttpTransportRes # upload_blob rest_response = AioHttpTransportResponse( request=request, - aiohttp_response=MockAioHttpClientResponse( + aiohttp_response=MockAsyncClientResponse( request.url, b"", { @@ -143,7 +150,7 @@ async def send(self, request: HttpRequest, **kwargs: Any) -> AioHttpTransportRes # delete_blob rest_response = AioHttpTransportResponse( request=request, - aiohttp_response=MockAioHttpClientResponse( + aiohttp_response=MockAsyncClientResponse( request.url, b"", { @@ -171,3 +178,105 @@ async def open(self): async def close(self): pass + + +class MockAsyncResponse(Response): + def __init__( + self, url: str, + body_bytes: bytes, + headers: Dict[str, Any], + status: int = 200, + reason: str = "OK" + ) -> None: + super(MockAsyncResponse).__init__() + self._content = body_bytes + self._content_consumed = False + self.url = url + self.headers = headers + self.status_code = status + self.reason = reason + self.raw = HTTPResponse(body_bytes) + + +class MockCoreTransport(AsyncioRequestsTransport): + """ + This transport returns legacy http response objects from azure core and is + intended only to test our backwards compatibility support. + """ + async def send(self, request: HttpRequest, **kwargs: Any) -> AsyncioRequestsTransportResponse: + if request.method == 'GET': + # download_blob + headers = { + "Content-Type": "application/octet-stream", + "Content-Range": "bytes 0-17/18", + "Content-Length": "18", + } + + if "x-ms-range-get-content-md5" in request.headers: + headers["Content-MD5"] = "I3pVbaOCUTom+G9F9uKFoA==" + + rest_response = AsyncioRequestsTransportResponse( + request=request, + requests_response=MockAsyncResponse( + request.url, + b"Hello Async World!", + headers, + ) + ) + elif request.method == 'HEAD': + # get_blob_properties + rest_response = AsyncioRequestsTransportResponse( + request=request, + requests_response=MockAsyncResponse( + request.url, + b"", + { + "Content-Type": "application/octet-stream", + "Content-Length": "1024", + }, + ) + ) + elif request.method == 'PUT': + # upload_blob + rest_response = AsyncioRequestsTransportResponse( + request=request, + requests_response=MockAsyncResponse( + request.url, + b"", + { + "Content-Length": "0", + }, + 201, + "Created" + ) + ) + elif request.method == 'DELETE': + # delete_blob + rest_response = AsyncioRequestsTransportResponse( + request=request, + requests_response=MockAsyncResponse( + request.url, + b"", + { + "Content-Length": "0", + }, + 202, + "Accepted" + ) + ) + else: + raise ValueError("The request is not accepted as part of MockCoreTransport.") + + return rest_response + + async def __aenter__(self): + return self + + async def __aexit__(self, *args): + pass + + async def open(self): + pass + + async def close(self): + pass diff --git a/sdk/storage/azure-storage-blob/tests/test_transports.py b/sdk/storage/azure-storage-blob/tests/test_transports.py index 172ea745095e..9b48a67b5369 100644 --- a/sdk/storage/azure-storage-blob/tests/test_transports.py +++ b/sdk/storage/azure-storage-blob/tests/test_transports.py @@ -36,7 +36,7 @@ def test_legacy_transport(self, **kwargs): transport = MockLegacyTransport() blob_client = BlobClient( self.account_url(storage_account_name, "blob"), - container_name='test_cont', + container_name='container', blob_name='test_blob', credential=storage_account_key, transport=transport, @@ -67,7 +67,7 @@ def test_legacy_transport_content_validation(self, **kwargs): transport = MockLegacyTransport() blob_client = BlobClient( self.account_url(storage_account_name, "blob"), - container_name='test_cont', + container_name='container', blob_name='test_blob', credential=storage_account_key, transport=transport, @@ -89,7 +89,7 @@ def test_core_transport(self, **kwargs): transport = MockCoreTransport() blob_client = BlobClient( self.account_url(storage_account_name, "blob"), - container_name='test_cont', + container_name='container', blob_name='test_blob', credential=storage_account_key, transport=transport, @@ -120,7 +120,7 @@ def test_core_transport_content_validation(self, **kwargs): transport = MockCoreTransport() blob_client = BlobClient( self.account_url(storage_account_name, "blob"), - container_name='test_cont', + container_name='container', blob_name='test_blob', credential=storage_account_key, transport=transport, diff --git a/sdk/storage/azure-storage-blob/tests/test_transports_async.py b/sdk/storage/azure-storage-blob/tests/test_transports_async.py index 00a249b1d30b..60c5663dddb7 100644 --- a/sdk/storage/azure-storage-blob/tests/test_transports_async.py +++ b/sdk/storage/azure-storage-blob/tests/test_transports_async.py @@ -4,8 +4,11 @@ # license information. # -------------------------------------------------------------------------- +import pytest + from azure.storage.blob.aio import BlobClient, BlobServiceClient from azure.core.exceptions import ResourceExistsError +from azure.core.pipeline.transport import AsyncioRequestsTransport from devtools_testutils.storage.aio import AsyncStorageRecordedTestCase from settings.testcase import BlobPreparer @@ -36,24 +39,21 @@ async def test_legacy_transport(self, **kwargs): transport = MockLegacyTransport() blob_client = BlobClient( self.account_url(storage_account_name, "blob"), - container_name='test_cont', + container_name='container', blob_name='test_blob', credential=storage_account_key, transport=transport, retry_total=0 ) - content = await blob_client.download_blob() - assert content is not None - - props = await blob_client.get_blob_properties() - assert props is not None - data = b"Hello Async World!" stream = AsyncStream(data) resp = await blob_client.upload_blob(stream, overwrite=True) assert resp is not None + props = await blob_client.get_blob_properties() + assert props is not None + blob_data = await (await blob_client.download_blob()).read() assert blob_data == b"Hello Async World!" # data is fixed by mock transport @@ -68,7 +68,58 @@ async def test_legacy_transport_content_validation(self, **kwargs): transport = MockLegacyTransport() blob_client = BlobClient( self.account_url(storage_account_name, "blob"), - container_name='test_cont', + container_name='container', + blob_name='test_blob', + credential=storage_account_key, + transport=transport, + retry_total=0 + ) + + data = b"Hello Async World!" + stream = AsyncStream(data) + resp = await blob_client.upload_blob(stream, overwrite=True, validate_content=True) + assert resp is not None + + blob_data = await (await blob_client.download_blob(validate_content=True)).read() + assert blob_data == b"Hello Async World!" # data is fixed by mock transport + + @pytest.mark.live_test_only + @BlobPreparer() + async def test_core_transport(self, **kwargs): + storage_account_name = kwargs.pop("storage_account_name") + storage_account_key = kwargs.pop("storage_account_key") + + transport = AsyncioRequestsTransport() + blob_client = BlobClient( + self.account_url(storage_account_name, "blob"), + container_name='container', + blob_name='test_blob', + credential=storage_account_key, + transport=transport, + retry_total=0 + ) + + data = b"Hello Async World!" + stream = AsyncStream(data) + resp = await blob_client.upload_blob(stream, overwrite=True) + assert resp is not None + + props = await blob_client.get_blob_properties() + assert props is not None + + blob_data = await (await blob_client.download_blob()).read() + assert blob_data == b"Hello Async World!" + + @pytest.mark.live_test_only + @BlobPreparer() + async def test_core_transport_content_validation(self, **kwargs): + storage_account_name = kwargs.pop("storage_account_name") + storage_account_key = kwargs.pop("storage_account_key") + + transport = AsyncioRequestsTransport() + blob_client = BlobClient( + self.account_url(storage_account_name, "blob"), + container_name='container', blob_name='test_blob', credential=storage_account_key, transport=transport, @@ -81,4 +132,4 @@ async def test_legacy_transport_content_validation(self, **kwargs): assert resp is not None blob_data = await (await blob_client.download_blob(validate_content=True)).read() - assert blob_data == b"Hello Async World!" # data is fixed by mock transport \ No newline at end of file + assert blob_data == b"Hello Async World!" diff --git a/sdk/storage/azure-storage-file-datalake/azure/storage/filedatalake/_shared/policies_async.py b/sdk/storage/azure-storage-file-datalake/azure/storage/filedatalake/_shared/policies_async.py index 807a51dd297c..f244cf3badfb 100644 --- a/sdk/storage/azure-storage-file-datalake/azure/storage/filedatalake/_shared/policies_async.py +++ b/sdk/storage/azure-storage-file-datalake/azure/storage/filedatalake/_shared/policies_async.py @@ -45,10 +45,6 @@ async def retry_hook(settings, **kwargs): async def is_checksum_retry(response): # retry if invalid content md5 if response.context.get('validate_content', False) and response.http_response.headers.get('content-md5'): - try: - await response.http_response.load_body() # Load the body in memory and close the socket - except (StreamClosedError, StreamConsumedError): - pass computed_md5 = response.http_request.headers.get('content-md5', None) or \ encode_base64(StorageContentValidation.get_content_md5(response.http_response.body())) if response.http_response.headers['content-md5'] != computed_md5: diff --git a/sdk/storage/azure-storage-file-share/azure/storage/fileshare/_shared/policies_async.py b/sdk/storage/azure-storage-file-share/azure/storage/fileshare/_shared/policies_async.py index 807a51dd297c..f244cf3badfb 100644 --- a/sdk/storage/azure-storage-file-share/azure/storage/fileshare/_shared/policies_async.py +++ b/sdk/storage/azure-storage-file-share/azure/storage/fileshare/_shared/policies_async.py @@ -45,10 +45,6 @@ async def retry_hook(settings, **kwargs): async def is_checksum_retry(response): # retry if invalid content md5 if response.context.get('validate_content', False) and response.http_response.headers.get('content-md5'): - try: - await response.http_response.load_body() # Load the body in memory and close the socket - except (StreamClosedError, StreamConsumedError): - pass computed_md5 = response.http_request.headers.get('content-md5', None) or \ encode_base64(StorageContentValidation.get_content_md5(response.http_response.body())) if response.http_response.headers['content-md5'] != computed_md5: diff --git a/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/policies_async.py b/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/policies_async.py index 807a51dd297c..f244cf3badfb 100644 --- a/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/policies_async.py +++ b/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/policies_async.py @@ -45,10 +45,6 @@ async def retry_hook(settings, **kwargs): async def is_checksum_retry(response): # retry if invalid content md5 if response.context.get('validate_content', False) and response.http_response.headers.get('content-md5'): - try: - await response.http_response.load_body() # Load the body in memory and close the socket - except (StreamClosedError, StreamConsumedError): - pass computed_md5 = response.http_request.headers.get('content-md5', None) or \ encode_base64(StorageContentValidation.get_content_md5(response.http_response.body())) if response.http_response.headers['content-md5'] != computed_md5: From 02c08bbe5082f8b684fd53f3c17075d25f4444e6 Mon Sep 17 00:00:00 2001 From: Peter Wu Date: Sun, 13 Apr 2025 11:07:40 -0400 Subject: [PATCH 6/8] Propagated process_content logic to file share --- .../azure/storage/fileshare/aio/_download_async.py | 7 +++++-- 1 file changed, 5 insertions(+), 2 deletions(-) diff --git a/sdk/storage/azure-storage-file-share/azure/storage/fileshare/aio/_download_async.py b/sdk/storage/azure-storage-file-share/azure/storage/fileshare/aio/_download_async.py index 44b45084a670..96cb311cc624 100644 --- a/sdk/storage/azure-storage-file-share/azure/storage/fileshare/aio/_download_async.py +++ b/sdk/storage/azure-storage-file-share/azure/storage/fileshare/aio/_download_async.py @@ -33,8 +33,11 @@ async def process_content(data: Any) -> bytes: raise ValueError("Response cannot be None.") try: - await data.response.load_body() - return cast(bytes, data.response.body()) + if hasattr(data.response, "read"): + return b"".join([d async for d in data]) + else: + await data.response.load_body() + return cast(bytes, data.response.body()) except Exception as error: raise HttpResponseError(message="Download stream interrupted.", response=data.response, error=error) from error From a85868ed17420e388566728c7e204fb90c065641 Mon Sep 17 00:00:00 2001 From: Peter Wu Date: Sun, 13 Apr 2025 13:45:59 -0400 Subject: [PATCH 7/8] Fixed content validation path for upload/download --- .../azure/storage/blob/_shared/policies_async.py | 5 +++++ .../azure/storage/blob/aio/_download_async.py | 6 +++--- .../azure/storage/filedatalake/_shared/policies_async.py | 5 +++++ .../azure/storage/fileshare/_shared/policies_async.py | 5 +++++ .../azure/storage/fileshare/aio/_download_async.py | 6 +++--- .../azure/storage/queue/_shared/policies_async.py | 5 +++++ 6 files changed, 26 insertions(+), 6 deletions(-) diff --git a/sdk/storage/azure-storage-blob/azure/storage/blob/_shared/policies_async.py b/sdk/storage/azure-storage-blob/azure/storage/blob/_shared/policies_async.py index 5324c66a205b..53f1582ac0a9 100644 --- a/sdk/storage/azure-storage-blob/azure/storage/blob/_shared/policies_async.py +++ b/sdk/storage/azure-storage-blob/azure/storage/blob/_shared/policies_async.py @@ -44,6 +44,11 @@ async def retry_hook(settings, **kwargs): async def is_checksum_retry(response): # retry if invalid content md5 + if hasattr(response.http_response, "load_body"): + try: + await response.http_response.load_body() + except (StreamClosedError, StreamConsumedError): + pass if response.context.get('validate_content', False) and response.http_response.headers.get('content-md5'): computed_md5 = response.http_request.headers.get('content-md5', None) or \ encode_base64(StorageContentValidation.get_content_md5(response.http_response.body())) diff --git a/sdk/storage/azure-storage-blob/azure/storage/blob/aio/_download_async.py b/sdk/storage/azure-storage-blob/azure/storage/blob/aio/_download_async.py index d784c94d7d95..8710b5fac227 100644 --- a/sdk/storage/azure-storage-blob/azure/storage/blob/aio/_download_async.py +++ b/sdk/storage/azure-storage-blob/azure/storage/blob/aio/_download_async.py @@ -46,11 +46,11 @@ async def process_content(data: Any, start_offset: int, end_offset: int, encryption: Dict[str, Any]) -> bytes: if data is None: raise ValueError("Response cannot be None.") - if hasattr(data.response, "read"): - content = b"".join([d async for d in data]) - else: + if hasattr(data.response, "load_body"): await data.response.load_body() content = cast(bytes, data.response.body()) + else: + content = b"".join([d async for d in data]) if encryption.get('key') is not None or encryption.get('resolver') is not None: try: return decrypt_blob( diff --git a/sdk/storage/azure-storage-file-datalake/azure/storage/filedatalake/_shared/policies_async.py b/sdk/storage/azure-storage-file-datalake/azure/storage/filedatalake/_shared/policies_async.py index f244cf3badfb..7aa5f97ace40 100644 --- a/sdk/storage/azure-storage-file-datalake/azure/storage/filedatalake/_shared/policies_async.py +++ b/sdk/storage/azure-storage-file-datalake/azure/storage/filedatalake/_shared/policies_async.py @@ -44,6 +44,11 @@ async def retry_hook(settings, **kwargs): async def is_checksum_retry(response): # retry if invalid content md5 + if hasattr(response.http_response, "load_body"): + try: + await response.http_response.load_body() + except (StreamClosedError, StreamConsumedError): + pass if response.context.get('validate_content', False) and response.http_response.headers.get('content-md5'): computed_md5 = response.http_request.headers.get('content-md5', None) or \ encode_base64(StorageContentValidation.get_content_md5(response.http_response.body())) diff --git a/sdk/storage/azure-storage-file-share/azure/storage/fileshare/_shared/policies_async.py b/sdk/storage/azure-storage-file-share/azure/storage/fileshare/_shared/policies_async.py index f244cf3badfb..7aa5f97ace40 100644 --- a/sdk/storage/azure-storage-file-share/azure/storage/fileshare/_shared/policies_async.py +++ b/sdk/storage/azure-storage-file-share/azure/storage/fileshare/_shared/policies_async.py @@ -44,6 +44,11 @@ async def retry_hook(settings, **kwargs): async def is_checksum_retry(response): # retry if invalid content md5 + if hasattr(response.http_response, "load_body"): + try: + await response.http_response.load_body() + except (StreamClosedError, StreamConsumedError): + pass if response.context.get('validate_content', False) and response.http_response.headers.get('content-md5'): computed_md5 = response.http_request.headers.get('content-md5', None) or \ encode_base64(StorageContentValidation.get_content_md5(response.http_response.body())) diff --git a/sdk/storage/azure-storage-file-share/azure/storage/fileshare/aio/_download_async.py b/sdk/storage/azure-storage-file-share/azure/storage/fileshare/aio/_download_async.py index 96cb311cc624..a20573d6e1fc 100644 --- a/sdk/storage/azure-storage-file-share/azure/storage/fileshare/aio/_download_async.py +++ b/sdk/storage/azure-storage-file-share/azure/storage/fileshare/aio/_download_async.py @@ -33,11 +33,11 @@ async def process_content(data: Any) -> bytes: raise ValueError("Response cannot be None.") try: - if hasattr(data.response, "read"): - return b"".join([d async for d in data]) - else: + if hasattr(data.response, "load_body"): await data.response.load_body() return cast(bytes, data.response.body()) + else: + return b"".join([d async for d in data]) except Exception as error: raise HttpResponseError(message="Download stream interrupted.", response=data.response, error=error) from error diff --git a/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/policies_async.py b/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/policies_async.py index f244cf3badfb..7aa5f97ace40 100644 --- a/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/policies_async.py +++ b/sdk/storage/azure-storage-queue/azure/storage/queue/_shared/policies_async.py @@ -44,6 +44,11 @@ async def retry_hook(settings, **kwargs): async def is_checksum_retry(response): # retry if invalid content md5 + if hasattr(response.http_response, "load_body"): + try: + await response.http_response.load_body() + except (StreamClosedError, StreamConsumedError): + pass if response.context.get('validate_content', False) and response.http_response.headers.get('content-md5'): computed_md5 = response.http_request.headers.get('content-md5', None) or \ encode_base64(StorageContentValidation.get_content_md5(response.http_response.body())) From ba5240a82b2a3de6f93f9923cb703a6aecfab198 Mon Sep 17 00:00:00 2001 From: Peter Wu Date: Sun, 13 Apr 2025 16:28:31 -0400 Subject: [PATCH 8/8] Fixed pylint error --- .../azure/storage/fileshare/aio/_download_async.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/sdk/storage/azure-storage-file-share/azure/storage/fileshare/aio/_download_async.py b/sdk/storage/azure-storage-file-share/azure/storage/fileshare/aio/_download_async.py index a20573d6e1fc..62b589b2ed9c 100644 --- a/sdk/storage/azure-storage-file-share/azure/storage/fileshare/aio/_download_async.py +++ b/sdk/storage/azure-storage-file-share/azure/storage/fileshare/aio/_download_async.py @@ -36,8 +36,7 @@ async def process_content(data: Any) -> bytes: if hasattr(data.response, "load_body"): await data.response.load_body() return cast(bytes, data.response.body()) - else: - return b"".join([d async for d in data]) + return b"".join([d async for d in data]) except Exception as error: raise HttpResponseError(message="Download stream interrupted.", response=data.response, error=error) from error