From 5786f818a60a9c1f33e96776b2adae7a91848a7e Mon Sep 17 00:00:00 2001
From: Alvaro Bartolome <36760800+alvarobartt@users.noreply.github.com>
Date: Thu, 8 Aug 2024 09:09:30 +0200
Subject: [PATCH 01/10] Propagate `**kwargs` to `sentence-transformers` and
`diffusers` pipelines
---
src/huggingface_inference_toolkit/diffusers_utils.py | 2 +-
.../sentence_transformers_utils.py | 2 +-
2 files changed, 2 insertions(+), 2 deletions(-)
diff --git a/src/huggingface_inference_toolkit/diffusers_utils.py b/src/huggingface_inference_toolkit/diffusers_utils.py
index f6241032..54cdb187 100644
--- a/src/huggingface_inference_toolkit/diffusers_utils.py
+++ b/src/huggingface_inference_toolkit/diffusers_utils.py
@@ -66,5 +66,5 @@ def __call__(
def get_diffusers_pipeline(task=None, model_dir=None, device=-1, **kwargs):
"""Get a pipeline for Diffusers models."""
device = "cuda" if device == 0 else "cpu"
- pipeline = DIFFUSERS_TASKS[task](model_dir=model_dir, device=device)
+ pipeline = DIFFUSERS_TASKS[task](model_dir=model_dir, device=device, **kwargs)
return pipeline
diff --git a/src/huggingface_inference_toolkit/sentence_transformers_utils.py b/src/huggingface_inference_toolkit/sentence_transformers_utils.py
index 72bb2ee2..a4df0955 100644
--- a/src/huggingface_inference_toolkit/sentence_transformers_utils.py
+++ b/src/huggingface_inference_toolkit/sentence_transformers_utils.py
@@ -54,5 +54,5 @@ def get_sentence_transformers_pipeline(
**kwargs
):
device = "cuda" if device == 0 else "cpu"
- pipeline = SENTENCE_TRANSFORMERS_TASKS[task](model_dir=model_dir, device=device)
+ pipeline = SENTENCE_TRANSFORMERS_TASKS[task](model_dir=model_dir, device=device, **kwargs)
return pipeline
From e8d689c90226aada395f7f136d1a8ab4831689c3 Mon Sep 17 00:00:00 2001
From: Alvaro Bartolome <36760800+alvarobartt@users.noreply.github.com>
Date: Thu, 8 Aug 2024 09:15:50 +0200
Subject: [PATCH 02/10] Add `HF_TRUST_REMOTE_CODE` env var
---
src/huggingface_inference_toolkit/const.py | 5 ++++-
src/huggingface_inference_toolkit/handler.py | 6 +++++-
2 files changed, 9 insertions(+), 2 deletions(-)
diff --git a/src/huggingface_inference_toolkit/const.py b/src/huggingface_inference_toolkit/const.py
index 993fea26..75dd4ba2 100644
--- a/src/huggingface_inference_toolkit/const.py
+++ b/src/huggingface_inference_toolkit/const.py
@@ -7,7 +7,10 @@
HF_FRAMEWORK = os.environ.get("HF_FRAMEWORK", None)
HF_REVISION = os.environ.get("HF_REVISION", None)
HF_HUB_TOKEN = os.environ.get("HF_HUB_TOKEN", None)
+HF_TRUST_REMOTE_CODE = os.environ.get("HF_TRUST_REMOTE_CODE", False)
# custom handler consts
HF_DEFAULT_PIPELINE_NAME = os.environ.get("HF_DEFAULT_PIPELINE_NAME", "handler.py")
# default is pipeline.PreTrainedPipeline
-HF_MODULE_NAME = os.environ.get("HF_MODULE_NAME", f"{Path(HF_DEFAULT_PIPELINE_NAME).stem}.EndpointHandler")
+HF_MODULE_NAME = os.environ.get(
+ "HF_MODULE_NAME", f"{Path(HF_DEFAULT_PIPELINE_NAME).stem}.EndpointHandler"
+)
diff --git a/src/huggingface_inference_toolkit/handler.py b/src/huggingface_inference_toolkit/handler.py
index 5b164af8..b77d9b89 100644
--- a/src/huggingface_inference_toolkit/handler.py
+++ b/src/huggingface_inference_toolkit/handler.py
@@ -2,6 +2,7 @@
from pathlib import Path
from typing import Optional, Union
+from huggingface_inference_toolkit.const import HF_TRUST_REMOTE_CODE
from huggingface_inference_toolkit.utils import (
check_and_register_custom_pipeline_from_directory,
get_pipeline,
@@ -16,7 +17,10 @@ class HuggingFaceHandler:
def __init__(self, model_dir: Union[str, Path], task=None, framework="pt"):
self.pipeline = get_pipeline(
- model_dir=model_dir, task=task, framework=framework
+ model_dir=model_dir,
+ task=task,
+ framework=framework,
+ trust_remote_code=HF_TRUST_REMOTE_CODE,
)
def __call__(self, data):
From 09e8d66f69986b4d1ec528d6fdda4951572b1e56 Mon Sep 17 00:00:00 2001
From: Alvaro Bartolome <36760800+alvarobartt@users.noreply.github.com>
Date: Thu, 8 Aug 2024 09:22:14 +0200
Subject: [PATCH 03/10] Fix `HF_TRUST_REMOTE_CODE` bool-handling via
`strtobool`
The `strtobool` had to be defined within `huggingface_inference_toolkit`
since it's deprecated and removed from `distutils` from Python 3.10
onwards.
---
src/huggingface_inference_toolkit/const.py | 4 +++-
src/huggingface_inference_toolkit/utils.py | 24 ++++++++++++++++++++++
2 files changed, 27 insertions(+), 1 deletion(-)
diff --git a/src/huggingface_inference_toolkit/const.py b/src/huggingface_inference_toolkit/const.py
index 75dd4ba2..6638cfba 100644
--- a/src/huggingface_inference_toolkit/const.py
+++ b/src/huggingface_inference_toolkit/const.py
@@ -1,13 +1,15 @@
import os
from pathlib import Path
+from huggingface_inference_toolkit.utils import strtobool
+
HF_MODEL_DIR = os.environ.get("HF_MODEL_DIR", "/opt/huggingface/model")
HF_MODEL_ID = os.environ.get("HF_MODEL_ID", None)
HF_TASK = os.environ.get("HF_TASK", None)
HF_FRAMEWORK = os.environ.get("HF_FRAMEWORK", None)
HF_REVISION = os.environ.get("HF_REVISION", None)
HF_HUB_TOKEN = os.environ.get("HF_HUB_TOKEN", None)
-HF_TRUST_REMOTE_CODE = os.environ.get("HF_TRUST_REMOTE_CODE", False)
+HF_TRUST_REMOTE_CODE = strtobool(os.environ.get("HF_TRUST_REMOTE_CODE", "0"))
# custom handler consts
HF_DEFAULT_PIPELINE_NAME = os.environ.get("HF_DEFAULT_PIPELINE_NAME", "handler.py")
# default is pipeline.PreTrainedPipeline
diff --git a/src/huggingface_inference_toolkit/utils.py b/src/huggingface_inference_toolkit/utils.py
index a0519d92..a35e5f1e 100644
--- a/src/huggingface_inference_toolkit/utils.py
+++ b/src/huggingface_inference_toolkit/utils.py
@@ -283,3 +283,27 @@ def convert_params_to_int_or_bool(params):
if v == "true":
params[k] = True
return params
+
+
+def strtobool(val: str) -> bool:
+ """Convert a string representation of truth to True or False booleans.
+ True values are 'y', 'yes', 't', 'true', 'on', and '1'; false values
+ are 'n', 'no', 'f', 'false', 'off', and '0'.
+
+ Raises:
+ ValueError: if 'val' is anything else.
+
+ Note:
+ Function `strtobool` copied and adapted from `distutils`, as it's deprected from Python 3.10 onwards.
+
+ References:
+ - https://github.com/python/cpython/blob/48f9d3e3faec5faaa4f7c9849fecd27eae4da213/Lib/distutils/util.py#L308-L321
+ """
+ val = val.lower()
+ if val in ("y", "yes", "t", "true", "on", "1"):
+ return True
+ if val in ("n", "no", "f", "false", "off", "0"):
+ return False
+ raise ValueError(
+ f"Invalid truth value, it should be a string but {val} was provided instead."
+ )
From 5905985a1468d62cb5be4e3892a4d95871d6c6e3 Mon Sep 17 00:00:00 2001
From: Alvaro Bartolome <36760800+alvarobartt@users.noreply.github.com>
Date: Thu, 8 Aug 2024 09:27:32 +0200
Subject: [PATCH 04/10] Fix some typos with `codespell`
---
setup.py | 2 +-
src/huggingface_inference_toolkit/handler.py | 2 +-
src/huggingface_inference_toolkit/utils.py | 2 +-
src/huggingface_inference_toolkit/webservice_starlette.py | 2 +-
4 files changed, 4 insertions(+), 4 deletions(-)
diff --git a/setup.py b/setup.py
index 7068f7a6..b9f804df 100644
--- a/setup.py
+++ b/setup.py
@@ -10,7 +10,7 @@
# Ubuntu packages
# libsndfile1-dev: torchaudio requires the development version of the libsndfile package which can be installed via a system package manager. On Ubuntu it can be installed as follows: apt install libsndfile1-dev
# ffmpeg: ffmpeg is required for audio processing. On Ubuntu it can be installed as follows: apt install ffmpeg
-# libavcodec-extra : libavcodec-extra inculdes additional codecs for ffmpeg
+# libavcodec-extra : libavcodec-extra includes additional codecs for ffmpeg
install_requires = [
"transformers[sklearn,sentencepiece,audio,vision]==4.41.1",
diff --git a/src/huggingface_inference_toolkit/handler.py b/src/huggingface_inference_toolkit/handler.py
index b77d9b89..636f185b 100644
--- a/src/huggingface_inference_toolkit/handler.py
+++ b/src/huggingface_inference_toolkit/handler.py
@@ -70,7 +70,7 @@ def __call__(self, data):
payload = {"inputs": inputs, "parameters": parameters}
predictions.append(super().__call__(payload))
- # reutrn predictions
+ # return predictions
return {"predictions": predictions}
diff --git a/src/huggingface_inference_toolkit/utils.py b/src/huggingface_inference_toolkit/utils.py
index a35e5f1e..6fd1ecc5 100644
--- a/src/huggingface_inference_toolkit/utils.py
+++ b/src/huggingface_inference_toolkit/utils.py
@@ -294,7 +294,7 @@ def strtobool(val: str) -> bool:
ValueError: if 'val' is anything else.
Note:
- Function `strtobool` copied and adapted from `distutils`, as it's deprected from Python 3.10 onwards.
+ Function `strtobool` copied and adapted from `distutils`, as it's deprecated from Python 3.10 onwards.
References:
- https://github.com/python/cpython/blob/48f9d3e3faec5faaa4f7c9849fecd27eae4da213/Lib/distutils/util.py#L308-L321
diff --git a/src/huggingface_inference_toolkit/webservice_starlette.py b/src/huggingface_inference_toolkit/webservice_starlette.py
index f8da6dcb..52cf161b 100644
--- a/src/huggingface_inference_toolkit/webservice_starlette.py
+++ b/src/huggingface_inference_toolkit/webservice_starlette.py
@@ -80,7 +80,7 @@ async def predict(request):
# checks if input schema is correct
if "inputs" not in deserialized_body and "instances" not in deserialized_body:
raise ValueError(
- f"Body needs to provide a inputs key, recieved: {orjson.dumps(deserialized_body)}"
+ f"Body needs to provide a inputs key, received: {orjson.dumps(deserialized_body)}"
)
# check for query parameter and add them to the body
From 43ad6e7bd415aadb30a4e691bc2fc98d19f334c4 Mon Sep 17 00:00:00 2001
From: Alvaro Bartolome <36760800+alvarobartt@users.noreply.github.com>
Date: Thu, 8 Aug 2024 09:27:44 +0200
Subject: [PATCH 05/10] Update `README.md`
---
README.md | 120 +++++++++++++++++++++++++++++-------------------------
1 file changed, 65 insertions(+), 55 deletions(-)
diff --git a/README.md b/README.md
index c7543569..66ef23a7 100644
--- a/README.md
+++ b/README.md
@@ -1,23 +1,22 @@
-

+
Hugging Face Inference Toolkit
-
Hugging Face Inference Toolkit is for serving 🤗 Transformers models in containers. This library provides default pre-processing, predict and postprocessing for Transformers, Sentence Tranfsformers. It is also possible to define custom `handler.py` for customization. The Toolkit is build to work with the [Hugging Face Hub](https://huggingface.co/models).
---
## 💻 Getting Started with Hugging Face Inference Toolkit
-* Clone the repository `git clone https://github.com/huggingface/huggingface-inference-toolkit``
-* Install the dependencies in dev mode `pip install -e ".[torch, st, diffusers, test,quality]"`
- * If you develop on AWS inferentia2 install with `pip install -e ".[test,quality]" optimum-neuron[neuronx] --upgrade`
+* Clone the repository `git clone
+* Install the dependencies in dev mode `pip install -e ".[torch,st,diffusers,test,quality]"`
+ * If you develop on AWS inferentia2 install with `pip install -e ".[test,quality]" optimum-neuron[neuronx] --upgrade`
+ * If you develop on Google Cloud install with `pip install -e ".[torch,st,diffusers,google,test,quality]"`
* Unit Testing: `make unit-test`
* Integration testing: `make integ-test`
-
### Local run
```bash
@@ -27,22 +26,22 @@ HF_MODEL_ID=hf-internal-testing/tiny-random-distilbert HF_MODEL_DIR=tmp2 HF_TASK
### Container
-
1. build the preferred container for either CPU or GPU for PyTorch.
-_cpu images_
+_CPU Images_
+
```bash
make inference-pytorch-cpu
```
-_gpu images_
+_GPU Images_
+
```bash
make inference-pytorch-gpu
```
2. Run the container and provide either environment variables to the HUB model you want to use or mount a volume to the container, where your model is stored.
-
```bash
docker run -ti -p 5000:5000 -e HF_MODEL_ID=distilbert-base-uncased-distilled-squad -e HF_TASK=question-answering integration-test-pytorch:cpu
docker run -ti -p 5000:5000 --gpus all -e HF_MODEL_ID=nlpconnect/vit-gpt2-image-captioning -e HF_TASK=image-to-text integration-test-pytorch:gpu
@@ -51,7 +50,6 @@ docker run -ti -p 5000:5000 --gpus all -e HF_MODEL_ID=stabilityai/stable-diffusi
docker run -ti -p 5000:5000 -e HF_MODEL_DIR=/repository -v $(pwd)/distilbert-base-uncased-emotion:/repository integration-test-pytorch:cpu
```
-
3. Send request. The API schema is the same as from the [inference API](https://huggingface.co/docs/api-inference/detailed_parameters)
```bash
@@ -59,17 +57,19 @@ curl --request POST \
--url http://localhost:5000 \
--header 'Content-Type: application/json' \
--data '{
- "inputs": {
- "question": "What is used for inference?",
- "context": "My Name is Philipp and I live in Nuremberg. This model is used with sagemaker for inference."
- }
+ "inputs": {
+ "question": "What is used for inference?",
+ "context": "My Name is Philipp and I live in Nuremberg. This model is used with sagemaker for inference."
+ }
}'
```
### Custom Handler and dependency support
-The Hugging Face Inference Toolkit allows user to provide a custom inference through a `handler.py` file which is located in the repository.
-For an example check [https://huggingface.co/philschmid/custom-pipeline-text-classification](https://huggingface.co/philschmid/custom-pipeline-text-classification):
+The Hugging Face Inference Toolkit allows user to provide a custom inference through a `handler.py` file which is located in the repository.
+
+For an example check [philschmid/custom-pipeline-text-classification](https://huggingface.co/philschmid/custom-pipeline-text-classification):
+
```bash
model.tar.gz/
|- pytorch_model.bin
@@ -77,17 +77,17 @@ model.tar.gz/
|- handler.py
|- requirements.txt
```
+
In this example, `pytroch_model.bin` is the model file saved from training, `handler.py` is the custom inference handler, and `requirements.txt` is a requirements file to add additional dependencies.
The custom module can override the following methods:
-
### Vertex AI Support
-The Hugging Face Inference Toolkit is also supported on Vertex AI, based on [Custom container requirements for prediction](https://cloud.google.com/vertex-ai/docs/predictions/custom-container-requirements). [Environment variables set by Vertex AI](https://cloud.google.com/vertex-ai/docs/predictions/custom-container-requirements#aip-variables) are automatically detected and used by the toolkit.
+The Hugging Face Inference Toolkit is also supported on Vertex AI, based on [Custom container requirements for prediction](https://cloud.google.com/vertex-ai/docs/predictions/custom-container-requirements). [Environment variables set by Vertex AI](https://cloud.google.com/vertex-ai/docs/predictions/custom-container-requirements#aip-variables) are automatically detected and used by the toolkit.
#### Local run with HF_MODEL_ID and HF_TASK
-Start Hugging Face Inference Toolkit with the following environment variables.
+Start Hugging Face Inference Toolkit with the following environment variables.
```bash
mkdir tmp2/
@@ -101,8 +101,8 @@ curl --request POST \
--url http://localhost:8080/pred \
--header 'Content-Type: application/json' \
--data '{
- "instances": ["I love this product", "I hate this product"],
- "parameters": { "top_k": 2 }
+ "instances": ["I love this product", "I hate this product"],
+ "parameters": { "top_k": 2 }
}'
```
@@ -124,18 +124,19 @@ docker run -ti -p 8080:8080 -e AIP_MODE=PREDICTION -e AIP_HTTP_PORT=8080 -e AIP_
```bash
curl --request POST \
- --url http://localhost:8080/pred \
- --header 'Content-Type: application/json' \
- --data '{
- "instances": ["I love this product", "I hate this product"],
- "parameters": { "top_k": 2 }
+ --url http://localhost:8080/pred \
+ --header 'Content-Type: application/json' \
+ --data '{
+ "instances": ["I love this product", "I hate this product"],
+ "parameters": { "top_k": 2 }
}'
```
-### AWS Inferentia2 Support
+### AWS Inferentia2 Support
The Hugging Face Inference Toolkit provides support for deploying Hugging Face on AWS Inferentia2. To deploy a model on Inferentia2 you have 3 options:
-* Provide `HF_MODEL_ID`, the model repo id on huggingface.co which contains the compiled model under `.neuron` format. e.g. `optimum/bge-base-en-v1.5-neuronx`
+
+* Provide `HF_MODEL_ID`, the model repo id on huggingface.co which contains the compiled model under `.neuron` format e.g. `optimum/bge-base-en-v1.5-neuronx`
* Provide the `HF_OPTIMUM_BATCH_SIZE` and `HF_OPTIMUM_SEQUENCE_LENGTH` environment variables to compile the model on the fly, e.g. `HF_OPTIMUM_BATCH_SIZE=1 HF_OPTIMUM_SEQUENCE_LENGTH=128`
* Include `neuron` dictionary in the [config.json](https://huggingface.co/optimum/tiny_random_bert_neuron/blob/main/config.json) file in the model archive, e.g. `neuron: {"static_batch_size": 1, "static_sequence_length": 128}`
@@ -143,16 +144,19 @@ The currently supported tasks can be found [here](https://huggingface.co/docs/op
#### Local run with HF_MODEL_ID and HF_TASK
-Start Hugging Face Inference Toolkit with the following environment variables.
+Start Hugging Face Inference Toolkit with the following environment variables.
_Note: You need to run this on an Inferentia2 instance._
-- transformers `text-classification` with `HF_OPTIMUM_BATCH_SIZE` and `HF_OPTIMUM_SEQUENCE_LENGTH`
+* transformers `text-classification` with `HF_OPTIMUM_BATCH_SIZE` and `HF_OPTIMUM_SEQUENCE_LENGTH`
+
```bash
mkdir tmp2/
HF_MODEL_ID="distilbert/distilbert-base-uncased-finetuned-sst-2-english" HF_TASK="text-classification" HF_OPTIMUM_BATCH_SIZE=1 HF_OPTIMUM_SEQUENCE_LENGTH=128 HF_MODEL_DIR=tmp2 uvicorn src.huggingface_inference_toolkit.webservice_starlette:app --port 5000
```
-- sentence transformers `feature-extration` with `HF_OPTIMUM_BATCH_SIZE` and `HF_OPTIMUM_SEQUENCE_LENGTH`
+
+* sentence transformers `feature-extration` with `HF_OPTIMUM_BATCH_SIZE` and `HF_OPTIMUM_SEQUENCE_LENGTH`
+
```bash
HF_MODEL_ID="sentence-transformers/all-MiniLM-L6-v2" HF_TASK="feature-extraction" HF_OPTIMUM_BATCH_SIZE=1 HF_OPTIMUM_SEQUENCE_LENGTH=128 HF_MODEL_DIR=tmp2 uvicorn src.huggingface_inference_toolkit.webservice_starlette:app --port 5000
```
@@ -161,16 +165,15 @@ Send request
```bash
curl --request POST \
- --url http://localhost:5000 \
- --header 'Content-Type: application/json' \
- --data '{
- "inputs": "Wow, this is such a great product. I love it!"
+ --url http://localhost:5000 \
+ --header 'Content-Type: application/json' \
+ --data '{
+ "inputs": "Wow, this is such a great product. I love it!"
}'
```
#### Container run with HF_MODEL_ID and HF_TASK
-
1. build the preferred container for either CPU or GPU for PyTorch o.
```bash
@@ -187,26 +190,25 @@ docker run -ti -p 5000:5000 -e HF_MODEL_ID="distilbert/distilbert-base-uncased-f
```bash
curl --request POST \
- --url http://localhost:5000 \
- --header 'Content-Type: application/json' \
- --data '{
- "inputs": "Wow, this is such a great product. I love it!",
- "parameters": { "top_k": 2 }
+ --url http://localhost:5000 \
+ --header 'Content-Type: application/json' \
+ --data '{
+ "inputs": "Wow, this is such a great product. I love it!",
+ "parameters": { "top_k": 2 }
}'
```
-
---
## 🛠️ Environment variables
-The Hugging Face Inference Toolkit implements various additional environment variables to simplify your deployment experience. A full list of environment variables is given below. All potential environment varialbes can be found in [const.py](src/huggingface_inference_toolkit/const.py)
+The Hugging Face Inference Toolkit implements various additional environment variables to simplify your deployment experience. A full list of environment variables is given below. All potential environment variables can be found in [const.py](src/huggingface_inference_toolkit/const.py)
### `HF_MODEL_DIR`
-The `HF_MODEL_DIR` environment variable defines the directory where your model is stored or will be stored.
-If `HF_MODEL_ID` is not set the toolkit expects a the model artifact at this directory. This value should be set to the value where you mount your model artifacts.
-If `HF_MODEL_ID` is set the toolkit and the directory where `HF_MODEL_DIR` is pointing to is empty. The toolkit will download the model from the Hub to this directory.
+The `HF_MODEL_DIR` environment variable defines the directory where your model is stored or will be stored.
+If `HF_MODEL_ID` is not set the toolkit expects a the model artifact at this directory. This value should be set to the value where you mount your model artifacts.
+If `HF_MODEL_ID` is set the toolkit and the directory where `HF_MODEL_DIR` is pointing to is empty. The toolkit will download the model from the Hub to this directory.
The default value is `/opt/huggingface/model`
@@ -246,6 +248,14 @@ The `HF_HUB_TOKEN` environment variable defines the your Hugging Face authorizat
HF_HUB_TOKEN="api_XXXXXXXXXXXXXXXXXXXXXXXXXXXXX"
```
+### `HF_TRUST_REMOTE_CODE`
+
+The `HF_TRUST_REMOTE_CODE` environment variable defines whether to trust remote code. This flag is already used for community defined inference code, and is therefore quite representative of the level of confidence you are giving the model providers when loading models from the Hugging Face Hub. The default value is `"0"`; set it to `"1"` to trust remote code.
+
+```bash
+HF_TRUST_REMOTE_CODE="0"
+```
+
### `HF_FRAMEWORK`
The `HF_FRAMEWORK` environment variable defines the base deep learning framework used in the container. This is important when loading large models from the Hugguing Face Hub to avoid extra file downloads.
@@ -256,7 +266,7 @@ HF_FRAMEWORK="pytorch"
#### `HF_OPTIMUM_BATCH_SIZE`
-The `HF_OPTIMUM_BATCH_SIZE` environment variable defines the batch size, which is used when compiling the model to Neuron. The default value is `1`. Not required when model is already converted.
+The `HF_OPTIMUM_BATCH_SIZE` environment variable defines the batch size, which is used when compiling the model to Neuron. The default value is `1`. Not required when model is already converted.
```bash
HF_OPTIMUM_BATCH_SIZE="1"
@@ -264,7 +274,7 @@ HF_OPTIMUM_BATCH_SIZE="1"
#### `HF_OPTIMUM_SEQUENCE_LENGTH`
-The `HF_OPTIMUM_SEQUENCE_LENGTH` environment variable defines the sequence length, which is used when compiling the model to Neuron. There is no default value. Not required when model is already converted.
+The `HF_OPTIMUM_SEQUENCE_LENGTH` environment variable defines the sequence length, which is used when compiling the model to Neuron. There is no default value. Not required when model is already converted.
```bash
HF_OPTIMUM_SEQUENCE_LENGTH="128"
@@ -272,12 +282,12 @@ HF_OPTIMUM_SEQUENCE_LENGTH="128"
---
-## ⚙ Supported Frontend
+## ⚙ Supported Front-Ends
-- [x] Starlette (HF Endpoints)
-- [x] Starlette (Vertex AI)
-- [ ] Starlette (Azure ML)
-- [ ] Starlette (SageMaker)
+* [x] Starlette (HF Endpoints)
+* [x] Starlette (Vertex AI)
+* [ ] Starlette (Azure ML)
+* [ ] Starlette (SageMaker)
---
@@ -287,6 +297,6 @@ HF_OPTIMUM_SEQUENCE_LENGTH="128"
## 📜 License
-TBD.
+TBD.
---
From c9384c2a229a31fab7a940b74a6bd5db34e955d0 Mon Sep 17 00:00:00 2001
From: Alvaro Bartolome <36760800+alvarobartt@users.noreply.github.com>
Date: Thu, 8 Aug 2024 12:15:44 +0200
Subject: [PATCH 06/10] Bump version to `0.4.2`
---
setup.py | 2 +-
1 file changed, 1 insertion(+), 1 deletion(-)
diff --git a/setup.py b/setup.py
index b9f804df..374ee85b 100644
--- a/setup.py
+++ b/setup.py
@@ -5,7 +5,7 @@
# We don't declare our dependency on transformers here because we build with
# different packages for different variants
-VERSION = "0.4.1.dev0"
+VERSION = "0.4.2"
# Ubuntu packages
# libsndfile1-dev: torchaudio requires the development version of the libsndfile package which can be installed via a system package manager. On Ubuntu it can be installed as follows: apt install libsndfile1-dev
From c1c37a364588204d3f59e32a6428ae3255e8298b Mon Sep 17 00:00:00 2001
From: Alvaro Bartolome <36760800+alvarobartt@users.noreply.github.com>
Date: Thu, 8 Aug 2024 15:29:28 +0200
Subject: [PATCH 07/10] Move `strtobool` to `env_utils` module to avoid
circular import
---
src/huggingface_inference_toolkit/const.py | 2 +-
.../env_utils.py | 22 +++++++++++
src/huggingface_inference_toolkit/utils.py | 38 ++++++-------------
3 files changed, 35 insertions(+), 27 deletions(-)
create mode 100644 src/huggingface_inference_toolkit/env_utils.py
diff --git a/src/huggingface_inference_toolkit/const.py b/src/huggingface_inference_toolkit/const.py
index 6638cfba..eb6ddcb8 100644
--- a/src/huggingface_inference_toolkit/const.py
+++ b/src/huggingface_inference_toolkit/const.py
@@ -1,7 +1,7 @@
import os
from pathlib import Path
-from huggingface_inference_toolkit.utils import strtobool
+from huggingface_inference_toolkit.env_utils import strtobool
HF_MODEL_DIR = os.environ.get("HF_MODEL_DIR", "/opt/huggingface/model")
HF_MODEL_ID = os.environ.get("HF_MODEL_ID", None)
diff --git a/src/huggingface_inference_toolkit/env_utils.py b/src/huggingface_inference_toolkit/env_utils.py
new file mode 100644
index 00000000..e582ec98
--- /dev/null
+++ b/src/huggingface_inference_toolkit/env_utils.py
@@ -0,0 +1,22 @@
+def strtobool(val: str) -> bool:
+ """Convert a string representation of truth to True or False booleans.
+ True values are 'y', 'yes', 't', 'true', 'on', and '1'; false values
+ are 'n', 'no', 'f', 'false', 'off', and '0'.
+
+ Raises:
+ ValueError: if 'val' is anything else.
+
+ Note:
+ Function `strtobool` copied and adapted from `distutils`, as it's deprecated from Python 3.10 onwards.
+
+ References:
+ - https://github.com/python/cpython/blob/48f9d3e3faec5faaa4f7c9849fecd27eae4da213/Lib/distutils/util.py#L308-L321
+ """
+ val = val.lower()
+ if val in ("y", "yes", "t", "true", "on", "1"):
+ return True
+ if val in ("n", "no", "f", "false", "off", "0"):
+ return False
+ raise ValueError(
+ f"Invalid truth value, it should be a string but {val} was provided instead."
+ )
diff --git a/src/huggingface_inference_toolkit/utils.py b/src/huggingface_inference_toolkit/utils.py
index 6fd1ecc5..9c383735 100644
--- a/src/huggingface_inference_toolkit/utils.py
+++ b/src/huggingface_inference_toolkit/utils.py
@@ -8,7 +8,11 @@
from transformers.file_utils import is_tf_available, is_torch_available
from transformers.pipelines import Pipeline
-from huggingface_inference_toolkit.const import HF_DEFAULT_PIPELINE_NAME, HF_MODULE_NAME
+from huggingface_inference_toolkit.const import (
+ HF_DEFAULT_PIPELINE_NAME,
+ HF_MODULE_NAME,
+ HF_TRUST_REMOTE_CODE,
+)
from huggingface_inference_toolkit.diffusers_utils import (
get_diffusers_pipeline,
is_diffusers_available,
@@ -243,6 +247,10 @@ def get_pipeline(
else:
kwargs["tokenizer"] = model_dir
+ logger.info(f"Creating pipeline for task: {task}")
+ logger.info(f"Using kwargs: {kwargs}")
+ logger.info(f"{HF_TRUST_REMOTE_CODE=}")
+
if is_optimum_neuron_available():
hf_pipeline = get_optimum_neuron_pipeline(task=task, model_dir=model_dir)
elif is_sentence_transformers_available() and task in [
@@ -258,7 +266,9 @@ def get_pipeline(
task=task, model_dir=model_dir, device=device, **kwargs
)
else:
- hf_pipeline = pipeline(task=task, model=model_dir, device=device, **kwargs)
+ hf_pipeline = pipeline(
+ task=task, model=model_dir, device=device, trust_remote_code=True
+ ) # **kwargs)
if task == "automatic-speech-recognition" and isinstance(
hf_pipeline.model, WhisperForConditionalGeneration
@@ -283,27 +293,3 @@ def convert_params_to_int_or_bool(params):
if v == "true":
params[k] = True
return params
-
-
-def strtobool(val: str) -> bool:
- """Convert a string representation of truth to True or False booleans.
- True values are 'y', 'yes', 't', 'true', 'on', and '1'; false values
- are 'n', 'no', 'f', 'false', 'off', and '0'.
-
- Raises:
- ValueError: if 'val' is anything else.
-
- Note:
- Function `strtobool` copied and adapted from `distutils`, as it's deprecated from Python 3.10 onwards.
-
- References:
- - https://github.com/python/cpython/blob/48f9d3e3faec5faaa4f7c9849fecd27eae4da213/Lib/distutils/util.py#L308-L321
- """
- val = val.lower()
- if val in ("y", "yes", "t", "true", "on", "1"):
- return True
- if val in ("n", "no", "f", "false", "off", "0"):
- return False
- raise ValueError(
- f"Invalid truth value, it should be a string but {val} was provided instead."
- )
From 3569eab0f9717cce963fff25f6a540aec2b88182 Mon Sep 17 00:00:00 2001
From: Alvaro Bartolome <36760800+alvarobartt@users.noreply.github.com>
Date: Thu, 8 Aug 2024 15:30:11 +0200
Subject: [PATCH 08/10] Revert enforce of `trust_remote_code=True`
---
src/huggingface_inference_toolkit/utils.py | 4 +---
1 file changed, 1 insertion(+), 3 deletions(-)
diff --git a/src/huggingface_inference_toolkit/utils.py b/src/huggingface_inference_toolkit/utils.py
index 9c383735..f5d76885 100644
--- a/src/huggingface_inference_toolkit/utils.py
+++ b/src/huggingface_inference_toolkit/utils.py
@@ -266,9 +266,7 @@ def get_pipeline(
task=task, model_dir=model_dir, device=device, **kwargs
)
else:
- hf_pipeline = pipeline(
- task=task, model=model_dir, device=device, trust_remote_code=True
- ) # **kwargs)
+ hf_pipeline = pipeline(task=task, model=model_dir, device=device, **kwargs)
if task == "automatic-speech-recognition" and isinstance(
hf_pipeline.model, WhisperForConditionalGeneration
From 0f7235ab127befcaeb8b3d546bb5636a6f578bfe Mon Sep 17 00:00:00 2001
From: Alvaro Bartolome <36760800+alvarobartt@users.noreply.github.com>
Date: Fri, 9 Aug 2024 10:01:34 +0200
Subject: [PATCH 09/10] Remove `logging` messages for debug
---
src/huggingface_inference_toolkit/utils.py | 4 ----
1 file changed, 4 deletions(-)
diff --git a/src/huggingface_inference_toolkit/utils.py b/src/huggingface_inference_toolkit/utils.py
index f5d76885..89261d71 100644
--- a/src/huggingface_inference_toolkit/utils.py
+++ b/src/huggingface_inference_toolkit/utils.py
@@ -247,10 +247,6 @@ def get_pipeline(
else:
kwargs["tokenizer"] = model_dir
- logger.info(f"Creating pipeline for task: {task}")
- logger.info(f"Using kwargs: {kwargs}")
- logger.info(f"{HF_TRUST_REMOTE_CODE=}")
-
if is_optimum_neuron_available():
hf_pipeline = get_optimum_neuron_pipeline(task=task, model_dir=model_dir)
elif is_sentence_transformers_available() and task in [
From db6e1d9f55de8f9895d8705590c7a39bddf1e649 Mon Sep 17 00:00:00 2001
From: Alvaro Bartolome <36760800+alvarobartt@users.noreply.github.com>
Date: Fri, 9 Aug 2024 10:15:20 +0200
Subject: [PATCH 10/10] Fix `diffusers` propagation of `trust_remote_code=True`
---
src/huggingface_inference_toolkit/diffusers_utils.py | 7 +++++--
1 file changed, 5 insertions(+), 2 deletions(-)
diff --git a/src/huggingface_inference_toolkit/diffusers_utils.py b/src/huggingface_inference_toolkit/diffusers_utils.py
index 54cdb187..afe96676 100644
--- a/src/huggingface_inference_toolkit/diffusers_utils.py
+++ b/src/huggingface_inference_toolkit/diffusers_utils.py
@@ -1,4 +1,5 @@
import importlib.util
+from typing import Union
from transformers.utils.import_utils import is_torch_bf16_gpu_available
@@ -21,14 +22,16 @@ def is_diffusers_available():
class IEAutoPipelineForText2Image:
- def __init__(self, model_dir: str, device: str = None): # needs "cuda" for GPU
+ def __init__(
+ self, model_dir: str, device: Union[str, None] = None, **kwargs
+ ): # needs "cuda" for GPU
dtype = torch.float32
if device == "cuda":
dtype = torch.bfloat16 if is_torch_bf16_gpu_available() else torch.float16
device_map = "auto" if device == "cuda" else None
self.pipeline = AutoPipelineForText2Image.from_pretrained(
- model_dir, torch_dtype=dtype, device_map=device_map
+ model_dir, torch_dtype=dtype, device_map=device_map, **kwargs
)
# try to use DPMSolverMultistepScheduler
if isinstance(self.pipeline, StableDiffusionPipeline):