Skip to content

Commit b551223

Browse files
authored
Merge pull request #73 from huggingface/add-inf2-support
[Inf2] Add Optimum Neuron support for Encoder models
2 parents f2ae200 + 2185d4a commit b551223

34 files changed

+956
-691
lines changed

.github/workflows/build-container.yaml

Lines changed: 4 additions & 13 deletions
Original file line numberDiff line numberDiff line change
@@ -34,21 +34,12 @@ jobs:
3434
TAILSCALE_AUTHKEY: ${{ secrets.TAILSCALE_AUTHKEY }}
3535
REGISTRY_USERNAME: ${{ secrets.REGISTRY_USERNAME }}
3636
REGISTRY_PASSWORD: ${{ secrets.REGISTRY_PASSWORD }}
37-
starlette-tensorflow-cpu:
37+
starlette-pytorch-inf2:
3838
uses: ./.github/workflows/docker-build-action.yaml
3939
with:
40-
image: inference-tensorflow-cpu
41-
dockerfile: dockerfiles/tensorflow/cpu/Dockerfile
40+
image: inference-pytorch-inf2
41+
dockerfile: dockerfiles/pytorch/Dockerfile.inf2
4242
secrets:
4343
TAILSCALE_AUTHKEY: ${{ secrets.TAILSCALE_AUTHKEY }}
4444
REGISTRY_USERNAME: ${{ secrets.REGISTRY_USERNAME }}
45-
REGISTRY_PASSWORD: ${{ secrets.REGISTRY_PASSWORD }}
46-
starlette-tensorflow-gpu:
47-
uses: ./.github/workflows/docker-build-action.yaml
48-
with:
49-
image: inference-tensorflow-gpu
50-
dockerfile: dockerfiles/tensorflow/gpu/Dockerfile
51-
secrets:
52-
TAILSCALE_AUTHKEY: ${{ secrets.TAILSCALE_AUTHKEY }}
53-
REGISTRY_USERNAME: ${{ secrets.REGISTRY_USERNAME }}
54-
REGISTRY_PASSWORD: ${{ secrets.REGISTRY_PASSWORD }}
45+
REGISTRY_PASSWORD: ${{ secrets.REGISTRY_PASSWORD }}

.github/workflows/integration-test.yaml

Lines changed: 2 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ jobs:
2222
with:
2323
test_path: "tests/integ/test_pytorch_local_gpu.py"
2424
build_img_cmd: "make inference-pytorch-gpu"
25+
test_parallelism: "1"
2526
pytorch-integration-remote-gpu:
2627
name: Remote Integration Tests - GPU
2728
uses: ./.github/workflows/integration-test-action.yaml
@@ -41,4 +42,5 @@ jobs:
4142
with:
4243
test_path: "tests/integ/test_pytorch_local_cpu.py"
4344
build_img_cmd: "make inference-pytorch-cpu"
45+
test_parallelism: "1"
4446
runs_on: "['ci']"

.gitignore

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -179,4 +179,6 @@ model
179179
tests/tmp
180180
tmp/
181181
act.sh
182-
.act
182+
.act
183+
tmp*
184+
log-*

README.md

