Skip to content

Commit c8607a1

Browse files
authored
Add a check to import_utils.py to allow for use of faiss_gpu installation (#37997)
Adding check to import_utils.py for faiss_gpu
1 parent fb1e3a4 commit c8607a1

File tree

1 file changed

+6
-2
lines changed

1 file changed

+6
-2
lines changed

src/transformers/utils/import_utils.py

+6-2
Original file line numberDiff line numberDiff line change
@@ -131,7 +131,7 @@ def _is_package_available(pkg_name: str, return_version: bool = False) -> Union[
131131
_yt_dlp_available = importlib.util.find_spec("yt_dlp") is not None
132132
_datasets_available = _is_package_available("datasets")
133133
_detectron2_available = _is_package_available("detectron2")
134-
# We need to check both `faiss` and `faiss-cpu`.
134+
# We need to check `faiss`, `faiss-cpu` and `faiss-gpu`.
135135
_faiss_available = importlib.util.find_spec("faiss") is not None
136136
try:
137137
_faiss_version = importlib.metadata.version("faiss")
@@ -141,7 +141,11 @@ def _is_package_available(pkg_name: str, return_version: bool = False) -> Union[
141141
_faiss_version = importlib.metadata.version("faiss-cpu")
142142
logger.debug(f"Successfully imported faiss version {_faiss_version}")
143143
except importlib.metadata.PackageNotFoundError:
144-
_faiss_available = False
144+
try:
145+
_faiss_version = importlib.metadata.version("faiss-gpu")
146+
logger.debug(f"Successfully imported faiss version {_faiss_version}")
147+
except importlib.metadata.PackageNotFoundError:
148+
_faiss_available = False
145149
_ftfy_available = _is_package_available("ftfy")
146150
_g2p_en_available = _is_package_available("g2p_en")
147151
_hadamard_available = _is_package_available("fast_hadamard_transform")

0 commit comments

Comments
 (0)