Skip to content

Commit b899b53

Browse files
committed
Remove TestTemplate concept
Previously it help all the strategies, but now it is obsolete.
1 parent 328afab commit b899b53

File tree

46 files changed

+103
-281
lines changed

Some content is hidden

Large Commits have some content hidden by default. Use the searchbox below for content that may be hidden.

46 files changed

+103
-281
lines changed

src/cloudai/_core/test.py

Lines changed: 1 addition & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -18,8 +18,6 @@
1818

1919
from typing import TYPE_CHECKING, Dict, List, Union
2020

21-
from .test_template import TestTemplate
22-
2321
if TYPE_CHECKING:
2422
from cloudai.models.workload import TestDefinition
2523

@@ -29,15 +27,13 @@ class Test:
2927

3028
__test__ = False
3129

32-
def __init__(self, test_definition: TestDefinition, test_template: TestTemplate) -> None:
30+
def __init__(self, test_definition: TestDefinition) -> None:
3331
"""
3432
Initialize a Test instance.
3533
3634
Args:
3735
test_definition (TestDefinition): The test definition object.
38-
test_template (TestTemplate): The test template object
3936
"""
40-
self.test_template = test_template
4137
self.test_definition = test_definition
4238

4339
@property

src/cloudai/_core/test_template.py

Lines changed: 0 additions & 38 deletions
This file was deleted.

src/cloudai/core.py

Lines changed: 0 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -41,7 +41,6 @@
4141
from ._core.system import System
4242
from ._core.test import Test
4343
from ._core.test_scenario import METRIC_ERROR, TestDependency, TestRun, TestScenario
44-
from ._core.test_template import TestTemplate
4544
from .configurator.base_agent import BaseAgent
4645
from .configurator.cloudai_gym import CloudAIGymEnv
4746
from .configurator.grid_search import GridSearchAgent
@@ -95,7 +94,6 @@
9594
"TestScenario",
9695
"TestScenarioParser",
9796
"TestScenarioParsingError",
98-
"TestTemplate",
9997
"case_name",
10098
"format_validation_error",
10199
]

src/cloudai/systems/lsf/lsf_command_gen_strategy.py

Lines changed: 1 addition & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -55,9 +55,7 @@ def gen_exec_command(self) -> str:
5555
"""
5656
env_vars = self.final_env_vars
5757
cmd_args = flatten_dict(self.test_run.test.cmd_args)
58-
lsf_args = self._parse_lsf_args(
59-
self.test_run.test.test_template.__class__.__name__, env_vars, cmd_args, self.test_run
60-
)
58+
lsf_args = self._parse_lsf_args(self.test_run.test.name, env_vars, cmd_args, self.test_run)
6159

6260
bsub_command = self._gen_bsub_command(lsf_args, env_vars, cmd_args, self.test_run)
6361

src/cloudai/systems/slurm/slurm_command_gen_strategy.py

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -118,7 +118,7 @@ def gen_srun_command(self) -> str:
118118
return self._gen_srun_command()
119119

120120
def job_name_prefix(self) -> str:
121-
return self.test_run.test.test_template.__class__.__name__
121+
return self.test_run.test.name
122122

123123
def job_name(self) -> str:
124124
job_name_prefix = self.job_name_prefix()

src/cloudai/test_parser.py

Lines changed: 2 additions & 25 deletions
Original file line numberDiff line numberDiff line change
@@ -21,14 +21,7 @@
2121
import toml
2222
from pydantic import ValidationError
2323

24-
from .core import (
25-
Registry,
26-
System,
27-
Test,
28-
TestConfigParsingError,
29-
TestTemplate,
30-
format_validation_error,
31-
)
24+
from .core import Registry, System, Test, TestConfigParsingError, format_validation_error
3225
from .models.workload import TestDefinition
3326

3427

@@ -92,31 +85,15 @@ def load_test_definition(self, data: dict) -> TestDefinition:
9285

9386
return test_def
9487

95-
def _get_test_template(self, name: str, tdef: TestDefinition) -> TestTemplate:
96-
"""
97-
Dynamically retrieves the appropriate TestTemplate subclass based on the given name.
98-
99-
Args:
100-
name (str): The name of the test template.
101-
tdef (TestDefinition): The test definition.
102-
103-
Returns:
104-
Type[TestTemplate]: A subclass of TestTemplate corresponding to the given name.
105-
"""
106-
obj = TestTemplate(system=self.system)
107-
return obj
108-
10988
def _parse_data(self, data: Dict[str, Any]) -> Test:
11089
"""
11190
Parse data for a Test object.
11291
11392
Args:
11493
data (Dict[str, Any]): Data from a source (e.g., a TOML file).
115-
strict (bool): Whether to enforce strict validation for test definition.
11694
11795
Returns:
11896
Test: Parsed Test object.
11997
"""
12098
test_def = self.load_test_definition(data)
121-
test_template = self._get_test_template(test_def.test_template_name, test_def)
122-
return Test(test_definition=test_def, test_template=test_template)
99+
return Test(test_definition=test_def)

src/cloudai/test_scenario_parser.py

Lines changed: 4 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,7 @@
1818
import logging
1919
from datetime import timedelta
2020
from pathlib import Path
21-
from typing import Any, Dict, List, Optional, Set, Tuple, Type
21+
from typing import Any, Dict, List, Optional, Set, Type
2222

2323
import toml
2424
from pydantic import ValidationError
@@ -186,9 +186,7 @@ def _create_test_run(
186186
Raises:
187187
ValueError: If the test or nodes are not found within the system.
188188
"""
189-
original_test, tdef = self._prepare_tdef(test_info)
190-
191-
test = Test(test_definition=tdef, test_template=original_test.test_template)
189+
test = Test(test_definition=self._prepare_tdef(test_info))
192190

