Skip to content
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
39 changes: 33 additions & 6 deletions libs/infinity_emb/infinity_emb/transformer/acceleration.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,11 +7,24 @@
from infinity_emb._optional_imports import CHECK_OPTIMUM, CHECK_TORCH, CHECK_TRANSFORMERS
from infinity_emb.primitives import Device

if CHECK_OPTIMUM.is_available:
from optimum.bettertransformer import ( # type: ignore[import-untyped]
BetterTransformer,
BetterTransformerManager,
)
# lazy imports to avoid issues with deprecated BetterTransformer
BetterTransformer = None
BetterTransformerManager = None

def _import_bettertransformer():
"""Lazy import BetterTransformer to avoid import errors when it's not needed."""
global BetterTransformer, BetterTransformerManager
if BetterTransformer is None and CHECK_OPTIMUM.is_available:
try:
from optimum.bettertransformer import ( # type: ignore[import-untyped]
BetterTransformer as _BetterTransformer,
BetterTransformerManager as _BetterTransformerManager,
)
BetterTransformer = _BetterTransformer
BetterTransformerManager = _BetterTransformerManager
except Exception:
# If import fails, keep them as None
pass

if CHECK_TORCH.is_available:
import torch
Expand All @@ -37,6 +50,11 @@ def check_if_bettertransformer_possible(engine_args: "EngineArgs") -> bool:
if not engine_args.bettertransformer:
return False

_import_bettertransformer()

if BetterTransformerManager is None:
return False

config = AutoConfig.from_pretrained(
pretrained_model_name_or_path=engine_args.model_name_or_path,
revision=engine_args.revision,
Expand Down Expand Up @@ -65,6 +83,15 @@ def to_bettertransformer(model: "PreTrainedModel", engine_args: "EngineArgs", lo
"INFINITY_DISABLE_OPTIMUM is no longer supported, please use the CLI / ENV for that."
)

_import_bettertransformer()

if BetterTransformer is None:
logger.warning(
"BetterTransformer is not available (likely due to transformers version incompatibility). "
"Continue without bettertransformer modeling code."
)
return model

if (
hasattr(model.config, "_attn_implementation")
and model.config._attn_implementation != "eager"
Expand All @@ -80,7 +107,7 @@ def to_bettertransformer(model: "PreTrainedModel", engine_args: "EngineArgs", lo
"Since torch 2.5.0, this combination leads to a segfault. Please report if you find this check to be incorrect."
)
try:
model = BetterTransformer.transform(model)
model = BetterTransformer.transform(model) # type: ignore
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

style: The type: ignore comment should explain why the type is being ignored, e.g. # type: ignore[attr-defined] since BetterTransformer could be None

except Exception as ex:
# if level is debug then show the exception
if logger.level <= 10:
Expand Down
Loading