Skip to content

Fix the parameter to tensor conversion in TRTLLM FastAPI implementation #98

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 2 commits into from
Jul 25, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 6 additions & 6 deletions Triton_Inference_Server_Python_API/examples/fastapi/README.md
Original file line number Diff line number Diff line change
Expand Up @@ -26,15 +26,15 @@
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
-->

# Triton Inference Server Open AI Compatible Server
# Triton Inference Server Open AI Compatible Server

Using the Triton In-Process Python API you can integrat triton server
based models into any Python framework including FastAPI with an
OpenAI compatible interface.

This directory contains a FastAPI based Triton Inference Server
supporing `llama-3-8b-instruct` with both the vLLM and TRT-LLM
backends.
backends.

The front end application was generated using a trimmed version of the
OpenAI OpenAPI [specification](api-spec/openai_trimmed.yml) and the
Expand Down Expand Up @@ -118,7 +118,7 @@ curl -X 'POST' \
"stream": false,
"stop": "string",
"frequency_penalty": 0.0
}' | jq .
}' | jq .
```

#### Chat Completions `/v1/chat/completions`
Expand Down Expand Up @@ -165,7 +165,7 @@ curl -s http://localhost:8000/v1/models | jq .
curl -s http://localhost:8000/v1/models/llama-3-8b-instruct | jq .
```

## Comparison to vllm
## Comparison to vllm

The vLLM container can also be used to run the vLLM FastAPI Server

Expand All @@ -185,7 +185,7 @@ Note: the following command requires the 24.05 pre-release version of genai-perf
Preliminary results show performance is on par with vLLM with concurrency 2

```
genai-perf -m meta-llama/Meta-Llama-3-8B-Instruct --endpoint v1/chat/completions --endpoint-type chat --service-kind openai -u http://localhost:8000 --num-prompts 100 --synthetic-input-tokens-mean 1024 --synthetic-input-tokens-stddev 50 --concurrency 2 --measurement-interval 40000 --extra-inputs max_tokens:512 --extra-input ignore_eos:true -- -v --max-threads=256
genai-perf -m meta-llama/Meta-Llama-3-8B-Instruct --endpoint v1/chat/completions --endpoint-type chat --service-kind openai -u http://localhost:8000 --num-prompts 100 --synthetic-input-tokens-mean 1024 --synthetic-input-tokens-stddev 50 --concurrency 2 --measurement-interval 40000 --extra-inputs max_tokens:512 --extra-input ignore_eos:true -- -v --max-threads=256
erval 40000 --extra-inputs max_tokens:512 --extra-input ignore_eos:true -- -v --max-threads=256
```

Expand All @@ -195,5 +195,5 @@ erval 40000 --extra-inputs max_tokens:512 --extra-input ignore_eos:true -- -v --
* Max tokens is not processed by trt-llm backend correctly
* Usage information is not populated
* `finish_reason` is currently always set to `stop`
* Limited performance testing has been done
* Limited performance testing has been done
* Using genai-perf to test streaming requires changes to genai-perf SSE handling
Original file line number Diff line number Diff line change
Expand Up @@ -165,21 +165,21 @@ def create_trtllm_inference_request(
inputs["text_input"] = [[prompt]]
inputs["stream"] = [[request.stream]]
if request.max_tokens:
inputs["max_tokens"] = [[numpy.int32(request.max_tokens)]]
inputs["max_tokens"] = numpy.int32([[request.max_tokens]])
if request.stop:
if isinstance(request.stop, str):
request.stop = [request.stop]
inputs["stop_words"] = [request.stop]
if request.top_p:
inputs["top_p"] = [[numpy.float32(request.top_p)]]
inputs["top_p"] = numpy.float32([[request.top_p]])
if request.frequency_penalty:
inputs["frequency_penalty"] = [[numpy.float32(request.frequency_penalty)]]
inputs["frequency_penalty"] = numpy.float32([[request.frequency_penalty]])
if request.presence_penalty:
inputs["presence_penalty":] = [[numpy.int32(request.presence_penalty)]]
inputs["presence_penalty":] = numpy.int32([[request.presence_penalty]])
if request.seed:
inputs["random_seed"] = [[numpy.uint64(request.seed)]]
inputs["random_seed"] = numpy.uint64([[request.seed]])
if request.temperature:
inputs["temperature"] = [[numpy.float32(request.temperature)]]
inputs["temperature"] = numpy.float32([[request.temperature]])

return model.create_request(inputs=inputs)

Expand Down
Loading