Lines changed: 95 additions & 57 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,16 @@
88
Hugging Face Inference Toolkit is for serving 🤗 Transformers models in containers. This library provides default pre-processing, predict and postprocessing for Transformers, Sentence Tranfsformers. It is also possible to define custom `handler.py` for customization. The Toolkit is build to work with the [Hugging Face Hub](https://huggingface.co/models).
99

1010
---
11+
1112
## 💻 Getting Started with Hugging Face Inference Toolkit
1213

14+
* Clone the repository `git clone https://github.com/huggingface/huggingface-inference-toolkit``
15+
* Install the dependencies in dev mode `pip install -e ".[torch, st, diffusers, test,quality]"`
16+
* If you develop on AWS inferentia2 install with `pip install -e ".[test,quality]" optimum-neuron[neuronx] --upgrade`
17+
* Unit Testing: `make unit-test`
18+
* Integration testing: `make integ-test`
19+
20+
1321
### Local run
1422

1523
```bash
@@ -58,6 +66,21 @@ curl --request POST \
5866
}'
5967
```
6068

69+
### Custom Handler and dependency support
70+
71+
The Hugging Face Inference Toolkit allows user to provide a custom inference through a `handler.py` file which is located in the repository.
72+
For an example check [https://huggingface.co/philschmid/custom-pipeline-text-classification](https://huggingface.co/philschmid/custom-pipeline-text-classification):
73+
```bash
74+
model.tar.gz/
75+
|- pytorch_model.bin
76+
|- ....
77+
|- handler.py
78+
|- requirements.txt
79+
```
80+
In this example, `pytroch_model.bin` is the model file saved from training, `handler.py` is the custom inference handler, and `requirements.txt` is a requirements file to add additional dependencies.
81+
The custom module can override the following methods:
82+
83+
6184
### Vertex AI Support
6285

6386
The Hugging Face Inference Toolkit is also supported on Vertex AI, based on [Custom container requirements for prediction](https://cloud.google.com/vertex-ai/docs/predictions/custom-container-requirements). [Environment variables set by Vertex AI](https://cloud.google.com/vertex-ai/docs/predictions/custom-container-requirements#aip-variables) are automatically detected and used by the toolkit.
@@ -109,6 +132,69 @@ curl --request POST \
109132
}'
110133
```
111134

