@@ -42,12 +42,15 @@ class Dataset(Generic[InputT, OutputT]):
4242 expected_trajectory=["calculator],
4343 metadata={"category": "math"})
4444 ],
45- evaluator=OutputEvaluator(rubric = "The output is relevant and complete. 0 if the output is incorrect or irrelevant.")
45+ evaluator=OutputEvaluator(rubric="The output is relevant and complete. 0 if the output is
46+ incorrect or irrelevant.")
4647 )
4748 """
4849
4950 def __init__ (
50- self , cases : list [Case [InputT , OutputT ]] | None = None , evaluator : Evaluator [InputT , OutputT ] | None = None
51+ self ,
52+ cases : list [Case [InputT , OutputT ]] | None = None ,
53+ evaluator : Evaluator [InputT , OutputT ] | None = None ,
5154 ):
5255 self ._cases = cases or []
5356 self ._evaluator = evaluator or Evaluator ()
@@ -102,7 +105,8 @@ def _run_task(
102105 Run the task with the inputs from the test case.
103106
104107 Args:
105- task: The task to run the test case on. This function should take in InputT and returns either OutputT or {"output": OutputT, "trajectory": ...}.
108+ task: The task to run the test case on. This function should take in InputT and returns either
109+ OutputT or {"output": OutputT, "trajectory": ...}.
106110 case: The test case containing neccessary information to run the task
107111
108112 Return:
@@ -138,8 +142,9 @@ async def _run_task_async(
138142 Run the task with the inputs from the test case asynchronously.
139143
140144 Args:
141- task: The task to run the test case on. This function should take in InputT and returns either OutputT or {"output": OutputT, "trajectory": ...}.
142- The task can either run synchronously or asynchronously.
145+ task: The task to run the test case on. This function should take in InputT and returns either
146+ OutputT or {"output": OutputT, "trajectory": ...}. The task can either run synchronously
147+ or asynchronously.
143148 case: The test case containing neccessary information to run the task
144149
145150 Return:
@@ -220,10 +225,12 @@ def run_evaluations(self, task: Callable[[InputT], OutputT | dict[str, Any]]) ->
220225 Run the evaluations for all of the test cases with the evaluator.
221226
222227 Args:
223- task: The task to run the test case on. This function should take in InputT and returns either OutputT or {"output": OutputT, "trajectory": ...}.
228+ task: The task to run the test case on. This function should take in InputT and returns either
229+ OutputT or {"output": OutputT, "trajectory": ...}.
224230
225231 Return:
226- An EvaluationReport containing the overall score, individual case results, and basic feedback for each test case.
232+ An EvaluationReport containing the overall score, individual case results, and basic feedback
233+ for each test case.
227234 """
228235 scores = []
229236 test_passes = []
@@ -261,15 +268,16 @@ async def run_evaluations_async(self, task: Callable, max_workers: int = 10) ->
261268 Run evaluations asynchronously using a queue for parallel processing.
262269
263270 Args:
264- task: The task function to run on each case. This function should take in InputT and returns either OutputT or {"output": OutputT, "trajectory": ...}.
265- The task can either run synchronously or asynchronously.
271+ task: The task function to run on each case. This function should take in InputT and returns
272+ either OutputT or {"output": OutputT, "trajectory": ...}. The task can either run
273+ synchronously or asynchronously.
266274 max_workers: Maximum number of parallel workers (default: 10)
267275
268276 Returns:
269277 EvaluationReport containing evaluation results
270278 """
271- queue = asyncio .Queue ()
272- results = []
279+ queue : asyncio . Queue [ Case [ InputT , OutputT ]] = asyncio .Queue ()
280+ results : list [ Any ] = []
273281
274282 for case in self ._cases :
275283 queue .put_nowait (case )
@@ -325,7 +333,7 @@ def to_file(self, file_name: str, format: str = "json", directory: str = "datase
325333 raise Exception (f"Format { format } is not supported." )
326334
327335 @classmethod
328- def from_dict (cls , data : dict , custom_evaluators : list [Evaluator ] = None ):
336+ def from_dict (cls , data : dict , custom_evaluators : list [type [ Evaluator ]] | None = None ):
329337 """
330338 Create a dataset from a dictionary.
331339
@@ -337,14 +345,17 @@ def from_dict(cls, data: dict, custom_evaluators: list[Evaluator] = None):
337345 A Dataset object.
338346 """
339347 custom_evaluators = custom_evaluators or []
340- cases = [Case .model_validate (case_data ) for case_data in data ["cases" ]]
341- default_evaluators = {
348+ cases : list [ Case ] = [Case .model_validate (case_data ) for case_data in data ["cases" ]]
349+ default_evaluators : dict [ str , type [ Evaluator ]] = {
342350 "Evaluator" : Evaluator ,
343351 "OutputEvaluator" : OutputEvaluator ,
344352 "TrajectoryEvaluator" : TrajectoryEvaluator ,
345353 "InteractionsEvaluator" : InteractionsEvaluator ,
346354 }
347- all_evaluators = {** default_evaluators , ** {v .get_type_name (): v for v in custom_evaluators }}
355+ all_evaluators : dict [str , type [Evaluator ]] = {
356+ ** default_evaluators ,
357+ ** {v .get_type_name (): v for v in custom_evaluators },
358+ }
348359
349360 evaluator_type = data ["evaluator" ]["evaluator_type" ]
350361 evaluator_args = {k : v for k , v in data ["evaluator" ].items () if k != "evaluator_type" }
@@ -353,13 +364,14 @@ def from_dict(cls, data: dict, custom_evaluators: list[Evaluator] = None):
353364 evaluator = all_evaluators [evaluator_type ](** evaluator_args )
354365 else :
355366 raise Exception (
356- f"Cannot find { evaluator_type } . Make sure the evaluator type is spelled correctly and all relevant custom evaluators are passed in."
367+ f"Cannot find { evaluator_type } . Make sure the evaluator type is spelled correctly and "
368+ f"all relevant custom evaluators are passed in."
357369 )
358370
359371 return cls (cases = cases , evaluator = evaluator )
360372
361373 @classmethod
362- def from_file (cls , file_path : str , format : str = "json" , custom_evaluators : list [Evaluator ] = None ):
374+ def from_file (cls , file_path : str , format : str = "json" , custom_evaluators : list [type [ Evaluator ]] | None = None ):
363375 """
364376 Create a dataset from a file.
365377
0 commit comments