Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
6 changes: 3 additions & 3 deletions promptolution/exemplar_selectors/base_exemplar_selector.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

from abc import ABC, abstractmethod

from typing import TYPE_CHECKING, Optional
from typing import TYPE_CHECKING, Optional, Union

from promptolution.utils.prompt import Prompt

Expand Down Expand Up @@ -35,11 +35,11 @@ def __init__(self, task: "BaseTask", predictor: "BasePredictor", config: Optiona
config.apply_to(self)

@abstractmethod
def select_exemplars(self, prompt: Prompt, n_examples: int = 5) -> Prompt:
def select_exemplars(self, prompt: Union[str, Prompt], n_examples: int = 5) -> Prompt:
"""Select exemplars based on the given prompt.

Args:
prompt (Prompt): The input prompt to base the exemplar selection on.
prompt (Union[str, Prompt]): The input prompt to base the exemplar selection on. A raw string is coerced to a Prompt.
n_examples (int, optional): The number of exemplars to select. Defaults to 5.

Returns:
Expand Down
9 changes: 7 additions & 2 deletions promptolution/exemplar_selectors/random_search_selector.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@

import random

from typing import Union

from promptolution.exemplar_selectors.base_exemplar_selector import BaseExemplarSelector
from promptolution.utils.prompt import Prompt

Expand All @@ -13,20 +15,23 @@ class RandomSearchSelector(BaseExemplarSelector):
evaluates their performance, and selects the best performing set.
"""

def select_exemplars(self, prompt: Prompt, n_examples: int = 5, n_trials: int = 5) -> Prompt:
def select_exemplars(self, prompt: Union[str, Prompt], n_examples: int = 5, n_trials: int = 5) -> Prompt:
"""Select exemplars using a random search strategy.

This method generates multiple sets of random examples, evaluates their performance
when combined with the original prompt, and returns the best performing set.

Args:
prompt (Prompt): The input prompt to base the exemplar selection on.
prompt (Union[str, Prompt]): The input prompt to base the exemplar selection on. A raw string is coerced to a Prompt.
n_examples (int, optional): The number of exemplars to select. Defaults to 5.
n_trials (int, optional): The number of random trials to perform. Defaults to 5.

Returns:
Prompt: The best performing prompt, which includes the original prompt and the selected exemplars.
"""
if isinstance(prompt, str):
prompt = Prompt(prompt)

best_score = 0.0
best_prompt = prompt

Expand Down
9 changes: 6 additions & 3 deletions promptolution/exemplar_selectors/random_selector.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import numpy as np

from typing import TYPE_CHECKING, List, Optional
from typing import TYPE_CHECKING, List, Optional, Union

from promptolution.exemplar_selectors.base_exemplar_selector import BaseExemplarSelector
from promptolution.utils.prompt import Prompt
Expand Down Expand Up @@ -38,19 +38,22 @@ def __init__(
self.desired_score = desired_score
super().__init__(task, predictor, config)

def select_exemplars(self, prompt: Prompt, n_examples: int = 5) -> Prompt:
def select_exemplars(self, prompt: Union[str, Prompt], n_examples: int = 5) -> Prompt:
"""Select exemplars using a random selection strategy.

This method generates random examples and selects those that are evaluated as correct
(score == self.desired_score) until the desired number of exemplars is reached.

Args:
prompt (Prompt): The input prompt to base the exemplar selection on.
prompt (Union[str, Prompt]): The input prompt to base the exemplar selection on. A raw string is coerced to a Prompt.
n_examples (int, optional): The number of exemplars to select. Defaults to 5.

Returns:
Prompt: A new prompt that includes the original prompt and the selected exemplars.
"""
if isinstance(prompt, str):
prompt = Prompt(prompt)

examples: List[str] = []
while len(examples) < n_examples:
result = self.task.evaluate(prompt, self.predictor, eval_strategy="subsample")
Expand Down
54 changes: 35 additions & 19 deletions tests/exemplar_selectors/test_exemplar_selectors.py
Original file line number Diff line number Diff line change
Expand Up @@ -8,9 +8,13 @@
from tests.mocks.mock_predictor import MockPredictor

from promptolution.exemplar_selectors.random_search_selector import RandomSearchSelector
from promptolution.exemplar_selectors.random_selector import RandomSelector
from promptolution.tasks.base_task import EvalResult
from promptolution.utils.prompt import Prompt

# Every concrete selector must satisfy the shared BaseExemplarSelector contract.
SELECTOR_CLASSES = [RandomSelector, RandomSearchSelector]


def make_eval_result(sequences, score):
n = len(sequences)
Expand All @@ -29,19 +33,44 @@ def make_eval_result(sequences, score):
def task_and_predictor():
task = MagicMock()
pred = MockPredictor()
# score 1.0 satisfies both RandomSelector (desired_score == 1) and RandomSearchSelector.
task.evaluate.return_value = make_eval_result([f"ex_{i}" for i in range(10)], score=1.0)
return task, pred


def test_random_search_selector_respects_n_examples(task_and_predictor):
@pytest.mark.parametrize("selector_cls", SELECTOR_CLASSES)
def test_select_exemplars_respects_n_examples(selector_cls, task_and_predictor):
task, pred = task_and_predictor
sequences = [f"ex_{i}" for i in range(10)]
task.evaluate.return_value = make_eval_result(sequences, score=0.8)

selector = RandomSearchSelector(task, pred)
result = selector.select_exemplars(Prompt("Classify:"), n_examples=3, n_trials=1)
selector = selector_cls(task, pred)
result = selector.select_exemplars(Prompt("Classify:"), n_examples=3)

assert isinstance(result, Prompt)
assert len(result.few_shots) == 3
assert all(ex in sequences for ex in result.few_shots)


@pytest.mark.parametrize("selector_cls", SELECTOR_CLASSES)
def test_select_exemplars_accepts_str_prompt(selector_cls, task_and_predictor):
"""Regression: a raw str prompt must be coerced, not split into characters."""
task, pred = task_and_predictor

selector = selector_cls(task, pred)
result = selector.select_exemplars(prompt="Classify:", n_examples=2)

assert isinstance(result, Prompt)
assert result.instruction == "Classify:"
assert len(result.few_shots) == 2


@pytest.mark.parametrize("selector_cls", SELECTOR_CLASSES)
def test_select_exemplars_n_examples_kwarg(selector_cls, task_and_predictor):
"""Regression: calling with n_examples as keyword arg must not raise TypeError."""
task, pred = task_and_predictor

selector = selector_cls(task, pred)
result = selector.select_exemplars(prompt=Prompt("Classify:"), n_examples=2)

assert len(result.few_shots) == 2


def test_random_search_selector_returns_best_trial(task_and_predictor):
Expand All @@ -61,16 +90,3 @@ def test_random_search_selector_returns_best_trial(task_and_predictor):

assert len(result.few_shots) == 2
assert result.few_shots != []


def test_random_search_selector_n_examples_kwarg(task_and_predictor):
"""Regression test: calling with n_examples as keyword arg must not raise TypeError."""
task, pred = task_and_predictor
sequences = [f"ex_{i}" for i in range(5)]
task.evaluate.return_value = make_eval_result(sequences, score=0.5)

selector = RandomSearchSelector(task, pred)
# This call pattern is what triggered the original bug report
result = selector.select_exemplars(prompt=Prompt("Classify:"), n_examples=2)

assert len(result.few_shots) == 2
Loading