|
16 | 16 |
|
17 | 17 | from __future__ import annotations |
18 | 18 |
|
19 | | -from typing import TYPE_CHECKING, ClassVar, List, Set, Tuple, Type, Union |
| 19 | +from typing import TYPE_CHECKING, Callable, ClassVar, List, Set, Tuple, Type, Union |
20 | 20 |
|
21 | 21 | if TYPE_CHECKING: |
22 | 22 | from ..configurator.base_agent import BaseAgent |
|
32 | 32 | from .system import System |
33 | 33 | from .test_template_strategy import TestTemplateStrategy |
34 | 34 |
|
| 35 | +RewardFunction = Callable[[List[float]], float] |
| 36 | + |
35 | 37 |
|
36 | 38 | class Singleton(type): |
37 | 39 | """Singleton metaclass.""" |
@@ -79,6 +81,7 @@ class Registry(metaclass=Singleton): |
79 | 81 | reports_map: ClassVar[dict[Type[TestDefinition], Set[Type[ReportGenerationStrategy]]]] = {} |
80 | 82 | scenario_reports: ClassVar[dict[str, type[Reporter]]] = {} |
81 | 83 | report_configs: ClassVar[dict[str, ReportConfig]] = {} |
| 84 | + reward_functions_map: ClassVar[dict[str, RewardFunction]] = {} |
82 | 85 |
|
83 | 86 | def add_runner(self, name: str, value: Type[BaseRunner]) -> None: |
84 | 87 | """ |
@@ -276,3 +279,18 @@ def add_scenario_report(self, name: str, report: type[Reporter], config: ReportC |
276 | 279 | def update_scenario_report(self, name: str, report: type[Reporter], config: ReportConfig) -> None: |
277 | 280 | self.scenario_reports[name] = report |
278 | 281 | self.report_configs[name] = config |
| 282 | + |
| 283 | + def add_reward_function(self, name: str, value: RewardFunction) -> None: |
| 284 | + if name in self.reward_functions_map: |
| 285 | + raise ValueError(f"Duplicating implementation for '{name}', use 'update()' for replacement.") |
| 286 | + self.update_reward_function(name, value) |
| 287 | + |
| 288 | + def update_reward_function(self, name: str, value: RewardFunction) -> None: |
| 289 | + self.reward_functions_map[name] = value |
| 290 | + |
| 291 | + def get_reward_function(self, name: str) -> RewardFunction: |
| 292 | + if name not in self.reward_functions_map: |
| 293 | + raise KeyError( |
| 294 | + f"Reward function '{name}' not found. Available functions: {list(self.reward_functions_map.keys())}" |
| 295 | + ) |
| 296 | + return self.reward_functions_map[name] |
0 commit comments