Skip to content

Commit 91c8e4d

Browse files
authored
Merge pull request #716 from NVIDIA/am/upds
Simplify internal hierarchy of classes
2 parents a8b105b + cb494db commit 91c8e4d

File tree

103 files changed

+799
-1304
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

103 files changed

+799
-1304
lines changed

src/cloudai/__init__.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,6 @@
1818
Registry,
1919
Runner,
2020
System,
21-
Test,
2221
TestDefinition,
2322
TestRun,
2423
TestScenario,
@@ -29,7 +28,6 @@
2928
"Registry",
3029
"Runner",
3130
"System",
32-
"Test",
3331
"TestDefinition",
3432
"TestRun",
3533
"TestScenario",

src/cloudai/_core/base_reporter.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -67,7 +67,7 @@ def load_test_runs(self):
6767
tr_file = tr.output_path / "test-run.toml"
6868
if tr_file.exists():
6969
tr_file = toml.load(tr_file)
70-
tr.test.test_definition = tr.test.test_definition.model_validate(tr_file["test_definition"])
70+
tr.test = tr.test.model_validate(tr_file["test_definition"])
7171
self.trs.append(copy.deepcopy(tr))
7272
else:
7373
tr.current_iteration = int(iter.name)

src/cloudai/_core/base_runner.py

Lines changed: 7 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -25,6 +25,7 @@
2525
from .command_gen_strategy import CommandGenStrategy
2626
from .exceptions import JobFailureError, JobSubmissionError
2727
from .job_status_result import JobStatusResult
28+
from .json_gen_strategy import JsonGenStrategy
2829
from .registry import Registry
2930
from .system import System
3031
from .test_scenario import TestRun, TestScenario
@@ -288,7 +289,7 @@ def get_job_status(self, job: BaseJob) -> JobStatusResult:
288289
JobStatusResult: The result containing the job status and an optional error message.
289290
"""
290291
runner_job_status_result = self.get_runner_job_status(job)
291-
workload_run_results = job.test_run.test.test_definition.was_run_successful(job.test_run)
292+
workload_run_results = job.test_run.test.was_run_successful(job.test_run)
292293
if not runner_job_status_result.is_successful:
293294
return runner_job_status_result
294295
if not workload_run_results.is_successful:
@@ -375,5 +376,9 @@ async def delayed_kill_job(self, job: BaseJob, delay: int = 0):
375376
self.system.kill(job)
376377

377378
def get_cmd_gen_strategy(self, system: System, test_run: TestRun) -> CommandGenStrategy:
378-
strategy_cls = Registry().get_command_gen_strategy(type(system), type(test_run.test.test_definition))
379+
strategy_cls = Registry().get_command_gen_strategy(type(system), type(test_run.test))
380+
return strategy_cls(system, test_run)
381+
382+
def get_json_gen_strategy(self, system: System, test_run: TestRun) -> JsonGenStrategy:
383+
strategy_cls = Registry().get_json_gen_strategy(type(system), type(test_run.test))
379384
return strategy_cls(system, test_run)

src/cloudai/_core/command_gen_strategy.py

Lines changed: 2 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,16 +18,15 @@
1818

1919
from .system import System
2020
from .test_scenario import TestRun
21-
from .test_template_strategy import TestTemplateStrategy
2221

2322

24-
class CommandGenStrategy(TestTemplateStrategy, ABC):
23+
class CommandGenStrategy(ABC):
2524
"""Abstract base class defining the interface for command generation strategies across different systems."""
2625

2726
TEST_RUN_DUMP_FILE_NAME: str = "test-run.toml"
2827

2928
def __init__(self, system: System, test_run: TestRun) -> None:
30-
super().__init__(system)
29+
self.system = system
3130
self.test_run = test_run
3231
self._final_env_vars: dict[str, str | list[str]] = {}
3332

src/cloudai/_core/grader.py

Lines changed: 8 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
# SPDX-FileCopyrightText: NVIDIA CORPORATION & AFFILIATES
2-
# Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
33
# SPDX-License-Identifier: Apache-2.0
44
#
55
# Licensed under the Apache License, Version 2.0 (the "License");
@@ -19,6 +19,8 @@
1919
from pathlib import Path
2020
from typing import Dict, List
2121

22+
from .registry import Registry
23+
from .system import System
2224
from .test_scenario import TestRun, TestScenario
2325

2426

@@ -31,8 +33,9 @@ class Grader:
3133
logger (logging.Logger): Logger for the class, used to log messages related to the grading process.
3234
"""
3335

