Skip to content

Commit e9bc384

Browse files
authored
Add configurable reward functions to CloudAIGym (#566)
1 parent 0c35fd3 commit e9bc384

File tree

7 files changed

+120
-15
lines changed

7 files changed

+120
-15
lines changed

src/cloudai/_core/registry.py

Lines changed: 19 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616

1717
from __future__ import annotations
1818

19-
from typing import TYPE_CHECKING, ClassVar, List, Set, Tuple, Type, Union
19+
from typing import TYPE_CHECKING, Callable, ClassVar, List, Set, Tuple, Type, Union
2020

2121
if TYPE_CHECKING:
2222
from ..configurator.base_agent import BaseAgent
@@ -32,6 +32,8 @@
3232
from .system import System
3333
from .test_template_strategy import TestTemplateStrategy
3434

35+
RewardFunction = Callable[[List[float]], float]
36+
3537

3638
class Singleton(type):
3739
"""Singleton metaclass."""
@@ -79,6 +81,7 @@ class Registry(metaclass=Singleton):
7981
reports_map: ClassVar[dict[Type[TestDefinition], Set[Type[ReportGenerationStrategy]]]] = {}
8082
scenario_reports: ClassVar[dict[str, type[Reporter]]] = {}
8183
report_configs: ClassVar[dict[str, ReportConfig]] = {}
84+
reward_functions_map: ClassVar[dict[str, RewardFunction]] = {}
8285

8386
def add_runner(self, name: str, value: Type[BaseRunner]) -> None:
8487
"""
@@ -276,3 +279,18 @@ def add_scenario_report(self, name: str, report: type[Reporter], config: ReportC
276279
def update_scenario_report(self, name: str, report: type[Reporter], config: ReportConfig) -> None:
277280
self.scenario_reports[name] = report
278281
self.report_configs[name] = config
282+
283+
def add_reward_function(self, name: str, value: RewardFunction) -> None:
284+
if name in self.reward_functions_map:
285+
raise ValueError(f"Duplicating implementation for '{name}', use 'update()' for replacement.")
286+
self.update_reward_function(name, value)
287+
288+
def update_reward_function(self, name: str, value: RewardFunction) -> None:
289+
self.reward_functions_map[name] = value
290+
291+
def get_reward_function(self, name: str) -> RewardFunction:
292+
if name not in self.reward_functions_map:
293+
raise KeyError(
294+
f"Reward function '{name}' not found. Available functions: {list(self.reward_functions_map.keys())}"
295+
)
296+
return self.reward_functions_map[name]

src/cloudai/configurator/__init__.py

Lines changed: 6 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -19,4 +19,9 @@
1919
from .cloudai_gym import CloudAIGymEnv
2020
from .grid_search import GridSearchAgent
2121

22-
__all__ = ["BaseAgent", "BaseGym", "CloudAIGymEnv", "GridSearchAgent"]
22+
__all__ = [
23+
"BaseAgent",
24+
"BaseGym",
25+
"CloudAIGymEnv",
26+
"GridSearchAgent",
27+
]

src/cloudai/configurator/cloudai_gym.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -20,7 +20,7 @@
2020
import logging
2121
from typing import Any, Dict, Optional, Tuple
2222

23-
from cloudai.core import METRIC_ERROR, Runner, TestRun
23+
from cloudai.core import METRIC_ERROR, Registry, Runner, TestRun
2424
from cloudai.util.lazy_imports import lazy
2525

2626
from .base_gym import BaseGym
@@ -44,6 +44,7 @@ def __init__(self, test_run: TestRun, runner: Runner):
4444
self.test_run = test_run
4545
self.runner = runner
4646
self.max_steps = test_run.test.test_definition.agent_steps
47+
self.reward_function = Registry().get_reward_function(test_run.test.test_definition.agent_reward_function)
4748
super().__init__()
4849

4950
def define_action_space(self) -> Dict[str, Any]:
@@ -144,9 +145,7 @@ def compute_reward(self, observation: list) -> float:
144145
Returns:
145146
float: Reward value.
146147
"""
147-
if observation and observation[0] != 0:
148-
return 1.0 / observation[0]
149-
return 0.0
148+
return self.reward_function(observation)
150149

151150
def get_observation(self, action: Any) -> list:
152151
"""
Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,35 @@
1+
# SPDX-FileCopyrightText: NVIDIA CORPORATION & AFFILIATES
2+
# Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
3+
# SPDX-License-Identifier: Apache-2.0
4+
#
5+
# Licensed under the Apache License, Version 2.0 (the "License");
6+
# you may not use this file except in compliance with the License.
7+
# You may obtain a copy of the License at
8+
#
9+
# http://www.apache.org/licenses/LICENSE-2.0
10+
#
11+
# Unless required by applicable law or agreed to in writing, software
12+
# distributed under the License is distributed on an "AS IS" BASIS,
13+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14+
# See the License for the specific language governing permissions and
15+
# limitations under the License.
16+
17+
from typing import List
18+
19+
20+
def inverse_reward(observation: List[float]) -> float:
21+
if observation and observation[0] != 0:
22+
return 1.0 / observation[0]
23+
return 0.0
24+
25+
26+
def negative_reward(observation: List[float]) -> float:
27+
if observation:
28+
return -observation[0]
29+
return 0.0
30+
31+
32+
def identity_reward(observation: List[float]) -> float:
33+
if observation:
34+
return observation[0]
35+
return 0.0

src/cloudai/models/workload.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -106,6 +106,7 @@ class TestDefinition(BaseModel, ABC):
106106
agent: str = "grid_search"
107107
agent_steps: int = 1
108108
agent_metrics: list[str] = Field(default=["default"])
109+
agent_reward_function: str = "inverse"
109110

110111
@property
111112
def cmd_args_dict(self) -> Dict[str, Union[str, List[str]]]:

src/cloudai/registration.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,11 @@
1818
def register_all():
1919
"""Register all workloads, systems, runners, installers, and strategies."""
2020
from cloudai.configurator.grid_search import GridSearchAgent
21+
from cloudai.configurator.reward_functions import (
22+
identity_reward,
23+
inverse_reward,
24+
negative_reward,
25+
)
2126
from cloudai.core import (
2227
CommandGenStrategy,
2328
GradingStrategy,
@@ -308,3 +313,7 @@ def register_all():
308313
Registry().add_scenario_report("per_test", PerTestReporter, ReportConfig(enable=True))
309314
Registry().add_scenario_report("status", StatusReporter, ReportConfig(enable=True))
310315
Registry().add_scenario_report("tarball", TarballReporter, ReportConfig(enable=True))
316+
317+
Registry().add_reward_function("inverse", inverse_reward)
318+
Registry().add_reward_function("negative", negative_reward)
319+
Registry().add_reward_function("identity", identity_reward)

tests/test_cloudaigym.py

Lines changed: 47 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -92,18 +92,56 @@ def test_observation_space(setup_env):
9292

9393

9494
@pytest.mark.parametrize(
95-
"observation,expected_reward",
95+
"reward_function,test_cases",
9696
[
97-
([0.34827126874999986], pytest.approx(2.871, 0.001)),
98-
([0.0], 0.0),
99-
([], 0.0),
100-
([2.0, 2.0], 0.5),
97+
(
98+
"inverse",
99+
[
100+
([0.34827126874999986], pytest.approx(2.871, 0.001)),
101+
([0.0], 0.0),
102+
([], 0.0),
103+
([2.0, 2.0], 0.5),
104+
],
105+
),
106+
(
107+
"negative",
108+
[
109+
([2.0], -2.0),
110+
([-1.5], 1.5),
111+
([0.0], 0.0),
112+
([], 0.0),
113+
],
114+
),
115+
(
116+
"identity",
117+
[
118+
([2.0], 2.0),
119+
([-1.5], -1.5),
120+
([0.0], 0.0),
121+
([], 0.0),
122+
],
123+
),
101124
],
102125
)
103-
def test_compute_reward(observation: list[float], expected_reward: float):
104-
env = CloudAIGymEnv(test_run=MagicMock(), runner=MagicMock())
105-
reward = env.compute_reward(observation)
106-
assert reward == expected_reward
126+
def test_compute_reward(reward_function, test_cases):
127+
test_run = MagicMock()
128+
test_run.test.test_definition.agent_reward_function = reward_function
129+
env = CloudAIGymEnv(test_run=test_run, runner=MagicMock())
130+
131+
for input_value, expected_reward in test_cases:
132+
reward = env.compute_reward(input_value)
133+
assert reward == expected_reward
134+
135+
136+
def test_compute_reward_invalid():
137+
test_run = MagicMock()
138+
test_run.test.test_definition.agent_reward_function = "nonexistent"
139+
140+
with pytest.raises(KeyError) as exc_info:
141+
CloudAIGymEnv(test_run=test_run, runner=MagicMock())
142+
143+
assert "Reward function 'nonexistent' not found" in str(exc_info.value)
144+
assert "Available functions: ['inverse', 'negative', 'identity']" in str(exc_info.value)
107145

108146

109147
def test_tr_output_path(setup_env: tuple[TestRun, Runner]):

0 commit comments

Comments
 (0)