diff --git a/.github/workflows/build-container.yaml b/.github/workflows/build-container.yaml index fe12fbf6..04cb4147 100644 --- a/.github/workflows/build-container.yaml +++ b/.github/workflows/build-container.yaml @@ -34,21 +34,12 @@ jobs: TAILSCALE_AUTHKEY: ${{ secrets.TAILSCALE_AUTHKEY }} REGISTRY_USERNAME: ${{ secrets.REGISTRY_USERNAME }} REGISTRY_PASSWORD: ${{ secrets.REGISTRY_PASSWORD }} - starlette-tensorflow-cpu: + starlette-pytorch-inf2: uses: ./.github/workflows/docker-build-action.yaml with: - image: inference-tensorflow-cpu - dockerfile: dockerfiles/tensorflow/cpu/Dockerfile + image: inference-pytorch-inf2 + dockerfile: dockerfiles/pytorch/Dockerfile.inf2 secrets: TAILSCALE_AUTHKEY: ${{ secrets.TAILSCALE_AUTHKEY }} REGISTRY_USERNAME: ${{ secrets.REGISTRY_USERNAME }} - REGISTRY_PASSWORD: ${{ secrets.REGISTRY_PASSWORD }} - starlette-tensorflow-gpu: - uses: ./.github/workflows/docker-build-action.yaml - with: - image: inference-tensorflow-gpu - dockerfile: dockerfiles/tensorflow/gpu/Dockerfile - secrets: - TAILSCALE_AUTHKEY: ${{ secrets.TAILSCALE_AUTHKEY }} - REGISTRY_USERNAME: ${{ secrets.REGISTRY_USERNAME }} - REGISTRY_PASSWORD: ${{ secrets.REGISTRY_PASSWORD }} + REGISTRY_PASSWORD: ${{ secrets.REGISTRY_PASSWORD }} \ No newline at end of file diff --git a/.github/workflows/integration-test.yaml b/.github/workflows/integration-test.yaml index 7aa1aa2f..5343b571 100644 --- a/.github/workflows/integration-test.yaml +++ b/.github/workflows/integration-test.yaml @@ -22,6 +22,7 @@ jobs: with: test_path: "tests/integ/test_pytorch_local_gpu.py" build_img_cmd: "make inference-pytorch-gpu" + test_parallelism: "1" pytorch-integration-remote-gpu: name: Remote Integration Tests - GPU uses: ./.github/workflows/integration-test-action.yaml @@ -41,4 +42,5 @@ jobs: with: test_path: "tests/integ/test_pytorch_local_cpu.py" build_img_cmd: "make inference-pytorch-cpu" + test_parallelism: "1" runs_on: "['ci']" \ No newline at end of file diff --git a/.gitignore b/.gitignore index bb0c387b..3dcab16d 100644 --- a/.gitignore +++ b/.gitignore @@ -179,4 +179,6 @@ model tests/tmp tmp/ act.sh -.act \ No newline at end of file +.act +tmp* +log-* \ No newline at end of file diff --git a/README.md b/README.md index f3056a89..c7543569 100644 --- a/README.md +++ b/README.md @@ -8,8 +8,16 @@ Hugging Face Inference Toolkit is for serving πŸ€— Transformers models in containers. This library provides default pre-processing, predict and postprocessing for Transformers, Sentence Tranfsformers. It is also possible to define custom `handler.py` for customization. The Toolkit is build to work with the [Hugging Face Hub](https://huggingface.co/models). --- + ## πŸ’» Getting Started with Hugging Face Inference Toolkit +* Clone the repository `git clone https://github.com/huggingface/huggingface-inference-toolkit`` +* Install the dependencies in dev mode `pip install -e ".[torch, st, diffusers, test,quality]"` + * If you develop on AWS inferentia2 install with `pip install -e ".[test,quality]" optimum-neuron[neuronx] --upgrade` +* Unit Testing: `make unit-test` +* Integration testing: `make integ-test` + + ### Local run ```bash @@ -58,6 +66,21 @@ curl --request POST \ }' ``` +### Custom Handler and dependency support + +The Hugging Face Inference Toolkit allows user to provide a custom inference through a `handler.py` file which is located in the repository. +For an example check [https://huggingface.co/philschmid/custom-pipeline-text-classification](https://huggingface.co/philschmid/custom-pipeline-text-classification): +```bash +model.tar.gz/ +|- pytorch_model.bin +|- .... +|- handler.py +|- requirements.txt +``` +In this example, `pytroch_model.bin` is the model file saved from training, `handler.py` is the custom inference handler, and `requirements.txt` is a requirements file to add additional dependencies. +The custom module can override the following methods: + + ### Vertex AI Support The Hugging Face Inference Toolkit is also supported on Vertex AI, based on [Custom container requirements for prediction](https://cloud.google.com/vertex-ai/docs/predictions/custom-container-requirements). [Environment variables set by Vertex AI](https://cloud.google.com/vertex-ai/docs/predictions/custom-container-requirements#aip-variables) are automatically detected and used by the toolkit. @@ -109,6 +132,69 @@ curl --request POST \ }' ``` +### AWS Inferentia2 Support + +The Hugging Face Inference Toolkit provides support for deploying Hugging Face on AWS Inferentia2. To deploy a model on Inferentia2 you have 3 options: +* Provide `HF_MODEL_ID`, the model repo id on huggingface.co which contains the compiled model under `.neuron` format. e.g. `optimum/bge-base-en-v1.5-neuronx` +* Provide the `HF_OPTIMUM_BATCH_SIZE` and `HF_OPTIMUM_SEQUENCE_LENGTH` environment variables to compile the model on the fly, e.g. `HF_OPTIMUM_BATCH_SIZE=1 HF_OPTIMUM_SEQUENCE_LENGTH=128` +* Include `neuron` dictionary in the [config.json](https://huggingface.co/optimum/tiny_random_bert_neuron/blob/main/config.json) file in the model archive, e.g. `neuron: {"static_batch_size": 1, "static_sequence_length": 128}` + +The currently supported tasks can be found [here](https://huggingface.co/docs/optimum-neuron/en/package_reference/supported_models). If you plan to deploy an LLM, we recommend taking a look at [Neuronx TGI](https://huggingface.co/blog/text-generation-inference-on-inferentia2), which is purposly build for LLMs. + +#### Local run with HF_MODEL_ID and HF_TASK + +Start Hugging Face Inference Toolkit with the following environment variables. + +_Note: You need to run this on an Inferentia2 instance._ + +- transformers `text-classification` with `HF_OPTIMUM_BATCH_SIZE` and `HF_OPTIMUM_SEQUENCE_LENGTH` +```bash +mkdir tmp2/ +HF_MODEL_ID="distilbert/distilbert-base-uncased-finetuned-sst-2-english" HF_TASK="text-classification" HF_OPTIMUM_BATCH_SIZE=1 HF_OPTIMUM_SEQUENCE_LENGTH=128 HF_MODEL_DIR=tmp2 uvicorn src.huggingface_inference_toolkit.webservice_starlette:app --port 5000 +``` +- sentence transformers `feature-extration` with `HF_OPTIMUM_BATCH_SIZE` and `HF_OPTIMUM_SEQUENCE_LENGTH` +```bash +HF_MODEL_ID="sentence-transformers/all-MiniLM-L6-v2" HF_TASK="feature-extraction" HF_OPTIMUM_BATCH_SIZE=1 HF_OPTIMUM_SEQUENCE_LENGTH=128 HF_MODEL_DIR=tmp2 uvicorn src.huggingface_inference_toolkit.webservice_starlette:app --port 5000 +``` + +Send request + +```bash +curl --request POST \ + --url http://localhost:5000 \ + --header 'Content-Type: application/json' \ + --data '{ + "inputs": "Wow, this is such a great product. I love it!" +}' +``` + +#### Container run with HF_MODEL_ID and HF_TASK + + +1. build the preferred container for either CPU or GPU for PyTorch o. + +```bash +make inference-pytorch-inf2 +``` + +2. Run the container and provide either environment variables to the HUB model you want to use or mount a volume to the container, where your model is stored. + +```bash +docker run -ti -p 5000:5000 -e HF_MODEL_ID="distilbert/distilbert-base-uncased-finetuned-sst-2-english" -e HF_TASK="text-classification" -e HF_OPTIMUM_BATCH_SIZE=1 -e HF_OPTIMUM_SEQUENCE_LENGTH=128 --device=/dev/neuron0 integration-test-pytorch:inf2 +``` + +3. Send request + +```bash +curl --request POST \ + --url http://localhost:5000 \ + --header 'Content-Type: application/json' \ + --data '{ + "inputs": "Wow, this is such a great product. I love it!", + "parameters": { "top_k": 2 } +}' +``` + --- @@ -168,61 +254,23 @@ The `HF_FRAMEWORK` environment variable defines the base deep learning framework HF_FRAMEWORK="pytorch" ``` -### `HF_ENDPOINT` +#### `HF_OPTIMUM_BATCH_SIZE` -The `HF_ENDPOINT` environment variable indicates whether the service is run inside the HF Inference endpoint service to adjust the `logging` config. +The `HF_OPTIMUM_BATCH_SIZE` environment variable defines the batch size, which is used when compiling the model to Neuron. The default value is `1`. Not required when model is already converted. ```bash -HF_ENDPOINT="True" +HF_OPTIMUM_BATCH_SIZE="1" ``` +#### `HF_OPTIMUM_SEQUENCE_LENGTH` ---- +The `HF_OPTIMUM_SEQUENCE_LENGTH` environment variable defines the sequence length, which is used when compiling the model to Neuron. There is no default value. Not required when model is already converted. -## πŸ§‘πŸ»β€πŸ’» Custom Handler and dependency support - -The Hugging Face Inference Toolkit allows user to provide a custom inference through a `handler.py` file which is located in the repository. -For an example check [https://huggingface.co/philschmid/custom-pipeline-text-classification](https://huggingface.co/philschmid/custom-pipeline-text-classification): ```bash -model.tar.gz/ -|- pytorch_model.bin -|- .... -|- handler.py -|- requirements.txt +HF_OPTIMUM_SEQUENCE_LENGTH="128" ``` -In this example, `pytroch_model.bin` is the model file saved from training, `handler.py` is the custom inference handler, and `requirements.txt` is a requirements file to add additional dependencies. -The custom module can override the following methods: - -## β˜‘οΈ Supported & Tested Tasks - -Below you ll find a list of supported and tested transformers and sentence transformers tasks. Each of those are always tested through integration tests. In addition to those tasks you can always provide `custom`, which expect a `handler.py` file to be provided. - -```bash -"text-classification", -"zero-shot-classification", -"ner", -"question-answering", -"fill-mask", -"summarization", -"translation_xx_to_yy", -"text2text-generation", -"text-generation", -"feature-extraction", -"image-classification", -"automatic-speech-recognition", -"audio-classification", -"object-detection", -"image-segmentation", -"table-question-answering", -"conversational" -"sentence-similarity", -"sentence-embeddings", -"sentence-ranking", -# TODO currently not supported due to multimodality input -# "visual-question-answering", -# "zero-shot-image-classification", -``` +--- ## βš™ Supported Frontend @@ -232,21 +280,11 @@ Below you ll find a list of supported and tested transformers and sentence trans - [ ]Β Starlette (SageMaker) --- -## 🀝 Contributing - -### Development - -* Recommended Python version: 3.11 -* We recommend `pyenv` for easily switching between different Python versions -* There are two options for unit and integration tests: - * `Make` - see `makefile` -#### Testing with Make - -* Unit Testing: `make unit-test` -* Integration testing: `make integ-test` +## 🀝 Contributing --- + ## πŸ“œ License TBD. diff --git a/dockerfiles/pytorch/Dockerfile b/dockerfiles/pytorch/Dockerfile index c554ce59..348ef232 100644 --- a/dockerfiles/pytorch/Dockerfile +++ b/dockerfiles/pytorch/Dockerfile @@ -51,4 +51,4 @@ ENTRYPOINT ["bash", "-c", "./entrypoint.sh"] from base as vertex # Install Vertex AI requiremented packages -RUN pip install --no-cache-dir google-cloud-storage +RUN pip install --no-cache-dir google-cloud-storage \ No newline at end of file diff --git a/dockerfiles/pytorch/Dockerfile.inf2 b/dockerfiles/pytorch/Dockerfile.inf2 new file mode 100644 index 00000000..bb8459ec --- /dev/null +++ b/dockerfiles/pytorch/Dockerfile.inf2 @@ -0,0 +1,122 @@ +# Build based on https://github.com/aws/deep-learning-containers/blob/master/huggingface/pytorch/inference/docker/2.1/py3/sdk2.18.0/Dockerfile.neuronx +FROM ubuntu:20.04 + +LABEL maintainer="Hugging Face" + +ARG PYTHON=python3.10 +ARG PYTHON_VERSION=3.10.12 +ARG MAMBA_VERSION=23.1.0-4 + +# Neuron SDK components version numbers +ARG NEURONX_FRAMEWORK_VERSION=2.1.2.2.1.0 +ARG NEURONX_DISTRIBUTED_VERSION=0.7.0 +ARG NEURONX_CC_VERSION=2.13.66.0 +ARG NEURONX_TRANSFORMERS_VERSION=0.10.0.21 +ARG NEURONX_COLLECTIVES_LIB_VERSION=2.20.22.0-c101c322e +ARG NEURONX_RUNTIME_LIB_VERSION=2.20.22.0-1b3ca6425 +ARG NEURONX_TOOLS_VERSION=2.17.1.0 + +# HF ARGS +ARG OPTIMUM_NEURON_VERSION=0.0.23 + +# See http://bugs.python.org/issue19846 +ENV LANG C.UTF-8 +ENV LD_LIBRARY_PATH /opt/aws/neuron/lib:/lib/x86_64-linux-gnu:/opt/conda/lib/:$LD_LIBRARY_PATH +ENV PATH /opt/conda/bin:/opt/aws/neuron/bin:$PATH + +RUN apt-get update \ + && apt-get upgrade -y \ + && apt-get install -y --no-install-recommends software-properties-common \ + && add-apt-repository ppa:openjdk-r/ppa \ + && apt-get update \ + && apt-get install -y --no-install-recommends \ + build-essential \ + apt-transport-https \ + ca-certificates \ + cmake \ + curl \ + emacs \ + git \ + jq \ + libgl1-mesa-glx \ + libsm6 \ + libxext6 \ + libxrender-dev \ + openjdk-11-jdk \ + vim \ + wget \ + unzip \ + zlib1g-dev \ + libcap-dev \ + gpg-agent \ + && rm -rf /var/lib/apt/lists/* \ + && rm -rf /tmp/tmp* \ + && apt-get clean + +RUN echo "deb https://apt.repos.neuron.amazonaws.com focal main" > /etc/apt/sources.list.d/neuron.list +RUN wget -qO - https://apt.repos.neuron.amazonaws.com/GPG-PUB-KEY-AMAZON-AWS-NEURON.PUB | apt-key add - + +# Install Neuronx tools +RUN apt-get update \ + && apt-get install -y \ + aws-neuronx-tools=$NEURONX_TOOLS_VERSION \ + aws-neuronx-collectives=$NEURONX_COLLECTIVES_LIB_VERSION \ + aws-neuronx-runtime-lib=$NEURONX_RUNTIME_LIB_VERSION \ + && rm -rf /var/lib/apt/lists/* \ + && rm -rf /tmp/tmp* \ + && apt-get clean + +# https://github.com/docker-library/openjdk/issues/261 https://github.com/docker-library/openjdk/pull/263/files +RUN keytool -importkeystore -srckeystore /etc/ssl/certs/java/cacerts -destkeystore /etc/ssl/certs/java/cacerts.jks -deststoretype JKS -srcstorepass changeit -deststorepass changeit -noprompt; \ + mv /etc/ssl/certs/java/cacerts.jks /etc/ssl/certs/java/cacerts; \ + /var/lib/dpkg/info/ca-certificates-java.postinst configure; + +RUN curl -L -o ~/mambaforge.sh https://github.com/conda-forge/miniforge/releases/download/${MAMBA_VERSION}/Mambaforge-${MAMBA_VERSION}-Linux-x86_64.sh \ + && chmod +x ~/mambaforge.sh \ + && ~/mambaforge.sh -b -p /opt/conda \ + && rm ~/mambaforge.sh \ + && /opt/conda/bin/conda update -y conda \ + && /opt/conda/bin/conda install -c conda-forge -y \ + python=$PYTHON_VERSION \ + pyopenssl \ + cython \ + mkl-include \ + mkl \ + botocore \ + parso \ + scipy \ + typing \ + # Below 2 are included in miniconda base, but not mamba so need to install + conda-content-trust \ + charset-normalizer \ + && /opt/conda/bin/conda update -y conda \ + && /opt/conda/bin/conda clean -ya + +RUN conda install -c conda-forge \ + scikit-learn \ + h5py \ + requests \ + && conda clean -ya \ + && pip install --upgrade pip --trusted-host pypi.org --trusted-host files.pythonhosted.org \ + && ln -s /opt/conda/bin/pip /usr/local/bin/pip3 \ + && pip install --no-cache-dir "protobuf>=3.18.3,<4" setuptools==69.5.1 packaging + +WORKDIR / + +# install Hugging Face libraries and its dependencies +RUN pip install --extra-index-url https://pip.repos.neuron.amazonaws.com --no-cache-dir optimum-neuron[neuronx]==${OPTIMUM_NEURON_VERSION} \ + && pip install --no-deps --no-cache-dir -U torchvision==0.16.* + + +COPY . . +# install wheel and setuptools +RUN pip install --no-cache-dir -U pip ".[st]" + +# copy application +COPY src/huggingface_inference_toolkit huggingface_inference_toolkit +COPY src/huggingface_inference_toolkit/webservice_starlette.py webservice_starlette.py + +# copy entrypoint and change permissions +COPY --chmod=0755 scripts/entrypoint.sh entrypoint.sh + +ENTRYPOINT ["bash", "-c", "./entrypoint.sh"] \ No newline at end of file diff --git a/makefile b/makefile index 3502d83e..ef45c469 100644 --- a/makefile +++ b/makefile @@ -1,6 +1,6 @@ .PHONY: quality style unit-test integ-test -check_dirs := src +check_dirs := src tests # run tests @@ -13,12 +13,12 @@ integ-test: # Check that source code meets quality standards quality: - ruff $(check_dirs) + ruff check $(check_dirs) # Format source code automatically style: - ruff $(check_dirs) --fix + ruff check $(check_dirs) --fix inference-pytorch-gpu: docker build -f dockerfiles/pytorch/Dockerfile -t integration-test-pytorch:gpu . @@ -26,6 +26,9 @@ inference-pytorch-gpu: inference-pytorch-cpu: docker build --build-arg="BASE_IMAGE=ubuntu:22.04" -f dockerfiles/pytorch/Dockerfile -t integration-test-pytorch:cpu . +inference-pytorch-inf2: + docker build -f dockerfiles/pytorch/Dockerfile.inf2 -t integration-test-pytorch:inf2 . + vertex-pytorch-gpu: docker build -t vertex -f dockerfiles/pytorch/Dockerfile -t integration-test-pytorch:gpu . diff --git a/setup.py b/setup.py index 37fdb0b0..a199a4d3 100644 --- a/setup.py +++ b/setup.py @@ -49,6 +49,7 @@ "isort", "ruff" ] +extras["inf2"] = ["optimum-neuron"] setup( name="huggingface-inference-toolkit", diff --git a/src/huggingface_inference_toolkit/diffusers_utils.py b/src/huggingface_inference_toolkit/diffusers_utils.py index 521a85df..f6241032 100644 --- a/src/huggingface_inference_toolkit/diffusers_utils.py +++ b/src/huggingface_inference_toolkit/diffusers_utils.py @@ -1,10 +1,8 @@ import importlib.util -import logging from transformers.utils.import_utils import is_torch_bf16_gpu_available -logger = logging.getLogger(__name__) -logging.basicConfig(format="%(asctime)s | %(levelname)s | %(message)s", level=logging.INFO) +from huggingface_inference_toolkit.logging import logger _diffusers = importlib.util.find_spec("diffusers") is not None @@ -15,7 +13,11 @@ def is_diffusers_available(): if is_diffusers_available(): import torch - from diffusers import AutoPipelineForText2Image, DPMSolverMultistepScheduler, StableDiffusionPipeline + from diffusers import ( + AutoPipelineForText2Image, + DPMSolverMultistepScheduler, + StableDiffusionPipeline, + ) class IEAutoPipelineForText2Image: @@ -25,11 +27,15 @@ def __init__(self, model_dir: str, device: str = None): # needs "cuda" for GPU dtype = torch.bfloat16 if is_torch_bf16_gpu_available() else torch.float16 device_map = "auto" if device == "cuda" else None - self.pipeline = AutoPipelineForText2Image.from_pretrained(model_dir, torch_dtype=dtype, device_map=device_map) + self.pipeline = AutoPipelineForText2Image.from_pretrained( + model_dir, torch_dtype=dtype, device_map=device_map + ) # try to use DPMSolverMultistepScheduler if isinstance(self.pipeline, StableDiffusionPipeline): try: - self.pipeline.scheduler = DPMSolverMultistepScheduler.from_config(self.pipeline.scheduler.config) + self.pipeline.scheduler = DPMSolverMultistepScheduler.from_config( + self.pipeline.scheduler.config + ) except Exception: pass @@ -43,7 +49,9 @@ def __call__( # TODO: add support for more images (Reason is correct output) if "num_images_per_prompt" in kwargs: kwargs.pop("num_images_per_prompt") - logger.warning("Sending num_images_per_prompt > 1 to pipeline is not supported. Using default value 1.") + logger.warning( + "Sending num_images_per_prompt > 1 to pipeline is not supported. Using default value 1." + ) # Call pipeline with parameters out = self.pipeline(prompt, num_images_per_prompt=1, **kwargs) diff --git a/src/huggingface_inference_toolkit/handler.py b/src/huggingface_inference_toolkit/handler.py index 0a8c93b8..5b164af8 100644 --- a/src/huggingface_inference_toolkit/handler.py +++ b/src/huggingface_inference_toolkit/handler.py @@ -1,12 +1,11 @@ -import logging import os from pathlib import Path from typing import Optional, Union -from huggingface_inference_toolkit.utils import check_and_register_custom_pipeline_from_directory, get_pipeline - -logger = logging.getLogger(__name__) -logging.basicConfig(format="%(asctime)s | %(levelname)s | %(message)s", level=logging.INFO) +from huggingface_inference_toolkit.utils import ( + check_and_register_custom_pipeline_from_directory, + get_pipeline, +) class HuggingFaceHandler: @@ -17,9 +16,7 @@ class HuggingFaceHandler: def __init__(self, model_dir: Union[str, Path], task=None, framework="pt"): self.pipeline = get_pipeline( - model_dir=model_dir, - task=task, - framework=framework + model_dir=model_dir, task=task, framework=framework ) def __call__(self, data): @@ -46,6 +43,7 @@ class VertexAIHandler(HuggingFaceHandler): A Default Vertex AI Hugging Face Inference Handler which abstracts the Vertex AI specific logic for inference. """ + def __init__(self, model_dir: Union[str, Path], task=None, framework="pt"): super().__init__(model_dir, task, framework) @@ -57,7 +55,9 @@ def __call__(self, data): :return: prediction output """ if "instances" not in data: - raise ValueError("The request body must contain a key 'instances' with a list of instances.") + raise ValueError( + "The request body must contain a key 'instances' with a list of instances." + ) parameters = data.pop("parameters", None) predictions = [] @@ -69,9 +69,9 @@ def __call__(self, data): # reutrn predictions return {"predictions": predictions} + def get_inference_handler_either_custom_or_default_handler( - model_dir: Path, - task: Optional[str] = None + model_dir: Path, task: Optional[str] = None ): """ Returns the appropriate inference handler based on the given model directory and task. diff --git a/src/huggingface_inference_toolkit/logging.py b/src/huggingface_inference_toolkit/logging.py new file mode 100644 index 00000000..513d94fe --- /dev/null +++ b/src/huggingface_inference_toolkit/logging.py @@ -0,0 +1,29 @@ +import logging +import sys + + +def setup_logging(): + # Remove all existing handlers + for handler in logging.root.handlers[:]: + logging.root.removeHandler(handler) + + # Configure the root logger + logging.basicConfig( + level=logging.INFO, + format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", + datefmt="%Y-%m-%d %H:%M:%S", + stream=sys.stdout, + ) + + # Remove Uvicorn loggers + logging.getLogger("uvicorn").handlers.clear() + logging.getLogger("uvicorn.access").handlers.clear() + logging.getLogger("uvicorn.error").handlers.clear() + + # Create a logger for your application + logger = logging.getLogger("huggingface_inference_toolkit") + return logger + + +# Create and configure the logger +logger = setup_logging() diff --git a/src/huggingface_inference_toolkit/optimum_utils.py b/src/huggingface_inference_toolkit/optimum_utils.py new file mode 100644 index 00000000..39419bb7 --- /dev/null +++ b/src/huggingface_inference_toolkit/optimum_utils.py @@ -0,0 +1,114 @@ +import importlib.util +import os + +from huggingface_inference_toolkit.logging import logger + +_optimum_neuron = False +if importlib.util.find_spec("optimum") is not None: + if importlib.util.find_spec("optimum.neuron") is not None: + _optimum_neuron = True + + +def is_optimum_neuron_available(): + return _optimum_neuron + + +def get_input_shapes(model_dir): + """Method to get input shapes from model config file. If config file is not present, default values are returned.""" + from transformers import AutoConfig + + input_shapes = {} + input_shapes_available = False + # try to get input shapes from config file + try: + config = AutoConfig.from_pretrained(model_dir) + if hasattr(config, "neuron"): + # check if static batch size and sequence length are available + if config.neuron.get("static_batch_size", None) and config.neuron.get( + "static_sequence_length", None + ): + input_shapes["batch_size"] = config.neuron["static_batch_size"] + input_shapes["sequence_length"] = config.neuron[ + "static_sequence_length" + ] + input_shapes_available = True + logger.info( + f"Input shapes found in config file. Using input shapes from config with batch size {input_shapes['batch_size']} and sequence length {input_shapes['sequence_length']}" + ) + else: + # Add warning if environment variables are set but will be ignored + if os.environ.get("HF_OPTIMUM_BATCH_SIZE", None) is not None: + logger.warning( + "HF_OPTIMUM_BATCH_SIZE environment variable is set. Environment variable will be ignored and input shapes from config file will be used." + ) + if os.environ.get("HF_OPTIMUM_SEQUENCE_LENGTH", None) is not None: + logger.warning( + "HF_OPTIMUM_SEQUENCE_LENGTH environment variable is set. Environment variable will be ignored and input shapes from config file will be used." + ) + except Exception: + input_shapes_available = False + + # return input shapes if available + if input_shapes_available: + return input_shapes + + # extract input shapes from environment variables + sequence_length = os.environ.get("HF_OPTIMUM_SEQUENCE_LENGTH", None) + if sequence_length is None: + raise ValueError( + "HF_OPTIMUM_SEQUENCE_LENGTH environment variable is not set. Please set HF_OPTIMUM_SEQUENCE_LENGTH to a positive integer." + ) + + if not int(sequence_length) > 0: + raise ValueError( + f"HF_OPTIMUM_SEQUENCE_LENGTH must be set to a positive integer. Current value is {sequence_length}" + ) + batch_size = os.environ.get("HF_OPTIMUM_BATCH_SIZE", 1) + logger.info( + f"Using input shapes from environment variables with batch size {batch_size} and sequence length {sequence_length}" + ) + return {"batch_size": int(batch_size), "sequence_length": int(sequence_length)} + + +def get_optimum_neuron_pipeline(task, model_dir): + """Method to get optimum neuron pipeline for a given task. Method checks if task is supported by optimum neuron and if required environment variables are set, in case model is not converted. If all checks pass, optimum neuron pipeline is returned. If checks fail, an error is raised.""" + logger.info("Getting optimum neuron pipeline.") + from optimum.neuron.pipelines.transformers.base import ( + NEURONX_SUPPORTED_TASKS, + pipeline, + ) + from optimum.neuron.utils import NEURON_FILE_NAME + + # convert from os.path or path + if not isinstance(model_dir, str): + model_dir = str(model_dir) + + # check if task is sentence-embeddings and convert to feature-extraction, as sentence-embeddings is supported in feature-extraction pipeline + if task == "sentence-embeddings": + task = "feature-extraction" + + # check task support + if task not in NEURONX_SUPPORTED_TASKS: + raise ValueError( + f"Task {task} is not supported by optimum neuron and inf2. Supported tasks are: {list(NEURONX_SUPPORTED_TASKS.keys())}" + ) + + # check if model is already converted and has input shapes available + export = True + if NEURON_FILE_NAME in os.listdir(model_dir): + export = False + if export: + logger.info( + "Model is not converted. Checking if required environment variables are set and converting model." + ) + + # get static input shapes to run inference + input_shapes = get_input_shapes(model_dir) + # set NEURON_RT_NUM_CORES to 1 to avoid conflicts with multiple HTTP workers + # TODO: Talk to optimum team what are the best options for encoder models to run on 2 neuron cores + # os.environ["NEURON_RT_NUM_CORES"] = "1" + # get optimum neuron pipeline + neuron_pipe = pipeline( + task, model=model_dir, export=export, input_shapes=input_shapes + ) + return neuron_pipe diff --git a/src/huggingface_inference_toolkit/utils.py b/src/huggingface_inference_toolkit/utils.py index 1570317b..a0519d92 100644 --- a/src/huggingface_inference_toolkit/utils.py +++ b/src/huggingface_inference_toolkit/utils.py @@ -1,5 +1,4 @@ import importlib.util -import logging import sys from pathlib import Path from typing import Optional, Union @@ -14,15 +13,16 @@ get_diffusers_pipeline, is_diffusers_available, ) +from huggingface_inference_toolkit.logging import logger +from huggingface_inference_toolkit.optimum_utils import ( + get_optimum_neuron_pipeline, + is_optimum_neuron_available, +) from huggingface_inference_toolkit.sentence_transformers_utils import ( get_sentence_transformers_pipeline, is_sentence_transformers_available, ) -logger = logging.getLogger(__name__) -logging.basicConfig(format="%(asctime)s | %(levelname)s | %(message)s", level=logging.INFO) - - if is_tf_available(): import tensorflow as tf @@ -69,22 +69,6 @@ def create_artifact_filter(framework): return [] -def wrap_conversation_pipeline(pipeline): - """ - Wrap a Conversation with a helper for better UX when using REST API - """ - - def wrapped_pipeline(inputs, *args, **kwargs): - logger.info(f"Inputs: {inputs}") - logger.info(f"Args: {args}") - logger.info(f"KWArgs: {kwargs}") - prediction = pipeline(inputs, *args, **kwargs) - logger.info(f"Prediction: {prediction}") - return list(prediction) - - return wrapped_pipeline - - def _is_gpu_available(): """ checks if a gpu is available. @@ -150,16 +134,18 @@ def _load_repository_from_hf( # create regex to only include the framework specific weights ignore_regex = create_artifact_filter(framework) - logger.info(f"Ignore regex pattern for files, which are not downloaded: { ', '.join(ignore_regex) }") + logger.info( + f"Ignore regex pattern for files, which are not downloaded: { ', '.join(ignore_regex) }" + ) # Download the repository to the workdir and filter out non-framework # specific weights snapshot_download( - repo_id = repository_id, - revision = revision, - local_dir = str(target_dir), - local_dir_use_symlinks = False, - ignore_patterns = ignore_regex, + repo_id=repository_id, + revision=revision, + local_dir=str(target_dir), + local_dir_use_symlinks=False, + ignore_patterns=ignore_regex, ) return target_dir @@ -191,7 +177,9 @@ def check_and_register_custom_pipeline_from_directory(model_dir): Please update to the new format. See documentation for more information.""" ) - spec = importlib.util.spec_from_file_location("pipeline.PreTrainedPipeline", legacy_module) + spec = importlib.util.spec_from_file_location( + "pipeline.PreTrainedPipeline", legacy_module + ) if spec: # add the whole directory to path for submodlues sys.path.insert(0, model_dir) @@ -222,14 +210,16 @@ def get_device(): def get_pipeline( task: str, model_dir: Path, - framework = "pytorch", **kwargs, ) -> Pipeline: """ create pipeline class for a specific task based on local saved model """ device = get_device() - logger.info(f"Using device { 'GPU' if device == 0 else 'CPU'}") + if is_optimum_neuron_available(): + logger.info("Using device Neuron") + else: + logger.info(f"Using device { 'GPU' if device == 0 else 'CPU'}") if task is None: raise EnvironmentError( @@ -248,53 +238,38 @@ def get_pipeline( kwargs["feature_extractor"] = model_dir elif task in {"image-to-text"}: pass + elif task == "conversational": + task = "text-generation" else: kwargs["tokenizer"] = model_dir - if is_optimum_available(): - logger.info("Optimum is not implemented yet using default pipeline.") - hf_pipeline = pipeline(task=task, model=model_dir, device=device, **kwargs) + if is_optimum_neuron_available(): + hf_pipeline = get_optimum_neuron_pipeline(task=task, model_dir=model_dir) elif is_sentence_transformers_available() and task in [ "sentence-similarity", "sentence-embeddings", "sentence-ranking", ]: hf_pipeline = get_sentence_transformers_pipeline( - task=task, - model_dir=model_dir, - device=device, - **kwargs + task=task, model_dir=model_dir, device=device, **kwargs ) elif is_diffusers_available() and task == "text-to-image": hf_pipeline = get_diffusers_pipeline( - task=task, - model_dir=model_dir, - device=device, - **kwargs + task=task, model_dir=model_dir, device=device, **kwargs ) else: - hf_pipeline = pipeline( - task=task, - model=model_dir, - device=device, - **kwargs - ) - - # wrap specific pipeline to support better ux - if task == "conversational": - hf_pipeline = wrap_conversation_pipeline(hf_pipeline) + hf_pipeline = pipeline(task=task, model=model_dir, device=device, **kwargs) - elif task == "automatic-speech-recognition" and isinstance( - hf_pipeline.model, - WhisperForConditionalGeneration + if task == "automatic-speech-recognition" and isinstance( + hf_pipeline.model, WhisperForConditionalGeneration ): # set chunk length to 30s for whisper to enable long audio files hf_pipeline._preprocess_params["chunk_length_s"] = 30 - hf_pipeline.model.config.forced_decoder_ids = hf_pipeline.tokenizer.get_decoder_prompt_ids( - language="english", - task="transcribe" + hf_pipeline.model.config.forced_decoder_ids = ( + hf_pipeline.tokenizer.get_decoder_prompt_ids( + language="english", task="transcribe" + ) ) - return hf_pipeline diff --git a/src/huggingface_inference_toolkit/vertex_ai_utils.py b/src/huggingface_inference_toolkit/vertex_ai_utils.py index 19dd41e2..cb588174 100644 --- a/src/huggingface_inference_toolkit/vertex_ai_utils.py +++ b/src/huggingface_inference_toolkit/vertex_ai_utils.py @@ -1,26 +1,20 @@ -import logging import re from pathlib import Path from typing import Union -logger = logging.getLogger(__name__) -logging.basicConfig(format="%(asctime)s | %(levelname)s | %(message)s", level=logging.INFO) - - - -_logger = logging.getLogger(__name__) - +from huggingface_inference_toolkit.logging import logger GCS_URI_PREFIX = "gs://" # copied from https://github.com/googleapis/python-aiplatform/blob/94d838d8cfe1599bc2d706e66080c05108821986/google/cloud/aiplatform/utils/prediction_utils.py#L121 -def _load_repository_from_gcs(artifact_uri: str, target_dir: Union[str, Path]="/tmp"): +def _load_repository_from_gcs(artifact_uri: str, target_dir: Union[str, Path] = "/tmp"): """ Load files from GCS path to target_dir """ from google.cloud import storage - _logger.info(f"Loading model artifacts from {artifact_uri} to {target_dir}") + + logger.info(f"Loading model artifacts from {artifact_uri} to {target_dir}") target_dir = Path(target_dir) if artifact_uri.startswith(GCS_URI_PREFIX): @@ -43,4 +37,3 @@ def _load_repository_from_gcs(artifact_uri: str, target_dir: Union[str, Path]="/ blob.download_to_filename(name_without_prefix) return str(target_dir.absolute()) - diff --git a/src/huggingface_inference_toolkit/webservice_starlette.py b/src/huggingface_inference_toolkit/webservice_starlette.py index 862560dc..1dddb5d3 100644 --- a/src/huggingface_inference_toolkit/webservice_starlette.py +++ b/src/huggingface_inference_toolkit/webservice_starlette.py @@ -1,4 +1,3 @@ -import logging import os from pathlib import Path from time import perf_counter @@ -17,26 +16,19 @@ HF_REVISION, HF_TASK, ) -from huggingface_inference_toolkit.handler import get_inference_handler_either_custom_or_default_handler +from huggingface_inference_toolkit.handler import ( + get_inference_handler_either_custom_or_default_handler, +) +from huggingface_inference_toolkit.logging import logger from huggingface_inference_toolkit.serialization.base import ContentType from huggingface_inference_toolkit.serialization.json_utils import Jsoner -from huggingface_inference_toolkit.utils import _load_repository_from_hf, convert_params_to_int_or_bool +from huggingface_inference_toolkit.utils import ( + _load_repository_from_hf, + convert_params_to_int_or_bool, +) from huggingface_inference_toolkit.vertex_ai_utils import _load_repository_from_gcs -def config_logging(level=logging.INFO): - logging.basicConfig(format="%(asctime)s | %(levelname)s | %(message)s", datefmt="", level=level) - # disable uvicorn access logs to hide /health - uvicorn_access = logging.getLogger("uvicorn.access") - uvicorn_access.disabled = True - # remove double logs for errors - logging.getLogger("uvicorn").removeHandler(logging.getLogger("uvicorn").handlers[0]) - - -config_logging() -logger = logging.getLogger(__name__) - - async def prepare_model_artifacts(): global inference_handler # 1. check if model artifacts available in HF_MODEL_DIR @@ -52,8 +44,10 @@ async def prepare_model_artifacts(): ) # 3. check if in Vertex AI environment and load from GCS # If artifactUri not on Model Creation not set returns an empty string - elif len(os.environ.get("AIP_STORAGE_URI", '')) > 0: - _load_repository_from_gcs(os.environ["AIP_STORAGE_URI"], target_dir=HF_MODEL_DIR) + elif len(os.environ.get("AIP_STORAGE_URI", "")) > 0: + _load_repository_from_gcs( + os.environ["AIP_STORAGE_URI"], target_dir=HF_MODEL_DIR + ) # 4. if not available, raise error else: raise ValueError( @@ -65,7 +59,10 @@ async def prepare_model_artifacts(): logger.info(f"Initializing model from directory:{HF_MODEL_DIR}") # 2. determine correct inference handler - inference_handler = get_inference_handler_either_custom_or_default_handler(HF_MODEL_DIR, task=HF_TASK) + inference_handler = get_inference_handler_either_custom_or_default_handler( + HF_MODEL_DIR, task=HF_TASK + ) + print("hello world") logger.info("Model initialized successfully") @@ -78,32 +75,47 @@ async def predict(request): # extracts content from request content_type = request.headers.get("content-Type", None) # try to deserialize payload - deserialized_body = ContentType.get_deserializer(content_type).deserialize(await request.body()) + deserialized_body = ContentType.get_deserializer(content_type).deserialize( + await request.body() + ) # checks if input schema is correct if "inputs" not in deserialized_body and "instances" not in deserialized_body: - raise ValueError(f"Body needs to provide a inputs key, recieved: {orjson.dumps(deserialized_body)}") + raise ValueError( + f"Body needs to provide a inputs key, recieved: {orjson.dumps(deserialized_body)}" + ) # check for query parameter and add them to the body if request.query_params and "parameters" not in deserialized_body: - deserialized_body["parameters"] = convert_params_to_int_or_bool(dict(request.query_params)) + deserialized_body["parameters"] = convert_params_to_int_or_bool( + dict(request.query_params) + ) # tracks request time start_time = perf_counter() # run async not blocking call pred = await async_handler_call(inference_handler, deserialized_body) # log request time - logger.info(f"POST {request.url.path} | Duration: {(perf_counter()-start_time) *1000:.2f} ms") + logger.info( + f"POST {request.url.path} | Duration: {(perf_counter()-start_time) *1000:.2f} ms" + ) # response extracts content from request accept = request.headers.get("accept", None) if accept is None or accept == "*/*": accept = "application/json" # deserialized and resonds with json - serialized_response_body = ContentType.get_serializer(accept).serialize(pred, accept) + serialized_response_body = ContentType.get_serializer(accept).serialize( + pred, accept + ) return Response(serialized_response_body, media_type=accept) except Exception as e: logger.error(e) - return Response(Jsoner.serialize({"error": str(e)}), status_code=400, media_type="application/json") + return Response( + Jsoner.serialize({"error": str(e)}), + status_code=400, + media_type="application/json", + ) + # Create app based on which cloud environment is used if os.getenv("AIP_MODE", None) == "PREDICTION": @@ -112,7 +124,9 @@ async def predict(request): _predict_route = os.getenv("AIP_PREDICT_ROUTE", None) _health_route = os.getenv("AIP_HEALTH_ROUTE", None) if _predict_route is None or _health_route is None: - raise ValueError("AIP_PREDICT_ROUTE and AIP_HEALTH_ROUTE need to be set in Vertex AI environment") + raise ValueError( + "AIP_PREDICT_ROUTE and AIP_HEALTH_ROUTE need to be set in Vertex AI environment" + ) app = Starlette( debug=False, @@ -132,4 +146,4 @@ async def predict(request): Route("/predict", predict, methods=["POST"]), ], on_startup=[prepare_model_artifacts], -) + ) diff --git a/tests/integ/config.py b/tests/integ/config.py index b1d4d605..9fc3d56a 100644 --- a/tests/integ/config.py +++ b/tests/integ/config.py @@ -3,6 +3,8 @@ from tests.integ.utils import ( validate_automatic_speech_recognition, validate_classification, + validate_conversational, + validate_custom, validate_feature_extraction, validate_fill_mask, validate_ner, @@ -14,11 +16,8 @@ validate_text_to_image, validate_translation, validate_zero_shot_classification, - validate_custom, - validate_conversational ) - task2model = { "text-classification": { "pytorch": "hf-internal-testing/tiny-random-distilbert", @@ -32,7 +31,7 @@ "pytorch": "hf-internal-testing/tiny-random-bert", "tensorflow": "hf-internal-testing/tiny-random-bert", }, - "ner": { + "token-classification": { "pytorch": "hf-internal-testing/tiny-random-roberta", "tensorflow": "hf-internal-testing/tiny-random-roberta", }, @@ -81,7 +80,7 @@ "tensorflow": "hf-internal-testing/tiny-random-clip-zero-shot-image-classification", }, "conversational": { - #"pytorch": "hf-internal-testing/tiny-random-blenderbot-small", + # "pytorch": "hf-internal-testing/tiny-random-blenderbot-small", "pytorch": "microsoft/DialoGPT-small", "tensorflow": None, }, @@ -119,7 +118,7 @@ "parameters": {"candidate_labels": ["refund", "legal", "faq"]}, }, "feature-extraction": {"inputs": "What is the best book."}, - "ner": {"inputs": "My name is Wolfgang and I live in Berlin"}, + "token-classification": {"inputs": "My name is Wolfgang and I live in Berlin"}, "question-answering": { "inputs": { "question": "What is used for inference?", @@ -135,12 +134,24 @@ "inputs": "question: What is 42 context: 42 is the answer to life, the universe and everything." }, "text-generation": {"inputs": "My name is philipp and I am"}, - "image-classification": open(os.path.join(os.getcwd(), "tests/resources/image/tiger.jpeg"), "rb").read(), - "zero-shot-image-classification": open(os.path.join(os.getcwd(), "tests/resources/image/tiger.jpeg"), "rb").read(), - "object-detection": open(os.path.join(os.getcwd(), "tests/resources/image/tiger.jpeg"), "rb").read(), - "image-segmentation": open(os.path.join(os.getcwd(), "tests/resources/image/tiger.jpeg"), "rb").read(), - "automatic-speech-recognition": open(os.path.join(os.getcwd(), "tests/resources/audio/sample1.flac"), "rb").read(), - "audio-classification": open(os.path.join(os.getcwd(), "tests/resources/audio/sample1.flac"), "rb").read(), + "image-classification": open( + os.path.join(os.getcwd(), "tests/resources/image/tiger.jpeg"), "rb" + ).read(), + "zero-shot-image-classification": open( + os.path.join(os.getcwd(), "tests/resources/image/tiger.jpeg"), "rb" + ).read(), + "object-detection": open( + os.path.join(os.getcwd(), "tests/resources/image/tiger.jpeg"), "rb" + ).read(), + "image-segmentation": open( + os.path.join(os.getcwd(), "tests/resources/image/tiger.jpeg"), "rb" + ).read(), + "automatic-speech-recognition": open( + os.path.join(os.getcwd(), "tests/resources/audio/sample1.flac"), "rb" + ).read(), + "audio-classification": open( + os.path.join(os.getcwd(), "tests/resources/audio/sample1.flac"), "rb" + ).read(), "table-question-answering": { "inputs": { "query": "How many stars does the transformers repository have?", @@ -152,27 +163,23 @@ }, } }, - "conversational": {"inputs": [ - { - "role": "user", - "content": "Which movie is the best ?" - }, - { - "role": "assistant", - "content": "It's Die Hard for sure." - }, - { - "role": "user", - "content": "Can you explain why?" - } - ]}, + "conversational": { + "inputs": [ + {"role": "user", "content": "Which movie is the best ?"}, + ] + }, "sentence-similarity": { - "inputs": {"source_sentence": "Lets create an embedding", "sentences": ["Lets create an embedding"]} + "inputs": { + "source_sentence": "Lets create an embedding", + "sentences": ["Lets create an embedding"], + } }, "sentence-embeddings": {"inputs": "Lets create an embedding"}, - "sentence-ranking": {"inputs": ["Lets create an embedding", "Lets create an embedding"]}, + "sentence-ranking": { + "inputs": ["Lets create an embedding", "Lets create an embedding"] + }, "text-to-image": {"inputs": "a man on a horse jumps over a broken down airplane."}, - "custom": {"inputs": "this is a test"} + "custom": {"inputs": "this is a test"}, } task2output = { @@ -182,30 +189,67 @@ "labels": ["refund", "faq", "legal"], "scores": [0.96, 0.027, 0.008], }, - "ner": [ - {"word": "Wolfgang", "score": 0.99, "entity": "I-PER", "index": 4, "start": 11, "end": 19}, - {"word": "Berlin", "score": 0.99, "entity": "I-LOC", "index": 9, "start": 34, "end": 40}, + "token-classification": [ + { + "word": "Wolfgang", + "score": 0.99, + "entity": "I-PER", + "index": 4, + "start": 11, + "end": 19, + }, + { + "word": "Berlin", + "score": 0.99, + "entity": "I-LOC", + "index": 9, + "start": 34, + "end": 40, + }, + ], + "question-answering": { + "score": 0.99, + "start": 68, + "end": 77, + "answer": "sagemaker", + }, + "summarization": [ + {"summary_text": " The A The The ANew York City has been installed in the US."} + ], + "translation_xx_to_yy": [ + {"translation_text": "Mein Name ist Sarah und ich lebe in London"} + ], + "text2text-generation": [ + {"generated_text": "42 is the answer to life, the universe and everything"} ], - "question-answering": {"score": 0.99, "start": 68, "end": 77, "answer": "sagemaker"}, - "summarization": [{"summary_text": " The A The The ANew York City has been installed in the US."}], - "translation_xx_to_yy": [{"translation_text": "Mein Name ist Sarah und ich lebe in London"}], - "text2text-generation": [{"generated_text": "42 is the answer to life, the universe and everything"}], "feature-extraction": None, "fill-mask": None, "text-generation": None, "image-classification": [ {"score": 0.8858247399330139, "label": "tiger, Panthera tigris"}, {"score": 0.10940514504909515, "label": "tiger cat"}, - {"score": 0.0006216464680619538, "label": "jaguar, panther, Panthera onca, Felis onca"}, + { + "score": 0.0006216464680619538, + "label": "jaguar, panther, Panthera onca, Felis onca", + }, {"score": 0.0004262699221726507, "label": "dhole, Cuon alpinus"}, - {"score": 0.00030842673731967807, "label": "lion, king of beasts, Panthera leo"}, + { + "score": 0.00030842673731967807, + "label": "lion, king of beasts, Panthera leo", + }, ], "zero-shot-image-classification": [ {"score": 0.8858247399330139, "label": "tiger, Panthera tigris"}, {"score": 0.10940514504909515, "label": "tiger cat"}, - {"score": 0.0006216464680619538, "label": "jaguar, panther, Panthera onca, Felis onca"}, + { + "score": 0.0006216464680619538, + "label": "jaguar, panther, Panthera onca, Felis onca", + }, {"score": 0.0004262699221726507, "label": "dhole, Cuon alpinus"}, - {"score": 0.00030842673731967807, "label": "lion, king of beasts, Panthera leo"}, + { + "score": 0.00030842673731967807, + "label": "lion, king of beasts, Panthera leo", + }, ], "automatic-speech-recognition": { "text": "GOING ALONG SLUSHY COUNTRY ROADS AND SPEAKING TO DAMP OAUDIENCES IN DROFTY SCHOOL ROOMS DAY AFTER DAY FOR A FORT NIGHT HE'LL HAVE TO PUT IN AN APPEARANCE AT SOME PLACE OF WORSHIP ON SUNDAY MORNING AND HE CAN COME TO US IMMEDIATELY AFTERWARDS" @@ -218,16 +262,14 @@ "image-segmentation": [{"score": 0.9143241047859192, "label": "cat", "mask": {}}], "table-question-answering": {"answer": "36542"}, "conversational": [ - {'role': 'user', 'content': 'Which movie is the best ?'}, - {'role': 'assistant', 'content': "It's Die Hard for sure."}, - {'role': 'user', 'content': 'Can you explain why?'}, - {'role': 'assistant', 'content': "It's a great movie."}, + {"role": "user", "content": "Which movie is the best ?"}, + {"role": "assistant", "content": "It's Die Hard for sure."}, ], "sentence-similarity": {"similarities": ""}, "sentence-embeddings": {"embeddings": ""}, "sentence-ranking": {"scores": ""}, "text-to-image": bytes, - "custom": {"inputs": "this is a test"} + "custom": {"inputs": "this is a test"}, } @@ -236,7 +278,7 @@ "zero-shot-classification": validate_zero_shot_classification, "zero-shot-image-classification": validate_zero_shot_classification, "feature-extraction": validate_feature_extraction, - "ner": validate_ner, + "token-classification": validate_ner, "question-answering": validate_question_answering, "fill-mask": validate_fill_mask, "summarization": validate_summarization, @@ -254,5 +296,5 @@ "sentence-embeddings": validate_zero_shot_classification, "sentence-ranking": validate_zero_shot_classification, "text-to-image": validate_text_to_image, - "custom": validate_custom + "custom": validate_custom, } diff --git a/tests/integ/conftest.py b/tests/integ/conftest.py index ec282ea8..4b3f6118 100644 --- a/tests/integ/conftest.py +++ b/tests/integ/conftest.py @@ -1,45 +1,37 @@ +import logging +import os +import random +import socket +import time + import docker import pytest -import random -import logging -from tests.integ.config import task2model import tenacity -import time -from huggingface_inference_toolkit.utils import ( - _is_gpu_available, - _load_repository_from_hf -) -from transformers.testing_utils import ( - slow, - _run_slow_tests -) -import uuid -import socket -import os +from huggingface_inference_toolkit.utils import _load_repository_from_hf +from transformers.testing_utils import _run_slow_tests + +from tests.integ.config import task2model HF_HUB_CACHE = os.environ.get("HF_HUB_CACHE", "/home/ubuntu/.cache/huggingface/hub") IS_GPU = _run_slow_tests DEVICE = "gpu" if IS_GPU else "cpu" + @tenacity.retry( - retry = tenacity.retry_if_exception(docker.errors.APIError), - stop = tenacity.stop_after_attempt(10) + retry=tenacity.retry_if_exception(docker.errors.APIError), + stop=tenacity.stop_after_attempt(10), ) -@pytest.fixture(scope = "function") -def remote_container( - device, - task, - framework -): +@pytest.fixture(scope="function") +def remote_container(device, task, framework): time.sleep(random.randint(1, 5)) - #client = docker.DockerClient(base_url='unix://var/run/docker.sock') + # client = docker.DockerClient(base_url='unix://var/run/docker.sock') client = docker.from_env() container_name = f"integration-test-{framework}-{task}-{device}" container_image = f"integration-test-{framework}:{device}" port = random.randint(5000, 9000) model = task2model[task][framework] - #check if port is already open + # check if port is already open sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) while sock.connect_ex(("localhost", port)) == 0: logging.debug(f"Port {port} is already being used; getting a new one...") @@ -48,51 +40,46 @@ def remote_container( logging.debug(f"Image: {container_image}") logging.debug(f"Port: {port}") - device_request = [ - docker.types.DeviceRequest( - count=-1, - capabilities=[["gpu"]]) - ] if device == "gpu" else [] + device_request = ( + [docker.types.DeviceRequest(count=-1, capabilities=[["gpu"]])] + if device == "gpu" + else [] + ) yield client.containers.run( - image = container_image, + image=container_image, name=container_name, ports={"5000": port}, - environment={ - "HF_MODEL_ID": model, - "HF_TASK": task, - "CUDA_LAUNCH_BLOCKING": 1 - }, + environment={"HF_MODEL_ID": model, "HF_TASK": task, "CUDA_LAUNCH_BLOCKING": 1}, detach=True, # GPU device_requests=device_request, ), port - #Teardown + # Teardown previous = client.containers.get(container_name) + logs = previous.logs().decode("utf-8") + logging.info(f"Container logs:\n{logs}") previous.stop() previous.remove() -@tenacity.retry( - stop = tenacity.stop_after_attempt(10), - reraise = True -) -@pytest.fixture(scope = "function") -def local_container( - device, - task, - repository_id, - framework -): +@tenacity.retry(stop=tenacity.stop_after_attempt(10), reraise=True) +@pytest.fixture(scope="function") +def local_container(device, task, repository_id, framework): try: time.sleep(random.randint(1, 5)) - id = uuid.uuid4() if not (task == "custom"): model = task2model[task][framework] id = task else: model = repository_id + id = random.randint(1, 1000) + + env = { + "HF_MODEL_DIR": "/opt/huggingface/model", + "HF_TASK": task, + } logging.info(f"Starting container with model: {model}") @@ -100,7 +87,7 @@ def local_container( message = f"No model supported for {framework}" logging.error(message) raise ValueError(message) - + logging.info(f"Starting container with Model = {model}") client = docker.from_env() container_name = f"integration-test-{framework}-{id}-{device}" @@ -108,7 +95,7 @@ def local_container( port = random.randint(5000, 9000) - #check if port is already open + # check if port is already open sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) while sock.connect_ex(("localhost", port)) == 0: logging.debug(f"Port {port} is already being used; getting a new one...") @@ -117,43 +104,50 @@ def local_container( logging.debug(f"Image: {container_image}") logging.debug(f"Port: {port}") - device_request = [ - docker.types.DeviceRequest(count=-1, capabilities=[["gpu"]]) - ] if device == "gpu" else [] + device_request = ( + [docker.types.DeviceRequest(count=-1, capabilities=[["gpu"]])] + if device == "gpu" + else None + ) + if device == "inf2": + devices = { + "/dev/neuron0": { + "PathInContainer": "/dev/neuron0", + "CgroupPermissions": "rwm", + } + } + env["HF_OPTIMUM_BATCH_SIZE"] = 1 + env["HF_OPTIMUM_SEQUENCE_LENGTH"] = 128 + else: + devices = None object_id = model.replace("/", "--") model_dir = f"{HF_HUB_CACHE}/{object_id}" - storage_dir = _load_repository_from_hf( - repository_id = model, - target_dir = model_dir, - framework = framework + _storage_dir = _load_repository_from_hf( + repository_id=model, target_dir=model_dir ) yield client.containers.run( container_image, name=container_name, ports={"5000": port}, - environment={ - "HF_MODEL_DIR": "/opt/huggingface/model", - "HF_TASK": task - }, - volumes = { - model_dir: { - "bind": "/opt/huggingface/model", - "mode": "ro" - } - }, + environment=env, + volumes={model_dir: {"bind": "/opt/huggingface/model", "mode": "ro"}}, detach=True, # GPU device_requests=device_request, + # INF2 + devices=devices, ), port - #Teardown + # Teardown previous = client.containers.get(container_name) + time.sleep(5) + logs = previous.logs().decode("utf-8") + logging.info(f"Container logs:\n{logs}") previous.stop() previous.remove() except Exception as exception: logging.error(f"Error starting container: {str(exception)}") raise exception - diff --git a/tests/integ/helpers.py b/tests/integ/helpers.py index 0dae2598..e9e5d808 100644 --- a/tests/integ/helpers.py +++ b/tests/integ/helpers.py @@ -1,35 +1,23 @@ +import logging import random import tempfile import time +import traceback + import docker import pytest import requests -from huggingface_inference_toolkit.utils import ( - _is_gpu_available, - _load_repository_from_hf -) -from tests.integ.config import ( - task2input, - task2model, - task2output, - task2validation -) -from transformers.testing_utils import ( - require_torch, - slow, - require_tf, - _run_slow_tests -) -import tenacity from docker import DockerClient -import logging -import traceback -import urllib3 +from huggingface_inference_toolkit.utils import _load_repository_from_hf +from transformers.testing_utils import _run_slow_tests, require_tf, require_torch + +from tests.integ.config import task2input, task2model, task2output, task2validation IS_GPU = _run_slow_tests DEVICE = "gpu" if IS_GPU else "cpu" -client = docker.DockerClient(base_url='unix://var/run/docker.sock') +client = docker.DockerClient(base_url="unix://var/run/docker.sock") + def make_sure_other_containers_are_stopped(client: DockerClient, container_name: str): try: @@ -40,17 +28,13 @@ def make_sure_other_containers_are_stopped(client: DockerClient, container_name: return None -#@tenacity.retry( +# @tenacity.retry( # retry = tenacity.retry_if_exception(ValueError), # stop = tenacity.stop_after_attempt(10), # reraise = True -#) -def wait_for_container_to_be_ready( - base_url, - time_between_retries = 1, - max_retries = 30 -): - +# ) +def wait_for_container_to_be_ready(base_url, time_between_retries=3, max_retries=30): + retries = 0 error = None @@ -62,20 +46,23 @@ def wait_for_container_to_be_ready( logging.info("Container ready!") return True else: - raise ConnectionError(f"Error: {response.status_code}") + raise ConnectionError( + f"Couldn'start container, Error: {response.status_code}" + ) except Exception as exception: error = exception logging.warning(f"Container at {base_url} not ready, trying again...") retries += 1 - + logging.error(f"Unable to start container: {str(error)}") raise error + def verify_task( - #container: DockerClient, + # container: DockerClient, task: str, port: int = 5000, - framework: str = "pytorch" + framework: str = "pytorch", ): BASE_URL = f"http://localhost:{port}" logging.info(f"Base URL: {BASE_URL}") @@ -92,18 +79,24 @@ def verify_task( or task == "zero-shot-image-classification" ): prediction = requests.post( - f"{BASE_URL}", data=task2input[task], headers={"content-type": "image/x-image"} + f"{BASE_URL}", + data=task2input[task], + headers={"content-type": "image/x-image"}, ).json() elif task == "automatic-speech-recognition" or task == "audio-classification": prediction = requests.post( - f"{BASE_URL}", data=task2input[task], headers={"content-type": "audio/x-audio"} + f"{BASE_URL}", + data=task2input[task], + headers={"content-type": "audio/x-audio"}, ).json() elif task == "text-to-image": - prediction = requests.post(f"{BASE_URL}", json=input, headers={"accept": "image/png"}).content + prediction = requests.post( + f"{BASE_URL}", json=input, headers={"accept": "image/png"} + ).content else: prediction = requests.post(f"{BASE_URL}", json=input).json() - + logging.info(f"Input: {input}") logging.info(f"Prediction: {prediction}") logging.info(f"Snapshot: {task2output[task]}") @@ -112,10 +105,7 @@ def verify_task( for message in prediction: assert "error" not in message.keys() else: - assert task2validation[task]( - result=prediction, - snapshot=task2output[task] - ) + assert task2validation[task](result=prediction, snapshot=task2output[task]) except Exception as exception: logging.error(f"Base URL: {BASE_URL}") logging.error(f"Task: {task}") @@ -131,7 +121,7 @@ def verify_task( [ "text-classification", "zero-shot-classification", - "ner", + "token-classification", "question-answering", "fill-mask", "summarization", @@ -162,7 +152,9 @@ def test_pt_container_remote_model(task) -> None: framework = "pytorch" model = task2model[task][framework] port = random.randint(5000, 6000) - device_request = [docker.types.DeviceRequest(count=-1, capabilities=[["gpu"]])] if IS_GPU else [] + device_request = ( + [docker.types.DeviceRequest(count=-1, capabilities=[["gpu"]])] if IS_GPU else [] + ) make_sure_other_containers_are_stopped(client, container_name) container = client.containers.run( @@ -176,7 +168,7 @@ def test_pt_container_remote_model(task) -> None: ) time.sleep(5) - verify_task(task = task, port = port) + verify_task(task=task, port=port) container.stop() container.remove() @@ -187,7 +179,7 @@ def test_pt_container_remote_model(task) -> None: [ "text-classification", "zero-shot-classification", - "ner", + "token-classification", "question-answering", "fill-mask", "summarization", @@ -218,11 +210,13 @@ def test_pt_container_local_model(task) -> None: framework = "pytorch" model = task2model[task][framework] port = random.randint(5000, 6000) - device_request = [docker.types.DeviceRequest(count=-1, capabilities=[["gpu"]])] if IS_GPU else [] + device_request = ( + [docker.types.DeviceRequest(count=-1, capabilities=[["gpu"]])] if IS_GPU else [] + ) make_sure_other_containers_are_stopped(client, container_name) with tempfile.TemporaryDirectory() as tmpdirname: # https://github.com/huggingface/infinity/blob/test-ovh/test/integ/utils.py - storage_dir = _load_repository_from_hf(model, tmpdirname, framework="pytorch") + _storage_dir = _load_repository_from_hf(model, tmpdirname, framework="pytorch") container = client.containers.run( container_image, name=container_name, @@ -247,13 +241,15 @@ def test_pt_container_local_model(task) -> None: def test_pt_container_custom_handler(repository_id) -> None: container_name = "integration-test-custom" container_image = f"starlette-transformers:{DEVICE}" - device_request = [docker.types.DeviceRequest(count=-1, capabilities=[["gpu"]])] if IS_GPU else [] + device_request = ( + [docker.types.DeviceRequest(count=-1, capabilities=[["gpu"]])] if IS_GPU else [] + ) port = random.randint(5000, 6000) make_sure_other_containers_are_stopped(client, container_name) with tempfile.TemporaryDirectory() as tmpdirname: # https://github.com/huggingface/infinity/blob/test-ovh/test/integ/utils.py - storage_dir = _load_repository_from_hf(repository_id, tmpdirname) + _storage_dir = _load_repository_from_hf(repository_id, tmpdirname) container = client.containers.run( container_image, name=container_name, @@ -284,13 +280,15 @@ def test_pt_container_custom_handler(repository_id) -> None: def test_pt_container_legacy_custom_pipeline(repository_id) -> None: container_name = "integration-test-custom" container_image = f"starlette-transformers:{DEVICE}" - device_request = [docker.types.DeviceRequest(count=-1, capabilities=[["gpu"]])] if IS_GPU else [] + device_request = ( + [docker.types.DeviceRequest(count=-1, capabilities=[["gpu"]])] if IS_GPU else [] + ) port = random.randint(5000, 6000) make_sure_other_containers_are_stopped(client, container_name) with tempfile.TemporaryDirectory() as tmpdirname: # https://github.com/huggingface/infinity/blob/test-ovh/test/integ/utils.py - storage_dir = _load_repository_from_hf(repository_id, tmpdirname) + _storage_dir = _load_repository_from_hf(repository_id, tmpdirname) container = client.containers.run( container_image, name=container_name, @@ -319,7 +317,7 @@ def test_pt_container_legacy_custom_pipeline(repository_id) -> None: [ "text-classification", "zero-shot-classification", - "ner", + "token-classification", "question-answering", "fill-mask", "summarization", @@ -347,7 +345,9 @@ def test_tf_container_remote_model(task) -> None: container_image = f"starlette-transformers:{DEVICE}" framework = "tensorflow" model = task2model[task][framework] - device_request = [docker.types.DeviceRequest(count=-1, capabilities=[["gpu"]])] if IS_GPU else [] + device_request = ( + [docker.types.DeviceRequest(count=-1, capabilities=[["gpu"]])] if IS_GPU else [] + ) if model is None: pytest.skip("no supported TF model") port = random.randint(5000, 6000) @@ -373,7 +373,7 @@ def test_tf_container_remote_model(task) -> None: [ "text-classification", "zero-shot-classification", - "ner", + "token-classification", "question-answering", "fill-mask", "summarization", @@ -401,14 +401,16 @@ def test_tf_container_local_model(task) -> None: container_image = f"starlette-transformers:{DEVICE}" framework = "tensorflow" model = task2model[task][framework] - device_request = [docker.types.DeviceRequest(count=-1, capabilities=[["gpu"]])] if IS_GPU else [] + device_request = ( + [docker.types.DeviceRequest(count=-1, capabilities=[["gpu"]])] if IS_GPU else [] + ) if model is None: pytest.skip("no supported TF model") port = random.randint(5000, 6000) make_sure_other_containers_are_stopped(client, container_name) with tempfile.TemporaryDirectory() as tmpdirname: # https://github.com/huggingface/infinity/blob/test-ovh/test/integ/utils.py - storage_dir = _load_repository_from_hf(model, tmpdirname, framework=framework) + _storage_dir = _load_repository_from_hf(model, tmpdirname, framework=framework) container = client.containers.run( container_image, name=container_name, diff --git a/tests/integ/test_pytorch_local_cpu.py b/tests/integ/test_pytorch_local_cpu.py index 17e651e9..86982367 100644 --- a/tests/integ/test_pytorch_local_cpu.py +++ b/tests/integ/test_pytorch_local_cpu.py @@ -1,27 +1,22 @@ -import tempfile -from tests.integ.helpers import verify_task -from tests.integ.config import ( - task2input, - task2model, - task2output, - task2validation -) -from transformers.testing_utils import ( - require_torch, - slow, - _run_slow_tests -) import pytest +import tenacity +from transformers.testing_utils import require_torch + +from tests.integ.helpers import verify_task -class TestPytorchLocal: +class TestPytorchLocal: + @tenacity.retry( + stop=tenacity.stop_after_attempt(5), + reraise=True, + ) @require_torch @pytest.mark.parametrize( "task", [ "text-classification", "zero-shot-classification", - "ner", + "token-classification", "question-answering", "fill-mask", "summarization", @@ -42,86 +37,53 @@ class TestPytorchLocal: "text-to-image", ], ) - @pytest.mark.parametrize( - "device", - ["cpu"] - ) - @pytest.mark.parametrize( - "framework", - ["pytorch"] - ) - @pytest.mark.parametrize( - "repository_id", - [""] - ) - @pytest.mark.usefixtures('local_container') + @pytest.mark.parametrize("device", ["cpu"]) + @pytest.mark.parametrize("framework", ["pytorch"]) + @pytest.mark.parametrize("repository_id", [""]) + @pytest.mark.usefixtures("local_container") def test_pt_container_local_model( - self, - local_container, - task, - framework, - device + self, local_container, task, framework, device ) -> None: - verify_task(task = task, port = local_container[1]) - + verify_task(task=task, port=local_container[1]) + @tenacity.retry( + stop=tenacity.stop_after_attempt(5), + reraise=True, + ) @require_torch @pytest.mark.parametrize( "repository_id", ["philschmid/custom-handler-test", "philschmid/custom-handler-distilbert"], ) - @pytest.mark.parametrize( - "device", - ["cpu"] - ) - @pytest.mark.parametrize( - "framework", - ["pytorch"] - ) - @pytest.mark.parametrize( - "task", - ["custom"] - ) - @pytest.mark.usefixtures('local_container') + @pytest.mark.parametrize("device", ["cpu"]) + @pytest.mark.parametrize("framework", ["pytorch"]) + @pytest.mark.parametrize("task", ["custom"]) + @pytest.mark.usefixtures("local_container") def test_pt_container_custom_handler( - self, - local_container, - task, - device, - repository_id + self, local_container, task, device, repository_id ) -> None: - + verify_task( - task = task, - port = local_container[1], + task=task, + port=local_container[1], ) - + @tenacity.retry( + stop=tenacity.stop_after_attempt(5), + reraise=True, + ) @require_torch @pytest.mark.parametrize( "repository_id", ["philschmid/custom-pipeline-text-classification"], ) - @pytest.mark.parametrize( - "device", - ["cpu"] - ) - @pytest.mark.parametrize( - "framework", - ["pytorch"] - ) - @pytest.mark.parametrize( - "task", - ["custom"] - ) - @pytest.mark.usefixtures('local_container') + @pytest.mark.parametrize("device", ["cpu"]) + @pytest.mark.parametrize("framework", ["pytorch"]) + @pytest.mark.parametrize("task", ["custom"]) + @pytest.mark.usefixtures("local_container") def test_pt_container_legacy_custom_pipeline( - self, - local_container, - repository_id, - device, - task + self, local_container, repository_id, device, task ) -> None: - verify_task(task = task, port = local_container[1]) + verify_task(task=task, port=local_container[1]) diff --git a/tests/integ/test_pytorch_local_gpu.py b/tests/integ/test_pytorch_local_gpu.py index 15ffebde..eb0cb0ae 100644 --- a/tests/integ/test_pytorch_local_gpu.py +++ b/tests/integ/test_pytorch_local_gpu.py @@ -1,17 +1,8 @@ -import tempfile -from tests.integ.helpers import verify_task -from tests.integ.config import ( - task2input, - task2model, - task2output, - task2validation -) -from transformers.testing_utils import ( - require_torch, - slow, - _run_slow_tests -) import pytest +from transformers.testing_utils import require_torch + +from tests.integ.helpers import verify_task + class TestPytorchLocal: @@ -21,7 +12,7 @@ class TestPytorchLocal: [ "text-classification", "zero-shot-classification", - "ner", + "token-classification", "question-answering", "fill-mask", "summarization", @@ -42,86 +33,45 @@ class TestPytorchLocal: "text-to-image", ], ) - @pytest.mark.parametrize( - "device", - ["gpu"] - ) - @pytest.mark.parametrize( - "framework", - ["pytorch"] - ) - @pytest.mark.parametrize( - "repository_id", - [""] - ) - @pytest.mark.usefixtures('local_container') + @pytest.mark.parametrize("device", ["gpu"]) + @pytest.mark.parametrize("framework", ["pytorch"]) + @pytest.mark.parametrize("repository_id", [""]) + @pytest.mark.usefixtures("local_container") def test_pt_container_local_model( - self, - local_container, - task, - framework, - device + self, local_container, task, framework, device ) -> None: - verify_task(task = task, port = local_container[1]) - + verify_task(task=task, port=local_container[1]) @require_torch @pytest.mark.parametrize( "repository_id", ["philschmid/custom-handler-test", "philschmid/custom-handler-distilbert"], ) - @pytest.mark.parametrize( - "device", - ["gpu"] - ) - @pytest.mark.parametrize( - "framework", - ["pytorch"] - ) - @pytest.mark.parametrize( - "task", - ["custom"] - ) - @pytest.mark.usefixtures('local_container') + @pytest.mark.parametrize("device", ["gpu"]) + @pytest.mark.parametrize("framework", ["pytorch"]) + @pytest.mark.parametrize("task", ["custom"]) + @pytest.mark.usefixtures("local_container") def test_pt_container_custom_handler( - self, - local_container, - task, - device, - repository_id + self, local_container, task, device, repository_id ) -> None: - + verify_task( - task = task, - port = local_container[1], + task=task, + port=local_container[1], ) - @require_torch @pytest.mark.parametrize( "repository_id", ["philschmid/custom-pipeline-text-classification"], ) - @pytest.mark.parametrize( - "device", - ["gpu"] - ) - @pytest.mark.parametrize( - "framework", - ["pytorch"] - ) - @pytest.mark.parametrize( - "task", - ["custom"] - ) - @pytest.mark.usefixtures('local_container') + @pytest.mark.parametrize("device", ["gpu"]) + @pytest.mark.parametrize("framework", ["pytorch"]) + @pytest.mark.parametrize("task", ["custom"]) + @pytest.mark.usefixtures("local_container") def test_pt_container_legacy_custom_pipeline( - self, - local_container, - repository_id, - device, - task + self, local_container, repository_id, device, task ) -> None: - verify_task(task = task, port = local_container[1]) + verify_task(task=task, port=local_container[1]) diff --git a/tests/integ/test_pytorch_local_inf2.py b/tests/integ/test_pytorch_local_inf2.py new file mode 100644 index 00000000..de0c7b4e --- /dev/null +++ b/tests/integ/test_pytorch_local_inf2.py @@ -0,0 +1,32 @@ +import pytest +from huggingface_inference_toolkit.optimum_utils import is_optimum_neuron_available +from transformers.testing_utils import require_torch + +from tests.integ.helpers import verify_task + +require_inferentia = pytest.mark.skipif( + not is_optimum_neuron_available(), + reason="Skipping tests, since optimum neuron is not available or not running on inf2 instances.", +) + + +class TestPytorchLocal: + @require_torch + @require_inferentia + @pytest.mark.parametrize( + "task", + [ + "feature-extraction", + "fill-mask", + "question-answering", + "text-classification", + "token-classification", + ], + ) + @pytest.mark.parametrize("device", ["inf2"]) + @pytest.mark.parametrize("framework", ["pytorch"]) + @pytest.mark.parametrize("repository_id", [""]) + @pytest.mark.usefixtures("local_container") + def test_pt_container_local_model(self, local_container, task) -> None: + + verify_task(task=task, port=local_container[1]) diff --git a/tests/integ/test_pytorch_remote_cpu.py b/tests/integ/test_pytorch_remote_cpu.py index 14001dda..5eb4edb4 100644 --- a/tests/integ/test_pytorch_remote_cpu.py +++ b/tests/integ/test_pytorch_remote_cpu.py @@ -1,31 +1,18 @@ -import tempfile -from tests.integ.helpers import verify_task -from tests.integ.config import ( - task2input, - task2model, - task2output, - task2validation -) -from transformers.testing_utils import ( - require_torch, - slow, - _run_slow_tests -) +import docker import pytest import tenacity -import docker + +from tests.integ.helpers import verify_task + class TestPytorchRemote: @tenacity.retry( - retry = tenacity.retry_if_exception(docker.errors.APIError), - stop = tenacity.stop_after_attempt(5), - reraise = True - ) - @pytest.mark.parametrize( - "device", - ["cpu"] + retry=tenacity.retry_if_exception(docker.errors.APIError), + stop=tenacity.stop_after_attempt(5), + reraise=True, ) + @pytest.mark.parametrize("device", ["cpu"]) @pytest.mark.parametrize( "task", [ @@ -34,7 +21,7 @@ class TestPytorchRemote: "question-answering", "fill-mask", "summarization", - "ner", + "token-classification", "translation_xx_to_yy", "text2text-generation", "text-generation", @@ -49,14 +36,11 @@ class TestPytorchRemote: "sentence-similarity", "sentence-embeddings", "sentence-ranking", - "text-to-image" - ] - ) - @pytest.mark.parametrize( - "framework", - ["pytorch"] + "text-to-image", + ], ) - @pytest.mark.usefixtures('remote_container') + @pytest.mark.parametrize("framework", ["pytorch"]) + @pytest.mark.usefixtures("remote_container") def test_inference_remote(self, remote_container, task, framework, device): - verify_task(task = task, port = remote_container[1]) + verify_task(task=task, port=remote_container[1]) diff --git a/tests/integ/test_pytorch_remote_gpu.py b/tests/integ/test_pytorch_remote_gpu.py index ec79f4a5..8c49f9ef 100644 --- a/tests/integ/test_pytorch_remote_gpu.py +++ b/tests/integ/test_pytorch_remote_gpu.py @@ -1,31 +1,18 @@ -import tempfile -from tests.integ.helpers import verify_task -from tests.integ.config import ( - task2input, - task2model, - task2output, - task2validation -) -from transformers.testing_utils import ( - require_torch, - slow, - _run_slow_tests -) +import docker import pytest import tenacity -import docker + +from tests.integ.helpers import verify_task + class TestPytorchRemote: @tenacity.retry( - retry = tenacity.retry_if_exception(docker.errors.APIError), - stop = tenacity.stop_after_attempt(5), - reraise = True - ) - @pytest.mark.parametrize( - "device", - ["gpu"] + retry=tenacity.retry_if_exception(docker.errors.APIError), + stop=tenacity.stop_after_attempt(5), + reraise=True, ) + @pytest.mark.parametrize("device", ["gpu"]) @pytest.mark.parametrize( "task", [ @@ -34,7 +21,7 @@ class TestPytorchRemote: "question-answering", "fill-mask", "summarization", - "ner", + "token-classification", "translation_xx_to_yy", "text2text-generation", "text-generation", @@ -49,14 +36,11 @@ class TestPytorchRemote: "sentence-similarity", "sentence-embeddings", "sentence-ranking", - "text-to-image" - ] - ) - @pytest.mark.parametrize( - "framework", - ["pytorch"] + "text-to-image", + ], ) - @pytest.mark.usefixtures('remote_container') + @pytest.mark.parametrize("framework", ["pytorch"]) + @pytest.mark.usefixtures("remote_container") def test_inference_remote(self, remote_container, task, framework, device): - verify_task(task = task, port = remote_container[1]) + verify_task(task=task, port=remote_container[1]) diff --git a/tests/integ/utils.py b/tests/integ/utils.py index 2b826cdb..24901f9d 100644 --- a/tests/integ/utils.py +++ b/tests/integ/utils.py @@ -1,7 +1,4 @@ import logging -from contextlib import contextmanager -from time import time - def validate_classification(result=None, snapshot=None): @@ -10,7 +7,7 @@ def validate_classification(result=None, snapshot=None): return True def validate_conversational(result=None, snapshot=None): - assert len(result) >= len(snapshot) + assert len(result[0]["generated_text"]) >= len(snapshot) def validate_zero_shot_classification(result=None, snapshot=None): diff --git a/tests/resources/custom_handler/custom_utils.py b/tests/resources/custom_handler/custom_utils.py index 30d73a8b..759f83d2 100644 --- a/tests/resources/custom_handler/custom_utils.py +++ b/tests/resources/custom_handler/custom_utils.py @@ -1,3 +1,3 @@ def test_method(input): """reverse string""" - return input[::-1] \ No newline at end of file + return input[::-1] diff --git a/tests/resources/custom_handler/pipeline.py b/tests/resources/custom_handler/pipeline.py index d3adeb4a..2c1ddc1a 100644 --- a/tests/resources/custom_handler/pipeline.py +++ b/tests/resources/custom_handler/pipeline.py @@ -1,4 +1,6 @@ from custom_utils import test_method + + class PreTrainedPipeline: def __init__(self, path): self.path = path diff --git a/tests/unit/conftest.py b/tests/unit/conftest.py index ddba0442..d2d3c2ac 100644 --- a/tests/unit/conftest.py +++ b/tests/unit/conftest.py @@ -1,7 +1,8 @@ import os -import logging + import pytest + @pytest.fixture(scope = "session") def cache_test_dir(): - yield os.environ.get("CACHE_TEST_DIR", "./tests") \ No newline at end of file + yield os.environ.get("CACHE_TEST_DIR", "./tests") diff --git a/tests/unit/test_const.py b/tests/unit/test_const.py index 37d2adcc..75104a98 100644 --- a/tests/unit/test_const.py +++ b/tests/unit/test_const.py @@ -1,5 +1,3 @@ -import os -from unittest import mock def test_if_provided(): diff --git a/tests/unit/test_diffusers.py b/tests/unit/test_diffusers.py index 4384cd4e..890575da 100644 --- a/tests/unit/test_diffusers.py +++ b/tests/unit/test_diffusers.py @@ -1,13 +1,10 @@ -import os +import logging import tempfile -from PIL import Image -from transformers.testing_utils import require_torch, slow - -from huggingface_inference_toolkit.diffusers_utils import get_diffusers_pipeline, IEAutoPipelineForText2Image +from huggingface_inference_toolkit.diffusers_utils import IEAutoPipelineForText2Image from huggingface_inference_toolkit.utils import _load_repository_from_hf, get_pipeline - -import logging +from PIL import Image +from transformers.testing_utils import require_torch, slow logging.basicConfig(level="DEBUG") diff --git a/tests/unit/test_handler.py b/tests/unit/test_handler.py index d1a0a561..44a8f818 100644 --- a/tests/unit/test_handler.py +++ b/tests/unit/test_handler.py @@ -1,20 +1,15 @@ import tempfile -from transformers.testing_utils import ( - require_tf, - require_torch, - slow -) + import pytest from huggingface_inference_toolkit.handler import ( HuggingFaceHandler, get_inference_handler_either_custom_or_default_handler, ) - from huggingface_inference_toolkit.utils import ( _is_gpu_available, - _load_repository_from_hf + _load_repository_from_hf, ) - +from transformers.testing_utils import require_tf, require_torch TASK = "text-classification" MODEL = "hf-internal-testing/tiny-random-distilbert" @@ -24,6 +19,7 @@ @require_torch def test_pt_get_device(): import torch + with tempfile.TemporaryDirectory() as tmpdirname: # https://github.com/huggingface/infinity/blob/test-ovh/test/integ/utils.py storage_dir = _load_repository_from_hf(MODEL, tmpdirname, framework="pytorch") @@ -38,11 +34,7 @@ def test_pt_get_device(): def test_pt_predict_call(): with tempfile.TemporaryDirectory() as tmpdirname: # https://github.com/huggingface/infinity/blob/test-ovh/test/integ/utils.py - storage_dir = _load_repository_from_hf( - MODEL, - tmpdirname, - framework="pytorch" - ) + storage_dir = _load_repository_from_hf(MODEL, tmpdirname, framework="pytorch") h = HuggingFaceHandler(model_dir=str(storage_dir), task=TASK) prediction = h(INPUT) @@ -56,9 +48,11 @@ def test_pt_custom_pipeline(): storage_dir = _load_repository_from_hf( "philschmid/custom-pipeline-text-classification", tmpdirname, - framework="pytorch" + framework="pytorch", + ) + h = get_inference_handler_either_custom_or_default_handler( + str(storage_dir), task="custom" ) - h = get_inference_handler_either_custom_or_default_handler(str(storage_dir), task="custom") assert h(INPUT) == INPUT @@ -66,11 +60,11 @@ def test_pt_custom_pipeline(): def test_pt_sentence_transformers_pipeline(): with tempfile.TemporaryDirectory() as tmpdirname: storage_dir = _load_repository_from_hf( - "sentence-transformers/all-MiniLM-L6-v2", - tmpdirname, - framework="pytorch" + "sentence-transformers/all-MiniLM-L6-v2", tmpdirname, framework="pytorch" + ) + h = get_inference_handler_either_custom_or_default_handler( + str(storage_dir), task="sentence-embeddings" ) - h = get_inference_handler_either_custom_or_default_handler(str(storage_dir), task="sentence-embeddings") pred = h(INPUT) assert isinstance(pred["embeddings"], list) @@ -81,9 +75,7 @@ def test_tf_get_device(): with tempfile.TemporaryDirectory() as tmpdirname: # https://github.com/huggingface/infinity/blob/test-ovh/test/integ/utils.py storage_dir = _load_repository_from_hf( - MODEL, - tmpdirname, - framework="tensorflow" + MODEL, tmpdirname, framework="tensorflow" ) h = HuggingFaceHandler(model_dir=str(storage_dir), task=TASK) if _is_gpu_available(): @@ -97,14 +89,10 @@ def test_tf_predict_call(): with tempfile.TemporaryDirectory() as tmpdirname: # https://github.com/huggingface/infinity/blob/test-ovh/test/integ/utils.py storage_dir = _load_repository_from_hf( - MODEL, - tmpdirname, - framework="tensorflow" + MODEL, tmpdirname, framework="tensorflow" ) handler = HuggingFaceHandler( - model_dir=str(storage_dir), - task=TASK, - framework="tf" + model_dir=str(storage_dir), task=TASK, framework="tf" ) prediction = handler(INPUT) @@ -116,9 +104,13 @@ def test_tf_predict_call(): def test_tf_custom_pipeline(): with tempfile.TemporaryDirectory() as tmpdirname: storage_dir = _load_repository_from_hf( - "philschmid/custom-pipeline-text-classification", tmpdirname, framework="tensorflow" + "philschmid/custom-pipeline-text-classification", + tmpdirname, + framework="tensorflow", + ) + h = get_inference_handler_either_custom_or_default_handler( + str(storage_dir), task="custom" ) - h = get_inference_handler_either_custom_or_default_handler(str(storage_dir), task="custom") assert h(INPUT) == INPUT @@ -127,12 +119,9 @@ def test_tf_sentence_transformers_pipeline(): # TODO should fail! because TF is not supported yet with tempfile.TemporaryDirectory() as tmpdirname: storage_dir = _load_repository_from_hf( - "sentence-transformers/all-MiniLM-L6-v2", - tmpdirname, - framework="tensorflow" + "sentence-transformers/all-MiniLM-L6-v2", tmpdirname, framework="tensorflow" ) - with pytest.raises(Exception) as exc_info: - h = get_inference_handler_either_custom_or_default_handler( - str(storage_dir), - task="sentence-embeddings" + with pytest.raises(Exception) as _exc_info: + get_inference_handler_either_custom_or_default_handler( + str(storage_dir), task="sentence-embeddings" ) diff --git a/tests/unit/test_optimum_utils.py b/tests/unit/test_optimum_utils.py new file mode 100644 index 00000000..8014decc --- /dev/null +++ b/tests/unit/test_optimum_utils.py @@ -0,0 +1,87 @@ +import os +import tempfile + +import pytest +from huggingface_inference_toolkit.optimum_utils import ( + get_input_shapes, + get_optimum_neuron_pipeline, + is_optimum_neuron_available, +) +from huggingface_inference_toolkit.utils import _load_repository_from_hf +from transformers.testing_utils import require_torch + +require_inferentia = pytest.mark.skipif( + not is_optimum_neuron_available(), + reason="Skipping tests, since optimum neuron is not available or not running on inf2 instances.", +) + + +REMOTE_NOT_CONVERTED_MODEL = "hf-internal-testing/tiny-random-BertModel" +REMOTE_CONVERTED_MODEL = "optimum/tiny_random_bert_neuron" +TASK = "text-classification" + + +@require_torch +@require_inferentia +def test_not_supported_task(): + os.environ["HF_TASK"] = "not-supported-task" + with pytest.raises(Exception): # noqa + get_optimum_neuron_pipeline(task=TASK, target_dir=os.getcwd()) + + +@require_torch +@require_inferentia +def test_get_input_shapes_from_file(): + with tempfile.TemporaryDirectory() as tmpdirname: + storage_folder = _load_repository_from_hf( + repository_id=REMOTE_CONVERTED_MODEL, + target_dir=tmpdirname, + ) + input_shapes = get_input_shapes(model_dir=storage_folder) + assert input_shapes["batch_size"] == 1 + assert input_shapes["sequence_length"] == 32 + + +@require_torch +@require_inferentia +def test_get_input_shapes_from_env(): + os.environ["HF_OPTIMUM_BATCH_SIZE"] = "4" + os.environ["HF_OPTIMUM_SEQUENCE_LENGTH"] = "32" + with tempfile.TemporaryDirectory() as tmpdirname: + storage_folder = _load_repository_from_hf( + repository_id=REMOTE_NOT_CONVERTED_MODEL, + target_dir=tmpdirname, + ) + input_shapes = get_input_shapes(model_dir=storage_folder) + assert input_shapes["batch_size"] == 4 + assert input_shapes["sequence_length"] == 32 + + +@require_torch +@require_inferentia +def test_get_optimum_neuron_pipeline_from_converted_model(): + with tempfile.TemporaryDirectory() as tmpdirname: + os.system( + f"optimum-cli export neuron --model philschmid/tiny-distilbert-classification --sequence_length 32 --batch_size 1 {tmpdirname}" + ) + pipe = get_optimum_neuron_pipeline(task=TASK, model_dir=tmpdirname) + r = pipe("This is a test") + + assert r[0]["score"] > 0.0 + assert isinstance(r[0]["label"], str) + + +@require_torch +@require_inferentia +def test_get_optimum_neuron_pipeline_from_non_converted_model(): + os.environ["HF_OPTIMUM_SEQUENCE_LENGTH"] = "32" + with tempfile.TemporaryDirectory() as tmpdirname: + storage_folder = _load_repository_from_hf( + repository_id=REMOTE_NOT_CONVERTED_MODEL, + target_dir=tmpdirname, + ) + pipe = get_optimum_neuron_pipeline(task=TASK, model_dir=storage_folder) + r = pipe("This is a test") + + assert r[0]["score"] > 0.0 + assert isinstance(r[0]["label"], str) diff --git a/tests/unit/test_sentence_transformers.py b/tests/unit/test_sentence_transformers.py index 233da490..f8556ed0 100644 --- a/tests/unit/test_sentence_transformers.py +++ b/tests/unit/test_sentence_transformers.py @@ -1,11 +1,5 @@ -import os import tempfile -from transformers import pipeline -from transformers.file_utils import is_torch_available -from transformers.testing_utils import require_tf, require_torch, slow - -from huggingface_inference_toolkit.handler import get_inference_handler_either_custom_or_default_handler from huggingface_inference_toolkit.sentence_transformers_utils import ( SentenceEmbeddingPipeline, get_sentence_transformers_pipeline, @@ -14,13 +8,14 @@ _load_repository_from_hf, get_pipeline, ) +from transformers.testing_utils import require_torch @require_torch def test_get_sentence_transformers_pipeline(): with tempfile.TemporaryDirectory() as tmpdirname: storage_dir = _load_repository_from_hf( - "sentence-transformers/all-MiniLM-L6-v2", tmpdirname, framework="pytorch" + "sentence-transformers/all-MiniLM-L6-v2", tmpdirname ) pipe = get_pipeline("sentence-embeddings", storage_dir.as_posix()) assert isinstance(pipe, SentenceEmbeddingPipeline) @@ -30,7 +25,7 @@ def test_get_sentence_transformers_pipeline(): def test_sentence_embedding_task(): with tempfile.TemporaryDirectory() as tmpdirname: storage_dir = _load_repository_from_hf( - "sentence-transformers/all-MiniLM-L6-v2", tmpdirname, framework="pytorch" + "sentence-transformers/all-MiniLM-L6-v2", tmpdirname ) pipe = get_sentence_transformers_pipeline("sentence-embeddings", storage_dir.as_posix()) res = pipe("Lets create an embedding") @@ -41,7 +36,7 @@ def test_sentence_embedding_task(): def test_sentence_similarity(): with tempfile.TemporaryDirectory() as tmpdirname: storage_dir = _load_repository_from_hf( - "sentence-transformers/all-MiniLM-L6-v2", tmpdirname, framework="pytorch" + "sentence-transformers/all-MiniLM-L6-v2", tmpdirname ) pipe = get_sentence_transformers_pipeline("sentence-similarity", storage_dir.as_posix()) res = pipe({"source_sentence": "Lets create an embedding", "sentences": ["Lets create an embedding"]}) @@ -51,7 +46,7 @@ def test_sentence_similarity(): @require_torch def test_sentence_ranking(): with tempfile.TemporaryDirectory() as tmpdirname: - storage_dir = _load_repository_from_hf("cross-encoder/ms-marco-MiniLM-L-6-v2", tmpdirname, framework="pytorch") + storage_dir = _load_repository_from_hf("cross-encoder/ms-marco-MiniLM-L-6-v2", tmpdirname) pipe = get_sentence_transformers_pipeline("sentence-ranking", storage_dir.as_posix()) res = pipe( [ diff --git a/tests/unit/test_serializer.py b/tests/unit/test_serializer.py index 07dfd5c1..0b53995d 100644 --- a/tests/unit/test_serializer.py +++ b/tests/unit/test_serializer.py @@ -1,15 +1,11 @@ -import base64 -import json +import os + import numpy as np import pytest -import os -from huggingface_inference_toolkit.serialization import ( - Jsoner, - Audioer, - Imager -) +from huggingface_inference_toolkit.serialization import Audioer, Imager, Jsoner from PIL import Image + def test_json_serialization(): t = {"res": np.array([2.0]), "text": "I like you.", "float": 1.2} assert b'{"res":[2.0],"text":"I like you.","float":1.2}' == Jsoner.serialize(t) diff --git a/tests/unit/test_utils.py b/tests/unit/test_utils.py index 79cff93d..e7b3eef6 100644 --- a/tests/unit/test_utils.py +++ b/tests/unit/test_utils.py @@ -1,11 +1,7 @@ +import logging import os -from pathlib import Path import tempfile - - -from transformers import pipeline -from transformers.file_utils import is_torch_available -from transformers.testing_utils import require_tf, require_torch, slow +from pathlib import Path from huggingface_inference_toolkit.handler import get_inference_handler_either_custom_or_default_handler from huggingface_inference_toolkit.utils import ( @@ -14,13 +10,10 @@ _load_repository_from_hf, check_and_register_custom_pipeline_from_directory, get_pipeline, - wrap_conversation_pipeline, ) +from transformers.file_utils import is_torch_available +from transformers.testing_utils import require_tf, require_torch, slow -import logging - -MODEL = "lysandre/tiny-bert-random" -TASK = "text-classification" TASK_MODEL = "sshleifer/tiny-dbmdz-bert-large-cased-finetuned-conll03-english" @@ -112,12 +105,13 @@ def test_get_framework_tensorflow(): @require_torch def test_get_pipeline(): + MODEL = "hf-internal-testing/tiny-random-BertForSequenceClassification" + TASK = "text-classification" with tempfile.TemporaryDirectory() as tmpdirname: storage_dir = _load_repository_from_hf(MODEL, tmpdirname, framework="pytorch") pipe = get_pipeline( task = TASK, model_dir = storage_dir.as_posix(), - framework = "pytorch" ) res = pipe("Life is good, Life is bad") assert "score" in res[0] @@ -129,8 +123,6 @@ def test_whisper_long_audio(cache_test_dir): storage_dir = _load_repository_from_hf( repository_id = "openai/whisper-tiny", target_dir = tmpdirname, - framework = "pytorch", - revision = "be0ba7c2f24f0127b27863a23a08002af4c2c279" ) logging.info(f"Temp dir: {tmpdirname}") logging.info(f"POSIX Path: {storage_dir.as_posix()}") @@ -138,42 +130,11 @@ def test_whisper_long_audio(cache_test_dir): pipe = get_pipeline( task = "automatic-speech-recognition", model_dir = storage_dir.as_posix(), - framework = "safetensors" ) res = pipe(f"{cache_test_dir}/resources/audio/long_sample.mp3") assert len(res["text"]) > 700 - -@require_torch -def test_wrap_conversation_pipeline(): - init_pipeline = pipeline( - "conversational", - model="microsoft/DialoGPT-small", - tokenizer="microsoft/DialoGPT-small", - framework="pt", - ) - conv_pipe = wrap_conversation_pipeline(init_pipeline) - data = [ - { - "role": "user", - "content": "Which movie is the best ?" - }, - { - "role": "assistant", - "content": "It's Die Hard for sure." - }, - { - "role": "user", - "content": "Can you explain why?" - } - ] - res = conv_pipe(data) - logging.info(f"Response: {res}") - assert res[-1]["role"] == "assistant" - assert "error" not in res[-1]["content"] - - @require_torch def test_wrapped_pipeline(): with tempfile.TemporaryDirectory() as tmpdirname: @@ -199,8 +160,8 @@ def test_wrapped_pipeline(): ] res = conv_pipe(data, max_new_tokens = 100) logging.info(f"Response: {res}") - assert res[-1]["role"] == "assistant" - assert "error" not in res[-1]["content"] + message = res[0]["generated_text"][-1] + assert message["role"] == "assistant" def test_local_custom_pipeline(cache_test_dir):