193191
hooks = [hook for hook in [pre_test, post_test] if hook is not None]
194192
total_time_limit = calculate_total_time_limit(test_hooks=hooks, time_limit=test_info.time_limit)
@@ -210,7 +208,7 @@ def _create_test_run(
210208

211209
return tr
212210

213-
def _prepare_tdef(self, test_info: TestRunModel) -> Tuple[Test, TestDefinition]:
211+
def _prepare_tdef(self, test_info: TestRunModel) -> TestDefinition:
214212
tp = TestParser([self.file_path], self.system)
215213
tp.current_file = self.file_path
216214

@@ -231,4 +229,4 @@ def _prepare_tdef(self, test_info: TestRunModel) -> Tuple[Test, TestDefinition]:
231229
f"Cannot configure test case '{test_info.id}' with both 'test_name' and 'test_template_name'."
232230
)
233231

234-
return test, test.test_definition
232+
return test.test_definition

tests/conftest.py

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -23,7 +23,7 @@
2323
import toml
2424
import yaml
2525

26-
from cloudai.core import CommandGenStrategy, Test, TestRun, TestTemplate
26+
from cloudai.core import CommandGenStrategy, Test, TestRun
2727
from cloudai.models.scenario import TestRunDetails
2828
from cloudai.models.workload import CmdArgs, TestDefinition
2929
from cloudai.systems.kubernetes import KubernetesSystem
@@ -125,8 +125,7 @@ def base_tr(slurm_system: SlurmSystem) -> TestRun:
125125
return TestRun(
126126
name="tr-name",
127127
test=Test(
128-
test_definition=TestDefinition(name="n", description="d", test_template_name="tt", cmd_args=CmdArgs()),
129-
test_template=TestTemplate(slurm_system),
128+
test_definition=TestDefinition(name="n", description="d", test_template_name="tt", cmd_args=CmdArgs())
130129
),
131130
num_nodes=1,
132131
nodes=[],
@@ -161,8 +160,7 @@ def benchmark_tr(slurm_system: SlurmSystem) -> TestRun:
161160
test_template_name="NcclTest",
162161
cmd_args=NCCLCmdArgs(docker_image_url="fake://url/nccl"),
163162
)
164-
test_template = TestTemplate(system=slurm_system)
165-
test = Test(test_definition=test_definition, test_template=test_template)
163+
test = Test(test_definition=test_definition)
166164
tr = TestRun(name="benchmark", test=test, num_nodes=1, nodes=["node1"], iterations=3)
167165
create_test_directories(slurm_system, tr)
168166
return tr
@@ -178,8 +176,7 @@ def dse_tr(slurm_system: SlurmSystem) -> TestRun:
178176
extra_env_vars={"VAR1": ["value1", "value2"]},
179177
agent_steps=12,
180178
)
181-
test_template = TestTemplate(system=slurm_system)
182-
test = Test(test_definition=test_definition, test_template=test_template)
179+
test = Test(test_definition=test_definition)
183180

184181
tr = TestRun(name="dse", test=test, num_nodes=1, nodes=["node1"], iterations=12)
185182
create_test_directories(slurm_system, tr)

tests/json_gen_strategy/test_nccl_kubernetes_json_gen_strategy.py

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -17,7 +17,7 @@
1717

1818
import pytest
1919

