Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 4 additions & 3 deletions promptolution/exemplar_selectors/base_exemplar_selector.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

from abc import ABC, abstractmethod

from typing import TYPE_CHECKING, Optional
from typing import TYPE_CHECKING, Optional, Union

from promptolution.utils.prompt import Prompt

Expand Down Expand Up @@ -35,11 +35,12 @@ def __init__(self, task: "BaseTask", predictor: "BasePredictor", config: Optiona
config.apply_to(self)

@abstractmethod
def select_exemplars(self, prompt: Prompt, n_examples: int = 5) -> Prompt:
def select_exemplars(self, prompt: Union[str, Prompt], system_prompt: Union[str, Prompt] = None, n_examples: int = 5) -> Prompt:
"""Select exemplars based on the given prompt.

Args:
prompt (Prompt): The input prompt to base the exemplar selection on.
prompt (Union[str, Prompt]): The input prompt to base the exemplar selection on. A raw string is coerced to a Prompt.
system_prompt (Union[str, Prompt]): The system prompt to be used.
n_examples (int, optional): The number of exemplars to select. Defaults to 5.

Returns:
Expand Down
16 changes: 12 additions & 4 deletions promptolution/exemplar_selectors/random_search_selector.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@

import random

from typing import Union

from promptolution.exemplar_selectors.base_exemplar_selector import BaseExemplarSelector
from promptolution.utils.prompt import Prompt

Expand All @@ -13,30 +15,36 @@ class RandomSearchSelector(BaseExemplarSelector):
evaluates their performance, and selects the best performing set.
"""

def select_exemplars(self, prompt: Prompt, n_examples: int = 5, n_trials: int = 5) -> Prompt:
def select_exemplars(self, prompt: Union[str, Prompt],system_prompt: Union[str, Prompt] = None, n_examples: int = 5, n_trials: int = 5) -> Prompt:
"""Select exemplars using a random search strategy.

This method generates multiple sets of random examples, evaluates their performance
when combined with the original prompt, and returns the best performing set.

Args:
prompt (Prompt): The input prompt to base the exemplar selection on.
prompt (Union[str, Prompt]): The input prompt to base the exemplar selection on. A raw string is coerced to a Prompt.
system_prompt (Union[str, Prompt]): The system prompt to use.
n_examples (int, optional): The number of exemplars to select. Defaults to 5.
n_trials (int, optional): The number of random trials to perform. Defaults to 5.

Returns:
Prompt: The best performing prompt, which includes the original prompt and the selected exemplars.
"""
if isinstance(prompt, str):
prompt = Prompt(prompt)
if isinstance(system_prompt, str):
system_prompt = Prompt(system_prompt)

best_score = 0.0
best_prompt = prompt

for _ in range(n_trials):
result = self.task.evaluate(prompt, self.predictor, eval_strategy="subsample")
result = self.task.evaluate(prompt, self.predictor,system_prompt, eval_strategy="subsample")
seq = result.sequences[0]
examples = random.sample(list(seq), n_examples)
prompt_with_examples = Prompt(prompt.instruction, examples)
# evaluate prompts as few shot prompt
result = self.task.evaluate(prompt_with_examples, self.predictor, eval_strategy="subsample")
result = self.task.evaluate(prompt_with_examples, self.predictor,system_prompt, eval_strategy="subsample")
score = float(result.agg_scores[0])
if score > best_score:
best_score = score
Expand Down
14 changes: 10 additions & 4 deletions promptolution/exemplar_selectors/random_selector.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import numpy as np

from typing import TYPE_CHECKING, List, Optional
from typing import TYPE_CHECKING, List, Optional, Union

from promptolution.exemplar_selectors.base_exemplar_selector import BaseExemplarSelector
from promptolution.utils.prompt import Prompt
Expand Down Expand Up @@ -38,22 +38,28 @@ def __init__(
self.desired_score = desired_score
super().__init__(task, predictor, config)

def select_exemplars(self, prompt: Prompt, n_examples: int = 5) -> Prompt:
def select_exemplars(self, prompt: Union[str, Prompt], system_prompt: Union[str, Prompt] = None, n_examples: int = 5) -> Prompt:
"""Select exemplars using a random selection strategy.

This method generates random examples and selects those that are evaluated as correct
(score == self.desired_score) until the desired number of exemplars is reached.

Args:
prompt (Prompt): The input prompt to base the exemplar selection on.
prompt (Union[str, Prompt]): The input prompt to base the exemplar selection on. A raw string is coerced to a Prompt.
system_prompt (Union[str, Prompt]): The system prompt to be used.
n_examples (int, optional): The number of exemplars to select. Defaults to 5.

Returns:
Prompt: A new prompt that includes the original prompt and the selected exemplars.
"""
if isinstance(prompt, str):
prompt = Prompt(prompt)
if isinstance(system_prompt, str):
system_prompt = Prompt(system_prompt)

examples: List[str] = []
while len(examples) < n_examples:
result = self.task.evaluate(prompt, self.predictor, eval_strategy="subsample")
result = self.task.evaluate(prompt, self.predictor, system_prompt, eval_strategy="subsample")
scores = result.scores
seqs = result.sequences
score = float(np.mean(scores))
Expand Down
2 changes: 1 addition & 1 deletion promptolution/llms/local_llm.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ def _get_response(self, prompts: List[str], system_prompts: List[str]) -> List[s
"""
inputs: List[List[Dict[str, str]]] = []
for prompt, sys_prompt in zip(prompts, system_prompts):
inputs.append([{"role": "system", "prompt": sys_prompt}, {"role": "user", "prompt": prompt}])
inputs.append([{"role": "system", "content": sys_prompt}, {"role": "user", "content": prompt}])

with torch.no_grad():
response = self.pipeline(inputs, pad_token_id=self.eos_token_id)
Expand Down
Loading