Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
17 changes: 9 additions & 8 deletions promptolution/tasks/base_task.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,11 +63,6 @@ def __init__(
seed (int): Random seed for reproducibility.
config (ExperimentConfig, optional): Configuration for the task, overriding defaults.
"""
self.df = df.drop_duplicates(subset=[x_column])
if len(self.df) != len(df):
logger.warning(
f"Duplicate entries detected for x_column '{x_column}' - dropped {len(df) - len(self.df)} rows to enforce uniqueness."
)
self.x_column: str = x_column
self.y_column: Optional[str] = y_column
self.task_type: TaskType | None = None
Expand All @@ -80,10 +75,16 @@ def __init__(
if config is not None:
config.apply_to(self)

self.df = df.drop_duplicates(subset=[self.x_column])
if len(self.df) != len(df):
logger.warning(
f"Duplicate entries detected for x_column '{self.x_column}' - dropped {len(df) - len(self.df)} rows to enforce uniqueness."
)

self.xs: List[str] = self.df[self.x_column].values.astype(str).tolist()
self.has_y: bool = y_column is not None
if self.has_y and y_column is not None:
self.ys: List[str] = self.df[y_column].values.astype(str).tolist()
self.has_y: bool = self.y_column is not None
if self.has_y and self.y_column is not None:
self.ys: List[str] = self.df[self.y_column].values.astype(str).tolist()
else:
# If no y_column is provided, create a dummy y array
self.ys = [""] * len(self.xs)
Expand Down
2 changes: 1 addition & 1 deletion promptolution/tasks/reward_tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,7 +63,7 @@ def __init__(
)
self.task_type = "reward"
# x -> kwargs to reward function
km = self.df.set_index(x_column)[self.reward_columns].to_dict("index")
km = self.df.set_index(self.x_column)[self.reward_columns].to_dict("index")
self.kwargs_map = defaultdict(dict, km)

def _evaluate(self, xs: List[str], ys: List[str], preds: List[str]) -> np.ndarray:
Expand Down
10 changes: 10 additions & 0 deletions tests/tasks/test_reward_tasks.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
import pandas as pd

from promptolution.tasks.reward_tasks import RewardTask
from promptolution.utils.config import ExperimentConfig
from promptolution.utils.prompt import Prompt


Expand Down Expand Up @@ -53,3 +54,12 @@ def reward_fn(prediction: str, reward: float) -> float:

assert scores.tolist() == [0.1, 0.2, -1.0]
assert seen_rewards == [0.1, 0.2, 0.3]


def test_reward_task_x_column_from_config(simple_reward_function):
"""Ensure setting an arbitrary x_column name via the config works."""
df = pd.DataFrame({"my_input": ["a", "b", "c"]})
config = ExperimentConfig(x_column="my_input")
task = RewardTask(df=df, reward_function=simple_reward_function, config=config)
assert task.x_column == "my_input"
assert task.xs == ["a", "b", "c"]
Loading