@@ -272,7 +272,7 @@ def evaluate(self, generated_text: str, dataset_subset: List) -> List:
272272 return wrong_examples
273273
274274 @iolog .log_io_params
275- def select_top_prompts (self , prompt_score_list : List , top_n : int ) -> List :
275+ def select_top_prompts (self , prompt_score_list : List , top_n : int , resolve_tie_criteria : str ) -> List :
276276 """
277277 Sort prompts in prompt_score_list, based on its performance. And return max, top `top_n` prompts.
278278
@@ -281,9 +281,16 @@ def select_top_prompts(self, prompt_score_list: List, top_n: int) -> List:
281281 :param top_n: Max number of prompts from the top of the list, that we need to return
282282 :return: List of top `top_n` prompts.
283283 """
284- sorted_prompts = sorted (prompt_score_list , key = lambda x : [x [self .GetPromptScoreIndex .SCORE ],
285- len (x [self .GetPromptScoreIndex .PROMPT_STR ])],
286- reverse = True )
284+
285+ if resolve_tie_criteria == 'max' :
286+ sorted_prompts = sorted (prompt_score_list , key = lambda x : [x [self .GetPromptScoreIndex .SCORE ],
287+ len (x [self .GetPromptScoreIndex .PROMPT_STR ])],
288+ reverse = True )
289+ else :
290+ sorted_prompts = sorted (prompt_score_list , key = lambda x : [x [self .GetPromptScoreIndex .SCORE ],
291+ - 1 * len (x [self .GetPromptScoreIndex .PROMPT_STR ])],
292+ reverse = True )
293+
287294 sorted_top_n_prompts = sorted_prompts [:top_n ]
288295 self .logger .debug (f"Sorted top n prompts: { sorted_top_n_prompts } " )
289296 return sorted_top_n_prompts
@@ -456,7 +463,7 @@ def get_best_instr_by_critique(self, examples: List, params: PromptOptimizationP
456463
457464 return refined_instructions [0 ] if refined_instructions else None
458465
459- def get_best_prompt (self , params : PromptOptimizationParams ,use_examples = False ,run_without_train_examples = False ,generate_synthetic_examples = False ) -> (str , Any ):
466+ def get_best_prompt (self , params : PromptOptimizationParams ,use_examples = False ,run_without_train_examples = False ,generate_synthetic_examples = False , resolve_tie_criteria = "max" ) -> (str , Any ):
460467 """
461468 Perform `params.max_iterations` iterations for optimizing your prompt. And return the best prompt found so far.
462469
@@ -500,13 +507,13 @@ def get_best_prompt(self, params: PromptOptimizationParams,use_examples=False,ru
500507 prompt_index += 1
501508 return "" ,""
502509 prompt_score_list = self .get_prompt_score (candidate_prompts , params )
503- prompt_score_list = self .select_top_prompts (prompt_score_list , params .top_n )
510+ prompt_score_list = self .select_top_prompts (prompt_score_list , params .top_n , resolve_tie_criteria )
504511
505512 if params .refine_instruction :
506513 refined_prompts = self .refine_prompts (prompt_score_list , params )
507514 refined_prompt_score_list = self .get_prompt_score (refined_prompts , params )
508515 prompt_score_list = self .select_top_prompts (refined_prompt_score_list + prompt_score_list ,
509- params .top_n )
516+ params .top_n , resolve_tie_criteria )
510517
511518 current_base_instruction = prompt_score_list [0 ][self .GetPromptScoreIndex .PROMPT_STR ]
512519 self .iolog .append_dict_to_chained_logs ({"round_num" : round_num ,
@@ -530,7 +537,7 @@ def get_best_prompt(self, params: PromptOptimizationParams,use_examples=False,ru
530537 break
531538
532539 if len (examples ) < params .few_shot_count :
533- examples = random .sample (self .dataset , params .few_shot_count - len (examples ))
540+ examples + = random .sample (self .dataset , params .few_shot_count - len (examples ))
534541
535542 # Refine task description and examples iteratively
536543 print ("\n Refining Task description and Examples iteratively...." )
0 commit comments