|
| 1 | +# pylint: disable=line-too-long,useless-suppression |
| 2 | +# ------------------------------------ |
| 3 | +# Copyright (c) Microsoft Corporation. |
| 4 | +# Licensed under the MIT License. |
| 5 | +# ------------------------------------ |
| 6 | +# pylint: disable=line-too-long,R,no-member |
| 7 | +"""Customize generated code here. |
| 8 | +
|
| 9 | +Follow our quickstart for examples: https://aka.ms/azsdk/python/dpcodegen/python/customize |
| 10 | +""" |
| 11 | + |
| 12 | +import traceback |
| 13 | +import sys |
| 14 | +from pathlib import Path |
| 15 | +from typing import Any, Dict, List, Optional |
| 16 | +from typing_extensions import Self |
| 17 | + |
| 18 | + |
| 19 | +class PromptTemplate: |
| 20 | + """A helper class which takes variant of inputs, e.g. Prompty format or string, and returns the parsed prompt in an array. |
| 21 | + Prompty library is required to use this class (`pip install prompty`). |
| 22 | + """ |
| 23 | + |
| 24 | + _MISSING_PROMPTY_PACKAGE_MESSAGE = ( |
| 25 | + "The 'prompty' package is required in order to use the 'PromptTemplate' class. " |
| 26 | + "Please install it by running 'pip install prompty'." |
| 27 | + ) |
| 28 | + |
| 29 | + @classmethod |
| 30 | + def from_prompty(cls, file_path: str) -> Self: |
| 31 | + """Initialize a PromptTemplate object from a prompty file. |
| 32 | +
|
| 33 | + :param file_path: The path to the prompty file. |
| 34 | + :type file_path: str |
| 35 | + :return: The PromptTemplate object. |
| 36 | + :rtype: PromptTemplate |
| 37 | + """ |
| 38 | + if not file_path: |
| 39 | + raise ValueError("Please provide file_path") |
| 40 | + |
| 41 | + try: |
| 42 | + from prompty import load |
| 43 | + except ImportError as exc: |
| 44 | + raise ImportError(cls._MISSING_PROMPTY_PACKAGE_MESSAGE) from exc |
| 45 | + |
| 46 | + # Get the absolute path of the file by `traceback.extract_stack()`, it's "-2" because: |
| 47 | + # In the stack, the last function is the current function. |
| 48 | + # The second last function is the caller function, which is the root of the file_path. |
| 49 | + stack = traceback.extract_stack() |
| 50 | + caller = Path(stack[-2].filename) |
| 51 | + abs_file_path = Path(caller.parent / Path(file_path)).resolve().absolute() |
| 52 | + |
| 53 | + prompty = load(str(abs_file_path)) |
| 54 | + prompty.template.type = "mustache" # For Azure, default to mustache instead of Jinja2 |
| 55 | + return cls(prompty=prompty) |
| 56 | + |
| 57 | + @classmethod |
| 58 | + def from_string(cls, prompt_template: str, api: str = "chat", model_name: Optional[str] = None) -> Self: |
| 59 | + """Initialize a PromptTemplate object from a message template. |
| 60 | +
|
| 61 | + :param prompt_template: The prompt template string. |
| 62 | + :type prompt_template: str |
| 63 | + :param api: The API type, e.g. "chat" or "completion". |
| 64 | + :type api: str |
| 65 | + :param model_name: The model name, e.g. "gpt-4o-mini". |
| 66 | + :type model_name: str |
| 67 | + :return: The PromptTemplate object. |
| 68 | + :rtype: PromptTemplate |
| 69 | + """ |
| 70 | + try: |
| 71 | + from prompty import headless |
| 72 | + except ImportError as exc: |
| 73 | + raise ImportError(cls._MISSING_PROMPTY_PACKAGE_MESSAGE) from exc |
| 74 | + |
| 75 | + prompt_template = cls._remove_leading_empty_space(prompt_template) |
| 76 | + prompty = headless(api=api, content=prompt_template) |
| 77 | + prompty.template.type = "mustache" # For Azure, default to mustache instead of Jinja2 |
| 78 | + prompty.template.parser = "prompty" |
| 79 | + return cls( |
| 80 | + api=api, |
| 81 | + model_name=model_name, |
| 82 | + prompty=prompty, |
| 83 | + ) |
| 84 | + |
| 85 | + @classmethod |
| 86 | + def _remove_leading_empty_space(cls, multiline_str: str) -> str: |
| 87 | + """ |
| 88 | + Processes a multiline string by: |
| 89 | + 1. Removing empty lines |
| 90 | + 2. Finding the minimum leading spaces |
| 91 | + 3. Indenting all lines to the minimum level |
| 92 | +
|
| 93 | + :param multiline_str: The input multiline string. |
| 94 | + :type multiline_str: str |
| 95 | + :return: The processed multiline string. |
| 96 | + :rtype: str |
| 97 | + """ |
| 98 | + lines = multiline_str.splitlines() |
| 99 | + start_index = 0 |
| 100 | + while start_index < len(lines) and lines[start_index].strip() == "": |
| 101 | + start_index += 1 |
| 102 | + |
| 103 | + # Find the minimum number of leading spaces |
| 104 | + min_spaces = sys.maxsize |
| 105 | + for line in lines[start_index:]: |
| 106 | + if len(line.strip()) == 0: |
| 107 | + continue |
| 108 | + spaces = len(line) - len(line.lstrip()) |
| 109 | + spaces += line.lstrip().count("\t") * 2 # Count tabs as 2 spaces |
| 110 | + min_spaces = min(min_spaces, spaces) |
| 111 | + |
| 112 | + # Remove leading spaces and indent to the minimum level |
| 113 | + processed_lines = [] |
| 114 | + for line in lines[start_index:]: |
| 115 | + processed_lines.append(line[min_spaces:]) |
| 116 | + |
| 117 | + return "\n".join(processed_lines) |
| 118 | + |
| 119 | + def __init__( |
| 120 | + self, |
| 121 | + *, |
| 122 | + api: str = "chat", |
| 123 | + prompty: Optional["Prompty"] = None, # type: ignore[name-defined] |
| 124 | + prompt_template: Optional[str] = None, |
| 125 | + model_name: Optional[str] = None, |
| 126 | + ) -> None: |
| 127 | + """Create a PromptTemplate object. |
| 128 | +
|
| 129 | + :keyword api: The API type. |
| 130 | + :paramtype api: str |
| 131 | + :keyword prompty: Optional Prompty object. |
| 132 | + :paramtype prompty: ~prompty.Prompty or None. |
| 133 | + :keyword prompt_template: Optional prompt template string. |
| 134 | + :paramtype prompt_template: str or None. |
| 135 | + :keyword model_name: Optional AI Model name. |
| 136 | + :paramtype model_name: str or None. |
| 137 | + """ |
| 138 | + self.prompty = prompty |
| 139 | + if self.prompty is not None: |
| 140 | + self.model_name = ( |
| 141 | + self.prompty.model.configuration["azure_deployment"] |
| 142 | + if "azure_deployment" in self.prompty.model.configuration |
| 143 | + else None |
| 144 | + ) |
| 145 | + self.parameters = self.prompty.model.parameters |
| 146 | + self._config = {} |
| 147 | + elif prompt_template is not None: |
| 148 | + self.model_name = model_name |
| 149 | + self.parameters = {} |
| 150 | + # _config is a dict to hold the internal configuration |
| 151 | + self._config = { |
| 152 | + "api": api if api is not None else "chat", |
| 153 | + "prompt_template": prompt_template, |
| 154 | + } |
| 155 | + else: |
| 156 | + raise ValueError("Please pass valid arguments for PromptTemplate") |
| 157 | + |
| 158 | + def create_messages(self, data: Optional[Dict[str, Any]] = None, **kwargs) -> List[Dict[str, Any]]: |
| 159 | + """Render the prompt template with the given data. |
| 160 | +
|
| 161 | + :param data: The data to render the prompt template with. |
| 162 | + :type data: Optional[Dict[str, Any]] |
| 163 | + :return: The rendered prompt template. |
| 164 | + :rtype: List[Dict[str, Any]] |
| 165 | + """ |
| 166 | + try: |
| 167 | + from prompty import prepare |
| 168 | + except ImportError as exc: |
| 169 | + raise ImportError(self._MISSING_PROMPTY_PACKAGE_MESSAGE) from exc |
| 170 | + |
| 171 | + if data is None: |
| 172 | + data = kwargs |
| 173 | + |
| 174 | + if self.prompty is not None: |
| 175 | + parsed = prepare(self.prompty, data) |
| 176 | + return parsed # type: ignore |
| 177 | + else: |
| 178 | + raise ValueError("Please provide valid prompt template") |
| 179 | + |
| 180 | + |
| 181 | +def patch_sdk(): |
| 182 | + """Do not remove from this file. |
| 183 | +
|
| 184 | + `patch_sdk` is a last resort escape hatch that allows you to do customizations |
| 185 | + you can't accomplish using the techniques described in |
| 186 | + https://aka.ms/azsdk/python/dpcodegen/python/customize |
| 187 | + """ |
0 commit comments