Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
import posixpath
from pathlib import Path

from azure.storage.blob import BlobServiceClient
from azure.storage.blob import BlobServiceClient, BlobProperties


# Application ID for telemetry (prepended to User-Agent)
Expand Down Expand Up @@ -173,18 +173,19 @@ def list_files(
paths = []
if recursive:
# Recursive: list all blobs under the prefix
blob_items = container_client.list_blobs(name_starts_with=prefix)
blob_items = container_client.list_blobs(name_starts_with=prefix, include=["metadata"])
for item in blob_items:
if hasattr(item, 'name'):
if hasattr(item, 'name') and not _is_adls_directory(item):
paths.append(item.name)
else:
# Non-recursive: use walk_blobs with delimiter to only get blobs at this level
# This is more efficient as Azure service handles the filtering
blob_items = container_client.walk_blobs(name_starts_with=prefix, delimiter='/')
blob_items = container_client.walk_blobs(name_starts_with=prefix, delimiter='/', include=["metadata"])
for item in blob_items:
# walk_blobs returns BlobProperties for blobs and BlobPrefix for directories
# We only want blobs (files), not prefixes (directories)
if hasattr(item, 'name') and not item.name.endswith('/'):

if hasattr(item, 'name') and not item.name.endswith('/') and not _is_adls_directory(item):
paths.append(item.name)

# Filter out directories (blobs ending with /)
Expand Down Expand Up @@ -218,3 +219,14 @@ def removeprefix(s: str, prefix: str) -> str:
if s.startswith(prefix):
return s[len(prefix):]
return s

def _is_adls_directory(blob: BlobProperties) -> bool:
# When listing against ADLS, we might hit directory stubs (hdi_isfolder=true) that don't have a trailing slash
# So, we use metadata to filter them out
return (
blob.size == 0
and blob.metadata is not None
and (blob.metadata.get("hdi_isfolder") == "true"
or blob.metadata.get("Hdi_isfolder") == "true" # Sometimes, the service returns this metadata key capitalized
)
)
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import unittest
from unittest.mock import MagicMock
import runai_model_streamer_azure.files.files as files
from azure.storage.blob import BlobProperties, BlobServiceClient, ContainerClient


class TestFiles(unittest.TestCase):
Expand Down Expand Up @@ -33,5 +35,77 @@ def test_removeprefix_no(self):
self.assertEqual(res, "test_prefix_string")


class TestListFiles(unittest.TestCase):

def make_blob(self, name, size=10, metadata=None):
blob = MagicMock(spec=BlobProperties)
blob.name = name
blob.size = size
blob.metadata = metadata
return blob

def setUp(self):
self.mock_container_client = MagicMock(spec=ContainerClient)
self.mock_blob_client = MagicMock(spec=BlobServiceClient)
self.mock_blob_client.get_container_client.return_value = (
self.mock_container_client
)

def test_listfiles_recursive(self):
# list_blobs and walk_blobs return an Iterable so list is fine
test_blobs = [
self.make_blob("file1.txt"),
self.make_blob("dir1/test.txt"),
self.make_blob("adls-dir1", size=0, metadata={"hdi_isfolder": "true"}), # ADLS directory stub
self.make_blob("dir2/", size=0, metadata=None),
self.make_blob("empty-blob", size=0),
self.make_blob("adls-caps", size=0, metadata={"Hdi_isfolder": "true"}),
self.make_blob("empty-file", size=0, metadata={"Hdi_isfolder": "false"}),

]

self.mock_container_client.list_blobs.return_value = test_blobs
_, _, result = files.list_files(self.mock_blob_client, "az://container/", recursive=True)
self.assertEqual(result, ["file1.txt", "dir1/test.txt", "empty-blob", "empty-file"])

def test_listfiles_non_recursive(self):
test_blobs = [
self.make_blob("file1.txt"),
self.make_blob("adls-dir1", size=0, metadata={"hdi_isfolder": "true"}),
self.make_blob("dir2/"),
self.make_blob("empty-blob", size=0),
self.make_blob("adls-caps", size=0, metadata={"Hdi_isfolder": "true"}),
self.make_blob("empty-file", size=0, metadata={"Hdi_isfolder": "false"}),
]

self.mock_container_client.walk_blobs.return_value = test_blobs
_, _, result = files.list_files(self.mock_blob_client, "az://container/", )
self.assertEqual(result, ["file1.txt", "empty-blob", "empty-file"])

def test_listfiles_with_allow_pattern(self):
blobs = [
self.make_blob("models/weights/config.json"),
self.make_blob("models/weights/model.safetensors"),
self.make_blob("models/README")
]

self.mock_container_client.list_blobs.return_value = blobs
_, _, result = files.list_files(
self.mock_blob_client, "az://container/", allow_pattern=["*.safetensors"], recursive=True
)
self.assertEqual(result,["models/weights/model.safetensors"])

def test_listfiles_with_ignore_pattern(self):
blobs = [
self.make_blob("models/weights/config.json"),
self.make_blob("models/weights/model.safetensors"),
self.make_blob("models/README")
]
self.mock_container_client.list_blobs.return_value = blobs
_, _, result = files.list_files(
self.mock_blob_client, "az://container/", ignore_pattern=["*.safetensors"], recursive=True
)
self.assertEqual(result, ["models/weights/config.json", "models/README"])

if __name__ == "__main__":
unittest.main()
Loading