135+
### AWS Inferentia2 Support
136+
137+
The Hugging Face Inference Toolkit provides support for deploying Hugging Face on AWS Inferentia2. To deploy a model on Inferentia2 you have 3 options:
138+
* Provide `HF_MODEL_ID`, the model repo id on huggingface.co which contains the compiled model under `.neuron` format. e.g. `optimum/bge-base-en-v1.5-neuronx`
139+
* Provide the `HF_OPTIMUM_BATCH_SIZE` and `HF_OPTIMUM_SEQUENCE_LENGTH` environment variables to compile the model on the fly, e.g. `HF_OPTIMUM_BATCH_SIZE=1 HF_OPTIMUM_SEQUENCE_LENGTH=128`
140+
* Include `neuron` dictionary in the [config.json](https://huggingface.co/optimum/tiny_random_bert_neuron/blob/main/config.json) file in the model archive, e.g. `neuron: {"static_batch_size": 1, "static_sequence_length": 128}`
141+
142+
The currently supported tasks can be found [here](https://huggingface.co/docs/optimum-neuron/en/package_reference/supported_models). If you plan to deploy an LLM, we recommend taking a look at [Neuronx TGI](https://huggingface.co/blog/text-generation-inference-on-inferentia2), which is purposly build for LLMs.
143+
144+
#### Local run with HF_MODEL_ID and HF_TASK
145+
146+
Start Hugging Face Inference Toolkit with the following environment variables.
147+
148+
_Note: You need to run this on an Inferentia2 instance._
149+
150+
- transformers `text-classification` with `HF_OPTIMUM_BATCH_SIZE` and `HF_OPTIMUM_SEQUENCE_LENGTH`
151+
```bash
152+
mkdir tmp2/
153+
HF_MODEL_ID="distilbert/distilbert-base-uncased-finetuned-sst-2-english" HF_TASK="text-classification" HF_OPTIMUM_BATCH_SIZE=1 HF_OPTIMUM_SEQUENCE_LENGTH=128 HF_MODEL_DIR=tmp2 uvicorn src.huggingface_inference_toolkit.webservice_starlette:app --port 5000
154+
```
155+
- sentence transformers `feature-extration` with `HF_OPTIMUM_BATCH_SIZE` and `HF_OPTIMUM_SEQUENCE_LENGTH`
156+
```bash
157+
HF_MODEL_ID="sentence-transformers/all-MiniLM-L6-v2" HF_TASK="feature-extraction" HF_OPTIMUM_BATCH_SIZE=1 HF_OPTIMUM_SEQUENCE_LENGTH=128 HF_MODEL_DIR=tmp2 uvicorn src.huggingface_inference_toolkit.webservice_starlette:app --port 5000
158+
```
159+
160+
Send request
161+
162+
```bash
163+
curl --request POST \
164+
--url http://localhost:5000 \
165+
--header 'Content-Type: application/json' \
166+
--data '{
167+
"inputs": "Wow, this is such a great product. I love it!"
168+
}'
169+
```
170+
171+
#### Container run with HF_MODEL_ID and HF_TASK
172+
173+
174+
1. build the preferred container for either CPU or GPU for PyTorch o.
175+
176+
```bash
177+
make inference-pytorch-inf2
178+
```
179+
180+
2. Run the container and provide either environment variables to the HUB model you want to use or mount a volume to the container, where your model is stored.
181+
182+
```bash
183+
docker run -ti -p 5000:5000 -e HF_MODEL_ID="distilbert/distilbert-base-uncased-finetuned-sst-2-english" -e HF_TASK="text-classification" -e HF_OPTIMUM_BATCH_SIZE=1 -e HF_OPTIMUM_SEQUENCE_LENGTH=128 --device=/dev/neuron0 integration-test-pytorch:inf2
184+
```
185+
186+
3. Send request
187+
188+
```bash
189+
curl --request POST \
190+
--url http://localhost:5000 \
191+
--header 'Content-Type: application/json' \
192+
--data '{
193+
"inputs": "Wow, this is such a great product. I love it!",
194+
"parameters": { "top_k": 2 }
195+
}'
196+
```
197+
112198

113199
---
114200

@@ -168,61 +254,23 @@ The `HF_FRAMEWORK` environment variable defines the base deep learning framework
168254
HF_FRAMEWORK="pytorch"
169255
```
170256

171-
### `HF_ENDPOINT`
257+
#### `HF_OPTIMUM_BATCH_SIZE`
172258

173-
The `HF_ENDPOINT` environment variable indicates whether the service is run inside the HF Inference endpoint service to adjust the `logging` config.
259+
The `HF_OPTIMUM_BATCH_SIZE` environment variable defines the batch size, which is used when compiling the model to Neuron. The default value is `1`. Not required when model is already converted.
174260

175261
```bash
176-
HF_ENDPOINT="True"
262+
HF_OPTIMUM_BATCH_SIZE="1"
177263
```
178264

265+
#### `HF_OPTIMUM_SEQUENCE_LENGTH`
179266

180-
---
267+
The `HF_OPTIMUM_SEQUENCE_LENGTH` environment variable defines the sequence length, which is used when compiling the model to Neuron. There is no default value. Not required when model is already converted.
181268

182-
## 🧑🏻‍💻 Custom Handler and dependency support
183-
184-
The Hugging Face Inference Toolkit allows user to provide a custom inference through a `handler.py` file which is located in the repository.
185-
For an example check [https://huggingface.co/philschmid/custom-pipeline-text-classification](https://huggingface.co/philschmid/custom-pipeline-text-classification):
186269
```bash
187-
model.tar.gz/
188-
|- pytorch_model.bin
189-
|- ....
190-
|- handler.py
191-
|- requirements.txt
270+
HF_OPTIMUM_SEQUENCE_LENGTH="128"
192271
```
193-
In this example, `pytroch_model.bin` is the model file saved from training, `handler.py` is the custom inference handler, and `requirements.txt` is a requirements file to add additional dependencies.
194-
The custom module can override the following methods:
195-
196272

197-
## ☑️ Supported & Tested Tasks
198-
199-
Below you ll find a list of supported and tested transformers and sentence transformers tasks. Each of those are always tested through integration tests. In addition to those tasks you can always provide `custom`, which expect a `handler.py` file to be provided.
200-
201-
```bash
202-
"text-classification",
203-
"zero-shot-classification",
204-
"ner",
205-
"question-answering",
206-
"fill-mask",
207-
"summarization",
208-
"translation_xx_to_yy",
209-
"text2text-generation",
210-
"text-generation",
211-
"feature-extraction",
212-
"image-classification",
213-
"automatic-speech-recognition",
214-
"audio-classification",
215-
"object-detection",
216-
"image-segmentation",
217-
"table-question-answering",
218-
"conversational"
219-
"sentence-similarity",
220-
"sentence-embeddings",
221-
"sentence-ranking",
222-
# TODO currently not supported due to multimodality input
223-
# "visual-question-answering",
224-
# "zero-shot-image-classification",
225-
```
273+
---
226274

227275
## ⚙ Supported Frontend
228276

@@ -232,21 +280,11 @@ Below you ll find a list of supported and tested transformers and sentence trans
232280
- [ ] Starlette (SageMaker)
233281

234282
---
235-
## 🤝 Contributing
236-
237-
### Development
238-
239-
* Recommended Python version: 3.11
240-
* We recommend `pyenv` for easily switching between different Python versions
241-
* There are two options for unit and integration tests:
242-
* `Make` - see `makefile`
243283

244-
#### Testing with Make
245-
246-
* Unit Testing: `make unit-test`
247-
* Integration testing: `make integ-test`
284+
## 🤝 Contributing
248285

249286
---
287+
250288
## 📜 License
251289

252290
TBD.

dockerfiles/pytorch/Dockerfile

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -51,4 +51,4 @@ ENTRYPOINT ["bash", "-c", "./entrypoint.sh"]
5151
from base as vertex
5252

5353
# Install Vertex AI requiremented packages
54-
RUN pip install --no-cache-dir google-cloud-storage
54+
RUN pip install --no-cache-dir google-cloud-storage

dockerfiles/pytorch/Dockerfile.inf2

Lines changed: 122 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,122 @@
1+
# Build based on https://github.com/aws/deep-learning-containers/blob/master/huggingface/pytorch/inference/docker/2.1/py3/sdk2.18.0/Dockerfile.neuronx
2+
FROM ubuntu:20.04
3+
4+
LABEL maintainer="Hugging Face"
5+
6+
ARG PYTHON=python3.10
7+
ARG PYTHON_VERSION=3.10.12
8+
ARG MAMBA_VERSION=23.1.0-4
9+
10+
# Neuron SDK components version numbers
11+
ARG NEURONX_FRAMEWORK_VERSION=2.1.2.2.1.0
12+
ARG NEURONX_DISTRIBUTED_VERSION=0.7.0
13+
ARG NEURONX_CC_VERSION=2.13.66.0
14+
ARG NEURONX_TRANSFORMERS_VERSION=0.10.0.21
15+
ARG NEURONX_COLLECTIVES_LIB_VERSION=2.20.22.0-c101c322e
16+
ARG NEURONX_RUNTIME_LIB_VERSION=2.20.22.0-1b3ca6425
17+
ARG NEURONX_TOOLS_VERSION=2.17.1.0
18+
19+
# HF ARGS
20+
ARG OPTIMUM_NEURON_VERSION=0.0.23
21+
22+
# See http://bugs.python.org/issue19846
23+
ENV LANG C.UTF-8
24+
ENV LD_LIBRARY_PATH /opt/aws/neuron/lib:/lib/x86_64-linux-gnu:/opt/conda/lib/:$LD_LIBRARY_PATH
25+
ENV PATH /opt/conda/bin:/opt/aws/neuron/bin:$PATH
26+
27+
RUN apt-get update \
28+
&& apt-get upgrade -y \
29+
&& apt-get install -y --no-install-recommends software-properties-common \
30+
&& add-apt-repository ppa:openjdk-r/ppa \
31+
&& apt-get update \
32+
&& apt-get install -y --no-install-recommends \
33+
build-essential \
34+
apt-transport-https \
35+
ca-certificates \
36+
cmake \
37+
curl \
38+
emacs \
39+
git \
40+
jq \
41+
libgl1-mesa-glx \
42+
libsm6 \
43+
libxext6 \
44+
libxrender-dev \
45+
openjdk-11-jdk \
46+
vim \
47+
wget \
48+
unzip \
49+
zlib1g-dev \
50+
libcap-dev \
51+
gpg-agent \
52+
&& rm -rf /var/lib/apt/lists/* \
53+
&& rm -rf /tmp/tmp* \
54+
&& apt-get clean
55+
56+
RUN echo "deb https://apt.repos.neuron.amazonaws.com focal main" > /etc/apt/sources.list.d/neuron.list
57+
RUN wget -qO - https://apt.repos.neuron.amazonaws.com/GPG-PUB-KEY-AMAZON-AWS-NEURON.PUB | apt-key add -
58+
59+
# Install Neuronx tools
60+
RUN apt-get update \
61+
&& apt-get install -y \
62+
aws-neuronx-tools=$NEURONX_TOOLS_VERSION \
63+
aws-neuronx-collectives=$NEURONX_COLLECTIVES_LIB_VERSION \
64+
aws-neuronx-runtime-lib=$NEURONX_RUNTIME_LIB_VERSION \
65+
&& rm -rf /var/lib/apt/lists/* \
66+
&& rm -rf /tmp/tmp* \
67+
&& apt-get clean
68+
69+
# https://github.com/docker-library/openjdk/issues/261 https://github.com/docker-library/openjdk/pull/263/files
70+
RUN keytool -importkeystore -srckeystore /etc/ssl/certs/java/cacerts -destkeystore /etc/ssl/certs/java/cacerts.jks -deststoretype JKS -srcstorepass changeit -deststorepass changeit -noprompt; \
71+
mv /etc/ssl/certs/java/cacerts.jks /etc/ssl/certs/java/cacerts; \
72+
/var/lib/dpkg/info/ca-certificates-java.postinst configure;
73+
74+
RUN curl -L -o ~/mambaforge.sh https://github.com/conda-forge/miniforge/releases/download/${MAMBA_VERSION}/Mambaforge-${MAMBA_VERSION}-Linux-x86_64.sh \
75+
&& chmod +x ~/mambaforge.sh \
76+
&& ~/mambaforge.sh -b -p /opt/conda \
77+
&& rm ~/mambaforge.sh \
78+
&& /opt/conda/bin/conda update -y conda \
79+
&& /opt/conda/bin/conda install -c conda-forge -y \
80+
python=$PYTHON_VERSION \
81+
pyopenssl \
82+
cython \
83+
mkl-include \
84+
mkl \
85+
botocore \
86+
parso \
87+
scipy \
88+
typing \
89+
# Below 2 are included in miniconda base, but not mamba so need to install
90+
conda-content-trust \
91+
charset-normalizer \
92+
&& /opt/conda/bin/conda update -y conda \
93+
&& /opt/conda/bin/conda clean -ya
94+
95+
RUN conda install -c conda-forge \
96+
scikit-learn \
97+
h5py \
98+
requests \
99+
&& conda clean -ya \
100+
&& pip install --upgrade pip --trusted-host pypi.org --trusted-host files.pythonhosted.org \
101+
&& ln -s /opt/conda/bin/pip /usr/local/bin/pip3 \
102+
&& pip install --no-cache-dir "protobuf>=3.18.3,<4" setuptools==69.5.1 packaging
103+
104+
WORKDIR /
105+
106+
# install Hugging Face libraries and its dependencies
107+
RUN pip install --extra-index-url https://pip.repos.neuron.amazonaws.com --no-cache-dir optimum-neuron[neuronx]==${OPTIMUM_NEURON_VERSION} \
108+
&& pip install --no-deps --no-cache-dir -U torchvision==0.16.*
109+
110+
111+
COPY . .
112+
# install wheel and setuptools
113+
RUN pip install --no-cache-dir -U pip ".[st]"
114+
115+
# copy application
116+
COPY src/huggingface_inference_toolkit huggingface_inference_toolkit
117+
COPY src/huggingface_inference_toolkit/webservice_starlette.py webservice_starlette.py
118+
119+
# copy entrypoint and change permissions
120+
COPY --chmod=0755 scripts/entrypoint.sh entrypoint.sh
121+
122+
ENTRYPOINT ["bash", "-c", "./entrypoint.sh"]

0 commit comments

Comments
 (0)