-
Notifications
You must be signed in to change notification settings - Fork 118
added a jax example #11
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
base: main
Are you sure you want to change the base?
Changes from 2 commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,100 @@ | ||
<!-- | ||
# Copyright 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. | ||
# | ||
# Redistribution and use in source and binary forms, with or without | ||
# modification, are permitted provided that the following conditions | ||
# are met: | ||
# * Redistributions of source code must retain the above copyright | ||
# notice, this list of conditions and the following disclaimer. | ||
# * Redistributions in binary form must reproduce the above copyright | ||
# notice, this list of conditions and the following disclaimer in the | ||
# documentation and/or other materials provided with the distribution. | ||
# * Neither the name of NVIDIA CORPORATION nor the names of its | ||
# contributors may be used to endorse or promote products derived | ||
# from this software without specific prior written permission. | ||
# | ||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY | ||
# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE | ||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR | ||
# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR | ||
# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, | ||
# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, | ||
# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR | ||
# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY | ||
# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT | ||
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE | ||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. | ||
--> | ||
|
||
# Deploying a JAX Model | ||
|
||
This README showcases how to deploy a simple ResNet model on Triton Inference Server. While Triton doesn't yet have a dedicated JAX backend, JAX/Flax models can be deployed using [Python Backend](https://github.com/triton-inference-server/python_backend). If you are new to Triton, it is recommended to watch this [getting started video](https://www.youtube.com/watch?v=NQDtfSi5QF4) and review [Part 1](https://github.com/triton-inference-server/tutorials/tree/main/Conceptual_Guide/Part_1-model_deployment) of the conceptual guide before proceeding. For the purposes of demonstration, we are using a pre-trained model provided by [flaxmodels](https://github.com/matthias-wright/flaxmodels). | ||
|
||
Before diving into the specifics execution, an understanding of the underlying structure is needed. To use a JAX or a Flax model, the recommended path for this is using a ["Python Model"](https://github.com/triton-inference-server/python_backend#python-backend). Python models in Triton are basically classes with three Triton-specific functions: `initialize`, `execute` and `finalize`. Users can customize this class to serve any python function they write or any model they want as long as it can be loaded in python runtime. The `initialize` function runs when the python model is loaded into memory, and the `finalize` function runs when the model is unloaded from memory. Both of these functions are optional to define. For the purposes of this example, we will use the `initialize` and the `execute` functions to load and run(respectively) a `resnet18` model. | ||
|
||
``` | ||
class TritonPythonModel: | ||
def initialize(self, args): | ||
... | ||
|
||
def execute(self, requests): | ||
... | ||
``` | ||
|
||
## Step 1: Set Up Triton Inference Server | ||
|
||
To use Triton, we need to build a model repository. The structure of the repository as follows: | ||
``` | ||
model_repository | ||
| | ||
+-- resnet50 | ||
| | ||
+-- config.pbtxt | ||
+-- 1 | ||
| | ||
+-- model.py | ||
``` | ||
For this example, we have pre-built the model repository. Next, we install the required dependencies and launch the Triton Inference Server. | ||
|
||
``` | ||
# Replace the yy.mm in the image name with the release year and month | ||
# of the Triton version needed, eg. 22.12 | ||
docker run --gpus=all -it --shm-size=256m --rm -p8000:8000 -p8001:8001 -p8002:8002 -v$(pwd):/workspace/ -v/$(pwd)/model_repository:/models nvcr.io/nvidia/tritonserver:<yy.mm>-py3 bash | ||
|
||
pip install --upgrade pip | ||
pip install --upgrade "jax[cuda]" -f https://storage.googleapis.com/jax-releases/jax_cuda_releases.html | ||
pip install --upgrade git+https://github.com/matthias-wright/flaxmodels.git | ||
|
||
``` | ||
|
||
## Step 2: Using a Triton Client to Query the Server | ||
|
||
Let's breakdown the client application. First, we setup a connection with the Triton Inference Server. | ||
``` | ||
client = httpclient.InferenceServerClient(url="localhost:8000") | ||
``` | ||
Then we set the input and output arrays. | ||
``` | ||
# Set Inputs | ||
input_tensors = [ | ||
httpclient.InferInput("image", image.shape, datatype="FP32") | ||
] | ||
input_tensors[0].set_data_from_numpy(image) | ||
|
||
# Set outputs | ||
outputs = [ | ||
httpclient.InferRequestedOutput("fc_out") | ||
] | ||
``` | ||
Lastly, we query send a request to the Triton Inference Server. | ||
|
||
``` | ||
# Query | ||
query_response = client.infer(model_name="resnet50", | ||
inputs=input_tensors, | ||
outputs=outputs) | ||
|
||
# Output | ||
out = query_response.as_numpy("fc_out") | ||
``` | ||
|
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,63 @@ | ||
# Copyright 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. | ||
# | ||
# Redistribution and use in source and binary forms, with or without | ||
# modification, are permitted provided that the following conditions | ||
# are met: | ||
# * Redistributions of source code must retain the above copyright | ||
# notice, this list of conditions and the following disclaimer. | ||
# * Redistributions in binary form must reproduce the above copyright | ||
# notice, this list of conditions and the following disclaimer in the | ||
# documentation and/or other materials provided with the distribution. | ||
# * Neither the name of NVIDIA CORPORATION nor the names of its | ||
# contributors may be used to endorse or promote products derived | ||
# from this software without specific prior written permission. | ||
# | ||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY | ||
# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE | ||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR | ||
# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR | ||
# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, | ||
# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, | ||
# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR | ||
# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY | ||
# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT | ||
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE | ||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. | ||
|
||
import numpy as np | ||
from tritonclient.utils import * | ||
from PIL import Image | ||
import tritonclient.http as httpclient | ||
import requests | ||
|
||
|
||
def main(): | ||
client = httpclient.InferenceServerClient(url="localhost:8000") | ||
|
||
# Inputs | ||
url = "http://images.cocodataset.org/val2017/000000161642.jpg" | ||
image = np.asarray(Image.open(requests.get(url, stream=True).raw)).astype(np.float32) | ||
image = np.expand_dims(image, axis=0) | ||
|
||
# Set Inputs | ||
input_tensors = [ | ||
httpclient.InferInput("image", image.shape, datatype="FP32") | ||
] | ||
input_tensors[0].set_data_from_numpy(image) | ||
|
||
# Set outputs | ||
outputs = [ | ||
httpclient.InferRequestedOutput("fc_out") | ||
] | ||
|
||
# Query | ||
query_response = client.infer(model_name="resnet50", | ||
inputs=input_tensors, | ||
outputs=outputs) | ||
|
||
# Output | ||
out = query_response.as_numpy("fc_out") | ||
print(out.shape) | ||
|
||
if __name__ == "__main__": | ||
main() |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,58 @@ | ||
# Copyright 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. | ||
# | ||
# Redistribution and use in source and binary forms, with or without | ||
# modification, are permitted provided that the following conditions | ||
# are met: | ||
# * Redistributions of source code must retain the above copyright | ||
# notice, this list of conditions and the following disclaimer. | ||
# * Redistributions in binary form must reproduce the above copyright | ||
# notice, this list of conditions and the following disclaimer in the | ||
# documentation and/or other materials provided with the distribution. | ||
# * Neither the name of NVIDIA CORPORATION nor the names of its | ||
# contributors may be used to endorse or promote products derived | ||
# from this software without specific prior written permission. | ||
# | ||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY | ||
# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE | ||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR | ||
# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR | ||
# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, | ||
# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, | ||
# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR | ||
# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY | ||
# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT | ||
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE | ||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. | ||
|
||
import triton_python_backend_utils as pb_utils | ||
import jax | ||
import jax.numpy as jnp | ||
import flaxmodels as fm | ||
|
||
import numpy as np | ||
from flax.jax_utils import replicate | ||
|
||
class TritonPythonModel: | ||
|
||
def initialize(self, args): | ||
self.key = jax.random.PRNGKey(0) | ||
self.resnet18 = fm.ResNet18(output='logits', pretrained='imagenet') | ||
|
||
|
||
def execute(self, requests): | ||
responses = [] | ||
for request in requests: | ||
inp = pb_utils.get_input_tensor_by_name(request, "image") | ||
input_image = inp.as_numpy() | ||
|
||
params = self.resnet18.init(self.key, input_image) | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. I'm surprised by this There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Ah, good catch, I missed this. |
||
out = self.resnet18.apply(params, input_image, train=False) | ||
|
||
inference_response = pb_utils.InferenceResponse(output_tensors=[ | ||
pb_utils.Tensor( | ||
"fc_out", | ||
np.array(out), | ||
) | ||
]) | ||
responses.append(inference_response) | ||
return responses |
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,50 @@ | ||
# Copyright 2023, NVIDIA CORPORATION & AFFILIATES. All rights reserved. | ||
# | ||
# Redistribution and use in source and binary forms, with or without | ||
# modification, are permitted provided that the following conditions | ||
# are met: | ||
# * Redistributions of source code must retain the above copyright | ||
# notice, this list of conditions and the following disclaimer. | ||
# * Redistributions in binary form must reproduce the above copyright | ||
# notice, this list of conditions and the following disclaimer in the | ||
# documentation and/or other materials provided with the distribution. | ||
# * Neither the name of NVIDIA CORPORATION nor the names of its | ||
# contributors may be used to endorse or promote products derived | ||
# from this software without specific prior written permission. | ||
# | ||
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY | ||
# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE | ||
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR | ||
# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR | ||
# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, | ||
# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, | ||
# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR | ||
# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY | ||
# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT | ||
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE | ||
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. | ||
|
||
name: "resnet50" | ||
backend: "python" | ||
max_batch_size: 8 | ||
|
||
input [ | ||
{ | ||
name: "image" | ||
data_type: TYPE_FP32 | ||
dims: [-1, -1, -1] | ||
tanayvarshney marked this conversation as resolved.
Show resolved
Hide resolved
|
||
} | ||
] | ||
output [ | ||
{ | ||
name: "fc_out" | ||
data_type: TYPE_FP32 | ||
dims: [-1, -1] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This doesn't seem right. Shouldn't it be There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. It [-1,1000] if I remember it correctly. I can re-run and check |
||
} | ||
] | ||
|
||
instance_group [ | ||
{ | ||
kind: KIND_GPU | ||
} | ||
] | ||
tanayvarshney marked this conversation as resolved.
Show resolved
Hide resolved
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can you add some explanation of what you are putting in the config.pbtxt file? Mainly the input/output names and how they have to correspond to what's in the execute method
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
We have this info in other places, I am trying to avoid replication of information, but I see why it is needed. We should discuss more about this. We are bound to run into this issue later down the line