Skip to content

Commit 59ab6ff

Browse files
committed
precommit changes
1 parent 8f563be commit 59ab6ff

File tree

2 files changed

+7
-4
lines changed

2 files changed

+7
-4
lines changed

prompting/tasks/qa.py

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -68,6 +68,7 @@ def make_reference(self, dataset_entry: Context):
6868
self.reference = self.generate_reference(messages=[{"role": "user", "content": reference_prompt}])
6969
return self.reference
7070

71+
7172
class WebQuestionAnsweringTask(BaseTextTask):
7273
"""QuestionAnsweringTasks must be initialised with an LLM pipeline to generate query and reference plus
7374
context from a dataset to base the query on"""
@@ -87,4 +88,4 @@ def make_query(self, dataset_entry: Context):
8788
def make_reference(self, dataset_entry: Context):
8889
reference_prompt = REFERENCE_PROMPT_TEMPLATE.format(context=dataset_entry.content, question=self.query)
8990
self.reference = self.generate_reference(messages=[{"role": "user", "content": reference_prompt}])
90-
return self.reference
91+
return self.reference

prompting/tasks/task_registry.py

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -15,7 +15,7 @@
1515
from prompting.tasks.multi_choice import MultiChoiceRewardConfig, MultiChoiceTask
1616
from prompting.tasks.multi_step_reasoning import MultiStepReasoningRewardConfig, MultiStepReasoningTask
1717
from prompting.tasks.programming_task import ProgrammingRewardConfig, ProgrammingTask
18-
from prompting.tasks.qa import QARewardConfig, WikiQuestionAnsweringTask, WebQuestionAnsweringTask
18+
from prompting.tasks.qa import QARewardConfig, WebQuestionAnsweringTask, WikiQuestionAnsweringTask
1919
from prompting.tasks.web_retrieval import WebRetrievalRewardConfig, WebRetrievalTask
2020
from shared.base import BaseDataset
2121

@@ -34,7 +34,9 @@ def __hash__(self):
3434

3535
class TaskRegistry(BaseModel):
3636
task_configs: ClassVar[list[TaskConfig]] = [
37-
TaskConfig(task=WikiQuestionAnsweringTask, probability=0.2, datasets=[WikiDataset], reward_model=QARewardConfig),
37+
TaskConfig(
38+
task=WikiQuestionAnsweringTask, probability=0.2, datasets=[WikiDataset], reward_model=QARewardConfig
39+
),
3840
TaskConfig(task=WebQuestionAnsweringTask, probability=0.1, datasets=[DDGDataset], reward_model=QARewardConfig),
3941
TaskConfig(
4042
task=InferenceTask,
@@ -65,7 +67,7 @@ class TaskRegistry(BaseModel):
6567
probability=0.1,
6668
datasets=[WikiDataset],
6769
reward_model=MultiStepReasoningRewardConfig,
68-
)
70+
),
6971
]
7072

7173
@classmethod

0 commit comments

Comments
 (0)