Skip to content

Commit 4d40bf3

Browse files
authored
AI Agents Guide (#107)
Adding AI Agents Guide with Constrained decoding tutorial (#109) and Function Calling tutorial (#113)
1 parent e23775d commit 4d40bf3

File tree

11 files changed

+2768
-0
lines changed

11 files changed

+2768
-0
lines changed

AI_Agents_Guide/Constrained_Decoding/README.md

Lines changed: 703 additions & 0 deletions
Large diffs are not rendered by default.
Lines changed: 278 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,278 @@
1+
#!/usr/bin/python
2+
# Copyright 2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3+
#
4+
# Redistribution and use in source and binary forms, with or without
5+
# modification, are permitted provided that the following conditions
6+
# are met:
7+
# * Redistributions of source code must retain the above copyright
8+
# notice, this list of conditions and the following disclaimer.
9+
# * Redistributions in binary form must reproduce the above copyright
10+
# notice, this list of conditions and the following disclaimer in the
11+
# documentation and/or other materials provided with the distribution.
12+
# * Neither the name of NVIDIA CORPORATION nor the names of its
13+
# contributors may be used to endorse or promote products derived
14+
# from this software without specific prior written permission.
15+
#
16+
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY
17+
# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
18+
# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
19+
# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR
20+
# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL,
21+
# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO,
22+
# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR
23+
# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
24+
# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
25+
# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
26+
# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
27+
28+
import argparse
29+
import sys
30+
31+
import client_utils
32+
import numpy as np
33+
import tritonclient.grpc as grpcclient
34+
from pydantic import BaseModel
35+
36+
37+
class AnswerFormat(BaseModel):
38+
title: str
39+
year: int
40+
director: str
41+
producer: str
42+
plot: str
43+
44+
45+
if __name__ == "__main__":
46+
parser = argparse.ArgumentParser()
47+
parser.add_argument(
48+
"-v",
49+
"--verbose",
50+
action="store_true",
51+
required=False,
52+
default=False,
53+
help="Enable verbose output",
54+
)
55+
parser.add_argument(
56+
"-u", "--url", type=str, required=False, help="Inference server URL."
57+
)
58+
59+
parser.add_argument("-p", "--prompt", type=str, required=True, help="Input prompt.")
60+
61+
parser.add_argument(
62+
"--model-name",
63+
type=str,
64+
required=False,
65+
default="ensemble",
66+
choices=["ensemble", "tensorrt_llm_bls"],
67+
help="Name of the Triton model to send request to",
68+
)
69+
70+
parser.add_argument(
71+
"-S",
72+
"--streaming",
73+
action="store_true",
74+
required=False,
75+
default=False,
76+
help="Enable streaming mode. Default is False.",
77+
)
78+
79+
parser.add_argument(
80+
"-b",
81+
"--beam-width",
82+
required=False,
83+
type=int,
84+
default=1,
85+
help="Beam width value",
86+
)
87+
88+
parser.add_argument(
89+
"--temperature",
90+
type=float,
91+
required=False,
92+
default=1.0,
93+
help="temperature value",
94+
)
95+
96+
parser.add_argument(
97+
"--repetition-penalty",
98+
type=float,
99+
required=False,
100+
default=None,
101+
help="The repetition penalty value",
102+
)
103+
104+
parser.add_argument(
105+
"--presence-penalty",
106+
type=float,
107+
required=False,
108+
default=None,
109+
help="The presence penalty value",
110+
)
111+
112+
parser.add_argument(
113+
"--frequency-penalty",
114+
type=float,
115+
required=False,
116+
default=None,
117+
help="The frequency penalty value",
118+
)
119+
120+
parser.add_argument(
121+
"-o",
122+
"--output-len",
123+
type=int,
124+
default=100,
125+
required=False,
126+
help="Specify output length",
127+
)
128+
129+
parser.add_argument(
130+
"--request-id",
131+
type=str,
132+
default="",
133+
required=False,
134+
help="The request_id for the stop request",
135+
)
136+
137+
parser.add_argument("--stop-words", nargs="+", default=[], help="The stop words")
138+
139+
parser.add_argument("--bad-words", nargs="+", default=[], help="The bad words")
140+
141+
parser.add_argument(
142+
"--embedding-bias-words", nargs="+", default=[], help="The biased words"
143+
)
144+
145+
parser.add_argument(
146+
"--embedding-bias-weights",
147+
nargs="+",
148+
default=[],
149+
help="The biased words weights",
150+
)
151+
152+
parser.add_argument(
153+
"--overwrite-output-text",
154+
action="store_true",
155+
required=False,
156+
default=False,
157+
help="In streaming mode, overwrite previously received output text instead of appending to it",
158+
)
159+
160+
parser.add_argument(
161+
"--return-context-logits",
162+
action="store_true",
163+
required=False,
164+
default=False,
165+
help="Return context logits, the engine must be built with gather_context_logits or gather_all_token_logits",
166+
)
167+
168+
parser.add_argument(
169+
"--return-generation-logits",
170+
action="store_true",
171+
required=False,
172+
default=False,
173+
help="Return generation logits, the engine must be built with gather_ generation_logits or gather_all_token_logits",
174+
)
175+
176+
parser.add_argument(
177+
"--end-id", type=int, required=False, help="The token id for end token."
178+
)
179+
180+
parser.add_argument(
181+
"--pad-id", type=int, required=False, help="The token id for pad token."
182+
)
183+
184+
parser.add_argument(
185+
"--use-system-prompt",
186+
action="store_true",
187+
required=False,
188+
default=False,
189+
help="Enhance text input with system prompt.",
190+
)
191+
192+
parser.add_argument(
193+
"--use-schema",
194+
action="store_true",
195+
required=False,
196+
default=False,
197+
help="Use client-defined JSON schema.",
198+
)
199+
200+
parser.add_argument(
201+
"--logits-post-processor-name",
202+
type=str,
203+
required=False,
204+
default=None,
205+
help="Logits Post-Processor to use for output generation.",
206+
)
207+
208+
FLAGS = parser.parse_args()
209+
if FLAGS.url is None:
210+
FLAGS.url = "localhost:8001"
211+
212+
embedding_bias_words = (
213+
FLAGS.embedding_bias_words if FLAGS.embedding_bias_words else None
214+
)
215+
embedding_bias_weights = (
216+
FLAGS.embedding_bias_weights if FLAGS.embedding_bias_weights else None
217+
)
218+
219+
try:
220+
client = grpcclient.InferenceServerClient(url=FLAGS.url)
221+
except Exception as e:
222+
print("client creation failed: " + str(e))
223+
sys.exit(1)
224+
225+
return_context_logits_data = None
226+
if FLAGS.return_context_logits:
227+
return_context_logits_data = np.array(
228+
[[FLAGS.return_context_logits]], dtype=bool
229+
)
230+
231+
return_generation_logits_data = None
232+
if FLAGS.return_generation_logits:
233+
return_generation_logits_data = np.array(
234+
[[FLAGS.return_generation_logits]], dtype=bool
235+
)
236+
237+
prompt = FLAGS.prompt
238+
239+
if FLAGS.use_system_prompt:
240+
prompt = (
241+
"<|im_start|>system\n You are a helpful assistant that answers in JSON."
242+
)
243+
244+
if FLAGS.use_schema:
245+
prompt += "Here's the json schema you must adhere to:\n<schema>\n{schema}\n</schema>".format(
246+
schema=AnswerFormat.model_json_schema()
247+
)
248+
249+
prompt += "<|im_end|>\n<|im_start|>user\n{user_prompt}<|im_end|>\n<|im_start|>assistant\n".format(
250+
user_prompt=FLAGS.prompt
251+
)
252+
253+
output_text = client_utils.run_inference(
254+
client,
255+
prompt,
256+
FLAGS.output_len,
257+
FLAGS.request_id,
258+
FLAGS.repetition_penalty,
259+
FLAGS.presence_penalty,
260+
FLAGS.frequency_penalty,
261+
FLAGS.temperature,
262+
FLAGS.stop_words,
263+
FLAGS.bad_words,
264+
embedding_bias_words,
265+
embedding_bias_weights,
266+
FLAGS.model_name,
267+
FLAGS.streaming,
268+
FLAGS.beam_width,
269+
FLAGS.overwrite_output_text,
270+
return_context_logits_data,
271+
return_generation_logits_data,
272+
FLAGS.end_id,
273+
FLAGS.pad_id,
274+
FLAGS.verbose,
275+
logits_post_processor_name=FLAGS.logits_post_processor_name,
276+
)
277+
278+
print(output_text)

0 commit comments

Comments
 (0)