34-
def __init__(self, output_path: Path) -> None:
36+
def __init__(self, output_path: Path, system: System) -> None:
3537
self.output_path = output_path
38+
self.system = system
3639

3740
def grade(self, test_scenario: TestScenario) -> str:
3841
"""
@@ -80,9 +83,11 @@ def _get_perfs_from_subdirs(self, directory_path: Path, tr: TestRun) -> List[flo
8083
List[float]: A list of performance values.
8184
"""
8285
perfs = []
86+
8387
for subdir in directory_path.iterdir():
8488
if subdir.is_dir() and subdir.name.isdigit():
85-
perf = tr.test.test_template.grade(subdir, tr.ideal_perf)
89+
grading_strategy = Registry().get_grading_strategy(type(self.system), type(tr.test))()
90+
perf = grading_strategy.grade(subdir, tr.ideal_perf)
8691
perfs.append(perf)
8792
return perfs
8893

src/cloudai/_core/json_gen_strategy.py

Lines changed: 8 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,5 @@
11
# SPDX-FileCopyrightText: NVIDIA CORPORATION & AFFILIATES
2-
# Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2+
# Copyright (c) 2024-2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
33
# SPDX-License-Identifier: Apache-2.0
44
#
55
# Licensed under the Apache License, Version 2.0 (the "License");
@@ -18,17 +18,21 @@
1818
from abc import ABC, abstractmethod
1919
from typing import Any, Dict
2020

21+
from .system import System
2122
from .test_scenario import TestRun
22-
from .test_template_strategy import TestTemplateStrategy
2323

2424

25-
class JsonGenStrategy(TestTemplateStrategy, ABC):
25+
class JsonGenStrategy(ABC):
2626
"""
2727
Abstract base class for generating Kubernetes job specifications based on system and test parameters.
2828
2929
It specifies how to generate JSON job specifications based on system and test parameters.
3030
"""
3131

32+
def __init__(self, system: System, test_run: TestRun) -> None:
33+
self.system = system
34+
self.test_run = test_run
35+
3236
def sanitize_k8s_job_name(self, job_name: str) -> str:
3337
"""
3438
Sanitize the job name to ensure it follows Kubernetes naming rules.
@@ -51,7 +55,7 @@ def sanitize_k8s_job_name(self, job_name: str) -> str:
5155
return sanitized_name[:253]
5256

5357
@abstractmethod
54-
def gen_json(self, tr: TestRun) -> Dict[Any, Any]:
58+
def gen_json(self) -> Dict[Any, Any]:
5559
"""
5660
Generate the Kubernetes job specification based on the given parameters.
5761

src/cloudai/_core/registry.py

Lines changed: 40 additions & 33 deletions
Original file line numberDiff line numberDiff line change
@@ -16,7 +16,7 @@
1616

1717
from __future__ import annotations
1818

19-
from typing import TYPE_CHECKING, Callable, ClassVar, List, Set, Tuple, Type, Union
19+
from typing import TYPE_CHECKING, Callable, ClassVar, List, Set, Tuple, Type
2020

2121
if TYPE_CHECKING:
2222
from ..configurator.base_agent import BaseAgent
@@ -49,16 +49,6 @@ class Registry(metaclass=Singleton):
4949
"""Registry for implementations mappings."""
5050

5151
runners_map: ClassVar[dict[str, Type[BaseRunner]]] = {}
52-
strategies_map: ClassVar[
53-
dict[
54-
Tuple[
55-
Type[Union[JsonGenStrategy, GradingStrategy]],
56-
Type[System],
57-
Type[TestDefinition],
58-
],
59-
Type[Union[JsonGenStrategy, GradingStrategy]],
60-
]
61-
] = {}
6252
installers_map: ClassVar[dict[str, Type[BaseInstaller]]] = {}
6353
systems_map: ClassVar[dict[str, Type[System]]] = {}
6454
test_definitions_map: ClassVar[dict[str, Type[TestDefinition]]] = {}
@@ -68,6 +58,8 @@ class Registry(metaclass=Singleton):
6858
report_configs: ClassVar[dict[str, ReportConfig]] = {}
6959
reward_functions_map: ClassVar[dict[str, RewardFunction]] = {}
7060
command_gen_strategies_map: ClassVar[dict[tuple[Type[System], Type[TestDefinition]], Type[CommandGenStrategy]]] = {}
61+
json_gen_strategies_map: ClassVar[dict[tuple[Type[System], Type[TestDefinition]], Type[JsonGenStrategy]]] = {}
62+
grading_strategies_map: ClassVar[dict[Tuple[Type[System], Type[TestDefinition]], Type[GradingStrategy]]] = {}
7163

