Skip to content

fix: azure connection check #503

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 24 commits into from
Apr 10, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion packages/ragbits-core/CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -2,8 +2,8 @@

## Unreleased
- Add new fusion strategies for the hybrid vector store: RRF and DBSF (#413)

- move sources from ragbits-document-search to ragbits-core (#496)
- adding connection check to Azure get_blob_service (#502)

## 0.13.0 (2025-04-02)
- Make the score in VectorStoreResult consistent (always bigger is better)
Expand Down
99 changes: 46 additions & 53 deletions packages/ragbits-core/src/ragbits/core/sources/azure.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,19 +2,18 @@
from collections.abc import Sequence
from contextlib import suppress
from pathlib import Path
from typing import ClassVar, Optional
from typing import ClassVar
from urllib.parse import urlparse

from ragbits.core.audit import trace, traceable
from ragbits.core.sources.base import Source, get_local_storage_dir
from ragbits.core.sources.exceptions import SourceConnectionError, SourceNotFoundError
from ragbits.core.utils.decorators import requires_dependencies

with suppress(ImportError):
from azure.core.exceptions import ResourceNotFoundError
from azure.identity import DefaultAzureCredential
from azure.storage.blob import BlobServiceClient

from ragbits.core.sources.base import Source, get_local_storage_dir
from ragbits.core.sources.exceptions import SourceConnectionError, SourceNotFoundError
from ragbits.core.utils.decorators import requires_dependencies
from azure.storage.blob import BlobServiceClient, ExponentialRetry


class AzureBlobStorageSource(Source):
Expand All @@ -26,7 +25,6 @@ class AzureBlobStorageSource(Source):
account_name: str
container_name: str
blob_name: str
_blob_service: Optional["BlobServiceClient"] = None

@property
def id(self) -> str:
Expand All @@ -35,49 +33,6 @@ def id(self) -> str:
"""
return f"azure://{self.account_name}/{self.container_name}/{self.blob_name}"

@classmethod
@requires_dependencies(["azure.storage.blob", "azure.identity"], "azure")
async def _get_blob_service(cls, account_name: str) -> "BlobServiceClient":
"""
Returns an authenticated BlobServiceClient instance.

Priority:
1. DefaultAzureCredential (if account_name is set and authentication succeeds).
2. Connection string (if authentication with DefaultAzureCredential fails).

If neither method works, an error is raised.

Args:
account_name: The name of the Azure Blob Storage account.

Returns:
BlobServiceClient: The authenticated Blob Storage client.

Raises:
ValueError: If the authentication fails.
"""
try:
credential = DefaultAzureCredential()
account_url = f"https://{account_name}.blob.core.windows.net"
cls._blob_service = BlobServiceClient(account_url=account_url, credential=credential)
return cls._blob_service
except Exception as e:
print(f"Warning: Failed to authenticate using DefaultAzureCredential. \nError: {str(e)}")

connection_string = os.getenv("AZURE_STORAGE_CONNECTION_STRING")
if connection_string:
try:
cls._blob_service = BlobServiceClient.from_connection_string(conn_str=connection_string)
return cls._blob_service
except Exception as e:
raise ValueError("Failed to authenticate using connection string.") from e

# If neither method works, raise an error
raise ValueError(
"No authentication method available. "
"Provide an account_name for identity-based authentication or a connection string."
)

@requires_dependencies(["azure.storage.blob", "azure.core.exceptions"], "azure")
async def fetch(self) -> Path:
"""
Expand All @@ -95,7 +50,7 @@ async def fetch(self) -> Path:
path = container_local_dir / self.blob_name
with trace(account_name=self.account_name, container=self.container_name, blob=self.blob_name) as outputs:
try:
blob_service = await self._get_blob_service(account_name=self.account_name)
blob_service = self._get_blob_service(self.account_name)
blob_client = blob_service.get_blob_client(container=self.container_name, blob=self.blob_name)
Path(path).parent.mkdir(parents=True, exist_ok=True)
stream = blob_client.download_blob()
Expand Down Expand Up @@ -174,12 +129,11 @@ async def list_sources(
List of source objects.

Raises:
ImportError: If the required 'azure-storage-blob' package is not installed
SourceConnectionError: If there's an error connecting to Azure
"""
with trace(account_name=account_name, container=container, blob_name=blob_name) as outputs:
blob_service = await cls._get_blob_service(account_name=account_name)
try:
blob_service = cls._get_blob_service(account_name)
container_client = blob_service.get_container_client(container)
blobs = container_client.list_blobs(name_starts_with=blob_name)
outputs.results = [
Expand All @@ -189,3 +143,42 @@ async def list_sources(
return outputs.results
except Exception as e:
raise SourceConnectionError() from e

@staticmethod
def _get_blob_service(account_name: str) -> "BlobServiceClient":
"""
Returns an authenticated BlobServiceClient instance.

Priority:
1. DefaultAzureCredential.
2. Connection string.

Args:
account_name: The name of the Azure Blob Storage account.

Returns:
The authenticated Blob Storage client.
"""
try:
credential = DefaultAzureCredential()
account_url = f"https://{account_name}.blob.core.windows.net"
blob_service = BlobServiceClient(
account_url=account_url,
credential=credential,
retry_policy=ExponentialRetry(retry_total=0),
)
blob_service.get_account_information()
return blob_service
except Exception as first_exc:
if conn_str := os.getenv("AZURE_STORAGE_CONNECTION_STRING", ""):
try:
service = BlobServiceClient.from_connection_string(
conn_str=conn_str,
retry_policy=ExponentialRetry(retry_total=0),
)
service.get_account_information()
return service
except Exception as second_error:
raise second_error from first_exc

raise first_exc
29 changes: 15 additions & 14 deletions packages/ragbits-core/tests/unit/sources/test_azure.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
from pathlib import Path, PosixPath
from unittest.mock import AsyncMock, MagicMock, mock_open, patch
from unittest.mock import ANY, AsyncMock, MagicMock, mock_open, patch

import pytest
from azure.core.exceptions import ResourceNotFoundError
Expand Down Expand Up @@ -81,31 +81,28 @@ async def test_from_uri_listing():
)


@pytest.mark.asyncio
async def test_get_blob_service_no_credentials():
"""Test that ValueError is raised when no credentials are set."""
def test_get_blob_service_no_credentials():
"""Test that Exception propageted when no credentials are set."""
with (
patch.object(DefaultAzureCredential, "__init__", side_effect=Exception("Authentication failed")),
patch("os.getenv", return_value=None),
pytest.raises(ValueError, match="No authentication method available"),
pytest.raises(Exception, match="Authentication failed"),
):
await AzureBlobStorageSource._get_blob_service(account_name=ACCOUNT_NAME)
AzureBlobStorageSource._get_blob_service(account_name=ACCOUNT_NAME)


@pytest.mark.asyncio
async def test_get_blob_service_with_connection_string():
def test_get_blob_service_with_connection_string():
"""Test that connection string is used when AZURE_STORAGE_ACCOUNT_NAME is not set."""
with (
patch.object(DefaultAzureCredential, "__init__", side_effect=Exception("Authentication failed")),
patch("os.getenv", return_value="mock_connection_string"),
patch("azure.storage.blob.BlobServiceClient.from_connection_string") as mock_from_connection_string,
):
await AzureBlobStorageSource._get_blob_service(account_name="account_name")
mock_from_connection_string.assert_called_once_with(conn_str="mock_connection_string")
AzureBlobStorageSource._get_blob_service(account_name="account_name")
mock_from_connection_string.assert_called_once_with(conn_str="mock_connection_string", retry_policy=ANY)


@pytest.mark.asyncio
async def test_get_blob_service_with_default_credentials():
def test_get_blob_service_with_default_credentials():
"""Test that default credentials are used when the account_name and credentials are available."""
account_url = f"https://{ACCOUNT_NAME}.blob.core.windows.net"

Expand All @@ -114,10 +111,14 @@ async def test_get_blob_service_with_default_credentials():
patch("ragbits.core.sources.azure.BlobServiceClient") as mock_blob_client,
patch("azure.storage.blob.BlobServiceClient.from_connection_string") as mock_from_connection_string,
):
await AzureBlobStorageSource._get_blob_service(ACCOUNT_NAME)
AzureBlobStorageSource._get_blob_service(ACCOUNT_NAME)

mock_credential.assert_called_once()
mock_blob_client.assert_called_once_with(account_url=account_url, credential=mock_credential.return_value)
mock_blob_client.assert_called_once_with(
account_url=account_url,
credential=mock_credential.return_value,
retry_policy=ANY,
)
mock_from_connection_string.assert_not_called()


Expand Down
Loading