1919@dataclass
2020class Demonstration :
2121 """A single, mathematically verified few-shot example."""
22+
2223 input_text : str
2324 output_text : str
2425 golden_index : int = - 1
@@ -27,6 +28,7 @@ class Demonstration:
2728@dataclass
2829class DemonstrationSet :
2930 """A set of demonstrations to be dynamically injected into a prompt."""
31+
3032 demonstrations : List [Demonstration ] = field (default_factory = list )
3133 id : str = ""
3234
@@ -37,10 +39,11 @@ def __post_init__(self):
3739 def to_text (self , max_demonstrations : Optional [int ] = None ) -> str :
3840 """Render demonstrations as text for inclusion in prompts."""
3941 demos_to_use = (
40- self .demonstrations [:max_demonstrations ]
41- if max_demonstrations else self .demonstrations
42+ self .demonstrations [:max_demonstrations ]
43+ if max_demonstrations
44+ else self .demonstrations
4245 )
43-
46+
4447 if not demos_to_use :
4548 return ""
4649
@@ -56,7 +59,7 @@ def to_text(self, max_demonstrations: Optional[int] = None) -> str:
5659
5760class DemonstrationBootstrapper :
5861 """
59- Bootstraps few-shot demonstrations by running the prompt on training
62+ Bootstraps few-shot demonstrations by running the prompt on training
6063 examples and keeping strictly successful outputs based on metric success.
6164 """
6265
@@ -78,28 +81,42 @@ def __init__(
7881 else :
7982 self .random_state = random_state or random .Random ()
8083
81- def _extract_input (self , golden : Union [Golden , ConversationalGolden ]) -> str :
84+ def _extract_input (
85+ self , golden : Union [Golden , ConversationalGolden ]
86+ ) -> str :
8287 """Strictly extract the input text, throwing errors on invalid state."""
8388 if isinstance (golden , Golden ):
8489 if not golden .input :
85- raise DeepEvalError ("Golden must have a valid 'input' for MIPROv2 bootstrapping." )
90+ raise DeepEvalError (
91+ "Golden must have a valid 'input' for MIPROv2 bootstrapping."
92+ )
8693 return golden .input
87-
94+
8895 else :
89- user_turns = [t .content for t in (golden .turns or []) if t .role == "user" ]
96+ user_turns = [
97+ t .content for t in (golden .turns or []) if t .role == "user"
98+ ]
9099 if not user_turns :
91- raise DeepEvalError ("ConversationalGolden must have at least one 'user' turn for MIPROv2 bootstrapping." )
100+ raise DeepEvalError (
101+ "ConversationalGolden must have at least one 'user' turn for MIPROv2 bootstrapping."
102+ )
92103 return "\n " .join (user_turns )
93104
94- def _extract_expected_output (self , golden : Union [Golden , ConversationalGolden ]) -> Optional [str ]:
105+ def _extract_expected_output (
106+ self , golden : Union [Golden , ConversationalGolden ]
107+ ) -> Optional [str ]:
95108 """Strictly extract the expected output/outcome if it exists."""
96109 if isinstance (golden , Golden ):
97110 if not golden .expected_output :
98- raise DeepEvalError ("Golden must have a valid 'expected_output' for MIPROv2 bootstrapping." )
111+ raise DeepEvalError (
112+ "Golden must have a valid 'expected_output' for MIPROv2 bootstrapping."
113+ )
99114 return str (golden .expected_output )
100115 else :
101116 if not golden .expected_outcome :
102- raise DeepEvalError ("ConversationalGolden must have a valid 'expected_outcome' for MIPROv2 bootstrapping." )
117+ raise DeepEvalError (
118+ "ConversationalGolden must have a valid 'expected_outcome' for MIPROv2 bootstrapping."
119+ )
103120 return golden .expected_outcome
104121
105122 def bootstrap (
@@ -114,25 +131,43 @@ def bootstrap(
114131 shuffled_indices = list (range (len (goldens )))
115132 self .random_state .shuffle (shuffled_indices )
116133
117- max_attempts = min (len (goldens ), self .max_bootstrapped_demonstrations * 3 )
134+ max_attempts = min (
135+ len (goldens ), self .max_bootstrapped_demonstrations * 3
136+ )
118137 prompt_dict = {"__module__" : prompt }
119138
120139 for idx in shuffled_indices [:max_attempts ]:
121140 golden = goldens [idx ]
122141 input_text = self ._extract_input (golden )
123142 expected = self ._extract_expected_output (golden )
124143
125- if expected and len (labeled_demonstrations ) < self .max_labeled_demonstrations * self .num_demonstration_sets :
126- labeled_demonstrations .append (Demonstration (input_text = input_text , output_text = expected , golden_index = idx ))
127-
128- if len (all_demonstrations ) < self .max_bootstrapped_demonstrations * self .num_demonstration_sets :
144+ if (
145+ expected
146+ and len (labeled_demonstrations )
147+ < self .max_labeled_demonstrations * self .num_demonstration_sets
148+ ):
149+ labeled_demonstrations .append (
150+ Demonstration (
151+ input_text = input_text ,
152+ output_text = expected ,
153+ golden_index = idx ,
154+ )
155+ )
156+
157+ if (
158+ len (all_demonstrations )
159+ < self .max_bootstrapped_demonstrations
160+ * self .num_demonstration_sets
161+ ):
129162 try :
130163 # 1. Generate actual output
131164 actual_output = self .scorer .generate (prompt_dict , golden )
132-
165+
133166 # 2. Build the test case safely
134- test_case = self .scorer ._golden_to_test_case (golden , actual_output )
135-
167+ test_case = self .scorer ._golden_to_test_case (
168+ golden , actual_output
169+ )
170+
136171 # 3. Evaluate against all metrics
137172 metrics = copy_metrics (self .scorer .metrics )
138173 is_successful = True
@@ -141,19 +176,32 @@ def bootstrap(
141176 if not metric .is_successful ():
142177 is_successful = False
143178 break
144-
179+
145180 # 4. Save if all metrics passed
146181 if is_successful :
147- all_demonstrations .append (Demonstration (input_text = input_text , output_text = actual_output , golden_index = idx ))
182+ all_demonstrations .append (
183+ Demonstration (
184+ input_text = input_text ,
185+ output_text = actual_output ,
186+ golden_index = idx ,
187+ )
188+ )
148189 except Exception :
149190 continue
150191
151- if (len (all_demonstrations ) >= self .max_bootstrapped_demonstrations * self .num_demonstration_sets and
152- len (labeled_demonstrations ) >= self .max_labeled_demonstrations * self .num_demonstration_sets ):
192+ if (
193+ len (all_demonstrations )
194+ >= self .max_bootstrapped_demonstrations
195+ * self .num_demonstration_sets
196+ and len (labeled_demonstrations )
197+ >= self .max_labeled_demonstrations * self .num_demonstration_sets
198+ ):
153199 break
154200
155- demo_sets = self ._create_demonstration_sets (all_demonstrations , labeled_demonstrations )
156-
201+ demo_sets = self ._create_demonstration_sets (
202+ all_demonstrations , labeled_demonstrations
203+ )
204+
157205 if not demo_sets or all (not ds .demonstrations for ds in demo_sets ):
158206 raise DeepEvalError (
159207 "Bootstrapper failed to generate any demonstrations. "
@@ -172,7 +220,9 @@ async def a_bootstrap(
172220 shuffled_indices = list (range (len (goldens )))
173221 self .random_state .shuffle (shuffled_indices )
174222
175- max_attempts = min (len (goldens ), self .max_bootstrapped_demonstrations * 3 )
223+ max_attempts = min (
224+ len (goldens ), self .max_bootstrapped_demonstrations * 3
225+ )
176226 selected_indices = shuffled_indices [:max_attempts ]
177227
178228 tasks_info : List [Tuple [int , str , Optional [str ]]] = []
@@ -183,23 +233,41 @@ async def a_bootstrap(
183233 input_text = self ._extract_input (golden )
184234 expected = self ._extract_expected_output (golden )
185235
186- if expected and len (labeled_demonstrations ) < self .max_labeled_demonstrations * self .num_demonstration_sets :
187- labeled_demonstrations .append (Demonstration (input_text = input_text , output_text = expected , golden_index = idx ))
236+ if (
237+ expected
238+ and len (labeled_demonstrations )
239+ < self .max_labeled_demonstrations * self .num_demonstration_sets
240+ ):
241+ labeled_demonstrations .append (
242+ Demonstration (
243+ input_text = input_text ,
244+ output_text = expected ,
245+ golden_index = idx ,
246+ )
247+ )
188248
189249 tasks_info .append ((idx , input_text , expected ))
190250
191- max_bootstrapped = self .max_bootstrapped_demonstrations * self .num_demonstration_sets
251+ max_bootstrapped = (
252+ self .max_bootstrapped_demonstrations * self .num_demonstration_sets
253+ )
192254 tasks_info = tasks_info [:max_bootstrapped ]
193255
194- async def evaluate_one (idx : int , input_text : str , expected : Optional [str ]) -> Optional [Demonstration ]:
256+ async def evaluate_one (
257+ idx : int , input_text : str , expected : Optional [str ]
258+ ) -> Optional [Demonstration ]:
195259 golden = goldens [idx ]
196260 try :
197261 # 1. Generate actual output
198- actual_output = await self .scorer .a_generate (prompt_dict , golden )
199-
262+ actual_output = await self .scorer .a_generate (
263+ prompt_dict , golden
264+ )
265+
200266 # 2. Build the test case safely
201- test_case = self .scorer ._golden_to_test_case (golden , actual_output )
202-
267+ test_case = self .scorer ._golden_to_test_case (
268+ golden , actual_output
269+ )
270+
203271 # 3. Evaluate against all metrics
204272 metrics = copy_metrics (self .scorer .metrics )
205273 is_successful = True
@@ -211,16 +279,24 @@ async def evaluate_one(idx: int, input_text: str, expected: Optional[str]) -> Op
211279
212280 # 4. Save if all metrics passed
213281 if is_successful :
214- return Demonstration (input_text = input_text , output_text = actual_output , golden_index = idx )
282+ return Demonstration (
283+ input_text = input_text ,
284+ output_text = actual_output ,
285+ golden_index = idx ,
286+ )
215287 except Exception :
216288 pass
217289 return None
218290
219- results = await asyncio .gather (* [evaluate_one (idx , inp , exp ) for idx , inp , exp in tasks_info ])
291+ results = await asyncio .gather (
292+ * [evaluate_one (idx , inp , exp ) for idx , inp , exp in tasks_info ]
293+ )
220294 all_demonstrations = [demo for demo in results if demo is not None ]
221295
222- demo_sets = self ._create_demonstration_sets (all_demonstrations , labeled_demonstrations )
223-
296+ demo_sets = self ._create_demonstration_sets (
297+ all_demonstrations , labeled_demonstrations
298+ )
299+
224300 if not demo_sets or all (not ds .demonstrations for ds in demo_sets ):
225301 raise DeepEvalError (
226302 "Bootstrapper failed to generate any demonstrations. "
@@ -230,23 +306,36 @@ async def evaluate_one(idx: int, input_text: str, expected: Optional[str]) -> Op
230306 return demo_sets
231307
232308 def _create_demonstration_sets (
233- self ,
234- bootstrapped_demonstrations : List [Demonstration ],
235- labeled_demonstrations : List [Demonstration ]
309+ self ,
310+ bootstrapped_demonstrations : List [Demonstration ],
311+ labeled_demonstrations : List [Demonstration ],
236312 ) -> List [DemonstrationSet ]:
237-
238- demo_sets : List [DemonstrationSet ] = [DemonstrationSet (demonstrations = [], id = "0-shot" )]
313+
314+ demo_sets : List [DemonstrationSet ] = [
315+ DemonstrationSet (demonstrations = [], id = "0-shot" )
316+ ]
239317
240318 for _ in range (self .num_demonstration_sets ):
241319 demos : List [Demonstration ] = []
242320
243321 if bootstrapped_demonstrations :
244- n_boot = min (self .max_bootstrapped_demonstrations , len (bootstrapped_demonstrations ))
245- demos .extend (self .random_state .sample (bootstrapped_demonstrations , n_boot ))
322+ n_boot = min (
323+ self .max_bootstrapped_demonstrations ,
324+ len (bootstrapped_demonstrations ),
325+ )
326+ demos .extend (
327+ self .random_state .sample (
328+ bootstrapped_demonstrations , n_boot
329+ )
330+ )
246331
247332 if labeled_demonstrations :
248- n_labeled = min (self .max_labeled_demonstrations , len (labeled_demonstrations ))
249- labeled_sample = self .random_state .sample (labeled_demonstrations , n_labeled )
333+ n_labeled = min (
334+ self .max_labeled_demonstrations , len (labeled_demonstrations )
335+ )
336+ labeled_sample = self .random_state .sample (
337+ labeled_demonstrations , n_labeled
338+ )
250339 existing_indices = {d .golden_index for d in demos }
251340 for demo in labeled_sample :
252341 if demo .golden_index not in existing_indices :
@@ -261,9 +350,9 @@ def _create_demonstration_sets(
261350
262351
263352def render_prompt_with_demonstrations (
264- prompt : Prompt ,
265- demonstration_set : Optional [DemonstrationSet ],
266- max_demonstrations : int = 8
353+ prompt : Prompt ,
354+ demonstration_set : Optional [DemonstrationSet ],
355+ max_demonstrations : int = 8 ,
267356) -> Prompt :
268357 from deepeval .prompt .api import PromptType , PromptMessage
269358
@@ -277,14 +366,20 @@ def render_prompt_with_demonstrations(
277366 demo_added = False
278367 for msg in prompt .messages_template :
279368 if not demo_added and msg .role == "system" :
280- new_messages .append (PromptMessage (role = msg .role , content = f"{ msg .content } \n \n { demo_text } " ))
369+ new_messages .append (
370+ PromptMessage (
371+ role = msg .role , content = f"{ msg .content } \n \n { demo_text } "
372+ )
373+ )
281374 demo_added = True
282375 else :
283376 new_messages .append (msg )
284377
285378 if not demo_added and new_messages :
286379 first = new_messages [0 ]
287- new_messages [0 ] = PromptMessage (role = first .role , content = f"{ demo_text } \n \n { first .content } " )
380+ new_messages [0 ] = PromptMessage (
381+ role = first .role , content = f"{ demo_text } \n \n { first .content } "
382+ )
288383 return Prompt (messages_template = new_messages )
289384 else :
290- return Prompt (text_template = f"{ demo_text } \n \n { prompt .text_template } " )
385+ return Prompt (text_template = f"{ demo_text } \n \n { prompt .text_template } " )
0 commit comments