Skip to content

Commit a018fc1

Browse files
Update core_logic.py
1 parent e542434 commit a018fc1

File tree

1 file changed

+15
-8
lines changed
  • promptwizard/glue/promptopt/techniques/critique_n_refine

1 file changed

+15
-8
lines changed

promptwizard/glue/promptopt/techniques/critique_n_refine/core_logic.py

Lines changed: 15 additions & 8 deletions
Original file line numberDiff line numberDiff line change
@@ -272,7 +272,7 @@ def evaluate(self, generated_text: str, dataset_subset: List) -> List:
272272
return wrong_examples
273273

274274
@iolog.log_io_params
275-
def select_top_prompts(self, prompt_score_list: List, top_n: int) -> List:
275+
def select_top_prompts(self, prompt_score_list: List, top_n: int, resolve_tie_criteria: str) -> List:
276276
"""
277277
Sort prompts in prompt_score_list, based on its performance. And return max, top `top_n` prompts.
278278
@@ -281,9 +281,16 @@ def select_top_prompts(self, prompt_score_list: List, top_n: int) -> List:
281281
:param top_n: Max number of prompts from the top of the list, that we need to return
282282
:return: List of top `top_n` prompts.
283283
"""
284-
sorted_prompts = sorted(prompt_score_list, key=lambda x: [x[self.GetPromptScoreIndex.SCORE],
285-
len(x[self.GetPromptScoreIndex.PROMPT_STR])],
286-
reverse=True)
284+
285+
if resolve_tie_criteria == 'max':
286+
sorted_prompts = sorted(prompt_score_list, key=lambda x: [x[self.GetPromptScoreIndex.SCORE],
287+
len(x[self.GetPromptScoreIndex.PROMPT_STR])],
288+
reverse=True)
289+
else:
290+
sorted_prompts = sorted(prompt_score_list, key=lambda x: [x[self.GetPromptScoreIndex.SCORE],
291+
-1 * len(x[self.GetPromptScoreIndex.PROMPT_STR])],
292+
reverse=True)
293+
287294
sorted_top_n_prompts = sorted_prompts[:top_n]
288295
self.logger.debug(f"Sorted top n prompts: {sorted_top_n_prompts}")
289296
return sorted_top_n_prompts
@@ -456,7 +463,7 @@ def get_best_instr_by_critique(self, examples: List, params: PromptOptimizationP
456463

457464
return refined_instructions[0] if refined_instructions else None
458465

459-
def get_best_prompt(self, params: PromptOptimizationParams,use_examples=False,run_without_train_examples=False,generate_synthetic_examples=False) -> (str, Any):
466+
def get_best_prompt(self, params: PromptOptimizationParams,use_examples=False,run_without_train_examples=False,generate_synthetic_examples=False,resolve_tie_criteria="max") -> (str, Any):
460467
"""
461468
Perform `params.max_iterations` iterations for optimizing your prompt. And return the best prompt found so far.
462469
@@ -500,13 +507,13 @@ def get_best_prompt(self, params: PromptOptimizationParams,use_examples=False,ru
500507
prompt_index += 1
501508
return "",""
502509
prompt_score_list = self.get_prompt_score(candidate_prompts, params)
503-
prompt_score_list = self.select_top_prompts(prompt_score_list, params.top_n)
510+
prompt_score_list = self.select_top_prompts(prompt_score_list, params.top_n,resolve_tie_criteria)
504511

505512
if params.refine_instruction:
506513
refined_prompts = self.refine_prompts(prompt_score_list, params)
507514
refined_prompt_score_list = self.get_prompt_score(refined_prompts, params)
508515
prompt_score_list = self.select_top_prompts(refined_prompt_score_list + prompt_score_list,
509-
params.top_n)
516+
params.top_n,resolve_tie_criteria)
510517

511518
current_base_instruction = prompt_score_list[0][self.GetPromptScoreIndex.PROMPT_STR]
512519
self.iolog.append_dict_to_chained_logs({"round_num": round_num,
@@ -530,7 +537,7 @@ def get_best_prompt(self, params: PromptOptimizationParams,use_examples=False,ru
530537
break
531538

532539
if len(examples) < params.few_shot_count:
533-
examples = random.sample(self.dataset, params.few_shot_count - len(examples))
540+
examples += random.sample(self.dataset, params.few_shot_count - len(examples))
534541

535542
# Refine task description and examples iteratively
536543
print("\nRefining Task description and Examples iteratively....")

0 commit comments

Comments
 (0)