20-
from cloudai.core import Test, TestRun, TestTemplate
20+
from cloudai.core import Test, TestRun
2121
from cloudai.systems.kubernetes import KubernetesSystem
2222
from cloudai.workloads.nccl_test import NCCLCmdArgs, NCCLTestDefinition, NcclTestKubernetesJsonGenStrategy
2323

@@ -27,8 +27,7 @@ class TestNcclTestKubernetesJsonGenStrategy:
2727
def basic_test_run(self, kubernetes_system: KubernetesSystem) -> TestRun:
2828
cmd_args = NCCLCmdArgs.model_validate({"subtest_name": "all_reduce_perf", "docker_image_url": "fake_image_url"})
2929
nccl = NCCLTestDefinition(name="name", description="desc", test_template_name="NcclTest", cmd_args=cmd_args)
30-
test_template = TestTemplate(kubernetes_system)
31-
t = Test(test_definition=nccl, test_template=test_template)
30+
t = Test(test_definition=nccl)
3231
return TestRun(name="t1", test=t, nodes=["node1", "node2"], num_nodes=2)
3332

3433
@pytest.fixture
@@ -41,8 +40,7 @@ def test_run_with_env_vars(self, kubernetes_system: KubernetesSystem) -> TestRun
4140
cmd_args=cmd_args,
4241
extra_env_vars={"TEST_VAR": "test_value", "LIST_VAR": ["item1", "item2"]},
4342
)
44-
test_template = TestTemplate(kubernetes_system)
45-
t = Test(test_definition=nccl, test_template=test_template)
43+
t = Test(test_definition=nccl)
4644
return TestRun(name="t1", test=t, nodes=["node1"], num_nodes=1)
4745

4846
@pytest.fixture
@@ -64,8 +62,7 @@ def test_run_with_extra_args(self, kubernetes_system: KubernetesSystem) -> TestR
6462
cmd_args=cmd_args,
6563
extra_cmd_args={"extra-flag": "value"},
6664
)
67-
test_template = TestTemplate(kubernetes_system)
68-
t = Test(test_definition=nccl, test_template=test_template)
65+
t = Test(test_definition=nccl)
6966
return TestRun(name="t1", test=t, nodes=["node1"], num_nodes=1)
7067

7168
def json_gen_strategy(

tests/json_gen_strategy/test_nccl_runai_json_gen_strategy.py

Lines changed: 3 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -14,7 +14,6 @@
1414
# See the License for the specific language governing permissions and
1515
# limitations under the License.
1616

17-
from unittest.mock import Mock
1817

1918
from cloudai.core import Test, TestRun
2019
from cloudai.systems.runai.runai_system import RunAISystem
@@ -28,7 +27,7 @@ def json_gen_strategy(self, runai_system: RunAISystem, tr: TestRun) -> NcclTestR
2827
def test_gen_json(self, runai_system: RunAISystem) -> None:
2928
cmd_args = NCCLCmdArgs.model_validate({"subtest_name": "all_reduce_perf", "docker_image_url": "fake_image_url"})
3029
nccl = NCCLTestDefinition(name="name", description="desc", test_template_name="tt", cmd_args=cmd_args)
31-
t = Test(test_definition=nccl, test_template=Mock())
30+
t = Test(test_definition=nccl)
3231
tr = TestRun(name="t1", test=t, nodes=["node1", "node2"], num_nodes=2)
3332
json_payload = self.json_gen_strategy(runai_system, tr).gen_json()
3433

@@ -63,7 +62,7 @@ def test_gen_json_with_cmd_args(self, runai_system: RunAISystem) -> None:
6362
}
6463
)
6564
nccl = NCCLTestDefinition(name="name", description="desc", test_template_name="tt", cmd_args=cmd_args)
66-
t = Test(test_definition=nccl, test_template=Mock())
65+
t = Test(test_definition=nccl)
6766
tr = TestRun(name="t1", test=t, nodes=["node1", "node2"], num_nodes=2)
6867

6968
json_payload = self.json_gen_strategy(runai_system, tr).gen_json()
@@ -107,7 +106,7 @@ def test_gen_json_with_extra_cmd_args(self, runai_system: RunAISystem) -> None:
107106
cmd_args=cmd_args,
108107
extra_cmd_args={"--extra-arg": "value"},
109108
)
110-
t = Test(test_definition=nccl, test_template=Mock())
109+
t = Test(test_definition=nccl)
111110
tr = TestRun(name="t1", test=t, nodes=["node1", "node2"], num_nodes=2)
112111

113112
json_payload = self.json_gen_strategy(runai_system, tr).gen_json()

0 commit comments

Comments
 (0)