diff --git a/.github/workflows/ci.yaml b/.github/workflows/ci.yaml index 2b9ea4c6..586542ee 100644 --- a/.github/workflows/ci.yaml +++ b/.github/workflows/ci.yaml @@ -38,13 +38,13 @@ jobs: working-directory: libs/infinity_emb secrets: inherit - lint-embed_package: - uses: - ./.github/workflows/linting.yaml - with: - working-directory: libs/embed_package - extra_poetry: "--with test,lint,codespell" - secrets: inherit + # lint-embed_package: + # uses: + # ./.github/workflows/linting.yaml + # with: + # working-directory: libs/embed_package + # extra_poetry: "--with test,lint,codespell" + # secrets: inherit test-infinity_emb: uses: @@ -54,11 +54,11 @@ jobs: upload_coverage: true secrets: inherit - test-embed_package: - uses: - ./.github/workflows/test.yaml - with: - working-directory: libs/embed_package - upload_coverage: false - extra_poetry: "--with test" - secrets: inherit \ No newline at end of file + # test-embed_package: + # uses: + # ./.github/workflows/test.yaml + # with: + # working-directory: libs/embed_package + # upload_coverage: false + # extra_poetry: "--with test" + # secrets: inherit \ No newline at end of file diff --git a/.github/workflows/test.yaml b/.github/workflows/test.yaml index 130e84d7..7aa3b907 100644 --- a/.github/workflows/test.yaml +++ b/.github/workflows/test.yaml @@ -31,22 +31,22 @@ jobs: matrix: os: [ubuntu-latest, windows-latest] # macos-latest python-version: - - "3.9" - - "3.10" - - "3.11" + # - "3.9" + # - "3.10" + # - "3.11" - "3.12" - coverage_tests: ["unit_test", "end_to_end"] + coverage_tests: ["unit_test"] #, "end_to_end"] exclude: # Exclude unit tests on macOS due to compatibility issues - - python-version: "3.9" - os: macos-latest - coverage_tests: "unit_test" - - python-version: "3.10" - os: macos-latest - coverage_tests: "unit_test" - - python-version: "3.11" - os: macos-latest - coverage_tests: "unit_test" + # - python-version: "3.9" + # os: macos-latest + # coverage_tests: "unit_test" + # - python-version: "3.10" + # os: macos-latest + # coverage_tests: "unit_test" + # - python-version: "3.11" + # os: macos-latest + # coverage_tests: "unit_test" - python-version: "3.12" os: macos-latest coverage_tests: "unit_test" diff --git a/libs/infinity_emb/Docker.template.yaml b/libs/infinity_emb/Docker.template.yaml index d2e2e6af..35b1f6d7 100644 --- a/libs/infinity_emb/Docker.template.yaml +++ b/libs/infinity_emb/Docker.template.yaml @@ -16,8 +16,18 @@ cpu: main_install: | # "RUN poetry install --no-interaction --no-ansi --no-root --extras \"${EXTRAS}\" --without lint,test && poetry cache clear pypi --all" COPY requirements_install_from_poetry.sh requirements_install_from_poetry.sh + RUN apt update -y && apt install git -y RUN ./requirements_install_from_poetry.sh --no-root --without lint,test "https://download.pytorch.org/whl/cpu" - RUN poetry run $PYTHON -m pip install --no-cache-dir onnxruntime-openvino + + RUN PIP_EXTRA_INDEX_URL="https://download.pytorch.org/whl/cpu" poetry run python -m pip install -U --pre optimum-intel@git+https://github.com/huggingface/optimum-intel.git \ + openvino-tokenizers[transformers]==2024.5.* \ + openvino==2024.5.* \ + nncf>=2.11.0 \ + sentence_transformers==3.1.1 \ + openai \ + "transformers>4.45" \ + einops + # RUN poetry run $PYTHON -m pip install --no-cache-dir onnxruntime-openvino extra_env_variables: | # Sets default to onnx ENV INFINITY_ENGINE="optimum" diff --git a/libs/infinity_emb/Dockerfile.cpu_auto b/libs/infinity_emb/Dockerfile.cpu_auto index 9008b8b5..e9932ddb 100644 --- a/libs/infinity_emb/Dockerfile.cpu_auto +++ b/libs/infinity_emb/Dockerfile.cpu_auto @@ -42,15 +42,35 @@ COPY poetry.lock poetry.toml pyproject.toml README.md /app/ # # "RUN poetry install --no-interaction --no-ansi --no-root --extras \"${EXTRAS}\" --without lint,test && poetry cache clear pypi --all" COPY requirements_install_from_poetry.sh requirements_install_from_poetry.sh +RUN apt update -y && apt install git -y RUN ./requirements_install_from_poetry.sh --no-root --without lint,test "https://download.pytorch.org/whl/cpu" -RUN poetry run $PYTHON -m pip install --no-cache-dir onnxruntime-openvino + +RUN PIP_EXTRA_INDEX_URL="https://download.pytorch.org/whl/cpu" poetry run python -m pip install -U --pre optimum-intel@git+https://github.com/huggingface/optimum-intel.git \ + openvino-tokenizers[transformers]==2024.5.* \ + openvino==2024.5.* \ + nncf>=2.11.0 \ + sentence_transformers==3.1.1 \ + openai \ + "transformers>4.45" \ + einops +# RUN poetry run $PYTHON -m pip install --no-cache-dir onnxruntime-openvino COPY infinity_emb infinity_emb # Install dependency with infinity_emb package # "RUN poetry install --no-interaction --no-ansi --extras \"${EXTRAS}\" --without lint,test && poetry cache clear pypi --all" COPY requirements_install_from_poetry.sh requirements_install_from_poetry.sh +RUN apt update -y && apt install git -y RUN ./requirements_install_from_poetry.sh --without lint,test "https://download.pytorch.org/whl/cpu" -RUN poetry run $PYTHON -m pip install --no-cache-dir onnxruntime-openvino + +RUN PIP_EXTRA_INDEX_URL="https://download.pytorch.org/whl/cpu" poetry run python -m pip install -U --pre optimum-intel@git+https://github.com/huggingface/optimum-intel.git \ + openvino-tokenizers[transformers]==2024.5.* \ + openvino==2024.5.* \ + nncf>=2.11.0 \ + sentence_transformers==3.1.1 \ + openai \ + "transformers>4.45" \ + einops +# RUN poetry run $PYTHON -m pip install --no-cache-dir onnxruntime-openvino # @@ -59,8 +79,18 @@ FROM builder as testing # install lint and test dependencies # "RUN poetry install --no-interaction --no-ansi --extras \"${EXTRAS}\" --with lint,test && poetry cache clear pypi --all" COPY requirements_install_from_poetry.sh requirements_install_from_poetry.sh +RUN apt update -y && apt install git -y RUN ./requirements_install_from_poetry.sh --with lint,test "https://download.pytorch.org/whl/cpu" -RUN poetry run $PYTHON -m pip install --no-cache-dir onnxruntime-openvino + +RUN PIP_EXTRA_INDEX_URL="https://download.pytorch.org/whl/cpu" poetry run python -m pip install -U --pre optimum-intel@git+https://github.com/huggingface/optimum-intel.git \ + openvino-tokenizers[transformers]==2024.5.* \ + openvino==2024.5.* \ + nncf>=2.11.0 \ + sentence_transformers==3.1.1 \ + openai \ + "transformers>4.45" \ + einops +# RUN poetry run $PYTHON -m pip install --no-cache-dir onnxruntime-openvino # lint RUN poetry run ruff check . diff --git a/libs/infinity_emb/Dockerfile.intel_auto b/libs/infinity_emb/Dockerfile.intel_auto new file mode 100644 index 00000000..2a8ffe99 --- /dev/null +++ b/libs/infinity_emb/Dockerfile.intel_auto @@ -0,0 +1,132 @@ +# Autogenerated warning: +# This file is generated from Dockerfile.jinja2. Do not edit the Dockerfile.cuda|cpu|amd file directly. +# Only contribute to the Dockerfile.jinja2 and dockerfile_template.yaml and regenerate the Dockerfile.cuda|cpu|amd + +FROM ubuntu:22.04 AS base + +ENV PYTHONUNBUFFERED=1 \ + \ + # pip + PIP_NO_CACHE_DIR=off \ + PIP_DISABLE_PIP_VERSION_CHECK=on \ + PIP_DEFAULT_TIMEOUT=100 \ + \ + # make poetry create the virtual environment in the project's root + # it gets named `.venv` + POETRY_VIRTUALENVS_CREATE="true" \ + POETRY_VIRTUALENVS_IN_PROJECT="true" \ + # do not ask any interactive question + POETRY_NO_INTERACTION=1 \ + EXTRAS="all" \ + PYTHON="python3.11" +RUN apt-get update && apt-get install --no-install-recommends -y build-essential python3-dev libsndfile1 $PYTHON-venv $PYTHON curl +WORKDIR /app + +FROM base as builder +# Set the working directory for the app +# Define the version of Poetry to install (default is 1.7.1) +# Define the directory to install Poetry to (default is /opt/poetry) +ARG POETRY_VERSION=1.8.4 +ARG POETRY_HOME=/opt/poetry +# Create a Python virtual environment for Poetry and install it +RUN curl -sSL https://install.python-poetry.org | POETRY_HOME=$POETRY_HOME POETRY_VERSION=$POETRY_VERSION $PYTHON - +ENV PATH=$POETRY_HOME/bin:$PATH +# Test if Poetry is installed in the expected path +RUN echo "Poetry version:" && poetry --version +# Copy the rest of the app source code (this layer will be invalidated and rebuilt whenever the source code changes) +COPY poetry.lock poetry.toml pyproject.toml README.md /app/ +# Install dependencies only +# +# "RUN poetry install --no-interaction --no-ansi --no-root --extras \"${EXTRAS}\" --without lint,test && poetry cache clear pypi --all" +COPY requirements_install_from_poetry.sh requirements_install_from_poetry.sh +RUN ./requirements_install_from_poetry.sh --no-root --without lint,test "https://download.pytorch.org/whl/cpu" + +RUN poetry run python -m pip install --upgrade --upgrade-strategy eager "optimum[openvino]" + +COPY infinity_emb infinity_emb +# Install dependency with infinity_emb package +# "RUN poetry install --no-interaction --no-ansi --extras \"${EXTRAS}\" --without lint,test && poetry cache clear pypi --all" +COPY requirements_install_from_poetry.sh requirements_install_from_poetry.sh +RUN ./requirements_install_from_poetry.sh --without lint,test "https://download.pytorch.org/whl/cpu" + +# + + +FROM builder as testing +# install lint and test dependencies +# "RUN poetry install --no-interaction --no-ansi --extras \"${EXTRAS}\" --with lint,test && poetry cache clear pypi --all" +COPY requirements_install_from_poetry.sh requirements_install_from_poetry.sh +RUN ./requirements_install_from_poetry.sh --with lint,test "https://download.pytorch.org/whl/cpu" + +# # lint +# # RUN poetry run ruff check . +# # RUN poetry run mypy . +# # pytest +# COPY tests tests +# # run end to end tests because of duration of build in github ci. +# # Run tests/end_to_end on TARGETPLATFORM x86_64 otherwise run tests/end_to_end_gpu +# # poetry run python -m pytest tests/end_to_end -x # TODO: does not work. +# RUN if [ -z "$TARGETPLATFORM" ]; then \ +# ARCH=$(uname -m); \ +# if [ "$ARCH" = "x86_64" ]; then \ +# TARGETPLATFORM="linux/amd64"; \ +# elif [ "$ARCH" = "aarch64" ] || [ "$ARCH" = "arm64" ]; then \ +# TARGETPLATFORM="linux/arm64"; \ +# else \ +# echo "Unsupported architecture: $ARCH"; exit 1; \ +# fi; \ +# fi; \ +# echo "Running tests on TARGETPLATFORM=$TARGETPLATFORM"; \ +# if [ "$TARGETPLATFORM" = "linux/arm64" ] ; then \ +# poetry run python -m pytest tests/end_to_end/test_api_with_dummymodel.py -x ; \ +# else \ +# poetry run python -m pytest tests/end_to_end/test_api_with_dummymodel.py tests/end_to_end/test_sentence_transformers.py -m "not performance" -x ; \ +# fi +# RUN echo "all tests passed" > "test_results.txt" + + +# # Use a multi-stage build -> production version, with download +# FROM base AS tested-builder +# COPY --from=builder /app /app +# # force testing stage to run +# COPY --from=testing /app/test_results.txt /app/test_results.txt +# ENV HF_HOME=/app/.cache/huggingface +# ENV PATH=/app/.venv/bin:$PATH +# # do nothing +# RUN echo "copied all files" + + +# Export with tensorrt, not recommended. +# docker buildx build --target=production-tensorrt -f Dockerfile . +# FROM nvidia/cuda:11.8.0-cudnn8-devel-ubuntu22.04 AS production-tensorrt +# ENV PYTHONUNBUFFERED=1 \ +# PIP_NO_CACHE_DIR=off \ +# PYTHON="python3.11" +# RUN apt-get update && apt-get install python3-dev python3-pip $PYTHON build-essential curl -y +# COPY --from=builder /app /app +# # force testing stage to run +# COPY --from=testing /app/test_results.txt /app/test_results.txt +# ENV HF_HOME=/app/.cache/torch +# ENV PATH=/app/.venv/bin:$PATH +# RUN pip install --no-cache-dir "onnxruntime-gpu==1.17.0" "tensorrt==8.6.*" +# ENV LD_LIBRARY_PATH /app/.venv/lib/$(PYTHON)/site-packages/tensorrt:/usr/lib/x86_64-linux-gnu:/app/.venv/lib/$(PYTHON)/site-packages/tensorrt_libs:${LD_LIBRARY_PATH} +# ENV PATH /app/.venv/lib/$(PYTHON)/site-packages/tensorrt/bin:${PATH} +# ENTRYPOINT ["infinity_emb"] + + +# # Use a multi-stage build -> production version, with download +# # docker buildx build --target=production-with-download \ +# # --build-arg MODEL_NAME=BAAI/bge-small-en-v1.5 --build-arg ENGINE=torch -f Dockerfile -t infinity-BAAI-small . +# FROM tested-builder AS production-with-download +# # collect model name and engine from build args +# ARG MODEL_NAME +# RUN if [ -z "${MODEL_NAME}" ]; then echo "Error: Build argument MODEL_NAME not set." && exit 1; fi +# ARG ENGINE +# RUN if [ -z "${ENGINE}" ]; then echo "Error: Build argument ENGINE not set." && exit 1; fi +# # will exit with 3 if model is downloaded # TODO: better exit code +# RUN infinity_emb v2 --model-id $MODEL_NAME --engine $ENGINE --preload-only || [ $? -eq 3 ] +# ENTRYPOINT ["infinity_emb"] + +# # Use a multi-stage build -> production version +# FROM tested-builder AS production +# ENTRYPOINT ["infinity_emb"] diff --git a/libs/infinity_emb/infinity_emb/_optional_imports.py b/libs/infinity_emb/infinity_emb/_optional_imports.py index 6606a146..ab51a9d7 100644 --- a/libs/infinity_emb/infinity_emb/_optional_imports.py +++ b/libs/infinity_emb/infinity_emb/_optional_imports.py @@ -69,6 +69,7 @@ def _raise_error(self) -> None: "optimum.neuron", "", ) +CHECK_OPTIMUM_INTEL = OptionalImports("optimum.intel", "optimum") CHECK_PIL = OptionalImports("PIL", "vision") CHECK_POSTHOG = OptionalImports("posthog", "server") CHECK_PYDANTIC = OptionalImports("pydantic", "server") diff --git a/libs/infinity_emb/infinity_emb/primitives.py b/libs/infinity_emb/infinity_emb/primitives.py index 7ff8d404..c3b2b001 100644 --- a/libs/infinity_emb/infinity_emb/primitives.py +++ b/libs/infinity_emb/infinity_emb/primitives.py @@ -106,6 +106,7 @@ def default_value(): class Device(EnumType): cpu = "cpu" + openvino = "openvino" cuda = "cuda" mps = "mps" tensorrt = "tensorrt" diff --git a/libs/infinity_emb/infinity_emb/transformer/embedder/optimum.py b/libs/infinity_emb/infinity_emb/transformer/embedder/optimum.py index e209a742..4fe89b05 100644 --- a/libs/infinity_emb/infinity_emb/transformer/embedder/optimum.py +++ b/libs/infinity_emb/infinity_emb/transformer/embedder/optimum.py @@ -6,7 +6,11 @@ import numpy as np -from infinity_emb._optional_imports import CHECK_ONNXRUNTIME, CHECK_TRANSFORMERS +from infinity_emb._optional_imports import ( + CHECK_ONNXRUNTIME, + CHECK_TRANSFORMERS, + CHECK_OPTIMUM_INTEL, +) from infinity_emb.args import EngineArgs from infinity_emb.primitives import EmbeddingReturnType, PoolingMethod from infinity_emb.transformer.abstract import BaseEmbedder @@ -14,7 +18,7 @@ from infinity_emb.transformer.utils_optimum import ( cls_token_pooling, device_to_onnx, - get_onnx_files, + # get_onnx_files, mean_pooling, normalize, optimize_model, @@ -25,43 +29,80 @@ from optimum.onnxruntime import ( # type: ignore[import-untyped] ORTModelForFeatureExtraction, ) + from infinity_emb.transformer.utils_optimum import get_onnx_files except (ImportError, RuntimeError, Exception) as ex: CHECK_ONNXRUNTIME.mark_dirty(ex) + +if CHECK_OPTIMUM_INTEL.is_available: + try: + from optimum.intel import OVModelForFeatureExtraction # type: ignore[import-untyped] + from infinity_emb.transformer.utils_optimum import get_openvino_files + + except (ImportError, RuntimeError, Exception) as ex: + CHECK_OPTIMUM_INTEL.mark_dirty(ex) + + if CHECK_TRANSFORMERS.is_available: from transformers import AutoConfig, AutoTokenizer # type: ignore[import-untyped] class OptimumEmbedder(BaseEmbedder): def __init__(self, *, engine_args: EngineArgs): - CHECK_ONNXRUNTIME.mark_required() provider = device_to_onnx(engine_args.device) + self.provider = provider + + if "openvino" in provider.lower(): # OpenVINO Executor + CHECK_OPTIMUM_INTEL.mark_required() + filename = "" + try: + openvino_file = get_openvino_files( + model_name_or_path=engine_args.model_name_or_path, + revision=engine_args.revision, + use_auth_token=True, + ) + filename = openvino_file.as_posix() + except Exception as e: # show error then let the optimum intel compress on the fly + print(str(e)) + + self.model = optimize_model( + model_name_or_path=engine_args.model_name_or_path, + revision=engine_args.revision, + trust_remote_code=engine_args.trust_remote_code, + execution_provider=provider, + file_name=filename, + optimize_model=not os.environ.get( + "INFINITY_ONNX_DISABLE_OPTIMIZE", False + ), # TODO: make this env variable public + model_class=OVModelForFeatureExtraction, + ) - onnx_file = get_onnx_files( - model_name_or_path=engine_args.model_name_or_path, - revision=engine_args.revision, - use_auth_token=True, - prefer_quantized=("cpu" in provider.lower() or "openvino" in provider.lower()), - ) + else: + CHECK_ONNXRUNTIME.mark_required() + onnx_file = get_onnx_files( + model_name_or_path=engine_args.model_name_or_path, + revision=engine_args.revision, + use_auth_token=True, + prefer_quantized=("cpu" in provider.lower() or "openvino" in provider.lower()), + ) + self.model = optimize_model( + model_name_or_path=engine_args.model_name_or_path, + revision=engine_args.revision, + trust_remote_code=engine_args.trust_remote_code, + execution_provider=provider, + file_name=onnx_file.as_posix(), + optimize_model=not os.environ.get( + "INFINITY_ONNX_DISABLE_OPTIMIZE", False + ), # TODO: make this env variable public + model_class=ORTModelForFeatureExtraction, + ) + self.model.use_io_binding = False self.pooling = ( mean_pooling if engine_args.pooling_method == PoolingMethod.mean else cls_token_pooling ) - self.model = optimize_model( - model_name_or_path=engine_args.model_name_or_path, - revision=engine_args.revision, - trust_remote_code=engine_args.trust_remote_code, - execution_provider=provider, - file_name=onnx_file.as_posix(), - optimize_model=not os.environ.get( - "INFINITY_ONNX_DISABLE_OPTIMIZE", False - ), # TODO: make this env variable public - model_class=ORTModelForFeatureExtraction, - ) - self.model.use_io_binding = False - self.tokenizer = AutoTokenizer.from_pretrained( engine_args.model_name_or_path, revision=engine_args.revision, diff --git a/libs/infinity_emb/infinity_emb/transformer/utils_optimum.py b/libs/infinity_emb/infinity_emb/transformer/utils_optimum.py index 76176c57..a22dfda6 100644 --- a/libs/infinity_emb/infinity_emb/transformer/utils_optimum.py +++ b/libs/infinity_emb/infinity_emb/transformer/utils_optimum.py @@ -8,11 +8,15 @@ from huggingface_hub import HfApi, HfFolder # type: ignore from huggingface_hub.constants import HUGGINGFACE_HUB_CACHE # type: ignore -from infinity_emb._optional_imports import CHECK_ONNXRUNTIME, CHECK_OPTIMUM_AMD - +from infinity_emb._optional_imports import ( + CHECK_ONNXRUNTIME, + CHECK_OPTIMUM_AMD, + CHECK_OPTIMUM_INTEL, +) from infinity_emb.log_handler import logger from infinity_emb.primitives import Device + if CHECK_ONNXRUNTIME.is_available: try: import onnxruntime as ort # type: ignore @@ -25,6 +29,17 @@ except (ImportError, RuntimeError, Exception) as ex: CHECK_ONNXRUNTIME.mark_dirty(ex) +if CHECK_OPTIMUM_INTEL.is_available: + try: + from optimum.intel import ( # type: ignore + OVModelForFeatureExtraction, + OVWeightQuantizationConfig, + OVConfig, + OVQuantizer, + ) + except (ImportError, RuntimeError, Exception) as ex: + CHECK_OPTIMUM_INTEL.mark_dirty(ex) + def mean_pooling(last_hidden_states: np.ndarray, attention_mask: np.ndarray): input_mask_expanded = (np.expand_dims(attention_mask, axis=-1)).astype(float) @@ -50,6 +65,8 @@ def normalize(input_array, p=2, dim=1, eps=1e-12): def device_to_onnx(device: Device) -> str: CHECK_ONNXRUNTIME.mark_required() available = ort.get_available_providers() + if CHECK_OPTIMUM_INTEL.is_available: + available.append(["OpenVINOExecutionProvider"]) if device == Device.cpu: if "OpenVINOExecutionProvider" in available: @@ -135,24 +152,52 @@ def optimize_model( file_name=file_name, ) - ## path to find if model has been optimized - CHECK_ONNXRUNTIME.mark_required() - path_folder = ( - Path(HUGGINGFACE_HUB_CACHE) / "infinity_onnx" / execution_provider / model_name_or_path - ) - OPTIMIZED_SUFFIX = "_optimized.onnx" - files_optimized = list(path_folder.glob(f"**/*{OPTIMIZED_SUFFIX}")) + file_optimized: Union[Path, str] = "" + + extra_args = {} - logger.info(f"files_optimized: {files_optimized}") - if files_optimized: - file_optimized = files_optimized[-1] + logger.info(f"file_name: {file_name}") + + if execution_provider == "OpenVINOExecutionProvider": # Optimum Intel OpenVINO path + CHECK_OPTIMUM_INTEL.mark_required() + path_folder = ( + Path(HUGGINGFACE_HUB_CACHE) + / "infinity_openvino" + / execution_provider + / model_name_or_path + ) + OPTIMIZED_PREFIX = "openvino_model" + files_optimized = sorted(list(path_folder.glob(f"**/{OPTIMIZED_PREFIX}*"))) + if files_optimized: + file_optimized = files_optimized[-1] + if file_name: + file_optimized = file_name + + extra_args = {"ov_config": {"INFERENCE_PRECISION_HINT": "bf16"}} + + else: # Optimum onnx path + CHECK_ONNXRUNTIME.mark_required() + path_folder = ( + Path(HUGGINGFACE_HUB_CACHE) / "infinity_onnx" / execution_provider / model_name_or_path + ) + OPTIMIZED_SUFFIX = "_optimized.onnx" + files_optimized = list(path_folder.glob(f"**/*{OPTIMIZED_SUFFIX}")) + if files_optimized: + file_optimized = files_optimized[0] + + if file_optimized: logger.info(f"Optimized model found at {file_optimized}, skipping optimization") return model_class.from_pretrained( - file_optimized.parent.as_posix(), + file_optimized.parent.as_posix() + if not isinstance(file_optimized, str) + else model_name_or_path, revision=revision, trust_remote_code=trust_remote_code, - provider=execution_provider, - file_name=file_optimized.name, + provider=execution_provider, # will be ignored by optimum intel + file_name=file_optimized.name + if not isinstance(file_optimized, str) + else file_optimized, + **extra_args, ) unoptimized_model = model_class.from_pretrained( @@ -166,35 +211,71 @@ def optimize_model( return unoptimized_model try: logger.info("Optimizing model") + if execution_provider == "OpenVINOExecutionProvider": + logger.info("Optimizing model OpenVINOExecutionProvider") + ov_model = OVModelForFeatureExtraction.from_pretrained( + model_name_or_path, + export=True, + # ov_config={"INFERENCE_PRECISION_HINT": "fp32"} # fp16 for now as it has better precision than bf16 + # ov_config={"INFERENCE_PRECISION_HINT": "fp16"} # fp16 for now as it has better precision than bf16 + ov_config={ + "INFERENCE_PRECISION_HINT": "bf16" + }, # fp16 for now as it has better precision than bf16 + ) + quantizer = OVQuantizer.from_pretrained( + ov_model, task="feature-extraction", export=True + ) + ov_config = OVConfig( + quantization_config=OVWeightQuantizationConfig( + bits=4, + sym=False, + ratio=1.0, + group_size=128, + all_layers=None, + ) + ) + quantizer.quantize(ov_config=ov_config, save_directory=path_folder.as_posix()) + model = OVModelForFeatureExtraction.from_pretrained( + path_folder.as_posix(), + # ov_config={"INFERENCE_PRECISION_HINT": "fp32"} # fp16 for now as it has better precision than bf16 + # ov_config={"INFERENCE_PRECISION_HINT": "fp16"} # fp16 for now as it has better precision than bf16 + ov_config={ + "INFERENCE_PRECISION_HINT": "bf16" + }, # fp16 for now as it has better precision than bf16, + export=False, + ) + logger.info("Successfully load optimized model OpenVINOExecutionProvider") - optimizer = ORTOptimizer.from_pretrained(unoptimized_model) + else: # Optimum onnx and optimum amd path + optimizer = ORTOptimizer.from_pretrained(unoptimized_model) - is_gpu = not ( - "cpu" in execution_provider.lower() or "openvino" in execution_provider.lower() - ) - optimization_config = OptimizationConfig( - optimization_level=99, - optimize_with_onnxruntime_only=False, - optimize_for_gpu=is_gpu, - fp16=is_gpu, - # enable_gelu_approximation=True, - # enable_gemm_fast_gelu_fusion=True, # might not work - ) + is_gpu = not ( + "cpu" in execution_provider.lower() or "openvino" in execution_provider.lower() + ) + optimization_config = OptimizationConfig( + optimization_level=99, + optimize_with_onnxruntime_only=False, + optimize_for_gpu=is_gpu, + fp16=is_gpu, + # enable_gelu_approximation=True, + # enable_gemm_fast_gelu_fusion=True, # might not work + ) - optimized_model_path = optimizer.optimize( - optimization_config=optimization_config, - save_dir=path_folder.as_posix(), - # if larger than 2gb use external data format - one_external_file=True, - ) + optimized_model_path = optimizer.optimize( + optimization_config=optimization_config, + save_dir=path_folder.as_posix(), + # if larger than 2gb use external data format + one_external_file=True, + ) + + model = model_class.from_pretrained( + optimized_model_path, + revision=revision, + trust_remote_code=trust_remote_code, + provider=execution_provider, + file_name=Path(file_name).name.replace(".onnx", OPTIMIZED_SUFFIX), + ) - model = model_class.from_pretrained( - optimized_model_path, - revision=revision, - trust_remote_code=trust_remote_code, - provider=execution_provider, - file_name=Path(file_name).name.replace(".onnx", OPTIMIZED_SUFFIX), - ) except Exception as e: logger.warning(f"Optimization failed with {e}. Going to use the unoptimized model.") model = unoptimized_model @@ -251,3 +332,31 @@ def get_onnx_files( return onnx_files[0] else: raise ValueError(f"No onnx files found for {model_name_or_path} and revision {revision}") + + +def get_openvino_files( + *, + model_name_or_path: str, + revision: Union[str, None] = None, + use_auth_token: Union[bool, str] = True, +) -> Path: + """gets the onnx files from the repo""" + repo_files = _list_all_repo_files( + model_name_or_path=model_name_or_path, + revision=revision, + use_auth_token=use_auth_token, + ) + pattern = "**openvino_model.*" + openvino_files = sorted([p for p in repo_files if p.match(pattern)]) + + if len(openvino_files) > 1: + logger.info(f"Found {len(openvino_files)} openvino files: {openvino_files}") + openvino_file = openvino_files[-1] + logger.info(f"Using {openvino_file} as the model") + return openvino_file + elif len(openvino_files) == 1: + return openvino_files[0] + else: + raise ValueError( + f"No openvino files found for {model_name_or_path} and revision {revision}" + ) diff --git a/libs/infinity_emb/poetry.lock b/libs/infinity_emb/poetry.lock index 5364d966..eb2f824c 100644 --- a/libs/infinity_emb/poetry.lock +++ b/libs/infinity_emb/poetry.lock @@ -1,4 +1,4 @@ -# This file is automatically @generated by Poetry 1.7.1 and should not be changed by hand. +# This file is automatically @generated by Poetry 1.8.4 and should not be changed by hand. [[package]] name = "accelerate" @@ -2367,6 +2367,28 @@ packaging = "*" protobuf = "*" sympy = "*" +[[package]] +name = "onnxruntime-openvino" +version = "1.20.0" +description = "ONNX Runtime is a runtime accelerator for Machine Learning models" +optional = true +python-versions = "*" +files = [ + {file = "onnxruntime_openvino-1.20.0-cp310-cp310-manylinux_2_28_x86_64.whl", hash = "sha256:ae9089466ad3930cced192e8604de161c17fe833b962e511c7133a3b148e6c87"}, + {file = "onnxruntime_openvino-1.20.0-cp311-cp311-manylinux_2_28_x86_64.whl", hash = "sha256:a5e28a369394b895a0f7048d6ad940f1510a445aa3c89ad4039b57c1a006f68f"}, + {file = "onnxruntime_openvino-1.20.0-cp311-cp311-win_amd64.whl", hash = "sha256:d5b3b547e887cbc4081dad940db7d9aef6103dcce30a6746f2042400ad70676f"}, + {file = "onnxruntime_openvino-1.20.0-cp312-cp312-manylinux_2_28_x86_64.whl", hash = "sha256:97f424b05feb18b4dbb6e9a85d2bfbd4c928508dc8846622b1c12b4086ce937c"}, + {file = "onnxruntime_openvino-1.20.0-cp312-cp312-win_amd64.whl", hash = "sha256:dbd39d1dedf798997393f8fdf8cb89ee4ed905c9a8ea000abdce7c288181b829"}, +] + +[package.dependencies] +coloredlogs = "*" +flatbuffers = "*" +numpy = ">=1.21.6" +packaging = "*" +protobuf = "*" +sympy = "*" + [[package]] name = "openai" version = "1.52.0" @@ -2391,6 +2413,81 @@ typing-extensions = ">=4.11,<5" [package.extras] datalib = ["numpy (>=1)", "pandas (>=1.2.3)", "pandas-stubs (>=1.1.0.11)"] +[[package]] +name = "openvino" +version = "2024.4.0" +description = "OpenVINO(TM) Runtime" +optional = true +python-versions = "*" +files = [ + {file = "openvino-2024.4.0-16579-cp310-cp310-macosx_10_15_x86_64.whl", hash = "sha256:8f19d4200ea04ab315a02f8279268851362f434beaa1a70b4f35d2eea1efa402"}, + {file = "openvino-2024.4.0-16579-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:4a1da4c8db12559bf2760c8a7c8455e0b4373a20364eaee2c9832a6bb23c88a9"}, + {file = "openvino-2024.4.0-16579-cp310-cp310-manylinux2014_x86_64.whl", hash = "sha256:98325dec5ca8bd79f19ea10fd45ad4631a33d9ab50e30659a91a6239ae02d8f4"}, + {file = "openvino-2024.4.0-16579-cp310-cp310-manylinux_2_31_aarch64.whl", hash = "sha256:61f68366017262603be0d876e2e9b7015789ee6b319da8f1792da28b733193f8"}, + {file = "openvino-2024.4.0-16579-cp310-cp310-win_amd64.whl", hash = "sha256:a5499d6daa91c358803441561b8792231dd964c5432e838df653c1e5df8de945"}, + {file = "openvino-2024.4.0-16579-cp311-cp311-macosx_10_15_x86_64.whl", hash = "sha256:e333a5f8155ae357f74e54b664d52d85fa4036a5ccea5da49a7df7f78826c1ce"}, + {file = "openvino-2024.4.0-16579-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:b59fb073da74c7ab6d89f2559e3024044f340750b3e519e25975426beb154942"}, + {file = "openvino-2024.4.0-16579-cp311-cp311-manylinux2014_x86_64.whl", hash = "sha256:4bd3e21a70eff12166bae3b5ea824787b8c766f975f39e3f2d93729b47b74cb6"}, + {file = "openvino-2024.4.0-16579-cp311-cp311-manylinux_2_31_aarch64.whl", hash = "sha256:49b578c4d7325e4a519eb66ee5655871a2b7cd5be9d2de0d5109df23301d10a9"}, + {file = "openvino-2024.4.0-16579-cp311-cp311-win_amd64.whl", hash = "sha256:ab42204c185a4f0df5600a0adb4a4a0c97cebdf630696f94f9d06732714385bc"}, + {file = "openvino-2024.4.0-16579-cp312-cp312-macosx_10_15_x86_64.whl", hash = "sha256:d67d30f830fa3147eb37f31f53c9eaee424a4e93f33eed00d8288f304ef0250a"}, + {file = "openvino-2024.4.0-16579-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:50e05d59a90f7950c205d95bb1559e9a8a7d655fe843449d3d426c579fe665f1"}, + {file = "openvino-2024.4.0-16579-cp312-cp312-manylinux2014_x86_64.whl", hash = "sha256:d4cf229fb240b8af44d14686b5bdd94f9eabc70120b9808fc804252fac9ef615"}, + {file = "openvino-2024.4.0-16579-cp312-cp312-manylinux_2_31_aarch64.whl", hash = "sha256:5c8ceeb537019280f69dbe86049c3136e648e94fa9f3da9ef0433975e479ad09"}, + {file = "openvino-2024.4.0-16579-cp312-cp312-win_amd64.whl", hash = "sha256:83af7df6f9b7e2a96dfc5d63a774e6ca3f87d64c7372d14f7ae339387474fc5c"}, + {file = "openvino-2024.4.0-16579-cp38-cp38-macosx_10_15_x86_64.whl", hash = "sha256:b23fd5114bf42f04773f293b16965a541d58e46e6847053f1417cd6e47acddf5"}, + {file = "openvino-2024.4.0-16579-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:fa6e9fffd31c095e052f6cecb9ac3ff95e0c122418b81b9926b7687465475742"}, + {file = "openvino-2024.4.0-16579-cp38-cp38-manylinux2014_x86_64.whl", hash = "sha256:24b28379bd8d43963829b0b4df957d659269fa647f4f842bd0b3d2f8db76782b"}, + {file = "openvino-2024.4.0-16579-cp38-cp38-manylinux_2_31_aarch64.whl", hash = "sha256:4ed049ab7a2ffb624690e6cf38366383630cd58736320953cc62c78e8b31eae5"}, + {file = "openvino-2024.4.0-16579-cp38-cp38-win_amd64.whl", hash = "sha256:0cccaa53a61629b44408fe0c7537db637be913697b0f3c54c78756e95dfc4498"}, + {file = "openvino-2024.4.0-16579-cp39-cp39-macosx_10_15_x86_64.whl", hash = "sha256:3b0834492ff5bc129debb506a705d26b640bca99a10e641af8f710bd081c9af0"}, + {file = "openvino-2024.4.0-16579-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:0e01c22a9989470ebcbb7b05fd07e4297bf6c5ecdca202b05f5dc9d2b3186f39"}, + {file = "openvino-2024.4.0-16579-cp39-cp39-manylinux2014_x86_64.whl", hash = "sha256:fb4781dd9691dc4cfbc6b69f56f724286699e309d4ddc7894fed3ee77b683e2f"}, + {file = "openvino-2024.4.0-16579-cp39-cp39-manylinux_2_31_aarch64.whl", hash = "sha256:74094f8ad81c2ae8500d3361ac087455316a6583016f693d7e1dd1500217ceec"}, + {file = "openvino-2024.4.0-16579-cp39-cp39-win_amd64.whl", hash = "sha256:be834d95405fe3724f104b54f3836e7053646d952c4f8a5dad2267665f55d88f"}, +] + +[package.dependencies] +numpy = ">=1.16.6,<2.1.0" +openvino-telemetry = ">=2023.2.1" +packaging = "*" + +[[package]] +name = "openvino-telemetry" +version = "2024.5.0" +description = "OpenVINO™ Telemetry package for sending statistics with user's consent, used in combination with other OpenVINO™ packages." +optional = true +python-versions = "*" +files = [ + {file = "openvino_telemetry-2024.5.0-py3-none-any.whl", hash = "sha256:c29073f4b0c4d4229be5d10612b072f90a6aea97bcb17005085248f1404ec2ab"}, + {file = "openvino_telemetry-2024.5.0.tar.gz", hash = "sha256:592d266954903e8f800d984a7573f218af8118a6c15fc623545ea0b5b0fa72e1"}, +] + +[[package]] +name = "openvino-tokenizers" +version = "2024.4.0.0" +description = "Convert tokenizers into OpenVINO models" +optional = true +python-versions = ">=3.8" +files = [ + {file = "openvino_tokenizers-2024.4.0.0-py3-none-macosx_10_15_x86_64.whl", hash = "sha256:0b9fe78b01e796e9124a04d21cf3074c98cf1a719136f7b14fbab933ac5b85dd"}, + {file = "openvino_tokenizers-2024.4.0.0-py3-none-macosx_11_0_arm64.whl", hash = "sha256:a9649be7949be9c72fea02577ea61f320173f43d6ba47d735d1dc9c1901e97c8"}, + {file = "openvino_tokenizers-2024.4.0.0-py3-none-manylinux2014_x86_64.whl", hash = "sha256:0ec426329d951c2e13c843da42a0d01926b2f40f7f2844e51207dbb5073f1662"}, + {file = "openvino_tokenizers-2024.4.0.0-py3-none-manylinux_2_31_aarch64.whl", hash = "sha256:e43c9ef643ebccf713d551186efb68edc394834a19f894822e0366413079bfc2"}, + {file = "openvino_tokenizers-2024.4.0.0-py3-none-win_amd64.whl", hash = "sha256:ee9f99f6e230a364f0708d0625273295e4f60dee54d915f8dc6d015bc0921015"}, +] + +[package.dependencies] +openvino = "==2024.4.*" + +[package.extras] +all = ["openvino_tokenizers[dev,transformers]"] +benchmark = ["openvino_tokenizers[transformers]", "pandas", "seaborn", "tqdm"] +dev = ["bandit", "openvino_tokenizers[torch,transformers]", "pandas", "pytest", "pytest_harvest", "ruff"] +fuzzing = ["atheris", "openvino_tokenizers[transformers]"] +torch = ["torch"] +transformers = ["tiktoken", "transformers[sentencepiece] (>=4.36.0)"] + [[package]] name = "optimum" version = "1.23.3" @@ -5198,6 +5295,7 @@ ct2 = ["ctranslate2", "sentence-transformers", "torch", "transformers"] einops = ["einops"] logging = ["rich"] onnxruntime-gpu = ["onnxruntime-gpu"] +openvino = ["onnxruntime-openvino", "openvino", "openvino-tokenizers"] optimum = ["optimum"] server = ["fastapi", "orjson", "posthog", "prometheus-fastapi-instrumentator", "pydantic", "rich", "typer", "uvicorn"] tensorrt = ["tensorrt"] @@ -5207,4 +5305,4 @@ vision = ["colpali-engine", "pillow", "timm", "torchvision"] [metadata] lock-version = "2.0" python-versions = ">=3.9,<4" -content-hash = "7d3ee9f0dbde4965c672639f86264c05689b8753c266e23400e5fd72b78dfed2" +content-hash = "23feae6cd9a95ff4a6ed50da692d28ba9b514d3067adc9bcc8e4860a70a13942" diff --git a/libs/infinity_emb/pyproject.toml b/libs/infinity_emb/pyproject.toml index 654cdf13..9fa0cf3f 100644 --- a/libs/infinity_emb/pyproject.toml +++ b/libs/infinity_emb/pyproject.toml @@ -41,9 +41,9 @@ timm = {version = "*", optional=true} colpali-engine = {version="^0.3.1", optional=true} # openvino # optimum-intel = {version=">=1.20.0", optional=true, extras=["openvino"]} -# onnxruntime-openvino = {version=">=1.19.0", optional=true} -# openvino = {version="2024.4.0", optional=true} -# openvino-tokenizers = {version="2024.4.0.0", optional=true} +onnxruntime-openvino = {version=">=1.19.0", optional=true} +openvino = {version="2024.4.0", optional=true} +openvino-tokenizers = {version="2024.4.0.0", optional=true} # pin torchvision to a specific source, but default to pypi. use sed to overwrite. @@ -106,7 +106,7 @@ einops=["einops"] logging=["rich"] cache=["diskcache"] vision=["colpali-engine","pillow","timm","torchvision"] -# openvino=["onnxruntime-openvino","openvino","openvino-tokenizers"] +openvino=["onnxruntime-openvino","openvino","openvino-tokenizers"] audio=["soundfile"] server=[ "fastapi", diff --git a/libs/infinity_emb/tests/unit_test/transformer/embedder/test_optimum.py b/libs/infinity_emb/tests/unit_test/transformer/embedder/test_optimum.py index c612d305..70c6b7a0 100644 --- a/libs/infinity_emb/tests/unit_test/transformer/embedder/test_optimum.py +++ b/libs/infinity_emb/tests/unit_test/transformer/embedder/test_optimum.py @@ -25,3 +25,25 @@ def test_embedder_optimum(size="large"): cosine_sim = np.dot(r, e) / (np.linalg.norm(e) * np.linalg.norm(r)) assert cosine_sim > 0.94 np.testing.assert_allclose(embeds, embeds_orig, atol=0.25) + + +def test_embedder_optimum_openvino_cpu(size="large"): + model = OptimumEmbedder( + engine_args=EngineArgs(model_name_or_path=f"BAAI/bge-{size}-en-v1.5", device="cpu") + ) + st_model = SentenceTransformer(model_name_or_path=f"BAAI/bge-{size}-en-v1.5", device="cpu") + + sentences = ["This is awesome.", "I am depressed."] + + encode_pre = model.encode_pre(sentences) + encode_core = model.encode_core(encode_pre) + embeds = model.encode_post(encode_core) + + embeds_orig = st_model.encode(sentences) + + assert len(embeds) == len(sentences) + + for r, e in zip(embeds, embeds_orig): + cosine_sim = np.dot(r, e) / (np.linalg.norm(e) * np.linalg.norm(r)) + assert cosine_sim > 0.94 + np.testing.assert_allclose(embeds, embeds_orig, atol=0.25)