Skip to content

Commit de3b610

Browse files
committed
fix faulty header of random search selector
1 parent 63cdec0 commit de3b610

2 files changed

Lines changed: 84 additions & 4 deletions

File tree

promptolution/exemplar_selectors/random_search_selector.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
"""Random search exemplar selector."""
22

3+
import random
4+
35
from promptolution.exemplar_selectors.base_exemplar_selector import BaseExemplarSelector
46
from promptolution.utils.prompt import Prompt
57

@@ -11,14 +13,15 @@ class RandomSearchSelector(BaseExemplarSelector):
1113
evaluates their performance, and selects the best performing set.
1214
"""
1315

14-
def select_exemplars(self, prompt: Prompt, n_trials: int = 5) -> Prompt:
16+
def select_exemplars(self, prompt: Prompt, n_examples: int = 5, n_trials: int = 5) -> Prompt:
1517
"""Select exemplars using a random search strategy.
1618
1719
This method generates multiple sets of random examples, evaluates their performance
1820
when combined with the original prompt, and returns the best performing set.
1921
2022
Args:
21-
prompt (str): The input prompt to base the exemplar selection on.
23+
prompt (Prompt): The input prompt to base the exemplar selection on.
24+
n_examples (int, optional): The number of exemplars to select. Defaults to 5.
2225
n_trials (int, optional): The number of random trials to perform. Defaults to 5.
2326
2427
Returns:
@@ -29,8 +32,9 @@ def select_exemplars(self, prompt: Prompt, n_trials: int = 5) -> Prompt:
2932

3033
for _ in range(n_trials):
3134
result = self.task.evaluate(prompt, self.predictor, eval_strategy="subsample")
32-
seq = result.sequences
33-
prompt_with_examples = Prompt(prompt.instruction, [seq[0][0]])
35+
seq = result.sequences[0]
36+
examples = random.sample(list(seq), n_examples)
37+
prompt_with_examples = Prompt(prompt.instruction, examples)
3438
# evaluate prompts as few shot prompt
3539
result = self.task.evaluate(prompt_with_examples, self.predictor, eval_strategy="subsample")
3640
score = float(result.agg_scores[0])
Lines changed: 76 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,76 @@
1+
"""Tests for exemplar selectors."""
2+
3+
from unittest.mock import MagicMock
4+
5+
import numpy as np
6+
import pytest
7+
8+
from tests.mocks.mock_predictor import MockPredictor
9+
10+
from promptolution.exemplar_selectors.random_search_selector import RandomSearchSelector
11+
from promptolution.tasks.base_task import EvalResult
12+
from promptolution.utils.prompt import Prompt
13+
14+
15+
def make_eval_result(sequences, score):
16+
n = len(sequences)
17+
return EvalResult(
18+
scores=np.array([[score] * n], dtype=float),
19+
agg_scores=np.array([score], dtype=float),
20+
sequences=np.array([sequences], dtype=object),
21+
input_tokens=np.array([[1.0] * n], dtype=float),
22+
output_tokens=np.array([[1.0] * n], dtype=float),
23+
agg_input_tokens=np.array([1.0], dtype=float),
24+
agg_output_tokens=np.array([1.0], dtype=float),
25+
)
26+
27+
28+
@pytest.fixture
29+
def task_and_predictor():
30+
task = MagicMock()
31+
pred = MockPredictor()
32+
return task, pred
33+
34+
35+
def test_random_search_selector_respects_n_examples(task_and_predictor):
36+
task, pred = task_and_predictor
37+
sequences = [f"ex_{i}" for i in range(10)]
38+
task.evaluate.return_value = make_eval_result(sequences, score=0.8)
39+
40+
selector = RandomSearchSelector(task, pred)
41+
result = selector.select_exemplars(Prompt("Classify:"), n_examples=3, n_trials=1)
42+
43+
assert len(result.few_shots) == 3
44+
assert all(ex in sequences for ex in result.few_shots)
45+
46+
47+
def test_random_search_selector_returns_best_trial(task_and_predictor):
48+
task, pred = task_and_predictor
49+
sequences = [f"ex_{i}" for i in range(5)]
50+
51+
# First trial scores low, second scores high
52+
task.evaluate.side_effect = [
53+
make_eval_result(sequences, score=0.3), # zero-shot eval trial 1
54+
make_eval_result(sequences, score=0.3), # few-shot eval trial 1
55+
make_eval_result(sequences, score=0.9), # zero-shot eval trial 2
56+
make_eval_result(sequences, score=0.9), # few-shot eval trial 2
57+
]
58+
59+
selector = RandomSearchSelector(task, pred)
60+
result = selector.select_exemplars(Prompt("Classify:"), n_examples=2, n_trials=2)
61+
62+
assert len(result.few_shots) == 2
63+
assert result.few_shots != []
64+
65+
66+
def test_random_search_selector_n_examples_kwarg(task_and_predictor):
67+
"""Regression test: calling with n_examples as keyword arg must not raise TypeError."""
68+
task, pred = task_and_predictor
69+
sequences = [f"ex_{i}" for i in range(5)]
70+
task.evaluate.return_value = make_eval_result(sequences, score=0.5)
71+
72+
selector = RandomSearchSelector(task, pred)
73+
# This call pattern is what triggered the original bug report
74+
result = selector.select_exemplars(prompt=Prompt("Classify:"), n_examples=2)
75+
76+
assert len(result.few_shots) == 2

0 commit comments

Comments
 (0)