-
Notifications
You must be signed in to change notification settings - Fork 14
fix faulty header of random search selector #70
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Changes from all commits
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -1,5 +1,7 @@ | ||
| """Random search exemplar selector.""" | ||
|
|
||
| import random | ||
|
|
||
| from promptolution.exemplar_selectors.base_exemplar_selector import BaseExemplarSelector | ||
| from promptolution.utils.prompt import Prompt | ||
|
|
||
|
|
@@ -11,14 +13,15 @@ class RandomSearchSelector(BaseExemplarSelector): | |
| evaluates their performance, and selects the best performing set. | ||
| """ | ||
|
|
||
| def select_exemplars(self, prompt: Prompt, n_trials: int = 5) -> Prompt: | ||
| def select_exemplars(self, prompt: Prompt, n_examples: int = 5, n_trials: int = 5) -> Prompt: | ||
| """Select exemplars using a random search strategy. | ||
|
|
||
| This method generates multiple sets of random examples, evaluates their performance | ||
| when combined with the original prompt, and returns the best performing set. | ||
|
|
||
| Args: | ||
| prompt (str): The input prompt to base the exemplar selection on. | ||
| prompt (Prompt): The input prompt to base the exemplar selection on. | ||
| n_examples (int, optional): The number of exemplars to select. Defaults to 5. | ||
| n_trials (int, optional): The number of random trials to perform. Defaults to 5. | ||
|
|
||
| Returns: | ||
|
|
@@ -29,8 +32,9 @@ def select_exemplars(self, prompt: Prompt, n_trials: int = 5) -> Prompt: | |
|
|
||
| for _ in range(n_trials): | ||
| result = self.task.evaluate(prompt, self.predictor, eval_strategy="subsample") | ||
| seq = result.sequences | ||
| prompt_with_examples = Prompt(prompt.instruction, [seq[0][0]]) | ||
| seq = result.sequences[0] | ||
| examples = random.sample(list(seq), n_examples) | ||
|
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more.
Useful? React with 👍 / 👎.
Comment on lines
+35
to
+36
|
||
| prompt_with_examples = Prompt(prompt.instruction, examples) | ||
| # evaluate prompts as few shot prompt | ||
| result = self.task.evaluate(prompt_with_examples, self.predictor, eval_strategy="subsample") | ||
| score = float(result.agg_scores[0]) | ||
|
|
||
| Original file line number | Diff line number | Diff line change |
|---|---|---|
| @@ -0,0 +1,76 @@ | ||
| """Tests for exemplar selectors.""" | ||
|
|
||
| from unittest.mock import MagicMock | ||
|
|
||
| import numpy as np | ||
| import pytest | ||
|
|
||
| from tests.mocks.mock_predictor import MockPredictor | ||
|
|
||
| from promptolution.exemplar_selectors.random_search_selector import RandomSearchSelector | ||
| from promptolution.tasks.base_task import EvalResult | ||
| from promptolution.utils.prompt import Prompt | ||
|
|
||
|
|
||
| def make_eval_result(sequences, score): | ||
| n = len(sequences) | ||
| return EvalResult( | ||
| scores=np.array([[score] * n], dtype=float), | ||
| agg_scores=np.array([score], dtype=float), | ||
| sequences=np.array([sequences], dtype=object), | ||
| input_tokens=np.array([[1.0] * n], dtype=float), | ||
| output_tokens=np.array([[1.0] * n], dtype=float), | ||
| agg_input_tokens=np.array([1.0], dtype=float), | ||
| agg_output_tokens=np.array([1.0], dtype=float), | ||
| ) | ||
|
|
||
|
|
||
| @pytest.fixture | ||
| def task_and_predictor(): | ||
| task = MagicMock() | ||
| pred = MockPredictor() | ||
| return task, pred | ||
|
|
||
|
|
||
| def test_random_search_selector_respects_n_examples(task_and_predictor): | ||
| task, pred = task_and_predictor | ||
| sequences = [f"ex_{i}" for i in range(10)] | ||
| task.evaluate.return_value = make_eval_result(sequences, score=0.8) | ||
|
|
||
| selector = RandomSearchSelector(task, pred) | ||
| result = selector.select_exemplars(Prompt("Classify:"), n_examples=3, n_trials=1) | ||
|
|
||
| assert len(result.few_shots) == 3 | ||
| assert all(ex in sequences for ex in result.few_shots) | ||
|
|
||
|
|
||
| def test_random_search_selector_returns_best_trial(task_and_predictor): | ||
| task, pred = task_and_predictor | ||
| sequences = [f"ex_{i}" for i in range(5)] | ||
|
|
||
| # First trial scores low, second scores high | ||
| task.evaluate.side_effect = [ | ||
| make_eval_result(sequences, score=0.3), # zero-shot eval trial 1 | ||
| make_eval_result(sequences, score=0.3), # few-shot eval trial 1 | ||
| make_eval_result(sequences, score=0.9), # zero-shot eval trial 2 | ||
| make_eval_result(sequences, score=0.9), # few-shot eval trial 2 | ||
| ] | ||
|
|
||
| selector = RandomSearchSelector(task, pred) | ||
| result = selector.select_exemplars(Prompt("Classify:"), n_examples=2, n_trials=2) | ||
|
|
||
| assert len(result.few_shots) == 2 | ||
| assert result.few_shots != [] | ||
|
|
||
|
|
||
| def test_random_search_selector_n_examples_kwarg(task_and_predictor): | ||
| """Regression test: calling with n_examples as keyword arg must not raise TypeError.""" | ||
| task, pred = task_and_predictor | ||
| sequences = [f"ex_{i}" for i in range(5)] | ||
| task.evaluate.return_value = make_eval_result(sequences, score=0.5) | ||
|
|
||
| selector = RandomSearchSelector(task, pred) | ||
| # This call pattern is what triggered the original bug report | ||
| result = selector.select_exemplars(prompt=Prompt("Classify:"), n_examples=2) | ||
|
|
||
| assert len(result.few_shots) == 2 |
Uh oh!
There was an error while loading. Please reload this page.