Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@

![Coverage](https://img.shields.io/badge/Coverage-95%25-brightgreen)
![Coverage](https://img.shields.io/badge/Coverage-96%25-brightgreen)
[![CI](https://github.com/automl/promptolution/actions/workflows/ci.yml/badge.svg?branch=main)](https://github.com/automl/promptolution/actions/workflows/ci.yml)
[![Docs](https://github.com/automl/promptolution/actions/workflows/docs.yml/badge.svg?branch=main)](https://github.com/automl/promptolution/actions/workflows/docs.yml)
![Code Style](https://img.shields.io/badge/Code%20Style-black-black)
Expand Down
12 changes: 8 additions & 4 deletions promptolution/exemplar_selectors/random_search_selector.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
"""Random search exemplar selector."""

import random

from promptolution.exemplar_selectors.base_exemplar_selector import BaseExemplarSelector
from promptolution.utils.prompt import Prompt

Expand All @@ -11,14 +13,15 @@ class RandomSearchSelector(BaseExemplarSelector):
evaluates their performance, and selects the best performing set.
"""

def select_exemplars(self, prompt: Prompt, n_trials: int = 5) -> Prompt:
def select_exemplars(self, prompt: Prompt, n_examples: int = 5, n_trials: int = 5) -> Prompt:
Comment thread
mo374z marked this conversation as resolved.
"""Select exemplars using a random search strategy.

This method generates multiple sets of random examples, evaluates their performance
when combined with the original prompt, and returns the best performing set.

Args:
prompt (str): The input prompt to base the exemplar selection on.
prompt (Prompt): The input prompt to base the exemplar selection on.
n_examples (int, optional): The number of exemplars to select. Defaults to 5.
n_trials (int, optional): The number of random trials to perform. Defaults to 5.

Returns:
Expand All @@ -29,8 +32,9 @@ def select_exemplars(self, prompt: Prompt, n_trials: int = 5) -> Prompt:

for _ in range(n_trials):
result = self.task.evaluate(prompt, self.predictor, eval_strategy="subsample")
seq = result.sequences
prompt_with_examples = Prompt(prompt.instruction, [seq[0][0]])
seq = result.sequences[0]
examples = random.sample(list(seq), n_examples)
Copy link
Copy Markdown

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

P1 Badge Guard against sampling more exemplars than available

random.sample(list(seq), n_examples) raises ValueError whenever the sampled subsplit has fewer items than n_examples (e.g., small datasets where len(xs) < 5, or any run with config.n_exemplars larger than the subsample size). Because BaseTask.subsample can legitimately return fewer rows than requested, this turns exemplar selection into a hard runtime failure instead of returning a best-effort prompt.

Useful? React with 👍 / 👎.

Comment on lines +35 to +36
prompt_with_examples = Prompt(prompt.instruction, examples)
# evaluate prompts as few shot prompt
result = self.task.evaluate(prompt_with_examples, self.predictor, eval_strategy="subsample")
score = float(result.agg_scores[0])
Expand Down
76 changes: 76 additions & 0 deletions tests/exemplar_selectors/test_exemplar_selectors.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
"""Tests for exemplar selectors."""

from unittest.mock import MagicMock

import numpy as np
import pytest

from tests.mocks.mock_predictor import MockPredictor

from promptolution.exemplar_selectors.random_search_selector import RandomSearchSelector
from promptolution.tasks.base_task import EvalResult
from promptolution.utils.prompt import Prompt


def make_eval_result(sequences, score):
n = len(sequences)
return EvalResult(
scores=np.array([[score] * n], dtype=float),
agg_scores=np.array([score], dtype=float),
sequences=np.array([sequences], dtype=object),
input_tokens=np.array([[1.0] * n], dtype=float),
output_tokens=np.array([[1.0] * n], dtype=float),
agg_input_tokens=np.array([1.0], dtype=float),
agg_output_tokens=np.array([1.0], dtype=float),
)


@pytest.fixture
def task_and_predictor():
task = MagicMock()
pred = MockPredictor()
return task, pred


def test_random_search_selector_respects_n_examples(task_and_predictor):
task, pred = task_and_predictor
sequences = [f"ex_{i}" for i in range(10)]
task.evaluate.return_value = make_eval_result(sequences, score=0.8)

selector = RandomSearchSelector(task, pred)
result = selector.select_exemplars(Prompt("Classify:"), n_examples=3, n_trials=1)

assert len(result.few_shots) == 3
assert all(ex in sequences for ex in result.few_shots)


def test_random_search_selector_returns_best_trial(task_and_predictor):
task, pred = task_and_predictor
sequences = [f"ex_{i}" for i in range(5)]

# First trial scores low, second scores high
task.evaluate.side_effect = [
make_eval_result(sequences, score=0.3), # zero-shot eval trial 1
make_eval_result(sequences, score=0.3), # few-shot eval trial 1
make_eval_result(sequences, score=0.9), # zero-shot eval trial 2
make_eval_result(sequences, score=0.9), # few-shot eval trial 2
]

selector = RandomSearchSelector(task, pred)
result = selector.select_exemplars(Prompt("Classify:"), n_examples=2, n_trials=2)

assert len(result.few_shots) == 2
assert result.few_shots != []


def test_random_search_selector_n_examples_kwarg(task_and_predictor):
"""Regression test: calling with n_examples as keyword arg must not raise TypeError."""
task, pred = task_and_predictor
sequences = [f"ex_{i}" for i in range(5)]
task.evaluate.return_value = make_eval_result(sequences, score=0.5)

selector = RandomSearchSelector(task, pred)
# This call pattern is what triggered the original bug report
result = selector.select_exemplars(prompt=Prompt("Classify:"), n_examples=2)

assert len(result.few_shots) == 2