diff --git a/AI_Agents_Guide/Function_Calling/README.md b/AI_Agents_Guide/Function_Calling/README.md index c5f4841e..447e539c 100644 --- a/AI_Agents_Guide/Function_Calling/README.md +++ b/AI_Agents_Guide/Function_Calling/README.md @@ -26,13 +26,26 @@ # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. --> -# Function Calling +# Function Calling with Triton Inference Server This tutorial focuses on function calling, a common approach to easily connect large language models (LLMs) to external tools. This method empowers AI agents with effective tool usage and seamless interaction with external APIs, significantly expanding their capabilities and practical applications. +## Table of Contents + +- [What is Function Calling?](#what-is-function-calling) +- [Tutorial Overview](#tutorial-overview) + + [Prerequisite: Hermes-2-Pro-Llama-3-8B](#prerequisite-hermes-2-pro-llama-3-8b) +- [Function Definitions](#function-definitions) +- [Prompt Engineering](#prompt-engineering) +- [Combining Everything Together](#combining-everything-together) +- [Further Optimizations](#further-optimizations) + + [Enforcing Output Format](#enforcing-output-format) + + [Parallel Tool Call](#parallel-tool-call) +- [References](#references) + ## What is Function Calling? Function calling refers to the ability of LLMs to: @@ -40,4 +53,252 @@ Function calling refers to the ability of LLMs to: or perform a task. * Generate a structured output containing the necessary arguments to call that function. - * Integrate the results of the function call into its response. \ No newline at end of file + * Integrate the results of the function call into its response. + +Function calling is a powerful mechanism that allows LLMs to perform +more complex tasks (e.g. agent orchestration in multi-agent systems) +that require specific computations or data retrieval +beyond their inherent knowledge. By recognizing when a particular function +is needed, LLMs can dynamically extend their functionality, making them more +versatile and useful in real-world applications. + +## Tutorial Overview + +This tutorial demonstrates function calling using the +[Hermes-2-Pro-Llama-3-8B](https://huggingface.co/NousResearch/Hermes-2-Pro-Llama-3-8B) +model, which is pre-fine-tuned for this capability. We'll create a basic +stock reporting agent that provides up-to-date stock information and summarizes +recent company news. + +### Prerequisite: Hermes-2-Pro-Llama-3-8B + +Before proceeding, please make sure that you've successfully deployed +[Hermes-2-Pro-Llama-3-8B.](https://huggingface.co/NousResearch/Hermes-2-Pro-Llama-3-8B) +model with Triton Inference Server and TensorRT-LLM backend +following [these steps.](../../Popular_Models_Guide/Hermes-2-Pro-Llama-3-8B/README.md) + +> [!IMPORTANT] +> Make sure that the `tutorials` folder is mounted to `/tutorials`, when you +> start the docker container. + +## Function Definitions + +We'll define three functions for our stock reporting agent: +1. `get_current_stock_price`: Retrieves the current stock price for a given symbol. +2. `get_company_news`: Retrieves company news and press releases for a given stock symbol. +3. `final_answer`: Used as a no-op and to indicate the final response. + +Each function includes its name, description, and input parameter schema: + ```python +TOOLS = [ + { + "type": "function", + "function": { + "name": "get_current_stock_price", + "description": "Get the current stock price for a given symbol.\n\nArgs:\n symbol (str): The stock symbol.\n\nReturns:\n float: The current stock price, or None if an error occurs.", + "parameters": { + "type": "object", + "properties": {"symbol": {"type": "string"}}, + "required": ["symbol"], + }, + }, + }, + { + "type": "function", + "function": { + "name": "get_company_news", + "description": "Get company news and press releases for a given stock symbol.\n\nArgs:\nsymbol (str): The stock symbol.\n\nReturns:\npd.DataFrame: DataFrame containing company news and press releases.", + "parameters": { + "type": "object", + "properties": {"symbol": {"type": "string"}}, + "required": ["symbol"], + }, + }, + }, + { + "type": "function", + "function": { + "name": "final_answer", + "description": "Return final generated answer", + "parameters": { + "type": "object", + "properties": {"final_response": {"type": "string"}}, + "required": ["final_response"], + }, + }, + }, +] + ``` +These function definitions will be passed to our model through a prompt, +enabling it to recognize and utilize them appropriately during the conversation. + +For the actual implementations, please refer to [client_utils.py.](./artifacts/client_utils.py) + +## Prompt Engineering + +**Prompt engineering** is a crucial aspect of function calling, as it guides +the LLM in recognizing when and how to utilize specific functions. +By carefully crafting prompts, you can effectively define the LLM's role, +objectives, and the tools it can access, ensuring accurate and efficient task +execution. + +For our task, we've organized a sample prompt structure, provided +in the accompanying [`system_prompt_schema.yml`](./artifacts/system_prompt_schema.yml) +file. This file meticulously outlines: + +- **Role**: Defines the specific role the LLM is expected to perform. +- **Objective**: Clearly states the goal or desired outcome of the interaction. +- **Tools**: Lists the available functions or tools the LLM can use to achieve +its objective. +- **Schema**: Specifies the structure and format required for calling each tool +or function. +- **Instructions**: Provides a clear set of guidelines to ensure the LLM follows +the intended path and utilizes the tools appropriately. + +By leveraging prompt engineering, you can enhance the LLM's ability +to perform complex tasks and integrate function calls seamlessly into +its responses, thereby maximizing its utility in various applications. + +## Combining Everything Together + +First, let's start Triton SDK container: +```bash +# Using the SDK container as an example +docker run --rm -it --net host --shm-size=2g \ + --ulimit memlock=-1 --ulimit stack=67108864 --gpus all \ + -v /path/to/tutorials/:/tutorials \ + -v /path/to/tutorials/repo:/tutorials \ + nvcr.io/nvidia/tritonserver:-py3-sdk +``` + +The provided client script uses `pydantic` and `yfinance` libraries, which we +do not ship with the sdk container. Make sure to install it, before proceeding: + +```bash +pip install pydantic yfinance +``` + +Run the provided [`client.py`](./artifacts/client.py) as follows: + +```bash +python3 /tutorials/AI_Agents_Guide/Function_Calling/artifacts/client.py --prompt "Tell me about Rivian. Include current stock price in your final response." -o 200 +``` + +You should expect to see a response similar to: + +```bash ++++++++++++++++++++++++++++++++++++++ +RESPONSE: Rivian, with its current stock price of , ++++++++++++++++++++++++++++++++++++++ +``` + +To see what tools were "called" by our LLM, simply add `verbose` flag as follows: +```bash +python3 /tutorials/AI_Agents_Guide/Function_Calling/artifacts/client.py --prompt "Tell me about Rivian. Include current stock price in your final response." -o 200 --verbose +``` + +This will show the step-by-step process of function calling, including: +- The tools being called +- The arguments passed to each tool +- The responses from each function call +- The final summarized response + + +```bash +[b'\n{\n "step": "1",\n "description": "Get the current stock price for Rivian",\n "tool": "get_current_stock_price",\n "arguments": {\n "symbol": "RIVN"\n }\n}'] +===================================== +Executing function: get_current_stock_price({'symbol': 'RIVN'}) +Function response: +===================================== +[b'\n{\n "step": "2",\n "description": "Get company news and press releases for Rivian",\n "tool": "get_company_news",\n "arguments": {\n "symbol": "RIVN"\n }\n}'] +===================================== +Executing function: get_company_news({'symbol': 'RIVN'}) +Function response: [] +===================================== +[b'\n{\n "step": "3",\n "description": "Summarize the company news and press releases for Rivian",\n "tool": "final_answer",\n "arguments": {\n "final_response": "Rivian, with its current stock price of , "\n }\n}'] + + ++++++++++++++++++++++++++++++++++++++ +RESPONSE: Rivian, with its current stock price of , ++++++++++++++++++++++++++++++++++++++ +``` + +> [!TIP] +> In this tutorial, all functionalities (tool definitions, implementations, +> and executions) are implemented on the client side (see +> [client.py](./artifacts/client.py)). +> For production scenarios, especially when functions are known beforehand, +> consider implementing this logic on the server side. +> A recommended approach for server-side implementation is to deploy your +> workflow through a Triton [ensemble](https://github.com/triton-inference-server/server/blob/a6fff975a214ff00221790dd0a5521fb05ce3ac9/docs/user_guide/architecture.md#ensemble-models) +> or a [BLS](https://github.com/triton-inference-server/python_backend?tab=readme-ov-file#business-logic-scripting). +> Use a pre-processing model to combine and format the user prompt with the +> system prompt and available tools. Employ a post-processing model to manage +> multiple calls to the deployed LLM as needed to reach the final answer. + +## Further Optimizations + +### Enforcing Output Format + +In this tutorial, we demonstrated how to enforce a specific output format +using prompt engineering. The desired structure is as follows: +```python + { + "step" : + "description": + "tool": , + "arguments": { + + } + } +``` +However, there may be instances where the output deviates from this +required schema. For example, consider the following prompt execution: + +```bash +python3 /tutorials/AI_Agents_Guide/Function_Calling/artifacts/client.py --prompt "How Rivian is doing?" -o 500 --verbose +``` +This execution may fail with an invalid JSON format error. The verbose +output will reveal that the final LLM response contained plain text +instead of the expected JSON format: +``` +{ + "step": "3", + "description": + "tool": "final_answer", + "arguments": { + "final_response": + } +} +``` +Fortunately, this behavior can be controlled using constrained decoding, +a technique that guides the model to generate outputs that meet specific +formatting and content requirements. We strongly recommend exploring our +dedicated [tutorial](../Constrained_Decoding/README.md) on constrained decoding +to gain deeper insights and enhance your ability to manage model outputs +effectively. + +> [!TIP] +> For optimal results, utilize the `FunctionCall` class defined in +> [client_utils.py](./artifacts/client_utils.py) as the JSON schema +> for your Logits Post-Processor. This approach ensures consistent +> and properly formatted outputs, aligning with the structure we've +> established throughout this tutorial. + +### Parallel Tool Call + +This tutorial focuses on a single turn forced call, the LLM is prompted +to make a specific function call within a single interaction. This approach is +useful when a precise action is needed immediately, ensuring that +the function is executed as part of the current conversation. + +It is possible, that come of function calls can be executed simultaneously. +This technique is beneficial for tasks that can be divided into independent +operations, allowing for increased efficiency and reduced response time. + +We encourage our readers to take on the challenge of implementing +parallel tool calls as a practical exercise. + +## References + +Parts of this tutorial are based of [Hermes-Function-Calling](https://github.com/NousResearch/Hermes-Function-Calling). \ No newline at end of file diff --git a/AI_Agents_Guide/Function_Calling/artifacts/client.py b/AI_Agents_Guide/Function_Calling/artifacts/client.py new file mode 100755 index 00000000..518a33ad --- /dev/null +++ b/AI_Agents_Guide/Function_Calling/artifacts/client.py @@ -0,0 +1,276 @@ +#!/usr/bin/python +# Copyright 2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions +# are met: +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above copyright +# notice, this list of conditions and the following disclaimer in the +# documentation and/or other materials provided with the distribution. +# * Neither the name of NVIDIA CORPORATION nor the names of its +# contributors may be used to endorse or promote products derived +# from this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY +# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR +# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY +# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +import argparse +import json +import sys + +import client_utils +import numpy as np +import tritonclient.grpc as grpcclient + +if __name__ == "__main__": + parser = argparse.ArgumentParser() + parser.add_argument( + "-v", + "--verbose", + action="store_true", + required=False, + default=False, + help="Enable verbose output", + ) + parser.add_argument( + "-u", "--url", type=str, required=False, help="Inference server URL." + ) + + parser.add_argument("-p", "--prompt", type=str, required=True, help="Input prompt.") + + parser.add_argument( + "--model-name", + type=str, + required=False, + default="ensemble", + choices=["ensemble", "tensorrt_llm_bls"], + help="Name of the Triton model to send request to", + ) + + parser.add_argument( + "-S", + "--streaming", + action="store_true", + required=False, + default=False, + help="Enable streaming mode. Default is False.", + ) + + parser.add_argument( + "-b", + "--beam-width", + required=False, + type=int, + default=1, + help="Beam width value", + ) + + parser.add_argument( + "--temperature", + type=float, + required=False, + default=1.0, + help="temperature value", + ) + + parser.add_argument( + "--repetition-penalty", + type=float, + required=False, + default=None, + help="The repetition penalty value", + ) + + parser.add_argument( + "--presence-penalty", + type=float, + required=False, + default=None, + help="The presence penalty value", + ) + + parser.add_argument( + "--frequency-penalty", + type=float, + required=False, + default=None, + help="The frequency penalty value", + ) + + parser.add_argument( + "-o", + "--output-len", + type=int, + default=100, + required=False, + help="Specify output length", + ) + + parser.add_argument( + "--request-id", + type=str, + default="", + required=False, + help="The request_id for the stop request", + ) + + parser.add_argument("--stop-words", nargs="+", default=[], help="The stop words") + + parser.add_argument("--bad-words", nargs="+", default=[], help="The bad words") + + parser.add_argument( + "--embedding-bias-words", nargs="+", default=[], help="The biased words" + ) + + parser.add_argument( + "--embedding-bias-weights", + nargs="+", + default=[], + help="The biased words weights", + ) + + parser.add_argument( + "--overwrite-output-text", + action="store_true", + required=False, + default=False, + help="In streaming mode, overwrite previously received output text instead of appending to it", + ) + + parser.add_argument( + "--return-context-logits", + action="store_true", + required=False, + default=False, + help="Return context logits, the engine must be built with gather_context_logits or gather_all_token_logits", + ) + + parser.add_argument( + "--return-generation-logits", + action="store_true", + required=False, + default=False, + help="Return generation logits, the engine must be built with gather_ generation_logits or gather_all_token_logits", + ) + + parser.add_argument( + "--end-id", type=int, required=False, help="The token id for end token." + ) + + parser.add_argument( + "--pad-id", type=int, required=False, help="The token id for pad token." + ) + + FLAGS = parser.parse_args() + if FLAGS.url is None: + FLAGS.url = "localhost:8001" + + embedding_bias_words = ( + FLAGS.embedding_bias_words if FLAGS.embedding_bias_words else None + ) + embedding_bias_weights = ( + FLAGS.embedding_bias_weights if FLAGS.embedding_bias_weights else None + ) + + try: + client = grpcclient.InferenceServerClient(url=FLAGS.url) + except Exception as e: + print("client creation failed: " + str(e)) + sys.exit(1) + + return_context_logits_data = None + if FLAGS.return_context_logits: + return_context_logits_data = np.array( + [[FLAGS.return_context_logits]], dtype=bool + ) + + return_generation_logits_data = None + if FLAGS.return_generation_logits: + return_generation_logits_data = np.array( + [[FLAGS.return_generation_logits]], dtype=bool + ) + + prompt = client_utils.process_prompt(FLAGS.prompt) + + functions = client_utils.MyFunctions() + + while True: + output_text = client_utils.run_inference( + client, + prompt, + FLAGS.output_len, + FLAGS.request_id, + FLAGS.repetition_penalty, + FLAGS.presence_penalty, + FLAGS.frequency_penalty, + FLAGS.temperature, + FLAGS.stop_words, + FLAGS.bad_words, + embedding_bias_words, + embedding_bias_weights, + FLAGS.model_name, + FLAGS.streaming, + FLAGS.beam_width, + FLAGS.overwrite_output_text, + return_context_logits_data, + return_generation_logits_data, + FLAGS.end_id, + FLAGS.pad_id, + FLAGS.verbose, + ) + + try: + response = json.loads(output_text) + except ValueError: + print("\n[ERROR] LLM responded with invalid JSON format!") + break + + # Repeat the loop until `final_answer` tool is called, which indicates + # that the full response is ready and llm does not require any + # additional information. Additionally, if the loop has taken more + # than 50 steps, the script ends. + if response["tool"] == "final_answer" or response["step"] == "50": + if response["tool"] == "final_answer": + final_response = response["arguments"]["final_response"] + print("\n\n+++++++++++++++++++++++++++++++++++++") + print(f"RESPONSE: {final_response}") + print("+++++++++++++++++++++++++++++++++++++\n\n") + elif response["step"] == "50": + print("\n\n+++++++++++++++++++++++++++++++++++++") + print(f"Reached maximum number of function calls available.") + print("+++++++++++++++++++++++++++++++++++++\n\n") + break + + # Extract tool's name and arguments from the response + function_name = response["tool"] + function_args = response["arguments"] + function_to_call = getattr(functions, function_name) + # Execute function call and store results in `function_response` + function_response = function_to_call(*function_args.values()) + + if FLAGS.verbose: + print("=====================================") + print(f"Executing function: {function_name}({function_args}) ") + print(f"Function response: {str(function_response)}") + print("=====================================") + + # Update prompt with the generated function call and results of that + # call. + results_dict = f'{{"name": "{function_name}", "content": {function_response}}}' + prompt += str( + output_text + + "<|im_end|>\n" + + str(results_dict) + + "\n<|im_start|>assistant" + ) diff --git a/AI_Agents_Guide/Function_Calling/artifacts/client_utils.py b/AI_Agents_Guide/Function_Calling/artifacts/client_utils.py new file mode 100644 index 00000000..552c98fc --- /dev/null +++ b/AI_Agents_Guide/Function_Calling/artifacts/client_utils.py @@ -0,0 +1,442 @@ +# Copyright 2024, NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions +# are met: +# * Redistributions of source code must retain the above copyright +# notice, this list of conditions and the following disclaimer. +# * Redistributions in binary form must reproduce the above copyright +# notice, this list of conditions and the following disclaimer in the +# documentation and/or other materials provided with the distribution. +# * Neither the name of NVIDIA CORPORATION nor the names of its +# contributors may be used to endorse or promote products derived +# from this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS ``AS IS'' AND ANY +# EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +# PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR +# CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, +# EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, +# PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR +# PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY +# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +# (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +import os +import sys +from functools import partial +from pathlib import Path +from typing import Dict + +import pandas as pd +import yaml +import yfinance as yf +from pydantic import BaseModel + +sys.path.append(os.path.dirname(os.path.dirname(os.path.realpath(__file__)))) + +import queue +import sys + +import numpy as np +import tritonclient.grpc as grpcclient +from tritonclient.utils import InferenceServerException, np_to_triton_dtype + +############################################################################### +# TOOLS Definition and Implementation # +############################################################################### + +TOOLS = [ + { + "type": "function", + "function": { + "name": "get_current_stock_price", + "description": "Get the current stock price for a given symbol.\n\nArgs:\n symbol (str): The stock symbol.\n\nReturns:\n float: The current stock price, or None if an error occurs.", + "parameters": { + "type": "object", + "properties": {"symbol": {"type": "string"}}, + "required": ["symbol"], + }, + }, + }, + { + "type": "function", + "function": { + "name": "get_company_news", + "description": "Get company news and press releases for a given stock symbol.\n\nArgs:\nsymbol (str): The stock symbol.\n\nReturns:\npd.DataFrame: DataFrame containing company news and press releases.", + "parameters": { + "type": "object", + "properties": {"symbol": {"type": "string"}}, + "required": ["symbol"], + }, + }, + }, + { + "type": "function", + "function": { + "name": "final_answer", + "description": "Return final generated answer", + "parameters": { + "type": "object", + "properties": {"final_response": {"type": "string"}}, + "required": ["final_response"], + }, + }, + }, +] + + +class MyFunctions: + def get_company_news(self, symbol: str) -> pd.DataFrame: + """ + Get company news and press releases for a given stock symbol. + + Args: + symbol (str): The stock symbol. + + Returns: + pd.DataFrame: DataFrame containing company news and press releases. + """ + try: + news = yf.Ticker(symbol).news + title_list = [] + for entry in news: + title_list.append(entry["title"]) + return title_list + except Exception as e: + print(f"Error fetching company news for {symbol}: {e}") + return pd.DataFrame() + + def get_current_stock_price(self, symbol: str) -> float: + """ + Get the current stock price for a given symbol. + + Args: + symbol (str): The stock symbol. + + Returns: + float: The current stock price, or None if an error occurs. + """ + try: + stock = yf.Ticker(symbol) + # Use "regularMarketPrice" for regular market hours, or "currentPrice" for pre/post market + current_price = stock.info.get( + "regularMarketPrice", stock.info.get("currentPrice") + ) + return current_price if current_price else None + except Exception as e: + print(f"Error fetching current price for {symbol}: {e}") + return None + + +############################################################################### +# Helper Schemas # +############################################################################### + + +class FunctionCall(BaseModel): + step: str + """Step number for the action sequence""" + description: str + """Description of what the step does and its output""" + + tool: str + """The name of the tool to call.""" + + arguments: dict + """ + The arguments to call the function with, as generated by the model in JSON + format. Note that the model does not always generate valid JSON, and may + hallucinate parameters not defined by your function schema. Validate the + arguments in your code before calling your function. + """ + + +class PromptSchema(BaseModel): + Role: str + """Defines the specific role the LLM is expected to perform.""" + Objective: str + """States the goal or desired outcome of the interaction.""" + Tools: str + """A set of available functions or tools the LLM can use to achieve its + objective.""" + Schema: str + """ Specifies the structure and format required for calling each tool + or function.""" + Instructions: str + """Provides a clear set of guidelines to ensure the LLM follows + the intended path and utilizes the tools appropriately.""" + + +############################################################################### +# Prompt processing helper functions # +############################################################################### + + +def read_yaml_file(file_path: str) -> PromptSchema: + """ + Reads a YAML file and converts its content into a PromptSchema object. + + Args: + file_path (str): The path to the YAML file. + + Returns: + PromptSchema: An object containing the structured prompt data. + """ + with open(file_path, "r") as file: + yaml_content = yaml.safe_load(file) + + prompt_schema = PromptSchema( + Role=yaml_content.get("Role", ""), + Objective=yaml_content.get("Objective", ""), + Tools=yaml_content.get("Tools", ""), + Schema=yaml_content.get("Schema", ""), + Instructions=yaml_content.get("Instructions", ""), + ) + return prompt_schema + + +def format_yaml_prompt(prompt_schema: PromptSchema, variables: Dict) -> str: + """ + Formats the prompt schema with provided variables. + + Args: + prompt_schema (PromptSchema): The prompt schema to format. + variables (Dict): A dictionary of variables to insert into the prompt. + + Returns: + str: The formatted prompt string. + """ + formatted_prompt = "" + for field, value in prompt_schema.model_dump().items(): + formatted_value = value.format(**variables) + if field == "Instructions": + formatted_prompt += f"{formatted_value}" + else: + formatted_value = formatted_value.replace("\n", " ") + formatted_prompt += f"{formatted_value}" + return formatted_prompt + + +def process_prompt( + user_prompt, + system_prompt_yml=Path(__file__).parent.joinpath("./system_prompt_schema.yml"), + tools=TOOLS, + schema_json=FunctionCall.model_json_schema(), +): + """ + Combines and formats the user prompt with a system prompt for model + processing. + + This function reads a system prompt from a YAML file, formats it with the + provided tools and schema, and integrates it with the user's original + prompt. The result is a structured prompt ready for input into a + language model. + + Args: + user_prompt (str): The initial prompt provided by the user. + system_prompt_yml (str, optional): The file path to the system prompt + defined in a YAML file. Defaults to "./system_prompt_schema.yml". + tools (list, optional): A list of tools available for the prompt. + Defaults to the global TOOLS variable. + schema_json (dict, optional): A JSON schema for a generated function call. + Defaults to the schema from FunctionCall.model_json_schema(). + + Returns: + str: A formatted prompt string ready for use by the language model. + """ + prompt_schema = read_yaml_file(system_prompt_yml) + variables = {"tools": tools, "schema": schema_json} + sys_prompt = format_yaml_prompt(prompt_schema, variables) + processed_prompt = f"<|im_start|>system\n {sys_prompt}<|im_end|>\n" + processed_prompt += f"<|im_start|>user\n {user_prompt}\nThis is the first turn and you don't have to analyze yet. <|im_end|>\n <|im_start|>assistant" + return processed_prompt + + +############################################################################### +# Triton client helper functions # +############################################################################### + + +def prepare_tensor(name, input): + t = grpcclient.InferInput(name, input.shape, np_to_triton_dtype(input.dtype)) + t.set_data_from_numpy(input) + return t + + +class UserData: + def __init__(self): + self._completed_requests = queue.Queue() + + +def callback(user_data, result, error): + if error: + user_data._completed_requests.put(error) + else: + user_data._completed_requests.put(result) + + +def run_inference( + triton_client, + prompt, + output_len, + request_id, + repetition_penalty, + presence_penalty, + frequency_penalty, + temperature, + stop_words, + bad_words, + embedding_bias_words, + embedding_bias_weights, + model_name, + streaming, + beam_width, + overwrite_output_text, + return_context_logits_data, + return_generation_logits_data, + end_id, + pad_id, + verbose, + num_draft_tokens=0, + use_draft_logits=None, +): + input0 = [[prompt]] + input0_data = np.array(input0).astype(object) + output0_len = np.ones_like(input0).astype(np.int32) * output_len + streaming_data = np.array([[streaming]], dtype=bool) + beam_width_data = np.array([[beam_width]], dtype=np.int32) + temperature_data = np.array([[temperature]], dtype=np.float32) + + inputs = [ + prepare_tensor("text_input", input0_data), + prepare_tensor("max_tokens", output0_len), + prepare_tensor("stream", streaming_data), + prepare_tensor("beam_width", beam_width_data), + prepare_tensor("temperature", temperature_data), + ] + + if num_draft_tokens > 0: + inputs.append( + prepare_tensor( + "num_draft_tokens", np.array([[num_draft_tokens]], dtype=np.int32) + ) + ) + if use_draft_logits is not None: + inputs.append( + prepare_tensor( + "use_draft_logits", np.array([[use_draft_logits]], dtype=bool) + ) + ) + + if bad_words: + bad_words_list = np.array([bad_words], dtype=object) + inputs += [prepare_tensor("bad_words", bad_words_list)] + + if stop_words: + stop_words_list = np.array([stop_words], dtype=object) + inputs += [prepare_tensor("stop_words", stop_words_list)] + + if repetition_penalty is not None: + repetition_penalty = [[repetition_penalty]] + repetition_penalty_data = np.array(repetition_penalty, dtype=np.float32) + inputs += [prepare_tensor("repetition_penalty", repetition_penalty_data)] + + if presence_penalty is not None: + presence_penalty = [[presence_penalty]] + presence_penalty_data = np.array(presence_penalty, dtype=np.float32) + inputs += [prepare_tensor("presence_penalty", presence_penalty_data)] + + if frequency_penalty is not None: + frequency_penalty = [[frequency_penalty]] + frequency_penalty_data = np.array(frequency_penalty, dtype=np.float32) + inputs += [prepare_tensor("frequency_penalty", frequency_penalty_data)] + + if return_context_logits_data is not None: + inputs += [ + prepare_tensor("return_context_logits", return_context_logits_data), + ] + + if return_generation_logits_data is not None: + inputs += [ + prepare_tensor("return_generation_logits", return_generation_logits_data), + ] + + if (embedding_bias_words is not None and embedding_bias_weights is None) or ( + embedding_bias_words is None and embedding_bias_weights is not None + ): + assert 0, "Both embedding bias words and weights must be specified" + + if embedding_bias_words is not None and embedding_bias_weights is not None: + assert len(embedding_bias_words) == len( + embedding_bias_weights + ), "Embedding bias weights and words must have same length" + embedding_bias_words_data = np.array([embedding_bias_words], dtype=object) + embedding_bias_weights_data = np.array( + [embedding_bias_weights], dtype=np.float32 + ) + inputs.append(prepare_tensor("embedding_bias_words", embedding_bias_words_data)) + inputs.append( + prepare_tensor("embedding_bias_weights", embedding_bias_weights_data) + ) + if end_id is not None: + end_id_data = np.array([[end_id]], dtype=np.int32) + inputs += [prepare_tensor("end_id", end_id_data)] + + if pad_id is not None: + pad_id_data = np.array([[pad_id]], dtype=np.int32) + inputs += [prepare_tensor("pad_id", pad_id_data)] + + user_data = UserData() + # Establish stream + triton_client.start_stream(callback=partial(callback, user_data)) + # Send request + triton_client.async_stream_infer(model_name, inputs, request_id=request_id) + + # Wait for server to close the stream + triton_client.stop_stream() + + # Parse the responses + output_text = "" + while True: + try: + result = user_data._completed_requests.get(block=False) + except Exception: + break + + if type(result) == InferenceServerException: + print("Received an error from server:") + print(result) + else: + output = result.as_numpy("text_output") + if streaming and beam_width == 1: + new_output = output[0].decode("utf-8") + if overwrite_output_text: + output_text = new_output + else: + output_text += new_output + else: + output_text = output[0].decode("utf-8") + if verbose: + print( + str("\n[VERBOSE MODE] LLM's response:" + output_text), + flush=True, + ) + + if return_context_logits_data is not None: + context_logits = result.as_numpy("context_logits") + if verbose: + print(f"context_logits.shape: {context_logits.shape}") + print(f"context_logits: {context_logits}") + if return_generation_logits_data is not None: + generation_logits = result.as_numpy("generation_logits") + if verbose: + print(f"generation_logits.shape: {generation_logits.shape}") + print(f"generation_logits: {generation_logits}") + + if streaming and beam_width == 1: + if verbose: + print(output_text) + + return output_text diff --git a/AI_Agents_Guide/Function_Calling/artifacts/system_prompt_schema.yml b/AI_Agents_Guide/Function_Calling/artifacts/system_prompt_schema.yml new file mode 100644 index 00000000..80ae7b6b --- /dev/null +++ b/AI_Agents_Guide/Function_Calling/artifacts/system_prompt_schema.yml @@ -0,0 +1,54 @@ +Role: | + You are an expert assistant who can solve any task using JSON tool calls. + You will be given a task to solve as best you can. + These tools are basically Python functions which you can call with code. + If your task is not related to any of available tools, don't use any of + available tools. +Objective: | + You may use agentic frameworks for reasoning and planning to help with user query. + Please call a function and wait for function results to be provided to you in the next iteration. + Don't make assumptions about what values to plug into function arguments. + Once you have called a function, results will be fed back to you within XML tags + in the following form: + {{"name": , "content": }} + Don't make assumptions about tool results if XML tags are not present since function hasn't been executed yet. + Analyze the data once you get the results and call another function. + Your final response should directly answer the user query with an analysis or summary of the results of function calls. + You MUST summarise all previous responses in the final response. +Tools: | + Only use the set of these available tools: + {tools} + If none of those tools are related to the task, then only use `final_answer` + to provide your response. +Schema: | + Use the following pydantic model json schema for each tool call you will make: + {schema} +Instructions: | + Output a step-by-step plan to solve the task using the given tools. + This plan should involve individual tasks based on the available tools, + that if executed correctly will yield the correct answer. + Each step should be structured as follows: + {{ + "step" : + "description": + "tool": , + "arguments": {{ + + }} + }} + Each step must be necessary to reach the final answer. + Steps should reuse outputs produced by earlier steps. + The last step must be the final answer. It is the only way to complete + the task, else you will be stuck on a loop. + So your final output should look like this: + {{ + "step" : + "description": "Provide the final answer", + "tool": "final_answer", + "arguments": {{ + "final_response": + }} + }} + Calling multiple functions at once can overload the system and increase + cost so call one function at a time please. + If you plan to continue with analysis, always call another function.