Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
41 changes: 39 additions & 2 deletions src/gaia/rag/sdk.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,14 +30,19 @@
except ImportError:
PdfReader = None

# Not just ImportError: a broken native dependency (e.g. torchcodec/FFmpeg
# pulled in by sentence-transformers, or an arch-mismatched faiss build) raises
# RuntimeError/OSError at import. Treat that the same as "not installed" so a
# bad install can't crash every module that transitively imports RAG; the loud,
# actionable error is deferred to RAGSDK._check_dependencies() at point of use.
try:
from sentence_transformers import SentenceTransformer
except ImportError:
except Exception: # pylint: disable=broad-except
SentenceTransformer = None

try:
import faiss
except ImportError:
except Exception: # pylint: disable=broad-except
faiss = None

from gaia.chat.sdk import AgentConfig, AgentSDK
Expand Down Expand Up @@ -225,6 +230,38 @@ def _check_dependencies(self):
f"Or install packages directly:\n"
f" uv pip install {' '.join(missing)}\n"
)
# A package that is installed but failed to import (broken native
# deps) needs a different fix than a missing one — name the cause.
# Re-import on this (already-failing) path to recover the reason,
# skipping genuinely-missing packages (ImportError) which the
# install instructions above already cover.
broken = []
for pkg, label in (
("sentence_transformers", "sentence-transformers"),
("faiss", "faiss"),
):
if (pkg == "sentence_transformers" and SentenceTransformer is None) or (
pkg == "faiss" and faiss is None
):
try:
# Use the import statement (__import__), not
# importlib.import_module — the latter bypasses
# builtins.__import__, so this path can't be exercised
# by tests that intercept imports, and re-running the
# real import is what re-surfaces the native cause.
__import__(pkg)
except ImportError:
pass # genuinely missing → covered by install instructions
except Exception as exc: # pylint: disable=broad-except
broken.append(f" {label}: {exc}")
if broken:
error_msg += (
"\nThe package(s) below are installed but failed to load — "
"reinstalling won't help until the underlying error is fixed "
"(e.g. a missing FFmpeg for torchcodec):\n"
+ "\n".join(broken)
+ "\n"
)
raise ImportError(error_msg)

def _safe_open(self, file_path: str, mode="rb"):
Expand Down
72 changes: 72 additions & 0 deletions tests/unit/rag/test_dependency_guard.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,72 @@
# Copyright(C) 2024-2026 Advanced Micro Devices, Inc. All rights reserved.
# SPDX-License-Identifier: MIT

"""Unit tests for RAG dependency-import guards.

A broken native dependency (e.g. torchcodec/FFmpeg under sentence-transformers,
or an arch-mismatched faiss build) raises ``RuntimeError``/``OSError`` at import
rather than ``ImportError``. The guard in ``gaia.rag.sdk`` must treat that the
same as "not installed" so it cannot crash every module that transitively
imports RAG, while still surfacing a loud, actionable error at point of use.
"""

import builtins
import importlib

import pytest

# Importing the module must NOT raise even when an optional native dep is broken
# in the environment — that is the regression this guard protects against.
sdk = importlib.import_module("gaia.rag.sdk")


def _bare_sdk():
"""An RAGSDK instance without running __init__ (it only needs the method)."""
return sdk.RAGSDK.__new__(sdk.RAGSDK)


def test_module_imports_without_optional_deps():
"""The module is importable regardless of optional-dependency health."""
assert hasattr(sdk, "RAGSDK")
assert hasattr(sdk, "SentenceTransformer")
assert hasattr(sdk, "faiss")


def test_broken_install_reports_actionable_cause(monkeypatch):
"""An installed-but-broken dep surfaces the captured cause, not just 'install it'."""
monkeypatch.setattr(sdk, "SentenceTransformer", None)

real_import = builtins.__import__

def fake_import(name, *args, **kwargs):
if name == "sentence_transformers":
raise RuntimeError("Could not load libtorchcodec (FFmpeg not found)")
return real_import(name, *args, **kwargs)

monkeypatch.setattr(builtins, "__import__", fake_import)

with pytest.raises(ImportError) as excinfo:
_bare_sdk()._check_dependencies()
msg = str(excinfo.value)
assert "installed but failed to load" in msg
assert "libtorchcodec" in msg # the captured cause is named


def test_genuinely_missing_dep_omits_broken_section(monkeypatch):
"""A simply-missing dep gets install instructions, not the broken-load hint."""
monkeypatch.setattr(sdk, "SentenceTransformer", None)

real_import = builtins.__import__

def fake_import(name, *args, **kwargs):
if name == "sentence_transformers":
raise ImportError("No module named 'sentence_transformers'")
return real_import(name, *args, **kwargs)

monkeypatch.setattr(builtins, "__import__", fake_import)

with pytest.raises(ImportError) as excinfo:
_bare_sdk()._check_dependencies()
msg = str(excinfo.value)
assert "sentence-transformers" in msg
assert "installed but failed to load" not in msg
Loading