|
| 1 | +""" |
| 2 | +Prompt Library Module for Flare AI RAG |
| 3 | +
|
| 4 | +This module provides a centralized management system for AI prompts used throughout |
| 5 | +the Flare AI RAG application. It handles the organization, storage, and retrieval |
| 6 | +of various prompt templates used for different operations like token transactions, |
| 7 | +account generation, and user interactions. |
| 8 | +
|
| 9 | +The module implements a PromptLibrary class that maintains a collection of Prompt |
| 10 | +objects, each representing a specific type of interaction or operation template. |
| 11 | +Prompts are categorized for easy management and retrieval. |
| 12 | +""" |
| 13 | + |
| 14 | +import structlog |
| 15 | + |
| 16 | +from flare_ai_rag.prompts.schemas import ( |
| 17 | + Prompt, |
| 18 | + RAGRouterResponse, |
| 19 | + SemanticRouterResponse, |
| 20 | +) |
| 21 | +from flare_ai_rag.prompts.templates import ( |
| 22 | + CONVERSATIONAL, |
| 23 | + RAG_RESPONDER, |
| 24 | + RAG_ROUTER, |
| 25 | + REMOTE_ATTESTATION, |
| 26 | + SEMANTIC_ROUTER, |
| 27 | +) |
| 28 | + |
| 29 | +logger = structlog.get_logger(__name__) |
| 30 | + |
| 31 | + |
| 32 | +class PromptLibrary: |
| 33 | + """ |
| 34 | + A library for managing and organizing AI prompts used in the Flare AI RAG. |
| 35 | +
|
| 36 | + This class serves as a central repository for all prompt templates used in |
| 37 | + the application. It provides functionality to add, retrieve, and categorize |
| 38 | + prompts for various operations such as token transactions, account management, |
| 39 | + and user interactions. |
| 40 | +
|
| 41 | + Attributes: |
| 42 | + prompts (dict[str, Prompt]): Dictionary storing prompt objects |
| 43 | + with their names as keys. |
| 44 | + """ |
| 45 | + |
| 46 | + def __init__(self) -> None: |
| 47 | + """ |
| 48 | + Initialize a new PromptLibrary instance. |
| 49 | +
|
| 50 | + Creates an empty prompt dictionary and populates it with default prompts |
| 51 | + through the _initialize_default_prompts method. |
| 52 | + """ |
| 53 | + self.prompts: dict[str, Prompt] = {} |
| 54 | + self._initialize_default_prompts() |
| 55 | + |
| 56 | + def _initialize_default_prompts(self) -> None: |
| 57 | + """ |
| 58 | + Initialize the library with a set of default prompts. |
| 59 | +
|
| 60 | + Creates and adds the following default prompts: |
| 61 | + - semantic_router: For routing user queries |
| 62 | + - token_send: For token transfer operations |
| 63 | + - token_swap: For token swap operations |
| 64 | + - generate_account: For wallet generation |
| 65 | + - conversational: For general user interactions |
| 66 | + - request_attestation: For remote attestation requests |
| 67 | + - tx_confirmation: For transaction confirmation |
| 68 | +
|
| 69 | + This method is called automatically during instance initialization. |
| 70 | + """ |
| 71 | + default_prompts = [ |
| 72 | + Prompt( |
| 73 | + name="semantic_router", |
| 74 | + description="Route user query based on user input", |
| 75 | + template=SEMANTIC_ROUTER, |
| 76 | + required_inputs=["user_input"], |
| 77 | + response_mime_type="text/x.enum", |
| 78 | + response_schema=SemanticRouterResponse, |
| 79 | + category="router", |
| 80 | + ), |
| 81 | + Prompt( |
| 82 | + name="conversational", |
| 83 | + description="Converse with a user", |
| 84 | + template=CONVERSATIONAL, |
| 85 | + required_inputs=["user_input"], |
| 86 | + response_schema=None, |
| 87 | + response_mime_type=None, |
| 88 | + category="conversational", |
| 89 | + ), |
| 90 | + Prompt( |
| 91 | + name="rag_router", |
| 92 | + description="The ", |
| 93 | + template=RAG_ROUTER, |
| 94 | + required_inputs=["user_input"], |
| 95 | + response_mime_type="application/json", |
| 96 | + response_schema=RAGRouterResponse, |
| 97 | + category="rag-router", |
| 98 | + ), |
| 99 | + Prompt( |
| 100 | + name="rag_responder", |
| 101 | + description="The ", |
| 102 | + template=RAG_RESPONDER, |
| 103 | + required_inputs=["user_input"], |
| 104 | + response_schema=None, |
| 105 | + response_mime_type=None, |
| 106 | + category="conversational", |
| 107 | + ), |
| 108 | + Prompt( |
| 109 | + name="request_attestation", |
| 110 | + description="User has requested a remote attestation", |
| 111 | + template=REMOTE_ATTESTATION, |
| 112 | + required_inputs=None, |
| 113 | + response_schema=None, |
| 114 | + response_mime_type=None, |
| 115 | + category="conversational", |
| 116 | + ), |
| 117 | + ] |
| 118 | + |
| 119 | + for prompt in default_prompts: |
| 120 | + self.add_prompt(prompt) |
| 121 | + |
| 122 | + def add_prompt(self, prompt: Prompt) -> None: |
| 123 | + """ |
| 124 | + Add a new prompt to the library. |
| 125 | +
|
| 126 | + Args: |
| 127 | + prompt (Prompt): The prompt object to add to the library. |
| 128 | +
|
| 129 | + Logs: |
| 130 | + Debug log entry when prompt is successfully added. |
| 131 | +
|
| 132 | + Example: |
| 133 | + ```python |
| 134 | + custom_prompt = Prompt(name="custom", template="...", category="misc") |
| 135 | + library.add_prompt(custom_prompt) |
| 136 | + ``` |
| 137 | + """ |
| 138 | + self.prompts[prompt.name] = prompt |
| 139 | + logger.debug("prompt_added", name=prompt.name, category=prompt.category) |
| 140 | + |
| 141 | + def get_prompt(self, name: str) -> Prompt: |
| 142 | + """ |
| 143 | + Retrieve a prompt by its name. |
| 144 | +
|
| 145 | + Args: |
| 146 | + name (str): The name of the prompt to retrieve. |
| 147 | +
|
| 148 | + Returns: |
| 149 | + Prompt: The requested prompt object. |
| 150 | +
|
| 151 | + Raises: |
| 152 | + KeyError: If the prompt name doesn't exist in the library. |
| 153 | +
|
| 154 | + Example: |
| 155 | + ```python |
| 156 | + try: |
| 157 | + prompt = library.get_prompt("token_send") |
| 158 | + except KeyError: |
| 159 | + print("Prompt not found") |
| 160 | + ``` |
| 161 | + """ |
| 162 | + if name not in self.prompts: |
| 163 | + logger.error("prompt_not_found", name=name) |
| 164 | + msg = f"Prompt '{name}' not found in library" |
| 165 | + raise KeyError(msg) |
| 166 | + return self.prompts[name] |
| 167 | + |
| 168 | + def get_prompts_by_category(self, category: str) -> list[Prompt]: |
| 169 | + """ |
| 170 | + Get all prompts in a specific category. |
| 171 | +
|
| 172 | + Args: |
| 173 | + category (str): The category to filter prompts by. |
| 174 | +
|
| 175 | + Returns: |
| 176 | + list[Prompt]: A list of all prompts in the specified category. |
| 177 | + """ |
| 178 | + return [ |
| 179 | + prompt for prompt in self.prompts.values() if prompt.category == category |
| 180 | + ] |
| 181 | + |
| 182 | + def list_categories(self) -> list[str]: |
| 183 | + """ |
| 184 | + List all available prompt categories. |
| 185 | +
|
| 186 | + Returns: |
| 187 | + list[str]: A list of unique category names used in the library. |
| 188 | +
|
| 189 | + Example: |
| 190 | + ```python |
| 191 | + categories = library.list_categories() |
| 192 | + print("Available categories:", categories) |
| 193 | + ``` |
| 194 | + """ |
| 195 | + return list( |
| 196 | + { |
| 197 | + prompt.category |
| 198 | + for prompt in self.prompts.values() |
| 199 | + if prompt.category is not None |
| 200 | + } |
| 201 | + ) |
0 commit comments