Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
29 changes: 15 additions & 14 deletions QEfficient/base/modeling_qeff.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,7 @@

from QEfficient.base.onnx_transforms import (
BaseOnnxTransform,
FP16ClipTransform,
OnnxTransformPipeline,
SplitTensorsTransform,
)
from QEfficient.base.pytorch_transforms import PytorchTransform
from QEfficient.compile.qnn_compiler import compile as qnn_compile
Expand Down Expand Up @@ -54,8 +52,9 @@ class QEFFBaseModel(ABC):
_pytorch_transforms: List[PytorchTransform]
_onnx_transforms = [BaseOnnxTransform]

def _transform_names(self) -> List[str]:
return [x.__name__ for x in self._pytorch_transforms + self._onnx_transforms]
@classmethod
def _transform_names(cls) -> List[str]:
return [x.__name__ for x in cls._pytorch_transforms + cls._onnx_transforms]

def __init__(self, model: torch.nn.Module, **kwargs) -> None:
super().__init__()
Expand Down Expand Up @@ -246,7 +245,10 @@ def _export(
# check if the model is in meta state or weights are offloaded
self._model_offloaded_check()

export_dir.mkdir(parents=True, exist_ok=True)
# Setup temporary paths
tmp_onnx_dir = export_dir / "onnx_tmp"
tmp_onnx_path = tmp_onnx_dir / f"{self.model_name}.onnx"
tmp_onnx_dir.mkdir(parents=True, exist_ok=True)

# Create input_names from example_inputs
input_names = []
Expand Down Expand Up @@ -276,7 +278,7 @@ def _export(
torch.onnx.export(
self.model,
(example_inputs,),
str(onnx_path),
str(tmp_onnx_path),
input_names=input_names,
output_names=output_names,
dynamic_axes=dynamic_axes,
Expand All @@ -285,13 +287,11 @@ def _export(
)
logger.info("PyTorch export successful")
_ = self._offload_model_weights(offload_pt_weights)
model = onnx.load(onnx_path, load_external_data=False)
model = onnx.load(tmp_onnx_path, load_external_data=False)

needs_external_tensor_data = any(
transform in self._onnx_transforms for transform in (FP16ClipTransform, SplitTensorsTransform)
)
# Clear temporary references
transform_kwargs = {
"onnx_base_dir": str(export_dir) if needs_external_tensor_data else None,
"onnx_base_dir": str(tmp_onnx_dir),
"model_name": self.model_name,
}
if onnx_transform_kwargs is not None:
Expand All @@ -306,9 +306,7 @@ def _export(
)
logger.info("ONNX transforms applied")

onnx_path_tmp = onnx_path.with_suffix(onnx_path.suffix + ".tmp")
onnx.save(model, onnx_path_tmp)
onnx_path_tmp.replace(onnx_path)
onnx.save(model, onnx_path)
del model
gc.collect()
logger.info("Transformed ONNX saved")
Expand All @@ -317,6 +315,9 @@ def _export(
logger.error(f"ONNX export or transforms failed: {e}")
raise e

finally:
shutil.rmtree(tmp_onnx_dir, ignore_errors=True)

self.onnx_path = onnx_path
return onnx_path

Expand Down
10 changes: 5 additions & 5 deletions QEfficient/base/onnx_transforms.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,7 @@

import logging
import os
import warnings
from concurrent.futures import ThreadPoolExecutor, as_completed
from typing import Any, Dict, List, Optional, Tuple, Type

Expand Down Expand Up @@ -117,15 +118,13 @@ def apply(cls, model: ModelProto) -> bool:

# Add function prototypes to model
existing = {f.name for f in model.functions}

for func_name, onnxscript_func in cls._custom_ops.values():
for _, onnxscript_func in cls._custom_ops.values():
proto = onnxscript_func.to_function_proto()
if proto.name not in used_op_types:
continue
if proto.name not in existing:
model.functions.append(proto)
op_applied = True

return op_applied


Expand Down Expand Up @@ -212,6 +211,8 @@ class OnnxTransformPipeline(BaseOnnxTransform):
"""Pipeline to apply multiple ONNX transformations in sequence."""

def __init__(self, transforms: List[Type[BaseOnnxTransform]]):
if not transforms:
warnings.warn("Transform list is empty. No transformations will be applied.")
self.transforms = transforms

def apply(
Expand All @@ -236,8 +237,7 @@ def apply(
do_split = SplitTensorsTransform in requested
fp16_min, fp16_max = np.finfo(np.float16).min, np.finfo(np.float16).max
file_num_tracker = {"num": 0, "size": 0}
if onnx_base_dir is not None:
external_data_helper.load_external_data_for_model(model, onnx_base_dir)
external_data_helper.load_external_data_for_model(model, onnx_base_dir)

if do_fp16 or do_split:
for tensor in external_data_helper._get_all_tensors(model):
Expand Down
32 changes: 1 addition & 31 deletions QEfficient/transformers/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@
# -----------------------------------------------------------------------------

from collections import namedtuple
from typing import TYPE_CHECKING, Dict, Optional, Tuple, Type
from typing import Dict, Optional, Tuple, Type

import torch
import torch.nn as nn
Expand Down Expand Up @@ -88,14 +88,8 @@
WhisperPositionalEmbedding,
)

from QEfficient.base.onnx_transforms import FP16ClipTransform, SplitTensorsTransform
from QEfficient.customop import CustomRMSNormAIC
from QEfficient.proxy.pytorch_transform import QeffProxyModuleTransform
from QEfficient.utils.constants import MIN_MASKED_ATTENTION_VALUE
from QEfficient.utils.logging_utils import logger

if TYPE_CHECKING:
from QEfficient.base.modeling_qeff import QEFFBaseModel

# Placeholder for all non-transformer models
from .models.codegen.modeling_codegen import (
Expand Down Expand Up @@ -197,30 +191,6 @@

# This is for supporting different modelling classes specially written for prefill-only model
SPECIALIZED_DISAGG_SERVING_MODEL_ARCH = {"gpt_oss"}

_PROXY_ONLY_ONNX_TRANSFORMS = (FP16ClipTransform, SplitTensorsTransform)


def _configure_proxy_for_model(instance: "QEFFBaseModel", enable_proxy: bool) -> None:
"""
Configure per-instance transform lists based on proxy mode.

Keep class-defined ONNX transforms by default.
Proxy flow appends additional proxy-only transforms.
"""
instance._pytorch_transforms = list(instance._pytorch_transforms)
instance._onnx_transforms = list(instance._onnx_transforms)
instance._enable_proxy = enable_proxy

if enable_proxy:
if QeffProxyModuleTransform not in instance._pytorch_transforms:
instance._pytorch_transforms.append(QeffProxyModuleTransform)
for transform in _PROXY_ONLY_ONNX_TRANSFORMS:
if transform not in instance._onnx_transforms:
instance._onnx_transforms.append(transform)
logger.info("Proxy Model Enabled for QEfficient Model")


# Define a transformers layers to QEff layers dictionary
# While onboarding new models make sure to add the new layer maps to this dictionary.
TransformersToQEffModulesDict: Dict[Type[nn.Module], Type[nn.Module]] = {
Expand Down
55 changes: 38 additions & 17 deletions QEfficient/transformers/models/modeling_auto.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@

import QEfficient
from QEfficient.base.modeling_qeff import QEFFBaseModel
from QEfficient.base.onnx_transforms import FP16ClipTransform
from QEfficient.base.onnx_transforms import FP16ClipTransform, SplitTensorsTransform
from QEfficient.base.pytorch_transforms import SplitGateUpWeightsTransform
from QEfficient.generation.cloud_infer import QAICInferenceSession
from QEfficient.generation.text_generation_inference import (
Expand All @@ -40,10 +40,10 @@
write_io_files,
)
from QEfficient.generation.vlm_generation import VisionLanguageGeneration
from QEfficient.proxy.pytorch_transform import QeffProxyModuleTransform
from QEfficient.transformers.modeling_utils import (
DYNAMIC_SEQ_LEN_SUPPORTED_MODEL_ARCH,
SPECIALIZED_DISAGG_SERVING_MODEL_ARCH,
_configure_proxy_for_model,
)
from QEfficient.transformers.models.pytorch_transforms import (
BlockedKVAttentionTransform,
Expand Down Expand Up @@ -91,7 +91,9 @@ class QEFFTransformersBase(QEFFBaseModel):
_hf_auto_class: type

def __init__(self, model: nn.Module, **kwargs) -> None:
_configure_proxy_for_model(self, kwargs.pop("enable_proxy", False))
if kwargs.pop("enable_proxy", False):
self._pytorch_transforms.append(QeffProxyModuleTransform)
logger.info("Proxy Model Enabled for QEfficient Model")

if (
hasattr(model, "config")
Expand Down Expand Up @@ -231,7 +233,7 @@ class QEFFAutoModel(QEFFTransformersBase):

_hf_auto_class = AutoModel
_pytorch_transforms = [CustomOpsTransform, AwqToMatmulNbitsTransform, GPTQToMatmulNbitsTransform]
_onnx_transforms = [FP16ClipTransform]
_onnx_transforms = [FP16ClipTransform, SplitTensorsTransform]

def __init__(self, model: nn.Module, pooling=None, **kwargs):
"""
Expand All @@ -248,6 +250,10 @@ def __init__(self, model: nn.Module, pooling=None, **kwargs):
**kwargs :
Additional keyword arguments passed to the base class constructor.
"""
if kwargs.pop("enable_proxy", False):
self._pytorch_transforms.append(QeffProxyModuleTransform)
logger.info("Proxy Model Enabled for QEfficient Model")

super().__init__(model, **kwargs)

# Make Embedding specific transforms like appending pooling
Expand Down Expand Up @@ -619,7 +625,7 @@ class QEFFAutoModelForSequenceClassification(QEFFTransformersBase):

_hf_auto_class = AutoModelForSequenceClassification
_pytorch_transforms = [CustomOpsTransform, TextClassificationTransform]
_onnx_transforms = []
_onnx_transforms = [FP16ClipTransform, SplitTensorsTransform]

def __init__(self, model: nn.Module, **kwargs):
"""
Expand Down Expand Up @@ -662,8 +668,6 @@ def from_pretrained(cls, pretrained_model_name_or_path, *args, **kwargs):
QEFFAutoModelForSequenceClassification
An instance initialized with the pretrained weights.
"""
enable_proxy = kwargs.pop("enable_proxy", False)

if kwargs.get("attn_implementation", None) not in {None, "eager"}:
logger.warning('Updating attn_implementation="eager"')

Expand All @@ -673,7 +677,6 @@ def from_pretrained(cls, pretrained_model_name_or_path, *args, **kwargs):
kwargs.update({"attn_implementation": "eager", "low_cpu_mem_usage": False})

model = cls._hf_auto_class.from_pretrained(pretrained_model_name_or_path, *args, **kwargs)
kwargs.update({"enable_proxy": enable_proxy} if enable_proxy else {})
return cls(model, pretrained_model_name_or_path=pretrained_model_name_or_path, **kwargs)

@property
Expand Down Expand Up @@ -861,7 +864,7 @@ class QEffVisionEncoderForTextImageToTextModel(QEFFBaseModel):
KVCacheTransform,
KVCacheExternalModuleMapperTransform,
]
_onnx_transforms = []
_onnx_transforms = [FP16ClipTransform, SplitTensorsTransform]

def __init__(self, model: nn.modules, **kwargs):
"""
Expand All @@ -874,7 +877,9 @@ def __init__(self, model: nn.modules, **kwargs):
**kwargs :
Additional keyword arguments passed to the base class constructor.
"""
_configure_proxy_for_model(self, kwargs.pop("enable_proxy", False))
if kwargs.pop("enable_proxy", False):
self._pytorch_transforms.append(QeffProxyModuleTransform)
logger.info("Proxy Model Enabled for QEfficient Model")
super().__init__(model, **kwargs)
self.model = model.get_qeff_vision_encoder()
self.hash_params["qeff_auto_class"] = self.__class__.__name__
Expand Down Expand Up @@ -1002,7 +1007,7 @@ class QEffCausalLMForTextImageToTextModel(QEFFBaseModel):
VlmKVOffloadTransform,
SplitGateUpWeightsTransform,
]
_onnx_transforms = []
_onnx_transforms = [FP16ClipTransform, SplitTensorsTransform]

def __init__(self, model, qaic_config: Optional[dict] = None, **kwargs):
"""
Expand All @@ -1018,7 +1023,9 @@ def __init__(self, model, qaic_config: Optional[dict] = None, **kwargs):
**kwargs :
Additional keyword arguments passed to the base class constructor.
"""
_configure_proxy_for_model(self, kwargs.pop("enable_proxy", False))
if kwargs.pop("enable_proxy", False):
self._pytorch_transforms.append(QeffProxyModuleTransform)
logger.info("Proxy Model Enabled for QEfficient Model")
super().__init__(model, **kwargs)
self.model = model.get_qeff_language_decoder()
self.model.qaic_config = qaic_config
Expand Down Expand Up @@ -1937,7 +1944,7 @@ class _QEFFAutoModelForImageTextToTextSingleQPC(QEFFTransformersBase, Multimodal
VlmNoKVOffloadTransform,
SplitGateUpWeightsTransform,
]
_onnx_transforms = []
_onnx_transforms = [FP16ClipTransform, SplitTensorsTransform]

def __init__(
self,
Expand Down Expand Up @@ -1971,6 +1978,10 @@ def __init__(
if qaic_config is not None and qaic_config.pop("include_sampler", False):
raise NotImplementedError("On-device sampling is not supported for single QPC multimodal models yet.")

if kwargs.pop("enable_proxy", False):
self._pytorch_transforms.append(QeffProxyModuleTransform)
logger.info("Proxy Model Enabled for QEfficient Model")

super().__init__(model, **kwargs)

self.model.qaic_config = qaic_config
Expand Down Expand Up @@ -2689,7 +2700,7 @@ class QEFFAutoModelForCausalLM(QEFFBaseModel):
KVCacheExternalModuleMapperTransform,
]

_onnx_transforms = []
_onnx_transforms = [FP16ClipTransform, SplitTensorsTransform]

def prefill(
self,
Expand Down Expand Up @@ -2766,7 +2777,9 @@ def __init__(
model_class_name = model.__class__.__name__
if not (model_class_name.endswith("ForCausalLM") or model_class_name.endswith("LMHeadModel")):
raise TypeError(f"Required pytorch module for CausalLM or LMHeadModel, got {model_class_name}")
_configure_proxy_for_model(self, kwargs.pop("enable_proxy", False))
if kwargs.pop("enable_proxy", False):
self._pytorch_transforms.append(QeffProxyModuleTransform)
logger.info("Proxy Model Enabled for QEfficient Model")

# TODO: remove from version 1.20
if kwargs.pop("full_batch_size", None):
Expand Down Expand Up @@ -3654,7 +3667,7 @@ class QEFFAutoModelForSpeechSeq2Seq(QEFFTransformersBase, MultimodalUtilityMixin

_hf_auto_class = AutoModelForSpeechSeq2Seq
_pytorch_transforms = [CustomOpsTransform, AwqToMatmulNbitsTransform, GPTQToMatmulNbitsTransform, KVCacheTransform]
_onnx_transforms = []
_onnx_transforms = [FP16ClipTransform, SplitTensorsTransform]

def __init__(self, model: nn.Module, **kwargs):
"""
Expand All @@ -3674,6 +3687,10 @@ def __init__(self, model: nn.Module, **kwargs):
"""
model_class_name = model.__class__.__name__

if kwargs.pop("enable_proxy", False):
self._pytorch_transforms.append(QeffProxyModuleTransform)
logger.info("Proxy Model Enabled for QEfficient Model")

if not (model_class_name.endswith("ForConditionalGeneration")):
raise TypeError(f"Required pytorch module with ForConditionalGeneration, got {model_class_name}")

Expand Down Expand Up @@ -4013,9 +4030,13 @@ class QEFFAutoModelForCTC(QEFFTransformersBase):

_hf_auto_class = AutoModelForCTC
_pytorch_transforms = [CustomOpsTransform, AwqToMatmulNbitsTransform, GPTQToMatmulNbitsTransform]
_onnx_transforms = []
_onnx_transforms = [FP16ClipTransform, SplitTensorsTransform]

def __init__(self, model: nn.Module, **kwargs):
if kwargs.pop("enable_proxy", False):
self._pytorch_transforms.append(QeffProxyModuleTransform)
logger.info("Proxy Model Enabled for QEfficient Model")

super().__init__(model, **kwargs)
self.model.base_model.config.use_cache = True

Expand Down
Loading
Loading