Skip to content

Commit a2bd0fd

Browse files
authored
Merge pull request #703 from NVIDIA/am/entry-agents
Allow agents to be registered via entry points
2 parents b7f0232 + 2432a0c commit a2bd0fd

File tree

4 files changed

+53
-8
lines changed

4 files changed

+53
-8
lines changed

src/cloudai/cli/cli.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -238,9 +238,8 @@ def verify_configs(configs_dir: Path, tests_dir: Path):
238238

239239

240240
@main.command()
241-
@click.argument("type", type=click.Choice(["reports"]))
241+
@click.argument("type", type=click.Choice(["reports", "agents"], case_sensitive=False))
242242
@click.option("-v", "--verbose", is_flag=True, default=False, help="Verbose output.")
243243
def list(type: str, verbose: bool):
244244
"""List available in Registry items."""
245-
args = argparse.Namespace(type=type, verbose=verbose)
246-
handle_list_registered_items(args)
245+
handle_list_registered_items(type, verbose)

src/cloudai/cli/handlers.py

Lines changed: 11 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -539,15 +539,21 @@ def load_tomls_by_type(tomls: List[Path]) -> dict[str, List[Path]]:
539539
return files
540540

541541

542-
def handle_list_registered_items(args: argparse.Namespace) -> int:
543-
item_type = args.type
542+
def handle_list_registered_items(item_type: str, verbose: bool) -> int:
544543
registry = Registry()
545-
if item_type == "reports":
546-
print("Registered scenario reports:")
544+
if item_type.lower() == "reports":
545+
print("Available scenario reports:")
547546
for idx, (name, report) in enumerate(sorted(registry.scenario_reports.items()), start=1):
548547
str = f'{idx}. "{name}" {report.__name__}'
549-
if args.verbose:
548+
if verbose:
550549
str += f" (config={registry.report_configs[name].model_dump_json(indent=None)})"
551550
print(str)
551+
elif item_type.lower() == "agents":
552+
print("Available agents:")
553+
for idx, (name, agent) in enumerate(sorted(registry.agents_map.items()), start=1):
554+
str = f'{idx}. "{name}" class={agent.__name__}'
555+
if verbose:
556+
str += f"{agent.__doc__}"
557+
print(str)
552558

553559
return 0

src/cloudai/registration.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,24 @@
1414
# See the License for the specific language governing permissions and
1515
# limitations under the License.
1616

17+
import warnings
18+
from importlib.metadata import entry_points
19+
20+
21+
def register_entrypoint_agents():
22+
from cloudai.configurator.base_agent import BaseAgent
23+
from cloudai.core import Registry
24+
25+
eps = entry_points(group="cloudai.agents")
26+
for ep in eps:
27+
cls = ep.load()
28+
if issubclass(cls, BaseAgent):
29+
Registry().add_agent(ep.name, cls)
30+
else:
31+
warnings.warn(
32+
f"Skipping entrypoint: {ep.name} -> {ep.value} class={cls} (not a subclass of BaseAgent)", stacklevel=2
33+
)
34+
1735

1836
def register_all():
1937
"""Register all workloads, systems, runners, installers, and strategies."""
@@ -233,3 +251,5 @@ def register_all():
233251
Registry().add_reward_function("ai_dynamo_weighted_normalized", ai_dynamo_weighted_normalized_reward)
234252
Registry().add_reward_function("ai_dynamo_ratio_normalized", ai_dynamo_ratio_normalized_reward)
235253
Registry().add_reward_function("ai_dynamo_log_scale", ai_dynamo_log_scale_reward)
254+
255+
register_entrypoint_agents()

tests/test_registry.py

Lines changed: 20 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -15,6 +15,8 @@
1515
# limitations under the License.
1616

1717
import copy
18+
from typing import Any
19+
from unittest.mock import patch
1820

1921
import pytest
2022

@@ -31,6 +33,7 @@
3133
)
3234
from cloudai.models.scenario import ReportConfig
3335
from cloudai.models.workload import TestDefinition
36+
from cloudai.registration import register_entrypoint_agents
3437

3538

3639
class MyTestDefinition(TestDefinition):
@@ -322,3 +325,20 @@ def test_get_command_gen_strategy_not_found(self, registry: Registry):
322325
with pytest.raises(KeyError) as exc_info:
323326
registry.get_command_gen_strategy(MySystem, AnotherTestDefinition)
324327
assert exc_info.match("Command gen strategy for 'MySystem, AnotherTestDefinition' not found.")
328+
329+
330+
def test_entrypoint_agent_type_verified():
331+
class MockEP:
332+
def __init__(self, load_value: Any):
333+
self._load_value = load_value
334+
self.name = "name"
335+
self.value = "value"
336+
337+
def load(self):
338+
return self._load_value
339+
340+
with (
341+
patch("cloudai.registration.entry_points", return_value=[MockEP(str)]),
342+
pytest.warns(UserWarning, match="(not a subclass of BaseAgent)"),
343+
):
344+
register_entrypoint_agents()

0 commit comments

Comments
 (0)