7264
def add_runner(self, name: str, value: Type[BaseRunner]) -> None:
7365
"""
@@ -94,30 +86,23 @@ def update_runner(self, name: str, value: Type[BaseRunner]) -> None:
9486
"""
9587
self.runners_map[name] = value
9688

97-
def add_strategy(
98-
self,
99-
strategy_interface: Type[Union[JsonGenStrategy, GradingStrategy]],
100-
system_types: List[Type[System]],
101-
definition_types: List[Type[TestDefinition]],
102-
strategy: Type[Union[JsonGenStrategy, GradingStrategy]],
89+
def add_grading_strategy(
90+
self, system_type: Type[System], tdef_type: Type[TestDefinition], strategy: Type[GradingStrategy]
10391
) -> None:
104-
for system_type in system_types:
105-
for def_type in definition_types:
106-
key = (strategy_interface, system_type, def_type)
107-
if key in self.strategies_map:
108-
raise ValueError(f"Duplicating implementation for '{key}', use 'update()' for replacement.")
109-
self.update_strategy(key, strategy)
110-
111-
def update_strategy(
112-
self,
113-
key: Tuple[
114-
Type[Union[JsonGenStrategy, GradingStrategy]],
115-
Type[System],
116-
Type[TestDefinition],
117-
],
118-
value: Type[Union[JsonGenStrategy, GradingStrategy]],
92+
key = (system_type, tdef_type)
93+
if key in self.grading_strategies_map:
94+
raise ValueError(f"Duplicating implementation for '{key}', use 'update()' for replacement.")
95+
self.update_grading_strategy(key, strategy)
96+
97+
def update_grading_strategy(
98+
self, key: Tuple[Type[System], Type[TestDefinition]], value: Type[GradingStrategy]
11999
) -> None:
120-
self.strategies_map[key] = value
100+
self.grading_strategies_map[key] = value
101+
102+
def get_grading_strategy(self, system_type: Type[System], tdef_type: Type[TestDefinition]) -> Type[GradingStrategy]:
103+
if (system_type, tdef_type) not in self.grading_strategies_map:
104+
raise KeyError(f"Grading gen strategy for '{system_type.__name__}, {tdef_type.__name__}' not found.")
105+
return self.grading_strategies_map[(system_type, tdef_type)]
121106

122107
def add_installer(self, name: str, value: Type[BaseInstaller]) -> None:
123108
"""
@@ -274,3 +259,25 @@ def get_command_gen_strategy(
274259
if (system_type, tdef_type) not in self.command_gen_strategies_map:
275260
raise KeyError(f"Command gen strategy for '{system_type.__name__}, {tdef_type.__name__}' not found.")
276261
return self.command_gen_strategies_map[(system_type, tdef_type)]
262+
263+
def add_json_gen_strategy(
264+
self, system_type: Type[System], tdef_type: Type[TestDefinition], value: Type[JsonGenStrategy]
265+
) -> None:
266+
if (system_type, tdef_type) in self.json_gen_strategies_map:
267+
raise ValueError(
268+
f"Duplicating implementation for '{system_type.__name__}, {tdef_type.__name__}', use 'update()' "
269+
"for replacement."
270+
)
271+
self.update_json_gen_strategy(system_type, tdef_type, value)
272+
273+
def update_json_gen_strategy(
274+
self, system_type: Type[System], tdef_type: Type[TestDefinition], value: Type[JsonGenStrategy]
275+
) -> None:
276+
self.json_gen_strategies_map[(system_type, tdef_type)] = value
277+
278+
def get_json_gen_strategy(
279+
self, system_type: Type[System], tdef_type: Type[TestDefinition]
280+
) -> Type[JsonGenStrategy]:
281+
if (system_type, tdef_type) not in self.json_gen_strategies_map:
282+
raise KeyError(f"JSON gen strategy for '{system_type.__name__}, {tdef_type.__name__}' not found.")
283+
return self.json_gen_strategies_map[(system_type, tdef_type)]

src/cloudai/_core/strategy_registry.py

Lines changed: 0 additions & 45 deletions
This file was deleted.

src/cloudai/_core/test.py

Lines changed: 0 additions & 61 deletions
This file was deleted.

0 commit comments

Comments
 (0)