generated from amazon-archives/__template_Apache-2.0
-
Notifications
You must be signed in to change notification settings - Fork 52
feat: Add blueprint TensorRT-LLM + Triton #65
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Open
aasthavar
wants to merge
2
commits into
awslabs:main
Choose a base branch
from
aasthavar:main
base: main
Could not load branches
Branch not found: {{ refName }}
Loading
Could not load tags
Nothing to show
Loading
Are you sure you want to change the base?
Some commits from the old base branch may be removed from the timeline,
and old review comments may become outdated.
Open
Changes from all commits
Commits
Show all changes
2 commits
Select commit
Hold shift + click to select a range
File filter
Filter by extension
Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
There are no files selected for viewing
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
19 changes: 19 additions & 0 deletions
19
blueprints/inference/trtllm-nvidia-triton-server-gpu/Dockerfile
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,19 @@ | ||
FROM nvcr.io/nvidia/tritonserver:25.01-trtllm-python-py3 | ||
|
||
COPY llama /llama | ||
COPY triton_engines /engines | ||
COPY triton_model_files /triton_model_files | ||
|
||
RUN git clone https://github.com/triton-inference-server/tensorrtllm_backend.git /tensorrtllm_backend \ | ||
&& cd /tensorrtllm_backend \ | ||
&& apt-get update \ | ||
&& apt-get install git-lfs -y --no-install-recommends \ | ||
&& git submodule update --init --recursive \ | ||
&& git lfs install | ||
|
||
COPY launch_triton_server.py /tensorrtllm_backend/scripts/launch_triton_server.py | ||
|
||
COPY start.sh /start.sh | ||
RUN chmod +x /start.sh | ||
|
||
ENTRYPOINT ["/bin/bash", "/start.sh"] | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. newline |
20 changes: 20 additions & 0 deletions
20
blueprints/inference/trtllm-nvidia-triton-server-gpu/benchmark-grpc/prompts.txt
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,20 @@ | ||
Explain the concept of generative adversarial networks (GANs). | ||
How does a variational autoencoder (VAE) work? | ||
What are the applications of generative AI in healthcare? | ||
Describe the process of training a GAN. | ||
How can generative AI be used in creative arts? | ||
What is the difference between supervised and unsupervised learning in the context of generative AI? | ||
Explain the role of a discriminator in a GAN. | ||
How can generative AI improve natural language processing (NLP)? | ||
What are the ethical considerations of using generative AI? | ||
How is generative AI used in drug discovery? | ||
Describe the architecture of a Transformer model. | ||
How can generative AI be applied in the gaming industry? | ||
What is the purpose of latent space in generative models? | ||
How does text generation with GPT-3 work? | ||
What are the challenges of using generative AI in finance? | ||
Explain the concept of zero-shot learning in generative AI. | ||
How can generative AI be used for image synthesis? | ||
What are some real-world applications of deepfakes? | ||
How can generative AI contribute to personalized medicine? | ||
Describe the use of generative AI in autonomous vehicles. |
248 changes: 248 additions & 0 deletions
248
blueprints/inference/trtllm-nvidia-triton-server-gpu/benchmark-grpc/triton-client.py
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,248 @@ | ||
import argparse | ||
import asyncio | ||
import json | ||
import sys | ||
import time | ||
from os import system | ||
import numpy as np | ||
|
||
import tritonclient.grpc.aio as grpcclient | ||
from tritonclient.utils import * | ||
|
||
|
||
def count_tokens(text): | ||
return len(text.split()) | ||
|
||
|
||
def create_request( | ||
prompt, | ||
stream, | ||
request_id, | ||
model_name, | ||
sampling_parameters, | ||
max_tokens, | ||
): | ||
inputs = [] | ||
|
||
# Required input | ||
prompt_data = np.array([[prompt.encode("utf-8")]], dtype=np.object_) | ||
inputs.append(grpcclient.InferInput("text_input", prompt_data.shape, "BYTES")) | ||
inputs[-1].set_data_from_numpy(prompt_data) | ||
|
||
# Required max_tokens | ||
max_tokens_data = np.array([[max_tokens]], dtype=np.int32) | ||
inputs.append(grpcclient.InferInput("max_tokens", max_tokens_data.shape, "INT32")) | ||
inputs[-1].set_data_from_numpy(max_tokens_data) | ||
|
||
# Optional stream | ||
stream_data = np.array([[stream]], dtype=bool) | ||
inputs.append(grpcclient.InferInput("stream", stream_data.shape, "BOOL")) | ||
inputs[-1].set_data_from_numpy(stream_data) | ||
|
||
# Handle optional sampling parameters. Map of field name -> (value, Triton dtype) | ||
optional_inputs = { | ||
"temperature": (np.float32, "FP32"), | ||
"top_k": (np.int32, "INT32"), | ||
"top_p": (np.float32, "FP32"), | ||
"length_penalty": (np.float32, "FP32"), | ||
"repetition_penalty": (np.float32, "FP32"), | ||
"min_length": (np.int32, "INT32"), | ||
"presence_penalty": (np.float32, "FP32"), | ||
"frequency_penalty": (np.float32, "FP32"), | ||
"end_id": (np.int32, "INT32"), | ||
"pad_id": (np.int32, "INT32"), | ||
"beam_width": (np.int32, "INT32"), | ||
"num_return_sequences": (np.int32, "INT32"), | ||
"random_seed": (np.uint64, "UINT64"), | ||
"return_log_probs": (np.bool_, "BOOL"), | ||
"return_context_logits": (np.bool_, "BOOL"), | ||
"return_generation_logits": (np.bool_, "BOOL"), | ||
"return_kv_cache_reuse_stats": (np.bool_, "BOOL"), | ||
"exclude_input_in_output": (np.bool_, "BOOL"), | ||
} | ||
|
||
for key, (dtype, dtype_str) in optional_inputs.items(): | ||
if key in sampling_parameters: | ||
val = np.array([[sampling_parameters[key]]], dtype=dtype) | ||
inp = grpcclient.InferInput(key, val.shape, dtype_str) | ||
inp.set_data_from_numpy(val) | ||
inputs.append(inp) | ||
|
||
# Handle optional bad_words and stop_words | ||
for key in ["bad_words", "stop_words", "embedding_bias_words"]: | ||
if key in sampling_parameters: | ||
words = sampling_parameters[key] | ||
arr = np.array([w.encode("utf-8") for w in words], dtype=np.object_) | ||
inp = grpcclient.InferInput(key, [len(arr)], "BYTES") | ||
inp.set_data_from_numpy(arr) | ||
inputs.append(inp) | ||
|
||
# Handle optional embedding_bias_weights | ||
if "embedding_bias_weights" in sampling_parameters: | ||
weights = np.array(sampling_parameters["embedding_bias_weights"], dtype=np.float32) | ||
inp = grpcclient.InferInput("embedding_bias_weights", list(weights.shape), "FP32") | ||
inp.set_data_from_numpy(weights) | ||
inputs.append(inp) | ||
|
||
# Outputs (we just ask for text_output for now) | ||
outputs = [] | ||
outputs.append(grpcclient.InferRequestedOutput("text_output")) | ||
|
||
request = { | ||
"model_name": model_name, | ||
"inputs": inputs, | ||
"outputs": outputs, | ||
"request_id": str(request_id), | ||
} | ||
# print(f"request: {request}") | ||
return request | ||
|
||
|
||
async def main(FLAGS): | ||
sampling_parameters = {"temperature": "0.01", "top_p": "1.0", "top_k": 20, "min_length": 1} | ||
max_tokens = 128 | ||
stream = FLAGS.streaming_mode | ||
model_name = FLAGS.model_name | ||
with open(FLAGS.input_prompts, "r") as file: | ||
print(f"Loading inputs from `{FLAGS.input_prompts}`...") | ||
prompts = file.readlines() | ||
|
||
results_dict = {} | ||
total_time_sec = 0 # Initialize total time in seconds | ||
|
||
async with grpcclient.InferenceServerClient(url=FLAGS.url, verbose=FLAGS.verbose) as triton_client: | ||
async def async_request_iterator(): | ||
try: | ||
for iter in range(FLAGS.iterations): | ||
for i, prompt in enumerate(prompts): | ||
prompt_id = FLAGS.offset + (len(prompts) * iter) + i | ||
results_dict[str(prompt_id)] = [] | ||
yield create_request( | ||
prompt=prompt, | ||
stream=stream, | ||
request_id=prompt_id, | ||
model_name=model_name, | ||
max_tokens=max_tokens, | ||
sampling_parameters=sampling_parameters, | ||
) | ||
except Exception as error: | ||
print(f"ERROR: caught error in request iterator: {error}") | ||
|
||
try: | ||
start_time = time.time() # Record the start time | ||
response_iterator = triton_client.stream_infer( | ||
inputs_iterator=async_request_iterator(), | ||
stream_timeout=FLAGS.stream_timeout, | ||
) | ||
async for response in response_iterator: | ||
result, error = response | ||
end_time = time.time() # Record the end time | ||
|
||
if error: | ||
print(f"Encountered error while processing: {error}") | ||
else: | ||
output = result.as_numpy("text_output") | ||
# print(f"output: {output[0].decode('utf-8')}") | ||
for i in output: | ||
debug = { | ||
"Prompt": prompts[int(result.get_response().id)], | ||
"Response Time": end_time - start_time, | ||
"Tokens": count_tokens(i.decode('utf-8')), | ||
"Response": i.decode('utf-8'), | ||
} | ||
results_dict[result.get_response().id] = debug | ||
|
||
duration = (end_time - start_time) # Calculate the duration in seconds | ||
total_time_sec += (end_time - start_time) # Add duration to total time in seconds | ||
print(f"Model {FLAGS.model_name} - Request {result.get_response().id}: {duration:.2f} seconds") | ||
|
||
except InferenceServerException as error: | ||
print(error) | ||
sys.exit(1) | ||
|
||
with open(FLAGS.results_file, "w") as file: | ||
for key, val in results_dict.items(): | ||
file.write( | ||
f"Prompt: {val['Prompt']}\nResponse Time: {val['Response Time']}\nTokens: {val['Tokens']}\nResponse: {val['Response']}") | ||
file.write("\n=========\n\n") | ||
print(f"Storing results into `{FLAGS.results_file}`...") | ||
|
||
if FLAGS.verbose: | ||
print(f"\nContents of `{FLAGS.results_file}` ===>") | ||
system(f"cat {FLAGS.results_file}") | ||
|
||
total_time_ms = total_time_sec # Convert total time to milliseconds | ||
print(f"Total time for all requests: {total_time_sec:.2f} seconds ({total_time_ms:.2f} seconds)") | ||
print("PASS: trtllm example") | ||
|
||
|
||
if __name__ == "__main__": | ||
parser = argparse.ArgumentParser() | ||
parser.add_argument( | ||
"-v", | ||
"--verbose", | ||
action="store_true", | ||
required=False, | ||
default=False, | ||
help="Enable verbose output", | ||
) | ||
parser.add_argument( | ||
"-u", | ||
"--url", | ||
type=str, | ||
required=False, | ||
default="localhost:8001", | ||
help="Inference server URL and it gRPC port. Default is localhost:8001.", | ||
) | ||
parser.add_argument( | ||
"-t", | ||
"--stream-timeout", | ||
type=float, | ||
required=False, | ||
default=None, | ||
help="Stream timeout in seconds. Default is None.", | ||
) | ||
parser.add_argument( | ||
"--offset", | ||
type=int, | ||
required=False, | ||
default=0, | ||
help="Add offset to request IDs used", | ||
) | ||
parser.add_argument( | ||
"--input-prompts", | ||
type=str, | ||
required=False, | ||
default="prompts.txt", | ||
help="Text file with input prompts", | ||
) | ||
parser.add_argument( | ||
"--results-file", | ||
type=str, | ||
required=False, | ||
default="results.txt", | ||
help="The file with output results", | ||
) | ||
parser.add_argument( | ||
"--iterations", | ||
type=int, | ||
required=False, | ||
default=1, | ||
help="Number of iterations through the prompts file", | ||
) | ||
parser.add_argument( | ||
"-s", | ||
"--streaming-mode", | ||
action="store_true", | ||
required=False, | ||
default=False, | ||
help="Enable streaming mode", | ||
) | ||
parser.add_argument( | ||
"--model-name", | ||
type=str, | ||
required=True, | ||
help="Name of the model to test", | ||
) | ||
FLAGS = parser.parse_args() | ||
asyncio.run(main(FLAGS)) |
Oops, something went wrong.
Add this suggestion to a batch that can be applied as a single commit.
This suggestion is invalid because no changes were made to the code.
Suggestions cannot be applied while the pull request is closed.
Suggestions cannot be applied while viewing a subset of changes.
Only one suggestion per line can be applied in a batch.
Add this suggestion to a batch that can be applied as a single commit.
Applying suggestions on deleted lines is not supported.
You must change the existing code in this line in order to create a valid suggestion.
Outdated suggestions cannot be applied.
This suggestion has been applied or marked resolved.
Suggestions cannot be applied from pending reviews.
Suggestions cannot be applied on multi-line comments.
Suggestions cannot be applied while the pull request is queued to merge.
Suggestion cannot be applied right now. Please check back later.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
As
.ecr_repo_uri
and.eks_region
seem to be a common pattern, can we gitignore them with